From 544dd14b4301beb47136f273deff3f532cdde181 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Fri, 7 Feb 2025 14:17:18 -0800 Subject: [PATCH] Update main branch with TE 2.0 code, update version to 2.1.0.dev0 Signed-off-by: Przemek Tredak --- .github/workflows/build.yml | 20 - .github/workflows/lint.yml | 27 - .gitignore | 1 - 3rdparty/cudnn-frontend | 2 +- README.rst | 4 +- build_tools/VERSION.txt | 2 +- build_tools/build_ext.py | 67 +- build_tools/paddle.py | 92 - build_tools/pytorch.py | 1 - build_tools/utils.py | 16 +- build_tools/wheel_utils/build_wheels.sh | 36 - docs/api/common.rst | 2 +- docs/api/framework.rst | 1 - docs/api/paddle.rst | 34 - docs/api/pytorch.rst | 2 - docs/examples/attention/attention.ipynb | 54 +- docs/installation.rst | 2 +- examples/README.md | 5 +- examples/paddle/mnist/README.md | 7 - .../paddle/mnist/test_single_gpu_mnist.py | 291 -- pylintrc | 1 - qa/L0_jax_unittest/test.sh | 2 +- qa/L0_paddle_lint/test.sh | 24 - qa/L0_paddle_unittest/test.sh | 10 - qa/L0_paddle_wheel/test.sh | 37 - qa/L0_pytorch_unittest/test.sh | 3 +- qa/L1_pytorch_distributed_unittest/test.sh | 4 +- qa/L1_pytorch_onnx_test/test.sh | 16 - qa/L3_pytorch_FA_versions_test/test.sh | 13 +- setup.py | 25 +- tests/cpp/CMakeLists.txt | 6 +- tests/cpp/operator/CMakeLists.txt | 18 +- tests/cpp/operator/test_act.cu | 88 +- tests/cpp/operator/test_cast.cu | 130 + tests/cpp/operator/test_cast_dbias.cu | 181 ++ tests/cpp/operator/test_cast_dbias_dgelu.cu | 196 ++ tests/cpp/operator/test_cast_gated_swiglu.cu | 165 ++ tests/cpp/operator/test_cast_mxfp8.cu | 636 ++++ .../operator/test_cast_mxfp8_gated_swiglu.cu | 470 +++ tests/cpp/operator/test_cast_transpose.cu | 26 +- .../cpp/operator/test_cast_transpose_dbias.cu | 53 +- .../test_cast_transpose_dbias_dgelu.cu | 39 +- .../operator/test_cast_transpose_dgeglu.cu | 26 +- tests/cpp/operator/test_causal_softmax.cu | 18 +- tests/cpp/operator/test_dequantize_mxfp8.cu | 452 +++ .../cpp/operator/test_multi_cast_transpose.cu | 37 +- tests/cpp/operator/test_multi_padding.cu | 10 +- tests/cpp/operator/test_normalization.cu | 107 +- .../cpp/operator/test_normalization_mxfp8.cu | 337 +++ tests/cpp/operator/test_qdq.cu | 22 +- tests/cpp/operator/test_swizzle.cu | 165 ++ tests/cpp/operator/test_transpose.cu | 8 +- tests/cpp/test_common.cu | 670 ++++- tests/cpp/test_common.h | 345 ++- tests/cpp/util/CMakeLists.txt | 7 +- tests/jax/conftest.py | 3 - tests/jax/test_layer.py | 39 +- tests/jax/utils.py | 23 +- tests/paddle/dist_launcher.py | 145 - tests/paddle/parallel_tests/amax_reduction.py | 87 - tests/paddle/parallel_tests/attention_tp.py | 234 -- tests/paddle/parallel_tests/group_sharding.py | 188 -- .../parallel_tests/layernorm_linear_tp.py | 182 -- .../paddle/parallel_tests/layernorm_mlp_tp.py | 197 -- tests/paddle/parallel_tests/linear_pp.py | 235 -- tests/paddle/parallel_tests/linear_tp.py | 222 -- tests/paddle/parallel_tests/transformer_tp.py | 250 -- .../recompute_transformer_encoder.py | 71 - tests/paddle/test_install.py | 11 - tests/paddle/test_layers.py | 1663 ----------- tests/paddle/test_master_grad.py | 92 - tests/paddle/test_operators.py | 1201 -------- tests/paddle/test_parallel.py | 99 - tests/paddle/test_recompute.py | 56 - tests/paddle/utils.py | 221 -- tests/pytorch/custom_ort_ops/.gitignore | 3 - tests/pytorch/custom_ort_ops/CMakeLists.txt | 29 - tests/pytorch/custom_ort_ops/README.md | 22 - tests/pytorch/custom_ort_ops/build.sh | 17 - .../custom_ort_ops/custom_op_library.cc | 102 - .../distributed/run_gemm_with_overlap.py | 284 +- .../distributed/run_layer_with_overlap.py | 94 +- tests/pytorch/distributed/run_numerics.py | 81 +- .../distributed/test_comm_gemm_overlap.py | 181 +- tests/pytorch/distributed/test_fusible_ops.py | 172 +- tests/pytorch/distributed/test_numerics.py | 22 +- tests/pytorch/distributed/test_torch_fsdp2.py | 45 +- .../fused_attn/run_fused_attn_with_cp.py | 11 +- tests/pytorch/fused_attn/test_fused_attn.py | 404 +-- tests/pytorch/test_cpu_offloading.py | 57 + tests/pytorch/test_cuda_graphs.py | 22 +- tests/pytorch/test_float8tensor.py | 165 +- tests/pytorch/test_fused_optimizer.py | 3 +- tests/pytorch/test_fusible_ops.py | 577 ++-- tests/pytorch/test_numerics.py | 241 +- tests/pytorch/test_onnx_export.py | 1562 ---------- tests/pytorch/test_permutation.py | 38 +- tests/pytorch/test_recipe.py | 74 +- tests/pytorch/test_sanity.py | 100 +- tests/pytorch/test_torch_save_load.py | 474 --- transformer_engine/__init__.py | 10 - transformer_engine/common/CMakeLists.txt | 9 +- .../common/activation/activation_template.h | 130 +- transformer_engine/common/activation/gelu.cu | 29 +- transformer_engine/common/activation/relu.cu | 28 +- .../common/activation/swiglu.cu | 14 +- .../comm_gemm_overlap/comm_gemm_overlap.cpp | 471 +-- .../userbuffers/userbuffers.cu | 1 + transformer_engine/common/common.cu | 121 +- transformer_engine/common/common.h | 207 +- .../common/fused_attn/fused_attn.cpp | 55 +- .../fused_attn_f16_arbitrary_seqlen.cu | 6 +- .../common/fused_attn/fused_attn_fp8.cu | 215 +- .../common/gemm/cublaslt_gemm.cu | 215 +- .../include/transformer_engine/activation.h | 165 +- .../common/include/transformer_engine/cast.h | 199 +- .../transformer_engine/cast_transpose_noop.h | 19 +- .../transformer_engine/comm_gemm_overlap.h | 155 +- .../include/transformer_engine/recipe.h | 19 +- .../include/transformer_engine/swizzle.h | 37 + .../transformer_engine/transformer_engine.h | 246 +- .../include/transformer_engine/transpose.h | 291 +- .../common/normalization/common.cpp | 166 +- .../common/normalization/common.h | 44 +- .../common/normalization/layernorm/ln_api.cpp | 39 +- .../normalization/rmsnorm/rmsnorm_api.cpp | 46 +- transformer_engine/common/recipe/__init__.py | 52 +- .../common/recipe/delayed_scaling.cu | 100 +- transformer_engine/common/swizzle/swizzle.cu | 338 +++ .../common/transformer_engine.cpp | 334 ++- .../common/transpose/cast_transpose.cu | 256 +- .../common/transpose/cast_transpose.h | 28 + .../common/transpose/cast_transpose_fusion.cu | 418 +-- .../common/transpose/multi_cast_transpose.cu | 68 +- .../transpose/rtc/cast_transpose_fusion.cu | 29 +- .../common/transpose/transpose.cu | 11 +- .../common/transpose/transpose_fusion.cu | 31 +- transformer_engine/common/util/cast.cu | 180 +- .../common/util/cast_gated_kernels.cuh | 1091 +++++++ .../common/util/cast_kernels.cuh | 1251 ++++++++ .../common/util/cuda_runtime.cpp | 20 + transformer_engine/common/util/cuda_runtime.h | 10 + .../common/util/dequantize_kernels.cuh | 360 +++ transformer_engine/common/util/ptx.cuh | 300 ++ .../common/util/pybind_helper.h | 152 +- transformer_engine/common/util/system.h | 2 - .../common/util/vectorized_pointwise.h | 112 +- transformer_engine/common/utils.cuh | 111 + .../jax/csrc/extensions/activation.cpp | 112 +- .../jax/csrc/extensions/quantization.cpp | 8 +- .../jax/csrc/extensions/transpose.cpp | 53 +- transformer_engine/jax/fp8.py | 5 - transformer_engine/paddle/MANIFEST.in | 3 - transformer_engine/paddle/__init__.py | 60 - transformer_engine/paddle/constants.py | 74 - transformer_engine/paddle/cpp_extensions.py | 1199 -------- transformer_engine/paddle/csrc/common.cpp | 84 - transformer_engine/paddle/csrc/common.h | 185 -- transformer_engine/paddle/csrc/custom_ops.cu | 1776 ----------- transformer_engine/paddle/csrc/extensions.cpp | 63 - transformer_engine/paddle/distributed.py | 213 -- transformer_engine/paddle/fp8.py | 370 --- transformer_engine/paddle/fp8_buffer.py | 350 --- transformer_engine/paddle/layer/__init__.py | 12 - transformer_engine/paddle/layer/attention.py | 1161 -------- transformer_engine/paddle/layer/base.py | 571 ---- transformer_engine/paddle/layer/layernorm.py | 197 -- .../paddle/layer/layernorm_linear.py | 721 ----- .../paddle/layer/layernorm_mlp.py | 1010 ------- transformer_engine/paddle/layer/linear.py | 919 ------ transformer_engine/paddle/layer/rmsnorm.py | 175 -- transformer_engine/paddle/layer/softmax.py | 254 -- .../paddle/layer/transformer.py | 375 --- transformer_engine/paddle/profile.py | 19 - transformer_engine/paddle/recompute.py | 63 - transformer_engine/paddle/setup.py | 64 - transformer_engine/paddle/utils.py | 149 - transformer_engine/pytorch/__init__.py | 15 - transformer_engine/pytorch/attention.py | 2608 ++++++----------- transformer_engine/pytorch/constants.py | 4 + .../pytorch/cpp_extensions/__init__.py | 5 - .../pytorch/cpp_extensions/_common.py | 87 - .../pytorch/cpp_extensions/activation.py | 237 -- .../pytorch/cpp_extensions/cast.py | 93 - .../pytorch/cpp_extensions/fused_attn.py | 970 +----- .../pytorch/cpp_extensions/gemm.py | 544 +--- .../pytorch/cpp_extensions/normalization.py | 260 -- .../pytorch/cpp_extensions/padding.py | 29 - .../pytorch/cpp_extensions/transpose.py | 230 -- transformer_engine/pytorch/cpu_offload.py | 18 +- transformer_engine/pytorch/csrc/common.cpp | 148 +- transformer_engine/pytorch/csrc/common.h | 169 +- transformer_engine/pytorch/csrc/extensions.h | 518 +--- .../pytorch/csrc/extensions/activation.cpp | 298 +- .../pytorch/csrc/extensions/apply_rope.cpp | 8 +- .../pytorch/csrc/extensions/attention.cu | 965 +----- .../pytorch/csrc/extensions/bias.cpp | 51 + .../pytorch/csrc/extensions/cast.cpp | 147 +- .../csrc/extensions/comm_gemm_overlap.cpp | 431 +-- .../pytorch/csrc/extensions/gemm.cpp | 488 ++- .../pytorch/csrc/extensions/normalization.cpp | 295 +- .../pytorch/csrc/extensions/padding.cpp | 1 + .../pytorch/csrc/extensions/permutation.cu | 3 + .../pytorch/csrc/extensions/pybind.cpp | 348 +-- .../pytorch/csrc/extensions/quantizer.cpp | 227 ++ .../pytorch/csrc/extensions/recipe.cpp | 23 +- .../pytorch/csrc/extensions/softmax.cpp | 16 +- .../pytorch/csrc/extensions/swizzle.cpp | 120 + .../pytorch/csrc/extensions/transpose.cpp | 482 +-- .../csrc/extensions/type_converters.cpp | 79 + .../pytorch/csrc/extensions/util.cpp | 14 +- transformer_engine/pytorch/csrc/pybind.h | 73 + transformer_engine/pytorch/csrc/ts_fp8_op.cpp | 414 --- transformer_engine/pytorch/csrc/util.h | 12 + transformer_engine/pytorch/distributed.py | 252 +- transformer_engine/pytorch/export.py | 40 - transformer_engine/pytorch/float8_tensor.py | 2 +- transformer_engine/pytorch/fp8.py | 238 +- transformer_engine/pytorch/graph.py | 16 +- transformer_engine/pytorch/module/_common.py | 150 +- transformer_engine/pytorch/module/base.py | 385 +-- .../pytorch/module/fp8_padding.py | 7 +- .../pytorch/module/fp8_unpadding.py | 9 +- .../pytorch/module/grouped_linear.py | 528 ++-- .../pytorch/module/layernorm_linear.py | 1040 +++---- .../pytorch/module/layernorm_mlp.py | 1581 +++++----- transformer_engine/pytorch/module/linear.py | 1017 +++---- transformer_engine/pytorch/ops/_common.py | 53 +- .../pytorch/ops/basic/activation.py | 161 +- .../pytorch/ops/basic/all_gather.py | 56 +- .../pytorch/ops/basic/basic_linear.py | 1024 +++---- .../pytorch/ops/basic/layer_norm.py | 76 +- .../pytorch/ops/basic/quantize.py | 30 +- .../pytorch/ops/basic/reduce_scatter.py | 52 +- .../pytorch/ops/basic/reshape.py | 5 +- .../pytorch/ops/basic/rmsnorm.py | 72 +- .../pytorch/ops/fused/backward_linear_add.py | 12 +- .../fused/forward_linear_bias_activation.py | 47 +- .../ops/fused/forward_linear_bias_add.py | 43 +- .../ops/fused/userbuffers_backward_linear.py | 13 +- .../ops/fused/userbuffers_forward_linear.py | 9 +- transformer_engine/pytorch/ops/op.py | 266 +- .../pytorch/optimizers/fused_adam.py | 38 +- transformer_engine/pytorch/permutation.py | 33 +- transformer_engine/pytorch/setup.py | 5 +- transformer_engine/pytorch/softmax.py | 155 +- .../pytorch/te_onnx_extensions.py | 519 ---- transformer_engine/pytorch/tensor/__init__.py | 18 +- .../pytorch/tensor/_internal/__init__.py | 5 +- .../tensor/_internal/float8_tensor_base.py | 139 + .../tensor/_internal/mxfp8_tensor_base.py | 136 + .../pytorch/tensor/float8_tensor.py | 1157 +++----- .../pytorch/tensor/mxfp8_tensor.py | 582 ++++ .../pytorch/tensor/quantized_tensor.py | 322 +- transformer_engine/pytorch/transformer.py | 4 +- transformer_engine/pytorch/utils.py | 47 +- 256 files changed, 20152 insertions(+), 34070 deletions(-) delete mode 100644 build_tools/paddle.py delete mode 100644 docs/api/paddle.rst delete mode 100644 examples/paddle/mnist/README.md delete mode 100644 examples/paddle/mnist/test_single_gpu_mnist.py delete mode 100644 qa/L0_paddle_lint/test.sh delete mode 100644 qa/L0_paddle_unittest/test.sh delete mode 100644 qa/L0_paddle_wheel/test.sh delete mode 100644 qa/L1_pytorch_onnx_test/test.sh create mode 100644 tests/cpp/operator/test_cast.cu create mode 100644 tests/cpp/operator/test_cast_dbias.cu create mode 100644 tests/cpp/operator/test_cast_dbias_dgelu.cu create mode 100644 tests/cpp/operator/test_cast_gated_swiglu.cu create mode 100644 tests/cpp/operator/test_cast_mxfp8.cu create mode 100644 tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu create mode 100644 tests/cpp/operator/test_dequantize_mxfp8.cu create mode 100644 tests/cpp/operator/test_normalization_mxfp8.cu create mode 100644 tests/cpp/operator/test_swizzle.cu delete mode 100644 tests/paddle/dist_launcher.py delete mode 100644 tests/paddle/parallel_tests/amax_reduction.py delete mode 100644 tests/paddle/parallel_tests/attention_tp.py delete mode 100644 tests/paddle/parallel_tests/group_sharding.py delete mode 100644 tests/paddle/parallel_tests/layernorm_linear_tp.py delete mode 100644 tests/paddle/parallel_tests/layernorm_mlp_tp.py delete mode 100644 tests/paddle/parallel_tests/linear_pp.py delete mode 100644 tests/paddle/parallel_tests/linear_tp.py delete mode 100644 tests/paddle/parallel_tests/transformer_tp.py delete mode 100644 tests/paddle/recompute_tests/recompute_transformer_encoder.py delete mode 100644 tests/paddle/test_install.py delete mode 100644 tests/paddle/test_layers.py delete mode 100644 tests/paddle/test_master_grad.py delete mode 100644 tests/paddle/test_operators.py delete mode 100644 tests/paddle/test_parallel.py delete mode 100644 tests/paddle/test_recompute.py delete mode 100644 tests/paddle/utils.py delete mode 100644 tests/pytorch/custom_ort_ops/.gitignore delete mode 100644 tests/pytorch/custom_ort_ops/CMakeLists.txt delete mode 100644 tests/pytorch/custom_ort_ops/README.md delete mode 100644 tests/pytorch/custom_ort_ops/build.sh delete mode 100755 tests/pytorch/custom_ort_ops/custom_op_library.cc create mode 100644 tests/pytorch/test_cpu_offloading.py delete mode 100644 tests/pytorch/test_onnx_export.py delete mode 100644 tests/pytorch/test_torch_save_load.py create mode 100644 transformer_engine/common/include/transformer_engine/swizzle.h create mode 100644 transformer_engine/common/swizzle/swizzle.cu create mode 100644 transformer_engine/common/transpose/cast_transpose.h create mode 100644 transformer_engine/common/util/cast_gated_kernels.cuh create mode 100644 transformer_engine/common/util/cast_kernels.cuh create mode 100644 transformer_engine/common/util/dequantize_kernels.cuh create mode 100644 transformer_engine/common/util/ptx.cuh delete mode 100644 transformer_engine/paddle/MANIFEST.in delete mode 100644 transformer_engine/paddle/__init__.py delete mode 100644 transformer_engine/paddle/constants.py delete mode 100644 transformer_engine/paddle/cpp_extensions.py delete mode 100644 transformer_engine/paddle/csrc/common.cpp delete mode 100644 transformer_engine/paddle/csrc/common.h delete mode 100644 transformer_engine/paddle/csrc/custom_ops.cu delete mode 100644 transformer_engine/paddle/csrc/extensions.cpp delete mode 100644 transformer_engine/paddle/distributed.py delete mode 100644 transformer_engine/paddle/fp8.py delete mode 100644 transformer_engine/paddle/fp8_buffer.py delete mode 100644 transformer_engine/paddle/layer/__init__.py delete mode 100644 transformer_engine/paddle/layer/attention.py delete mode 100644 transformer_engine/paddle/layer/base.py delete mode 100644 transformer_engine/paddle/layer/layernorm.py delete mode 100644 transformer_engine/paddle/layer/layernorm_linear.py delete mode 100644 transformer_engine/paddle/layer/layernorm_mlp.py delete mode 100644 transformer_engine/paddle/layer/linear.py delete mode 100644 transformer_engine/paddle/layer/rmsnorm.py delete mode 100644 transformer_engine/paddle/layer/softmax.py delete mode 100644 transformer_engine/paddle/layer/transformer.py delete mode 100644 transformer_engine/paddle/profile.py delete mode 100644 transformer_engine/paddle/recompute.py delete mode 100644 transformer_engine/paddle/setup.py delete mode 100644 transformer_engine/paddle/utils.py delete mode 100644 transformer_engine/pytorch/cpp_extensions/_common.py delete mode 100644 transformer_engine/pytorch/cpp_extensions/activation.py delete mode 100644 transformer_engine/pytorch/cpp_extensions/cast.py delete mode 100644 transformer_engine/pytorch/cpp_extensions/normalization.py delete mode 100644 transformer_engine/pytorch/cpp_extensions/padding.py delete mode 100644 transformer_engine/pytorch/cpp_extensions/transpose.py create mode 100644 transformer_engine/pytorch/csrc/extensions/bias.cpp create mode 100644 transformer_engine/pytorch/csrc/extensions/quantizer.cpp create mode 100644 transformer_engine/pytorch/csrc/extensions/swizzle.cpp create mode 100644 transformer_engine/pytorch/csrc/extensions/type_converters.cpp rename tests/pytorch/custom_ort_ops/custom_op_library.h => transformer_engine/pytorch/csrc/extensions/util.cpp (53%) mode change 100755 => 100644 create mode 100644 transformer_engine/pytorch/csrc/pybind.h delete mode 100644 transformer_engine/pytorch/csrc/ts_fp8_op.cpp create mode 100644 transformer_engine/pytorch/csrc/util.h delete mode 100755 transformer_engine/pytorch/export.py delete mode 100755 transformer_engine/pytorch/te_onnx_extensions.py rename tests/paddle/test_sanity_import.py => transformer_engine/pytorch/tensor/_internal/__init__.py (69%) create mode 100644 transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py create mode 100644 transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py create mode 100644 transformer_engine/pytorch/tensor/mxfp8_tensor.py diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 964e71fa8c..4be7a30a86 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -73,23 +73,3 @@ jobs: MAX_JOBS: 1 - name: 'Sanity check' run: python tests/jax/test_sanity_import.py - paddle: - name: 'PaddlePaddle' - runs-on: ubuntu-latest - container: - image: nvcr.io/nvidia/paddlepaddle:24.10-py3 - options: --user root - steps: - - name: 'Checkout' - uses: actions/checkout@v3 - with: - submodules: recursive - - name: 'Build' - run: | - apt-get update - apt-get install -y libgoogle-glog-dev - pip install . -v - env: - NVTE_FRAMEWORK: paddle - - name: 'Sanity check' - run: python tests/paddle/test_sanity_import.py diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index f98fc9aa3a..ee6433d484 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -61,30 +61,3 @@ jobs: export PYTHON_ONLY=1 export TE_PATH=. bash ./qa/L0_jax_lint/test.sh - paddle_cpplint: - name: 'PaddlePaddle C++' - runs-on: ubuntu-latest - steps: - - name: Checkout - uses: actions/checkout@v3 - - name: 'Lint' - run: | - sudo apt-get update - sudo apt-get install pip -y - export CPP_ONLY=1 - export TE_PATH=. - bash ./qa/L0_paddle_lint/test.sh - paddle_pylint: - name: 'PaddlePaddle Python' - runs-on: ubuntu-latest - steps: - - name: 'Checkout' - uses: actions/checkout@v3 - - name: 'Lint' - run: | - sudo apt-get update - sudo apt-get install pip -y - pip install paddlepaddle-gpu - export PYTHON_ONLY=1 - export TE_PATH=. - bash ./qa/L0_paddle_lint/test.sh diff --git a/.gitignore b/.gitignore index 9b61454e21..f491b21f43 100644 --- a/.gitignore +++ b/.gitignore @@ -8,7 +8,6 @@ *.nsys-rep *.ncu-rep *.sqlite -*.onnx *.eggs build/ *.so diff --git a/3rdparty/cudnn-frontend b/3rdparty/cudnn-frontend index cc5632eda7..91b7532f33 160000 --- a/3rdparty/cudnn-frontend +++ b/3rdparty/cudnn-frontend @@ -1 +1 @@ -Subproject commit cc5632eda70bbdac34455c2d94066d27d10e2699 +Subproject commit 91b7532f3386768bba4f444ee7672b497f34da8a diff --git a/README.rst b/README.rst index fbcf05f3c9..8fea8c9d94 100644 --- a/README.rst +++ b/README.rst @@ -174,7 +174,7 @@ To install the latest stable version of Transformer Engine, pip install git+https://github.com/NVIDIA/TransformerEngine.git@stable -This will automatically detect if any supported deep learning frameworks are installed and build Transformer Engine support for them. To explicitly specify frameworks, set the environment variable NVTE_FRAMEWORK to a comma-separated list (e.g. NVTE_FRAMEWORK=jax,pytorch,paddle). +This will automatically detect if any supported deep learning frameworks are installed and build Transformer Engine support for them. To explicitly specify frameworks, set the environment variable NVTE_FRAMEWORK to a comma-separated list (e.g. NVTE_FRAMEWORK=jax,pytorch). Alternatively, the package can be directly installed from `Transformer Engine's PyPI `_, e.g. @@ -182,7 +182,7 @@ Alternatively, the package can be directly installed from `Transformer Engine's pip install transformer_engine[pytorch] -To obtain the necessary Python bindings for Transformer Engine, the frameworks needed must be explicitly specified as extra dependencies in a comma-separated list (e.g. [jax,pytorch,paddle]). Transformer Engine ships wheels for the core library as well as the PaddlePaddle extensions. Source distributions are shipped for the JAX and PyTorch extensions. +To obtain the necessary Python bindings for Transformer Engine, the frameworks needed must be explicitly specified as extra dependencies in a comma-separated list (e.g. [jax,pytorch]). Transformer Engine ships wheels for the core library. Source distributions are shipped for the JAX and PyTorch extensions. From source ^^^^^^^^^^^ diff --git a/build_tools/VERSION.txt b/build_tools/VERSION.txt index 809a0327d8..eb5820cd2d 100644 --- a/build_tools/VERSION.txt +++ b/build_tools/VERSION.txt @@ -1 +1 @@ -1.14.0.dev0 +2.1.0.dev0 diff --git a/build_tools/build_ext.py b/build_tools/build_ext.py index 5744439c1b..a3243d087b 100644 --- a/build_tools/build_ext.py +++ b/build_tools/build_ext.py @@ -129,63 +129,6 @@ def run(self) -> None: super().run() self.extensions = all_extensions - paddle_ext = None - if "paddle" in get_frameworks(): - for ext in self.extensions: - if "paddle" in ext.name: - paddle_ext = ext - break - - # Manually write stub file for Paddle extension - if paddle_ext is not None: - # Load libtransformer_engine.so to avoid linker errors - if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))): - # Source compilation from top-level (--editable) - search_paths = list(Path(__file__).resolve().parent.parent.iterdir()) - # Source compilation from top-level - search_paths.extend(list(Path(self.build_lib).iterdir())) - - # Dynamically load required_libs. - from transformer_engine.common import _load_cudnn, _load_nvrtc - - _load_cudnn() - _load_nvrtc() - else: - # Only during release bdist build for paddlepaddle. - import transformer_engine - - search_paths = list(Path(transformer_engine.__path__[0]).iterdir()) - del transformer_engine - - common_so_path = "" - for path in search_paths: - if path.name.startswith("libtransformer_engine."): - common_so_path = str(path) - assert common_so_path, "Could not find libtransformer_engine" - ctypes.CDLL(common_so_path, mode=ctypes.RTLD_GLOBAL) - - # Figure out stub file path - module_name = paddle_ext.name - assert module_name.endswith( - "_pd_" - ), "Expected Paddle extension module to end with '_pd_'" - stub_name = module_name[:-4] # remove '_pd_' - stub_path = os.path.join(self.build_lib, "transformer_engine", stub_name + ".py") - Path(stub_path).parent.mkdir(exist_ok=True, parents=True) - - # Figure out library name - # Note: This library doesn't actually exist. Paddle - # internally reinserts the '_pd_' suffix. - so_path = self.get_ext_fullpath(module_name) - _, so_ext = os.path.splitext(so_path) - lib_name = stub_name + so_ext - - # Write stub file - print(f"Writing Paddle stub for {lib_name} into file {stub_path}") - from paddle.utils.cpp_extension.extension_utils import custom_write_stub - - custom_write_stub(lib_name, stub_path) - # Ensure that binaries are not in global package space. target_dir = install_dir / "transformer_engine" target_dir.mkdir(exist_ok=True, parents=True) @@ -194,16 +137,10 @@ def run(self) -> None: self.copy_file(ext, target_dir) os.remove(ext) - # For paddle, the stub file needs to be copied to the install location. - if paddle_ext is not None: - stub_path = Path(self.build_lib) / "transformer_engine" - for stub in stub_path.glob("transformer_engine_paddle.py"): - self.copy_file(stub, target_dir) - def build_extensions(self): - # BuildExtensions from PyTorch and PaddlePaddle already handle CUDA files correctly + # BuildExtensions from PyTorch already handle CUDA files correctly # so we don't need to modify their compiler. Only the pybind11 build_ext needs to be fixed. - if "pytorch" not in get_frameworks() and "paddle" not in get_frameworks(): + if "pytorch" not in get_frameworks(): # Ensure at least an empty list of flags for 'cxx' and 'nvcc' when # extra_compile_args is a dict. for ext in self.extensions: diff --git a/build_tools/paddle.py b/build_tools/paddle.py deleted file mode 100644 index f0fcdb8f25..0000000000 --- a/build_tools/paddle.py +++ /dev/null @@ -1,92 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -"""Paddle-paddle related extensions.""" -from pathlib import Path - -import setuptools -import os - -from .utils import cuda_version - -import paddle - -paddle_version = paddle.__version__.replace(".", "") - - -def setup_paddle_extension( - csrc_source_files, - csrc_header_files, - common_header_files, -) -> setuptools.Extension: - """Setup CUDA extension for Paddle support""" - - # Source files - csrc_source_files = Path(csrc_source_files) - sources = [ - csrc_source_files / "extensions.cpp", - csrc_source_files / "common.cpp", - csrc_source_files / "custom_ops.cu", - ] - - # Header files - include_dirs = [ - common_header_files, - common_header_files / "common", - common_header_files / "common" / "include", - csrc_header_files, - ] - - # Compiler flags - cxx_flags = ["-O3"] - nvcc_flags = [ - "-O3", - "-gencode", - "arch=compute_70,code=sm_70", - "-U__CUDA_NO_HALF_OPERATORS__", - "-U__CUDA_NO_HALF_CONVERSIONS__", - "-U__CUDA_NO_BFLOAT16_OPERATORS__", - "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", - "-U__CUDA_NO_BFLOAT162_OPERATORS__", - "-U__CUDA_NO_BFLOAT162_CONVERSIONS__", - f"-DPADDLE_VERSION={paddle_version}", - "--expt-relaxed-constexpr", - "--expt-extended-lambda", - "--use_fast_math", - ] - - # Version-dependent CUDA options - try: - version = cuda_version() - except FileNotFoundError: - print("Could not determine CUDA Toolkit version") - else: - if version < (12, 0): - raise RuntimeError("Transformer Engine requires CUDA 12.0 or newer") - nvcc_flags.extend( - ( - "--threads", - os.getenv("NVTE_BUILD_THREADS_PER_JOB", "1"), - "-gencode", - "arch=compute_80,code=sm_80", - "-gencode", - "arch=compute_90,code=sm_90", - ) - ) - - # Construct Paddle CUDA extension - sources = [str(path) for path in sources] - include_dirs = [str(path) for path in include_dirs] - from paddle.utils.cpp_extension import CUDAExtension - - ext = CUDAExtension( - sources=sources, - include_dirs=include_dirs, - extra_compile_args={ - "cxx": cxx_flags, - "nvcc": nvcc_flags, - }, - ) - ext.name = "transformer_engine_paddle_pd_" - return ext diff --git a/build_tools/pytorch.py b/build_tools/pytorch.py index f060e99dff..b8501e1008 100644 --- a/build_tools/pytorch.py +++ b/build_tools/pytorch.py @@ -27,7 +27,6 @@ def setup_pytorch_extension( extensions_dir = csrc_source_files / "extensions" sources = [ csrc_source_files / "common.cpp", - csrc_source_files / "ts_fp8_op.cpp", ] + all_files_in_dir(extensions_dir) # Header files diff --git a/build_tools/utils.py b/build_tools/utils.py index f2a4200685..723f2f200c 100644 --- a/build_tools/utils.py +++ b/build_tools/utils.py @@ -190,7 +190,12 @@ def cuda_path() -> Tuple[str, str]: @functools.lru_cache(maxsize=None) def cuda_archs() -> str: - return os.getenv("NVTE_CUDA_ARCHS", "70;80;89;90") + version = cuda_version() + if os.getenv("NVTE_CUDA_ARCHS") is None: + os.environ["NVTE_CUDA_ARCHS"] = ( + "70;80;89;90;100;120" if version >= (12, 8) else "70;80;89;90" + ) + return os.getenv("NVTE_CUDA_ARCHS") def cuda_version() -> Tuple[int, ...]: @@ -211,7 +216,7 @@ def cuda_version() -> Tuple[int, ...]: def get_frameworks() -> List[str]: """DL frameworks to build support for""" _frameworks: List[str] = [] - supported_frameworks = ["pytorch", "jax", "paddle"] + supported_frameworks = ["pytorch", "jax"] # Check environment variable if os.getenv("NVTE_FRAMEWORK"): @@ -237,12 +242,6 @@ def get_frameworks() -> List[str]: pass else: _frameworks.append("jax") - try: - import paddle - except ImportError: - pass - else: - _frameworks.append("paddle") # Special framework names if "all" in _frameworks: @@ -311,7 +310,6 @@ def uninstall_te_wheel_packages(): "-y", "transformer_engine_cu12", "transformer_engine_torch", - "transformer_engine_paddle", "transformer_engine_jax", ] ) diff --git a/build_tools/wheel_utils/build_wheels.sh b/build_tools/wheel_utils/build_wheels.sh index ceebe626f4..9acb22aee6 100644 --- a/build_tools/wheel_utils/build_wheels.sh +++ b/build_tools/wheel_utils/build_wheels.sh @@ -9,7 +9,6 @@ BUILD_METAPACKAGE=${2:-true} BUILD_COMMON=${3:-true} BUILD_PYTORCH=${4:-true} BUILD_JAX=${5:-true} -BUILD_PADDLE=${6:-true} export NVTE_RELEASE_BUILD=1 export TARGET_BRANCH=${TARGET_BRANCH:-} @@ -63,38 +62,3 @@ if $BUILD_JAX ; then /opt/python/cp310-cp310/bin/python setup.py sdist 2>&1 | tee /wheelhouse/logs/jax.txt cp dist/* /wheelhouse/ fi - -if $BUILD_PADDLE ; then - if [ "$PLATFORM" == "manylinux_2_28_x86_64" ] ; then - dnf -y remove --allowerasing cudnn9-cuda-12 - dnf -y install libcudnn8-devel.x86_64 libcudnn8.x86_64 - cd /TransformerEngine/transformer_engine/paddle - - /opt/python/cp38-cp38/bin/pip install /wheelhouse/*.whl --no-deps - /opt/python/cp38-cp38/bin/pip install paddlepaddle-gpu==2.6.1 - /opt/python/cp38-cp38/bin/python setup.py bdist_wheel --verbose --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/paddle_cp38.txt - /opt/python/cp38-cp38/bin/pip uninstall -y transformer-engine transformer-engine-cu12 paddlepaddle-gpu - - /opt/python/cp39-cp39/bin/pip install /wheelhouse/*.whl --no-deps - /opt/python/cp39-cp39/bin/pip install paddlepaddle-gpu==2.6.1 - /opt/python/cp39-cp39/bin/python setup.py bdist_wheel --verbose --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/paddle_cp39.txt - /opt/python/cp39-cp39/bin/pip uninstall -y transformer-engine transformer-engine-cu12 paddlepaddle-gpu - - /opt/python/cp310-cp310/bin/pip install /wheelhouse/*.whl --no-deps - /opt/python/cp310-cp310/bin/pip install paddlepaddle-gpu==2.6.1 - /opt/python/cp310-cp310/bin/python setup.py bdist_wheel --verbose --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/paddle_cp310.txt - /opt/python/cp310-cp310/bin/pip uninstall -y transformer-engine transformer-engine-cu12 paddlepaddle-gpu - - /opt/python/cp311-cp311/bin/pip install /wheelhouse/*.whl --no-deps - /opt/python/cp311-cp311/bin/pip install paddlepaddle-gpu==2.6.1 - /opt/python/cp311-cp311/bin/python setup.py bdist_wheel --verbose --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/paddle_cp311.txt - /opt/python/cp311-cp311/bin/pip uninstall -y transformer-engine transformer-engine-cu12 paddlepaddle-gpu - - /opt/python/cp312-cp312/bin/pip install /wheelhouse/*.whl --no-deps - /opt/python/cp312-cp312/bin/pip install paddlepaddle-gpu==2.6.1 - /opt/python/cp312-cp312/bin/python setup.py bdist_wheel --verbose --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/paddle_cp312.txt - /opt/python/cp312-cp312/bin/pip uninstall -y transformer-engine transformer-engine-cu12 paddlepaddle-gpu - - mv dist/* /wheelhouse/ - fi -fi diff --git a/docs/api/common.rst b/docs/api/common.rst index 85201aee5d..5e0a660ae6 100644 --- a/docs/api/common.rst +++ b/docs/api/common.rst @@ -8,4 +8,4 @@ Common API .. autoapiclass:: transformer_engine.common.recipe.Format -.. autoapiclass:: transformer_engine.common.recipe.DelayedScaling(margin=0, fp8_format=Format.HYBRID, amax_history_len=1024, amax_compute_algo="max", scaling_factor_compute_algo=None, override_linear_precision=(False, False, False)) +.. autoapiclass:: transformer_engine.common.recipe.DelayedScaling(margin=0, fp8_format=Format.HYBRID, amax_history_len=1024, amax_compute_algo="max", scaling_factor_compute_algo=None) diff --git a/docs/api/framework.rst b/docs/api/framework.rst index acd54fe3b1..0ac1a0e34e 100644 --- a/docs/api/framework.rst +++ b/docs/api/framework.rst @@ -10,4 +10,3 @@ Framework-specific API pytorch jax - paddle diff --git a/docs/api/paddle.rst b/docs/api/paddle.rst deleted file mode 100644 index 3b3ecf55c6..0000000000 --- a/docs/api/paddle.rst +++ /dev/null @@ -1,34 +0,0 @@ -.. - Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - - See LICENSE for license information. - -paddle -====== - -.. autoapiclass:: transformer_engine.paddle.Linear(in_features, out_features, **kwargs) - :members: forward - -.. autoapiclass:: transformer_engine.paddle.LayerNorm(hidden_size, eps=1e-5, **kwargs) - -.. autoapiclass:: transformer_engine.paddle.LayerNormLinear(in_features, out_features, eps=1e-5, **kwargs) - :members: forward - -.. autoapiclass:: transformer_engine.paddle.LayerNormMLP(hidden_size, ffn_hidden_size, eps=1e-5, **kwargs) - :members: forward - -.. autoapiclass:: transformer_engine.paddle.FusedScaleMaskSoftmax(attn_mask_type, mask_func, **kwargs) - :members: forward - -.. autoapiclass:: transformer_engine.paddle.DotProductAttention(num_attention_heads, kv_channels, **kwargs) - :members: forward - -.. autoapiclass:: transformer_engine.paddle.MultiHeadAttention(hidden_size, num_attention_heads, **kwargs) - :members: forward - -.. autoapiclass:: transformer_engine.paddle.TransformerLayer(hidden_size, ffn_hidden_size, num_attention_heads, **kwargs) - :members: forward - -.. autoapifunction:: transformer_engine.paddle.fp8_autocast - -.. autoapifunction:: transformer_engine.paddle.recompute diff --git a/docs/api/pytorch.rst b/docs/api/pytorch.rst index 43001feeb3..6d5fe6761d 100644 --- a/docs/api/pytorch.rst +++ b/docs/api/pytorch.rst @@ -42,8 +42,6 @@ pyTorch .. autoapifunction:: transformer_engine.pytorch.checkpoint -.. autoapifunction:: transformer_engine.pytorch.onnx_export - .. autoapifunction:: transformer_engine.pytorch.make_graphed_callables .. autoapifunction:: transformer_engine.pytorch.get_cpu_offload_context diff --git a/docs/examples/attention/attention.ipynb b/docs/examples/attention/attention.ipynb index 27017b4773..16a3b05466 100644 --- a/docs/examples/attention/attention.ipynb +++ b/docs/examples/attention/attention.ipynb @@ -14,11 +14,10 @@ "
Figure 1: Dot product attention.
\n", "\n", "\n", - "[Transformer Engine](https://github.com/NVIDIA/TransformerEngine.git) supports the calculation of dot product attention in three frameworks, [PyTorch](https://github.com/pytorch/pytorch), [JAX](https://github.com/google/jax) and [PaddlePaddle](https://github.com/PaddlePaddle/Paddle). The API for each framework is\n", + "[Transformer Engine](https://github.com/NVIDIA/TransformerEngine.git) supports the calculation of dot product attention in two frameworks, [PyTorch](https://github.com/pytorch/pytorch) and [JAX](https://github.com/google/jax). The API for each framework is\n", "\n", "- [transformer_engine.pytorch.DotProductAttention](../../api/pytorch.rst#transformer_engine.pytorch.DotProductAttention)\n", - "- [transformer_engine.jax.flax.DotProductAttention](../../api/jax.rst#transformer_engine.jax.flax.DotProductAttention)\n", - "- [transformer_engine.paddle.DotProductAttention](../../api/paddle.rst#transformer_engine.paddle.DotProductAttention)" + "- [transformer_engine.jax.flax.DotProductAttention](../../api/jax.rst#transformer_engine.jax.flax.DotProductAttention)" ] }, { @@ -56,15 +55,6 @@ " \n", " JAX-native attention (`_UnfusedDotProductAttention`)\n", " \n", - " \n", - " PaddlePaddle\n", - " cuDNN attention (`_te_forward`) \n", - " [transformer_engine.paddle.layer.attention](https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/paddle/layer/attention.py)\n", - " \n", - " \n", - " \n", - " PaddlePaddle-native attention (`_pd_forward`)\n", - " \n", " \n", "" ] @@ -87,7 +77,7 @@ "
\n", "Note: \n", " \n", - "Transformer Engine's flash-attention backend, available in PyTorch, and cuDNN attention backend (sub-backends 1 and 2), available in PyTorch, JAX and PaddlePaddle, are both based on the flash algorithm.\n", + "Transformer Engine's flash-attention backend, available in PyTorch, and cuDNN attention backend (sub-backends 1 and 2), available in PyTorch and JAX, are both based on the flash algorithm.\n", "
\n" ] }, @@ -102,13 +92,13 @@ "\n", "The flash-attention backend supports `flash-attn`'s features as well as a few extra functionalities to facilitate the use of `flash-attn`, such as converting the `attention_mask` to cumulative sequence lengths `cu_seqlens` for `padding` mask use cases. Please see `transformer_engine.pytorch.attention.FlashAttention` for details.\n", "\n", - "The `flash-attn` dependency is regularly updated in Transformer Engine. As of v1.10, Transformer Engine supports `flash-attn` 2.0.6+ (see [setup.py](https://github.com/NVIDIA/TransformerEngine/blob/main/setup.py)).\n", + "The `flash-attn` dependency is regularly updated in Transformer Engine. As of v2.0, Transformer Engine supports `flash-attn` 2.0.6+ (see [setup.py](https://github.com/NVIDIA/TransformerEngine/blob/main/setup.py)).\n", "\n", "To understand `flash-attn`'s performance, please refer to their benchmarks [here](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#performance).\n", "\n", "### 1.3 cuDNN Attention\n", "\n", - "The cuDNN attention backend, available in PyTorch, JAX and PaddlePaddle, offers another high-performance solution to the attention calculation. It requires [cuDNN](https://developer.nvidia.com/cudnn) to run, and has several sub-backends to support the different precisions and sequence lengths.\n", + "The cuDNN attention backend, available in PyTorch and JAX, offers another high-performance solution to the attention calculation. It requires [cuDNN](https://developer.nvidia.com/cudnn) to run, and has several sub-backends to support the different precisions and sequence lengths.\n", "\n", "\n", " \n", @@ -153,9 +143,9 @@ " \n", "
\n", "\n", - "The cuDNN attention backend and flash-attention backend have several notable differences. As of Transformer Engine 1.10, cuDNN 9.3 and `flash-attn` 2.4.2,\n", + "The cuDNN attention backend and flash-attention backend have several notable differences. As of Transformer Engine 2.0, cuDNN 9.3 and `flash-attn` 2.4.2,\n", "\n", - "- flash-attention only supports the PyTorch framework while cuDNN attention supports PyTorch, JAX and PaddlePaddle.\n", + "- flash-attention only supports the PyTorch framework while cuDNN attention supports PyTorch and JAX.\n", "- flash-attention supports BF16, FP16 precisions while cuDNN attention also supports FP8 (through its sub-backend 2).\n", "- flash-attention supports `bshd`, `thd` input formats, without any transposes, and `sbhd` format, with transposes, while cuDNN attention supports all three formats without transposes (see Section 3.1 for more details).\n", "- flash-attention does not support `post_scale_bias`, and cuDNN attention does.\n", @@ -244,10 +234,6 @@ " JAX\n", " cuDNN attention > JAX-native attention\n", " \n", - " \n", - " PaddlePaddle\n", - " cuDNN attention > PaddlePaddle-native attention \n", - " \n", "" ] }, @@ -266,7 +252,7 @@ "
\n", "Note:\n", " \n", - "These flags are supported in PyTorch only as of Transformer Engine 1.10. JAX and PaddlePaddle support is expected to be added in the future.\n", + "These flags are supported in PyTorch only as of Transformer Engine 2.0. JAX support is expected to be added in the future.\n", "
" ] }, @@ -382,7 +368,7 @@ "
\n", "Note\n", " \n", - "Environment variables NVTE_FLASH_ATTN, NVTE_FUSED_ATTN, NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT and NVTE_ALLOW_NONDETERMINISTIC_ALGO are only supported in PyTorch, and will be added to JAX and PaddlePaddle in the future.\n", + "Environment variables NVTE_FLASH_ATTN, NVTE_FUSED_ATTN, NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT and NVTE_ALLOW_NONDETERMINISTIC_ALGO are only supported in PyTorch, and will be added to JAX in the future.\n", "
\n", "\n", "### 2.3 Example Tests\n", @@ -399,7 +385,7 @@ "source": [ "## 3. Backend Support\n", "\n", - "Transformer Engine supports commonly-used features such as self and cross attention, FP16/BF16 precisions, dropout, and checkpointing. But it also offers a range of other features. As of v1.10, Transformer Engine's attention backends have the following support matrix.\n", + "Transformer Engine supports commonly-used features such as self and cross attention, FP16/BF16 precisions, dropout, and checkpointing. But it also offers a range of other features. As of v2.0, Transformer Engine's attention backends have the following support matrix.\n", "\n", "| Attention Backend | Precision | Architecture | Sliding Window Attention | MQA/GQA | Multi-Latent Attention | Context Parallelism | Determinism Possible |\n", "| :---------------- | :-------- | :----------- | :----------------------- | :------ | :--------------------- | :------------------ | :------------ |\n", @@ -442,7 +428,7 @@ "**qkv_layout=thd_thd_thd:**\n", "`q`, `k`, `v` have variable sequence lengths in a batch. They are all contiguous and have no interleaving.\n", "\n", - "As of v1.10, Transformer Engine has the following support matrix.\n", + "As of v2.0, Transformer Engine has the following support matrix.\n", "\n", "\n", " \n", @@ -462,13 +448,13 @@ " \n", " \n", " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", "
\n", - " JAX, PaddlePaddle: `bs3hd`, `bshd_bs2hd`, `bshd_bshd_bshd` layouts\n", + " JAX: `bs3hd`, `bshd_bs2hd`, `bshd_bshd_bshd` layouts\n", "
Framework-native attention`bshd`, `sbhd`PyTorch, JAX, PaddlePaddle: 2 formats, i.e. 10 layoutsPyTorch, JAX: 2 formats, i.e. 10 layouts
\n", "\n", @@ -492,7 +478,7 @@ "\n", "- `no_mask`, `padding`, `causal`, `causal_bottom_right`, `padding_causal`, `padding_causal_bottom_right`, `arbitrary`\n", "\n", - "Different backends offer different support for attention mask. As of Transformer Engine 1.10,\n", + "Different backends offer different support for attention mask. As of Transformer Engine 2.0,\n", "\n", "\n", " \n", @@ -512,21 +498,21 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", "
Framework-native attention
  • All (PyTorch)
  • `no_mask`, `causal`, `padding` (Jax, PaddlePaddle)
  • All (PyTorch)
  • `no_mask`, `causal`, `padding` (Jax)
  • \n", "\n", - "**Padding masks:** For `padding`, `padding_causal`, `padding_causal_bottom_right` mask types, users need to provide sequence length information to help Transformer Engine figure out where each sequence ends in a batch. As of Transformer Engine 1.10, there are two options to do so in PyTorch and one in JAX and PaddlePaddle.\n", + "**Padding masks:** For `padding`, `padding_causal`, `padding_causal_bottom_right` mask types, users need to provide sequence length information to help Transformer Engine figure out where each sequence ends in a batch. As of Transformer Engine 2.0, there are two options to do so in PyTorch and one in JAX.\n", "\n", "* PyTorch: When both options are provided by the user, `cu_seqlens` is preferred as there is no extra conversion needed.\n", " - `cu_seqlens`: Users can provide cumulative sequence length tensors `cu_seqlens_q` and `cu_seqlens_kv` for `q` and `k`/`v` to the flash-attention or cuDNN attention backend. An example of `cu_seqlens` is `[0, 2, 6, 7]` for a batch of 3 `[aa000, bbbb0, c0000]`.\n", " - `attention_mask`: Users can also provide `attention_mask` as an alternative, which will then be converted to `cu_seqlens`. For self-attention, `attention_mask` should be one single tensor in shape `[batch_size, 1, 1, seqlen_q]`, and for cross-attention, `attention_mask` should be a list of two tensors in shapes `[batch_size, 1, 1, seqlen_q]` and `[batch_size, 1, 1, seqlen_kv]`, respectively.\n", "\n", "\n", - "* JAX and PaddlePaddle: Users should provide the `attention_mask` tensor in shape `[batch_size, 1, seqlen_q, seqlen_kv]`.\n", + "* JAX: Users should provide the `attention_mask` tensor in shape `[batch_size, 1, seqlen_q, seqlen_kv]`.\n", "\n", "**qkv_format=thd:** Transformer Engine extracts the max sequence length information from `q`, `k`, `v` if `max_seqlen_q` and `max_seqlen_kv` are not provided. This requires GPU-CPU copy and synchronization operations. For performance reasons, please set `max_seqlen_q` and `max_seqlen_kv` to their appropriate values for `thd` QKV format.\n", "\n", @@ -566,7 +552,7 @@ "\n", "### 3.3 Attention Bias\n", "\n", - "Transformer Engine supports 4 attention bias types, `no_bias`, `pre_scale_bias`, `post_scale_bias`, and `ALiBi` (with/without custom slopes). As of Transformer Engine 1.10, their support matrix is as follows.\n", + "Transformer Engine supports 4 attention bias types, `no_bias`, `pre_scale_bias`, `post_scale_bias`, and `ALiBi` (with/without custom slopes). As of Transformer Engine 2.0, their support matrix is as follows.\n", "\n", "\n", " \n", @@ -591,7 +577,7 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -620,7 +606,7 @@ "\n", "A unique feature of Transformer Engine is its FP8 support, not only for the `Linear` layers but also for dot product attention. Transformer Engine's FP8 attention support is through its cuDNN attention sub-backend 2. Recall Figure 1: the two `MatMul` operations are performed in FP8 for computational efficiency, and the `SoftMax` operation is performed in FP32 for numerical accuracy.\n", "\n", - "Transformer Engine supports FP8 attention through its [C APIs](../../api/c/fused_attn.rst), and [PyTorch API](../../api/pytorch.rst#transformer_engine.pytorch.DotProductAttention), as of v1.10. Its PyTorch API offers two options, both controlled through the FP8 recipe definition, `transformer_engine.common.recipe.DelayedScaling`.\n", + "Transformer Engine supports FP8 attention through its [C APIs](../../api/c/fused_attn.rst), and [PyTorch API](../../api/pytorch.rst#transformer_engine.pytorch.DotProductAttention), as of v2.0. Its PyTorch API offers two options, both controlled through the FP8 recipe definition, `transformer_engine.common.recipe.DelayedScaling`.\n", "\n", "- `DelayedScaling.fp8_dpa=True (default=False)`: This enables the use of cuDNN attention sub-backend 2, when it does support the provided user inputs. The `FusedAttention` module for cuDNN attention takes FP16 or BF16 tensors as inputs, performs dot product attention in FP8, and returns attention logits in FP16 or BF16 (same as the input type). Casting operations are required to cast tensors to FP8 at the beginning, and back to FP16/BF16 at the end of the module.\n", "\n", diff --git a/docs/installation.rst b/docs/installation.rst index fae01c64fa..ee7afa9006 100644 --- a/docs/installation.rst +++ b/docs/installation.rst @@ -37,7 +37,7 @@ Transformer Engine can be directly installed from `our PyPI desired_test_accuracy - - @unittest.skipIf( - paddle.device.cuda.get_device_capability() < (8, 0), - "BF16 MNIST example requires Ampere+ GPU", - ) - def test_te_bf16(self): - """Test Transformer Engine with BF16""" - self.args.use_te = True - self.args.use_fp8 = False - self.args.save_model = True - actual = train_and_evaluate(self.args) - if os.path.exists("mnist_cnn.pdparams"): - os.remove("mnist_cnn.pdparams") - self.verify(actual) - - @unittest.skipIf(not gpu_has_fp8, reason) - def test_te_fp8(self): - """Test Transformer Engine with FP8""" - self.args.use_te = True - self.args.use_fp8 = True - self.args.save_model = True - actual = train_and_evaluate(self.args) - if os.path.exists("mnist_cnn.pdparams"): - os.remove("mnist_cnn.pdparams") - self.verify(actual) - - @unittest.skipIf(not gpu_has_fp8, reason) - def test_te_fp8_calibration(self): - """Test Transformer Engine with FP8 calibration""" - self.args.use_te = True - self.args.use_fp8 = False - self.args.use_fp8_infer = True - actual = train_and_evaluate(self.args) - if os.path.exists("mnist_cnn.pdparams"): - os.remove("mnist_cnn.pdparams") - self.verify(actual) - - -if __name__ == "__main__": - train_and_evaluate(mnist_parser(None)) diff --git a/pylintrc b/pylintrc index b80679d72c..4af0c6b427 100644 --- a/pylintrc +++ b/pylintrc @@ -2,7 +2,6 @@ extension-pkg-whitelist=flash_attn_2_cuda, torch, transformer_engine_torch, - transformer_engine_paddle, transformer_engine_jax extension-pkg-allow-list=transformer_engine.transformer_engine_jax diff --git a/qa/L0_jax_unittest/test.sh b/qa/L0_jax_unittest/test.sh index 6eff047721..8e2e540293 100644 --- a/qa/L0_jax_unittest/test.sh +++ b/qa/L0_jax_unittest/test.sh @@ -8,7 +8,7 @@ pip install "nltk>=3.8.2" pip install pytest==8.2.1 : ${TE_PATH:=/opt/transformerengine} -pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax -k 'not distributed' +pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax -k 'not distributed' --ignore=$TE_PATH/tests/jax/test_praxis_layers.py # Test without custom calls NVTE_CUSTOM_CALLS_RE="" pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_custom_call_compute.py diff --git a/qa/L0_paddle_lint/test.sh b/qa/L0_paddle_lint/test.sh deleted file mode 100644 index 1c26bd265b..0000000000 --- a/qa/L0_paddle_lint/test.sh +++ /dev/null @@ -1,24 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -set -e - -: "${TE_PATH:=/opt/transformerengine}" - -pip install cpplint==1.6.0 pylint==3.3.1 -if [ -z "${PYTHON_ONLY}" ] -then - cd $TE_PATH - echo "Checking common API headers" - cpplint --root transformer_engine/common/include --recursive transformer_engine/common/include - echo "Checking C++ files" - cpplint --recursive --exclude=transformer_engine/common/include --exclude=transformer_engine/build_tools/build transformer_engine/common - cpplint --recursive transformer_engine/paddle -fi -if [ -z "${CPP_ONLY}" ] -then - cd $TE_PATH - echo "Checking Python files" - pylint --recursive=y transformer_engine/common transformer_engine/paddle -fi diff --git a/qa/L0_paddle_unittest/test.sh b/qa/L0_paddle_unittest/test.sh deleted file mode 100644 index 9312f22ba4..0000000000 --- a/qa/L0_paddle_unittest/test.sh +++ /dev/null @@ -1,10 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -set -xe - -pip install pytest==8.2.1 -: ${TE_PATH:=/opt/transformerengine} -pytest -Wignore -v $TE_PATH/tests/paddle -pytest -Wignore -v $TE_PATH/examples/paddle/mnist diff --git a/qa/L0_paddle_wheel/test.sh b/qa/L0_paddle_wheel/test.sh deleted file mode 100644 index 5116bdb5cf..0000000000 --- a/qa/L0_paddle_wheel/test.sh +++ /dev/null @@ -1,37 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -set -e - -: "${TE_PATH:=/opt/transformerengine}" - -# Install dependencies -# Note: Need to install wheel locally since PaddlePaddle container -# already contains APT install. -pip install pydantic -pip install --user wheel==0.44.0 - -cd $TE_PATH -pip uninstall -y transformer-engine transformer-engine-cu12 transformer-engine-paddle - -VERSION=`cat $TE_PATH/build_tools/VERSION.txt` -WHL_BASE="transformer_engine-${VERSION}" - -# Core wheel. -NVTE_RELEASE_BUILD=1 python setup.py bdist_wheel -python -m wheel unpack dist/* -sed -i "s/Name: transformer-engine/Name: transformer-engine-cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" -sed -i "s/Name: transformer_engine/Name: transformer_engine_cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" -mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_cu12-${VERSION}.dist-info" -python -m wheel pack ${WHL_BASE} -rm dist/*.whl -mv *.whl dist/ -NVTE_RELEASE_BUILD=1 NVTE_BUILD_METAPACKAGE=1 python setup.py bdist_wheel -pip install dist/*.whl --no-deps - -cd transformer_engine/paddle -NVTE_RELEASE_BUILD=1 python setup.py bdist_wheel -pip install dist/* - -python $TE_PATH/tests/paddle/test_sanity_import.py diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index 793fa47259..dd7f95bce0 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -11,11 +11,10 @@ pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py pytest -v -s $TE_PATH/tests/pytorch/test_recipe.py pytest -v -s $TE_PATH/tests/pytorch/test_deferred_init.py PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py -PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_cuda_graphs.py +NVTE_CUDNN_MXFP8_NORM=0 PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_cuda_graphs.py pytest -v -s $TE_PATH/tests/pytorch/test_jit.py pytest -v -s $TE_PATH/tests/pytorch/test_fused_rope.py pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py -pytest -v -s $TE_PATH/tests/pytorch/test_torch_save_load.py pytest -v -s $TE_PATH/tests/pytorch/test_gqa.py pytest -v -s $TE_PATH/tests/pytorch/test_fused_optimizer.py pytest -v -s $TE_PATH/tests/pytorch/test_multi_tensor.py diff --git a/qa/L1_pytorch_distributed_unittest/test.sh b/qa/L1_pytorch_distributed_unittest/test.sh index ee7c28ca5f..8ee0be1af5 100644 --- a/qa/L1_pytorch_distributed_unittest/test.sh +++ b/qa/L1_pytorch_distributed_unittest/test.sh @@ -8,8 +8,8 @@ set -e pip install pytest==8.2.1 pytest -v -s $TE_PATH/tests/pytorch/distributed/test_numerics.py -pytest -v -s $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops.py -pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py pytest -v -s $TE_PATH/tests/pytorch/distributed/test_torch_fsdp2.py +pytest -v -s $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py +# pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py ### TODO Debug UB support with te.Sequential pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn_with_cp.py diff --git a/qa/L1_pytorch_onnx_test/test.sh b/qa/L1_pytorch_onnx_test/test.sh deleted file mode 100644 index 8e4ef03b8e..0000000000 --- a/qa/L1_pytorch_onnx_test/test.sh +++ /dev/null @@ -1,16 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -set -e - -: ${TE_PATH:=/opt/transformerengine} - -pip install pytest==8.2.1 onnxruntime==1.19.2 - -# Build custom ONNX Runtime operators -export CUSTOM_ORT_OPS_PATH=$TE_PATH/tests/pytorch/custom_ort_ops -bash $CUSTOM_ORT_OPS_PATH/build.sh - -# Run tests -NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/test_onnx_export.py diff --git a/qa/L3_pytorch_FA_versions_test/test.sh b/qa/L3_pytorch_FA_versions_test/test.sh index e63ba358a5..8ed3002214 100644 --- a/qa/L3_pytorch_FA_versions_test/test.sh +++ b/qa/L3_pytorch_FA_versions_test/test.sh @@ -12,7 +12,14 @@ pip install pytest==8.2.1 export MAX_JOBS=4 # Iterate over Flash Attention versions -FA_versions=(2.1.1 2.3.0 2.4.1 2.5.7 2.6.3 3.0.0b1) +sm_arch=`python -c "import torch; sm = torch.cuda.get_device_capability(0); print(sm[0]*10+sm[1])"` +if [ $sm_arch -gt 90 ] +then + FA_versions=(2.7.3) +else + FA_versions=(2.1.1 2.3.0 2.4.1 2.5.7 2.7.3 3.0.0b1) +fi + for fa_version in "${FA_versions[@]}" do @@ -21,10 +28,10 @@ do then pip install flash-attn==${fa_version} else - pip install "git+https://github.com/Dao-AILab/flash-attention.git#egg=flashattn-hopper&subdirectory=hopper" + pip install "git+https://github.com/Dao-AILab/flash-attention.git@v2.7.2#egg=flashattn-hopper&subdirectory=hopper" python_path=`python -c "import site; print(site.getsitepackages()[0])"` mkdir -p $python_path/flashattn_hopper - wget -P $python_path/flashattn_hopper https://raw.githubusercontent.com/Dao-AILab/flash-attention/main/hopper/flash_attn_interface.py + wget -P $python_path/flashattn_hopper https://raw.githubusercontent.com/Dao-AILab/flash-attention/v2.7.2/hopper/flash_attn_interface.py fi # Run tests diff --git a/setup.py b/setup.py index 643dd7a908..1d9818458e 100644 --- a/setup.py +++ b/setup.py @@ -5,6 +5,7 @@ """Installation script.""" import os +import sys import time from pathlib import Path from typing import List, Tuple @@ -35,14 +36,13 @@ if "pytorch" in frameworks: from torch.utils.cpp_extension import BuildExtension -elif "paddle" in frameworks: - from paddle.utils.cpp_extension import BuildExtension elif "jax" in frameworks: install_and_import("pybind11[global]") from pybind11.setup_helpers import build_ext as BuildExtension CMakeBuildExtension = get_build_ext(BuildExtension) +archs = cuda_archs() class TimedBdist(bdist_wheel): @@ -57,7 +57,7 @@ def run(self): def setup_common_extension() -> CMakeExtension: """Setup CMake extension for common library""" - cmake_flags = ["-DCMAKE_CUDA_ARCHITECTURES={}".format(cuda_archs())] + cmake_flags = ["-DCMAKE_CUDA_ARCHITECTURES={}".format(archs)] if bool(int(os.getenv("NVTE_UB_WITH_MPI", "0"))): assert ( os.getenv("MPI_HOME") is not None @@ -104,13 +104,11 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]: if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))): if "pytorch" in frameworks: install_reqs.extend(["torch"]) - test_reqs.extend(["numpy", "onnxruntime", "torchvision", "prettytable"]) + test_reqs.extend(["numpy", "torchvision", "prettytable"]) if "jax" in frameworks: install_reqs.extend(["jax", "flax>=0.7.1"]) - test_reqs.extend(["numpy", "praxis"]) - if "paddle" in frameworks: - install_reqs.append("paddlepaddle-gpu") - test_reqs.append("numpy") + # test_reqs.extend(["numpy", "praxis"]) + test_reqs.extend(["numpy"]) return [remove_dups(reqs) for reqs in [setup_reqs, install_reqs, test_reqs]] @@ -135,7 +133,6 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]: extras_require = { "pytorch": [f"transformer_engine_torch=={__version__}"], "jax": [f"transformer_engine_jax=={__version__}"], - "paddle": [f"transformer_engine_paddle=={__version__}"], } else: setup_requires, install_requires, test_requires = setup_requirements() @@ -169,16 +166,6 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]: current_file_path / "transformer_engine", ) ) - if "paddle" in frameworks: - from build_tools.paddle import setup_paddle_extension - - ext_modules.append( - setup_paddle_extension( - "transformer_engine/paddle/csrc", - current_file_path / "transformer_engine" / "paddle" / "csrc", - current_file_path / "transformer_engine", - ) - ) # Configure package setuptools.setup( diff --git a/tests/cpp/CMakeLists.txt b/tests/cpp/CMakeLists.txt index d8c8d99fac..081cd14eb4 100644 --- a/tests/cpp/CMakeLists.txt +++ b/tests/cpp/CMakeLists.txt @@ -5,7 +5,11 @@ cmake_minimum_required(VERSION 3.18) if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES) - set(CMAKE_CUDA_ARCHITECTURES 70 80 90) + if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.8) + set(CMAKE_CUDA_ARCHITECTURES 75 80 89 90 100 120) + else () + set(CMAKE_CUDA_ARCHITECTURES 75 80 89 90) + endif() endif() diff --git a/tests/cpp/operator/CMakeLists.txt b/tests/cpp/operator/CMakeLists.txt index 178dc5e8dd..ce78fcaae2 100644 --- a/tests/cpp/operator/CMakeLists.txt +++ b/tests/cpp/operator/CMakeLists.txt @@ -3,23 +3,33 @@ # See LICENSE for license information. add_executable(test_operator + test_cast.cu + test_cast_dbias.cu + test_cast_dbias_dgelu.cu + test_cast_gated_swiglu.cu + test_cast_mxfp8_gated_swiglu.cu test_qdq.cu - test_cast_transpose.cu + test_cast_mxfp8.cu + test_dequantize_mxfp8.cu test_transpose.cu + test_cast_transpose.cu test_cast_transpose_dbias.cu test_cast_transpose_dbias_dgelu.cu test_cast_transpose_dgeglu.cu test_act.cu test_normalization.cu + test_normalization_mxfp8.cu test_multi_cast_transpose.cu test_multi_padding.cu test_causal_softmax.cu + test_swizzle.cu ../test_common.cu) +find_package(OpenMP REQUIRED) list(APPEND test_operator_LINKER_LIBS CUDA::cudart GTest::gtest_main ${TE_LIB} CUDA::nvrtc CUDNN::cudnn) -target_link_libraries(test_operator PUBLIC ${test_operator_LINKER_LIBS}) -target_compile_options(test_operator PRIVATE -O2) +target_link_libraries(test_operator PUBLIC ${test_operator_LINKER_LIBS} OpenMP::OpenMP_CXX) +target_compile_options(test_operator PRIVATE -O2 -fopenmp) include(GoogleTest) -gtest_discover_tests(test_operator) +gtest_discover_tests(test_operator DISCOVERY_TIMEOUT 600) diff --git a/tests/cpp/operator/test_act.cu b/tests/cpp/operator/test_act.cu index cec997d078..4224f199f4 100644 --- a/tests/cpp/operator/test_act.cu +++ b/tests/cpp/operator/test_act.cu @@ -21,58 +21,6 @@ using namespace transformer_engine; -namespace { - -// forward - -float gelu(const float x) { - return 0.5f * x * (1.0f + tanhf(0.79788456F * x * (1.0f + 0.044715f * x * x))); -} - -float silu(const float x) { - return x / (1 + expf(-x)); -} - -float relu(const float x) { - return x > 0 ? x : 0; -} - -float srelu(const float x) { - return x > 0 ? x * x : 0; -} - -float qgelu(const float x) { - return x / (1 + expf(-1.702f * x)); -} - -// backward - -float dgelu(const float x) { - const float tanh_out = tanhf(0.79788456f * x * (1.f + 0.044715f * x * x)); - return 0.5f * x * ((1.f - tanh_out * tanh_out) * (0.79788456f + 0.1070322243f * x * x)) + - 0.5f * (1.f + tanh_out); -} - -float dsilu(const float x) { - const float sigmoid = 1.f / (1 + expf(-x)); - return x * sigmoid * (1.f - sigmoid) + sigmoid; -} - -float drelu(const float x) { - return x > 0.f ? 1.f : 0.f; -} - -float dsrelu(const float x) { - return fmaxf(2.f * x, 0.f); -} - -float dqgelu(const float x) { - const float sigmoid = 1.f / (1 + expf(-1.702f * x)); - return 1.702f * x * sigmoid * (1.f - sigmoid) + sigmoid; -} - -} // namespace - template void compute_ref_act_cast(const IT *input_h, OT *output_h, @@ -82,6 +30,7 @@ void compute_ref_act_cast(const IT *input_h, const size_t H) { CT amax = 0.; + #pragma omp parallel for schedule(static) reduction(max: amax) proc_bind(spread) for (size_t i = 0; i < N; i++) { for (size_t j = 0; j < H; j++) { CT elt = static_cast(input_h[i * H + j]); @@ -101,6 +50,7 @@ void compute_ref_dact_cast(const IT *input_h, const size_t N, const size_t H) { using CT = float; + #pragma omp parallel for schedule(static) proc_bind(spread) for (size_t i = 0; i < N; i++) { for (size_t j = 0; j < H; j++) { CT elt = static_cast(input_h[i * H + j]); @@ -118,6 +68,7 @@ void compute_ref_glu_act_cast(const IT *input_h, OT *output_h, const CT scale, C const int col = H * 2; + #pragma omp parallel for schedule(static) reduction(max: amax) proc_bind(spread) for (size_t i = 0; i < N; i++) { for (size_t j = 0; j < H; j++) { CT gelu_elt = static_cast(input_h[i * col + j]); @@ -139,6 +90,7 @@ void compute_ref_dglu_act_cast(const IT *input_h, const IT *grad_h, OT *output_h const int col = H * 2; using CT = float; + #pragma omp parallel for schedule(static) proc_bind(spread) for (size_t i = 0; i < N; i++) { for (size_t j = 0; j < H; j++) { CT grad = static_cast(grad_h[i * H + j]); @@ -164,10 +116,10 @@ void performTest(const size_t N, const size_t H) { DType itype = TypeInfo::dtype; DType otype = TypeInfo::dtype; - Tensor input({ N, H }, itype); - Tensor output({ N, H }, otype); - Tensor igrad({ N, H }, itype); - Tensor ograd({ N, H }, itype); + Tensor input("input", { N, H }, itype); + Tensor output("output", { N, H }, otype); + Tensor igrad("igrad", { N, H }, itype); + Tensor ograd("ograd", { N, H }, itype); fillUniform(&input); fillUniform(&ograd); @@ -179,7 +131,7 @@ void performTest(const size_t N, const size_t H) { nvte_act(input.data(), output.data(), 0); float ref_amax; - compute_ref_act_cast(input.cpu_dptr(), ref_output.get(), + compute_ref_act_cast(input.rowwise_cpu_dptr(), ref_output.get(), output.scale(), &ref_amax, N, H); cudaDeviceSynchronize(); @@ -195,7 +147,7 @@ void performTest(const size_t N, const size_t H) { nvte_dact(ograd.data(), input.data(), igrad.data(), 0); - compute_ref_dact_cast(input.cpu_dptr(), ograd.cpu_dptr(), + compute_ref_dact_cast(input.rowwise_cpu_dptr(), ograd.rowwise_cpu_dptr(), ref_igrad.get(), N, H); cudaDeviceSynchronize(); @@ -219,10 +171,10 @@ void performTestGLU(const size_t N, const size_t H) { DType itype = TypeInfo::dtype; DType otype = TypeInfo::dtype; - Tensor input({N, H * 2}, itype); - Tensor output({N, H}, otype); - Tensor igrad({ N, H * 2 }, itype); - Tensor ograd({ N, H }, itype); + Tensor input("input", {N, H * 2}, itype); + Tensor output("output", {N, H}, otype); + Tensor igrad("igrad", { N, H * 2 }, itype); + Tensor ograd("ograd", { N, H }, itype); fillUniform(&input); fillUniform(&ograd); @@ -234,7 +186,7 @@ void performTestGLU(const size_t N, const size_t H) { nvte_act(input.data(), output.data(), 0); float ref_amax; - compute_ref_glu_act_cast(input.cpu_dptr(), ref_output.get(), + compute_ref_glu_act_cast(input.rowwise_cpu_dptr(), ref_output.get(), output.scale(), &ref_amax, N, H); cudaDeviceSynchronize(); @@ -242,15 +194,19 @@ void performTestGLU(const size_t N, const size_t H) { ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); if (otype == DType::kFloat8E4M3 || otype == DType::kFloat8E5M2) { - auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); - compareResults("amax", output.amax(), ref_amax, atol_amax, rtol_amax); + auto [atol, rtol] = getTolerances(DType::kFloat32); + compareResults("amax", output.amax(), ref_amax, atol, rtol); + if (output.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) { + const float ref_scale = 1.f / output.scale(); + compareResults("scale_inv", *output.rowwise_cpu_scale_inv_ptr(), ref_scale, atol, rtol); + } } auto [atol, rtol] = getTolerances(otype); compareResults("output_gelu", output, ref_output.get(), atol, rtol); nvte_dact(ograd.data(), input.data(), igrad.data(), 0); - compute_ref_dglu_act_cast(input.cpu_dptr(), ograd.cpu_dptr(), + compute_ref_dglu_act_cast(input.rowwise_cpu_dptr(), ograd.rowwise_cpu_dptr(), ref_igrad.get(), N, H); cudaDeviceSynchronize(); diff --git a/tests/cpp/operator/test_cast.cu b/tests/cpp/operator/test_cast.cu new file mode 100644 index 0000000000..f57d1f035d --- /dev/null +++ b/tests/cpp/operator/test_cast.cu @@ -0,0 +1,130 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include +#include + +#include +#include +#include + +#include +#include "../test_common.h" + +using namespace transformer_engine; + +namespace { + +template +void compute_ref(const InputType *data, OutputType *output_c, + const size_t size, + float *amax, float scale) { + using compute_t = float; + compute_t current_max = -1e100; + for (size_t i = 0; i < size; ++i) { + compute_t current = static_cast(data[i]); + current_max = fmaxf(current_max, fabsf(current)); + output_c[i] = OutputType(scale * current); + } + *amax = current_max; +} + +template +void performTest(const std::vector& shape) { + using namespace test; + + const size_t full_size = product(shape); + + DType itype = TypeInfo::dtype; + DType otype = TypeInfo::dtype; + + Tensor input("input", shape, itype); + Tensor output_c("output_c", shape, otype); + + std::unique_ptr ref_output_c = std::make_unique(full_size); + + fillUniform(&input); + setRandomScale(&output_c); + + nvte_quantize(input.data(), output_c.data(), 0); + + float ref_amax; + compute_ref(input.rowwise_cpu_dptr(), ref_output_c.get(), + full_size, &ref_amax, output_c.scale()); + + cudaDeviceSynchronize(); + auto err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + if (isFp8Type(otype)) { + auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); + compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax); + float ref_scale_inv = 1.f / output_c.scale(); + compareResults("scale_inv", output_c.rowwise_scale_inv(), ref_scale_inv, atol_amax, rtol_amax); + } + auto [atol, rtol] = getTolerances(otype); + compareResults("output_c", output_c, ref_output_c.get(), true, atol, rtol); +} + +std::vector> test_cases = { + {16}, + {16000}, + {128, 128}, + {256, 256}, + {768, 1024}, + {256, 65536}, + {2048, 12288}, + {65536, 128}, + {65536, 160}, + {16384, 1616}, + {1, 128}, + {1, 1296}, + {1, 16}, + {5, 160}, + {5, 4, 3, 160}, + {217, 256}, +}; +} // namespace + +class CastTestSuite : public ::testing::TestWithParam>> {}; + +TEST_P(CastTestSuite, TestCast) { + using namespace transformer_engine; + using namespace test; + + const DType input_type = std::get<0>(GetParam()); + const DType output_type = std::get<1>(GetParam()); + const auto size = std::get<2>(GetParam()); + + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType, + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType, + performTest(size); + ); + ); +} + + + +INSTANTIATE_TEST_SUITE_P( + OperatorTest, + CastTestSuite, + ::testing::Combine( + ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), + ::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2), + ::testing::ValuesIn(test_cases)), + [](const testing::TestParamInfo& info) { + std::string name = test::typeName(std::get<0>(info.param)) + "X" + + test::typeName(std::get<1>(info.param)); + const auto& shape = std::get<2>(info.param); + for ( const auto& s: shape) { + name += "X" + std::to_string(s); + } + return name; + }); diff --git a/tests/cpp/operator/test_cast_dbias.cu b/tests/cpp/operator/test_cast_dbias.cu new file mode 100644 index 0000000000..1f0a9305d8 --- /dev/null +++ b/tests/cpp/operator/test_cast_dbias.cu @@ -0,0 +1,181 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include +#include "../test_common.h" + +using namespace transformer_engine; + +namespace { + +template +void compute_ref_cast_dbias(const IT *input_h, + const CT scale, + OT *output_c_h, + CT *amax_h, + IT *dbias_h, + const size_t N, + const size_t H) { + CT amax = 0.; + + std::vector acc_dbias(H, 0.); + + for (size_t i = 0; i < N; i++) { + for (size_t j = 0; j < H; j++) { + CT elt = static_cast(input_h[i * H + j]); + + // update amax + amax = std::abs(elt) > amax ? std::abs(elt) : amax; + + output_c_h[i * H + j] = static_cast(scale * elt); + + // dbias + acc_dbias[j] += elt; + } + } + + *amax_h = amax; + + for (size_t i = 0; i < H; i++) { + dbias_h[i] = static_cast(acc_dbias[i]); + } +} + +template +void performTest(const std::vector& shape) { + using namespace test; + using CType = fp32; + + DType itype = TypeInfo::dtype; + DType otype = TypeInfo::dtype; + + const size_t N = first_dimension(shape); + const size_t H = last_dimension(shape); + + Tensor input("input", shape, itype); + + Tensor output_c("output_c", shape, otype); + // dbias has the same data type with "output grad" + Tensor dbias("dbias", {H}, itype); + + fillUniform(&input); + setRandomScale(&output_c); + + std::unique_ptr ref_output_c = std::make_unique(N*H); + std::unique_ptr ref_output_dbias = std::make_unique(H); + + CType ref_amax; + compute_ref_cast_dbias(input.rowwise_cpu_dptr(), + output_c.scale(), + ref_output_c.get(), + &ref_amax, + ref_output_dbias.get(), + N, H); + + Tensor workspace; + + nvte_quantize_dbias(input.data(), + output_c.data(), + dbias.data(), + workspace.data(), + 0); + + workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype()); + + nvte_quantize_dbias(input.data(), + output_c.data(), + dbias.data(), + workspace.data(), + 0); + + cudaDeviceSynchronize(); + auto err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + + if (isFp8Type(otype)) { + auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); + compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax); + float ref_scale_inv = 1.f / output_c.scale(); + compareResults("scale_inv", output_c.rowwise_scale_inv(), ref_scale_inv, atol_amax, rtol_amax); + } + auto [atol, rtol] = getTolerances(otype); + compareResults("output_c", output_c, ref_output_c.get(), true, atol, rtol); + + auto [atol_dbias, rtol_dbias] = getTolerances(itype); + rtol_dbias *= 4; + compareResults("output_dbias", dbias, ref_output_dbias.get(), true, atol_dbias, rtol_dbias); +} + +std::vector> test_cases = { + {128, 128}, + {256, 256}, + {768, 1024}, + {256, 65536}, + {2048, 12288}, + {65536, 128}, + {65536, 160}, + {16384, 1616}, + {1, 128}, + {1, 1296}, + {1, 16}, + {5, 160}, + {5, 4, 3, 160}, + {217, 256}, +}; + +} // namespace; + + +class CastDBiasTestSuite : public ::testing::TestWithParam>> {}; + +TEST_P(CastDBiasTestSuite, TestCastDBias) { + using namespace transformer_engine; + using namespace test; + // Skip tests for pre-Blackwell architectures + if (getDeviceComputeCapability() < blackwellComputeCapability) { + GTEST_SKIP(); + } + + const DType input_type = std::get<0>(GetParam()); + const DType output_type = std::get<1>(GetParam()); + const auto size = std::get<2>(GetParam()); + + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType, + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType, + performTest(size); + ); + ); +} + +INSTANTIATE_TEST_SUITE_P( + OperatorTest, + CastDBiasTestSuite, + ::testing::Combine( + ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), + ::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2), + ::testing::ValuesIn(test_cases)), + [](const testing::TestParamInfo& info) { + std::string name = test::typeName(std::get<0>(info.param)) + "X" + + test::typeName(std::get<1>(info.param)); + const auto& shape = std::get<2>(info.param); + for ( const auto& s: shape) { + name += "X" + std::to_string(s); + } + return name; + }); diff --git a/tests/cpp/operator/test_cast_dbias_dgelu.cu b/tests/cpp/operator/test_cast_dbias_dgelu.cu new file mode 100644 index 0000000000..20ea5c31f1 --- /dev/null +++ b/tests/cpp/operator/test_cast_dbias_dgelu.cu @@ -0,0 +1,196 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include +#include "../test_common.h" + +using namespace transformer_engine; +using namespace test; + +namespace { + +template +void compute_ref_cast_dbias_dgelu(const IT *input, + const IT *grad, + const CT scale, + OT *output_c, + CT *amax_h, + IT *dbias, + const size_t N, + const size_t H) { + CT amax = 0.; + + std::vector acc_dbias(H, 0.); + + for (size_t i = 0; i < N; i++) { + for (size_t j = 0; j < H; j++) { + CT in_elt = static_cast(input[i * H + j]); + const CT in_grad = static_cast(grad[i * H + j]); + + const CT elt = in_grad * static_cast(dgelu(static_cast(in_elt))); + const CT elt_abs = std::abs(elt); + + // update amax + if (elt_abs > amax) { + amax = elt_abs; + } + + output_c[i * H + j] = static_cast(scale * elt); + + // dbias + acc_dbias[j] += elt; + } + } + + *amax_h = amax; + + for (size_t i = 0; i < H; i++) { + dbias[i] = static_cast(acc_dbias[i]); + } +} + +template +void performTest(const std::vector& shape) { + using namespace test; + using CType = fp32; + + DType itype = TypeInfo::dtype; + DType otype = TypeInfo::dtype; + + const size_t N = first_dimension(shape); + const size_t H = last_dimension(shape); + + Tensor input("input", shape, itype); + Tensor grad("grad", shape, itype); + + Tensor output_c("output_c", shape, otype); + // dbias has the same data type with "output grad" + Tensor dbias("dbias", {H}, itype); + + fillUniform(&input); + fillUniform(&grad); + setRandomScale(&output_c); + + std::unique_ptr ref_output_c = std::make_unique(N*H); + std::unique_ptr ref_output_dbias = std::make_unique(H); + + CType ref_amax; + compute_ref_cast_dbias_dgelu(input.rowwise_cpu_dptr(), + grad.rowwise_cpu_dptr(), + output_c.scale(), + ref_output_c.get(), + &ref_amax, + ref_output_dbias.get(), + N, H); + + Tensor workspace; + + nvte_quantize_dbias_dgelu(grad.data(), + input.data(), + output_c.data(), + dbias.data(), + workspace.data(), + 0); + + workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype()); + + + nvte_quantize_dbias_dgelu(grad.data(), + input.data(), + output_c.data(), + dbias.data(), + workspace.data(), + 0); + + cudaDeviceSynchronize(); + auto err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + + if (isFp8Type(otype)) { + auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); + compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax); + float ref_scale_inv = 1.f / output_c.scale(); + compareResults("scale_inv", output_c.rowwise_scale_inv(), ref_scale_inv, atol_amax, rtol_amax); + } + + auto [atol, rtol] = getTolerances(otype); + compareResults("output_c", output_c, ref_output_c.get(), true, atol, rtol); + + auto [atol_dbias, rtol_dbias] = getTolerances(itype); + rtol_dbias *= 4; + compareResults("output_dbias", dbias, ref_output_dbias.get(), true, atol_dbias, rtol_dbias); +} + +std::vector> test_cases = { + {128, 128}, + {256, 256}, + {768, 1024}, + {256, 65536}, + {2048, 12288}, + {65536, 128}, + {65536, 160}, + {16384, 1616}, + {1, 128}, + {1, 1296}, + {1, 16}, + {5, 160}, + {5, 4, 3, 160}, + {217, 256}, +}; + +} // namespace; + + +class CastDBiasDGeluTestSuite : public ::testing::TestWithParam>> {}; + +TEST_P(CastDBiasDGeluTestSuite, TestCastDBiasDgelu) { + using namespace transformer_engine; + using namespace test; + // Skip tests for pre-Blackwell architectures + if (getDeviceComputeCapability() < blackwellComputeCapability) { + GTEST_SKIP(); + } + + const DType input_type = std::get<0>(GetParam()); + const DType output_type = std::get<1>(GetParam()); + const auto size = std::get<2>(GetParam()); + + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType, + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType, + performTest(size); + ); + ); +} + +INSTANTIATE_TEST_SUITE_P( + OperatorTest, + CastDBiasDGeluTestSuite, + ::testing::Combine( + ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), + ::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2), + ::testing::ValuesIn(test_cases)), + [](const testing::TestParamInfo& info) { + std::string name = test::typeName(std::get<0>(info.param)) + "X" + + test::typeName(std::get<1>(info.param)); + const auto& shape = std::get<2>(info.param); + for ( const auto& s: shape) { + name += "X" + std::to_string(s); + } + return name; + }); diff --git a/tests/cpp/operator/test_cast_gated_swiglu.cu b/tests/cpp/operator/test_cast_gated_swiglu.cu new file mode 100644 index 0000000000..35ae462106 --- /dev/null +++ b/tests/cpp/operator/test_cast_gated_swiglu.cu @@ -0,0 +1,165 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include +#include +#include "../test_common.h" + +using namespace transformer_engine; +using namespace test; + +namespace { + +template +void compute_ref_cast_dgated_swiglu(const IType * const grad, + const IType * const input, + const float scale, + OType * const output, + float * const amax_ptr, + const size_t rows, + const size_t cols) { + float amax = 0; + const size_t stride = cols * 2; + + #pragma omp parallel for reduction(max: amax) proc_bind(spread) + for (size_t i = 0; i < rows; i++) { + for (size_t j = 0; j < cols; j++) { + float grad_elt = static_cast(grad[i * cols + j]); + float silu_elt = static_cast(input[i * stride + j]); + float gate_elt = static_cast(input[i * stride + cols + j]); + + float after_dsilu = dsilu(silu_elt) * grad_elt * gate_elt; + float after_dgate = grad_elt * silu(silu_elt); + + if (abs(after_dsilu) > amax) { amax = abs(after_dsilu); } + if (abs(after_dgate) > amax) { amax = abs(after_dgate); } + + output[i * stride + j] = static_cast(scale * after_dsilu); + output[i * stride + cols + j] = static_cast(scale * after_dgate); + } + } + + *amax_ptr = amax; +} + +template +void performTest(const std::vector& shape) { + using namespace test; + + DType itype = TypeInfo::dtype; + DType otype = TypeInfo::dtype; + + std::vector input_shape = shape; + input_shape[input_shape.size() - 1] *= 2; + + const size_t input_size = product(input_shape); + + const size_t rows = first_dimension(shape); + const size_t cols = last_dimension(shape); + + Tensor grad("grad", shape, itype); + Tensor input("input", input_shape, itype); + Tensor output_c("output_c", input_shape, otype); + + fillUniform(&grad); + fillUniform(&input); + setRandomScale(&output_c); + + std::unique_ptr ref_output_c = std::make_unique(input_size); + + nvte_dswiglu(grad.data(), input.data(), output_c.data(), 0); + cudaDeviceSynchronize(); + + auto err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + + float ref_amax; + compute_ref_cast_dgated_swiglu(grad.rowwise_cpu_dptr(), + input.rowwise_cpu_dptr(), + output_c.scale(), + ref_output_c.get(), + &ref_amax, + rows, + cols); + + if (isFp8Type(otype)) { + auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); + compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax); + float ref_scale_inv = 1.f / output_c.scale(); + compareResults("scale_inv", output_c.rowwise_scale_inv(), ref_scale_inv, atol_amax, rtol_amax); + } + + auto [atol, rtol] = getTolerances(otype); + compareResults("output_c", output_c, ref_output_c.get(), true, atol, rtol); +} + +std::vector> test_cases = { + {128, 128}, + {256, 256}, + {768, 1024}, + {256, 65536}, + {2048, 12288}, + {65536, 128}, + {217, 256}, + {1296}, + {5, 4, 3, 160}, +}; + +} // namespace + +class CastSwiGLUTestSuite + : public ::testing::TestWithParam>> {}; + +TEST_P(CastSwiGLUTestSuite, TestCastSwiGLU) { + using namespace transformer_engine; + using namespace test; + // Skip tests for pre-Blackwell architectures + if (getDeviceComputeCapability() < blackwellComputeCapability) { + GTEST_SKIP(); + } + + const DType input_type = std::get<0>(GetParam()); + const DType output_type = std::get<1>(GetParam()); + const auto size = std::get<2>(GetParam()); + + if (size.back() % 32 != 0) { + GTEST_SKIP(); + } + + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( + input_type, InputType, + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( + output_type, OutputType, performTest(size););); +} + +INSTANTIATE_TEST_SUITE_P( + OperatorTest, CastSwiGLUTestSuite, + ::testing::Combine( + ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), + ::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2), + ::testing::ValuesIn(test_cases)), + [](const testing::TestParamInfo &info) { + std::string name = test::typeName(std::get<0>(info.param)) + "X" + + test::typeName(std::get<1>(info.param)); + const auto& shape = std::get<2>(info.param); + for ( const auto& s: shape) { + name += "X" + std::to_string(s); + } + return name; + }); diff --git a/tests/cpp/operator/test_cast_mxfp8.cu b/tests/cpp/operator/test_cast_mxfp8.cu new file mode 100644 index 0000000000..cb38a5a74a --- /dev/null +++ b/tests/cpp/operator/test_cast_mxfp8.cu @@ -0,0 +1,636 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include + +#include +#include +#include "../test_common.h" +#include "transformer_engine/transformer_engine.h" + +using namespace transformer_engine; +using namespace test; + +namespace { + +enum ProcessingMethod { + CAST_ONLY, + CAST_DBIAS, + CAST_DBIAS_DACT, + CAST_DACT, + CAST_ACT +}; + +enum ActivationType { + Identity, + GeLU, + SiLU, + ReLU, + QGeLU, + SReLU +}; + +template +void scale_block(const ProcessingMethod processing_method, + const InputType* input, + const InputType* grad, + OutputType* output_c, + float* dbias, + fp8e8m0* output_scales, + const size_t scale_idx, + const size_t i_min, + const size_t i_max, + const size_t j_min, + const size_t j_max, + const size_t cols) { + float amax = 0.0f; + + // Find the absolute maximum value in the block + for (size_t i = i_min; i < i_max; ++i) { + for (size_t j = j_min; j < j_max; ++j) { + const size_t idx = i * cols + j; + float elt = static_cast(input[idx]); + if (processing_method == ProcessingMethod::CAST_DBIAS) { + // grad is the input + elt = static_cast(grad[idx]); + } + if (processing_method != ProcessingMethod::CAST_ONLY + && processing_method != ProcessingMethod::CAST_DBIAS) { + elt = OP(elt); + } + if (processing_method == ProcessingMethod::CAST_DACT || + processing_method == ProcessingMethod::CAST_DBIAS_DACT) { + elt *= static_cast(grad[idx]); + } + dbias[j] += elt; + if (isinf(elt) || isnan(elt)) { + continue; + } + amax = std::max(amax, std::abs(elt)); + } + } + + const fp8e8m0 biased_exponent = float_to_e8m0(amax * Quantized_Limits::max_reciprocal()); + const float scale_reciprocal = exp2f_rcp(biased_exponent); + output_scales[scale_idx] = biased_exponent; + + // Quantize elements in the block + for (size_t i = i_min; i < i_max; ++i) { + for (size_t j = j_min; j < j_max; ++j) { + const size_t idx = i * cols + j; + float elt = static_cast(input[idx]); + if (processing_method == ProcessingMethod::CAST_DBIAS) { + // grad is the input + elt = static_cast(grad[idx]); + } + if (processing_method != ProcessingMethod::CAST_ONLY + && processing_method != ProcessingMethod::CAST_DBIAS) { + elt = OP(elt); + } + if (processing_method == ProcessingMethod::CAST_DACT || + processing_method == ProcessingMethod::CAST_DBIAS_DACT) { + elt *= static_cast(grad[idx]); + } + output_c[idx] = static_cast(elt * scale_reciprocal); + } + } +} + +template +void compute_ref_x1(const ProcessingMethod processing_method, + const InputType* input, + const InputType* grad, + OutputType* output_c, + fp8e8m0* output_scales, + InputType* output_dbias, + const size_t rows, + const size_t cols, + const size_t block_size_Y, + const size_t block_size_X, + const size_t scales_stride) +{ + std::vector output_dbias_fp32(cols, 0); + + const size_t blocks_Y = (rows + block_size_Y - 1) / block_size_Y; + const size_t blocks_X = (cols + block_size_X - 1) / block_size_X; + + for (size_t ii = 0; ii < blocks_Y; ++ii) { + const size_t i_min = ii * block_size_Y; + const size_t i_max = std::min((ii + 1) * block_size_Y, rows); + for (size_t jj = 0; jj < blocks_X; ++jj) { + const size_t j_min = jj * block_size_X; + const size_t j_max = std::min((jj + 1) * block_size_X, cols); + const size_t scale_idx = ii * scales_stride + jj; + scale_block( + processing_method, input, grad, output_c, output_dbias_fp32.data(), + output_scales, scale_idx, i_min, i_max, j_min, j_max, cols); + } + } + for (size_t j = 0; j < cols; ++j) { + output_dbias[j] = static_cast(output_dbias_fp32[j]); + } +} + +template +void compute_ref_x2(const ProcessingMethod processing_method, + const InputType* input, + const InputType* grad, + OutputType* output_rowwise, + OutputType* output_colwise, + fp8e8m0* scales_rowwise, + fp8e8m0* scales_colwise, + InputType* output_dbias, + const size_t rows, + const size_t cols, + const size_t block_size_Y, + const size_t block_size_X, + const size_t scales_stride_rowwise, + const size_t scales_stride_colwise) { + compute_ref_x1( + processing_method, input, grad, output_rowwise, scales_rowwise, output_dbias, + rows, cols, 1, block_size_X, scales_stride_rowwise); + compute_ref_x1( + processing_method, input, grad, output_colwise, scales_colwise, output_dbias, + rows, cols, block_size_Y, 1, scales_stride_colwise); +} + +/** + * Scaling along single dimension (either rows or columns) + * Produces one set of output data and the corresponding data of the fused operation (dbias): + * 1) Scaled rows + row-wise scaling factors + * OR + * 2) Scaled columns + column-wise scaling factors + */ + +template +void performTest_x1(const ProcessingMethod processing_method, + const std::vector& shape, + const bool rowwise, + const bool colwise, + InputsFillCase fill_case) { + using namespace test; + using EncodingType = fp32; + DType itype = TypeInfo::dtype; + DType otype = TypeInfo::dtype; + + const size_t rows = first_dimension(shape); + const size_t cols = last_dimension(shape); + + if (shape.size() < 2 && colwise) { + GTEST_SKIP(); + } + + const size_t block_size_rows = rowwise ? 1 : 32; + const size_t block_size_cols = colwise ? 1 : 32; + + const std::array scale_dims = get_scale_tensor_dims(rows, cols, block_size_rows, + block_size_cols); + + const size_t unpadded_blocks_Y = scale_dims[0]; + const size_t unpadded_blocks_X = scale_dims[1]; + const size_t blocks_Y = scale_dims[2]; + const size_t blocks_X = scale_dims[3]; + const size_t scales_stride = blocks_X; + + Tensor input("input", shape, itype); + Tensor grad("grad", shape, itype); + Tensor output_c("output_c", shape, otype, rowwise, colwise, NVTE_MXFP8_1D_SCALING); + Tensor output_dbias("output_dbias", { cols }, itype); + + std::unique_ptr ref_output_c = std::make_unique(rows * cols); + std::unique_ptr ref_output_dbias = std::make_unique(cols); + std::unique_ptr ref_output_scales = std::make_unique(blocks_Y * blocks_X); + + fillCase(&input, fill_case); + fillUniform(&grad); + + Tensor workspace; + switch (processing_method) { + case ProcessingMethod::CAST_ONLY: { + nvte_quantize(input.data(), output_c.data(), 0); + break; + } + case ProcessingMethod::CAST_DBIAS: { + nvte_quantize_dbias(grad.data(), + output_c.data(), + output_dbias.data(), + workspace.data(), + 0); + workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype()); + + nvte_quantize_dbias(grad.data(), + output_c.data(), + output_dbias.data(), + workspace.data(), + 0); + break; + } + case ProcessingMethod::CAST_DBIAS_DACT: { + nvte_quantize_dbias_dgelu(grad.data(), + input.data(), + output_c.data(), + output_dbias.data(), + workspace.data(), + 0); + workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype()); + + nvte_quantize_dbias_dgelu(grad.data(), + input.data(), + output_c.data(), + output_dbias.data(), + workspace.data(), + 0); + break; + } + case ProcessingMethod::CAST_DACT: { + nvte_dgelu(grad.data(), input.data(), output_c.data(), 0); + break; + } + case ProcessingMethod::CAST_ACT: { + nvte_gelu(input.data(), output_c.data(), 0); + break; + } + } + + cudaDeviceSynchronize(); + auto err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + + compute_ref_x1(processing_method, + input.rowwise_cpu_dptr(), + grad.rowwise_cpu_dptr(), + ref_output_c.get(), + ref_output_scales.get(), + ref_output_dbias.get(), + rows, + cols, + block_size_rows, + block_size_cols, + scales_stride); + + auto [atol, rtol] = getTolerances(otype); + compareResults("output_c", output_c, ref_output_c.get(), rowwise, atol, rtol); + + const uint8_t * const gpu_scales_ptr = rowwise + ? output_c.rowwise_cpu_scale_inv_ptr() + : output_c.columnwise_cpu_scale_inv_ptr(); + + compare_e8m0_scaling_factors("scales", gpu_scales_ptr, ref_output_scales.get(), + unpadded_blocks_Y, unpadded_blocks_X, scales_stride); + + if (processing_method == ProcessingMethod::CAST_DBIAS || processing_method == ProcessingMethod::CAST_DBIAS_DACT) { + auto [atol_dbias, rtol_dbias] = getTolerances(itype); + if (itype == DType::kFloat32) { + atol_dbias = 1e-4; + rtol_dbias *= sqrt(static_cast(rows)) ; + } else { + rtol_dbias *= 4; + } + compareResults("output_dbias", output_dbias, ref_output_dbias.get(), true, atol_dbias, rtol_dbias); + } +} + +/** + * Scaling along both dimensions (rows and columns) + * Produces two sets of scaled output data and the corresponding data of the fused operation (dbias): + * 1) Scaled rows + row-wise scaling factors + * AND + * 2) Scaled columns + column-wise scaling factors + */ +template +void performTest_x2(const ProcessingMethod processing_method, + const std::vector& shape, + const size_t block_size_rows, + const size_t block_size_cols, + InputsFillCase fill_case) { + using namespace test; + using EncodingType = fp32; + DType itype = TypeInfo::dtype; + DType otype = TypeInfo::dtype; + + if (shape.size() < 2) { + GTEST_SKIP(); + } + + const size_t rows = first_dimension(shape); + const size_t cols = last_dimension(shape); + + const std::array scale_dims_rowwise = get_scale_tensor_dims(rows, cols, 1, 32); + const std::array scale_dims_colwise = get_scale_tensor_dims(rows, cols, 32, 1); + + const size_t unpadded_blocks_Y_rowwise = scale_dims_rowwise[0]; + const size_t unpadded_blocks_X_rowwise = scale_dims_rowwise[1]; + const size_t blocks_Y_rowwise = scale_dims_rowwise[2]; + const size_t blocks_X_rowwise = scale_dims_rowwise[3]; + const size_t scales_stride_rowwise = blocks_X_rowwise; + + const size_t unpadded_blocks_Y_colwise = scale_dims_colwise[0]; + const size_t unpadded_blocks_X_colwise = scale_dims_colwise[1]; + const size_t blocks_Y_colwise = scale_dims_colwise[2]; + const size_t blocks_X_colwise = scale_dims_colwise[3]; + const size_t scales_stride_colwise = blocks_X_colwise; + + Tensor input("input", shape, itype); + Tensor grad("grad", shape, itype); + Tensor output("output", shape, otype, true, true, NVTE_MXFP8_1D_SCALING); + Tensor output_dbias("output_dbias", { cols }, itype); + + std::unique_ptr ref_output_c_rowwise = std::make_unique(rows * cols); + std::unique_ptr ref_output_c_colwise = std::make_unique(rows * cols); + std::unique_ptr ref_scales_rowwise = std::make_unique(blocks_Y_rowwise * blocks_X_rowwise); + std::unique_ptr ref_scales_colwise = std::make_unique(blocks_Y_colwise * blocks_X_colwise); + std::unique_ptr ref_output_dbias = std::make_unique(cols); + + fillCase(&input, fill_case); + fillUniform(&grad); + + Tensor workspace; + switch (processing_method) { + case ProcessingMethod::CAST_ONLY: { + nvte_quantize(input.data(), output.data(), 0); + break; + } + case ProcessingMethod::CAST_DBIAS: { + nvte_quantize_dbias(grad.data(), + output.data(), + output_dbias.data(), + workspace.data(), + 0); + workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype()); + + nvte_quantize_dbias(grad.data(), + output.data(), + output_dbias.data(), + workspace.data(), + 0); + break; + } + case ProcessingMethod::CAST_DBIAS_DACT: { + nvte_quantize_dbias_dgelu(grad.data(), + input.data(), + output.data(), + output_dbias.data(), + workspace.data(), + 0); + workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype()); + + nvte_quantize_dbias_dgelu(grad.data(), + input.data(), + output.data(), + output_dbias.data(), + workspace.data(), + 0); + break; + } + case ProcessingMethod::CAST_DACT: { + nvte_dgelu(grad.data(), input.data(), output.data(), 0); + break; + } + case ProcessingMethod::CAST_ACT: { + nvte_gelu(input.data(), output.data(), 0); + break; + } + } + + cudaDeviceSynchronize(); + auto err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + + compute_ref_x2(processing_method, + input.rowwise_cpu_dptr(), + grad.rowwise_cpu_dptr(), + ref_output_c_rowwise.get(), + ref_output_c_colwise.get(), + ref_scales_rowwise.get(), + ref_scales_colwise.get(), + ref_output_dbias.get(), + rows, + cols, + block_size_rows, + block_size_cols, + scales_stride_rowwise, + scales_stride_colwise); + + auto [atol, rtol] = getTolerances(otype); + compareResults("output_c_rowwise", output, ref_output_c_rowwise.get(), true, atol, rtol); + compareResults("output_c_colwise", output, ref_output_c_colwise.get(), false, atol, rtol); + compare_e8m0_scaling_factors("scales_rowwise", output.rowwise_cpu_scale_inv_ptr(), + ref_scales_rowwise.get(), unpadded_blocks_Y_rowwise, + unpadded_blocks_X_rowwise, scales_stride_rowwise); + compare_e8m0_scaling_factors("scales_colwise", output.columnwise_cpu_scale_inv_ptr(), + ref_scales_colwise.get(), unpadded_blocks_Y_colwise, + unpadded_blocks_X_colwise, scales_stride_colwise); + + if (processing_method == ProcessingMethod::CAST_DBIAS || processing_method == ProcessingMethod::CAST_DBIAS_DACT) { + auto [atol_dbias, rtol_dbias] = getTolerances(itype); + if (itype == DType::kFloat32) { + atol_dbias = 1e-4; + rtol_dbias *= sqrt(static_cast(rows)) ; + } else { + rtol_dbias *= 4; + } + compareResults("output_dbias", output_dbias, ref_output_dbias.get(), true, atol_dbias, rtol_dbias); + } +} + +std::vector> matrix_sizes = { + {1, 16}, + {16, 48}, + {65, 96}, + {128, 128}, + {256, 256}, + {993, 512}, + {256, 65536}, + {2048, 6144}, + {16384, 128}, + {32768, 160}, + {4096, 1632}, + {1024}, + {8, 32, 1024}, + {16, 8, 4, 512}, +}; + +std::vector> block_sizes = { + {1, 32}, + {32, 1}, + {32, 32}, +}; + +std::vector input_scenarios = { + InputsFillCase::uniform, + // InputsFillCase::zeros, + // InputsFillCase::zero_to_minNorm, + // InputsFillCase::minNorm_to_maxNorm, + // InputsFillCase::maxNorm_to_inf +}; + +std::vector processing_methods = { + ProcessingMethod::CAST_ONLY, + ProcessingMethod::CAST_DBIAS, + ProcessingMethod::CAST_DBIAS_DACT, + ProcessingMethod::CAST_DACT, + ProcessingMethod::CAST_ACT, +}; + +// Only GeLU activation tests are supported +std::vector Activation_types = { + ActivationType::Identity, + ActivationType::GeLU, + // ActivationType::SiLU, + // ActivationType::ReLU, + // ActivationType::QGeLU, + // ActivationType::SReLU, +}; + +} // namespace + +class FusedCastMXFP8TestSuite : public ::testing::TestWithParam + , + std::pair, + transformer_engine::DType, + transformer_engine::DType, + InputsFillCase>> {}; + +#define DACT_FUNC_SWITCH(OP_FUNC_TYPE, OP, ...) \ +switch (OP_FUNC_TYPE) { \ + case ActivationType::Identity: { constexpr auto OP = &identity; { __VA_ARGS__ } } break; \ + case ActivationType::GeLU: { constexpr auto OP = &dgelu; { __VA_ARGS__ } } break; \ + case ActivationType::SiLU: { constexpr auto OP = &dsilu; { __VA_ARGS__ } } break; \ + case ActivationType::ReLU: { constexpr auto OP = &drelu; { __VA_ARGS__ } } break; \ + case ActivationType::QGeLU: { constexpr auto OP = &dqgelu; { __VA_ARGS__ } } break; \ + case ActivationType::SReLU: { constexpr auto OP = &dsrelu; { __VA_ARGS__ } } break; \ +} + +#define ACT_FUNC_SWITCH(OP_FUNC_TYPE, OP, ...) \ +switch (OP_FUNC_TYPE) { \ + case ActivationType::Identity: { constexpr auto OP = &identity; { __VA_ARGS__ } } break; \ + case ActivationType::GeLU: { constexpr auto OP = &gelu; { __VA_ARGS__ } } break; \ + case ActivationType::SiLU: { constexpr auto OP = &silu; { __VA_ARGS__ } } break; \ + case ActivationType::ReLU: { constexpr auto OP = &relu; { __VA_ARGS__ } } break; \ + case ActivationType::QGeLU: { constexpr auto OP = &qgelu; { __VA_ARGS__ } } break; \ + case ActivationType::SReLU: { constexpr auto OP = &srelu; { __VA_ARGS__ } } break; \ +} + +TEST_P(FusedCastMXFP8TestSuite, TestFusedCastMXFP8) { + // Skip tests for pre-Blackwell architectures + if (getDeviceComputeCapability() < blackwellComputeCapability) { + GTEST_SKIP(); + } + + using namespace transformer_engine; + using namespace test; + + const ProcessingMethod processing_method = std::get<0>(GetParam()); + const ActivationType Act_type = std::get<1>(GetParam()); + const auto matrix_size = std::get<2>(GetParam()); + const auto block_size = std::get<3>(GetParam()); + const DType input_type = std::get<4>(GetParam()); + const DType output_type = std::get<5>(GetParam()); + const InputsFillCase fill_case = std::get<6>(GetParam()); + + // Skips non Act tests if the Activation type is not an identity + if ((processing_method == ProcessingMethod::CAST_ONLY || processing_method == ProcessingMethod::CAST_DBIAS) + && Act_type != ActivationType::Identity) { + GTEST_SKIP(); + } + // Skips Act tests if the Activation is an identity + if ((processing_method == ProcessingMethod::CAST_DBIAS_DACT + || processing_method == ProcessingMethod::CAST_DACT + || processing_method == ProcessingMethod::CAST_ACT) && (Act_type == ActivationType::Identity)) { + GTEST_SKIP(); + } + + const bool rowwise = block_size.second != 1; + const bool colwise = block_size.first != 1; + if (processing_method == ProcessingMethod::CAST_ACT) { + // Forward activations + ACT_FUNC_SWITCH(Act_type, OP, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(output_type, OutputType, + if (block_size.first == 1 || block_size.second == 1) { + performTest_x1( + processing_method, matrix_size, + rowwise, colwise, fill_case); + } else { + performTest_x2( + processing_method, matrix_size, + block_size.first, block_size.second, fill_case); + } + ); + ); + ); + } else { + DACT_FUNC_SWITCH(Act_type, OP, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(output_type, OutputType, + if (block_size.first == 1 || block_size.second == 1) { + performTest_x1( + processing_method, matrix_size, + rowwise, colwise, fill_case); + } else { + performTest_x2( + processing_method, matrix_size, + block_size.first, block_size.second, fill_case); + } + ); + ); + ); + } +} + +std::string to_string(const ProcessingMethod method) { + switch (method) { + case ProcessingMethod::CAST_ONLY: return "CAST_ONLY"; + case ProcessingMethod::CAST_DBIAS: return "CAST_DBIAS"; + case ProcessingMethod::CAST_DBIAS_DACT: return "CAST_DBIAS_DACT"; + case ProcessingMethod::CAST_DACT: return "CAST_DACT"; + case ProcessingMethod::CAST_ACT: return "CAST_ACT"; + default: return ""; + } +} + +std::string to_string(const ActivationType Act_type) { + switch (Act_type) { + case ActivationType::Identity: return "Identity"; + case ActivationType::GeLU: return "GeLU"; + case ActivationType::SiLU: return "SiLU"; + case ActivationType::ReLU: return "ReLU"; + case ActivationType::QGeLU: return "QGeLU"; + case ActivationType::SReLU: return "SReLU"; + default: return ""; + } +} + +INSTANTIATE_TEST_SUITE_P( + OperatorTest, + FusedCastMXFP8TestSuite, + ::testing::Combine( + ::testing::ValuesIn(processing_methods), + ::testing::ValuesIn(Activation_types), + ::testing::ValuesIn(matrix_sizes), + ::testing::ValuesIn(block_sizes), + ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), + ::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2), + ::testing::ValuesIn(input_scenarios)), + [](const testing::TestParamInfo& info) { + std::string name = to_string(std::get<0>(info.param)) + "X" + + to_string(std::get<1>(info.param)); + const auto& shape = std::get<2>(info.param); + for ( const auto& s: shape) { + name += "X" + std::to_string(s); + } + name += "X" + std::to_string(std::get<3>(info.param).first) + + "X" + std::to_string(std::get<3>(info.param).second) + + "X" + test::typeName(std::get<4>(info.param)) + + "X" + test::typeName(std::get<5>(info.param)) + + "X" + test::caseName(std::get<6>(info.param)); + return name; + }); diff --git a/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu b/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu new file mode 100644 index 0000000000..6acbdefeab --- /dev/null +++ b/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu @@ -0,0 +1,470 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include + +#include +#include "../test_common.h" +#include "transformer_engine/transformer_engine.h" + +using namespace transformer_engine; +using namespace test; + +namespace { + +template +void scale_block(const IType* grad, + const IType* input, + OType* output, + fp8e8m0* output_scales, + const size_t scale_idx, + const size_t scale_idx_gate, + float& thread_amax, + const size_t i_min, + const size_t i_max, + const size_t j_min, + const size_t j_max, + const size_t cols) { + + float block_amax = 0.0f; + float block_amax_gate = 0.0f; + const size_t stride = cols * 2; + + // Find the absolute maximum value in the block + for (size_t i = i_min; i < i_max; ++i) { + for (size_t j = j_min; j < j_max; ++j) { + float silu_elt = static_cast(input[i * stride + j]); + float gate_elt = static_cast(input[i * stride + cols + j]); + float gated_amax_act = 0; + float gated_amax_gate = 0; + + if constexpr (IS_DGATED) { + const float grad_elt = static_cast(grad[i * cols + j]); + const float after_dsilu = dsilu(silu_elt) * grad_elt * gate_elt; + const float after_dgate = silu(silu_elt) * grad_elt; + gated_amax_act = abs(after_dsilu); + gated_amax_gate = abs(after_dgate); + } else { + const float after_silu = silu(silu_elt) * gate_elt; + gated_amax_act = abs(after_silu); + } + + if (gated_amax_act > block_amax) { block_amax = gated_amax_act; } + if (gated_amax_gate > block_amax_gate) { block_amax_gate = gated_amax_gate; } + } + } + + const fp8e8m0 biased_exponent = float_to_e8m0(block_amax * + Quantized_Limits::max_reciprocal()); + const float scale_reciprocal = exp2f_rcp(biased_exponent); + output_scales[scale_idx] = biased_exponent; + float scale_reciprocal_gate = 1; + if constexpr (IS_DGATED) { + const fp8e8m0 biased_exponent = float_to_e8m0(block_amax_gate * + Quantized_Limits::max_reciprocal()); + scale_reciprocal_gate = exp2f_rcp(biased_exponent); + output_scales[scale_idx_gate] = biased_exponent; + } + + + // Quantize elements in the block + for (size_t i = i_min; i < i_max; ++i) { + for (size_t j = j_min; j < j_max; ++j) { + float silu_elt = static_cast(input[i * stride + j]); + float gate_elt = static_cast(input[i * stride + cols + j]); + + if constexpr (IS_DGATED) { + const float grad_elt = static_cast(grad[i * cols + j]); + const float after_dsilu = dsilu(silu_elt) * grad_elt * gate_elt; + const float after_dgate = silu(silu_elt) * grad_elt; + output[i * stride + j] = static_cast(after_dsilu * scale_reciprocal); + output[i * stride + cols + j] = static_cast(after_dgate * + scale_reciprocal_gate); + } else { + const float after_silu = silu(silu_elt) * gate_elt; + output[i * cols + j] = static_cast(after_silu * scale_reciprocal); + } + + } + } + thread_amax = std::max(thread_amax, block_amax); + thread_amax = std::max(thread_amax, block_amax_gate); +} + +template +void compute_ref_x1(const IType* grad, + const IType* input, + OType* output, + fp8e8m0* output_scales, + float& ref_amax, + const size_t rows, + const size_t cols, + const size_t block_size_Y, + const size_t block_size_X, + const size_t scales_stride) { + const size_t tile_size_Y = std::max(32lu, block_size_Y); + const size_t tile_size_X = std::max(64lu, block_size_X); + const size_t tiles_num_Y = (rows + tile_size_Y - 1) / tile_size_Y; + const size_t tiles_num_X = (cols + tile_size_X - 1) / tile_size_X; + const size_t blocks_per_tile_Y = tile_size_Y / block_size_Y; + const size_t blocks_per_tile_X = tile_size_X / block_size_X; + + float amax = 0; + #pragma omp parallel reduction(max: amax) proc_bind(spread) + { + float thread_amax = 0; + #pragma omp for schedule(static) + for (size_t t = 0; t < tiles_num_Y * tiles_num_X; ++t) { + const size_t tile_Y = t / tiles_num_X; + const size_t tile_X = t % tiles_num_X; + const size_t tile_offset_Y = tile_Y * tile_size_Y; + const size_t tile_offset_X = tile_X * tile_size_X; + + for (size_t ii = 0; ii < blocks_per_tile_Y; ++ii) { + const size_t block_idx_Y = tile_Y * blocks_per_tile_Y + ii; + const size_t block_offset_Y = ii * block_size_Y; + const size_t i_min = tile_offset_Y + block_offset_Y; + if (i_min >= rows) continue; + const size_t i_max = std::min(i_min + block_size_Y, rows); + + for (size_t jj = 0; jj < blocks_per_tile_X; ++jj) { + const size_t block_idx_X = tile_X * blocks_per_tile_X + jj; + const size_t block_offset_X = jj * block_size_X; + const size_t j_min = tile_offset_X + block_offset_X; + if (j_min >= cols) continue; + const size_t j_max = std::min(j_min + block_size_X, cols); + + const size_t mx_scale_idx = block_idx_Y * scales_stride + block_idx_X; + const size_t mx_scale_idx_gate = block_idx_Y * scales_stride + block_idx_X + + cols / block_size_X; + scale_block( + grad, input, output, output_scales, mx_scale_idx, mx_scale_idx_gate, + thread_amax, i_min, i_max, j_min, j_max, cols); + } + } + } + if (thread_amax > amax) { + amax = thread_amax; + } + } + ref_amax = amax; +} + +template +void compute_ref_x2(const IType* grad, + const IType* input, + OType* output_rowwise, + OType* output_colwise, + fp8e8m0* scales_rowwise, + fp8e8m0* scales_colwise, + float& ref_amax, + const size_t rows, + const size_t cols, + const size_t block_size_Y, + const size_t block_size_X, + const size_t scales_stride_rowwise, + const size_t scales_stride_colwise) { + compute_ref_x1( + grad, input, output_rowwise, scales_rowwise, ref_amax, rows, cols, 1, block_size_X, scales_stride_rowwise); + compute_ref_x1( + grad, input, output_colwise, scales_colwise, ref_amax, rows, cols, block_size_Y, 1, scales_stride_colwise); +} + +/** + * Scaling along single dimension (either rows or columns) + * Produces one set of output data and the corresponding data of the fused operation (dbias): + * 1) Scaled rows + row-wise scaling factors + * OR + * 2) Scaled columns + column-wise scaling factors + */ +template +void performTest_x1(const size_t rows, + const size_t cols, + const size_t block_size_rows, + const size_t block_size_cols, + InputsFillCase fill_case) { + using namespace test; + using EncodingType = fp32; + DType itype = TypeInfo::dtype; + DType otype = TypeInfo::dtype; + + const bool rowwise = (block_size_rows == 1) && (block_size_cols == 32); + const bool colwise = (block_size_rows == 32) && (block_size_cols == 1); + NVTE_CHECK(rowwise || colwise); + + // std::cout << "unpadded_blocks_Y: " << unpadded_blocks_Y << std::endl; + // std::cout << "unpadded_blocks_X: " << unpadded_blocks_X << std::endl; + // std::cout << "blocks_Y: " << blocks_Y << std::endl; + // std::cout << "blocks_X: " << blocks_X << std::endl; + // std::cout << "scales_stride: " << scales_stride << std::endl; + + Tensor grad("grad", { rows, cols }, itype); + Tensor input("input", { rows, cols * 2 }, itype); + + const size_t output_cols = (IS_DGATED ? 2 : 1) * cols; + + const std::array scale_dims = get_scale_tensor_dims(rows, output_cols, block_size_rows, + block_size_cols); + + const size_t unpadded_blocks_Y = scale_dims[0]; + const size_t unpadded_blocks_X = scale_dims[1]; + const size_t blocks_Y = scale_dims[2]; + const size_t blocks_X = scale_dims[3]; + const size_t scales_stride = blocks_X; + + Tensor output("output", std::vector{ rows, output_cols }, otype, + rowwise, colwise, NVTE_MXFP8_1D_SCALING); + + std::unique_ptr ref_output = std::make_unique(rows * output_cols); + std::unique_ptr ref_output_scales = std::make_unique(blocks_Y * blocks_X); + + for (size_t i = 0; i < blocks_Y * blocks_X; ++i) { + ref_output_scales[i] = 0; + } + + // fillCase(&grad, fill_case); + if constexpr (IS_DGATED) { + fillUniform(&grad); + } + fillUniform(&input); + + if constexpr (IS_DGATED) { + nvte_dswiglu(grad.data(), input.data(), output.data(), 0); + } else { + nvte_swiglu(input.data(), output.data(), 0); + } + cudaDeviceSynchronize(); + + auto err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + + float ref_amax = 0; + compute_ref_x1(grad.rowwise_cpu_dptr(), + input.rowwise_cpu_dptr(), + ref_output.get(), + ref_output_scales.get(), + ref_amax, + rows, + cols, + block_size_rows, + block_size_cols, + scales_stride); + + auto [atol, rtol] = getTolerances(otype); + compareResults("output", output, ref_output.get(), rowwise, atol, rtol); + + const uint8_t * const gpu_scales_ptr = rowwise + ? output.rowwise_cpu_scale_inv_ptr() + : output.columnwise_cpu_scale_inv_ptr(); + if (rowwise) { + compare_e8m0_scaling_factors("rowwise scales", gpu_scales_ptr, ref_output_scales.get(), + unpadded_blocks_Y, unpadded_blocks_X, scales_stride); + } else { + compare_e8m0_scaling_factors("colwise scales", gpu_scales_ptr, ref_output_scales.get(), + unpadded_blocks_Y, unpadded_blocks_X, scales_stride); + } +} + +/** + * Scaling along both dimensions (rows and columns) + * Produces two sets of scaled output data and the corresponding data of the fused operation (dbias): + * 1) Scaled rows + row-wise scaling factors + * AND + * 2) Scaled columns + column-wise scaling factors + */ +template +void performTest_x2(const size_t rows, + const size_t cols, + const size_t block_size_rows, + const size_t block_size_cols, + InputsFillCase fill_case) { + using namespace test; + using EncodingType = fp32; + DType itype = TypeInfo::dtype; + DType otype = TypeInfo::dtype; + + Tensor grad("grad", { rows, cols }, itype); + Tensor input("input", { rows, cols * 2 }, itype); + + const size_t output_cols = (IS_DGATED ? 2 : 1) * cols; + + const std::array scale_dims_rowwise = get_scale_tensor_dims(rows, output_cols, 1, 32); + const std::array scale_dims_colwise = get_scale_tensor_dims(rows, output_cols, 32, 1); + + const size_t unpadded_blocks_Y_rowwise = scale_dims_rowwise[0]; + const size_t unpadded_blocks_X_rowwise = scale_dims_rowwise[1]; + const size_t blocks_Y_rowwise = scale_dims_rowwise[2]; + const size_t blocks_X_rowwise = scale_dims_rowwise[3]; + const size_t scales_stride_rowwise = blocks_X_rowwise; + + const size_t unpadded_blocks_Y_colwise = scale_dims_colwise[0]; + const size_t unpadded_blocks_X_colwise = scale_dims_colwise[1]; + const size_t blocks_Y_colwise = scale_dims_colwise[2]; + const size_t blocks_X_colwise = scale_dims_colwise[3]; + const size_t scales_stride_colwise = blocks_X_colwise; + + Tensor output("output", std::vector{ rows, output_cols }, otype, + true, true, NVTE_MXFP8_1D_SCALING); + + std::unique_ptr ref_output_rowwise = std::make_unique(rows * output_cols); + std::unique_ptr ref_output_colwise = std::make_unique(rows * output_cols); + std::unique_ptr ref_scales_rowwise = std::make_unique(blocks_Y_rowwise * blocks_X_rowwise); + std::unique_ptr ref_scales_colwise = std::make_unique(blocks_Y_colwise * blocks_X_colwise); + + for (size_t i = 0; i < blocks_Y_rowwise * blocks_X_rowwise; ++i) { + ref_scales_rowwise[i] = 0; + } + for (size_t i = 0; i < blocks_Y_colwise * blocks_X_colwise; ++i) { + ref_scales_colwise[i] = 0; + } + + // fillCase(&grad, fill_case); + if constexpr (IS_DGATED) { + fillUniform(&grad); + } + fillUniform(&input); + + if constexpr (IS_DGATED) { + nvte_dswiglu(grad.data(), input.data(), output.data(), 0); + } else { + nvte_swiglu(input.data(), output.data(), 0); + } + cudaDeviceSynchronize(); + + auto err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + + float ref_amax = 0; + compute_ref_x2(grad.rowwise_cpu_dptr(), + input.rowwise_cpu_dptr(), + ref_output_rowwise.get(), + ref_output_colwise.get(), + ref_scales_rowwise.get(), + ref_scales_colwise.get(), + ref_amax, + rows, + cols, + block_size_rows, + block_size_cols, + scales_stride_rowwise, + scales_stride_colwise); + + auto [atol, rtol] = getTolerances(otype); + auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); + compareResults("output_c_rowwise", output, ref_output_rowwise.get(), true, atol, rtol); + compareResults("output_c_colwise", output, ref_output_colwise.get(), false, atol, rtol); + compare_e8m0_scaling_factors("scales_rowwise", output.rowwise_cpu_scale_inv_ptr(), + ref_scales_rowwise.get(), unpadded_blocks_Y_rowwise, + unpadded_blocks_X_rowwise, scales_stride_rowwise); + compare_e8m0_scaling_factors("scales_colwise", output.columnwise_cpu_scale_inv_ptr(), + ref_scales_colwise.get(), unpadded_blocks_Y_colwise, + unpadded_blocks_X_colwise, scales_stride_colwise); +} + +std::vector> matrix_sizes = { + {1, 32}, + {16, 64}, + {65, 96}, + {128, 128}, + {256, 256}, + {993, 512}, + {768, 1024}, + {65536, 128}, + {16384, 1632}, +}; + +std::vector> block_sizes = { + {1, 32}, + {32, 1}, + {32, 32}, +}; + +std::vector input_scenarios = { + InputsFillCase::uniform, + // InputsFillCase::zeros, + // InputsFillCase::zero_to_minNorm, + // InputsFillCase::minNorm_to_maxNorm, + // InputsFillCase::maxNorm_to_inf +}; + +std::vector is_dgated_op = { + true, + false +}; + +} // namespace + +class CastMXFP8_GatedActTestSuite : public ::testing::TestWithParam + , + std::pair, + transformer_engine::DType, + transformer_engine::DType, + InputsFillCase, + bool>> {}; + +TEST_P(CastMXFP8_GatedActTestSuite, TestCastMXFP8Swiglu) { + // Skip tests for pre-Blackwell architectures + if (getDeviceComputeCapability() < blackwellComputeCapability) { + GTEST_SKIP(); + } + + using namespace transformer_engine; + using namespace test; + + const auto matrix_size = std::get<0>(GetParam()); + const auto block_size = std::get<1>(GetParam()); + const DType input_type = std::get<2>(GetParam()); + const DType output_type = std::get<3>(GetParam()); + const InputsFillCase fill_case = std::get<4>(GetParam()); + const bool IS_DGATED = std::get<5>(GetParam()); + + TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(output_type, OType, + if (block_size.first == 1 || block_size.second == 1) { + if (IS_DGATED) { + performTest_x1(matrix_size.first, matrix_size.second, + block_size.first, block_size.second, fill_case); + } else { + performTest_x1(matrix_size.first, matrix_size.second, + block_size.first, block_size.second, fill_case); + } + } else { + if (IS_DGATED) { + performTest_x2(matrix_size.first, matrix_size.second, + block_size.first, block_size.second, fill_case); + } else { + performTest_x2(matrix_size.first, matrix_size.second, + block_size.first, block_size.second, fill_case); + } + } + ); + ); +} + +INSTANTIATE_TEST_SUITE_P( + OperatorTest, + CastMXFP8_GatedActTestSuite, + ::testing::Combine( + ::testing::ValuesIn(matrix_sizes), + ::testing::ValuesIn(block_sizes), + ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), + ::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2), + ::testing::ValuesIn(input_scenarios), + ::testing::ValuesIn(is_dgated_op)), + [](const testing::TestParamInfo& info) { + std::string name = std::to_string(std::get<0>(info.param).first) + "X" + + std::to_string(std::get<0>(info.param).second) + "X" + + std::to_string(std::get<1>(info.param).first) + "X" + + std::to_string(std::get<1>(info.param).second) + "X" + + test::typeName(std::get<2>(info.param)) + "X" + + test::typeName(std::get<3>(info.param)) + "X" + + test::caseName(std::get<4>(info.param)) + "X" + + (std::get<5>(info.param) ? "DGATED" : "GATED"); + return name; + }); diff --git a/tests/cpp/operator/test_cast_transpose.cu b/tests/cpp/operator/test_cast_transpose.cu index 05fcafb0b1..830682eec3 100644 --- a/tests/cpp/operator/test_cast_transpose.cu +++ b/tests/cpp/operator/test_cast_transpose.cu @@ -14,7 +14,7 @@ #include #include -#include +#include #include "../test_common.h" using namespace transformer_engine; @@ -45,36 +45,34 @@ void performTest(const size_t N, const size_t H) { DType itype = TypeInfo::dtype; DType otype = TypeInfo::dtype; - Tensor input({ N, H }, itype); - Tensor output_c({ N, H }, otype); - Tensor output_t({ H, N }, otype); + Tensor input("input", { N, H }, itype); + Tensor output("output", { N, H }, otype, true, true); std::unique_ptr ref_output_c = std::make_unique(N * H); std::unique_ptr ref_output_t = std::make_unique(N * H); fillUniform(&input); - setRandomScale(&output_c); - output_t.shareFP8Meta(output_c); + setRandomScale(&output); - nvte_cast_transpose(input.data(), output_c.data(), output_t.data(), 0); + nvte_quantize(input.data(), output.data(), 0); float ref_amax; - compute_ref(input.cpu_dptr(), ref_output_c.get(), + compute_ref(input.rowwise_cpu_dptr(), ref_output_c.get(), ref_output_t.get(), N, H, &ref_amax, - output_c.scale()); + output.scale()); cudaDeviceSynchronize(); auto err = cudaGetLastError(); ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); if (isFp8Type(otype)) { auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); - compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax); - float ref_scale_inv = 1.f / output_c.scale(); - compareResults("scale_inv", output_c.scale_inv(), ref_scale_inv, atol_amax, rtol_amax); + compareResults("amax", output.amax(), ref_amax, atol_amax, rtol_amax); + float ref_scale_inv = 1.f / output.scale(); + compareResults("scale_inv", output.rowwise_scale_inv(), ref_scale_inv, atol_amax, rtol_amax); } auto [atol, rtol] = getTolerances(otype); - compareResults("output_c", output_c, ref_output_c.get(), atol, rtol); - compareResults("output_t", output_t, ref_output_t.get(), atol, rtol); + compareResults("output_c", output, ref_output_c.get(), true, atol, rtol); + compareResults("output_t", output, ref_output_t.get(), false, atol, rtol); } std::vector> test_cases = {{2048, 12288}, diff --git a/tests/cpp/operator/test_cast_transpose_dbias.cu b/tests/cpp/operator/test_cast_transpose_dbias.cu index 72d890f8e9..53918e2699 100644 --- a/tests/cpp/operator/test_cast_transpose_dbias.cu +++ b/tests/cpp/operator/test_cast_transpose_dbias.cu @@ -15,7 +15,7 @@ #include #include -#include +#include #include "../test_common.h" using namespace transformer_engine; @@ -64,26 +64,23 @@ void performTest(const size_t N, const size_t H) { DType itype = TypeInfo::dtype; DType otype = TypeInfo::dtype; - DType ctype = TypeInfo::dtype; - Tensor input({N, H}, itype); + Tensor input("input", {N, H}, itype); - Tensor output_c({N, H}, otype); - Tensor output_t({ H, N}, otype); + Tensor output("output", {N, H}, otype, true, true); // dbias has the same data type with "output grad" - Tensor dbias({H}, itype); + Tensor dbias("dbias", {H}, itype); fillUniform(&input); - setRandomScale(&output_c); - output_t.shareFP8Meta(output_c); + setRandomScale(&output); std::unique_ptr ref_output_c = std::make_unique(N*H); std::unique_ptr ref_output_t = std::make_unique(N*H); std::unique_ptr ref_output_dbias = std::make_unique(H); CType ref_amax; - compute_ref_cast_transpose_dbias(input.cpu_dptr(), - output_c.scale(), + compute_ref_cast_transpose_dbias(input.rowwise_cpu_dptr(), + output.scale(), ref_output_c.get(), ref_output_t.get(), &ref_amax, @@ -92,22 +89,20 @@ void performTest(const size_t N, const size_t H) { Tensor workspace; - nvte_cast_transpose_dbias(input.data(), - output_c.data(), - output_t.data(), - dbias.data(), - workspace.data(), - 0); + nvte_quantize_dbias(input.data(), + output.data(), + dbias.data(), + workspace.data(), + 0); - workspace = Tensor(workspace.shape(), workspace.dtype()); + workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype()); - nvte_cast_transpose_dbias(input.data(), - output_c.data(), - output_t.data(), - dbias.data(), - workspace.data(), - 0); + nvte_quantize_dbias(input.data(), + output.data(), + dbias.data(), + workspace.data(), + 0); cudaDeviceSynchronize(); auto err = cudaGetLastError(); @@ -115,17 +110,17 @@ void performTest(const size_t N, const size_t H) { if (isFp8Type(otype)) { auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); - compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax); - float ref_scale_inv = 1.f / output_c.scale(); - compareResults("scale_inv", output_c.scale_inv(), ref_scale_inv, atol_amax, rtol_amax); + compareResults("amax", output.amax(), ref_amax, atol_amax, rtol_amax); + float ref_scale_inv = 1.f / output.scale(); + compareResults("scale_inv", output.rowwise_scale_inv(), ref_scale_inv, atol_amax, rtol_amax); } auto [atol, rtol] = getTolerances(otype); - compareResults("output_c", output_c, ref_output_c.get(), atol, rtol); - compareResults("output_t", output_t, ref_output_t.get(), atol, rtol); + compareResults("output_c", output, ref_output_c.get(), true, atol, rtol); + compareResults("output_t", output, ref_output_t.get(), false, atol, rtol); auto [atol_dbias, rtol_dbias] = getTolerances(itype); rtol_dbias *= 4; - compareResults("output_dbias", dbias, ref_output_dbias.get(), atol_dbias, rtol_dbias); + compareResults("output_dbias", dbias, ref_output_dbias.get(), true, atol_dbias, rtol_dbias); } std::vector> test_cases = {{64, 400}, diff --git a/tests/cpp/operator/test_cast_transpose_dbias_dgelu.cu b/tests/cpp/operator/test_cast_transpose_dbias_dgelu.cu index d3ba31fa53..15c7d8d665 100644 --- a/tests/cpp/operator/test_cast_transpose_dbias_dgelu.cu +++ b/tests/cpp/operator/test_cast_transpose_dbias_dgelu.cu @@ -75,29 +75,26 @@ void performTest(const size_t N, const size_t H) { DType itype = TypeInfo::dtype; DType otype = TypeInfo::dtype; - DType ctype = TypeInfo::dtype; - Tensor input({N, H}, itype); - Tensor gelu_input({N, H}, itype); + Tensor input("input", {N, H}, itype); + Tensor gelu_input("gelu_input", {N, H}, itype); - Tensor output_c({N, H}, otype); - Tensor output_t({ H, N}, otype); + Tensor output("output", {N, H}, otype, true, true); // dbias has the same data type with "output grad" - Tensor dbias({H}, itype); + Tensor dbias("dbias", {H}, itype); fillUniform(&input); fillUniform(&gelu_input); - setRandomScale(&output_c); - output_t.shareFP8Meta(output_c); + setRandomScale(&output); std::unique_ptr ref_output_c = std::make_unique(N*H); std::unique_ptr ref_output_t = std::make_unique(N*H); std::unique_ptr ref_output_dbias = std::make_unique(H); CType ref_amax; - compute_ref_cast_transpose_dbias_dgelu(input.cpu_dptr(), - gelu_input.cpu_dptr(), - output_c.scale(), + compute_ref_cast_transpose_dbias_dgelu(input.rowwise_cpu_dptr(), + gelu_input.rowwise_cpu_dptr(), + output.scale(), ref_output_c.get(), ref_output_t.get(), &ref_amax, @@ -108,19 +105,17 @@ void performTest(const size_t N, const size_t H) { nvte_cast_transpose_dbias_dgelu(input.data(), gelu_input.data(), - output_c.data(), - output_t.data(), + output.data(), dbias.data(), workspace.data(), 0); - workspace = Tensor(workspace.shape(), workspace.dtype()); + workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype()); nvte_cast_transpose_dbias_dgelu(input.data(), gelu_input.data(), - output_c.data(), - output_t.data(), + output.data(), dbias.data(), workspace.data(), 0); @@ -131,18 +126,18 @@ void performTest(const size_t N, const size_t H) { if (isFp8Type(otype)) { auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); - compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax); - float ref_scale_inv = 1.f / output_c.scale(); - compareResults("scale_inv", output_c.scale_inv(), ref_scale_inv, atol_amax, rtol_amax); + compareResults("amax", output.amax(), ref_amax, atol_amax, rtol_amax); + float ref_scale_inv = 1.f / output.scale(); + compareResults("scale_inv", output.rowwise_scale_inv(), ref_scale_inv, atol_amax, rtol_amax); } auto [atol, rtol] = getTolerances(otype); - compareResults("output_c", output_c, ref_output_c.get(), atol, rtol); - compareResults("output_t", output_t, ref_output_t.get(), atol, rtol); + compareResults("output_c", output, ref_output_c.get(), true, atol, rtol); + compareResults("output_t", output, ref_output_t.get(), false, atol, rtol); auto [atol_dbias, rtol_dbias] = getTolerances(itype); rtol_dbias *= 4; - compareResults("output_dbias", dbias, ref_output_dbias.get(), atol_dbias, rtol_dbias); + compareResults("output_dbias", dbias, ref_output_dbias.get(), true, atol_dbias, rtol_dbias); } std::vector> test_cases = {{64, 400}, diff --git a/tests/cpp/operator/test_cast_transpose_dgeglu.cu b/tests/cpp/operator/test_cast_transpose_dgeglu.cu index 03cec4e658..ae2da7bad2 100644 --- a/tests/cpp/operator/test_cast_transpose_dgeglu.cu +++ b/tests/cpp/operator/test_cast_transpose_dgeglu.cu @@ -74,24 +74,22 @@ void performTest(const size_t N, const size_t H) { DType itype = TypeInfo::dtype; DType otype = TypeInfo::dtype; - Tensor grad({N, H}, itype); - Tensor input({N, H * 2}, itype); - Tensor output_c({N, H * 2}, otype); - Tensor output_t({H * 2, N}, otype); + Tensor grad("grad", {N, H}, itype); + Tensor input("input", {N, H * 2}, itype); + Tensor output("output", {N, H * 2}, otype, true, true); fillUniform(&grad); fillUniform(&input); - setRandomScale(&output_c); - output_t.shareFP8Meta(output_c); + setRandomScale(&output); std::unique_ptr ref_output_c = std::make_unique(N * H * 2); std::unique_ptr ref_output_t = std::make_unique(N * H * 2); - nvte_dgeglu_cast_transpose(grad.data(), input.data(), output_c.data(), output_t.data(), 0); + nvte_dgeglu_cast_transpose(grad.data(), input.data(), output.data(), 0); CType ref_amax; - compute_ref_cast_transpose_dgated_gelu(grad.cpu_dptr(), input.cpu_dptr(), - output_c.scale(), ref_output_c.get(), ref_output_t.get(), + compute_ref_cast_transpose_dgated_gelu(grad.rowwise_cpu_dptr(), input.rowwise_cpu_dptr(), + output.scale(), ref_output_c.get(), ref_output_t.get(), &ref_amax, N, H); cudaDeviceSynchronize(); @@ -100,14 +98,14 @@ void performTest(const size_t N, const size_t H) { if (isFp8Type(otype)) { auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); - compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax); - float ref_scale_inv = 1.f / output_c.scale(); - compareResults("scale_inv", output_c.scale_inv(), ref_scale_inv, atol_amax, rtol_amax); + compareResults("amax", output.amax(), ref_amax, atol_amax, rtol_amax); + float ref_scale_inv = 1.f / output.scale(); + compareResults("scale_inv", output.rowwise_scale_inv(), ref_scale_inv, atol_amax, rtol_amax); } auto [atol, rtol] = getTolerances(otype); - compareResults("output_c", output_c, ref_output_c.get(), atol, rtol); - compareResults("output_t", output_t, ref_output_t.get(), atol, rtol); + compareResults("output_c", output, ref_output_c.get(), true, atol, rtol); + compareResults("output_t", output, ref_output_t.get(), false, atol, rtol); } std::vector> test_cases = {{64, 400}, {4096, 2048}, {768, 2816}, diff --git a/tests/cpp/operator/test_causal_softmax.cu b/tests/cpp/operator/test_causal_softmax.cu index 5401b03296..2fdc0a524d 100644 --- a/tests/cpp/operator/test_causal_softmax.cu +++ b/tests/cpp/operator/test_causal_softmax.cu @@ -153,11 +153,11 @@ void performTest( DType itype = TypeInfo::dtype; - Tensor data_in({ batches, heads, rows, cols }, itype); - Tensor softmax_out({ batches, heads, rows, cols }, itype); - Tensor softmax_in({ batches, heads, rows, cols }, itype); - Tensor grads_in({ batches, heads, rows, cols }, itype); - Tensor grads_out({ batches, heads, rows, cols }, itype); + Tensor data_in("data_in", { batches, heads, rows, cols }, itype); + Tensor softmax_out("softmax_out", { batches, heads, rows, cols }, itype); + Tensor softmax_in("softmax_in", { batches, heads, rows, cols }, itype); + Tensor grads_in("grads_in", { batches, heads, rows, cols }, itype); + Tensor grads_out("grads_out", { batches, heads, rows, cols }, itype); const size_t elements_total = batches * heads * rows * cols; std::unique_ptr softmax_out_ref = std::make_unique(elements_total); @@ -175,9 +175,9 @@ void performTest( // Reference implementations - compute_fwd_ref(softmax_out_ref.get(), data_in.cpu_dptr(), + compute_fwd_ref(softmax_out_ref.get(), data_in.rowwise_cpu_dptr(), compute_buffer.get(), scaling_factor, batches, heads, rows, cols); - compute_bwd_ref(grads_out_ref.get(), grads_in.cpu_dptr(), softmax_in.cpu_dptr(), + compute_bwd_ref(grads_out_ref.get(), grads_in.rowwise_cpu_dptr(), softmax_in.rowwise_cpu_dptr(), compute_buffer.get(), scaling_factor, batches, heads, rows, cols); cudaDeviceSynchronize(); @@ -187,8 +187,8 @@ void performTest( if(itype == DType::kBFloat16) { atol = 1e-3; } - compareResults("softmax_fwd", softmax_out, softmax_out_ref.get(), atol, rtol); - compareResults("softmax_bwd", grads_out, grads_out_ref.get(), atol, rtol); + compareResults("softmax_fwd", softmax_out, softmax_out_ref.get(), true, atol, rtol); + compareResults("softmax_bwd", grads_out, grads_out_ref.get(), true, atol, rtol); } // [Batches, Attention Heads, Query Sequence Length, Key Sequence Length, Scaling Factor] diff --git a/tests/cpp/operator/test_dequantize_mxfp8.cu b/tests/cpp/operator/test_dequantize_mxfp8.cu new file mode 100644 index 0000000000..701deb38bb --- /dev/null +++ b/tests/cpp/operator/test_dequantize_mxfp8.cu @@ -0,0 +1,452 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include +#include +#include "../test_common.h" +#include "transformer_engine/transformer_engine.h" + +using namespace transformer_engine; +using namespace test; + +namespace { + +template +void dequantize_block(const InputType* input, + OutputType* output, + fp8e8m0* scales, + const size_t scale_idx, + const size_t i_min, + const size_t i_max, + const size_t j_min, + const size_t j_max, + const size_t cols) +{ + const fp8e8m0 biased_exponent = scales[scale_idx]; + const float block_scale = exp2f(static_cast(biased_exponent) - FP32_EXPONENT_BIAS); + const float elem_scale = block_scale; + + // Dequantize elements in the block + for (size_t i = i_min; i < i_max; ++i) { + for (size_t j = j_min; j < j_max; ++j) { + const size_t idx = i * cols + j; + const float elt = static_cast(input[idx]); + output[idx] = static_cast(elt * elem_scale); + } + } +} + +template +void compute_ref_x1(const InputType* input, + OutputType* output, + fp8e8m0* scales, + const size_t rows, + const size_t cols, + const size_t block_size_Y, + const size_t block_size_X, + const size_t scales_stride) +{ + const size_t blocks_Y = (rows + block_size_Y - 1) / block_size_Y; + const size_t blocks_X = (cols + block_size_X - 1) / block_size_X; + + for (size_t ii = 0; ii < blocks_Y; ++ii) { + const size_t i_min = ii * block_size_Y; + const size_t i_max = std::min((ii + 1) * block_size_Y, rows); + for (size_t jj = 0; jj < blocks_X; ++jj) { + const size_t j_min = jj * block_size_X; + const size_t j_max = std::min((jj + 1) * block_size_X, cols); + const size_t scale_idx = ii * scales_stride + jj; + dequantize_block( + input, output, scales, scale_idx, i_min, i_max, j_min, j_max, cols); + } + } +} + +template +void compute_ref_x2(const InputType* input, + OutputType* output_rowwise, + OutputType* output_colwise, + fp8e8m0* scales_rowwise, + fp8e8m0* scales_colwise, + const size_t rows, + const size_t cols, + const size_t block_size_Y, + const size_t block_size_X, + const size_t scales_stride_rowwise, + const size_t scales_stride_colwise) +{ + compute_ref_x1(input, output_rowwise, scales_rowwise, rows, cols, 1, block_size_X, scales_stride_rowwise); + compute_ref_x1(input, output_colwise, scales_colwise, rows, cols, block_size_Y, 1, scales_stride_colwise); +} + +void generate_scales(fp8e8m0 * const scales_ref, + fp8e8m0 * const scales, + const size_t blocks_num, + std::mt19937& gen, + std::uniform_int_distribution dis) +{ + for (size_t i = 0; i < blocks_num; ++i) { + const fp8e8m0 val = dis(gen); + scales_ref[i] = val; + scales[i] = val; + } +} + +template +void generate_data(InputType * const data, + const size_t rows, + const size_t cols, + std::mt19937& gen, + std::uniform_real_distribution<>& dis, + std::uniform_real_distribution<>& dis_sign) +{ + for (size_t i = 0; i < rows; ++i) { + for (size_t j = 0; j < cols; ++j) { + const size_t idx = i * cols + j; + const bool is_negative = (dis_sign(gen) < 0.0); + double val = dis(gen); + if (is_negative) { + val = -val; + } + data[idx] = static_cast(val); + } + } +} + +template +void fill_tensor_data(Tensor& input, + fp8e8m0 * const scales_rowwise, + fp8e8m0 * const scales_colwise, + const bool is_rowwise_scaling, + const bool is_colwise_scaling, + const size_t rows, + const size_t cols, + const size_t blocks_num_rowwise, + const size_t blocks_num_colwise) +{ + const double minAbs = Numeric_Traits::minNorm; + const double maxAbs = Numeric_Traits::maxNorm; + static std::mt19937 gen(12345); + std::uniform_real_distribution<> dis(minAbs, maxAbs); + std::uniform_real_distribution<> dis_sign(-1.0, 1.0); + std::uniform_int_distribution int_dis(0, 255); + + if (is_rowwise_scaling) { + generate_scales(scales_rowwise, input.rowwise_cpu_scale_inv_ptr(), blocks_num_rowwise, gen, int_dis); + generate_data(input.rowwise_cpu_dptr(), rows, cols, gen, dis, dis_sign); + } + + if (is_colwise_scaling) { + generate_scales(scales_colwise, input.columnwise_cpu_scale_inv_ptr(), blocks_num_colwise, gen, int_dis); + generate_data(input.columnwise_cpu_dptr(), rows, cols, gen, dis, dis_sign); + } + + input.from_cpu(); +} + +// Dequantize along single dimension (either row- or columnwise) +template +void performTest_x1(const size_t rows, + const size_t cols, + const bool rowwise, + const bool colwise) +{ + using namespace test; + using EncodingType = fp32; + DType itype = TypeInfo::dtype; + DType otype = TypeInfo::dtype; + + const size_t block_size_rows = rowwise ? 1 : 32; + const size_t block_size_cols = colwise ? 1 : 32; + + const size_t unpadded_blocks_Y_rowwise = rows; + const size_t unpadded_blocks_X_rowwise = divide_round_up(cols, block_size_cols); + const size_t unpadded_blocks_Y_colwise = divide_round_up(rows, block_size_rows); + const size_t unpadded_blocks_X_colwise = cols; + + const size_t blocks_Y_rowwise = round_up_to_nearest_multiple(unpadded_blocks_Y_rowwise, + scale_tensor_alignment_Y_rowwise); + const size_t blocks_X_rowwise = round_up_to_nearest_multiple(unpadded_blocks_X_rowwise, + scale_tensor_alignment_X_rowwise); + const size_t blocks_Y_colwise = round_up_to_nearest_multiple(unpadded_blocks_Y_colwise, + scale_tensor_alignment_Y_colwise); + const size_t blocks_X_colwise = round_up_to_nearest_multiple(unpadded_blocks_X_colwise, + scale_tensor_alignment_X_colwise); + + const size_t blocks_num_rowwise = blocks_Y_rowwise * blocks_X_rowwise; + const size_t blocks_num_colwise = blocks_Y_colwise * blocks_X_colwise; + + const size_t blocks_num = rowwise ? blocks_num_rowwise : blocks_num_colwise; + const size_t scales_stride = rowwise ? blocks_X_rowwise : blocks_X_colwise; + + Tensor input("input", { rows, cols }, itype, rowwise, colwise, NVTE_MXFP8_1D_SCALING); + + // Output data are written to the rowwise ptr regardless of the scaling direction + Tensor output("output", { rows, cols }, otype, true, false); + + std::unique_ptr ref_output = std::make_unique(rows * cols); + std::unique_ptr scales = std::make_unique(blocks_num); + + fill_tensor_data(input, scales.get(), scales.get(), rowwise, colwise, rows, cols, + blocks_num_rowwise, blocks_num_colwise); + + nvte_dequantize(input.data(), output.data(), 0); + + cudaDeviceSynchronize(); + auto err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + + InputType * data_ptr = rowwise + ? input.rowwise_cpu_dptr() + : input.columnwise_cpu_dptr(); + + compute_ref_x1(data_ptr, + ref_output.get(), + scales.get(), + rows, + cols, + block_size_rows, + block_size_cols, + scales_stride); + + auto [atol, rtol] = getTolerances(otype); + compareResults("output", output, ref_output.get(), true, atol, rtol); +} + +// Dequantize along single dimension (either row- or columnwise) +template +void performTest_quantize_then_dequantize(const size_t rows, + const size_t cols, + const bool rowwise, + const bool colwise) +{ + using namespace test; + using EncodingType = fp32; + DType in_type = TypeInfo::dtype; + DType intermed_type = TypeInfo::dtype; + DType out_type = TypeInfo::dtype; + + std::unique_ptr input_cpu = std::make_unique(rows * cols); + std::unique_ptr quantized_cpu = std::make_unique(rows * cols); + std::unique_ptr output_cpu = std::make_unique(rows * cols); + + // input --> quantized --> output (dequantized) + // input == output + Tensor input("input", { rows, cols }, in_type); + Tensor quantized("quantized", { rows, cols }, intermed_type, rowwise, colwise, NVTE_MXFP8_1D_SCALING); + + // Output data are written to the rowwise ptr regardless of the scaling direction + Tensor output("output", { rows, cols }, out_type, true, false); + + // fillCase(&input, InputsFillCase::minNorm_to_maxNorm); + fillCase(&input, InputsFillCase::uniform); + + const size_t copy_size = sizeof(InputType) * rows * cols; + cudaMemcpy(input_cpu.get(), input.rowwise_dptr(), copy_size, cudaMemcpyDeviceToHost); + + nvte_quantize(input.data(), quantized.data(), 0); + cudaDeviceSynchronize(); + + const size_t copy_size_quantized = sizeof(IntermediateType) * rows * cols; + if (rowwise) { + cudaMemcpy(quantized_cpu.get(), quantized.rowwise_dptr(), copy_size_quantized, cudaMemcpyDeviceToHost); + } + if (colwise) { + cudaMemcpy(quantized_cpu.get(), quantized.columnwise_dptr(), copy_size_quantized, cudaMemcpyDeviceToHost); + } + + nvte_dequantize(quantized.data(), output.data(), 0); + cudaDeviceSynchronize(); + + cudaMemcpy(output_cpu.get(), output.rowwise_dptr(), copy_size, cudaMemcpyDeviceToHost); + + auto err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + + auto [atol, rtol] = getTolerances(intermed_type); + compareResults("Quantize-Dequantize", input, output_cpu.get(), true, atol, rtol); +} + +// Dequantize along both dimensions (row- and columnwise) +template +void performTest_x2(const size_t rows, + const size_t cols, + const size_t block_size_rows, + const size_t block_size_cols) +{ + using namespace test; + using EncodingType = fp32; + DType itype = TypeInfo::dtype; + DType otype = TypeInfo::dtype; + + const size_t unpadded_blocks_Y_rowwise = rows; + const size_t unpadded_blocks_X_rowwise = divide_round_up(cols, block_size_cols); + const size_t unpadded_blocks_Y_colwise = divide_round_up(rows, block_size_rows); + const size_t unpadded_blocks_X_colwise = cols; + + const size_t blocks_Y_rowwise = round_up_to_nearest_multiple(unpadded_blocks_Y_rowwise, + scale_tensor_alignment_Y_rowwise); + const size_t blocks_X_rowwise = round_up_to_nearest_multiple(unpadded_blocks_X_rowwise, + scale_tensor_alignment_X_rowwise); + const size_t blocks_Y_colwise = round_up_to_nearest_multiple(unpadded_blocks_Y_colwise, + scale_tensor_alignment_Y_colwise); + const size_t blocks_X_colwise = round_up_to_nearest_multiple(unpadded_blocks_X_colwise, + scale_tensor_alignment_X_colwise); + + const size_t scales_stride_rowwise = blocks_X_rowwise; + const size_t scales_stride_colwise = blocks_X_colwise; + const size_t blocks_num_rowwise = blocks_Y_rowwise * blocks_X_rowwise; + const size_t blocks_num_colwise = blocks_Y_colwise * blocks_X_colwise; + + Tensor input("input", { rows, cols }, itype, true, true, NVTE_MXFP8_1D_SCALING); + Tensor output("output", { rows, cols }, otype); + + std::unique_ptr ref_output_rowwise = std::make_unique(rows * cols); + std::unique_ptr ref_output_colwise = std::make_unique(rows * cols); + std::unique_ptr ref_scales_rowwise = std::make_unique(blocks_num_rowwise); + std::unique_ptr ref_scales_colwise = std::make_unique(blocks_num_colwise); + + constexpr bool rowwise = true; + constexpr bool colwise = true; + fill_tensor_data(input, ref_scales_rowwise.get(), ref_scales_colwise.get(), + rowwise, colwise, rows, cols, blocks_num_rowwise, blocks_num_colwise); + + nvte_dequantize(input.data(), output.data(), 0); + + cudaDeviceSynchronize(); + auto err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + + compute_ref_x2(input.rowwise_cpu_dptr(), + ref_output_rowwise.get(), + ref_output_colwise.get(), + ref_scales_rowwise.get(), + ref_scales_colwise.get(), + rows, + cols, + block_size_rows, + block_size_cols, + scales_stride_rowwise, + scales_stride_colwise); + + auto [atol, rtol] = getTolerances(otype); + compareResults("output_rowwise", output, ref_output_rowwise.get(), true, atol, rtol); + compareResults("output_colwise", output, ref_output_colwise.get(), false, atol, rtol); +} + +std::vector> tensor_dims = { + {1, 16}, + {16, 48}, + {65, 96}, + {128, 128}, + {256, 256}, + {993, 512}, + {768, 1024}, + // {2048, 12288}, + // {65536, 128}, + // {16384, 1632}, + // {16384, 6144}, +}; + +std::vector> block_sizes = { + {1, 32}, + {32, 1}, + // {32, 32}, +}; + +} // namespace + +class DequantizeMXFP8TestSuite : public ::testing::TestWithParam + , + std::pair, + transformer_engine::DType, + transformer_engine::DType, + bool>> {}; + +TEST_P(DequantizeMXFP8TestSuite, TestDequantizeMXFP8) +{ + // Skip tests for pre-Blackwell architectures + if (getDeviceComputeCapability() < blackwellComputeCapability) { + GTEST_SKIP(); + } + + using namespace transformer_engine; + using namespace test; + + const auto tensor_size = std::get<0>(GetParam()); + const auto block_size = std::get<1>(GetParam()); + const DType input_type = std::get<2>(GetParam()); + const DType output_type = std::get<3>(GetParam()); + const bool quantize_then_dequantize = std::get<4>(GetParam()); + + const bool rowwise = block_size.second != 1; + const bool colwise = block_size.first != 1; + + // Skip tests for dequantization along both dimensions + if (rowwise && colwise) { + GTEST_SKIP(); + } + + // Skip cases with invalid alignment + if (rowwise && tensor_size.second % 32 != 0) { + GTEST_SKIP(); + } + if (colwise && tensor_size.first % 32 != 0) { + GTEST_SKIP(); + } + + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(input_type, InputType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(output_type, OutputType, + if (quantize_then_dequantize) { + // Mind the order of the Output/Input template parameters + performTest_quantize_then_dequantize( + tensor_size.first, tensor_size.second, rowwise, colwise); + } else { + if (block_size.first == 1 || block_size.second == 1) { + performTest_x1(tensor_size.first, tensor_size.second, + rowwise, colwise); + } else { + performTest_x2(tensor_size.first, tensor_size.second, + block_size.first, block_size.second); + } + } + ); + ); +} + +INSTANTIATE_TEST_SUITE_P( + OperatorTest, + DequantizeMXFP8TestSuite, + ::testing::Combine( + ::testing::ValuesIn(tensor_dims), + ::testing::ValuesIn(block_sizes), + ::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2), + ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), + ::testing::Values(false)), + [](const testing::TestParamInfo& info) + { + std::string name = std::to_string(std::get<0>(info.param).first) + "X" + + std::to_string(std::get<0>(info.param).second) + "X" + + std::to_string(std::get<1>(info.param).first) + "X" + + std::to_string(std::get<1>(info.param).second) + "X" + + test::typeName(std::get<2>(info.param)) + "X" + + test::typeName(std::get<3>(info.param)) + "X" + + (std::get<4>(info.param) ? "QD" : "D"); + return name; + } +); diff --git a/tests/cpp/operator/test_multi_cast_transpose.cu b/tests/cpp/operator/test_multi_cast_transpose.cu index e9f420e5b1..f07138caca 100644 --- a/tests/cpp/operator/test_multi_cast_transpose.cu +++ b/tests/cpp/operator/test_multi_cast_transpose.cu @@ -69,7 +69,7 @@ void performTest() { const size_t num_tensors = tensor_dims.size(); // Buffers for Transformer Engine implementation - std::vector input_list, output_c_list, output_t_list; + std::vector input_list, output_list; // Buffers for reference implementation std::vector> ref_input_list; @@ -81,25 +81,23 @@ void performTest() { for (size_t tensor_id = 0; tensor_id < num_tensors; ++tensor_id) { const size_t height = tensor_dims[tensor_id].first; const size_t width = tensor_dims[tensor_id].second; - input_list.emplace_back(Tensor({ height, width }, itype)); - output_c_list.emplace_back(Tensor({ height, width }, otype)); - output_t_list.emplace_back(Tensor({ width, height }, otype)); + input_list.emplace_back(Tensor("input_" + std::to_string(tensor_id), { height, width }, itype)); + output_list.emplace_back(Tensor("output_" + std::to_string(tensor_id), + { height, width }, otype, true, true)); auto& input = input_list.back(); - auto& output_c = output_c_list.back(); - auto& output_t = output_t_list.back(); + auto& output = output_list.back(); fillUniform(&input); - setRandomScale(&output_c); - output_t.shareFP8Meta(output_c); + setRandomScale(&output); ref_input_list.emplace_back(height*width); ref_output_c_list.emplace_back(height*width); ref_output_t_list.emplace_back(width*height); - std::copy(input.cpu_dptr(), - input.cpu_dptr() + height * width, + std::copy(input.rowwise_cpu_dptr(), + input.rowwise_cpu_dptr() + height * width, ref_input_list.back().begin()); - ref_scale_list[tensor_id] = output_c.scale(); + ref_scale_list[tensor_id] = output.scale(); ref_height_list[tensor_id] = height; ref_width_list[tensor_id] = width; } @@ -115,8 +113,7 @@ void performTest() { }; nvte_multi_cast_transpose(num_tensors, make_nvte_vector(input_list).data(), - make_nvte_vector(output_c_list).data(), - make_nvte_vector(output_t_list).data(), + make_nvte_vector(output_list).data(), 0); // Reference implementation @@ -136,23 +133,23 @@ void performTest() { if (isFp8Type(otype)) { auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); compareResults("amax", - output_c_list[tensor_id].amax(), + output_list[tensor_id].amax(), ref_amax_list[tensor_id], atol_amax, rtol_amax); compareResults("scale_inv", - output_c_list[tensor_id].scale_inv(), - 1.f / output_c_list[tensor_id].scale(), + output_list[tensor_id].rowwise_scale_inv(), + 1.f / output_list[tensor_id].scale(), atol_amax, rtol_amax); } auto [atol, rtol] = getTolerances(otype); compareResults("output_c", - output_c_list[tensor_id], + output_list[tensor_id], ref_output_c_list[tensor_id].data(), - atol, rtol); + true, atol, rtol); compareResults("output_t", - output_t_list[tensor_id], + output_list[tensor_id], ref_output_t_list[tensor_id].data(), - atol, rtol); + false, atol, rtol); } } diff --git a/tests/cpp/operator/test_multi_padding.cu b/tests/cpp/operator/test_multi_padding.cu index 23c824e857..b8475fe561 100644 --- a/tests/cpp/operator/test_multi_padding.cu +++ b/tests/cpp/operator/test_multi_padding.cu @@ -9,6 +9,7 @@ #include #include #include +#include #include #include @@ -84,8 +85,8 @@ void performTest() { const size_t height = tensor_dims[tensor_id].first; const size_t width = tensor_dims[tensor_id].second; const size_t padded_height = (height + align - 1) / align * align; - input_list.emplace_back(Tensor({ height, width }, itype)); - output_list.emplace_back(Tensor({ padded_height, width }, otype)); + input_list.emplace_back(Tensor("input_" + std::to_string(tensor_id), { height, width }, itype)); + output_list.emplace_back(Tensor("output_" + std::to_string(tensor_id), { padded_height, width }, otype)); auto& input = input_list.back(); auto& output = output_list.back(); @@ -95,8 +96,8 @@ void performTest() { ref_input_list.emplace_back(height*width); ref_output_list.emplace_back(padded_height*width); - std::copy(input.cpu_dptr(), - input.cpu_dptr() + height * width, + std::copy(input.rowwise_cpu_dptr(), + input.rowwise_cpu_dptr() + height * width, ref_input_list.back().begin()); ref_height_list[tensor_id] = height; ref_width_list[tensor_id] = width; @@ -134,6 +135,7 @@ void performTest() { compareResults("output", output_list[tensor_id], ref_output_list[tensor_id].data(), + true, atol, rtol); } } diff --git a/tests/cpp/operator/test_normalization.cu b/tests/cpp/operator/test_normalization.cu index 58152864eb..0004c2ce74 100644 --- a/tests/cpp/operator/test_normalization.cu +++ b/tests/cpp/operator/test_normalization.cu @@ -10,7 +10,6 @@ #include #include #include -#include #include #include @@ -176,6 +175,11 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, GTEST_SKIP() << "LN kernel does not support OutputType > InputType"; return; } + + if (getDeviceComputeCapability() < blackwellComputeCapability && use_cudnn) { + GTEST_SKIP() << "cuDNN normalizations not supported on pre-Blackwell GPUs yet!"; + } + using WeightType = InputType; DType itype = TypeInfo::dtype; DType wtype = TypeInfo::dtype; @@ -187,16 +191,16 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, return; } - Tensor input({ N, H }, itype); - Tensor z({ N, H }, otype); - Tensor gamma({ H }, wtype); - Tensor beta({ H }, wtype); - Tensor mu({ N }, DType::kFloat32); - Tensor rsigma({ N }, DType::kFloat32); - Tensor dz({ N, H }, wtype); - Tensor dx({ N, H }, itype); - Tensor dgamma({ H }, wtype); - Tensor dbeta({ H }, wtype); + Tensor input("input", { N, H }, itype); + Tensor z("z", { N, H }, otype); + Tensor gamma("gamma", { H }, wtype); + Tensor beta("beta", { H }, wtype); + Tensor mu("mu", { N }, DType::kFloat32); + Tensor rsigma("rsigma", { N }, DType::kFloat32); + Tensor dz("dz", { N, H }, wtype); + Tensor dx("dx", { N, H }, itype); + Tensor dgamma("dgamma", { H }, wtype); + Tensor dbeta("dbeta", { H }, wtype); Tensor workspace_fwd, workspace_bwd; fillUniform(&input); @@ -226,7 +230,7 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, nvte_layernorm_fwd(input.data(), gamma.data(), beta.data(), epsilon, z.data(), mu.data(), rsigma.data(), workspace_fwd.data(), prop.multiProcessorCount, zero_centered_gamma, 0); - workspace_fwd = Tensor(workspace_fwd.shape(), workspace_fwd.dtype()); + workspace_fwd = Tensor("workspace", workspace_fwd.rowwise_shape(), workspace_fwd.dtype()); nvte_layernorm_fwd(input.data(), gamma.data(), beta.data(), epsilon, z.data(), mu.data(), rsigma.data(), workspace_fwd.data(), prop.multiProcessorCount, zero_centered_gamma, 0); @@ -236,7 +240,7 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, dx.data(), dgamma.data(), dbeta.data(), workspace_bwd.data(), prop.multiProcessorCount, zero_centered_gamma, 0); - workspace_bwd = Tensor(workspace_bwd.shape(), workspace_bwd.dtype()); + workspace_bwd = Tensor("workspace", workspace_bwd.rowwise_shape(), workspace_bwd.dtype()); nvte_layernorm_bwd(dz.data(), input.data(), mu.data(), rsigma.data(), gamma.data(), dx.data(), dgamma.data(), dbeta.data(), @@ -246,7 +250,7 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, nvte_rmsnorm_fwd(input.data(), gamma.data(), epsilon, z.data(), rsigma.data(), workspace_fwd.data(), prop.multiProcessorCount, zero_centered_gamma, 0); - workspace_fwd = Tensor(workspace_fwd.shape(), workspace_fwd.dtype()); + workspace_fwd = Tensor("workspace", workspace_fwd.rowwise_shape(), workspace_fwd.dtype()); nvte_rmsnorm_fwd(input.data(), gamma.data(), epsilon, z.data(), rsigma.data(), workspace_fwd.data(), prop.multiProcessorCount, zero_centered_gamma, 0); @@ -255,7 +259,7 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, dx.data(), dgamma.data(), workspace_bwd.data(), prop.multiProcessorCount, zero_centered_gamma, 0); - workspace_bwd = Tensor(workspace_bwd.shape(), workspace_bwd.dtype()); + workspace_bwd = Tensor("workspace", workspace_bwd.rowwise_shape(), workspace_bwd.dtype()); nvte_rmsnorm_bwd(dz.data(), input.data(), rsigma.data(), gamma.data(), dx.data(), dgamma.data(), workspace_bwd.data(), @@ -272,23 +276,24 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, mu.to_cpu(); rsigma.to_cpu(); float ref_amax; - compute_ref_stats(norm_type, input.cpu_dptr(), ref_mu.get(), + compute_ref_stats(norm_type, input.rowwise_cpu_dptr(), ref_mu.get(), ref_rsigma.get(), N, H, epsilon); float ref_scale = isFp8Type(otype) ? z.scale() : 1.f; - compute_ref_output(norm_type, input.cpu_dptr(), - gamma.cpu_dptr(), - beta.cpu_dptr(), + compute_ref_output(norm_type, input.rowwise_cpu_dptr(), + gamma.rowwise_cpu_dptr(), + beta.rowwise_cpu_dptr(), ref_output.get(), - mu.cpu_dptr(), - rsigma.cpu_dptr(), + mu.rowwise_cpu_dptr(), + rsigma.rowwise_cpu_dptr(), N, H, &ref_amax, ref_scale, zero_centered_gamma, use_cudnn); - compute_ref_backward(norm_type, dz.cpu_dptr(), input.cpu_dptr(), - mu.cpu_dptr(), rsigma.cpu_dptr(), - gamma.cpu_dptr(), + compute_ref_backward(norm_type, dz.rowwise_cpu_dptr(), + input.rowwise_cpu_dptr(), + mu.rowwise_cpu_dptr(), rsigma.rowwise_cpu_dptr(), + gamma.rowwise_cpu_dptr(), ref_dx.get(), ref_dgamma.get(), ref_dbeta.get(), N, H, zero_centered_gamma, use_cudnn); @@ -301,25 +306,25 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, if (isFp8Type(otype)) { compareResults("amax", z.amax(), ref_amax, atol_amax, rtol_amax); float ref_scale_inv = 1.f / z.scale(); - compareResults("scale_inv", z.scale_inv(), ref_scale_inv, atol_amax, rtol_amax); + compareResults("scale_inv", z.rowwise_scale_inv(), ref_scale_inv, atol_amax, rtol_amax); } auto [atol_stats, rtol_stats] = getTolerances(DType::kFloat32); rtol_stats = 5e-5; - compareResults("mu", mu, ref_mu.get(), atol_stats, rtol_stats); - compareResults("rsigma", rsigma, ref_rsigma.get(), atol_stats, rtol_stats); + compareResults("mu", mu, ref_mu.get(), true, atol_stats, rtol_stats); + compareResults("rsigma", rsigma, ref_rsigma.get(), true, atol_stats, rtol_stats); auto [atol, rtol] = getTolerances(otype); if (otype == DType::kFloat32) { atol = 5e-7; } - compareResults("output", z, ref_output.get(), atol, rtol); + compareResults("output", z, ref_output.get(), true, atol, rtol); double atol_bwd = 5e-4; double rtol_bwd = 5e-4; - compareResults("dx", dx, ref_dx.get(), atol_bwd, rtol_bwd); - compareResults("dgamma", dgamma, ref_dgamma.get(), atol_bwd, rtol_bwd); - compareResults("dbeta", dbeta, ref_dbeta.get(), atol_bwd, rtol_bwd); + compareResults("dx", dx, ref_dx.get(), true, atol_bwd, rtol_bwd); + compareResults("dgamma", dgamma, ref_dgamma.get(), true, atol_bwd, rtol_bwd); + compareResults("dbeta", dbeta, ref_dbeta.get(), true, atol_bwd, rtol_bwd); } std::vector> test_cases = { @@ -357,24 +362,24 @@ TEST_P(NormTestSuite, TestNorm) { } INSTANTIATE_TEST_SUITE_P( - OperatorTest, - NormTestSuite, - ::testing::Combine( - ::testing::Values(false), //TODO: enabling tests for cudnn backend - ::testing::Values(NormType::LayerNorm, NormType::RMSNorm), - ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), - ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16, DType::kFloat8E4M3), - ::testing::ValuesIn(test_cases), - ::testing::Values(false, true)), - [](const testing::TestParamInfo& info) { + OperatorTest, + NormTestSuite, + ::testing::Combine( + ::testing::Values(true, false), + ::testing::Values(NormType::LayerNorm, NormType::RMSNorm), + ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), + ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16, DType::kFloat8E4M3), + ::testing::ValuesIn(test_cases), + ::testing::Values(false, true)), + [](const testing::TestParamInfo& info) { auto backend = std::get<0>(info.param) == false ? "Te" : "Cudnn"; -std::string name = - backend + - normToString.at(std::get<1>(info.param)) + "_" + - test::typeName(std::get<2>(info.param)) + "X" + - test::typeName(std::get<3>(info.param)) + "X" + - std::to_string(std::get<4>(info.param).first) + "X" + - std::to_string(std::get<4>(info.param).second) + "X" + - std::to_string(std::get<5>(info.param)); - return name; - }); + std::string name = + backend + + normToString.at(std::get<1>(info.param)) + "_" + + test::typeName(std::get<2>(info.param)) + "X" + + test::typeName(std::get<3>(info.param)) + "X" + + std::to_string(std::get<4>(info.param).first) + "X" + + std::to_string(std::get<4>(info.param).second) + "X" + + std::to_string(std::get<5>(info.param)); + return name; + }); diff --git a/tests/cpp/operator/test_normalization_mxfp8.cu b/tests/cpp/operator/test_normalization_mxfp8.cu new file mode 100644 index 0000000000..d1bdb6203b --- /dev/null +++ b/tests/cpp/operator/test_normalization_mxfp8.cu @@ -0,0 +1,337 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include +#include +#include "../test_common.h" + +using namespace transformer_engine; +using namespace test; + +namespace { + +using fp8e8m0 = byte; + +enum NormType { + LayerNorm, + RMSNorm +}; + +std::map normToString = { + {NormType::LayerNorm, "LayerNorm"}, + {NormType::RMSNorm, "RMSNorm"} +}; + +template +void dequantize_1x_kernel(InputType* input_ptr, ScaleType* scale_ptr, OutputType* output_ptr, + size_t rows, size_t cols, size_t scaling_mode_x, size_t scaling_mode_y){ + + const size_t block_size_Y = scaling_mode_x; // mind the mapping Y <-- x + const size_t block_size_X = scaling_mode_y; // and X <-- y + const size_t tile_size_Y = std::max(32lu, block_size_Y); + const size_t tile_size_X = std::max(64lu, block_size_X); + const size_t tiles_num_Y = (rows + tile_size_Y - 1) / tile_size_Y; + const size_t tiles_num_X = (cols + tile_size_X - 1) / tile_size_X; + const size_t blocks_per_tile_Y = tile_size_Y / block_size_Y; + const size_t blocks_per_tile_X = tile_size_X / block_size_X; + const size_t blocks_per_row = (cols + block_size_X - 1) / block_size_X; + + #pragma omp parallel for proc_bind(spread) schedule(static) + for (size_t t = 0; t < tiles_num_Y * tiles_num_X; ++t) { + const size_t tile_Y = t / tiles_num_X; + const size_t tile_X = t % tiles_num_X; + const size_t tile_offset_Y = tile_Y * tile_size_Y; + const size_t tile_offset_X = tile_X * tile_size_X; + + for (size_t ii = 0; ii < blocks_per_tile_Y; ++ii) { + const size_t block_idx_Y = tile_Y * blocks_per_tile_Y + ii; + const size_t block_offset_Y = ii * block_size_Y; + const size_t i_min = tile_offset_Y + block_offset_Y; + const size_t i_max = std::min(i_min + block_size_Y, rows); + + for (size_t jj = 0; jj < blocks_per_tile_X; ++jj) { + const size_t block_idx_X = tile_X * blocks_per_tile_X + jj; + const size_t block_offset_X = jj * block_size_X; + const size_t j_min = tile_offset_X + block_offset_X; + const size_t j_max = std::min(j_min + block_size_X, cols); + + const size_t mx_scale_idx = block_idx_Y * blocks_per_row + block_idx_X; + + // TODO: padded SFs i.e. (4,128) + const float scale_inv = exp2f(static_cast(scale_ptr[mx_scale_idx]) - FP32_EXPONENT_BIAS); + for (size_t i = i_min; i < i_max; ++i) { + for (size_t j = j_min; j < j_max; ++j) { + const size_t idx = i * cols + j; + const float elem = static_cast(input_ptr[idx]); + output_ptr[idx] = static_cast(elem * scale_inv); + } + } + } + } + } +} + +template +void dequantize_2x(Tensor& input, Tensor& output, bool is_training) +{ + input.to_cpu(); + auto scaling_mode = input.scaling_mode(); + assert(input.rowwise_shape().ndim == 2); + assert(input.columnwise_shape().ndim == 2); + + dequantize_1x_kernel(input.rowwise_cpu_dptr(), + input.rowwise_cpu_scale_inv_ptr(), + output.rowwise_cpu_dptr(), + input.rowwise_shape().data[0], input.rowwise_shape().data[1], + 1, 32); + if (is_training) + dequantize_1x_kernel(input.columnwise_cpu_dptr(), + input.columnwise_cpu_scale_inv_ptr(), + output.columnwise_cpu_dptr(), + input.columnwise_shape().data[0], input.columnwise_shape().data[1], + 32, 1); +} + +template +void compute_ref_stats(NormType norm_type, + const InputType *data, float *mu, float *rsigma, + const size_t N, const size_t H, const double epsilon){ + using compute_t = float; + + #pragma omp parallel for proc_bind(spread) + for (size_t i = 0; i < N; ++i) { + compute_t sum = 0; + for (size_t j = 0; j < H; ++j) { + sum += static_cast(data[i * H + j]); + } + compute_t m; + if (norm_type == LayerNorm){ + mu[i] = sum / H; + m = mu[i]; + } else { m = 0;} + + compute_t sum_sq = 0; + for (size_t j = 0; j < H; ++j) { + compute_t current = static_cast(data[i * H + j]); + sum_sq += (current - m) * (current - m); + } + rsigma[i] = rsqrtf((sum_sq / H) + epsilon); + } +} + +template +void compute_ref_output(NormType norm_type, + const InputType *data, const InputType *gamma, const InputType *beta, + const float *mu, const float *rsigma, + const size_t N, const size_t H, + OutputType* output, + const bool zero_centered_gamma){ + using compute_t = float; + + #pragma omp parallel for proc_bind(spread) + for (size_t i = 0; i < N; ++i) { + for (size_t j = 0; j < H; ++j) { + compute_t current = static_cast(data[i * H + j]); + compute_t g = static_cast(gamma[j]); + if (zero_centered_gamma) { + g += 1.0; + } + + compute_t tmp; + if (norm_type == LayerNorm) { + tmp = (current - mu[i]) * rsigma[i] * g + static_cast(beta[j]); + } else { // RMSNorm + tmp = current * rsigma[i] * g; + } + + output[i * H + j] = tmp; + } + } +} + +template +void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, NormType norm_type, bool is_training) { + + cudaDeviceProp prop; + cudaGetDeviceProperties(&prop, 0); + + if (getDeviceComputeCapability() < blackwellComputeCapability) { + GTEST_SKIP(); + } + + using WeightType = InputType; + DType itype = TypeInfo::dtype; + DType wtype = TypeInfo::dtype; + DType otype = TypeInfo::dtype; + + Tensor input("input", { N, H }, itype); + Tensor z("z", { N, H }, otype, true, is_training, NVTE_MXFP8_1D_SCALING); + Tensor gamma("gamma", { H }, wtype); + Tensor beta("beta", { H }, wtype); + Tensor mu("mu", { N }, DType::kFloat32); + Tensor rsigma("rsigma", { N }, DType::kFloat32); + Tensor workspace; + + + fillUniform(&input); + fillUniform(&gamma); + fillUniform(&beta); + + // Forward kernel + float epsilon = 1e-5; + if (norm_type == NormType::LayerNorm){ + nvte_layernorm_fwd(input.data(), gamma.data(), beta.data(), epsilon, + z.data(), mu.data(), rsigma.data(), workspace.data(), + prop.multiProcessorCount, zero_centered_gamma, + 0); + workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype()); + nvte_layernorm_fwd(input.data(), gamma.data(), beta.data(), epsilon, + z.data(), mu.data(), rsigma.data(), workspace.data(), + prop.multiProcessorCount, zero_centered_gamma, + 0); + } else { + nvte_rmsnorm_fwd(input.data(), gamma.data(), epsilon, + z.data(), rsigma.data(), workspace.data(), + prop.multiProcessorCount, zero_centered_gamma, + 0); + + workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype()); + nvte_rmsnorm_fwd(input.data(), gamma.data(), epsilon, + z.data(), rsigma.data(), workspace.data(), + prop.multiProcessorCount, zero_centered_gamma, + 0); + } + + Tensor dequantized_output("dequantized_output", { N, H }, DType::kFloat32, true, true); + + dequantize_2x(z, dequantized_output, is_training); + + // Reference implementations + std::unique_ptr ref_mu = std::make_unique(N); + std::unique_ptr ref_rsigma = std::make_unique(N); + std::unique_ptr ref_output = std::make_unique(N * H); + + + compute_ref_stats(norm_type, input.rowwise_cpu_dptr(), ref_mu.get(), + ref_rsigma.get(), N, H, epsilon); + // use the GPU stats to tighten the tolerances + float *ref_mu_ptr, *ref_rsigma_ptr; + if (is_training){ + mu.to_cpu(); + rsigma.to_cpu(); + ref_mu_ptr = mu.rowwise_cpu_dptr(); + ref_rsigma_ptr = rsigma.rowwise_cpu_dptr(); + } else { + ref_mu_ptr = ref_mu.get(); + ref_rsigma_ptr = ref_rsigma.get(); + } + compute_ref_output(norm_type, input.rowwise_cpu_dptr(), + gamma.rowwise_cpu_dptr(), + beta.rowwise_cpu_dptr(), + ref_mu_ptr, + ref_rsigma_ptr, + N, H, + ref_output.get(), + zero_centered_gamma); + + cudaDeviceSynchronize(); + auto err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + + auto [atol_stats, rtol_stats] = getTolerances(DType::kFloat32); + rtol_stats = 5e-5; + if (is_training){ + compareResults("mu", mu, ref_mu.get(), true, atol_stats, rtol_stats); + compareResults("rsigma", rsigma, ref_rsigma.get(), true, atol_stats, rtol_stats); + } + + float atol, rtol; + if (otype == DType::kFloat8E5M2){ + atol = 1.25e-1; + rtol = 1.25e-1; + } else if (otype == DType::kFloat8E4M3){ + if (itype == DType::kBFloat16){ + atol = 7e-2; + rtol = 7e-2; + } else { + atol = 6.25e-2; + rtol = 6.25e-2; + } + } + compareResults("output_rowwise", dequantized_output, ref_output.get(), true, atol, rtol, false); + if (is_training) + compareResults("output_colwise", dequantized_output, ref_output.get(), false, atol, rtol, false); +} + +std::vector> test_cases = { + {32, 32}, + {768, 2304}, + {2048, 12288}, +}; + +std::vector norms = { + NormType::LayerNorm, + NormType::RMSNorm +}; + +} // namespace + +class MxNormTestSuite : public ::testing::TestWithParam< std::tuple, + bool, bool>> {}; + +TEST_P(MxNormTestSuite, TestMxNorm) { + using namespace transformer_engine; + using namespace test; + + const NormType norm_type = std::get<0>(GetParam()); + const DType input_type = std::get<1>(GetParam()); + const DType output_type = std::get<2>(GetParam()); + const auto size = std::get<3>(GetParam()); + const bool zero_centered_gamma = std::get<4>(GetParam()); + const bool is_training = std::get<5>(GetParam()); + + TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(output_type, OutputType, + performTest(size.first, size.second, zero_centered_gamma, norm_type, is_training); + ); + ); +} + +INSTANTIATE_TEST_SUITE_P( + OperatorTest, + MxNormTestSuite, + ::testing::Combine( + ::testing::Values(NormType::LayerNorm, NormType::RMSNorm), + ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), + ::testing::Values(DType::kFloat8E5M2, DType::kFloat8E4M3), + ::testing::ValuesIn(test_cases), + ::testing::Values(true, false), + ::testing::Values(true, false)), + [](const testing::TestParamInfo& info) { + std::string name = normToString.at(std::get<0>(info.param)) + "_" + + test::typeName(std::get<1>(info.param)) + "X" + + test::typeName(std::get<2>(info.param)) + "X" + + std::to_string(std::get<3>(info.param).first) + "X" + + std::to_string(std::get<3>(info.param).second) + "X" + + std::to_string(std::get<4>(info.param)) + "out" + + std::to_string(int(std::get<5>(info.param)) + 1) + "x"; + return name; + }); diff --git a/tests/cpp/operator/test_qdq.cu b/tests/cpp/operator/test_qdq.cu index 76f049360a..3c12cef865 100644 --- a/tests/cpp/operator/test_qdq.cu +++ b/tests/cpp/operator/test_qdq.cu @@ -58,18 +58,18 @@ void performTestQ(const size_t N) { DType itype = TypeInfo::dtype; DType otype = TypeInfo::dtype; - Tensor input({ N }, itype); - Tensor output({ N }, otype); + Tensor input("input", { N }, itype); + Tensor output("output", { N }, otype); std::unique_ptr ref_output = std::make_unique(N); fillUniform(&input); setRandomScale(&output); - nvte_fp8_quantize(input.data(), output.data(), 0); + nvte_quantize(input.data(), output.data(), 0); float ref_amax; - compute_ref_q(input.cpu_dptr(), ref_output.get(), + compute_ref_q(input.rowwise_cpu_dptr(), ref_output.get(), N, &ref_amax, output.scale()); cudaDeviceSynchronize(); @@ -79,7 +79,7 @@ void performTestQ(const size_t N) { auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); compareResults("amax", output.amax(), ref_amax, atol_amax, rtol_amax); auto [atol, rtol] = getTolerances(otype); - compareResults("output_q", output, ref_output.get(), atol, rtol); + compareResults("output_q", output, ref_output.get(), true, atol, rtol); } template @@ -89,24 +89,24 @@ void performTestDQ(const size_t N) { DType itype = TypeInfo::dtype; DType otype = TypeInfo::dtype; - Tensor input({ N }, itype); - Tensor output({ N }, otype); + Tensor input("input", { N }, itype); + Tensor output("output", { N }, otype); std::unique_ptr ref_output = std::make_unique(N); fillUniform(&input); - nvte_fp8_dequantize(input.data(), output.data(), 0); + nvte_dequantize(input.data(), output.data(), 0); - compute_ref_dq(input.cpu_dptr(), ref_output.get(), - N, input.scale_inv()); + compute_ref_dq(input.rowwise_cpu_dptr(), ref_output.get(), + N, input.rowwise_scale_inv()); cudaDeviceSynchronize(); auto err = cudaGetLastError(); ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); auto [atol, rtol] = getTolerances(otype); - compareResults("output_dq", output, ref_output.get(), atol, rtol); + compareResults("output_dq", output, ref_output.get(), true, atol, rtol); } std::vector qdq_test_cases = {2048* 12288, diff --git a/tests/cpp/operator/test_swizzle.cu b/tests/cpp/operator/test_swizzle.cu new file mode 100644 index 0000000000..f6e0da057a --- /dev/null +++ b/tests/cpp/operator/test_swizzle.cu @@ -0,0 +1,165 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include + +#include "../test_common.h" +#include "transformer_engine/transformer_engine.h" + +using namespace transformer_engine; + +constexpr int MAT_TILE_DIM_M = 128; +constexpr int MAT_TILE_DIM_K = 128; + +template +void compute_ref_swizzle(const uint8_t *h_input, uint8_t *h_output, + const size_t M, const size_t K) { + + constexpr int NEW_SF_TILE_DIM_M = SF_TILE_DIM_M / 4; + constexpr int NEW_SF_TILE_DIM_K = SF_TILE_DIM_K * 4; + constexpr int SF_TILE_SIZE = SF_TILE_DIM_M * SF_TILE_DIM_K; + + for (int m = 0; m < M; m++) { + for (int k = 0; k < K; k++) { + + int tile_id_m = m / SF_TILE_DIM_M; + int tile_id_k = k / SF_TILE_DIM_K; + int m_in_tile = m % SF_TILE_DIM_M; + int k_in_tile = k % SF_TILE_DIM_K; + + int row_in_new_tile = m_in_tile % NEW_SF_TILE_DIM_M; + int col_in_new_tile = m_in_tile / NEW_SF_TILE_DIM_M * SF_TILE_DIM_K + k_in_tile; + + int tile_output_ptr = tile_id_m * SF_TILE_DIM_M * K + tile_id_k * SF_TILE_SIZE; + int out_index = tile_output_ptr + row_in_new_tile * NEW_SF_TILE_DIM_K + col_in_new_tile; + if constexpr(row_scaling) + h_output[out_index] = h_input[k + m * K]; + else + h_output[out_index] = h_input[k * M + m]; + } + } +} + +void performTestSwizzle1D(const int num_tiles_M, const int num_tiles_K, bool rowwise, bool columnwise, const bool transa) { + using namespace test; + + int SF_MODE_X, SF_MODE_Y; + if (rowwise) { + SF_MODE_X = 1; + SF_MODE_Y = 32; + } + if (columnwise) { + SF_MODE_X = 32; + SF_MODE_Y = 1; + } + + if ((rowwise && columnwise) || !(rowwise || columnwise)){ + GTEST_SKIP() << "TEST SKIPPED, The scaling mode " + std::to_string(SF_MODE_X) + "x" + + std::to_string(SF_MODE_Y) + "is not implemented."; + } + + DType dtype = DType::kFloat8E4M3; + + const size_t M = num_tiles_M * MAT_TILE_DIM_M; + const size_t K = num_tiles_K * MAT_TILE_DIM_K; + const auto data_shape = transa ? std::vector{M, K} : std::vector{K, M}; + + const auto scale_shape = std::vector{data_shape[0] / SF_MODE_X, data_shape[1] /SF_MODE_Y}; + + std::vector scaling_mode = {SF_MODE_X, SF_MODE_Y, 0}; + Tensor input("input", data_shape, dtype, rowwise, columnwise, NVTE_MXFP8_1D_SCALING); + Tensor output("output", data_shape, dtype, rowwise, columnwise, NVTE_MXFP8_1D_SCALING); + + fillUniform(&input); + + std::unique_ptr ref_output = std::make_unique(scale_shape[0] * scale_shape[1]); + + nvte_swizzle_scaling_factors(input.data(), output.data(), 0); + + if (rowwise) + compute_ref_swizzle<128, 4, true>(input.rowwise_cpu_scale_inv_ptr(), ref_output.get(), scale_shape[0], scale_shape[1]); + else + compute_ref_swizzle<128, 4, false>(input.columnwise_cpu_scale_inv_ptr(), ref_output.get(), scale_shape[1], scale_shape[0]); + + cudaDeviceSynchronize(); + auto err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + + output.to_cpu(); + if (rowwise) { + compareResults("output_swizzle", output.rowwise_cpu_scale_inv_ptr(), ref_output.get(), scale_shape[0] * scale_shape[1]); + } else { + compareResults("output_swizzle", output.columnwise_cpu_scale_inv_ptr(), ref_output.get(), scale_shape[0] * scale_shape[1]); + } +} + +class SwizzleTestSuite : public ::testing::TestWithParam, std::pair, bool>> {}; + + +TEST_P(SwizzleTestSuite, TestSwizzle) { + using namespace transformer_engine; + using namespace test; + + const auto num_tiles = std::get<0>(GetParam()); + const auto scaling_mode = std::get<1>(GetParam()); + const auto transa = std::get<2>(GetParam()); + + performTestSwizzle1D(num_tiles.first, num_tiles.second, + scaling_mode.first, scaling_mode.second, + transa); +} + +namespace { + +std::vector> num_tiles = { + {1, 1}, + {1, 132}, + {132, 1}, + {65, 256}, + {65, 257}, + {65, 258}, + {65, 259}, +}; + +std::vector> scaling_mode = { + {true, false}, + {false, true} +}; + +std::vector transa = {true, false}; + +} // namespace + +INSTANTIATE_TEST_SUITE_P( + OperatorTest, + SwizzleTestSuite, + ::testing::Combine( + ::testing::ValuesIn(num_tiles), + ::testing::ValuesIn(scaling_mode), + ::testing::ValuesIn(transa) + ), + [](const testing::TestParamInfo& info) { + std::string name = "ntiles" + + std::to_string(std::get<0>(info.param).first) + "X" + + std::to_string(std::get<0>(info.param).second) + "smode" + + std::to_string(std::get<1>(info.param).first) + "X"+ + std::to_string(std::get<1>(info.param).second) + "trans" + + std::to_string(std::get<2>(info.param)); + return name; + }); diff --git a/tests/cpp/operator/test_transpose.cu b/tests/cpp/operator/test_transpose.cu index 0852ddf7c3..00dd241c92 100644 --- a/tests/cpp/operator/test_transpose.cu +++ b/tests/cpp/operator/test_transpose.cu @@ -37,8 +37,8 @@ void performTest(const size_t N, const size_t H) { DType dtype = TypeInfo::dtype; - Tensor input({ N, H }, dtype); - Tensor output({ H, N }, dtype); + Tensor input("input", { N, H }, dtype); + Tensor output("output", { H, N }, dtype); std::unique_ptr ref_output = std::make_unique(N * H); @@ -46,13 +46,13 @@ void performTest(const size_t N, const size_t H) { nvte_transpose(input.data(), output.data(), 0); - compute_ref(input.cpu_dptr(), ref_output.get(), N, H); + compute_ref(input.rowwise_cpu_dptr(), ref_output.get(), N, H); cudaDeviceSynchronize(); auto err = cudaGetLastError(); ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); auto [atol, rtol] = getTolerances(dtype); - compareResults("output", output, ref_output.get(), atol, rtol); + compareResults("output", output, ref_output.get(), true, atol, rtol); } std::vector> test_cases = {{2048, 12288}, diff --git a/tests/cpp/test_common.cu b/tests/cpp/test_common.cu index 84cc11673b..ec4a9bdbb7 100644 --- a/tests/cpp/test_common.cu +++ b/tests/cpp/test_common.cu @@ -10,14 +10,24 @@ #include #include #include +#include +#include +#include #include +#include #include #include "util/logging.h" namespace test { +size_t create_seed_from_tensor_name(const std::string& tensor_name) { + auto full_name = std::string(testing::UnitTest::GetInstance()->current_test_info()->name()) + + "/" + tensor_name; + return std::hash{}(full_name); +} + std::vector all_fp_types = {DType::kFloat32, DType::kFloat16, DType::kBFloat16, @@ -50,102 +60,379 @@ const std::string &typeName(DType type) { {DType::kFloat16, "float16"}, {DType::kBFloat16, "bfloat16"}, {DType::kFloat8E4M3, "float8e4m3"}, - {DType::kFloat8E5M2, "float8e5m2"}}; + {DType::kFloat8E5M2, "float8e5m2"}, + {DType::kFloat8E8M0, "float8e8m0"}}; return name_map.at(type); } -size_t product(const NVTEShape &shape) { +const std::string& caseName(InputsFillCase type) { + static const std::unordered_map name_map = { + {InputsFillCase::uniform, "uniform"}, + {InputsFillCase::zeros, "zeros"}, + {InputsFillCase::zero_to_minNorm, "zero_to_minNorm"}, + {InputsFillCase::minNorm_to_maxNorm, "minNorm_to_maxNorm"}, + {InputsFillCase::maxNorm_to_inf, "maxNorm_to_inf"}}; + return name_map.at(type); +} + +size_t product(const NVTEShape &shape, size_t begin, size_t end) { size_t ret = 1; - for (size_t i = 0; i < shape.ndim; ++i) { + NVTE_CHECK(end <= shape.ndim); + for (size_t i = begin; i < end; ++i) { ret *= shape.data[i]; } return ret; } +size_t product(const NVTEShape &shape) { + return product(shape, 0, shape.ndim); +} +size_t product(const std::vector shape, size_t begin, size_t end) { + size_t ret = 1; + NVTE_CHECK(end <= shape.size()); + for (size_t i = begin; i < end; ++i) { + ret *= shape[i]; + } + return ret; +} -Tensor::Tensor(const NVTEShape &shape, const DType type) { - size_t s = typeToSize(type); - size_t total_size = product(shape) * s; - void *dptr = nullptr; - cpu_data_ = nullptr; - amax_cpu_data_ = nullptr; - scale_cpu_data_ = nullptr; - scale_inv_cpu_data_ = nullptr; - float *amax = nullptr, *scale = nullptr, *scale_inv = nullptr; - if (total_size != 0) { - cudaMalloc((void**)&dptr, total_size); // NOLINT(*) - cudaMemset(dptr, 0, total_size); - cpu_data_ = std::make_unique(total_size); - for (size_t i = 0; i < total_size; ++i) { - cpu_data_[i] = 0; - } +size_t product(const std::vector& shape) { + return product(shape, 0, shape.size()); +} + +size_t DIVUP(const size_t &x, const size_t &y){ + return (((x) + ((y)-1)) / (y)); +} + +inline bool is_tensor_scaling(const NVTEScalingMode &mode) { + return mode == NVTE_DELAYED_TENSOR_SCALING; +} + +struct scale_inv_meta { + std::vector shape; + DType type; + size_t type_size; +}; + +NVTEShape convertShape(const std::vector& shape) { + return {shape.data(), shape.size()}; +} + +std::pair get_scales(const NVTEShape& shape, + const NVTEScalingMode scaling_mode) { + if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { + scale_inv_meta ret; + ret.shape = {1}; + ret.type = DType::kFloat32; + ret.type_size = sizeof(float); + return {ret, ret}; + } + if (scaling_mode == NVTE_MXFP8_1D_SCALING) { + std::vector shape_vec; + for (size_t i = 0; i < shape.ndim; ++i) { + shape_vec.push_back(shape.data[i]); + } + size_t first_dim = first_dimension(shape_vec); + size_t last_dim = last_dimension(shape_vec); + + scale_inv_meta ret_rowwise, ret_colwise; + + auto block_alignment = std::vector{128ul,4ul}; + { + auto alignment = block_alignment[0]; + auto scale_dim_0 = DIVUP(DIVUP(first_dim, + static_cast(1)), + alignment) * alignment; + alignment = block_alignment[1]; + auto scale_dim_1 = DIVUP(DIVUP(last_dim, + static_cast(32)), + alignment) * alignment; + ret_rowwise.shape = {scale_dim_0, scale_dim_1}; + } + { + auto alignment = block_alignment[1]; + auto scale_dim_0 = DIVUP(DIVUP(first_dim, + static_cast(32)), + alignment) * alignment; + alignment = block_alignment[0]; + auto scale_dim_1 = DIVUP(DIVUP(last_dim, + static_cast(1)), + alignment) * alignment; + ret_colwise.shape = {scale_dim_0, scale_dim_1}; } - if (isFp8Type(type)) { + ret_rowwise.type = DType::kFloat8E8M0; + ret_colwise.type = DType::kFloat8E8M0; + ret_rowwise.type_size = sizeof(uint8_t); + ret_colwise.type_size = sizeof(uint8_t); + + return {ret_rowwise, ret_colwise}; + } + + NVTE_ERROR("Invalid scaling mode!"); +} + +Tensor::Tensor(const std::string& name, + const NVTEShape &shape, const DType type, + const bool rowwise, const bool columnwise, + const NVTEScalingMode &scaling_mode) { + name_ = name; + const size_t seed = create_seed_from_tensor_name(name); + gen_.seed(seed); + rowwise_ = rowwise; + columnwise_ = columnwise; + size_t s = typeToSize(type); + size_t total_size = product(shape) * s; + void *dptr_rowwise = nullptr; + void *dptr_columnwise = nullptr; + cpu_data_rowwise_ = nullptr; + cpu_data_columnwise_ = nullptr; + amax_cpu_data_ = nullptr; + scale_cpu_data_ = nullptr; + rowwise_scale_inv_cpu_data_ = nullptr; + columnwise_scale_inv_cpu_data_ = nullptr; + float *amax = nullptr, *scale = nullptr; + float *rowwise_scale_inv = nullptr, *columnwise_scale_inv = nullptr; + if (columnwise) { + NVTE_CHECK(shape.ndim >= 2); + } + std::vector normalized_shape_v = {product(shape, 0, shape.ndim - 1), + shape.data[shape.ndim - 1]}; + NVTEShape normalized_shape = convertShape(normalized_shape_v); + + std::vector columnwise_shape_vec; + if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { + // Transpose when tensor scaling + columnwise_shape_vec.emplace_back(shape.data[shape.ndim - 1]); + for (size_t i = 0; i < shape.ndim - 1; ++i) { + columnwise_shape_vec.emplace_back(shape.data[i]); + } + } else { + // Same shape for MX + for (size_t i = 0; i < shape.ndim; ++i) { + columnwise_shape_vec.emplace_back(shape.data[i]); + } + } + const NVTEShape columnwise_shape{columnwise_shape_vec.data(), columnwise_shape_vec.size()}; + + tensor_ = TensorWrapper(scaling_mode); + + if (total_size != 0) { + if (rowwise) { + cudaMalloc((void**)&dptr_rowwise, total_size); // NOLINT(*) + cudaMemset(dptr_rowwise, 0, total_size); + cpu_data_rowwise_ = std::make_unique(total_size); + std::fill_n(cpu_data_rowwise_.get(), total_size, 0); + } + if (columnwise) { + cudaMalloc((void**)&dptr_columnwise, total_size); // NOLINT(*) + cudaMemset(dptr_columnwise, 0, total_size); + cpu_data_columnwise_ = std::make_unique(total_size); + std::fill_n(cpu_data_columnwise_.get(), total_size, 0); + } + } + tensor_.set_rowwise_data(dptr_rowwise, type, shape); + tensor_.set_columnwise_data(dptr_columnwise, type, columnwise_shape); + + if (isFp8Type(type)) { + if (is_tensor_scaling(scaling_mode)) { cudaMalloc((void**)&amax, sizeof(float)); // NOLINT(*) cudaMemset(amax, 0, sizeof(float)); cudaMalloc((void**)&scale, sizeof(float)); // NOLINT(*) cudaMemset(scale, 0, sizeof(float)); - cudaMalloc((void**)&scale_inv, sizeof(float)); // NOLINT(*) - cudaMemset(scale_inv, 0, sizeof(float)); - amax_cpu_data_ = std::make_shared(); - *amax_cpu_data_ = 0; - scale_cpu_data_ = std::make_shared(); - *scale_cpu_data_ = 0; - scale_inv_cpu_data_ = std::make_shared(); - *scale_inv_cpu_data_ = 0; + amax_cpu_data_ = std::make_shared(0); + scale_cpu_data_ = std::make_shared(0); + tensor_.set_amax(amax, DType::kFloat32, std::vector{1}); + tensor_.set_scale(scale, DType::kFloat32, std::vector{1}); + cudaMalloc((void**)&rowwise_scale_inv, sizeof(float)); // NOLINT(*) + if (rowwise) { + tensor_.set_rowwise_scale_inv(rowwise_scale_inv, DType::kFloat32, + std::vector{1}); + rowwise_scale_inv_cpu_data_ = std::make_unique(sizeof(float)); + std::fill_n(rowwise_scale_inv_cpu_data_.get(), sizeof(float), 0); + } + if (columnwise) { + tensor_.set_columnwise_scale_inv(rowwise_scale_inv, DType::kFloat32, + std::vector{1}); + columnwise_scale_inv_cpu_data_ = std::make_unique(sizeof(float)); + std::fill_n(columnwise_scale_inv_cpu_data_.get(), sizeof(float), 0); + } + } else { + auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(normalized_shape, + tensor_.scaling_mode()); + auto rowwise_scale_size = product(rowwise_scale_meta.shape) * rowwise_scale_meta.type_size; + auto columnwise_scale_size = product(colwise_scale_meta.shape) * colwise_scale_meta.type_size; + auto scale_shape = rowwise_scale_meta.shape; + auto columnwise_scale_shape = colwise_scale_meta.shape; + if (rowwise) { + cudaMalloc((void**)&rowwise_scale_inv, rowwise_scale_size); // NOLINT(*) + cudaMemset(rowwise_scale_inv, 0, rowwise_scale_size); + rowwise_scale_inv_cpu_data_ = std::make_unique(rowwise_scale_size); + std::fill_n(rowwise_scale_inv_cpu_data_.get(), rowwise_scale_size, 0); + tensor_.set_rowwise_scale_inv(rowwise_scale_inv, DType::kFloat8E8M0, scale_shape); + } + if (columnwise) { + cudaMalloc((void**)&columnwise_scale_inv, columnwise_scale_size); // NOLINT(*) + cudaMemset(columnwise_scale_inv, 0, columnwise_scale_size); + columnwise_scale_inv_cpu_data_ = std::make_unique(columnwise_scale_size); + std::fill_n(columnwise_scale_inv_cpu_data_.get(), columnwise_scale_size, 0); + tensor_.set_columnwise_scale_inv(columnwise_scale_inv, DType::kFloat8E8M0, columnwise_scale_shape); + } } - tensor_ = TensorWrapper(dptr, shape, type, amax, scale, scale_inv); + } } void Tensor::to_cpu() const { const NVTEShape s = tensor_.shape(); const size_t size = product(s) * typeToSize(tensor_.dtype()); - cudaMemcpy(cpu_data_.get(), tensor_.dptr(), size, cudaMemcpyDeviceToHost); + if (rowwise_) { + cudaMemcpy(cpu_data_rowwise_.get(), + tensor_.get_rowwise_data().data_ptr, + size, + cudaMemcpyDeviceToHost); + } + if (columnwise_) { + cudaMemcpy(cpu_data_columnwise_.get(), + tensor_.get_columnwise_data().data_ptr, + size, + cudaMemcpyDeviceToHost); + } if (isFp8Type(dtype())) { - cudaMemcpy(amax_cpu_data_.get(), tensor_.amax(), sizeof(float), - cudaMemcpyDeviceToHost); - cudaMemcpy(scale_cpu_data_.get(), tensor_.scale(), sizeof(float), - cudaMemcpyDeviceToHost); - cudaMemcpy(scale_inv_cpu_data_.get(), tensor_.scale_inv(), sizeof(float), - cudaMemcpyDeviceToHost); + if (is_tensor_scaling(tensor_.scaling_mode())) { + cudaMemcpy(amax_cpu_data_.get(), + tensor_.amax(), + sizeof(float), + cudaMemcpyDeviceToHost); + cudaMemcpy(scale_cpu_data_.get(), + tensor_.scale(), + sizeof(float), + cudaMemcpyDeviceToHost); + } + auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(s, tensor_.scaling_mode()); + if (rowwise_) { + auto scale_size = product(rowwise_scale_meta.shape) * rowwise_scale_meta.type_size; + cudaMemcpy(rowwise_scale_inv_cpu_data_.get(), + tensor_.get_rowwise_scale_inv().data_ptr, + scale_size, + cudaMemcpyDeviceToHost); + } + if (columnwise_) { + auto scale_size = product(colwise_scale_meta.shape) * colwise_scale_meta.type_size; + cudaMemcpy(columnwise_scale_inv_cpu_data_.get(), + tensor_.get_columnwise_scale_inv().data_ptr, + scale_size, + cudaMemcpyDeviceToHost); + } } } void Tensor::from_cpu() const { const NVTEShape s = tensor_.shape(); const size_t size = product(s) * typeToSize(tensor_.dtype()); - cudaMemcpy(tensor_.dptr(), cpu_data_.get(), size, cudaMemcpyHostToDevice); + if (rowwise_) { + cudaMemcpy(tensor_.get_rowwise_data().data_ptr, + cpu_data_rowwise_.get(), size, cudaMemcpyHostToDevice); + } + if (columnwise_) { + cudaMemcpy(tensor_.get_columnwise_data().data_ptr, + cpu_data_columnwise_.get(), size, cudaMemcpyHostToDevice); + } if (isFp8Type(dtype())) { - cudaMemcpy(tensor_.amax(), amax_cpu_data_.get(), sizeof(float), - cudaMemcpyHostToDevice); - cudaMemcpy(tensor_.scale(), scale_cpu_data_.get(), sizeof(float), - cudaMemcpyHostToDevice); - cudaMemcpy(tensor_.scale_inv(), scale_inv_cpu_data_.get(), sizeof(float), - cudaMemcpyHostToDevice); + if (is_tensor_scaling(tensor_.scaling_mode())) { + cudaMemcpy(tensor_.amax(), amax_cpu_data_.get(), sizeof(float), + cudaMemcpyHostToDevice); + cudaMemcpy(tensor_.scale(), scale_cpu_data_.get(), sizeof(float), + cudaMemcpyHostToDevice); + } + auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(s, tensor_.scaling_mode()); + if (rowwise_) { + auto scale_size = product(rowwise_scale_meta.shape) * rowwise_scale_meta.type_size; + cudaMemcpy(tensor_.get_rowwise_scale_inv().data_ptr, + rowwise_scale_inv_cpu_data_.get(), scale_size, + cudaMemcpyHostToDevice); + } + if (columnwise_) { + auto scale_size = product(colwise_scale_meta.shape) * colwise_scale_meta.type_size; + cudaMemcpy(tensor_.get_columnwise_scale_inv().data_ptr, + columnwise_scale_inv_cpu_data_.get(), scale_size, + cudaMemcpyHostToDevice); + } } } void Tensor::set_scale(float scale) { if (isFp8Type(dtype())) { NVTE_CHECK(scale_cpu_data_); - *scale_cpu_data_ = scale; - from_cpu(); + if (is_tensor_scaling(tensor_.scaling_mode())) { + *scale_cpu_data_ = scale; + from_cpu(); + } } } void Tensor::set_scale_inv(float scale_inv) { if (isFp8Type(dtype())) { - NVTE_CHECK(scale_inv_cpu_data_); - *scale_inv_cpu_data_ = scale_inv; + if (rowwise_) { + NVTE_CHECK(rowwise_scale_inv_cpu_data_); + } + if (columnwise_) { + NVTE_CHECK(columnwise_scale_inv_cpu_data_); + } + auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(tensor_.shape(), tensor_.scaling_mode()); + if (rowwise_) { + auto num_scales = product(rowwise_scale_meta.shape); + if (num_scales == 1){ + rowwise_cpu_scale_inv_ptr()[0] = scale_inv; + } else{ + std::uniform_int_distribution dis(0, 127); + auto* scale_inv_ptr = rowwise_cpu_scale_inv_ptr(); + for (size_t i = 0; i < num_scales; i++){ + scale_inv_ptr[i] = dis(gen_); + } + } + } + if (columnwise_) { + auto num_scales = product(colwise_scale_meta.shape); + if (num_scales == 1){ + columnwise_cpu_scale_inv_ptr()[0] = scale_inv; + } else{ + std::uniform_int_distribution dis(0, 127); + auto* scale_inv_ptr = columnwise_cpu_scale_inv_ptr(); + for (size_t i = 0; i < num_scales; i++){ + scale_inv_ptr[i] = dis(gen_); + } + } + } from_cpu(); } } void Tensor::shareFP8Meta(const Tensor &other) { if(isFp8Type(dtype()) && isFp8Type(other.dtype())) { - tensor_ = TensorWrapper(dptr(), shape(), dtype(), - other.tensor_.amax(), - other.tensor_.scale(), - other.tensor_.scale_inv()); + auto new_tensor = TensorWrapper(other.tensor_.scaling_mode()); + auto my_rowwise_data = tensor_.get_rowwise_data(); + new_tensor.set_rowwise_data(my_rowwise_data.data_ptr, + static_cast(my_rowwise_data.dtype), + my_rowwise_data.shape); + auto my_columnwise_data = tensor_.get_columnwise_data(); + new_tensor.set_columnwise_data(my_columnwise_data.data_ptr, + static_cast(my_columnwise_data.dtype), + my_columnwise_data.shape); + auto other_amax = other.tensor_.get_amax(); + new_tensor.set_amax(other_amax.data_ptr, + static_cast(other_amax.dtype), + other_amax.shape); + auto other_scale = other.tensor_.get_scale(); + new_tensor.set_scale(other_scale.data_ptr, + static_cast(other_scale.dtype), + other_scale.shape); + auto other_row_scale_inv = other.tensor_.get_rowwise_scale_inv(); + new_tensor.set_rowwise_scale_inv(other_row_scale_inv.data_ptr, + static_cast(other_row_scale_inv.dtype), + other_row_scale_inv.shape); + auto other_col_scale_inv = other.tensor_.get_columnwise_scale_inv(); + new_tensor.set_columnwise_scale_inv(other_col_scale_inv.data_ptr, + static_cast(other_col_scale_inv.dtype), + other_col_scale_inv.shape); + tensor_ = std::move(new_tensor); to_cpu(); } } @@ -177,12 +464,14 @@ std::vector unravel(const size_t i, const NVTEShape &shape) { return ret; } -void compareResults(const std::string &name, const Tensor &test, const void *ref, - double atol, double rtol) { - test.to_cpu(); - const size_t N = product(test.shape()); +void compareResults_sequential(const std::string &name, const Tensor &test, + const void *ref, const bool rowwise, + double atol, double rtol, bool if_on_gpus) { + if (if_on_gpus) test.to_cpu(); + const auto& shape = rowwise ? test.rowwise_shape() : test.columnwise_shape(); + const size_t N = product(shape); TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(test.dtype(), T, - const T *test_data = test.cpu_dptr(); + const T *test_data = rowwise ? test.rowwise_cpu_dptr() : test.columnwise_cpu_dptr(); const T *ref_data = reinterpret_cast(ref); for (size_t i = 0; i < N; ++i) { double t = static_cast(test_data[i]); @@ -200,14 +489,84 @@ void compareResults(const std::string &name, const Tensor &test, const void *ref const double cast_mean_m = static_cast(static_cast(mean_m)); assertion = !(cast_mean_m == std::min(t,r) && cast_mean_p == std::max(t,r)); } - ASSERT_FALSE(assertion) << "Error in tensor " << name << std::endl - << "Mismatch at place " << to_string(unravel(i, test.shape())) + std::string direction = rowwise ? "rowwise" : "columnwise"; + ASSERT_FALSE(assertion) << "Error in tensor " << name << " in " + << direction << " direction." << std::endl + << "Mismatch at place " << to_string(unravel(i, shape)) << " (" << std::to_string(i) << "): " << t << " vs " << r; + } + ); +} + +template +static size_t getFirstMismatchIdx(const DType data_type, const T* test_data, const T* ref_data, + const size_t N, const double atol, const double rtol) { + int first_mismatch_idx = N; + + bool is_mismatch_found = false; + #pragma omp parallel for schedule(static) firstprivate(is_mismatch_found) \ + reduction(min: first_mismatch_idx) proc_bind(spread) + for (size_t i = 0; i < N; ++i) { + if (is_mismatch_found) { // early escape of the omp thread + continue; + } + + double t = static_cast(test_data[i]); + double r = static_cast(ref_data[i]); + + bool mismatch = fabs(t - r) > atol && (r == 0 || fabs((t - r) / r) > rtol); + /* For Float32 the floating point comparison is enough to error out */ + bool assertion = mismatch && (data_type == DType::kFloat32); + if (mismatch && !assertion) { + /* Check if it is just a failure of round to nearest choosing different + side of the real value */ + const double mean = (t + r) / 2; + const double mean_p = mean >= 0 ? mean * (1 + 1e-6) : mean * (1 - 1e-6); + const double mean_m = mean >= 0 ? mean * (1 - 1e-6) : mean * (1 + 1e-6); + const double cast_mean_p = static_cast(static_cast(mean_p)); + const double cast_mean_m = static_cast(static_cast(mean_m)); + assertion = !(cast_mean_m == std::min(t,r) && cast_mean_p == std::max(t,r)); + } + if (assertion && i < first_mismatch_idx) { + first_mismatch_idx = i; + is_mismatch_found = true; + } + } + return first_mismatch_idx; +} + +void compareResults_parallel(const std::string &name, const Tensor &test, const void *ref, + const bool rowwise, double atol, double rtol, bool if_on_gpus) { + if (if_on_gpus) test.to_cpu(); + const auto& shape = rowwise ? test.rowwise_shape() : test.columnwise_shape(); + const size_t N = product(shape); + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(test.dtype(), T, + const T *test_data = rowwise ? test.rowwise_cpu_dptr() : test.columnwise_cpu_dptr(); + const T *ref_data = reinterpret_cast(ref); + const size_t i = getFirstMismatchIdx(test.dtype(), test_data, ref_data, N, atol, rtol); + if (i != N) { + const double t = static_cast(test_data[i]); + const double r = static_cast(ref_data[i]); + std::string direction = rowwise ? "rowwise" : "columnwise"; + ASSERT_FALSE(true) << "Error in tensor " << name << " in " + << direction << " direction." << std::endl + << "Mismatch at place " << to_string(unravel(i, shape)) + << " (" << std::to_string(i) << "): " << t << " vs " << r; } ); } +void compareResults(const std::string &name, const Tensor &test, const void *ref, + const bool rowwise, double atol, double rtol, bool if_on_gpus) { + constexpr bool sequential = false; + if constexpr (sequential) { + compareResults_sequential(name, test, ref, rowwise, atol, rtol, if_on_gpus); + } else { + compareResults_parallel(name, test, ref, rowwise, atol, rtol, if_on_gpus); + } +} + void compareResults(const std::string &name, const float test, const float ref, double atol, double rtol) { double t = static_cast(test); @@ -218,6 +577,51 @@ void compareResults(const std::string &name, const float test, const float ref, } + +void compareResults(const std::string &name, const uint8_t *test, const uint8_t *ref, + size_t N, float mismatch_rate_tol) { + size_t max_mismatches = std::ceil(N * mismatch_rate_tol); + size_t n_mismatches = 0; + std::vector mismatch_indices; + for (int i = 0; i < N; i++){ + bool mismatch = test[i] != ref[i]; + if (mismatch){ + n_mismatches++; + mismatch_indices.push_back(i); + } + if (n_mismatches > max_mismatches){ + std::cout << "Error in " << name << std::endl; + for (auto &index : mismatch_indices) + std::cout << "Mismatch at (" << index << "):" << static_cast(test[i]) << " vs " + << static_cast(ref[i]) << std::endl; + GTEST_FAIL() << n_mismatches << " mismatche(s) which is more than mismatch tol."; + } + } +} + +void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, const uint8_t *ref, + const size_t row_blocks, const size_t col_blocks, const size_t stride) +{ + for (int i = 0; i < row_blocks; ++i) { + for (int j = 0; j < col_blocks; ++j) { + const int idx = i * stride + j; + ASSERT_FALSE(test[idx] != ref[idx]) << "Error in " << name << std::endl + << "Mismatch: " << static_cast(test[idx]) << " vs " + << static_cast(ref[idx]) << " at index " << idx; + } + } +} + +void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, const uint8_t *ref, + const size_t N) +{ + for (int i = 0; i < N; i++) { + ASSERT_FALSE(test[i] != ref[i]) << "Error in " << name << std::endl + << "Mismatch: " << static_cast(test[i]) << " vs " + << static_cast(ref[i]) << " at index " << i; + } +} + std::pair getTolerances(const DType type) { switch(type) { case DType::kFloat32: @@ -228,6 +632,7 @@ std::pair getTolerances(const DType type) { return {1e-5, 1e-2}; case DType::kFloat8E4M3: case DType::kFloat8E5M2: + case DType::kFloat8E8M0: return {1e-2, 1e-2}; default: NVTE_CHECK("Invalid type!"); @@ -235,29 +640,158 @@ std::pair getTolerances(const DType type) { return {0, 0}; } +template +void generate_data_uniformly(T* data, const size_t size, std::mt19937* gen) { + #pragma omp parallel proc_bind(spread) + { + std::mt19937 gen_local = *gen; + gen_local.discard(omp_get_thread_num() * 599); + std::uniform_real_distribution<> dis(-2.0, 1.0); + #pragma omp for schedule(static) + for (size_t i = 0; i < size; ++i) { + data[i] = static_cast(dis(gen_local)); + } + } + gen->discard(size); +} + void fillUniform(Tensor *t) { - const size_t size = product(t->shape()); - static std::mt19937 gen(12345); + if (t->rowwise()) { + const size_t size = product(t->rowwise_shape()); + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(t->dtype(), T, + { + T *data = t->rowwise_cpu_dptr(); + generate_data_uniformly(data, size, &(t->gen())); + } + ); + } else { + const size_t size = product(t->columnwise_shape()); + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(t->dtype(), T, + { + T *data = t->columnwise_cpu_dptr(); + generate_data_uniformly(data, size, &(t->gen())); + } + ); + } std::uniform_real_distribution<> dis(-2.0, 1.0); - TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(t->dtype(), T, { - T *data = t->cpu_dptr(); + t->set_scale_inv(dis(t->gen())); + t->from_cpu(); +} + +template +void fillCase_special(Tensor *t) { + const size_t size = product(t->rowwise_shape()); + const size_t rows = t->rowwise_shape().data[0]; + const size_t cols = t->rowwise_shape().data[1]; + + if constexpr (Case == InputsFillCase::zeros) { + TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(t->dtype(), InputType, { + InputType *data = t->rowwise_cpu_dptr(); for (size_t i = 0; i < size; ++i) { - data[i] = T(dis(gen)); + data[i] = static_cast(0); } - }); - t->set_scale_inv(dis(gen)); + }); + } else { + double minAbs = -2.0; + double maxAbs = 1.0; + if constexpr (Case != InputsFillCase::uniform) { + minAbs = Quantized_Limits::ranges[Case]; + maxAbs = Quantized_Limits::ranges[Case + 1]; + } + std::uniform_real_distribution<> dis(minAbs, maxAbs); + std::uniform_real_distribution<> dis_sign(-1.0, 1.0); + TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(t->dtype(), InputType, { + InputType *data = t->rowwise_cpu_dptr(); + for (size_t i = 0; i < rows; ++i) { + for (size_t j = 0; j < cols; ++j) { + const size_t idx = i * cols + j; + const bool is_negative = (dis_sign(t->gen()) < 0.0); + double val = dis(t->gen()); + if (is_negative) { + val = -val; + } + data[idx] = static_cast(val); + } + } + }); + } + t->set_scale_inv(1.0); t->from_cpu(); } +template +void fillCase(Tensor *t, const InputsFillCase fill_case) { + switch (fill_case) { + case InputsFillCase::uniform: + fillCase_special(t); break; + case InputsFillCase::zeros: + fillCase_special(t); break; + case InputsFillCase::zero_to_minNorm: + fillCase_special(t); break; + case InputsFillCase::minNorm_to_maxNorm: + fillCase_special(t); break; + case InputsFillCase::maxNorm_to_inf: + fillCase_special(t); break; + } +} + +template void fillCase(Tensor *t, const InputsFillCase fill_case); +template void fillCase(Tensor *t, const InputsFillCase fill_case); +template void fillCase(Tensor *t, const InputsFillCase fill_case); + void setRandomScale(Tensor *t) { - static std::mt19937 gen(12345); std::uniform_real_distribution<> dis(-2.0, 1.0); - const float scale = dis(gen); + const float scale = dis(t->gen()); t->set_scale(scale); } +void setRandomScaleInv(Tensor *t) { + std::uniform_real_distribution<> dis(-2.0, 1.0); + const float scale_inv = dis(t->gen()); + t->set_scale_inv(scale_inv); +} + bool isFp8Type(DType type) { - return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2; + return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2 || type == DType::kFloat8E8M0; +} + +int32_t getDeviceComputeCapability() +{ + cudaDeviceProp deviceProp; + cudaGetDeviceProperties(&deviceProp, 0); + return 10 * deviceProp.major + deviceProp.minor; +} + +size_t first_dimension(const std::vector &shape) { + if (shape.size() == 0) return 1; + if (shape.size() == 1) return 1; + return product(shape, 0, shape.size() - 1); +} + +size_t last_dimension(const std::vector &shape) { + if (shape.size() == 0) return 1; + return shape[shape.size() - 1]; +} + +std::array get_scale_tensor_dims(const size_t rows, + const size_t cols, + const size_t block_size_rows, + const size_t block_size_cols) { + const bool is_rowwise = (block_size_rows == 1) && (block_size_cols == 32); + + const size_t alignment_Y = is_rowwise + ? scale_tensor_alignment_Y_rowwise + : scale_tensor_alignment_Y_colwise; + const size_t alignment_X = is_rowwise + ? scale_tensor_alignment_X_rowwise + : scale_tensor_alignment_X_colwise; + + const size_t unpadded_blocks_Y = divide_round_up(rows, block_size_rows); + const size_t unpadded_blocks_X = divide_round_up(cols, block_size_cols); + + const size_t blocks_Y = round_up_to_nearest_multiple(unpadded_blocks_Y, alignment_Y); + const size_t blocks_X = round_up_to_nearest_multiple(unpadded_blocks_X, alignment_X); + return {unpadded_blocks_Y, unpadded_blocks_X, blocks_Y, blocks_X}; } } // namespace test diff --git a/tests/cpp/test_common.h b/tests/cpp/test_common.h index 4598a7b021..dc515ccb8e 100644 --- a/tests/cpp/test_common.h +++ b/tests/cpp/test_common.h @@ -6,9 +6,10 @@ #pragma once -#include #include #include +#include +#include #include #include @@ -52,6 +53,7 @@ using fp16 = half; using bf16 = nv_bfloat16; using fp8e4m3 = __nv_fp8_e4m3; using fp8e5m2 = __nv_fp8_e5m2; +using fp8e8m0 = uint8_t; template struct TypeInfo{ @@ -62,7 +64,8 @@ struct TypeInfo{ fp16, bf16, fp8e4m3, - fp8e5m2>; + fp8e5m2, + fp8e8m0>; template struct Helper { @@ -94,10 +97,19 @@ struct TypeInfo{ class Tensor { public: - Tensor(const NVTEShape &shape, const DType type); - - Tensor(const std::vector &shape, const DType type) : - Tensor(NVTEShape{shape.data(), shape.size()}, type) {} + Tensor(const std::string& name, + const NVTEShape &shape, const DType type, + const bool rowwise = true, + const bool columnwise = false, + const NVTEScalingMode &mode = NVTE_DELAYED_TENSOR_SCALING); + + Tensor(const std::string& name, + const std::vector &shape, + const DType type, + const bool rowwise = true, + const bool columnwise = false, + const NVTEScalingMode &mode = NVTE_DELAYED_TENSOR_SCALING) : + Tensor(name, NVTEShape{shape.data(), shape.size()}, type, rowwise, columnwise, mode) {} Tensor() {} @@ -108,30 +120,82 @@ class Tensor { Tensor& operator=(Tensor &&other) = default; ~Tensor() { - if (tensor_.dptr() != nullptr) { - cudaFree(tensor_.dptr()); + void *data_ptr = tensor_.dptr(); + void *scale_inv = tensor_.scale_inv(); + void *columnwise_data_ptr = tensor_.get_columnwise_data().data_ptr; + void *columnwise_scale_inv = tensor_.get_columnwise_scale_inv().data_ptr; + if (columnwise_data_ptr == data_ptr) { + columnwise_data_ptr = nullptr; + } + if (columnwise_scale_inv == scale_inv) { + columnwise_scale_inv = nullptr; + } + if (data_ptr != nullptr) { + cudaFree(data_ptr); + } + if (scale_inv != nullptr) { + cudaFree(scale_inv); + } + if (columnwise_data_ptr != nullptr){ + cudaFree(columnwise_data_ptr); + } + if (columnwise_scale_inv != nullptr){ + cudaFree(columnwise_scale_inv); } } + NVTETensor data() const noexcept { return tensor_.data(); } - const NVTEShape shape() const noexcept { - return tensor_.shape(); + NVTEShape rowwise_shape() const noexcept { + return tensor_.get_rowwise_data().shape; + } + + NVTEShape columnwise_shape() const noexcept { + return tensor_.get_columnwise_data().shape; + } + + NVTEShape rowwise_scale_inv_shape() const { + NVTE_CHECK(rowwise_, "Tensor does not have rowwise data!"); + return tensor_.get_rowwise_scale_inv().shape; + } + + NVTEShape columnwise_scale_inv_shape() const { + NVTE_CHECK(columnwise_, "Tensor does not have columnwise data!"); + return tensor_.get_columnwise_scale_inv().shape; + } + + NVTEScalingMode scaling_mode() const noexcept { + return tensor_.scaling_mode(); } DType dtype() const noexcept { return tensor_.dtype(); } - void *dptr() const noexcept { - return tensor_.dptr(); + void *rowwise_dptr() const { + NVTE_CHECK(rowwise_, "Tensor does not have rowwise data!"); + return tensor_.get_rowwise_data().data_ptr; + } + + void *columnwise_dptr() const { + NVTE_CHECK(columnwise_, "Tensor does not have columnwise data!"); + return tensor_.get_columnwise_data().data_ptr; + } + + template + T *rowwise_cpu_dptr() const { + NVTE_CHECK(TypeInfo::dtype == tensor_.dtype(), "Invalid type!"); + NVTE_CHECK(rowwise_, "Tensor does not have rowwise data!"); + return reinterpret_cast(cpu_data_rowwise_.get()); } template - T *cpu_dptr() const { + T *columnwise_cpu_dptr() const { NVTE_CHECK(TypeInfo::dtype == tensor_.dtype(), "Invalid type!"); - return reinterpret_cast(cpu_data_.get()); + NVTE_CHECK(columnwise_, "Tensor does not have columnwise data!"); + return reinterpret_cast(cpu_data_columnwise_.get()); } float amax() const { @@ -145,6 +209,7 @@ class Tensor { float scale() const { if(scale_cpu_data_) { + NVTE_CHECK(tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING, "Invalid scaling_mode!"); to_cpu(); return *scale_cpu_data_; } else { @@ -152,52 +217,246 @@ class Tensor { } } - float scale_inv() const { - if(scale_inv_cpu_data_) { - to_cpu(); - return *scale_inv_cpu_data_; + template + T *rowwise_cpu_scale_inv_ptr(){ + if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING){ + NVTE_CHECK(TypeInfo::dtype == DType::kFloat32, "Invalid type!"); + } else { + NVTE_CHECK(TypeInfo::dtype == DType::kByte, "Invalid type!"); + } + to_cpu(); + return reinterpret_cast(rowwise_scale_inv_cpu_data_.get()); + } + + template + T *columnwise_cpu_scale_inv_ptr(){ + if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING){ + NVTE_CHECK(TypeInfo::dtype == DType::kFloat32, "Invalid type!"); + } else { + NVTE_CHECK(TypeInfo::dtype == DType::kByte, "Invalid type!"); + } + to_cpu(); + return reinterpret_cast(columnwise_scale_inv_cpu_data_.get()); + } + + float rowwise_scale_inv(){ + if(rowwise_scale_inv_cpu_data_) { + float scale_inv = rowwise_cpu_scale_inv_ptr()[0]; + return scale_inv; } else { return 1; } } + bool rowwise() const { + return rowwise_; + } + + bool columnwise() const { + return columnwise_; + } + void to_cpu() const; void from_cpu() const; void set_scale(float scale); void set_scale_inv(float scale_inv); void shareFP8Meta(const Tensor &other); + std::mt19937& gen() { return gen_; } + private: TensorWrapper tensor_; - std::unique_ptr cpu_data_; + std::unique_ptr cpu_data_rowwise_; + std::unique_ptr cpu_data_columnwise_; std::shared_ptr amax_cpu_data_; std::shared_ptr scale_cpu_data_; - std::shared_ptr scale_inv_cpu_data_; + std::unique_ptr rowwise_scale_inv_cpu_data_; + std::unique_ptr columnwise_scale_inv_cpu_data_; + bool rowwise_; + bool columnwise_; + std::string name_; + std::mt19937 gen_; +}; + +constexpr uint32_t FP32_EXPONENT_BIAS = 127; +constexpr uint32_t FP32_MANTISSA_BITS = 23; + +// [128,4] rowwise and [4,128] colwise alignment requirement +constexpr size_t scale_tensor_alignment_X_rowwise = 4; +constexpr size_t scale_tensor_alignment_Y_rowwise = 128; +constexpr size_t scale_tensor_alignment_X_colwise = 128; +constexpr size_t scale_tensor_alignment_Y_colwise = 4; + +inline size_t divide_round_up(const size_t N, const size_t M) { + return (N - 1 + M) / M; +} + +inline size_t round_up_to_nearest_multiple(const size_t N, const size_t M) { + return divide_round_up(N, M) * M; +} + +template +struct Numeric_Traits { + static constexpr double minSubnorm = 1.0; + static constexpr double maxSubnorm = 1.0; + static constexpr double minNorm = 1.0; + static constexpr double maxNorm = 1.0; + static constexpr double artifInf = 1.0; + static constexpr int maxBiasedExponent = 1; +}; + +template <> +struct Numeric_Traits { + static constexpr double minSubnorm = 1.0 / static_cast(1 << 9); // std::pow(2.0, -9.0); + static constexpr double maxSubnorm = 0.875 / static_cast(1 << 6); // std::pow(2.0, -6.0); + static constexpr double minNorm = 1.0 / static_cast(1 << 6); // std::pow(2.0, -6.0); + static constexpr double maxNorm = 448.0; + static constexpr double artifInf = 10.0 * maxNorm; // artificial Infinity + static constexpr int maxBiasedExponentAsFP32 = 8 + FP32_EXPONENT_BIAS; + static constexpr int maxUnbiasedExponentAsFP32 = 8; + static constexpr int maxExpNorm = 1 << maxUnbiasedExponentAsFP32; +}; + +template <> +struct Numeric_Traits { + static constexpr double minSubnorm = 1.0 / static_cast(1 << 16); // std::pow(2.0, -16.0); + static constexpr double maxSubnorm = 0.75 / static_cast(1 << 14); // std::pow(2.0, -14.0); + static constexpr double minNorm = 1.0 / static_cast(1 << 14); // std::pow(2.0, -14.0); + static constexpr double maxNorm = 57344.0; + static constexpr double artifInf = 10.0 * maxNorm; // artificial Infinity + static constexpr int maxBiasedExponentAsFP32 = 15 + FP32_EXPONENT_BIAS; + static constexpr int maxUnbiasedExponentAsFP32 = 15; + static constexpr int maxExpNorm = 1 << maxUnbiasedExponentAsFP32; +}; + +template <> +struct Numeric_Traits { + static constexpr double minSubnorm = std::numeric_limits::denorm_min(); // std::pow(2.0, -149.0); + static constexpr double maxSubnorm = std::numeric_limits::min() + - std::numeric_limits::denorm_min(); // minNormalized - minDenormalized + static constexpr double minNorm = std::numeric_limits::min(); // std::pow(2.0, -126.0); + static constexpr double maxNorm = std::numeric_limits::max(); // (1 - pow(2, -24)) * pow(2, 128) + static constexpr double artifInf = std::numeric_limits::infinity(); + static constexpr int maxBiasedExponentAsFP32 = 255; + static constexpr int maxUnbiasedExponentAsFP32 = 128; +}; + +template +struct Quantized_Limits { + static constexpr double ranges[] = { + 0.0, + Numeric_Traits::minNorm, + Numeric_Traits::maxNorm, + Numeric_Traits::artifInf + }; + static constexpr inline fp32 max() { return static_cast(Numeric_Traits::maxNorm); } + static constexpr inline fp32 max_reciprocal() { return static_cast(1.0 / max()); } + static constexpr inline fp32 emax() { return static_cast(Numeric_Traits::maxExpNorm); } + static constexpr inline fp32 emax_reciprocal() { return static_cast(1.0 / emax()); } + static constexpr inline int max_norm_biased_exponent() { return Numeric_Traits::maxBiasedExponentAsFP32; } + static constexpr inline int max_norm_unbiased_exponent() { return Numeric_Traits::maxUnbiasedExponentAsFP32; } +}; + +// Input data filling cases +// Considering normal and subnormal magnitudes of E4M3 and E5M2 formats +// with nearest to even rounding per OFP8 specification +enum InputsFillCase { + zero_to_minNorm = 0, // [0, min_normal) + minNorm_to_maxNorm = 1, // [min_normal, max_normal) + maxNorm_to_inf = 2, // [max_normal, inf) + zeros = 3, // {0} + uniform = 4, // std::uniform_real_distribution<> dis(-2.0, 1.0) }; +inline fp8e8m0 float_to_e8m0(float val) { + // TODO: nan/inf needs to be set for any value + // of nan/inf in input not just amax. + if (std::isnan(val)) { + return 0xFF; + } + if (std::isinf(val)) { + return 0xFE; + } + if (val == 0.0f) { + return 0x00; + } + uint32_t val_u32 = *reinterpret_cast(&val); + fp8e8m0 exponent = (val_u32 >> FP32_MANTISSA_BITS); + uint32_t mantissa = val_u32 & 0x7FFFFF; + // Round up exponent and deal with satfinite. + if ((mantissa > 0 && exponent != 0xFE) && !(exponent == 0 && mantissa <= 0x400000)) { + ++exponent; + } + return exponent; +} + +inline float exp2f_rcp(fp8e8m0 biased_exp) { + return (biased_exp == 0) ? 1 : exp2f(FP32_EXPONENT_BIAS - static_cast(biased_exp)); +} + +inline float identity(const float x) { return x; } +inline float gelu(const float x) { return x * (0.5f + 0.5f * tanhf(x * (0.79788456f + 0.03567741f * x * x))); } +inline float dgelu(const float x) { + const float tanh_out = tanhf(0.79788456f * x * (1 + 0.044715f * x * x)); + return 0.5f * x * ((1 - tanh_out * tanh_out) * (0.79788456f + 0.1070322243f * x * x)) + + 0.5f * (1 + tanh_out); +} +inline float sigmoid(const float x) { return 1 / (1 + expf(-x)); } +inline float dsigmoid(const float x) { return sigmoid(x) * (1 - sigmoid(x)); } +inline float qgelu(const float x) { return x * sigmoid(1.702f * x); } +inline float dqgelu(const float x) { return 1.702f * x * dsigmoid(1.702f * x) + sigmoid(1.702f * x); } +inline float relu(const float x) { return fmaxf(0, x); } +inline float drelu(const float x) { return x > 0 ? 1 : 0; } +inline float silu(const float x) { return x * sigmoid(x); } +inline float dsilu(const float x) { return x * dsigmoid(x) + sigmoid(x); } +inline float srelu(const float x) { return x > 0 ? x * x : 0; } +inline float dsrelu(const float x) { return fmaxf(0, 2 * x); } + size_t typeToSize(DType type); size_t product(const NVTEShape &shape); +size_t product(const std::vector &shape); + +size_t first_dimension(const std::vector &shape); +size_t last_dimension(const std::vector &shape); bool areShapesEqual(const NVTEShape &s1, const NVTEShape &s2); void compareResults(const std::string &name, const Tensor &test, const void *ref, - double atol = 1e-5, double rtol = 1e-8); + bool rowwise, double atol = 1e-5, double rtol = 1e-8, bool if_on_gpus = true); void compareResults(const std::string &name, const float test, const float ref, double atol = 1e-5, double rtol = 1e-8); +void compareResults(const std::string &name, const uint8_t *test, const uint8_t *ref, + size_t N, float mismatch_rate_tol = 0.); +void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, const uint8_t *ref, + const size_t row_blocks, const size_t col_blocks, const size_t stride); +void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, const uint8_t *ref, + const size_t N); + +std::array get_scale_tensor_dims(const size_t rows, const size_t cols, + const size_t block_size_rows, const size_t block_size_cols); std::pair getTolerances(const DType type); void fillUniform(Tensor *t); + +template +void fillCase(Tensor *t, const InputsFillCase fill_case); + void setRandomScale(Tensor *t); +void setRandomScaleInv(Tensor *t); constexpr int THREADS_PER_WARP = 32; const std::string &typeName(DType type); +const std::string& caseName(InputsFillCase type); extern std::vector all_fp_types; bool isFp8Type(DType type); +int32_t getDeviceComputeCapability(); +constexpr int32_t blackwellComputeCapability = 100; + } // namespace test #define TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(dtype, type, ...) \ @@ -254,3 +513,47 @@ bool isFp8Type(DType type); default: \ NVTE_ERROR("Invalid type."); \ } + +#define TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(dtype, type, ...) \ + switch (dtype) { \ + using namespace transformer_engine; \ + case DType::kFloat8E4M3: \ + { \ + using type = fp8e4m3; \ + {__VA_ARGS__} \ + } \ + break; \ + case DType::kFloat8E5M2: \ + { \ + using type = fp8e5m2; \ + {__VA_ARGS__} \ + } \ + break; \ + default: \ + NVTE_ERROR("Invalid type."); \ + } + +#define TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(dtype, type, ...) \ + switch (dtype) { \ + using namespace transformer_engine; \ + case DType::kFloat32: \ + { \ + using type = float; \ + {__VA_ARGS__} \ + } \ + break; \ + case DType::kFloat16: \ + { \ + using type = fp16; \ + {__VA_ARGS__} \ + } \ + break; \ + case DType::kBFloat16: \ + { \ + using type = bf16; \ + {__VA_ARGS__} \ + } \ + break; \ + default: \ + NVTE_ERROR("Invalid type."); \ + } diff --git a/tests/cpp/util/CMakeLists.txt b/tests/cpp/util/CMakeLists.txt index ffa05f0d66..7540687089 100644 --- a/tests/cpp/util/CMakeLists.txt +++ b/tests/cpp/util/CMakeLists.txt @@ -8,8 +8,9 @@ add_executable(test_util ../test_common.cu) -target_link_libraries(test_util PUBLIC CUDA::cudart GTest::gtest_main ${TE_LIB} CUDA::nvrtc CUDNN::cudnn) -target_compile_options(test_util PRIVATE -O2) +find_package(OpenMP REQUIRED) +target_link_libraries(test_util PUBLIC CUDA::cudart GTest::gtest_main ${TE_LIB} CUDA::nvrtc CUDNN::cudnn OpenMP::OpenMP_CXX) +target_compile_options(test_util PRIVATE -O2 -fopenmp) include(GoogleTest) -gtest_discover_tests(test_util) +gtest_discover_tests(test_util DISCOVERY_TIMEOUT 600) diff --git a/tests/jax/conftest.py b/tests/jax/conftest.py index 920f9dc62e..d1558710c7 100644 --- a/tests/jax/conftest.py +++ b/tests/jax/conftest.py @@ -27,9 +27,6 @@ def enable_fused_attn_after_hopper(): """ if get_device_compute_capability(0) >= 90: os.environ["NVTE_FUSED_ATTN"] = "1" - os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "0" yield if "NVTE_FUSED_ATTN" in os.environ: del os.environ["NVTE_FUSED_ATTN"] - if "NVTE_ALLOW_NONDETERMINISTIC_ALGO" in os.environ: - del os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] diff --git a/tests/jax/test_layer.py b/tests/jax/test_layer.py index e6ad8ce20c..a67335236d 100644 --- a/tests/jax/test_layer.py +++ b/tests/jax/test_layer.py @@ -4,14 +4,19 @@ """Test transformer_engine.jax.flax.TransformerLayer""" import os from functools import partial -from typing import Dict, Tuple +from typing import Dict, Tuple, Optional import flax import jax import jax.numpy as jnp import pytest -from utils import assert_allclose, assert_tree_like_allclose, sync_params_values +from utils import ( + assert_allclose, + assert_tree_like_allclose, + dtype_tols, + sync_params_values, +) from utils import DecoderLayer as RefDecoderLayer from utils import EncoderLayer as RefEncoderLayer @@ -250,7 +255,13 @@ def _sync_params(self, ref, target): target = sync_params_values(target, ref, self.transformations) return ref, target - def test_forward(self, data_shape, dtype, rtol=1e-05, atol=1e-08): + def test_forward( + self, + data_shape: Tuple[int], + dtype: jnp.dtype, + rtol: Optional[float] = None, + atol: Optional[float] = None, + ) -> None: """Test only the forward""" inputs, (ref_masks, test_masks) = self.generate_inputs(data_shape, dtype) @@ -264,9 +275,16 @@ def test_forward(self, data_shape, dtype, rtol=1e-05, atol=1e-08): ref_out = self._loss_fn(inputs, ref_masks, ref_params, ref_others, ref_layer) test_out = self._loss_fn(inputs, test_masks, test_params, test_others, test_layer) - assert_allclose(ref_out, test_out, rtol=rtol, atol=atol) + tols = dtype_tols(dtype, rtol=rtol, atol=atol) + assert_allclose(ref_out, test_out, **tols) - def test_backward(self, data_shape, dtype, rtol=1e-05, atol=1e-08): + def test_backward( + self, + data_shape: Tuple[int], + dtype: jnp.dtype, + rtol: Optional[float] = None, + atol: Optional[float] = None, + ) -> None: """Test forward and backward through value_and_grad()""" inputs, (ref_masks, test_masks) = self.generate_inputs(data_shape, dtype) @@ -302,11 +320,12 @@ def test_backward(self, data_shape, dtype, rtol=1e-05, atol=1e-08): inputs, test_masks, test_params, test_others, test_layer ) - assert_allclose(ref_out, test_out, rtol=rtol, atol=atol) - assert_tree_like_allclose(ref_dgrads, test_dgrads, rtol=rtol, atol=atol) + tols = dtype_tols(dtype, rtol=rtol, atol=atol) + assert_allclose(ref_out, test_out, **tols) + assert_tree_like_allclose(ref_dgrads, test_dgrads, **tols) _, restructed_ref_wgrads = self._sync_params(ref_wgrads, test_wgrads) - assert_tree_like_allclose(restructed_ref_wgrads, test_wgrads, rtol=rtol, atol=atol) + assert_tree_like_allclose(restructed_ref_wgrads, test_wgrads, **tols) class EncoderRunner(BaseRunner): @@ -418,12 +437,12 @@ class BaseTester: def test_forward(self, data_shape, dtype, attrs): """Test normal datatype forward""" FP8Helper.finalize() # Ensure FP8 disabled. - self.runner(attrs).test_forward(data_shape, dtype, rtol=1e-5, atol=7e-5) + self.runner(attrs).test_forward(data_shape, dtype) def test_backward(self, data_shape, dtype, attrs): """Test normal datatype backward""" FP8Helper.finalize() # Ensure FP8 disabled. - self.runner(attrs).test_backward(data_shape, dtype, rtol=1e-5, atol=7e-5) + self.runner(attrs).test_backward(data_shape, dtype) @pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.parametrize("fp8_format", FP8_FORMATS) diff --git a/tests/jax/utils.py b/tests/jax/utils.py index 9cb02bc555..554def2c3f 100644 --- a/tests/jax/utils.py +++ b/tests/jax/utils.py @@ -1387,18 +1387,26 @@ def assert_tree_like_allclose(expected, actual, rtol=1e-05, atol=1e-08): def dtype_tols( dtype: Union[DType, TEDType, np.dtype], reference_value: float = 1.0, + rtol: Optional[float] = None, + atol: Optional[float] = None, ) -> Dict[str, float]: """Expected numerical tolerance for a data type. Args: dtype: data type. reference_value: reference value (default: 1). + rtol: override for relative tolerance estimate + atol: override for absolute tolerance estimate Returns: Dictionary with "rtol" and "atol" as keys """ + # Return immediately if tolerances are fully specified + if rtol is not None and atol is not None: + return {"rtol": rtol, "atol": atol} + # Convert to JAX dtype if needed if isinstance(dtype, TEDType): dtype = { @@ -1416,7 +1424,11 @@ def dtype_tols( # Expect bit-wise accuracy for integer dtypes if not jnp.issubdtype(dtype, jnp.floating): - return dict(rtol=0, atol=0) + if rtol is None: + rtol = 0.0 + if atol is None: + atol = 0.0 + return {"rtol": rtol, "atol": atol} # Estimate floating-point error finfo = jnp.finfo(dtype) @@ -1429,10 +1441,11 @@ def dtype_tols( spacing_high = jnp.nextafter(reference_value, finfo.max) - reference_value spacing_low = reference_value - jnp.nextafter(reference_value, finfo.min) ulp = max(spacing_high.item(), spacing_low.item()) - return dict( - rtol=eps_relaxed, - atol=max(ulp, eps_relaxed), - ) + if rtol is None: + rtol = eps_relaxed + if atol is None: + atol = max(ulp, eps_relaxed) + return {"rtol": rtol, "atol": atol} def sync_params_values(dst, src, transformations, sep="/"): diff --git a/tests/paddle/dist_launcher.py b/tests/paddle/dist_launcher.py deleted file mode 100644 index f262f1a1d4..0000000000 --- a/tests/paddle/dist_launcher.py +++ /dev/null @@ -1,145 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""Helper functions to launch distributed tests""" - -import copy -import os -from pathlib import Path -import subprocess -import time -import unittest - -try: - from paddle.base import core -except ImportError: - from paddle.fluid import core -from paddle.distributed.utils.launch_utils import ( - TrainerProc, - find_free_ports, - get_cluster, - watch_local_trainers, -) - -__all__ = ["TestDistributed"] - - -def get_cluster_from_args(selected_gpus): - """Get node information from selected GPUs""" - cluster_node_ips = "127.0.0.1" - node_ip = "127.0.0.1" - - node_ips = [x.strip() for x in cluster_node_ips.split(",")] - - node_ips.index(node_ip) - - free_ports = None - - free_ports = find_free_ports(len(selected_gpus)) - if free_ports is not None: - free_ports = list(free_ports) - - trainer_endpoints = [] - for ip in node_ips: - trainer_endpoints.append([f"{ip}:{port}" for port in free_ports]) - return get_cluster(node_ips, node_ip, trainer_endpoints, selected_gpus) - - -def get_gpus(selected_gpus): - """Get selected GPU string""" - selected_gpus = [x.strip() for x in selected_gpus.split(",")] - return selected_gpus - - -def start_local_trainers( - cluster, - pod, - training_script, - training_script_args, - allocator_strategy="auto_growth", -): - """Launch trainers""" - current_env = copy.copy(os.environ.copy()) - # paddle broadcast ncclUniqueId use socket, and - # proxy maybe make trainers unreachable, so delete them. - # if we set them to "", grpc will log error message "bad uri" - # so just delete them. - current_env.pop("http_proxy", None) - current_env.pop("https_proxy", None) - - procs = [] - for t in pod.trainers: - proc_env = { - "FLAGS_selected_gpus": ",".join([str(g) for g in t.gpus]), - "PADDLE_TRAINER_ID": f"{t.rank}", - "PADDLE_CURRENT_ENDPOINT": f"{t.endpoint}", - "PADDLE_TRAINERS_NUM": f"{cluster.trainers_nranks()}", - "PADDLE_TRAINER_ENDPOINTS": ",".join(cluster.trainers_endpoints()), - "PYTHONPATH": str(Path(__file__).resolve().parent), - } - - proc_env["FLAGS_allocator_strategy"] = allocator_strategy - if allocator_strategy == "auto_growth": - proc_env["FLAGS_fraction_of_gpu_memory_to_use"] = "0.1" - - current_env.update(proc_env) - - print(f"trainer proc env:{current_env}") - - if os.getenv("WITH_COVERAGE", "OFF") == "ON": - cmd = "python -m coverage run --branch -p " + training_script - else: - cmd = "python -u " + training_script - - print(f"start trainer proc:{cmd} env:{proc_env}") - - fn = None - - proc = subprocess.Popen( - cmd.split(" ") + training_script_args, env=current_env - ) # pylint: disable=consider-using-with - - tp = TrainerProc() - tp.proc = proc - tp.rank = t.rank - tp.log_fn = fn - tp.cmd = cmd - - procs.append(tp) - - return procs - - -class TestDistributed(unittest.TestCase): - """Base class for distributed test""" - - @staticmethod - def run_2gpu( - target_file_name, - allocator_strategy="auto_growth", - ): - """Run target file in subprocesses""" - if not core.is_compiled_with_cuda() or core.get_cuda_device_count() == 0: - return - - selected_gpus = get_gpus("0,1") - cluster = None - pod = None - - cluster, pod = get_cluster_from_args(selected_gpus) - - procs = start_local_trainers( - cluster, - pod, - allocator_strategy=allocator_strategy, - training_script=target_file_name, - training_script_args=[], - ) - - while True: - alive = watch_local_trainers(procs, cluster.trainers_endpoints()) - - if not alive: - print(f"Local procs complete, POD info:{pod}") - break - time.sleep(3) diff --git a/tests/paddle/parallel_tests/amax_reduction.py b/tests/paddle/parallel_tests/amax_reduction.py deleted file mode 100644 index 3e0a6d2bac..0000000000 --- a/tests/paddle/parallel_tests/amax_reduction.py +++ /dev/null @@ -1,87 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""Unittest for Linear layer in tensor parallel""" - -import unittest - -import paddle -from paddle.distributed import fleet - -from utils import assert_allclose, set_random_seed -import transformer_engine.paddle as te - - -def assert_allclose_across_ranks(tensor, group=None): - """Assert tensor is identical in all ranks""" - gathered_list = [] - paddle.distributed.all_gather(gathered_list, tensor, group=group) - assert len(gathered_list) > 1 - for gathered_tensor in gathered_list: - assert_allclose(tensor, gathered_tensor) - - -class TestAmaxReduction(unittest.TestCase): - """Tests Amax reduction""" - - def setUp(self): - self.data_parallel_size = 2 - self.init_dist_env() - self.global_dtype = "bfloat16" - paddle.set_default_dtype(self.global_dtype) - - def init_dist_env(self): - """Init Paddle Fleet environment""" - strategy = fleet.DistributedStrategy() - strategy.hybrid_configs = { - "dp_degree": self.data_parallel_size, - "mp_degree": 1, - "pp_degree": 1, - } - fleet.init(is_collective=True, strategy=strategy) - - def test_amax_reduction(self): - """Tests column parallel linear""" - set_random_seed(1024) - layer1 = te.Linear(16, 16) - layer2 = te.Linear(16, 16) - model = paddle.nn.Sequential(layer1, layer2) - model = fleet.distributed_model(model) - - rank_id = paddle.distributed.get_rank() - set_random_seed(rank_id) - - optimizer = paddle.optimizer.SGD(learning_rate=10.0, parameters=model.parameters()) - optimizer = fleet.distributed_optimizer(optimizer) - - def train_one_step(layer, inp, optimizer): - inp = paddle.to_tensor(inp) - inp.stop_gradient = False - out = layer(inp) - loss = out.mean() - loss.backward() - optimizer.step() - optimizer.clear_grad() - return loss - - for _ in range(5): - inp = paddle.uniform([16, 16], self.global_dtype) - with te.fp8_autocast(enabled=True): - train_one_step(model, inp, optimizer) - - assert_allclose_across_ranks(layer1.fp8_meta["scaling_fwd"].amax_history[-1]) - assert_allclose_across_ranks(layer1.fp8_meta["scaling_fwd"].scale) - assert_allclose_across_ranks(layer1.fp8_meta["scaling_fwd"].scale_inv) - assert_allclose_across_ranks(layer2.fp8_meta["scaling_fwd"].amax_history[-1]) - assert_allclose_across_ranks(layer2.fp8_meta["scaling_fwd"].scale) - assert_allclose_across_ranks(layer2.fp8_meta["scaling_fwd"].scale_inv) - assert_allclose_across_ranks(layer1.fp8_meta["scaling_bwd"].amax_history[-1]) - assert_allclose_across_ranks(layer1.fp8_meta["scaling_bwd"].scale) - assert_allclose_across_ranks(layer1.fp8_meta["scaling_bwd"].scale_inv) - assert_allclose_across_ranks(layer2.fp8_meta["scaling_bwd"].amax_history[-1]) - assert_allclose_across_ranks(layer2.fp8_meta["scaling_bwd"].scale) - assert_allclose_across_ranks(layer2.fp8_meta["scaling_bwd"].scale_inv) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/paddle/parallel_tests/attention_tp.py b/tests/paddle/parallel_tests/attention_tp.py deleted file mode 100644 index c0ffa288ee..0000000000 --- a/tests/paddle/parallel_tests/attention_tp.py +++ /dev/null @@ -1,234 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""Unittest for Transformer layer in tensor parallel""" - -import unittest - -import paddle -from paddle.distributed import fleet -from paddle.distributed.fleet.layers.mpu import mp_ops - -from utils import assert_allclose, set_random_seed, register_sequence_parallel_allreduce_hooks -import transformer_engine.paddle as te - - -class TestAttentionTp(unittest.TestCase): - """Tests MultiHeadAttention layer with model parallel in BF16""" - - def setUp(self): - self.set_attr() - self.init_dist_env() - paddle.set_default_dtype(self.global_dtype) - - def init_dist_env(self): - """Init Paddle Fleet environment""" - strategy = fleet.DistributedStrategy() - self.model_parallel_size = 2 - strategy.hybrid_configs = { - "dp_degree": 1, - "mp_degree": self.model_parallel_size, - "pp_degree": 1, - } - strategy.hybrid_configs["mp_configs"].need_broadcast_data = False - fleet.init(is_collective=True, strategy=strategy) - self.rank = fleet.worker_index() - self.hcg = fleet.get_hybrid_communicate_group() - self.tp_group = self.hcg.get_model_parallel_group() - self.world_size = self.hcg.get_model_parallel_world_size() - - def set_attr(self): - """Set test configs""" - self.batch_size = 16 - self.hidden_size = 1024 - self.num_heads = 16 - self.q_seqlen = 128 - self.kv_seqlen = 128 - self.mask_type = "padding" - self.global_dtype = "bfloat16" - self.rtol = 5e-3 - self.atol = 5e-3 - self.eps = 1e-3 - self.fp8 = False - self.sequence_parallel = False - - def _train_one_step(self, layer, inp_list, optimizer, fp8_enabled, sequence_parallel=False): - inp, mask = inp_list - if sequence_parallel: - split_size = inp.shape[0] // self.world_size - input_parallel = inp[split_size * self.rank : split_size * (self.rank + 1), :] - else: - input_parallel = inp - with te.fp8_autocast(enabled=fp8_enabled): - out = layer(input_parallel, mask) - if sequence_parallel: - total_out = mp_ops._c_concat(out, group=self.tp_group) - total_out = paddle.concat(paddle.split(total_out, self.world_size, axis=-1), axis=0) - else: - total_out = out - loss = total_out.mean() - loss.backward() - optimizer.step() - optimizer.clear_grad() - return loss, total_out - - def test_parallel_layer(self): - """Tests parallel Transformer""" - set_random_seed(1024) - common_args = ( - self.hidden_size, - self.num_heads, - ) - common_kwargs = { - "layernorm_epsilon": self.eps, - "attention_dropout": 0.0, - "attn_mask_type": self.mask_type, - "attention_type": "self", - "tp_group": self.tp_group, - "input_layernorm": True, - } - - layer_tp = te.MultiHeadAttention( - *common_args, - **common_kwargs, - set_parallel_mode=True, - sequence_parallel=self.sequence_parallel, - ) - layer_single = te.MultiHeadAttention(*common_args, **common_kwargs, set_parallel_mode=False) - - def _get_total_weight(local_weight, tp_group, axis, interleave=False): - total_weight = [] - partial_weight = local_weight.clone().detach() - paddle.distributed.all_gather(total_weight, partial_weight, group=tp_group) - if interleave: - # Due to the interleaved qkv layout, need to concat on num_head - # dimension for column parallel linear in MultiHeadAttention layer - assert axis == 0 - assert [ - 3 * self.hidden_size // self.world_size, - self.hidden_size, - ] == partial_weight.shape - local_num_head = self.num_heads // self.world_size - for idx, _ in enumerate(total_weight): - total_weight[idx] = total_weight[idx].reshape( - [3, local_num_head, -1, self.hidden_size] - ) - total_weight = paddle.concat(total_weight, axis=1).reshape([-1, self.hidden_size]) - else: - total_weight = paddle.concat(total_weight, axis=axis) - return total_weight - - def _get_weight(obj, weight_names): - for name in weight_names: - obj = getattr(obj, name) - return obj - - def copy_weight(layer_src, layer_dst, partition_mode, weight_names, interleave=False): - weight_src = _get_weight(layer_src, weight_names) - weight_dst = _get_weight(layer_dst, weight_names) - if partition_mode is None: - total_weight = weight_src - elif partition_mode == "column": - total_weight = _get_total_weight( - weight_src, tp_group=self.tp_group, axis=0, interleave=interleave - ) - elif partition_mode == "row": - total_weight = _get_total_weight(weight_src, tp_group=self.tp_group, axis=1) - else: - raise ValueError(f"Partition Mode {partition_mode} is not supported.") - assert ( - weight_dst.shape == total_weight.shape - ), f"Shapes of src:{total_weight.shape} and dst:{weight_dst.shape} do not match." - weight_dst.copy_(total_weight, True) - - copy_weight(layer_tp, layer_single, None, ["layernorm_qkv", "ln_weight"]) - copy_weight(layer_tp, layer_single, "column", ["layernorm_qkv", "weight"], interleave=True) - copy_weight(layer_tp, layer_single, "row", ["proj", "weight"]) - - if self.sequence_parallel: - register_sequence_parallel_allreduce_hooks(layer_tp, accumulation_steps=1) - - optimizer_tp = paddle.optimizer.SGD(learning_rate=0.01, parameters=layer_tp.parameters()) - optimizer_single = paddle.optimizer.SGD( - learning_rate=0.01, parameters=layer_single.parameters() - ) - - layer_tp = fleet.distributed_model(layer_tp) - optimizer_tp = fleet.distributed_optimizer(optimizer_tp) - - for _ in range(5): - inp = paddle.uniform( - [self.batch_size, self.q_seqlen, self.hidden_size], self.global_dtype - ) - mask = paddle.zeros( - shape=(self.batch_size, 1, self.q_seqlen, self.kv_seqlen), dtype="bool" - ) - loss_tp, out_tp = self._train_one_step( - layer_tp, [inp, mask], optimizer_tp, self.fp8, self.sequence_parallel - ) - loss_single, out_single = self._train_one_step( - layer_single, [inp, mask], optimizer_single, self.fp8 - ) - assert_allclose(out_tp, out_single, rtol=self.rtol, atol=self.atol) - assert_allclose(loss_tp, loss_single, rtol=self.rtol, atol=self.atol) - - -class TestAttentionTpFp8(TestAttentionTp): - """Tests MultiHeadAttention layer with model parallel in FP8""" - - def set_attr(self): - """Set test configs""" - self.batch_size = 16 - self.hidden_size = 1024 - self.num_heads = 16 - self.q_seqlen = 128 - self.kv_seqlen = 128 - self.mask_type = "padding" - self.global_dtype = "bfloat16" - self.rtol = 5e-2 - self.atol = 5e-2 - self.eps = 1e-3 - self.fp8 = True - self.sequence_parallel = False - - -class TestAttentionSp(TestAttentionTp): - """Tests MultiHeadAttention layer with sequence parallel in BF16""" - - def set_attr(self): - """Set test configs""" - self.batch_size = 16 - self.hidden_size = 1024 - self.num_heads = 16 - self.q_seqlen = 128 - self.kv_seqlen = 128 - self.mask_type = "padding" - self.global_dtype = "bfloat16" - self.rtol = 5e-3 - self.atol = 5e-3 - self.eps = 1e-3 - self.fp8 = False - self.sequence_parallel = True - - -class TestAttentionSpFp8(TestAttentionTp): - """Tests MultiHeadAttention layer with sequence parallel in FP8""" - - def set_attr(self): - """Set test configs""" - self.batch_size = 16 - self.hidden_size = 1024 - self.num_heads = 16 - self.q_seqlen = 128 - self.kv_seqlen = 128 - self.mask_type = "padding" - self.global_dtype = "bfloat16" - self.rtol = 5e-2 - self.atol = 1e-1 - self.eps = 1e-3 - self.fp8 = True - self.sequence_parallel = True - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/paddle/parallel_tests/group_sharding.py b/tests/paddle/parallel_tests/group_sharding.py deleted file mode 100644 index 21d08a8ef3..0000000000 --- a/tests/paddle/parallel_tests/group_sharding.py +++ /dev/null @@ -1,188 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""Unittest for group sharding""" - -import unittest - -import paddle -from paddle.distributed import fleet -from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer import ( - DygraphShardingOptimizer, -) - -from utils import assert_allclose, set_random_seed -import transformer_engine.paddle as te - - -class TestGroupSharding(unittest.TestCase): - """Tests group sharding""" - - def setUp(self): - self.set_attr() - self.init_dist_env() - paddle.set_default_dtype(self.global_dtype) - - def set_attr(self): - """Set test configs""" - self.sharding_degree = 2 - self.global_dtype = "float32" - self.rtol = 1e-5 - self.atol = 1e-5 - self.batch_size = 16 - self.in_channels = 16 - self.out_channels = 32 - self.fp8 = False - - def init_dist_env(self): - """Init Paddle Fleet environment""" - strategy = fleet.DistributedStrategy() - strategy.hybrid_configs = { - "dp_degree": 1, - "mp_degree": 1, - "pp_degree": 1, - "sharding_degree": self.sharding_degree, - } - self.strategy = strategy - fleet.init(is_collective=True, strategy=strategy) - - def _get_model_and_optimizer(self, model, stage): - if stage == 1: - optimizer = DygraphShardingOptimizer( - paddle.optimizer.AdamW(learning_rate=0.01, parameters=model.parameters()), - fleet.get_hybrid_communicate_group(), - ) - model = fleet.distributed_model(model) - optimizer = fleet.distributed_optimizer(optimizer) - elif stage in [2, 3]: - optimizer = paddle.optimizer.AdamW(learning_rate=0.01, parameters=model.parameters()) - group = fleet.get_hybrid_communicate_group().get_sharding_parallel_group() - - class ShardingLevel: # pylint: disable=too-few-public-methods, - """Paddle sharding options""" - - kStage1 = "os" - kStage2 = "os_g" - kStage3 = "p_g_os" - - level = ShardingLevel.kStage3 if stage == 3 else ShardingLevel.kStage2 - model, optimizer, _ = paddle.distributed.sharding.group_sharded_parallel( - model=model, - optimizer=optimizer, - level=level, - group=group, - segment_size=256, - ) - else: - raise ValueError(f"Stage {stage} not supported") - return model, optimizer - - def test_group_sharding_stage1(self): - """Tests group sharding training""" - set_random_seed(1024) - model_te = te.Linear(self.in_channels, self.out_channels) - model_pd = paddle.nn.Linear(self.in_channels, self.out_channels) - model_pd.weight.copy_(model_te.weight.T, True) - model_pd.bias.copy_(model_te.bias, True) - - model_te, optimizer_te = self._get_model_and_optimizer(model_te, stage=1) - model_pd, optimizer_pd = self._get_model_and_optimizer(model_pd, stage=1) - - rank_id = paddle.distributed.get_rank() - paddle.seed(rank_id) - - def train_one_step(model, inp, optimizer): - out = model(inp) - loss = out.mean() - loss.backward() - optimizer.step() - optimizer.clear_grad() - return loss - - for _ in range(5): - inp = paddle.uniform([self.batch_size, self.in_channels], self.global_dtype) - with te.fp8_autocast(enabled=False): - loss_te = train_one_step(model_te, inp, optimizer_te) - loss_pd = train_one_step(model_pd, inp, optimizer_pd) - assert_allclose(loss_te, loss_pd, rtol=self.rtol, atol=self.atol) - - assert ( - len(optimizer_te.state_dict()) == 4 - ), "Expect each rank to hold 4 optimizer state entries." - - def test_group_sharding_stage2(self): - """Tests group sharding training""" - set_random_seed(1024) - model_te = te.Linear(self.in_channels, self.out_channels) - model_pd = paddle.nn.Linear(self.in_channels, self.out_channels) - model_pd.weight.copy_(model_te.weight.T, True) - model_pd.bias.copy_(model_te.bias, True) - - model_te, optimizer_te = self._get_model_and_optimizer(model_te, stage=2) - model_pd, optimizer_pd = self._get_model_and_optimizer(model_pd, stage=2) - - rank_id = paddle.distributed.get_rank() - paddle.seed(rank_id) - - def train_one_step(model, inp, optimizer): - out = model(inp) - loss = out.mean() - loss.backward() - # Check gradients are split to different trainers - if rank_id == 0: - assert model.bias.grad is None and model.weight.grad is not None - elif rank_id == 1: - assert model.weight.grad is None and model.bias.grad is not None - optimizer.step() - optimizer.clear_grad() - return loss - - for _ in range(5): - inp = paddle.uniform([self.batch_size, self.in_channels], self.global_dtype) - with te.fp8_autocast(enabled=False): - loss_te = train_one_step(model_te, inp, optimizer_te) - loss_pd = train_one_step(model_pd, inp, optimizer_pd) - assert_allclose(loss_te, loss_pd, rtol=self.rtol, atol=self.atol) - - assert ( - len(optimizer_te.state_dict()) == 4 - ), "Expect each rank to hold 4 optimizer state entries." - - def test_group_sharding_stage3(self): - """Tests group sharding training""" - set_random_seed(1024) - model_te = te.Linear(self.in_channels, self.out_channels) - model_pd = paddle.nn.Linear(self.in_channels, self.out_channels) - model_pd.weight.copy_(model_te.weight.T, True) - model_pd.bias.copy_(model_te.bias, True) - - model_te, optimizer_te = self._get_model_and_optimizer(model_te, stage=3) - model_pd, optimizer_pd = self._get_model_and_optimizer(model_pd, stage=3) - - rank_id = paddle.distributed.get_rank() - paddle.seed(rank_id) - - def train_one_step(model, inp, optimizer): - out = model(inp) - loss = out.mean() - loss.backward() - optimizer.step() - optimizer.clear_grad() - return loss - - for _ in range(5): - inp = paddle.uniform([self.batch_size, self.in_channels], self.global_dtype) - with te.fp8_autocast(enabled=False): - loss_te = train_one_step(model_te, inp, optimizer_te) - loss_pd = train_one_step(model_pd, inp, optimizer_pd) - assert_allclose(loss_te, loss_pd, rtol=self.rtol, atol=self.atol) - - for name, value in optimizer_te.state_dict().items(): - if name.endswith("w_0_moment1_0"): - assert ( - value.numel() == self.in_channels * self.out_channels // self.sharding_degree - ), "Expect optimizer state to be sharded across trainers." - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/paddle/parallel_tests/layernorm_linear_tp.py b/tests/paddle/parallel_tests/layernorm_linear_tp.py deleted file mode 100644 index 96070a03c5..0000000000 --- a/tests/paddle/parallel_tests/layernorm_linear_tp.py +++ /dev/null @@ -1,182 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""Unittest for LayerNormLinear layer in tensor parallel""" - -import unittest - -import paddle -from paddle.distributed import fleet -from paddle.distributed.fleet.layers.mpu import mp_ops - -from utils import assert_allclose, assert_shape, set_random_seed -import transformer_engine.paddle as te - - -class TestLayerNormLinearTp(unittest.TestCase): - """Tests LayerNormLinear layer with column/row parallelism in BF16""" - - def setUp(self): - self.set_attr() - self.init_dist_env() - paddle.set_default_dtype(self.global_dtype) - - def init_dist_env(self): - """Init Paddle Fleet environment""" - strategy = fleet.DistributedStrategy() - self.model_parallel_size = 2 - strategy.hybrid_configs = { - "dp_degree": 1, - "mp_degree": self.model_parallel_size, - "pp_degree": 1, - } - strategy.hybrid_configs["mp_configs"].need_broadcast_data = False - fleet.init(is_collective=True, strategy=strategy) - self.rank = fleet.worker_index() - self.hcg = fleet.get_hybrid_communicate_group() - self.tp_group = self.hcg.get_model_parallel_group() - self.world_size = self.hcg.get_model_parallel_world_size() - - def set_attr(self): - """Set test configs""" - self.batch_size = 16 - self.in_features = 32 - self.out_features = 64 - self.global_dtype = "bfloat16" - self.rtol = 1e-3 - self.atol = 1e-3 - self.eps = 1e-3 - self.fp8 = False - self.sequence_parallel = False - - def _train_one_step(self, layer, inp, optimizer, split_input="none", gather_output=False): - inp = paddle.to_tensor(inp, stop_gradient=True) - assert split_input in ["none", "column", "row"] - if split_input == "column": - split_size = inp.shape[1] // self.world_size - input_parallel = inp[:, split_size * self.rank : split_size * (self.rank + 1)] - elif split_input == "row": - split_size = inp.shape[0] // self.world_size - input_parallel = inp[split_size * self.rank : split_size * (self.rank + 1), :] - else: - input_parallel = inp - input_parallel.stop_gradient = False - out = layer(input_parallel) - if gather_output: - total_out = mp_ops._c_concat(out, group=self.tp_group) - else: - total_out = out - loss = total_out.mean() - loss.backward() - optimizer.step() - optimizer.clear_grad() - if split_input != "none": - grad_input = [] - paddle.distributed.all_gather(grad_input, input_parallel.grad, group=self.tp_group) - if split_input == "column": - grad_input = paddle.concat(grad_input, axis=1) - elif split_input == "row": - grad_input = paddle.concat(grad_input, axis=0) - else: - grad_input = input_parallel.grad - return loss, grad_input - - def test_column_parallel_layer(self): - """Tests column parallel LayerNormLinear""" - set_random_seed(1024) - layer_te = te.LayerNormLinear( - self.in_features, - self.out_features, - eps=self.eps, - parallel_mode="column", - sequence_parallel=self.sequence_parallel, - ) - layer_pd = te.LayerNormLinear( - self.in_features, - self.out_features, - eps=self.eps, - backend="paddle", - ) - # Get total weight - total_weight = [] - partial_weight = layer_te.weight.clone().detach() - paddle.distributed.all_gather(total_weight, partial_weight, group=self.tp_group) - total_weight = paddle.concat(total_weight, axis=0) - layer_pd.weight.copy_(total_weight.T, True) - - assert_shape( - layer_te.weight, [self.out_features // self.model_parallel_size, self.in_features] - ) - assert_shape(layer_te.bias, [self.out_features // self.model_parallel_size]) - - optimizer_te = paddle.optimizer.SGD(learning_rate=0.001, parameters=layer_te.parameters()) - optimizer_pd = paddle.optimizer.SGD(learning_rate=0.001, parameters=layer_pd.parameters()) - - layer_te = fleet.distributed_model(layer_te) - optimizer_te = fleet.distributed_optimizer(optimizer_te) - - for _ in range(5): - inp = paddle.uniform([self.batch_size, self.in_features], self.global_dtype) - with te.fp8_autocast(enabled=self.fp8): - loss_tp, grad_input = self._train_one_step( - layer_te, - inp, - optimizer_te, - split_input="row" if self.sequence_parallel else "none", - gather_output=True, - ) - loss_ref, grad_input_ref = self._train_one_step(layer_pd, inp, optimizer_pd) - assert_allclose(loss_tp, loss_ref, rtol=self.rtol, atol=self.atol) - assert_allclose(grad_input, grad_input_ref, rtol=self.rtol, atol=self.atol) - - -class TestLayerNormLinearTpFp8(TestLayerNormLinearTp): - """Tests LayernormLinear layer with column/row parallelism in FP8""" - - def set_attr(self): - """Set test configs""" - self.batch_size = 16 - self.in_features = 32 - self.out_features = 64 - self.global_dtype = "bfloat16" - self.rtol = 1e-2 - self.atol = 1e-2 - self.eps = 1e-3 - self.fp8 = True - self.sequence_parallel = False - - -class TestLayerNormLinearSp(TestLayerNormLinearTp): - """Tests LayernormLinear layer with sequence parallelism""" - - def set_attr(self): - """Set test configs""" - self.batch_size = 16 - self.in_features = 32 - self.out_features = 64 - self.global_dtype = "bfloat16" - self.rtol = 1e-3 - self.atol = 1e-3 - self.eps = 1e-3 - self.fp8 = False - self.sequence_parallel = True - - -class TestLayerNormLinearSpFp8(TestLayerNormLinearTp): - """Tests LayernormLinear layer with sequence parallelism in FP8""" - - def set_attr(self): - """Set test configs""" - self.batch_size = 16 - self.in_features = 32 - self.out_features = 64 - self.global_dtype = "bfloat16" - self.rtol = 1e-2 - self.atol = 1e-2 - self.eps = 1e-3 - self.fp8 = True - self.sequence_parallel = True - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/paddle/parallel_tests/layernorm_mlp_tp.py b/tests/paddle/parallel_tests/layernorm_mlp_tp.py deleted file mode 100644 index 9ec09c7e7a..0000000000 --- a/tests/paddle/parallel_tests/layernorm_mlp_tp.py +++ /dev/null @@ -1,197 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""Unittest for LayerNormMLP layer in tensor parallel""" - -import unittest - -import paddle -from paddle.distributed import fleet -from paddle.distributed.fleet.layers.mpu import mp_ops - -from utils import assert_allclose, assert_shape, set_random_seed -import transformer_engine.paddle as te - - -class TestLayerNormMLPTp(unittest.TestCase): - """Tests LayerNormMLP layer with model parallel in BF16""" - - def setUp(self): - self.set_attr() - self.init_dist_env() - paddle.set_default_dtype(self.global_dtype) - - def init_dist_env(self): - """Init Paddle Fleet environment""" - strategy = fleet.DistributedStrategy() - self.model_parallel_size = 2 - strategy.hybrid_configs = { - "dp_degree": 1, - "mp_degree": self.model_parallel_size, - "pp_degree": 1, - } - strategy.hybrid_configs["mp_configs"].need_broadcast_data = False - fleet.init(is_collective=True, strategy=strategy) - self.rank = fleet.worker_index() - self.hcg = fleet.get_hybrid_communicate_group() - self.tp_group = self.hcg.get_model_parallel_group() - self.world_size = self.hcg.get_model_parallel_world_size() - - def set_attr(self): - """Set test configs""" - self.batch_size = 16 - self.hidden_size = 32 - self.ffn_hidden_size = 64 - self.global_dtype = "bfloat16" - self.rtol = 1e-3 - self.atol = 1e-3 - self.eps = 1e-3 - self.fp8 = False - self.sequence_parallel = False - - def _train_one_step(self, layer, inp, optimizer, split_input="none", gather_output=False): - inp = paddle.to_tensor(inp, stop_gradient=True) - assert split_input in ["none", "column", "row"] - if split_input == "column": - split_size = inp.shape[1] // self.world_size - input_parallel = inp[:, split_size * self.rank : split_size * (self.rank + 1)] - elif split_input == "row": - split_size = inp.shape[0] // self.world_size - input_parallel = inp[split_size * self.rank : split_size * (self.rank + 1), :] - else: - input_parallel = inp - input_parallel.stop_gradient = False - out = layer(input_parallel) - if gather_output: - # Need to concat on the first dim, while _c_concat concats on the last dim - total_out = mp_ops._c_concat(out.T, group=self.tp_group).T - else: - total_out = out - loss = total_out.mean() - loss.backward() - optimizer.step() - optimizer.clear_grad() - if split_input != "none": - grad_input = [] - paddle.distributed.all_gather(grad_input, input_parallel.grad, group=self.tp_group) - if split_input == "column": - grad_input = paddle.concat(grad_input, axis=1) - elif split_input == "row": - grad_input = paddle.concat(grad_input, axis=0) - else: - grad_input = input_parallel.grad - return loss, grad_input - - def test_parallel_layer(self): - """Tests parallel LayerNormMLP""" - set_random_seed(1024) - layer_te = te.LayerNormMLP( - hidden_size=self.hidden_size, - ffn_hidden_size=self.ffn_hidden_size, - eps=self.eps, - set_parallel_mode=True, - sequence_parallel=self.sequence_parallel, - ) - layer_pd = te.LayerNormMLP( - hidden_size=self.hidden_size, - ffn_hidden_size=self.ffn_hidden_size, - eps=self.eps, - set_parallel_mode=False, - backend="paddle", - ) - - def _get_total_weight(local_weight, tp_group, axis): - total_weight = [] - partial_weight = local_weight.clone().detach() - paddle.distributed.all_gather(total_weight, partial_weight, group=tp_group) - total_weight = paddle.concat(total_weight, axis=axis) - return total_weight - - # Get total weight - total_fc1_weight = _get_total_weight(layer_te.fc1_weight, tp_group=self.tp_group, axis=0) - total_fc2_weight = _get_total_weight(layer_te.fc2_weight, tp_group=self.tp_group, axis=1) - layer_pd.fc1_weight.copy_(total_fc1_weight.T, True) - layer_pd.fc2_weight.copy_(total_fc2_weight.T, True) - - assert_shape( - layer_te.fc1_weight, - [self.ffn_hidden_size // self.model_parallel_size, self.hidden_size], - ) - assert_shape(layer_te.fc1_bias, [self.ffn_hidden_size // self.model_parallel_size]) - assert_shape( - layer_te.fc2_weight, - [self.hidden_size, self.ffn_hidden_size // self.model_parallel_size], - ) - assert_shape(layer_te.fc2_bias, [self.hidden_size]) - - optimizer_te = paddle.optimizer.SGD(learning_rate=0.001, parameters=layer_te.parameters()) - optimizer_pd = paddle.optimizer.SGD(learning_rate=0.001, parameters=layer_pd.parameters()) - - layer_te = fleet.distributed_model(layer_te) - optimizer_te = fleet.distributed_optimizer(optimizer_te) - - for _ in range(5): - inp = paddle.uniform([self.batch_size, self.hidden_size], self.global_dtype) - with te.fp8_autocast(enabled=self.fp8): - loss_tp, grad_input = self._train_one_step( - layer_te, - inp, - optimizer_te, - split_input="row" if self.sequence_parallel else "none", - gather_output=self.sequence_parallel, - ) - loss_ref, grad_input_ref = self._train_one_step(layer_pd, inp, optimizer_pd) - assert_allclose(loss_tp, loss_ref, rtol=self.rtol, atol=self.atol) - assert_allclose(grad_input, grad_input_ref, rtol=self.rtol, atol=self.atol) - - -class TestLayerNormMLPTpFp8(TestLayerNormMLPTp): - """Tests LayerNormMLP layer with tensor parallelism in FP8""" - - def set_attr(self): - """Set test configs""" - self.batch_size = 16 - self.hidden_size = 32 - self.ffn_hidden_size = 64 - self.global_dtype = "bfloat16" - self.rtol = 1e-2 - self.atol = 1e-2 - self.eps = 1e-3 - self.fp8 = True - self.sequence_parallel = False - - -class TestLayerNormMLPSp(TestLayerNormMLPTp): - """Tests LayerNormMLP layer with sequence parallel in BF16""" - - def set_attr(self): - """Set test configs""" - self.batch_size = 16 - self.hidden_size = 32 - self.ffn_hidden_size = 64 - self.global_dtype = "bfloat16" - self.rtol = 1e-3 - self.atol = 1e-3 - self.eps = 1e-3 - self.fp8 = False - self.sequence_parallel = True - - -class TestLayerNormMLPSpFp8(TestLayerNormMLPTp): - """Tests LayerNormMLP layer with sequence parallelism in FP8""" - - def set_attr(self): - """Set test configs""" - self.batch_size = 16 - self.hidden_size = 32 - self.ffn_hidden_size = 64 - self.global_dtype = "bfloat16" - self.rtol = 1e-2 - self.atol = 1e-2 - self.eps = 1e-3 - self.fp8 = True - self.sequence_parallel = True - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/paddle/parallel_tests/linear_pp.py b/tests/paddle/parallel_tests/linear_pp.py deleted file mode 100644 index 68271e52e7..0000000000 --- a/tests/paddle/parallel_tests/linear_pp.py +++ /dev/null @@ -1,235 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""Unittest for Linear layer in pipeline parallel""" - -import unittest - -import numpy as np - -import paddle -from paddle.distributed import fleet - -from paddle.distributed.fleet.meta_parallel import ( - LayerDesc, - PipelineLayer, -) - -from utils import assert_allclose, set_random_seed -import transformer_engine.paddle as te - - -class TELinear(te.Linear): - """To pass is_first_microbatch""" - - def __init__(self, *args, **kwargs): - assert "accumulate_steps" in kwargs - self.accumulate_steps = kwargs["accumulate_steps"] - del kwargs["accumulate_steps"] - self._micro_batch_id = 0 - super().__init__(*args, **kwargs) - - def forward(self, *args, **kwargs): - kwargs["is_first_microbatch"] = (self._micro_batch_id % self.accumulate_steps) == 0 - if paddle.is_grad_enabled() and self.training: - self._micro_batch_id += 1 - return super().forward(*args, **kwargs) - - -class TEPipelineModel(PipelineLayer): - """Model for pipeline parallel test""" - - def __init__( - self, - in_features, - hidden_features, - weight_attrs, - use_te=True, - use_fp8=False, - accumulate_steps=1, - **kwargs, - ): - self.in_features = in_features - self.hidden_features = hidden_features - self.fp8 = use_fp8 - hcg = fleet.get_hybrid_communicate_group() - self.dp_group = hcg.get_data_parallel_group() - - Linear = TELinear if use_te else paddle.nn.Linear - extra_kwargs = {} - if use_te: - extra_kwargs["accumulate_steps"] = accumulate_steps - - model_desc = [ - LayerDesc( - Linear, - self.in_features, - self.hidden_features, - weight_attr=weight_attrs[0], - **extra_kwargs, - ), - LayerDesc( - Linear, - self.hidden_features, - self.in_features, - weight_attr=weight_attrs[1], - **extra_kwargs, - ), - ] - super().__init__(layers=model_desc, loss_fn=paddle.nn.CrossEntropyLoss(), **kwargs) - - def forward(self, *args, **kwargs): - with te.fp8_autocast(enabled=self.fp8, fp8_group=self.dp_group): - return super().forward(*args, **kwargs) - - -class StandaloneModel(paddle.nn.Layer): - """Model for pipeline parallel test""" - - def __init__(self, in_features, hidden_features, weight_attrs): - super().__init__() - self.in_features = in_features - self.hidden_features = hidden_features - Linear = paddle.nn.Linear - self.layer = paddle.nn.Sequential( - Linear(self.in_features, self.hidden_features, weight_attr=weight_attrs[0]), - Linear(self.hidden_features, self.in_features, weight_attr=weight_attrs[1]), - ) - self.loss = paddle.nn.CrossEntropyLoss() - - def forward(self, inp): - out = self.layer(inp[0]) - loss = self.loss(out, inp[1]) - return loss - - -class TestLinearPipelineParallel(unittest.TestCase): - """Tests Linear layer with pipeline parallel""" - - def setUp(self): - self.set_attr() - self.init_dist_env() - paddle.set_default_dtype(self.global_dtype) - - def init_dist_env(self): - """Init Paddle Fleet environment""" - strategy = fleet.DistributedStrategy() - self.pipeline_parallel_size = 2 - strategy.hybrid_configs = { - "dp_degree": 1, - "mp_degree": 1, - "pp_degree": self.pipeline_parallel_size, - } - self.accumulate_steps = self.batch_size // self.micro_batch_size - strategy.pipeline_configs = { - "accumulate_steps": self.accumulate_steps, - "micro_batch_size": self.micro_batch_size, - } - fleet.init(is_collective=True, strategy=strategy) - self.rank = fleet.worker_index() - self.hcg = fleet.get_hybrid_communicate_group() - - def set_attr(self): - """Set test configs""" - self.batch_size = 32 - self.micro_batch_size = 16 - self.in_features = 32 - self.hidden_features = 64 - self.global_dtype = "float32" - self.rtol = 1e-5 - self.atol = 1e-5 - self.iter = 10 - self.fp8 = False - - def test_pipeline_train(self): - """Test pipeline parallel training""" - set_random_seed(1024) - np.random.seed(1024) - - weight1_np = np.random.normal(size=[self.in_features, self.hidden_features]) - weight2_np = np.random.normal(size=[self.hidden_features, self.in_features]) - weight_attrs = [ - paddle.ParamAttr(initializer=paddle.nn.initializer.Assign(weight1_np)), - paddle.ParamAttr(initializer=paddle.nn.initializer.Assign(weight2_np)), - ] - weight_attrs_transposed = [ - paddle.ParamAttr(initializer=paddle.nn.initializer.Assign(weight1_np.T)), - paddle.ParamAttr(initializer=paddle.nn.initializer.Assign(weight2_np.T)), - ] - - pipe_model = TEPipelineModel( - self.in_features, - self.hidden_features, - weight_attrs_transposed, - use_te=True, - use_fp8=self.fp8, - seg_method="layer:Linear", - num_stages=self.pipeline_parallel_size, - accumulate_steps=self.accumulate_steps, - ) - - # Check if model is split across ranks as expected - for name, sublayer in pipe_model.named_sublayers(): - if name in ("_loss_fn", "shared_layers"): - continue - if self.rank == 0: - assert tuple(sublayer.weight.shape) == weight1_np.T.shape, ( - f"Shape does not match, expect: {weight1_np.T.shape}, " - f"actual: {tuple(sublayer.weight.shape)}" - ) - elif self.rank == 1: - assert tuple(sublayer.weight.shape) == weight2_np.T.shape, ( - f"Shape does not match, expect: {weight2_np.T.shape}, " - f"actual: {tuple(sublayer.weight.shape)}" - ) - - standalone_model = StandaloneModel( - self.in_features, - self.hidden_features, - weight_attrs, - ) - - optimizer_te = paddle.optimizer.SGD(learning_rate=0.1, parameters=pipe_model.parameters()) - optimizer_pd = paddle.optimizer.SGD( - learning_rate=0.1, parameters=standalone_model.parameters() - ) - - pipe_model = fleet.distributed_model(pipe_model) - optimizer_te = fleet.distributed_optimizer(optimizer_te) - - def train_one_step(layer, inp, optimizer): - loss = layer(inp) - loss.backward() - optimizer.step() - optimizer.clear_grad() - return loss - - for i in range(self.iter): - inp = paddle.to_tensor( - np.random.normal(size=[self.batch_size, self.in_features]), dtype=self.global_dtype - ) - label = paddle.to_tensor(np.random.randint(self.in_features, size=[self.batch_size, 1])) - loss_te = pipe_model.train_batch([inp, label], optimizer_te) - loss_pd = train_one_step(standalone_model, [inp, label], optimizer_pd) - print(f"Iter: {i}, loss_te: {loss_te.item()}, loss_pd: {loss_pd.item()}") - assert_allclose(loss_te, loss_pd, rtol=self.rtol, atol=self.atol) - - -class TestLinearPipelineParallelFP8(TestLinearPipelineParallel): - """Tests Linear layer with column/row parallelism in FP8""" - - def set_attr(self): - """Set test configs""" - self.batch_size = 32 - self.micro_batch_size = 16 - self.in_features = 32 - self.hidden_features = 64 - self.global_dtype = "float32" - self.rtol = 5e-2 - self.atol = 5e-2 - self.iter = 10 - self.fp8 = True - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/paddle/parallel_tests/linear_tp.py b/tests/paddle/parallel_tests/linear_tp.py deleted file mode 100644 index 1a42d6c621..0000000000 --- a/tests/paddle/parallel_tests/linear_tp.py +++ /dev/null @@ -1,222 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""Unittest for Linear layer in tensor parallel""" - -import unittest - -import paddle -from paddle.distributed import fleet -from paddle.distributed.fleet.layers.mpu import mp_ops - -from utils import assert_allclose, assert_shape, set_random_seed -import transformer_engine.paddle as te - - -class TestLinearTp(unittest.TestCase): - """Tests Linear layer with column/row parallelism in BF16""" - - def setUp(self): - self.set_attr() - self.init_dist_env() - paddle.set_default_dtype(self.global_dtype) - - def init_dist_env(self): - """Init Paddle Fleet environment""" - strategy = fleet.DistributedStrategy() - self.model_parallel_size = 2 - strategy.hybrid_configs = { - "dp_degree": 1, - "mp_degree": self.model_parallel_size, - "pp_degree": 1, - } - strategy.hybrid_configs["mp_configs"].need_broadcast_data = False - fleet.init(is_collective=True, strategy=strategy) - self.rank = fleet.worker_index() - self.hcg = fleet.get_hybrid_communicate_group() - self.tp_group = self.hcg.get_model_parallel_group() - self.world_size = self.hcg.get_model_parallel_world_size() - - def set_attr(self): - """Set test configs""" - self.batch_size = 16 - self.in_features = 32 - self.out_features = 64 - self.global_dtype = "bfloat16" - self.rtol = 1e-3 - self.atol = 1e-3 - self.fp8 = False - self.sequence_parallel = False - - def _train_one_step(self, layer, inp, optimizer, split_input="none", gather_output=False): - inp = paddle.to_tensor(inp, stop_gradient=True) - assert split_input in ["none", "column", "row"] - if split_input == "column": - split_size = inp.shape[1] // self.world_size - input_parallel = inp[:, split_size * self.rank : split_size * (self.rank + 1)] - elif split_input == "row": - split_size = inp.shape[0] // self.world_size - input_parallel = inp[split_size * self.rank : split_size * (self.rank + 1), :] - else: - input_parallel = inp - input_parallel.stop_gradient = False - out = layer(input_parallel) - if gather_output: - total_out = mp_ops._c_concat(out, group=self.tp_group) - else: - total_out = out - loss = total_out.mean() - loss.backward() - optimizer.step() - optimizer.clear_grad() - if split_input != "none": - grad_input = [] - paddle.distributed.all_gather(grad_input, input_parallel.grad, group=self.tp_group) - if split_input == "column": - grad_input = paddle.concat(grad_input, axis=1) - elif split_input == "row": - grad_input = paddle.concat(grad_input, axis=0) - else: - grad_input = input_parallel.grad - return loss, grad_input - - def test_column_parallel_layer(self): - """Tests column parallel linear""" - set_random_seed(1024) - layer_te = te.Linear( - self.in_features, - self.out_features, - parallel_mode="column", - sequence_parallel=self.sequence_parallel, - ) - layer_pd = te.Linear( - self.in_features, - self.out_features, - backend="paddle", - ) - # Get total weight - total_weight = [] - partial_weight = layer_te.weight.clone().detach() - paddle.distributed.all_gather(total_weight, partial_weight, group=self.tp_group) - total_weight = paddle.concat(total_weight, axis=0) - layer_pd.weight.copy_(total_weight.T, True) - - assert_shape( - layer_te.weight, [self.out_features // self.model_parallel_size, self.in_features] - ) - assert_shape(layer_te.bias, [self.out_features // self.model_parallel_size]) - - optimizer_te = paddle.optimizer.SGD(learning_rate=0.001, parameters=layer_te.parameters()) - optimizer_pd = paddle.optimizer.SGD(learning_rate=0.001, parameters=layer_pd.parameters()) - - layer_te = fleet.distributed_model(layer_te) - optimizer_te = fleet.distributed_optimizer(optimizer_te) - - for _ in range(5): - inp = paddle.uniform([self.batch_size, self.in_features], self.global_dtype) - with te.fp8_autocast(enabled=self.fp8): - loss_tp, grad_input = self._train_one_step( - layer_te, - inp, - optimizer_te, - split_input="row" if self.sequence_parallel else "none", - gather_output=True, - ) - loss_ref, grad_input_ref = self._train_one_step(layer_pd, inp, optimizer_pd) - assert_allclose(loss_tp, loss_ref, rtol=self.rtol, atol=self.atol) - assert_allclose(grad_input, grad_input_ref, rtol=self.rtol, atol=self.atol) - - def test_row_parallel_layer(self): - """Tests row parallel linear""" - set_random_seed(1024) - layer_te = te.Linear( - self.in_features, - self.out_features, - parallel_mode="row", - sequence_parallel=self.sequence_parallel, - ) - layer_pd = te.Linear( - self.in_features, - self.out_features, - backend="paddle", - ) - # Get total weight - total_weight = [] - partial_weight = layer_te.weight.clone().detach() - paddle.distributed.all_gather(total_weight, partial_weight, group=self.tp_group) - total_weight = paddle.concat(total_weight, axis=1) - layer_pd.weight.copy_(total_weight.T, True) - - assert_shape( - layer_te.weight, [self.out_features, self.in_features // self.model_parallel_size] - ) - assert_shape(layer_te.bias, [self.out_features]) - - optimizer_te = paddle.optimizer.SGD(learning_rate=0.001, parameters=layer_te.parameters()) - optimizer_pd = paddle.optimizer.SGD(learning_rate=0.001, parameters=layer_pd.parameters()) - - layer_te = fleet.distributed_model(layer_te) - optimizer_te = fleet.distributed_optimizer(optimizer_te) - - for _ in range(5): - inp = paddle.uniform([self.batch_size, self.in_features], self.global_dtype) - with te.fp8_autocast(enabled=self.fp8): - loss_tp, grad_input = self._train_one_step( - layer_te, - inp, - optimizer_te, - split_input="column", - gather_output=self.sequence_parallel, - ) - loss_ref, grad_input_ref = self._train_one_step(layer_pd, inp, optimizer_pd) - assert_allclose(loss_tp, loss_ref, rtol=self.rtol, atol=self.atol) - assert_allclose(grad_input, grad_input_ref, rtol=self.rtol, atol=self.atol) - - -class TestLinearTpFP8(TestLinearTp): - """Tests Linear layer with column/row parallelism in FP8""" - - def set_attr(self): - """Set test configs""" - self.batch_size = 16 - self.in_features = 32 - self.out_features = 64 - self.global_dtype = "bfloat16" - self.rtol = 1e-2 - self.atol = 1e-2 - self.fp8 = True - self.sequence_parallel = False - - -class TestLinearSp(TestLinearTp): - """Tests Linear layer with sequence parallelism""" - - def set_attr(self): - """Set test configs""" - self.batch_size = 16 - self.in_features = 32 - self.out_features = 64 - self.global_dtype = "bfloat16" - self.rtol = 1e-3 - self.atol = 1e-3 - self.fp8 = False - self.sequence_parallel = True - - -class TestLinearSpFP8(TestLinearTp): - """Tests Linear layer with sequence parallelism in FP8""" - - def set_attr(self): - """Set test configs""" - self.batch_size = 16 - self.in_features = 32 - self.out_features = 64 - self.global_dtype = "bfloat16" - self.rtol = 1e-2 - self.atol = 1e-2 - self.fp8 = True - self.sequence_parallel = True - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/paddle/parallel_tests/transformer_tp.py b/tests/paddle/parallel_tests/transformer_tp.py deleted file mode 100644 index 5fc3e7ddf3..0000000000 --- a/tests/paddle/parallel_tests/transformer_tp.py +++ /dev/null @@ -1,250 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""Unittest for Transformer layer in tensor parallel""" - -import unittest - -import paddle -from paddle.distributed import fleet -from paddle.distributed.fleet.layers.mpu import mp_ops - -from utils import assert_allclose, set_random_seed, register_sequence_parallel_allreduce_hooks -import transformer_engine.paddle as te - - -class TestTransformerTp(unittest.TestCase): - """Tests Transformer layer with model parallel in BF16""" - - def setUp(self): - self.set_attr() - self.init_dist_env() - paddle.set_default_dtype(self.global_dtype) - - def init_dist_env(self): - """Init Paddle Fleet environment""" - strategy = fleet.DistributedStrategy() - self.model_parallel_size = 2 - strategy.hybrid_configs = { - "dp_degree": 1, - "mp_degree": self.model_parallel_size, - "pp_degree": 1, - } - strategy.hybrid_configs["mp_configs"].need_broadcast_data = False - fleet.init(is_collective=True, strategy=strategy) - self.rank = fleet.worker_index() - self.hcg = fleet.get_hybrid_communicate_group() - self.tp_group = self.hcg.get_model_parallel_group() - self.world_size = self.hcg.get_model_parallel_world_size() - - def set_attr(self): - """Set test configs""" - self.batch_size = 16 - self.hidden_size = 1024 - self.num_heads = 16 - self.ffn_hidden_size = 4096 - self.q_seqlen = 128 - self.kv_seqlen = 128 - self.mask_type = "padding" - self.layer_type = "encoder" - self.global_dtype = "bfloat16" - self.rtol = 5e-2 - self.atol = 5e-2 - self.eps = 1e-3 - self.fp8 = False - self.sequence_parallel = False - - def _train_one_step(self, layer, inp_list, optimizer, fp8_enabled, sequence_parallel=False): - inp, mask = inp_list - if sequence_parallel: - split_size = inp.shape[0] // self.world_size - input_parallel = inp[split_size * self.rank : split_size * (self.rank + 1), :] - else: - input_parallel = inp - with te.fp8_autocast(enabled=fp8_enabled): - out = layer(input_parallel, mask) - if sequence_parallel: - total_out = mp_ops._c_concat(out, group=self.tp_group) - total_out = paddle.concat(paddle.split(total_out, self.world_size, axis=-1), axis=0) - else: - total_out = out - loss = total_out.mean() - loss.backward() - optimizer.step() - optimizer.clear_grad() - return loss, total_out - - def test_parallel_layer(self): - """Tests parallel Transformer""" - set_random_seed(1024) - common_args = [ - self.hidden_size, - self.ffn_hidden_size, - self.num_heads, - ] - common_kwargs = { - "layernorm_epsilon": self.eps, - "hidden_dropout": 0.0, - "attention_dropout": 0.0, - "self_attn_mask_type": self.mask_type, - "layer_type": self.layer_type, - } - layer_tp = te.TransformerLayer( - *common_args, - **common_kwargs, - set_parallel_mode=True, - sequence_parallel=self.sequence_parallel, - ) - layer_single = te.TransformerLayer(*common_args, **common_kwargs, set_parallel_mode=False) - - def _get_total_weight(local_weight, tp_group, axis, interleave=False): - total_weight = [] - partial_weight = local_weight.clone().detach() - paddle.distributed.all_gather(total_weight, partial_weight, group=tp_group) - if interleave: - # Due to the interleaved qkv layout, need to concat on num_head - # dimension for column parallel linear in MultiHeadAttention layer - assert axis == 0 - assert [ - 3 * self.hidden_size // self.world_size, - self.hidden_size, - ] == partial_weight.shape - local_num_head = self.num_heads // self.world_size - for idx, _ in enumerate(total_weight): - total_weight[idx] = total_weight[idx].reshape( - [3, local_num_head, -1, self.hidden_size] - ) - total_weight = paddle.concat(total_weight, axis=1).reshape([-1, self.hidden_size]) - else: - total_weight = paddle.concat(total_weight, axis=axis) - return total_weight - - def _get_weight(obj, weight_names): - for name in weight_names: - obj = getattr(obj, name) - return obj - - def copy_weight(layer_src, layer_dst, partition_mode, weight_names, interleave=False): - weight_src = _get_weight(layer_src, weight_names) - weight_dst = _get_weight(layer_dst, weight_names) - if partition_mode is None: - total_weight = weight_src - elif partition_mode == "column": - total_weight = _get_total_weight( - weight_src, tp_group=self.tp_group, axis=0, interleave=interleave - ) - elif partition_mode == "row": - total_weight = _get_total_weight(weight_src, tp_group=self.tp_group, axis=1) - else: - raise ValueError(f"Partition Mode {partition_mode} is not supported.") - assert ( - weight_dst.shape == total_weight.shape - ), f"Shapes of src:{total_weight.shape} and dst:{weight_dst.shape} do not match." - weight_dst.copy_(total_weight, True) - - copy_weight(layer_tp, layer_single, None, ["self_attention", "layernorm_qkv", "ln_weight"]) - copy_weight( - layer_tp, - layer_single, - "column", - ["self_attention", "layernorm_qkv", "weight"], - interleave=True, - ) - copy_weight(layer_tp, layer_single, "row", ["self_attention", "proj", "weight"]) - copy_weight(layer_tp, layer_single, None, ["layernorm_mlp", "ln_weight"]) - copy_weight(layer_tp, layer_single, "column", ["layernorm_mlp", "fc1_weight"]) - copy_weight(layer_tp, layer_single, "row", ["layernorm_mlp", "fc2_weight"]) - - if self.sequence_parallel: - register_sequence_parallel_allreduce_hooks(layer_tp, accumulation_steps=1) - - optimizer_tp = paddle.optimizer.SGD(learning_rate=0.01, parameters=layer_tp.parameters()) - optimizer_single = paddle.optimizer.SGD( - learning_rate=0.01, parameters=layer_single.parameters() - ) - - layer_tp = fleet.distributed_model(layer_tp) - optimizer_tp = fleet.distributed_optimizer(optimizer_tp) - - for _ in range(5): - inp = paddle.uniform( - [self.batch_size, self.q_seqlen, self.hidden_size], self.global_dtype - ) - mask = paddle.zeros( - shape=(self.batch_size, 1, self.q_seqlen, self.kv_seqlen), dtype="bool" - ) - loss_tp, out_tp = self._train_one_step( - layer_tp, [inp, mask], optimizer_tp, self.fp8, self.sequence_parallel - ) - loss_single, out_single = self._train_one_step( - layer_single, [inp, mask], optimizer_single, self.fp8 - ) - assert_allclose(out_tp, out_single, rtol=self.rtol, atol=self.atol) - assert_allclose(loss_tp, loss_single, rtol=self.rtol, atol=self.atol) - - -class TestTransformerTpFp8(TestTransformerTp): - """Tests Transformer layer with tensor parallelism in FP8""" - - def set_attr(self): - """Set test configs""" - self.batch_size = 16 - self.hidden_size = 1024 - self.num_heads = 16 - self.ffn_hidden_size = 4096 - self.q_seqlen = 128 - self.kv_seqlen = 128 - self.mask_type = "padding" - self.layer_type = "encoder" - self.global_dtype = "bfloat16" - self.rtol = 5e-2 - self.atol = 0.5 - self.eps = 1e-3 - self.fp8 = True - self.sequence_parallel = False - - -class TestTransformerSp(TestTransformerTp): - """Tests Transformer layer with sequence parallel in BF16""" - - def set_attr(self): - """Set test configs""" - self.batch_size = 16 - self.hidden_size = 1024 - self.num_heads = 16 - self.ffn_hidden_size = 4096 - self.q_seqlen = 128 - self.kv_seqlen = 128 - self.mask_type = "padding" - self.layer_type = "encoder" - self.global_dtype = "bfloat16" - self.rtol = 5e-2 - self.atol = 5e-2 - self.eps = 1e-3 - self.fp8 = False - self.sequence_parallel = True - - -class TestTransformerSpFp8(TestTransformerSp): - """Tests Transformer layer with sequence parallelism in FP8""" - - def set_attr(self): - """Set test configs""" - self.batch_size = 16 - self.hidden_size = 1024 - self.num_heads = 16 - self.ffn_hidden_size = 4096 - self.q_seqlen = 128 - self.kv_seqlen = 128 - self.mask_type = "padding" - self.layer_type = "encoder" - self.global_dtype = "bfloat16" - self.rtol = 5e-2 - self.atol = 0.5 - self.eps = 1e-3 - self.fp8 = True - self.sequence_parallel = True - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/paddle/recompute_tests/recompute_transformer_encoder.py b/tests/paddle/recompute_tests/recompute_transformer_encoder.py deleted file mode 100644 index e753f750c5..0000000000 --- a/tests/paddle/recompute_tests/recompute_transformer_encoder.py +++ /dev/null @@ -1,71 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""Test TransformerLayer encoder recompute""" - -import sys -import paddle -import transformer_engine.paddle as te - - -class Net(paddle.nn.Layer): - """Network use for recompute testing""" - - def __init__(self, layers): - super().__init__() - self.layers = layers - - def forward(self, inp, mask, enable_recompute, use_reentrant): - for layer in self.layers: - if enable_recompute: - out = te.recompute(layer, inp, mask, use_reentrant=use_reentrant) - else: - out = layer(inp, mask) - return out - - -def main(): - """Main function""" - paddle.seed(10) - batch_size = 16 - hidden_size = 4096 - num_heads = 32 - ffn_hidden_size = 16384 - q_seqlen = 512 - kv_seqlen = 512 - num_layers = 4 - enable_recompute = int(sys.argv[1]) - use_reentrant = int(sys.argv[2]) - - layers = paddle.nn.LayerList( - [ - te.TransformerLayer( - hidden_size, - ffn_hidden_size, - num_heads, - layer_type="encoder", - ) - for _ in range(num_layers) - ] - ) - model = Net(layers) - - optimizer = paddle.optimizer.AdamW(learning_rate=0.001, parameters=model.parameters()) - - for _ in range(10): - inp = paddle.uniform([batch_size, q_seqlen, hidden_size]) - inp.stop_gradient = False - mask = paddle.zeros(shape=(batch_size, 1, q_seqlen, kv_seqlen), dtype="bool") - with te.fp8_autocast(enabled=True): - out = model(inp, mask, enable_recompute, use_reentrant) - loss = out.mean() - loss.backward() - optimizer.step() - optimizer.clear_grad() - - print("Loss: ", float(loss)) - print("Peak memory: ", paddle.device.cuda.max_memory_allocated(0)) - - -if __name__ == "__main__": - main() diff --git a/tests/paddle/test_install.py b/tests/paddle/test_install.py deleted file mode 100644 index 1c317584ed..0000000000 --- a/tests/paddle/test_install.py +++ /dev/null @@ -1,11 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""Test basic installation of Paddle extensions""" - - -def test_import(): - """ - Test if Paddle extension can be imported normally - """ - import transformer_engine.paddle # pylint: disable=unused-import diff --git a/tests/paddle/test_layers.py b/tests/paddle/test_layers.py deleted file mode 100644 index fbd6c61ad7..0000000000 --- a/tests/paddle/test_layers.py +++ /dev/null @@ -1,1663 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""Test TE Paddle Layer-level APIs""" - -import os -from utils import assert_allclose, is_fused_attention_supported - -import paddle -import pytest - -from transformer_engine.common.recipe import DelayedScaling -import transformer_engine.paddle as te -from transformer_engine.paddle.fp8 import is_fp8_available, fp8_autocast - -is_fp8_supported, reason = is_fp8_available() -LINEAR_CASES = [(16, 16, 32), (32, 32, 64)] -NORM_CASES = [(16, 32), (256, 1024)] - - -@pytest.fixture(autouse=True) -def setup(): - """Setup random seed before each test""" - paddle.seed(10) - yield - - -@pytest.mark.skipif(not is_fp8_supported, reason=reason) -@pytest.mark.parametrize("use_fp8", [True, False]) -def test_checkpoint(use_fp8): - """Test checkpoint save / load""" - bs = 16 - in_features = 16 - out_features = 32 - file_name = "model.pdparams" - input_tensor = paddle.uniform(shape=(bs, in_features), dtype="float32") - model = te.Linear(in_features, out_features) - model_loaded = te.Linear(in_features, out_features) - # Populate amax_history - with fp8_autocast(enabled=False, calibrating=True): - _ = model(input_tensor) - # Save model - paddle.save(model.state_dict(), file_name) - # Get ref output - with fp8_autocast(enabled=use_fp8): - out_ref = model(input_tensor) - # Load model - model_loaded.set_state_dict(paddle.load(file_name)) - if os.path.exists(file_name): - os.remove(file_name) - # Get actual output - with fp8_autocast(enabled=use_fp8): - out = model_loaded(input_tensor) - - assert_allclose(out, out_ref) - - -def calc_output_and_grad(layer, x, dy): - """ - Calculate forward and backward pass - """ - inp = paddle.to_tensor(x) - inp.stop_gradient = x.stop_gradient - y = layer(inp) - y.backward(dy) - - return y, inp.grad if not inp.stop_gradient else None - - -@staticmethod -def calc_output_and_grad_ln_out(layer, x, dy, return_ln_out=False): - """ - Calculate forward and backward pass for layernorm - """ - inp = paddle.to_tensor(x) - inp.stop_gradient = x.stop_gradient - outputs = layer(inp) - ln_out = None - if return_ln_out: - y, ln_out = outputs - else: - y = outputs - y.backward(dy) - - return y, ln_out, inp.grad if not inp.stop_gradient else None - - -class TestLinear: - """ - Tests for Linear layer - """ - - @staticmethod - @pytest.mark.skipif( - paddle.device.cuda.get_device_capability() < (8, 0), - reason="BF16 Linear requires Ampere+ GPU", - ) - @pytest.mark.parametrize("bs,in_features,out_features", LINEAR_CASES) - @pytest.mark.parametrize("has_bias,no_dbias", [[True, False], [True, True], [False, False]]) - @pytest.mark.parametrize("no_dgrad", [True, False]) - @pytest.mark.parametrize("no_wgrad", [True, False]) - @pytest.mark.parametrize("activation_dtype", ["bfloat16", "float32"]) - def test_linear_bf16( - bs, in_features, out_features, has_bias, no_dbias, no_dgrad, no_wgrad, activation_dtype - ): - """ - Test BF16 Linear - """ - rtol = 5e-2 - atol = 5e-2 - - input_tensor = paddle.uniform(shape=(bs, in_features), dtype=activation_dtype) - input_tensor.stop_gradient = no_dgrad - grad_out = paddle.uniform(shape=(bs, out_features), dtype=activation_dtype) - - paddle.set_default_dtype(activation_dtype) - layer_te = te.Linear(in_features, out_features, bias_attr=None if has_bias else False) - layer_pd = te.Linear( - in_features, out_features, bias_attr=None if has_bias else False, backend="paddle" - ) - layer_pd.weight.copy_(layer_te.weight.T, True) - if has_bias: - layer_pd.bias.copy_(layer_te.bias, True) - - layer_te.weight.stop_gradient = no_wgrad - layer_pd.weight.stop_gradient = no_wgrad - if has_bias: - layer_te.bias.stop_gradient = no_dbias - layer_pd.bias.stop_gradient = no_dbias - - out_ref, grad_input_ref = calc_output_and_grad(layer_pd, input_tensor, grad_out) - out, grad_input = calc_output_and_grad(layer_te, input_tensor, grad_out) - - assert_allclose(out, out_ref, rtol=rtol, atol=atol) - if not no_dgrad: - assert_allclose(grad_input, grad_input_ref, rtol=rtol, atol=atol) - if not no_wgrad: - assert_allclose(layer_te.weight.grad, layer_pd.weight.grad.T, rtol=rtol, atol=atol) - if has_bias and not no_dbias: - assert_allclose(layer_te.bias.grad, layer_pd.bias.grad, rtol=rtol, atol=atol) - - @staticmethod - @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest.mark.parametrize("bs,in_features,out_features", LINEAR_CASES) - @pytest.mark.parametrize("has_bias,no_dbias", [[True, False], [True, True], [False, False]]) - @pytest.mark.parametrize("no_dgrad", [True, False]) - @pytest.mark.parametrize("no_wgrad", [True, False]) - @pytest.mark.parametrize("fp8_wgrad", [True, False]) - @pytest.mark.parametrize("do_calibration", [True, False]) - @pytest.mark.parametrize("activation_dtype", ["bfloat16", "float32"]) - def test_linear_fp8( - bs, - in_features, - out_features, - has_bias, - no_dbias, - no_dgrad, - no_wgrad, - fp8_wgrad, - do_calibration, - activation_dtype, - ): - """ - Test FP8 Linear - """ - rtol = 0.1 - atol = 0.5 - - input_tensor = paddle.uniform(shape=(bs, in_features), dtype=activation_dtype) - input_tensor.stop_gradient = no_dgrad - grad_out = paddle.uniform(shape=(bs, out_features), dtype=activation_dtype) - - recipe = DelayedScaling(override_linear_precision=(False, False, not fp8_wgrad)) - - paddle.set_default_dtype(activation_dtype) - layer_te = te.Linear( - in_features=in_features, - out_features=out_features, - bias_attr=None if has_bias else False, - ) - layer_pd = te.Linear( - in_features=in_features, - out_features=out_features, - bias_attr=None if has_bias else False, - backend="paddle", - ) - layer_pd.weight.copy_(layer_te.weight.T, True) - if has_bias: - layer_pd.bias.copy_(layer_te.bias, True) - - layer_te.weight.stop_gradient = no_wgrad - layer_pd.weight.stop_gradient = no_wgrad - if has_bias: - layer_te.bias.stop_gradient = no_dbias - layer_pd.bias.stop_gradient = no_dbias - - with fp8_autocast( - enabled=not do_calibration, calibrating=do_calibration, fp8_recipe=recipe - ): - out_ref, grad_input_ref = calc_output_and_grad(layer_pd, input_tensor, grad_out) - out, grad_input = calc_output_and_grad(layer_te, input_tensor, grad_out) - - assert_allclose(out, out_ref, rtol=rtol, atol=atol) - if not no_dgrad: - assert_allclose(grad_input, grad_input_ref, rtol=rtol, atol=atol) - if not no_wgrad: - assert_allclose(layer_te.weight.grad, layer_pd.weight.grad.T, rtol=rtol, atol=atol) - if has_bias and not no_dbias: - assert_allclose(layer_te.bias.grad, layer_pd.bias.grad, rtol=rtol, atol=atol) - if do_calibration: - assert paddle.count_nonzero(layer_te.fp8_meta["scaling_fwd"].amax_history).item() > 0 - - @staticmethod - @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest.mark.parametrize("bs,in_features,out_features", LINEAR_CASES) - @pytest.mark.parametrize("activation_dtype", ["bfloat16"]) - @pytest.mark.parametrize("num_microbatch", [8]) - def test_linear_fp8_microbatch(bs, in_features, out_features, activation_dtype, num_microbatch): - """ - Test FP8 Linear - """ - rtol = 0.1 - atol = 0.1 - - recipe = DelayedScaling() - - paddle.set_default_dtype(activation_dtype) - layer_cached = te.Linear( - in_features=in_features, - out_features=out_features, - ) - layer_normal = te.Linear( - in_features=in_features, - out_features=out_features, - ) - layer_cached.weight.copy_(layer_normal.weight, True) - layer_cached.bias.copy_(layer_normal.bias, True) - - for iteration in range(num_microbatch): - input_tensor = paddle.uniform(shape=(bs, in_features), dtype=activation_dtype) - grad_out = paddle.uniform(shape=(bs, out_features), dtype=activation_dtype) - - with fp8_autocast(enabled=True, fp8_recipe=recipe): - out = layer_cached(input_tensor, is_first_microbatch=(iteration == 0)) - out.backward(grad_out) - - with fp8_autocast(enabled=True, fp8_recipe=recipe): - out_ref = layer_normal(input_tensor) - out_ref.backward(grad_out) - - assert_allclose(out, out_ref, rtol=rtol, atol=atol) - assert_allclose( - layer_cached.weight.grad, layer_normal.weight.grad, rtol=rtol, atol=atol - ) - - -@pytest.mark.parametrize("bs,hidden_size", NORM_CASES) -@pytest.mark.parametrize("has_bias,no_dbias", [[True, False], [True, True], [False, False]]) -@pytest.mark.parametrize("no_dgrad", [True, False]) -@pytest.mark.parametrize("no_wgrad", [True, False]) -@pytest.mark.parametrize("activation_dtype", ["bfloat16", "float32"]) -def test_layernorm_bf16(bs, hidden_size, has_bias, no_dbias, no_dgrad, no_wgrad, activation_dtype): - """ - Test BF16 LayerNorm - """ - eps = 1e-3 - rtol = 1e-2 - atol = 1e-2 - - x = paddle.uniform(shape=(bs, hidden_size), dtype=activation_dtype) - x.stop_gradient = no_dgrad - grad_out = paddle.uniform(shape=(bs, hidden_size), dtype=activation_dtype) - - paddle.set_default_dtype(activation_dtype) - layer_te = te.LayerNorm(hidden_size=hidden_size, eps=eps, bias_attr=None if has_bias else False) - layer_pd = te.LayerNorm( - hidden_size=hidden_size, eps=eps, bias_attr=None if has_bias else False, backend="paddle" - ) - layer_pd.weight.copy_(layer_te.weight, True) - if has_bias: - layer_pd.bias.copy_(layer_te.bias, True) - - layer_te.weight.stop_gradient = no_wgrad - layer_pd.weight.stop_gradient = no_wgrad - if has_bias: - layer_te.bias.stop_gradient = no_dbias - layer_pd.bias.stop_gradient = no_dbias - - out_ref, grad_input_ref = calc_output_and_grad(layer_pd, x, grad_out) - out, grad_input = calc_output_and_grad(layer_te, x, grad_out) - - assert_allclose(out, out_ref, rtol=rtol, atol=atol) - if not no_dgrad: - assert_allclose(grad_input, grad_input_ref, rtol=rtol, atol=atol) - if not no_wgrad: - assert_allclose(layer_te.weight.grad, layer_pd.weight.grad, rtol=rtol, atol=atol) - if has_bias and not no_dbias: - assert_allclose(layer_te.bias.grad, layer_pd.bias.grad, rtol=rtol, atol=atol) - - -class TestLayerNormLinear: - """ - Tests for LayerNormLinear layer - """ - - @staticmethod - @pytest.mark.skipif( - paddle.device.cuda.get_device_capability() < (8, 0), - reason="BF16 Linear requires Ampere+ GPU", - ) - @pytest.mark.parametrize("bs,in_features,out_features", LINEAR_CASES) - @pytest.mark.parametrize("has_bias,no_dbias", [[True, False], [True, True], [False, False]]) - @pytest.mark.parametrize("no_dgrad", [True, False]) - @pytest.mark.parametrize("no_wgrad", [True, False]) - @pytest.mark.parametrize("return_ln_out", [True, False]) - @pytest.mark.parametrize("activation_dtype", ["bfloat16", "float32"]) - @pytest.mark.parametrize("normalization", ["RMSNorm", "LayerNorm"]) - def test_layernorm_linear_bf16( - bs, - in_features, - out_features, - has_bias, - no_dbias, - no_dgrad, - no_wgrad, - return_ln_out, - activation_dtype, - normalization, - ): - """ - Test BF16 LayerNormLinear Layer - """ - paddle.set_default_dtype(activation_dtype) - rtol = 5e-2 - atol = 5e-2 - - input_tensor = paddle.uniform(shape=(bs, in_features), dtype=activation_dtype) - input_tensor.stop_gradient = no_dgrad - grad_out = paddle.uniform(shape=(bs, out_features), dtype=activation_dtype) - eps = 1e-3 - has_ln_bias = normalization == "LayerNorm" - - layer_te = te.LayerNormLinear( - in_features=in_features, - out_features=out_features, - eps=eps, - normalization=normalization, - bias_attr=None if has_bias else False, - return_layernorm_output=return_ln_out, - ) - - layer_pd = te.LayerNormLinear( - in_features=in_features, - out_features=out_features, - eps=eps, - normalization=normalization, - bias_attr=None if has_bias else False, - return_layernorm_output=return_ln_out, - backend="paddle", - ) - - layer_pd.ln_weight.copy_(layer_te.ln_weight, True) - if has_ln_bias: - layer_pd.ln_bias.copy_(layer_te.ln_bias, True) - layer_pd.weight.copy_(layer_te.weight.T, True) - if has_bias: - layer_pd.bias.copy_(layer_te.bias, True) - - layer_te.weight.stop_gradient = no_wgrad - layer_te.ln_weight.stop_gradient = no_wgrad - layer_pd.weight.stop_gradient = no_wgrad - layer_pd.ln_weight.stop_gradient = no_wgrad - if has_ln_bias: - layer_te.ln_bias.stop_gradient = no_dbias - layer_pd.ln_bias.stop_gradient = no_dbias - if has_bias: - layer_te.bias.stop_gradient = no_dbias - layer_pd.bias.stop_gradient = no_dbias - - out_ref, ln_out_ref, grad_input_ref = calc_output_and_grad_ln_out( - layer_pd, input_tensor, grad_out, return_ln_out=return_ln_out - ) - out, ln_out, grad_input = calc_output_and_grad_ln_out( - layer_te, input_tensor, grad_out, return_ln_out=return_ln_out - ) - - assert_allclose(out, out_ref, rtol=rtol, atol=atol) - if not no_dgrad: - assert_allclose(grad_input, grad_input_ref, rtol=rtol, atol=atol) - if not no_wgrad: - assert_allclose(layer_te.weight.grad, layer_pd.weight.grad.T, rtol=rtol, atol=atol) - assert_allclose(layer_te.ln_weight.grad, layer_pd.ln_weight.grad, rtol=rtol, atol=atol) - if not no_dbias: - if has_ln_bias: - assert_allclose(layer_te.ln_bias.grad, layer_pd.ln_bias.grad, rtol=rtol, atol=atol) - if has_bias: - assert_allclose(layer_te.bias.grad, layer_pd.bias.grad, rtol=rtol, atol=atol) - if return_ln_out: - assert_allclose(ln_out, ln_out_ref, rtol=rtol, atol=atol) - - @staticmethod - @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest.mark.parametrize("bs,in_features,out_features", LINEAR_CASES) - @pytest.mark.parametrize("has_bias,no_dbias", [[True, False], [True, True], [False, False]]) - @pytest.mark.parametrize("no_dgrad", [True, False]) - @pytest.mark.parametrize("no_wgrad", [True, False]) - @pytest.mark.parametrize("fp8_wgrad", [True, False]) - @pytest.mark.parametrize("do_calibration", [True, False]) - @pytest.mark.parametrize("return_ln_out", [True, False]) - @pytest.mark.parametrize("activation_dtype", ["bfloat16", "float32"]) - @pytest.mark.parametrize("normalization", ["RMSNorm", "LayerNorm"]) - def test_layernorm_linear_fp8( - bs, - in_features, - out_features, - has_bias, - no_dbias, - no_dgrad, - no_wgrad, - fp8_wgrad, - do_calibration, - return_ln_out, - activation_dtype, - normalization, - ): - """ - Test FP8 LayerNormLinear Layer - """ - paddle.set_default_dtype(activation_dtype) - rtol = 0.1 - atol = 0.75 - - input_tensor = paddle.uniform(shape=(bs, in_features), dtype=activation_dtype) - input_tensor.stop_gradient = no_dgrad - grad_out = paddle.uniform(shape=(bs, out_features), dtype=activation_dtype) - eps = 1e-3 - has_ln_bias = normalization == "LayerNorm" - - recipe = DelayedScaling(override_linear_precision=(False, False, not fp8_wgrad)) - - layer_te = te.LayerNormLinear( - in_features=in_features, - out_features=out_features, - eps=eps, - normalization=normalization, - bias_attr=None if has_bias else False, - return_layernorm_output=return_ln_out, - ) - - layer_pd = te.LayerNormLinear( - in_features=in_features, - out_features=out_features, - eps=eps, - normalization=normalization, - bias_attr=None if has_bias else False, - return_layernorm_output=return_ln_out, - backend="paddle", - ) - - layer_pd.ln_weight.copy_(layer_te.ln_weight, True) - if has_ln_bias: - layer_pd.ln_bias.copy_(layer_te.ln_bias, True) - layer_pd.weight.copy_(layer_te.weight.T, True) - if has_bias: - layer_pd.bias.copy_(layer_te.bias, True) - - layer_te.weight.stop_gradient = no_wgrad - layer_te.ln_weight.stop_gradient = no_wgrad - layer_pd.weight.stop_gradient = no_wgrad - layer_pd.ln_weight.stop_gradient = no_wgrad - if has_ln_bias: - layer_te.ln_bias.stop_gradient = no_dbias - layer_pd.ln_bias.stop_gradient = no_dbias - if has_bias: - layer_te.bias.stop_gradient = no_dbias - layer_pd.bias.stop_gradient = no_dbias - - with fp8_autocast( - enabled=not do_calibration, calibrating=do_calibration, fp8_recipe=recipe - ): - out_ref, ln_out_ref, grad_input_ref = calc_output_and_grad_ln_out( - layer_pd, input_tensor, grad_out, return_ln_out=return_ln_out - ) - out, ln_out, grad_input = calc_output_and_grad_ln_out( - layer_te, input_tensor, grad_out, return_ln_out=return_ln_out - ) - - assert_allclose(out, out_ref, rtol=rtol, atol=atol) - if not no_dgrad: - assert_allclose(grad_input, grad_input_ref, rtol=rtol, atol=atol) - if not no_wgrad: - assert_allclose(layer_te.weight.grad, layer_pd.weight.grad.T, rtol=rtol, atol=atol) - assert_allclose(layer_te.ln_weight.grad, layer_pd.ln_weight.grad, rtol=rtol, atol=atol) - if not no_dbias: - if has_ln_bias: - assert_allclose(layer_te.ln_bias.grad, layer_pd.ln_bias.grad, rtol=rtol, atol=atol) - if has_bias: - assert_allclose(layer_te.bias.grad, layer_pd.bias.grad, rtol=rtol, atol=atol) - if return_ln_out: - assert_allclose(ln_out, ln_out_ref, rtol=rtol, atol=atol) - if do_calibration: - assert paddle.count_nonzero(layer_te.fp8_meta["scaling_fwd"].amax_history).item() > 0 - - @staticmethod - @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest.mark.parametrize("bs,in_features,out_features", LINEAR_CASES) - @pytest.mark.parametrize("activation_dtype", ["bfloat16"]) - @pytest.mark.parametrize("num_microbatch", [8]) - def test_layernorm_linear_fp8_microbatch( - bs, in_features, out_features, activation_dtype, num_microbatch - ): - """ - Test FP8 LayerNormLinear Layer - """ - paddle.set_default_dtype(activation_dtype) - eps = 1e-3 - rtol = 0.5 - atol = 0.5 - - recipe = DelayedScaling() - - layer_cached = te.LayerNormLinear( - in_features=in_features, - out_features=out_features, - eps=eps, - ) - - layer_normal = te.LayerNormLinear( - in_features=in_features, - out_features=out_features, - eps=eps, - ) - - layer_cached.ln_weight.copy_(layer_normal.ln_weight, True) - layer_cached.ln_bias.copy_(layer_normal.ln_bias, True) - layer_cached.weight.copy_(layer_normal.weight, True) - layer_cached.bias.copy_(layer_normal.bias, True) - - for iteration in range(num_microbatch): - input_tensor = paddle.uniform(shape=(bs, in_features), dtype=activation_dtype) - grad_out = paddle.uniform(shape=(bs, out_features), dtype=activation_dtype) - - with fp8_autocast(enabled=True, fp8_recipe=recipe): - out = layer_cached(input_tensor, is_first_microbatch=(iteration == 0)) - out.backward(grad_out) - - with fp8_autocast(enabled=True, fp8_recipe=recipe): - out_ref = layer_normal(input_tensor) - out_ref.backward(grad_out) - - assert_allclose(out, out_ref, rtol=rtol, atol=atol) - assert_allclose( - layer_cached.weight.grad, layer_normal.weight.grad, rtol=rtol, atol=atol - ) - assert_allclose( - layer_cached.ln_weight.grad, layer_normal.ln_weight.grad, rtol=rtol, atol=atol - ) - - -class TestLayerNormMLP: - """ - Test LayerNormMLP Layer - """ - - @staticmethod - @pytest.mark.skipif( - paddle.device.cuda.get_device_capability() < (8, 0), - reason="BF16 Linear requires Ampere+ GPU", - ) - @pytest.mark.parametrize("bs,hidden_size,ffn_hidden_size", LINEAR_CASES) - @pytest.mark.parametrize("has_bias,no_dbias", [[True, False], [True, True], [False, False]]) - @pytest.mark.parametrize("no_dgrad", [True, False]) - @pytest.mark.parametrize("no_wgrad", [True, False]) - @pytest.mark.parametrize("return_ln_out", [True, False]) - @pytest.mark.parametrize("activation_dtype", ["bfloat16", "float32"]) - @pytest.mark.parametrize("normalization", ["RMSNorm", "LayerNorm"]) - @pytest.mark.parametrize("activation", ["gelu", "swiglu"]) - def test_layernorm_mlp_bf16( - bs, - hidden_size, - ffn_hidden_size, - has_bias, - no_dbias, - no_dgrad, - no_wgrad, - return_ln_out, - activation_dtype, - normalization, - activation, - ): - """ - Tests for TestLayerNormMLP layer - """ - paddle.set_default_dtype(activation_dtype) - rtol = 5e-2 - atol = 5e-2 - - input_tensor = paddle.uniform(shape=(bs, hidden_size), dtype=activation_dtype) - input_tensor.stop_gradient = no_dgrad - grad_out = paddle.uniform(shape=(bs, hidden_size), dtype=activation_dtype) - eps = 1e-3 - has_ln_bias = normalization == "LayerNorm" - - layer_te = te.LayerNormMLP( - hidden_size=hidden_size, - ffn_hidden_size=ffn_hidden_size, - eps=eps, - normalization=normalization, - activation=activation, - bias_attr=None if has_bias else False, - return_layernorm_output=return_ln_out, - ) - layer_pd = te.LayerNormMLP( - hidden_size=hidden_size, - ffn_hidden_size=ffn_hidden_size, - eps=eps, - normalization=normalization, - activation=activation, - bias_attr=None if has_bias else False, - return_layernorm_output=return_ln_out, - backend="paddle", - ) - layer_pd.ln_weight.copy_(layer_te.ln_weight, True) - if has_ln_bias: - layer_pd.ln_bias.copy_(layer_te.ln_bias, True) - layer_pd.fc1_weight.copy_(layer_te.fc1_weight.T, True) - layer_pd.fc2_weight.copy_(layer_te.fc2_weight.T, True) - if has_bias: - layer_pd.fc1_bias.copy_(layer_te.fc1_bias, True) - layer_pd.fc2_bias.copy_(layer_te.fc2_bias, True) - - layer_te.fc1_weight.stop_gradient = no_wgrad - layer_te.fc2_weight.stop_gradient = no_wgrad - layer_te.ln_weight.stop_gradient = no_wgrad - layer_pd.fc1_weight.stop_gradient = no_wgrad - layer_pd.fc2_weight.stop_gradient = no_wgrad - layer_pd.ln_weight.stop_gradient = no_wgrad - if has_ln_bias: - layer_te.ln_bias.stop_gradient = no_dbias - layer_pd.ln_bias.stop_gradient = no_dbias - if has_bias: - layer_te.fc1_bias.stop_gradient = no_dbias - layer_te.fc2_bias.stop_gradient = no_dbias - layer_pd.fc1_bias.stop_gradient = no_dbias - layer_pd.fc2_bias.stop_gradient = no_dbias - - out_ref, ln_out_ref, grad_input_ref = calc_output_and_grad_ln_out( - layer_pd, input_tensor, grad_out, return_ln_out=return_ln_out - ) - out, ln_out, grad_input = calc_output_and_grad_ln_out( - layer_te, input_tensor, grad_out, return_ln_out=return_ln_out - ) - - assert_allclose(out, out_ref, rtol=rtol, atol=atol) - if not no_dgrad: - assert_allclose(grad_input, grad_input_ref, rtol=rtol, atol=atol) - if not no_wgrad: - assert_allclose(layer_te.ln_weight.grad, layer_pd.ln_weight.grad, rtol=rtol, atol=atol) - assert_allclose( - layer_te.fc1_weight.grad, layer_pd.fc1_weight.grad.T, rtol=rtol, atol=atol - ) - assert_allclose( - layer_te.fc2_weight.grad, layer_pd.fc2_weight.grad.T, rtol=rtol, atol=atol - ) - if not no_dbias: - if has_ln_bias: - assert_allclose(layer_te.ln_bias.grad, layer_pd.ln_bias.grad, rtol=rtol, atol=atol) - if has_bias: - assert_allclose( - layer_te.fc1_bias.grad, layer_pd.fc1_bias.grad, rtol=rtol, atol=atol - ) - assert_allclose( - layer_te.fc2_bias.grad, layer_pd.fc2_bias.grad, rtol=rtol, atol=atol - ) - if return_ln_out: - assert_allclose(ln_out, ln_out_ref, rtol=rtol, atol=atol) - - @staticmethod - @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest.mark.parametrize("bs,hidden_size,ffn_hidden_size", LINEAR_CASES) - @pytest.mark.parametrize("has_bias,no_dbias", [[True, False], [True, True], [False, False]]) - @pytest.mark.parametrize("no_dgrad", [True, False]) - @pytest.mark.parametrize("no_wgrad", [True, False]) - @pytest.mark.parametrize("fp8_wgrad", [True, False]) - @pytest.mark.parametrize("do_calibration", [True, False]) - @pytest.mark.parametrize("return_ln_out", [True, False]) - @pytest.mark.parametrize("activation_dtype", ["bfloat16", "float32"]) - @pytest.mark.parametrize("normalization", ["RMSNorm", "LayerNorm"]) - @pytest.mark.parametrize("activation", ["gelu", "swiglu"]) - def test_layernorm_mlp_fp8( - bs, - hidden_size, - ffn_hidden_size, - has_bias, - no_dbias, - no_dgrad, - no_wgrad, - fp8_wgrad, - do_calibration, - return_ln_out, - activation_dtype, - normalization, - activation, - ): - """ - Test FP8 LayerNormMLP Layer - """ - paddle.set_default_dtype(activation_dtype) - rtol = 0.1 - atol = 0.75 - - input_tensor = paddle.uniform(shape=(bs, hidden_size), dtype=activation_dtype) - input_tensor.stop_gradient = no_dgrad - grad_out = paddle.uniform(shape=(bs, hidden_size), dtype=activation_dtype) - eps = 1e-3 - has_ln_bias = normalization == "LayerNorm" - - recipe = DelayedScaling(override_linear_precision=(False, False, not fp8_wgrad)) - - layer_te = te.LayerNormMLP( - hidden_size=hidden_size, - ffn_hidden_size=ffn_hidden_size, - eps=eps, - normalization=normalization, - activation=activation, - bias_attr=None if has_bias else False, - return_layernorm_output=return_ln_out, - ) - - layer_pd = te.LayerNormMLP( - hidden_size=hidden_size, - ffn_hidden_size=ffn_hidden_size, - eps=eps, - normalization=normalization, - activation=activation, - bias_attr=None if has_bias else False, - return_layernorm_output=return_ln_out, - backend="paddle", - ) - layer_pd.ln_weight.copy_(layer_te.ln_weight, True) - if has_ln_bias: - layer_pd.ln_bias.copy_(layer_te.ln_bias, True) - layer_pd.fc1_weight.copy_(layer_te.fc1_weight.T, True) - layer_pd.fc2_weight.copy_(layer_te.fc2_weight.T, True) - if has_bias: - layer_pd.fc1_bias.copy_(layer_te.fc1_bias, True) - layer_pd.fc2_bias.copy_(layer_te.fc2_bias, True) - - layer_te.fc1_weight.stop_gradient = no_wgrad - layer_te.fc2_weight.stop_gradient = no_wgrad - layer_te.ln_weight.stop_gradient = no_wgrad - layer_pd.fc1_weight.stop_gradient = no_wgrad - layer_pd.fc2_weight.stop_gradient = no_wgrad - layer_pd.ln_weight.stop_gradient = no_wgrad - if has_ln_bias: - layer_te.ln_bias.stop_gradient = no_dbias - layer_pd.ln_bias.stop_gradient = no_dbias - if has_bias: - layer_te.fc1_bias.stop_gradient = no_dbias - layer_te.fc2_bias.stop_gradient = no_dbias - layer_pd.fc1_bias.stop_gradient = no_dbias - layer_pd.fc2_bias.stop_gradient = no_dbias - - with fp8_autocast( - enabled=not do_calibration, calibrating=do_calibration, fp8_recipe=recipe - ): - out_ref, ln_out_ref, grad_input_ref = calc_output_and_grad_ln_out( - layer_pd, input_tensor, grad_out, return_ln_out=return_ln_out - ) - out, ln_out, grad_input = calc_output_and_grad_ln_out( - layer_te, input_tensor, grad_out, return_ln_out=return_ln_out - ) - - assert_allclose(out, out_ref, rtol=rtol, atol=atol) - if not no_dgrad: - assert_allclose(grad_input, grad_input_ref, rtol=rtol, atol=atol) - if not no_wgrad: - assert_allclose(layer_te.ln_weight.grad, layer_pd.ln_weight.grad, rtol=rtol, atol=atol) - assert_allclose( - layer_te.fc1_weight.grad, layer_pd.fc1_weight.grad.T, rtol=rtol, atol=atol - ) - assert_allclose( - layer_te.fc2_weight.grad, layer_pd.fc2_weight.grad.T, rtol=rtol, atol=atol - ) - if not no_dbias: - if has_ln_bias: - assert_allclose(layer_te.ln_bias.grad, layer_pd.ln_bias.grad, rtol=rtol, atol=atol) - if has_bias: - assert_allclose( - layer_te.fc1_bias.grad, layer_pd.fc1_bias.grad, rtol=rtol, atol=atol - ) - assert_allclose( - layer_te.fc2_bias.grad, layer_pd.fc2_bias.grad, rtol=rtol, atol=atol - ) - if return_ln_out: - assert_allclose(ln_out, ln_out_ref, rtol=rtol, atol=atol) - - if do_calibration: - assert paddle.count_nonzero(layer_te.fp8_meta["scaling_fwd"].amax_history).item() > 0 - - @staticmethod - @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest.mark.parametrize("bs,hidden_size,ffn_hidden_size", LINEAR_CASES) - @pytest.mark.parametrize("activation_dtype", ["bfloat16"]) - @pytest.mark.parametrize("num_microbatch", [8]) - def test_layernorm_mlp_fp8_microbatch( - bs, hidden_size, ffn_hidden_size, activation_dtype, num_microbatch - ): - """ - Test FP8 LayerNormMLP Layer - """ - paddle.set_default_dtype(activation_dtype) - rtol = 1e-5 - atol = 1e-5 - eps = 1e-3 - - recipe = DelayedScaling() - - layer_cached = te.LayerNormMLP( - hidden_size=hidden_size, - ffn_hidden_size=ffn_hidden_size, - eps=eps, - ) - - layer_normal = te.LayerNormMLP( - hidden_size=hidden_size, - ffn_hidden_size=ffn_hidden_size, - eps=eps, - ) - layer_normal.ln_weight.copy_(layer_cached.ln_weight, True) - layer_normal.ln_bias.copy_(layer_cached.ln_bias, True) - layer_normal.fc1_weight.copy_(layer_cached.fc1_weight, True) - layer_normal.fc2_weight.copy_(layer_cached.fc2_weight, True) - layer_normal.fc1_bias.copy_(layer_cached.fc1_bias, True) - layer_normal.fc2_bias.copy_(layer_cached.fc2_bias, True) - - # Calibration to make sure weight scale is the same - input_tensor = paddle.uniform(shape=(bs, hidden_size), dtype=activation_dtype) - with fp8_autocast(enabled=False, calibrating=True, fp8_recipe=recipe): - _ = layer_cached(input_tensor) - - with fp8_autocast(enabled=False, calibrating=True, fp8_recipe=recipe): - _ = layer_normal(input_tensor) - - for iteration in range(num_microbatch): - input_tensor = paddle.uniform(shape=(bs, hidden_size), dtype=activation_dtype) - grad_out = paddle.uniform(shape=(bs, hidden_size), dtype=activation_dtype) - - with fp8_autocast(enabled=True, fp8_recipe=recipe): - out = layer_cached(input_tensor, is_first_microbatch=(iteration == 0)) - out.backward(grad_out) - - with fp8_autocast(enabled=True, fp8_recipe=recipe): - out_ref = layer_normal(input_tensor) - out_ref.backward(grad_out) - - assert_allclose(out, out_ref, rtol=rtol, atol=atol) - assert_allclose( - layer_cached.ln_weight.grad, layer_normal.ln_weight.grad, rtol=rtol, atol=atol - ) - assert_allclose( - layer_cached.fc1_weight.grad, layer_normal.fc1_weight.grad, rtol=rtol, atol=atol - ) - assert_allclose( - layer_cached.fc2_weight.grad, layer_normal.fc2_weight.grad, rtol=rtol, atol=atol - ) - - -@pytest.mark.parametrize("bs", [1, 2]) -@pytest.mark.parametrize("hidden_size, num_heads", [[1024, 16]]) -@pytest.mark.parametrize("q_seqlen, kv_seqlen", [[1024, 1024]]) -@pytest.mark.parametrize("attn_type", ["self", "cross"]) -@pytest.mark.parametrize("mask_type", ["causal", "padding"]) -@pytest.mark.parametrize("math_dtype", ["bfloat16", "float16"]) -@pytest.mark.parametrize("deterministic", [True, False]) -def test_dot_product_attention( - bs, hidden_size, num_heads, q_seqlen, kv_seqlen, attn_type, mask_type, math_dtype, deterministic -): - """ - Test DotProductAttention Layer - """ - paddle.set_default_dtype(math_dtype) - rtol = 1e-4 - atol = 2e-2 - head_size = hidden_size // num_heads - - # Skip if cuDNN fused attention is not supported - if not is_fused_attention_supported( - num_heads=num_heads, - num_gqa_groups=num_heads, - q_seqlen=q_seqlen, - kv_seqlen=kv_seqlen, - head_size=head_size, - dtype=math_dtype, - dropout=0.0, - qkv_layout="bshd_bshd_bshd", - bias_type="no_bias", - mask_type=mask_type, - ): - pytest.skip("cuDNN fused attention is not supported") - - attn_q_input = paddle.normal( - mean=0.0, std=0.02, shape=(bs, q_seqlen, num_heads, head_size) - ).astype(math_dtype) - attn_k_input = paddle.normal( - mean=0.0, std=0.02, shape=(bs, kv_seqlen, num_heads, head_size) - ).astype(math_dtype) - attn_v_input = paddle.normal( - mean=0.0, std=0.02, shape=(bs, kv_seqlen, num_heads, head_size) - ).astype(math_dtype) - - q_actual_seqlen = paddle.randint(low=20, high=q_seqlen, shape=(bs,), dtype="int32") - kv_actual_seqlen = ( - paddle.randint(low=20, high=kv_seqlen, shape=(bs,), dtype="int32") - if attn_type == "cross" - else q_actual_seqlen - ) - attn_mask = paddle.ones(shape=(bs, 1, q_seqlen, kv_seqlen), dtype="bool") - - grad_out = paddle.normal(mean=0.0, std=0.02, shape=(bs, q_seqlen, num_heads, head_size)).astype( - "float32" - ) - for i in range(0, bs): - grad_out[i, q_actual_seqlen[i] :, :, :] = 0 - grad_out = grad_out.astype(math_dtype) - - for i in range(0, bs): - attn_mask[i, 0, 0 : q_actual_seqlen[i], 0 : kv_actual_seqlen[i]] = False - - head_size = hidden_size // num_heads - - if deterministic: - os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "0" - - layer_te = te.DotProductAttention( - num_heads, - head_size, - attention_dropout=0.0, - attn_mask_type=mask_type, - attention_type=attn_type, - backend="transformer_engine", - ) - layer_pd = te.DotProductAttention( - num_heads, - head_size, - attention_dropout=0.0, - attn_mask_type=mask_type, - attention_type=attn_type, - backend="paddle", - ) - - def calc_attn_output_and_grad(layer, q, k, v, mask, dout): - _q = paddle.to_tensor(q, stop_gradient=False) - _k = paddle.to_tensor(k, stop_gradient=False) - _v = paddle.to_tensor(v, stop_gradient=False) - - out = layer(_q, _k, _v, mask) - out.backward(dout) - return out, _q.grad, _k.grad, _v.grad - - out, q_grad, k_grad, v_grad = calc_attn_output_and_grad( - layer_te, attn_q_input, attn_k_input, attn_v_input, attn_mask, grad_out - ) - out_ref, q_grad_ref, k_grad_ref, v_grad_ref = calc_attn_output_and_grad( - layer_pd, attn_q_input, attn_k_input, attn_v_input, attn_mask, grad_out - ) - valid_out_ref = paddle.full_like(out_ref, 0) - for i in range(0, bs): - valid_out_ref[i, 0 : q_actual_seqlen[i], :, :] = out_ref[i, 0 : q_actual_seqlen[i], :, :] - - valid_q_grad_ref = paddle.full_like(q_grad_ref, 0) - valid_k_grad_ref = paddle.full_like(k_grad_ref, 0) - valid_v_grad_ref = paddle.full_like(v_grad_ref, 0) - for i in range(0, bs): - valid_q_grad_ref[i, 0 : q_actual_seqlen[i], :, :] = q_grad_ref[ - i, 0 : q_actual_seqlen[i], :, : - ] - valid_k_grad_ref[i, 0 : kv_actual_seqlen[i], :, :] = k_grad_ref[ - i, 0 : kv_actual_seqlen[i], :, : - ] - valid_v_grad_ref[i, 0 : kv_actual_seqlen[i], :, :] = v_grad_ref[ - i, 0 : kv_actual_seqlen[i], :, : - ] - - assert_allclose(out, valid_out_ref, rtol=rtol, atol=atol) - assert_allclose(q_grad, valid_q_grad_ref, rtol=rtol, atol=atol) - assert_allclose(k_grad, valid_k_grad_ref, rtol=rtol, atol=atol) - assert_allclose(v_grad, valid_v_grad_ref, rtol=rtol, atol=atol) - if deterministic: - out2, q_grad2, k_grad2, v_grad2 = calc_attn_output_and_grad( - layer_te, attn_q_input, attn_k_input, attn_v_input, attn_mask, grad_out - ) - assert_allclose(out, out2, rtol=1e-12, atol=1e-12) - assert_allclose(q_grad, q_grad2, rtol=1e-12, atol=1e-12) - assert_allclose(k_grad, k_grad2, rtol=1e-12, atol=1e-12) - assert_allclose(v_grad, v_grad2, rtol=1e-12, atol=1e-12) - os.environ.pop("NVTE_ALLOW_NONDETERMINISTIC_ALGO", None) - - -@pytest.mark.parametrize("bs", [1, 2]) -@pytest.mark.parametrize("num_gqa_groups", [1, 2, 4]) -@pytest.mark.parametrize("hidden_size, num_heads, ffn_hidden_size", [[256, 4, 1024]]) -@pytest.mark.parametrize("q_seqlen, kv_seqlen", [[1024, 1024]]) -@pytest.mark.parametrize("has_bias, no_dbias", [[False, True], [True, True], [True, False]]) -@pytest.mark.parametrize("no_wgrad", [True, False]) -@pytest.mark.parametrize("mask_type", ["causal", "padding"]) -@pytest.mark.parametrize("math_dtype", ["bfloat16", "float16"]) -@pytest.mark.parametrize("output_layernorm", [True, False]) -@pytest.mark.parametrize("return_layernorm_output", [True, False]) -@pytest.mark.parametrize("normalization", ["RMSNorm", "LayerNorm"]) -def test_transformer_encoder_layer( - bs, - hidden_size, - num_heads, - num_gqa_groups, - ffn_hidden_size, - has_bias, - no_dbias, - no_wgrad, - q_seqlen, - kv_seqlen, - mask_type, - math_dtype, - output_layernorm, - return_layernorm_output, - normalization, -): - """ - Test Transformer Encoder Layer - """ - paddle.set_default_dtype(math_dtype) - rtol = 5e-2 - atol = 5e-2 - eps = 1e-3 - has_ln_bias = normalization == "LayerNorm" - - # Skip if cuDNN fused attention is not supported - if not is_fused_attention_supported( - num_heads=num_heads, - num_gqa_groups=num_gqa_groups, - q_seqlen=q_seqlen, - kv_seqlen=kv_seqlen, - head_size=hidden_size // num_heads, - dtype=math_dtype, - dropout=0.0, - qkv_layout="bshd_bshd_bshd", - bias_type="no_bias", - mask_type=mask_type, - ): - pytest.skip("cuDNN fused attention is not supported") - - encoder_input = paddle.uniform(shape=(bs, q_seqlen, hidden_size), dtype=math_dtype) - - q_actual_seqlen = paddle.ones(shape=(bs,), dtype="int32") * q_seqlen - kv_actual_seqlen = q_actual_seqlen - attn_mask = paddle.ones(shape=(bs, 1, q_seqlen, kv_seqlen), dtype="bool") - - grad_out = paddle.normal(mean=0.0, std=0.02, shape=(bs, q_seqlen, hidden_size)).astype( - "float32" - ) - for i in range(0, bs): - grad_out[i, q_actual_seqlen[i] :, :] = 0 - grad_out = grad_out.astype(math_dtype) - - for i in range(0, bs): - attn_mask[i, 0, 0 : q_actual_seqlen[i], 0 : kv_actual_seqlen[i]] = False - - layer_te = te.TransformerLayer( - hidden_size, - ffn_hidden_size, - num_heads, - num_gqa_groups=num_gqa_groups, - layernorm_epsilon=eps, - hidden_dropout=0.0, - attention_dropout=0.0, - weight_attr=None, - bias_attr=None if has_bias else False, - self_attn_mask_type=mask_type, - apply_residual_connection_post_layernorm=return_layernorm_output, - output_layernorm=output_layernorm, - layer_type="encoder", - normalization=normalization, - backend="transformer_engine", - ) - layer_pd = te.TransformerLayer( - hidden_size, - ffn_hidden_size, - num_heads, - num_gqa_groups=num_gqa_groups, - layernorm_epsilon=eps, - hidden_dropout=0.0, - attention_dropout=0.0, - weight_attr=None, - bias_attr=None if has_bias else False, - self_attn_mask_type=mask_type, - apply_residual_connection_post_layernorm=return_layernorm_output, - output_layernorm=output_layernorm, - layer_type="encoder", - normalization=normalization, - backend="paddle", - ) - - # MultiHeadAttention params - if output_layernorm: - layer_pd.self_attention.qkv.weight.copy_(layer_te.self_attention.qkv.weight.T, True) - layer_pd.self_attention.qkv.weight.stop_gradient = no_wgrad - layer_te.self_attention.qkv.weight.stop_gradient = no_wgrad - if has_bias: - layer_pd.self_attention.qkv.bias.copy_(layer_te.self_attention.qkv.bias, True) - layer_pd.self_attention.qkv.bias.stop_gradient = no_dbias - layer_te.self_attention.qkv.bias.stop_gradient = no_dbias - else: - layer_pd.self_attention.layernorm_qkv.ln_weight.copy_( - layer_te.self_attention.layernorm_qkv.ln_weight, True - ) - layer_pd.self_attention.layernorm_qkv.weight.copy_( - layer_te.self_attention.layernorm_qkv.weight.T, True - ) - layer_pd.self_attention.layernorm_qkv.ln_weight.stop_gradient = no_wgrad - layer_pd.self_attention.layernorm_qkv.weight.stop_gradient = no_wgrad - layer_te.self_attention.layernorm_qkv.ln_weight.stop_gradient = no_wgrad - layer_te.self_attention.layernorm_qkv.weight.stop_gradient = no_wgrad - if has_ln_bias: - layer_pd.self_attention.layernorm_qkv.ln_bias.copy_( - layer_te.self_attention.layernorm_qkv.ln_bias, True - ) - layer_pd.self_attention.layernorm_qkv.ln_bias.stop_gradient = no_dbias - layer_te.self_attention.layernorm_qkv.ln_bias.stop_gradient = no_dbias - if has_bias: - layer_pd.self_attention.layernorm_qkv.bias.copy_( - layer_te.self_attention.layernorm_qkv.bias, True - ) - layer_pd.self_attention.layernorm_qkv.bias.stop_gradient = no_dbias - layer_te.self_attention.layernorm_qkv.bias.stop_gradient = no_dbias - - layer_pd.self_attention.proj.weight.copy_(layer_te.self_attention.proj.weight.T, True) - layer_pd.self_attention.proj.weight.stop_gradient = no_wgrad - layer_te.self_attention.proj.weight.stop_gradient = no_wgrad - if has_bias: - layer_pd.self_attention.proj.bias.copy_(layer_te.self_attention.proj.bias, True) - layer_pd.self_attention.proj.bias.stop_gradient = no_dbias - layer_te.self_attention.proj.bias.stop_gradient = no_dbias - - # LayerNorm MLP params - layer_pd.layernorm_mlp.ln_weight.copy_(layer_te.layernorm_mlp.ln_weight, True) - layer_pd.layernorm_mlp.fc1_weight.copy_(layer_te.layernorm_mlp.fc1_weight.T, True) - layer_pd.layernorm_mlp.fc2_weight.copy_(layer_te.layernorm_mlp.fc2_weight.T, True) - layer_pd.layernorm_mlp.ln_weight.stop_gradient = no_wgrad - layer_pd.layernorm_mlp.fc1_weight.stop_gradient = no_wgrad - layer_pd.layernorm_mlp.fc2_weight.stop_gradient = no_wgrad - layer_te.layernorm_mlp.ln_weight.stop_gradient = no_wgrad - layer_te.layernorm_mlp.fc1_weight.stop_gradient = no_wgrad - layer_te.layernorm_mlp.fc2_weight.stop_gradient = no_wgrad - if has_ln_bias: - layer_pd.layernorm_mlp.ln_bias.copy_(layer_te.layernorm_mlp.ln_bias, True) - layer_pd.layernorm_mlp.ln_bias.stop_gradient = no_dbias - layer_te.layernorm_mlp.ln_bias.stop_gradient = no_dbias - if has_bias: - layer_pd.layernorm_mlp.fc1_bias.copy_(layer_te.layernorm_mlp.fc1_bias, True) - layer_pd.layernorm_mlp.fc2_bias.copy_(layer_te.layernorm_mlp.fc2_bias, True) - layer_pd.layernorm_mlp.fc1_bias.stop_gradient = no_dbias - layer_pd.layernorm_mlp.fc2_bias.stop_gradient = no_dbias - layer_te.layernorm_mlp.fc1_bias.stop_gradient = no_dbias - layer_te.layernorm_mlp.fc2_bias.stop_gradient = no_dbias - - if output_layernorm: - layer_pd.layernorm.weight.copy_(layer_te.layernorm.weight, True) - layer_pd.layernorm.bias.copy_(layer_te.layernorm.bias, True) - layer_pd.layernorm.weight.stop_gradient = no_wgrad - layer_pd.layernorm.bias.stop_gradient = no_dbias - layer_te.layernorm.weight.stop_gradient = no_wgrad - layer_te.layernorm.bias.stop_gradient = no_dbias - - def calc_transformer_output_and_grad(layer, encoder_input, mask, dout): - _encoder_input = paddle.to_tensor(encoder_input, stop_gradient=False) - out = layer(_encoder_input, mask) - out.backward(dout) - return out, _encoder_input.grad - - out_ref, grad_input_ref = calc_transformer_output_and_grad( - layer_pd, encoder_input, attn_mask, grad_out - ) - out, grad_input = calc_transformer_output_and_grad(layer_te, encoder_input, attn_mask, grad_out) - - assert_allclose(out, out_ref, rtol=rtol, atol=atol) - assert_allclose(grad_input, grad_input_ref, rtol=rtol, atol=atol) - if not no_wgrad: - if output_layernorm: - assert_allclose( - layer_te.self_attention.qkv.weight.grad, - layer_pd.self_attention.qkv.weight.grad.T, - rtol=rtol, - atol=atol, - ) - else: - assert_allclose( - layer_te.self_attention.layernorm_qkv.weight.grad, - layer_pd.self_attention.layernorm_qkv.weight.grad.T, - rtol=rtol, - atol=atol, - ) - if not no_dbias: - if output_layernorm: - assert_allclose( - layer_te.self_attention.qkv.bias.grad, - layer_pd.self_attention.qkv.bias.grad, - rtol=0.01, - atol=0.5, - ) - else: - assert_allclose( - layer_te.self_attention.layernorm_qkv.bias.grad, - layer_pd.self_attention.layernorm_qkv.bias.grad, - rtol=0.01, - atol=0.5, - ) - - -@pytest.mark.parametrize("bs", [1, 2]) -@pytest.mark.parametrize("num_gqa_groups", [1, 2, 4]) -@pytest.mark.parametrize("hidden_size, num_heads, ffn_hidden_size", [[256, 4, 1024]]) -@pytest.mark.parametrize("q_seqlen, kv_seqlen", [[1024, 1024]]) -@pytest.mark.parametrize("has_bias, no_dbias", [[False, True], [True, True], [True, False]]) -@pytest.mark.parametrize("no_wgrad", [True, False]) -@pytest.mark.parametrize("mask_type", ["causal", "padding"]) -@pytest.mark.parametrize("math_dtype", ["bfloat16", "float16"]) -@pytest.mark.parametrize("output_layernorm", [True, False]) -@pytest.mark.parametrize("return_layernorm_output", [True, False]) -@pytest.mark.parametrize("recompute_core_attention", [True, False]) -@pytest.mark.parametrize("normalization", ["RMSNorm", "LayerNorm"]) -def test_transformer_decoder_layer( - bs, - hidden_size, - num_heads, - num_gqa_groups, - ffn_hidden_size, - has_bias, - no_dbias, - no_wgrad, - q_seqlen, - kv_seqlen, - mask_type, - math_dtype, - output_layernorm, - return_layernorm_output, - recompute_core_attention, - normalization, -): - """ - Test Transformer Decoder Layer - """ - paddle.set_default_dtype(math_dtype) - rtol = 5e-2 - atol = 6e-2 - eps = 1e-3 - has_ln_bias = normalization == "LayerNorm" - - # Skip if cuDNN fused attention is not supported - if not is_fused_attention_supported( - num_heads=num_heads, - num_gqa_groups=num_gqa_groups, - q_seqlen=q_seqlen, - kv_seqlen=kv_seqlen, - head_size=hidden_size // num_heads, - dtype=math_dtype, - dropout=0.0, - qkv_layout="bshd_bshd_bshd", - bias_type="no_bias", - mask_type=mask_type, - ): - pytest.skip("cuDNN fused attention is not supported") - - encoder_input = paddle.normal(mean=0.0, std=0.1, shape=(bs, q_seqlen, hidden_size)).astype( - math_dtype - ) - encoder_output = paddle.normal(mean=0.0, std=0.1, shape=(bs, kv_seqlen, hidden_size)).astype( - math_dtype - ) - - q_actual_seqlen = paddle.ones(shape=(bs,), dtype="int32") * q_seqlen - kv_actual_seqlen = q_actual_seqlen - attn_mask = paddle.ones(shape=(bs, 1, q_seqlen, kv_seqlen), dtype="bool") - - grad_out = paddle.normal(mean=0.0, std=0.01, shape=(bs, q_seqlen, hidden_size)).astype( - "float32" - ) - - # rounding to avoid numerical issues - encoder_input = paddle.round(encoder_input * 1000) / 1000 - encoder_output = paddle.round(encoder_output * 1000) / 1000 - grad_out = paddle.round(grad_out * 1000) / 1000 - - for i in range(0, bs): - grad_out[i, q_actual_seqlen[i] :, :] = 0 - grad_out = grad_out.astype(math_dtype) - - for i in range(0, bs): - attn_mask[i, 0, 0 : q_actual_seqlen[i], 0 : kv_actual_seqlen[i]] = False - - layer_te = te.TransformerLayer( - hidden_size, - ffn_hidden_size, - num_heads, - num_gqa_groups=num_gqa_groups, - layernorm_epsilon=eps, - hidden_dropout=0.0, - attention_dropout=0.0, - weight_attr=None, - bias_attr=None if has_bias else False, - self_attn_mask_type=mask_type, - apply_residual_connection_post_layernorm=return_layernorm_output, - output_layernorm=output_layernorm, - layer_type="decoder", - normalization=normalization, - backend="transformer_engine", - ) - layer_pd = te.TransformerLayer( - hidden_size, - ffn_hidden_size, - num_heads, - num_gqa_groups=num_gqa_groups, - layernorm_epsilon=eps, - hidden_dropout=0.0, - attention_dropout=0.0, - weight_attr=None, - bias_attr=None if has_bias else False, - self_attn_mask_type=mask_type, - apply_residual_connection_post_layernorm=return_layernorm_output, - output_layernorm=output_layernorm, - layer_type="decoder", - normalization=normalization, - backend="paddle", - ) - - # MultiHeadAttention params - self attn - if output_layernorm: - layer_pd.self_attention.qkv.weight.copy_(layer_te.self_attention.qkv.weight.T, True) - layer_pd.self_attention.qkv.weight.stop_gradient = no_wgrad - layer_te.self_attention.qkv.weight.stop_gradient = no_wgrad - if has_bias: - layer_pd.self_attention.qkv.bias.copy_(layer_te.self_attention.qkv.bias, True) - layer_pd.self_attention.qkv.bias.stop_gradient = no_dbias - layer_te.self_attention.qkv.bias.stop_gradient = no_dbias - else: - layer_pd.self_attention.layernorm_qkv.ln_weight.copy_( - layer_te.self_attention.layernorm_qkv.ln_weight, True - ) - layer_pd.self_attention.layernorm_qkv.weight.copy_( - layer_te.self_attention.layernorm_qkv.weight.T, True - ) - layer_pd.self_attention.layernorm_qkv.ln_weight.stop_gradient = no_wgrad - layer_pd.self_attention.layernorm_qkv.weight.stop_gradient = no_wgrad - layer_te.self_attention.layernorm_qkv.ln_weight.stop_gradient = no_wgrad - layer_te.self_attention.layernorm_qkv.weight.stop_gradient = no_wgrad - if has_ln_bias: - layer_pd.self_attention.layernorm_qkv.ln_bias.copy_( - layer_te.self_attention.layernorm_qkv.ln_bias, True - ) - layer_pd.self_attention.layernorm_qkv.ln_bias.stop_gradient = no_dbias - layer_te.self_attention.layernorm_qkv.ln_bias.stop_gradient = no_dbias - if has_bias: - layer_pd.self_attention.layernorm_qkv.bias.copy_( - layer_te.self_attention.layernorm_qkv.bias, True - ) - layer_pd.self_attention.layernorm_qkv.bias.stop_gradient = no_dbias - layer_te.self_attention.layernorm_qkv.bias.stop_gradient = no_dbias - - layer_pd.self_attention.proj.weight.copy_(layer_te.self_attention.proj.weight.T, True) - layer_pd.self_attention.proj.weight.stop_gradient = no_wgrad - layer_te.self_attention.proj.weight.stop_gradient = no_wgrad - if has_bias: - layer_pd.self_attention.proj.bias.copy_(layer_te.self_attention.proj.bias, True) - layer_pd.self_attention.proj.bias.stop_gradient = no_dbias - layer_te.self_attention.proj.bias.stop_gradient = no_dbias - - # MultiHeadAttention params - cross attn - layer_pd.inter_attention.layernorm_query.ln_weight.copy_( - layer_te.inter_attention.layernorm_query.ln_weight, True - ) - layer_pd.inter_attention.layernorm_query.weight.copy_( - layer_te.inter_attention.layernorm_query.weight.T, True - ) - layer_pd.inter_attention.layernorm_query.ln_weight.stop_gradient = no_wgrad - layer_pd.inter_attention.layernorm_query.weight.stop_gradient = no_wgrad - layer_te.inter_attention.layernorm_query.ln_weight.stop_gradient = no_wgrad - layer_te.inter_attention.layernorm_query.weight.stop_gradient = no_wgrad - if has_ln_bias: - layer_pd.inter_attention.layernorm_query.ln_bias.copy_( - layer_te.inter_attention.layernorm_query.ln_bias, True - ) - layer_pd.inter_attention.layernorm_query.ln_bias.stop_gradient = no_dbias - layer_te.inter_attention.layernorm_query.ln_bias.stop_gradient = no_dbias - if has_bias: - layer_pd.inter_attention.layernorm_query.bias.copy_( - layer_te.inter_attention.layernorm_query.bias, True - ) - layer_pd.inter_attention.layernorm_query.bias.stop_gradient = no_dbias - layer_te.inter_attention.layernorm_query.bias.stop_gradient = no_dbias - - layer_pd.inter_attention.key_value.weight.copy_( - layer_te.inter_attention.key_value.weight.T, True - ) - layer_pd.inter_attention.key_value.weight.stop_gradient = no_wgrad - layer_te.inter_attention.key_value.weight.stop_gradient = no_wgrad - layer_pd.inter_attention.proj.weight.copy_(layer_te.inter_attention.proj.weight.T, True) - layer_pd.inter_attention.proj.weight.stop_gradient = no_wgrad - layer_te.inter_attention.proj.weight.stop_gradient = no_wgrad - if has_bias: - layer_pd.inter_attention.key_value.bias.copy_(layer_te.inter_attention.key_value.bias, True) - layer_pd.inter_attention.key_value.bias.stop_gradient = no_dbias - layer_te.inter_attention.key_value.bias.stop_gradient = no_dbias - layer_pd.inter_attention.proj.bias.copy_(layer_te.inter_attention.proj.bias, True) - layer_pd.inter_attention.proj.bias.stop_gradient = no_dbias - layer_te.inter_attention.proj.bias.stop_gradient = no_dbias - - # LayerNorm MLP params - layer_pd.layernorm_mlp.ln_weight.copy_(layer_te.layernorm_mlp.ln_weight, True) - layer_pd.layernorm_mlp.fc1_weight.copy_(layer_te.layernorm_mlp.fc1_weight.T, True) - layer_pd.layernorm_mlp.fc2_weight.copy_(layer_te.layernorm_mlp.fc2_weight.T, True) - layer_pd.layernorm_mlp.ln_weight.stop_gradient = no_wgrad - layer_pd.layernorm_mlp.fc1_weight.stop_gradient = no_wgrad - layer_pd.layernorm_mlp.fc2_weight.stop_gradient = no_wgrad - layer_te.layernorm_mlp.ln_weight.stop_gradient = no_wgrad - layer_te.layernorm_mlp.fc1_weight.stop_gradient = no_wgrad - layer_te.layernorm_mlp.fc2_weight.stop_gradient = no_wgrad - if has_ln_bias: - layer_pd.layernorm_mlp.ln_bias.copy_(layer_te.layernorm_mlp.ln_bias, True) - layer_pd.layernorm_mlp.ln_bias.stop_gradient = no_dbias - layer_te.layernorm_mlp.ln_bias.stop_gradient = no_dbias - if has_bias: - layer_pd.layernorm_mlp.fc1_bias.copy_(layer_te.layernorm_mlp.fc1_bias, True) - layer_pd.layernorm_mlp.fc2_bias.copy_(layer_te.layernorm_mlp.fc2_bias, True) - layer_pd.layernorm_mlp.fc1_bias.stop_gradient = no_dbias - layer_pd.layernorm_mlp.fc2_bias.stop_gradient = no_dbias - layer_te.layernorm_mlp.fc1_bias.stop_gradient = no_dbias - layer_te.layernorm_mlp.fc2_bias.stop_gradient = no_dbias - - if output_layernorm: - layer_pd.layernorm.weight.copy_(layer_te.layernorm.weight, True) - layer_pd.layernorm.bias.copy_(layer_te.layernorm.bias, True) - layer_pd.layernorm.weight.stop_gradient = no_wgrad - layer_pd.layernorm.bias.stop_gradient = no_dbias - layer_te.layernorm.weight.stop_gradient = no_wgrad - layer_te.layernorm.bias.stop_gradient = no_dbias - - def calc_transformer_output_and_grad( - layer, - encoder_input, - mask, - encoder_output, - enc_dec_attn_mask, - dout, - recompute_core_attention=False, - ): - _encoder_input = paddle.to_tensor(encoder_input, stop_gradient=False) - _encoder_output = paddle.to_tensor(encoder_output, stop_gradient=False) - out = layer( - _encoder_input, - mask, - _encoder_output, - enc_dec_attn_mask, - recompute_core_attention=recompute_core_attention, - ) - out.backward(dout) - return out, _encoder_input.grad, _encoder_output.grad - - out_ref, grad_encoder_input_ref, grad_encoder_output_ref = calc_transformer_output_and_grad( - layer_pd, encoder_input, attn_mask, encoder_output, attn_mask, grad_out - ) - out, grad_encoder_input, grad_encoder_output = calc_transformer_output_and_grad( - layer_te, - encoder_input, - attn_mask, - encoder_output, - attn_mask, - grad_out, - recompute_core_attention=recompute_core_attention, - ) - - assert_allclose(out, out_ref, rtol=rtol, atol=atol) - assert_allclose(grad_encoder_input, grad_encoder_input_ref, rtol=rtol, atol=atol) - assert_allclose(grad_encoder_output, grad_encoder_output_ref, rtol=rtol, atol=atol) - if not no_wgrad: - if output_layernorm: - assert_allclose( - layer_te.self_attention.qkv.weight.grad, - layer_pd.self_attention.qkv.weight.grad.T, - rtol=rtol, - atol=atol, - ) - else: - assert_allclose( - layer_te.self_attention.layernorm_qkv.weight.grad, - layer_pd.self_attention.layernorm_qkv.weight.grad.T, - rtol=rtol, - atol=atol, - ) - assert_allclose( - layer_te.inter_attention.layernorm_query.weight.grad, - layer_pd.inter_attention.layernorm_query.weight.grad.T, - rtol=rtol, - atol=atol, - ) - if not no_dbias: - if output_layernorm: - assert_allclose( - layer_te.self_attention.qkv.bias.grad, - layer_pd.self_attention.qkv.bias.grad, - rtol=0.5, - atol=0.6, - ) - else: - assert_allclose( - layer_te.self_attention.layernorm_qkv.bias.grad, - layer_pd.self_attention.layernorm_qkv.bias.grad, - rtol=0.01, - atol=0.5, - ) - assert_allclose( - layer_te.inter_attention.layernorm_query.bias.grad, - layer_pd.inter_attention.layernorm_query.bias.grad, - rtol=rtol, - atol=atol, - ) - - -@pytest.mark.skipif(not is_fp8_supported, reason=reason) -@pytest.mark.parametrize("bs", [8]) -@pytest.mark.parametrize("hidden_size, num_heads, ffn_hidden_size", [[1024, 16, 4096]]) -@pytest.mark.parametrize("q_seqlen, kv_seqlen", [[128, 128]]) -@pytest.mark.parametrize("mask_type", ["causal"]) -@pytest.mark.parametrize("math_dtype", ["bfloat16"]) -@pytest.mark.parametrize("num_microbatch", [8]) -def test_transformer_encoder_layer_microbatch( - bs, - hidden_size, - num_heads, - ffn_hidden_size, - q_seqlen, - kv_seqlen, - mask_type, - math_dtype, - num_microbatch, -): - """ - Test Transformer Encoder Layer with FP8 weight caching - """ - paddle.set_default_dtype(math_dtype) - rtol = 1e-5 - atol = 1e-5 - eps = 1e-3 - - # Skip if cuDNN fused attention is not supported - if not is_fused_attention_supported( - num_heads=num_heads, - num_gqa_groups=num_heads, - q_seqlen=q_seqlen, - kv_seqlen=kv_seqlen, - head_size=hidden_size // num_heads, - dtype=math_dtype, - dropout=0.0, - qkv_layout="bs3hd", - bias_type="no_bias", - mask_type=mask_type, - ): - pytest.skip("cuDNN fused attention is not supported") - - layer_cached = te.TransformerLayer( - hidden_size, - ffn_hidden_size, - num_heads, - layernorm_epsilon=eps, - hidden_dropout=0.0, - attention_dropout=0.0, - weight_attr=None, - bias_attr=None, - self_attn_mask_type=mask_type, - layer_type="encoder", - ) - layer_normal = te.TransformerLayer( - hidden_size, - ffn_hidden_size, - num_heads, - layernorm_epsilon=eps, - hidden_dropout=0.0, - attention_dropout=0.0, - weight_attr=None, - bias_attr=None, - self_attn_mask_type=mask_type, - layer_type="encoder", - ) - - layer_normal.self_attention.layernorm_qkv.ln_weight.copy_( - layer_cached.self_attention.layernorm_qkv.ln_weight, True - ) - layer_normal.self_attention.layernorm_qkv.ln_bias.copy_( - layer_cached.self_attention.layernorm_qkv.ln_bias, True - ) - layer_normal.self_attention.layernorm_qkv.weight.copy_( - layer_cached.self_attention.layernorm_qkv.weight, True - ) - layer_normal.self_attention.layernorm_qkv.bias.copy_( - layer_cached.self_attention.layernorm_qkv.bias, True - ) - - layer_normal.self_attention.proj.weight.copy_(layer_cached.self_attention.proj.weight, True) - layer_normal.self_attention.proj.bias.copy_(layer_cached.self_attention.proj.bias, True) - - # LayerNorm MLP params - layer_normal.layernorm_mlp.ln_weight.copy_(layer_cached.layernorm_mlp.ln_weight, True) - layer_normal.layernorm_mlp.ln_bias.copy_(layer_cached.layernorm_mlp.ln_bias, True) - layer_normal.layernorm_mlp.fc1_weight.copy_(layer_cached.layernorm_mlp.fc1_weight, True) - layer_normal.layernorm_mlp.fc2_weight.copy_(layer_cached.layernorm_mlp.fc2_weight, True) - layer_normal.layernorm_mlp.fc1_bias.copy_(layer_cached.layernorm_mlp.fc1_bias, True) - layer_normal.layernorm_mlp.fc2_bias.copy_(layer_cached.layernorm_mlp.fc2_bias, True) - - recipe = DelayedScaling() - - def generate_input(): - encoder_input = paddle.uniform(shape=(bs, q_seqlen, hidden_size), dtype=math_dtype) - - q_actual_seqlen = paddle.ones(shape=(bs,), dtype="int32") * q_seqlen - kv_actual_seqlen = q_actual_seqlen - attn_mask = paddle.ones(shape=(bs, 1, q_seqlen, kv_seqlen), dtype="bool") - - grad_out = paddle.normal(mean=0.0, std=0.02, shape=(bs, q_seqlen, hidden_size)).astype( - "float32" - ) - for i in range(0, bs): - grad_out[i, q_actual_seqlen[i] :, :] = 0 - grad_out = grad_out.astype(math_dtype) - - for i in range(0, bs): - attn_mask[i, 0, 0 : q_actual_seqlen[i], 0 : kv_actual_seqlen[i]] = False - - return encoder_input, attn_mask, grad_out - - # Calibration to make sure weight scale is the same - encoder_input, mask, _ = generate_input() - with fp8_autocast(enabled=False, calibrating=True, fp8_recipe=recipe): - _ = layer_cached(encoder_input, mask) - - with fp8_autocast(enabled=False, calibrating=True, fp8_recipe=recipe): - _ = layer_normal(encoder_input, mask) - - for iteration in range(num_microbatch): - encoder_input, mask, grad_out = generate_input() - - with fp8_autocast(enabled=True, fp8_recipe=recipe): - out = layer_cached(encoder_input, mask, is_first_microbatch=(iteration == 0)) - out.backward(grad_out) - - with fp8_autocast(enabled=True, fp8_recipe=recipe): - out_ref = layer_normal(encoder_input, mask) - out_ref.backward(grad_out) - - assert_allclose(out, out_ref, rtol=rtol, atol=atol) - assert_allclose( - layer_cached.self_attention.layernorm_qkv.weight.grad, - layer_normal.self_attention.layernorm_qkv.weight.grad, - rtol=rtol, - atol=atol, - ) diff --git a/tests/paddle/test_master_grad.py b/tests/paddle/test_master_grad.py deleted file mode 100644 index c896a7871c..0000000000 --- a/tests/paddle/test_master_grad.py +++ /dev/null @@ -1,92 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""Test TransformerLayer encoder main_grad""" - -import numpy as np -import pytest - -import paddle -from paddle.distributed.fleet.utils import mix_precision_utils - -import transformer_engine.paddle as te -from transformer_engine.paddle.fp8 import is_fp8_available - -is_fp8_supported, reason = is_fp8_available() - - -def create_optimizer(model, use_pure_bf16, use_main_grad): - """Create optimizer""" - if use_main_grad: - assert use_pure_bf16 - model = mix_precision_utils.MixPrecisionLayer(model, dtype="bfloat16") - optimizer = paddle.optimizer.AdamW( - parameters=model.parameters(), - learning_rate=0.0001, - multi_precision=use_pure_bf16, - ) - if use_main_grad: - optimizer = mix_precision_utils.MixPrecisionOptimizer(optimizer) - - return optimizer - - -class Net(paddle.nn.Layer): - """Network use for main_grad testing""" - - def __init__(self, fuse_wgrad_accumulation): - super().__init__() - self.layer = te.TransformerLayer( - 4096, - 16384, - 32, - layer_type="encoder", - fuse_wgrad_accumulation=fuse_wgrad_accumulation, - ) - - def forward(self, inp): - out = self.layer(inp) - return out - - -def train(enable_master_grad, fuse_wgrad_accumulation=False): - """Train function""" - paddle.seed(10) - - accumulate_steps = 4 - - if fuse_wgrad_accumulation: - assert enable_master_grad, "fuse_wgrad_accumulation requires enable_master_grad" - - model = Net(fuse_wgrad_accumulation) - - optimizer = create_optimizer(model, use_pure_bf16=True, use_main_grad=enable_master_grad) - - loss_list = [] - for step_id in range(16): - inp = paddle.uniform([2, 1024, 4096], dtype="float32") - inp.stop_gradient = False - with te.fp8_autocast(enabled=True): - out = model(inp) - loss = out.mean() - loss_list.append(loss) - loss.backward() - - # gradient accumulation - if (step_id + 1) % accumulate_steps == 0: - optimizer.step() - optimizer.clear_grad() - - return loss_list - - -@pytest.mark.skipif(not is_fp8_supported, reason=reason) -def test_master_grad(): - """Test main_grad""" - paddle.set_default_dtype("float32") - loss1 = train(enable_master_grad=False) - loss2 = train(enable_master_grad=True) - loss3 = train(enable_master_grad=True, fuse_wgrad_accumulation=True) - - np.testing.assert_allclose(loss1, loss2, rtol=1e-5, atol=1e-5) - np.testing.assert_allclose(loss1, loss3, rtol=1e-5, atol=1e-5) diff --git a/tests/paddle/test_operators.py b/tests/paddle/test_operators.py deleted file mode 100644 index d9b1fa5cd1..0000000000 --- a/tests/paddle/test_operators.py +++ /dev/null @@ -1,1201 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""Test TE operators""" - -import struct - -import numpy as np -import paddle -import paddle.nn.functional as F -import pytest - -from utils import ( - assert_allclose, - create_fp8_meta, - get_fused_attention_backend, - is_fused_attention_supported, -) - -from transformer_engine import transformer_engine_paddle as tex -from transformer_engine.paddle.cpp_extensions import ( - cast_to_fp8, - cast_from_fp8, - gemm, - fp8_gemm, - transpose, - cast_transpose, - cast_transpose_bgrad, - te_gelu, - gelu_fp8, - swiglu, - swiglu_fp8, - swiglu_pd, - dswiglu, - dgelu_cast_transpose_bgrad_fp8, - layernorm_fwd_fp8, - layernorm_fwd, - layernorm_bwd, - rmsnorm_fwd_fp8, - rmsnorm_fwd, - rmsnorm_bwd, - fused_attn_fwd_qkvpacked, - fused_attn_bwd_qkvpacked, - fused_attn_fwd_kvpacked, - fused_attn_bwd_kvpacked, - fused_attn_fwd, - fused_attn_bwd, - scaled_softmax_forward, - scaled_softmax_backward, - scaled_masked_softmax_forward, - scaled_masked_softmax_backward, - scaled_upper_triang_masked_softmax_forward, - scaled_upper_triang_masked_softmax_backward, -) -from transformer_engine.paddle.fp8 import is_fp8_available -from transformer_engine.paddle.constants import FP8FwdTensors -from transformer_engine.common.recipe import DelayedScaling - -GEMM_CASES = [ - (256, 256, 512), - (32, 32, 32), - (16384, 1024, 2816), - (16384, 2816, 1024), - (16384, 1024, 1024), -] -is_fp8_supported, reason = is_fp8_available() - -SELF_ATTN_CASES = [(2, 512, 12, 64)] -CROSS_ATTN_CASES = [(2, 128, 512, 12, 64)] -FLASH_ATTN_CASES = [(2, 1024, 16, 64), (2, 2048, 16, 128)] -ATTN_DTYPES = [tex.DType.kFloat16, tex.DType.kBFloat16] - - -@pytest.fixture(autouse=True) -def setup(): - """Setup random seed before each test""" - np.random.seed(10) - paddle.seed(11) - yield - - -@pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]) -@pytest.mark.parametrize("inplace", [True, False]) -def test_quantize_dequantize(fp8_dtype, inplace): - """ - Test cast_to_fp8 and cast_from_fp8 - """ - a = paddle.rand(shape=(32, 32), dtype="float32") - # Init fp8_meta - fp8_meta = create_fp8_meta() - a_fp8 = paddle.zeros(shape=a.shape, dtype=paddle.uint8) if inplace else None - a_fp8 = cast_to_fp8(a, fp8_meta, FP8FwdTensors.GEMM1_OUTPUT, otype=fp8_dtype, out=a_fp8) - b = cast_from_fp8( - a_fp8, - fp8_meta, - FP8FwdTensors.GEMM1_OUTPUT, - itype=fp8_dtype, - otype=tex.DType.kFloat32, - ) - assert_allclose(a, b, rtol=5e-2, atol=5e-2) - - -def copy_bits_from_float_to_uint16(f): - """ - Copy bits - """ - return struct.unpack("> 16 - - -def convert_float_to_uint16(float_list): - """ - convert float to uint16 - """ - new_output = [] - for x in np.nditer(float_list): - new_output.append(np.uint16(copy_bits_from_float_to_uint16(x))) - new_output = np.reshape(new_output, float_list.shape).view(np.uint16) - - return new_output - - -class TestTranspose: - """ - Test transpose operators - """ - - @staticmethod - def test_transpose_bf16(): - """ - Test BF16 transpose - """ - a = paddle.rand(shape=(16, 32), dtype="bfloat16") - a_transposed = transpose(a, otype=tex.DType.kBFloat16) - assert_allclose(a_transposed, a.T) - - @staticmethod - @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]) - def test_transpose_fp8(fp8_dtype): - """ - Test FP8 transpose - """ - min_val = -8 - max_val = 8 - a = paddle.cast(paddle.randint(min_val, max_val, shape=(16, 32)), "float32") - fp8_meta = create_fp8_meta() - a_fp8 = cast_to_fp8(a, fp8_meta, FP8FwdTensors.GEMM1_INPUT, otype=fp8_dtype) - a_fp8_transposed = transpose(a_fp8, otype=fp8_dtype) - a_transposed = cast_from_fp8( - a_fp8_transposed, - fp8_meta, - FP8FwdTensors.GEMM1_INPUT, - itype=fp8_dtype, - otype=tex.DType.kFloat32, - ) - assert_allclose(a_transposed, a.T) - - @staticmethod - @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]) - @pytest.mark.parametrize("inplace", [True, False]) - def test_cast_transpose(fp8_dtype, inplace): - """ - Test cast_transpose - """ - min_val = -8 - max_val = 8 - a = paddle.cast(paddle.randint(min_val, max_val, shape=(16, 32)), "float32") - fp8_meta = create_fp8_meta() - a_fp8_casted, a_fp8_transposed = None, None - if inplace: - a_fp8_casted = paddle.zeros(shape=a.shape, dtype=paddle.uint8) - a_fp8_transposed = paddle.zeros(shape=a.T.shape, dtype=paddle.uint8) - a_fp8_casted, a_fp8_transposed = cast_transpose( - a, - fp8_meta, - FP8FwdTensors.GEMM1_INPUT, - otype=fp8_dtype, - cast_out=a_fp8_casted, - transpose_out=a_fp8_transposed, - ) - - a_transposed = cast_from_fp8( - a_fp8_transposed, - fp8_meta, - FP8FwdTensors.GEMM1_INPUT, - itype=fp8_dtype, - otype=tex.DType.kFloat32, - ) - - a_casted = cast_from_fp8( - a_fp8_casted, - fp8_meta, - FP8FwdTensors.GEMM1_INPUT, - itype=fp8_dtype, - otype=tex.DType.kFloat32, - ) - - assert_allclose(a_casted, a) - assert_allclose(a_transposed, a.T) - - @staticmethod - @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]) - def test_cast_transpose_bgrad(fp8_dtype): - """ - Test cast_transpose_bgrad - """ - min_val = -8 - max_val = 8 - a = paddle.cast(paddle.randint(min_val, max_val, shape=(16, 32)), "float32") - fp8_meta = create_fp8_meta() - bgrad, a_fp8_casted, a_fp8_transposed = cast_transpose_bgrad( - a, fp8_meta, FP8FwdTensors.GEMM1_INPUT, otype=fp8_dtype - ) - - a_transposed = cast_from_fp8( - a_fp8_transposed, - fp8_meta, - FP8FwdTensors.GEMM1_INPUT, - itype=fp8_dtype, - otype=tex.DType.kFloat32, - ) - - a_casted = cast_from_fp8( - a_fp8_casted, - fp8_meta, - FP8FwdTensors.GEMM1_INPUT, - itype=fp8_dtype, - otype=tex.DType.kFloat32, - ) - - assert_allclose(a_casted, a) - assert_allclose(a_transposed, a.T) - assert_allclose(bgrad, a.sum(axis=0)) - - -class TestActivation: - """ - Test activation operators - """ - - @staticmethod - def test_gelu_bf16(): - """ - Test BF16 GELU Forward - """ - a = paddle.rand(shape=(16, 32), dtype="bfloat16") * 2 - 1 - gelu_out = te_gelu(a, otype=tex.DType.kBFloat16) - gelu_ref = paddle.nn.GELU()(a) - - assert_allclose(gelu_out, gelu_ref, rtol=1e-2) - - @staticmethod - @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]) - def test_gelu_fp8(fp8_dtype): - """ - Test FP8 GELU Forward - """ - a = paddle.rand(shape=(16, 32), dtype="float32") * 2 - 1 - fp8_meta = create_fp8_meta() - - gelu_out_fp8 = gelu_fp8(a, fp8_meta, FP8FwdTensors.GEMM1_INPUT, otype=fp8_dtype) - - gelu_out = cast_from_fp8( - gelu_out_fp8, - fp8_meta, - FP8FwdTensors.GEMM1_INPUT, - itype=fp8_dtype, - otype=tex.DType.kFloat32, - ) - - gelu_ref = paddle.nn.GELU()(a) - - assert_allclose(gelu_out, gelu_ref, rtol=0.1, atol=0.01) - - @staticmethod - @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]) - def test_gelu_bwd_fp8(fp8_dtype): - """ - Test FP8 GELU Backward - """ - # y = GELU(x), calculate ref - x = paddle.rand(shape=(16, 32), dtype="float32") * 2 - 1 - x.stop_gradient = False - y = paddle.nn.GELU()(x) - y_grad = paddle.rand(shape=(16, 32), dtype="float32") * 2 - 1 - paddle.autograd.backward([y], [y_grad], True) - # calculate fp8 - fp8_meta = create_fp8_meta() - x_grad_fp8, x_grad_t_fp8, dbias = dgelu_cast_transpose_bgrad_fp8( - y_grad, x, fp8_meta, FP8FwdTensors.GEMM1_INPUT, otype=fp8_dtype - ) - - x_grad = cast_from_fp8( - x_grad_fp8, - fp8_meta, - FP8FwdTensors.GEMM1_INPUT, - itype=fp8_dtype, - otype=tex.DType.kFloat32, - ) - - x_grad_t = cast_from_fp8( - x_grad_t_fp8, - fp8_meta, - FP8FwdTensors.GEMM1_INPUT, - itype=fp8_dtype, - otype=tex.DType.kFloat32, - ) - - assert_allclose(x_grad, x.grad, rtol=0.1, atol=0.01) - assert_allclose(x_grad_t, x.grad.T, rtol=0.1, atol=0.01) - assert_allclose(dbias, x.grad.sum(axis=0), rtol=0.1, atol=0.01) - - @staticmethod - def test_swiglu_bf16(): - """ - Test BF16 SwiGLU Forward - """ - a = paddle.rand(shape=(16, 32), dtype="bfloat16") * 2 - 1 - swiglu_out = swiglu(a, otype=tex.DType.kBFloat16) - swiglu_ref = swiglu_pd(a) - - assert_allclose(swiglu_out, swiglu_ref, rtol=1e-2) - - @staticmethod - @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]) - def test_swiglu_fp8(fp8_dtype): - """ - Test FP8 SwiGLU Forward - """ - a = paddle.rand(shape=(16, 32), dtype="float32") * 2 - 1 - fp8_meta = create_fp8_meta() - - swiglu_out_fp8 = swiglu_fp8(a, fp8_meta, FP8FwdTensors.GEMM1_INPUT, otype=fp8_dtype) - - swiglu_out = cast_from_fp8( - swiglu_out_fp8, - fp8_meta, - FP8FwdTensors.GEMM1_INPUT, - itype=fp8_dtype, - otype=tex.DType.kFloat32, - ) - - swiglu_ref = swiglu_pd(a) - - assert_allclose(swiglu_out, swiglu_ref, rtol=0.1, atol=0.01) - - @staticmethod - def test_swiglu_bwd(): - """ - Test SwiGLU Backward - """ - # y = SwiGLU(x), calculate ref - x = paddle.rand(shape=(16, 32), dtype="bfloat16") * 2 - 1 - x.stop_gradient = False - y = swiglu_pd(x) - y_grad = paddle.rand(shape=(16, 16), dtype="bfloat16") * 2 - 1 - paddle.autograd.backward([y], [y_grad], True) - # calculate fp8 - x_grad = dswiglu(y_grad, x, otype=tex.DType.kBFloat16) - - assert_allclose(x_grad, x.grad, rtol=0.1, atol=0.01) - - -class TestGemm: - """ - Tests for gemm(cuBLASLt) operator - """ - - @staticmethod - @pytest.mark.skipif( - paddle.device.cuda.get_device_capability() < (8, 0), reason="BF16 GEMM requires Ampere+ GPU" - ) - @pytest.mark.parametrize("m,n,k", GEMM_CASES) - def test_bf16(m, n, k): - """ - Test "TN" BF16 GEMM - """ - a = paddle.rand(shape=(m, k), dtype="bfloat16") - b = paddle.rand(shape=(n, k), dtype="bfloat16") - - workspace = paddle.zeros(shape=[33_554_432], dtype="uint8") - - ref_out = paddle.matmul(a, b.T) - # CublasLt inside tex.te_gemm assumes inputs are column major. - # Mathematically, A@B=C is equivalent to B^T@A^T=C^T, where X^T is the - # transpose of X. - # Here we perform "TN" GEMM in column major, i.e., b@a^T = C^T, - # which is equivalent to a@b^T = C in row major. - actual_out, _, _ = gemm( - b, a, paddle.bfloat16, workspace, False, None, False, False, "TN", None, None, False - ) - - assert_allclose(actual_out, ref_out, rtol=1.6e-2, atol=1e-5) - - @staticmethod - @pytest.mark.skipif( - paddle.device.cuda.get_device_capability() < (8, 0), reason="BF16 GEMM requires Ampere+ GPU" - ) - @pytest.mark.parametrize("m,n,k", GEMM_CASES) - def test_bf16_inplace(m, n, k): - """ - Test "TN" BF16 GEMM, with accumulate=True - """ - min_val = -16 - max_val = 16 - a = paddle.rand(shape=(m, k), dtype="bfloat16") - b = paddle.rand(shape=(n, k), dtype="bfloat16") - c = paddle.cast(paddle.randint(min_val, max_val, shape=(m, n)), "bfloat16") - workspace = paddle.zeros(shape=[33_554_432], dtype="uint8") - - ref_out = c + paddle.matmul(a, b.T) - - actual_out = paddle.clone(c) - _, _, _ = gemm( - b, - a, - paddle.bfloat16, - workspace, - False, - None, - False, - True, - "TN", - actual_out, - None, - False, - ) - - assert_allclose(actual_out, ref_out, rtol=5e-2, atol=5e-2) - - @staticmethod - @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest.mark.parametrize("m,n,k", GEMM_CASES) - def test_fp8_randint(m, n, k): - """ - Test "TN" FP8 GEMM - """ - min_val = -4 - max_val = 4 - fp8_dtype = tex.DType.kFloat8E4M3 - out_dtype = paddle.float32 - fp8_meta = create_fp8_meta(num_gemms=1) - - a = paddle.cast(paddle.randint(min_val, max_val, shape=(m, k)), "float32") - - a_casted = cast_to_fp8(a, fp8_meta, FP8FwdTensors.GEMM1_INPUT, otype=fp8_dtype) - b = paddle.cast(paddle.randint(min_val, max_val, shape=(n, k)), "float32") - b_casted = cast_to_fp8(b, fp8_meta, FP8FwdTensors.GEMM1_WEIGHT, otype=fp8_dtype) - workspace = paddle.zeros(shape=[33_554_432], dtype="uint8") - - ref_out = paddle.matmul(a, b.T) - actual_out, _ = fp8_gemm( - b_casted, - fp8_meta.scale_inv, - FP8FwdTensors.GEMM1_WEIGHT, - fp8_dtype, - a_casted, - fp8_meta.scale_inv, - FP8FwdTensors.GEMM1_INPUT, - fp8_dtype, - out_dtype, - workspace, - ) - - assert_allclose(actual_out, ref_out) - - -class TestLayerNorm: - """ - Test layernorm operators - """ - - @staticmethod - def calc_fwd_ref(x, eps, gamma, beta): - """ - Calculate reference using paddle layer_norm op - """ - y = paddle.nn.functional.layer_norm( - x=x, normalized_shape=x.shape[1:], weight=gamma, bias=beta, epsilon=eps - ) - mean = paddle.mean(x, axis=-1) - var = paddle.var(x, axis=-1) - inv_var = paddle.sqrt(1.0 / var) - return y, mean, inv_var - - @staticmethod - def calc_bwd_ref(x, eps, gamma, beta, dy): - """ - Calculate reference using paddle layer_norm op - """ - x.stop_gradient = False - gamma.stop_gradient = False - beta.stop_gradient = False - - y = paddle.nn.functional.layer_norm( - x=x, normalized_shape=x.shape[1:], weight=gamma, bias=beta, epsilon=eps - ) - - paddle.autograd.backward([y], [dy], True) - - return x.grad, gamma.grad, beta.grad - - def test_layernorm_fwd(self): - """ - Test BF16 LayerNorm Forward - """ - N, H = (16, 32) - eps = 1e-3 - x = paddle.uniform(shape=(N, H), dtype="bfloat16") - gamma = paddle.uniform(shape=(H,), dtype="bfloat16") - beta = paddle.uniform(shape=(H,), dtype="bfloat16") - - y, mu, rsigma = layernorm_fwd(x, gamma, beta, eps, tex.DType.kBFloat16) - - y_ref, mu_ref, rsigma_ref = self.calc_fwd_ref(x, eps, gamma, beta) - - assert_allclose(y, y_ref, rtol=1e-4, atol=1e-4) - assert_allclose(mu, mu_ref, rtol=1e-3, atol=1e-3) - assert_allclose(rsigma, rsigma_ref, rtol=5e-2, atol=5e-2) - - @staticmethod - def test_layernorm_fwd_fp8(): - """ - Test FP8 LayerNorm Forward - """ - fp8_dtype = tex.DType.kFloat8E4M3 - N, H = (16, 32) - eps = 1e-3 - - x = paddle.uniform(shape=(N, H), dtype="float32") - gamma = paddle.uniform(shape=(H,), dtype="float32") - beta = paddle.uniform(shape=(H,), dtype="float32") - - fp8_tensor = FP8FwdTensors.GEMM1_INPUT - fp8_meta = create_fp8_meta() - - y_ref, mu_ref, rsigma_ref = layernorm_fwd(x, gamma, beta, eps, tex.DType.kFloat32) - - y_fp8, mu, rsigma = layernorm_fwd_fp8(x, gamma, beta, eps, fp8_meta, fp8_tensor, fp8_dtype) - - y = cast_from_fp8(y_fp8, fp8_meta, fp8_tensor, itype=fp8_dtype, otype=tex.DType.kFloat32) - - assert_allclose(y, y_ref, rtol=0.1, atol=0.01) - assert_allclose(mu, mu_ref) - assert_allclose(rsigma, rsigma_ref) - - def test_layernorm_bwd(self): - """ - Test BF16 LayerNorm Backward - """ - N, H = (16, 32) - eps = 1e-3 - x = paddle.uniform(shape=(N, H), dtype="bfloat16") - dy = paddle.uniform(shape=(N, H), dtype="bfloat16") - gamma = paddle.uniform(shape=(H,), dtype="bfloat16") - beta = paddle.uniform(shape=(H,), dtype="bfloat16") - - dx_ref, dgamma_ref, dbeta_ref = self.calc_bwd_ref(x, eps, gamma, beta, dy) - - _, mu, rsigma = layernorm_fwd(x, gamma, beta, eps, tex.DType.kBFloat16) - dx, dgamma, dbeta = layernorm_bwd(dy, x, mu, rsigma, gamma) - - assert_allclose(dx, dx_ref, rtol=1e-5, atol=1e-5) - assert_allclose(dgamma, dgamma_ref, rtol=1e-5, atol=1e-5) - assert_allclose(dbeta, dbeta_ref, rtol=1e-5, atol=1e-5) - - -class TestRMSNorm: - """ - Test rmsnorm operators - """ - - @staticmethod - def calc_fwd_ref(x, eps, gamma): - """ - Calculate rmsnorm reference using paddle op - """ - - norm = paddle.rsqrt(paddle.mean(x**2, axis=-1, keepdim=True) + eps) - y = x * norm * gamma - - return y - - def calc_bwd_ref(self, x, eps, gamma, dy): - """ - Calculate rmsnorm bwd reference using paddle op - """ - x.stop_gradient = False - gamma.stop_gradient = False - - y = self.calc_fwd_ref(x, eps, gamma) - - paddle.autograd.backward([y], [dy], True) - - return x.grad, gamma.grad - - def test_rmsnorm_fwd(self): - """ - Test BF16 RMSNorm Forward - """ - N, H = (16, 32) - eps = 1e-3 - x = paddle.uniform(shape=(N, H), dtype="bfloat16") - gamma = paddle.uniform(shape=(H,), dtype="bfloat16") - - y, _ = rmsnorm_fwd(x, gamma, eps, tex.DType.kBFloat16) - - y_ref = self.calc_fwd_ref(x, eps, gamma) - - assert_allclose(y, y_ref, rtol=1e-2, atol=1e-2) - - @staticmethod - def test_rmsnorm_fwd_fp8(): - """ - Test FP8 RMSNorm Forward - """ - fp8_dtype = tex.DType.kFloat8E4M3 - N, H = (16, 32) - eps = 1e-3 - - x = paddle.uniform(shape=(N, H), dtype="float32") - gamma = paddle.uniform(shape=(H,), dtype="float32") - - fp8_tensor = FP8FwdTensors.GEMM1_INPUT - fp8_meta = create_fp8_meta() - - y_ref, rsigma_ref = rmsnorm_fwd(x, gamma, eps, tex.DType.kFloat32) - - y_fp8, rsigma = rmsnorm_fwd_fp8(x, gamma, eps, fp8_meta, fp8_tensor, fp8_dtype) - - y = cast_from_fp8(y_fp8, fp8_meta, fp8_tensor, itype=fp8_dtype, otype=tex.DType.kFloat32) - - assert_allclose(y, y_ref, rtol=0.1, atol=0.01) - assert_allclose(rsigma, rsigma_ref) - - def test_rmsnorm_bwd(self): - """ - Test BF16 RMSNorm Backward - """ - N, H = (16, 32) - eps = 1e-3 - x = paddle.uniform(shape=(N, H), dtype="bfloat16") - dy = paddle.uniform(shape=(N, H), dtype="bfloat16") - gamma = paddle.uniform(shape=(H,), dtype="bfloat16") - - dx_ref, dgamma_ref = self.calc_bwd_ref(x, eps, gamma, dy) - - _, rsigma = rmsnorm_fwd(x, gamma, eps, tex.DType.kBFloat16) - dx, dgamma = rmsnorm_bwd(dy, x, rsigma, gamma) - - assert_allclose(dx, dx_ref, rtol=1e-2, atol=1e-2) - assert_allclose(dgamma, dgamma_ref, rtol=1e-2, atol=5e-2) - - -class TestFusedAttn: - """ - Test fused attention operators - """ - - def set_input(self, b, s_q, s_kv, h, d, dtype, attn_mode="self_attn", is_causal_masking=False): - """ - set test input - """ - - def _random(shape): - if self.dtype == "bfloat16": - data = np.random.normal(loc=0.0, scale=0.02, size=shape).astype("float32") - return convert_float_to_uint16(data) - return np.random.normal(loc=0.0, scale=0.02, size=shape).astype(self.dtype) - - self.batch_size = b - self.q_seqlen = s_q - self.kv_seqlen = s_kv - self.num_heads = h - self.head_size = d - self.dropout_prob = 0.0 - self.scaling_factor = 1.0 / np.sqrt(d) - self.q_shape = (b, s_q, h, d) - self.kv_shape = (b, s_kv, h, d) - self.fuse_qkv_shape = (b, s_q, 3, h, d) - self.fuse_kv_shape = (b, s_kv, 2, h, d) - self.bias_shape = (1, h, s_q, s_kv) - self.attn_mode = attn_mode - self.dtype = dtype - self.is_causal_masking = is_causal_masking - - self.q = _random(self.q_shape) - if self.attn_mode == "self_attn": - assert self.q_seqlen == self.kv_seqlen, "self attention requires q_seqlen == kv_seqlen" - self.kv = self.q - else: - self.kv = _random(self.kv_shape) - - self.q_actual_seqlen = None - if self.is_causal_masking: - self.q_actual_seqlen = np.full( - self.batch_size, - self.q_seqlen, - dtype=np.int32, - ) - else: - self.q_actual_seqlen = np.random.randint( - low=20, - high=self.q_seqlen, - size=(self.batch_size,), - dtype=np.int32, - ) - self.kv_actual_seqlen = self.q_actual_seqlen - - self.q_cu_seqlen = np.cumsum(self.q_actual_seqlen) - self.q_cu_seqlen = np.insert(self.q_cu_seqlen, 0, 0) - self.kv_cu_seqlen = np.cumsum(self.kv_actual_seqlen) - self.kv_cu_seqlen = np.insert(self.kv_cu_seqlen, 0, 0) - self.attn_mask = np.ones( - shape=(self.batch_size, 1, self.q_seqlen, self.kv_seqlen), - dtype=np.int32, - ) - if self.is_causal_masking: - assert attn_mode == "self_attn", "only support causal masking for self attention" - for i in range(0, self.batch_size): - for j in range(self.q_actual_seqlen[i]): - self.attn_mask[i, :, j, : j + 1] = 0 - else: - for i in range(0, self.batch_size): - self.attn_mask[i, :, : self.q_actual_seqlen[i], : self.kv_actual_seqlen[i]] = 0 - - dout = _random((self.batch_size, self.q_seqlen, self.num_heads, self.head_size)) - self.dout = paddle.to_tensor(dout, dtype=self.dtype) - - def _get_reference_out(self): - paddle.disable_static(place=paddle.CUDAPlace(0)) - q_tensor = paddle.to_tensor(self.q, stop_gradient=False) - k_tensor = paddle.to_tensor(self.kv, stop_gradient=False) - v_tensor = paddle.to_tensor(self.kv, stop_gradient=False) - - q_out = paddle.transpose(x=q_tensor, perm=[0, 2, 1, 3]) # [b, s, h, d] -> [b, h, s, d] - k_out = paddle.transpose(x=k_tensor, perm=[0, 2, 1, 3]) # [b, s, h, d] -> [b, h, s, d] - v_out = paddle.transpose(x=v_tensor, perm=[0, 2, 1, 3]) # [b, s, h, d] -> [b, h, s, d] - - qk_out = paddle.matmul( - x=q_out * self.scaling_factor, - y=k_out, - transpose_x=False, - transpose_y=True, - ) - - attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=True).cast("bool") - attn_mask_vals = paddle.full(qk_out.shape, -1e4, qk_out.dtype) - attn_mask_out = paddle.where(attn_mask, attn_mask_vals, qk_out) - attn_mask_out = paddle.cast(attn_mask_out, "float32") - softmax_out = F.softmax(attn_mask_out) - softmax_out = paddle.cast(softmax_out, self.dtype) - - if self.dropout_prob: - dropout_out = F.dropout( - softmax_out, - self.dropout_prob, - training=self.training, - mode="upscale_in_train", - ) - qkv_out = paddle.matmul(dropout_out, v_out) - else: - qkv_out = paddle.matmul(softmax_out, v_out) - - out = paddle.transpose(qkv_out, perm=[0, 2, 1, 3]) # [b, h, s, d] -> [b, s, h, d] - - paddle.autograd.backward( - [out], - [self.dout], - retain_graph=True, - ) - return out, q_tensor.grad, k_tensor.grad, v_tensor.grad - - def _get_fused_attention_out(self): - paddle.disable_static(place=paddle.CUDAPlace(0)) - - if self.attn_mode == "self_attn": - qkv = np.stack([self.q, self.kv, self.kv], axis=2) # [b, s, 3, h, d] - qkv_tensor = paddle.to_tensor(qkv, stop_gradient=False) - else: - q_tensor = paddle.to_tensor(self.q, stop_gradient=False) - kv = np.stack([self.kv, self.kv], axis=2) # [b, s, 2, h, d] - kv_tensor = paddle.to_tensor(kv, stop_gradient=False) - - q_cu_seqlen_tensor = paddle.to_tensor(self.q_cu_seqlen, dtype="int32", stop_gradient=True) - kv_cu_seqlen_tensor = paddle.to_tensor(self.kv_cu_seqlen, dtype="int32", stop_gradient=True) - - qkv_layout = "bs3hd" if self.attn_mode == "self_attn" else "bshd_bs2hd" - fused_attention_backend = get_fused_attention_backend( - num_heads=self.num_heads, - num_gqa_groups=self.num_heads, - q_seqlen=self.q_seqlen, - kv_seqlen=self.kv_seqlen, - head_size=self.head_size, - dtype=self.dtype, - dropout=self.dropout_prob, - qkv_layout=qkv_layout, - bias_type="no_bias", - mask_type="causal" if self.is_causal_masking else "padding", - ) - - qkv_dtype = tex.DType.kBFloat16 if self.dtype == "bfloat16" else tex.DType.kFloat16 - out, softmax_aux_tensor, q_grad, k_grad, v_grad = None, None, None, None, None - if self.attn_mode == "self_attn": - out, softmax_aux_tensor, rng_state = fused_attn_fwd_qkvpacked( - qkv_tensor, - q_cu_seqlen_tensor, - is_training=True, - max_seqlen=self.q_seqlen, - qkv_dtype=qkv_dtype, - fused_attention_backend=fused_attention_backend, - Bias=None, - attn_scale=self.scaling_factor, - dropout=self.dropout_prob, - set_zero=False, - attn_mask_type="causal" if self.is_causal_masking else "padding", - ) - dqkv, _ = fused_attn_bwd_qkvpacked( - qkv_tensor, - q_cu_seqlen_tensor, - rng_state, - out, - self.dout, - softmax_aux_tensor, - max_seqlen=self.q_seqlen, - qkv_dtype=qkv_dtype, - fused_attention_backend=fused_attention_backend, - attn_scale=self.scaling_factor, - dropout=self.dropout_prob, - set_zero=False, - attn_mask_type="causal" if self.is_causal_masking else "padding", - ) - q_grad = dqkv[:, :, 0, :, :] - k_grad = dqkv[:, :, 1, :, :] - v_grad = dqkv[:, :, 2, :, :] - else: # attn_mode == 'cross_attn' - out, softmax_aux_tensor, rng_state = fused_attn_fwd_kvpacked( - q_tensor, - kv_tensor, - q_cu_seqlen_tensor, - kv_cu_seqlen_tensor, - is_training=True, - max_seqlen_q=self.q_seqlen, - max_seqlen_kv=self.kv_seqlen, - qkv_dtype=qkv_dtype, - fused_attention_backend=fused_attention_backend, - Bias=None, - attn_scale=self.scaling_factor, - dropout=self.dropout_prob, - set_zero=False, - ) - dq, dkv, _ = fused_attn_bwd_kvpacked( - q_tensor, - kv_tensor, - q_cu_seqlen_tensor, - kv_cu_seqlen_tensor, - rng_state, - out, - self.dout, - softmax_aux_tensor, - fused_attention_backend=fused_attention_backend, - max_seqlen_q=self.q_seqlen, - max_seqlen_kv=self.kv_seqlen, - qkv_dtype=qkv_dtype, - attn_scale=self.scaling_factor, - dropout=self.dropout_prob, - set_zero=False, - ) - q_grad = dq - k_grad = dkv[:, :, 0, :, :] - v_grad = dkv[:, :, 1, :, :] - - return out, q_grad, k_grad, v_grad - - def _get_fused_attention_with_separate_qkv(self): - paddle.disable_static(place=paddle.CUDAPlace(0)) - - q_tensor = paddle.to_tensor(self.q, stop_gradient=False) - k_tensor = paddle.to_tensor(self.kv, stop_gradient=False) - v_tensor = paddle.to_tensor(self.kv, stop_gradient=False) - - q_cu_seqlen_tensor = paddle.to_tensor(self.q_cu_seqlen, dtype="int32", stop_gradient=True) - kv_cu_seqlen_tensor = paddle.to_tensor(self.kv_cu_seqlen, dtype="int32", stop_gradient=True) - - qkv_layout = "bshd_bshd_bshd" - fused_attention_backend = get_fused_attention_backend( - num_heads=self.num_heads, - num_gqa_groups=self.num_heads, - q_seqlen=self.q_seqlen, - kv_seqlen=self.kv_seqlen, - head_size=self.head_size, - dtype=self.dtype, - dropout=self.dropout_prob, - qkv_layout=qkv_layout, - bias_type="no_bias", - mask_type="causal" if self.is_causal_masking else "padding", - ) - - qkv_dtype = tex.DType.kBFloat16 if self.dtype == "bfloat16" else tex.DType.kFloat16 - out, softmax_aux_tensor, rng_state = fused_attn_fwd( - q_tensor, - k_tensor, - v_tensor, - q_cu_seqlen_tensor, - kv_cu_seqlen_tensor, - is_training=True, - max_seqlen_q=self.q_seqlen, - max_seqlen_kv=self.kv_seqlen, - qkv_dtype=qkv_dtype, - fused_attention_backend=fused_attention_backend, - Bias=None, - attn_scale=self.scaling_factor, - dropout=self.dropout_prob, - set_zero=False, - qkv_layout=qkv_layout, - attn_mask_type="causal" if self.is_causal_masking else "padding", - ) - dq, dk, dv, _ = fused_attn_bwd( - q_tensor, - k_tensor, - v_tensor, - q_cu_seqlen_tensor, - kv_cu_seqlen_tensor, - rng_state, - out, - self.dout, - softmax_aux_tensor, - fused_attention_backend=fused_attention_backend, - max_seqlen_q=self.q_seqlen, - max_seqlen_kv=self.kv_seqlen, - qkv_dtype=qkv_dtype, - attn_scale=self.scaling_factor, - dropout=self.dropout_prob, - set_zero=False, - qkv_layout=qkv_layout, - attn_mask_type="causal" if self.is_causal_masking else "padding", - ) - - return out, dq, dk, dv - - @pytest.mark.parametrize("b, s, h, d", SELF_ATTN_CASES) - @pytest.mark.parametrize("dtype", ["float16", "bfloat16"]) - @pytest.mark.parametrize("is_causal_masking", [True, False]) - def test_self_attn_forward_backward(self, b, s, h, d, dtype, is_causal_masking): - """ - test self attention forward + backward - """ - if not is_fused_attention_supported( - num_heads=h, - num_gqa_groups=h, - q_seqlen=s, - kv_seqlen=s, - head_size=d, - dtype=dtype, - dropout=0.0, - qkv_layout="bs3hd", - bias_type="no_bias", - mask_type="causal" if is_causal_masking else "padding", - ): - pytest.skip("cuDNN fused attention is not supported") - self.set_input(b, s, s, h, d, dtype, "self_attn", is_causal_masking) - reference_out, q_grad_ref, k_grad_ref, v_grad_ref = self._get_reference_out() - fused_attention_out, q_grad, k_grad, v_grad = self._get_fused_attention_out() - assert_allclose(reference_out, fused_attention_out, rtol=1e-3, atol=1e-2) - assert_allclose(q_grad_ref, q_grad, rtol=1e-3, atol=1e-2) - assert_allclose(k_grad_ref, k_grad, rtol=1e-3, atol=1e-2) - assert_allclose(v_grad_ref, v_grad, rtol=1e-3, atol=1e-2) - - @pytest.mark.parametrize("b, s_q, s_kv, h, d", CROSS_ATTN_CASES) - @pytest.mark.parametrize("dtype", ["float16", "bfloat16"]) - def test_cross_attn_forward_backward(self, b, s_q, s_kv, h, d, dtype): - """ - test cross attention forward + backward - """ - if not is_fused_attention_supported( - num_heads=h, - num_gqa_groups=h, - q_seqlen=s_q, - kv_seqlen=s_kv, - head_size=d, - dtype=dtype, - dropout=0.0, - qkv_layout="bshd_bs2hd", - bias_type="no_bias", - mask_type="padding", - ): - pytest.skip("cuDNN fused attention is not supported") - self.set_input(b, s_q, s_kv, h, d, dtype, "cross_attn") - reference_out, q_grad_ref, k_grad_ref, v_grad_ref = self._get_reference_out() - fused_attention_out, q_grad, k_grad, v_grad = self._get_fused_attention_out() - assert_allclose(reference_out, fused_attention_out, rtol=1e-3, atol=1e-2) - assert_allclose(q_grad_ref, q_grad, rtol=1e-3, atol=1e-2) - assert_allclose(k_grad_ref, k_grad, rtol=1e-3, atol=1e-2) - assert_allclose(v_grad_ref, v_grad, rtol=1e-3, atol=1e-2) - - @pytest.mark.parametrize("b, s, h, d", FLASH_ATTN_CASES) - @pytest.mark.parametrize("dtype", ["float16", "bfloat16"]) - @pytest.mark.parametrize("is_causal_masking", [True]) - def test_flash_attn_forward_backward(self, b, s, h, d, dtype, is_causal_masking): - """ - test flash attention forward + backward - """ - if not is_fused_attention_supported( - num_heads=h, - num_gqa_groups=h, - q_seqlen=s, - kv_seqlen=s, - head_size=d, - dtype=dtype, - dropout=0.0, - qkv_layout="bs3hd", - bias_type="no_bias", - mask_type="causal" if is_causal_masking else "padding", - ): - pytest.skip("cuDNN fused attention is not supported") - self.set_input(b, s, s, h, d, dtype, "self_attn", is_causal_masking) - reference_out, q_grad_ref, k_grad_ref, v_grad_ref = self._get_reference_out() - fused_attention_out, q_grad, k_grad, v_grad = self._get_fused_attention_out() - assert_allclose(reference_out, fused_attention_out, rtol=1e-3, atol=1e-2) - assert_allclose(q_grad_ref, q_grad, rtol=1e-3, atol=1e-2) - assert_allclose(k_grad_ref, k_grad, rtol=1e-3, atol=1e-2) - assert_allclose(v_grad_ref, v_grad, rtol=1e-3, atol=1e-2) - - @pytest.mark.parametrize("b, s, h, d", FLASH_ATTN_CASES) - @pytest.mark.parametrize("dtype", ["float16", "bfloat16"]) - @pytest.mark.parametrize("is_causal_masking", [False, True]) - def test_fused_attn_with_separate_qkv_forward_backward( - self, b, s, h, d, dtype, is_causal_masking - ): - """ - test flash attention forward + backward with separate qkv inputs - """ - if not is_fused_attention_supported( - num_heads=h, - num_gqa_groups=h, - q_seqlen=s, - kv_seqlen=s, - head_size=d, - dtype=dtype, - dropout=0.0, - qkv_layout="bshd_bshd_bshd", - bias_type="no_bias", - mask_type="causal" if is_causal_masking else "padding", - ): - pytest.skip("cuDNN fused attention is not supported") - self.set_input(b, s, s, h, d, dtype, "self_attn", is_causal_masking) - reference_out, q_grad_ref, k_grad_ref, v_grad_ref = self._get_reference_out() - fused_attention_out, q_grad, k_grad, v_grad = self._get_fused_attention_with_separate_qkv() - assert_allclose(reference_out, fused_attention_out, rtol=1e-3, atol=1e-2) - assert_allclose(q_grad_ref, q_grad, rtol=1e-3, atol=1e-2) - assert_allclose(k_grad_ref, k_grad, rtol=1e-3, atol=1e-2) - assert_allclose(v_grad_ref, v_grad, rtol=1e-3, atol=1e-2) - - -class TestSoftmax: - """ - Test softmax operators - """ - - @staticmethod - @pytest.mark.parametrize("dtype", ["float16", "bfloat16"]) - def test_scaled_softmax_fwd_bwd(dtype): - """test scaled softmax""" - B, H, S = (16, 4, 32) - scale = 0.8 - - x = paddle.uniform(shape=(B, H, S, S), dtype=dtype) - x.stop_gradient = False - dy = paddle.uniform(shape=(B, H, S, S), dtype=dtype) - - y_ref = F.softmax(scale * x) - y = scaled_softmax_forward(x, scale) - - paddle.autograd.backward([y_ref], [dy], True) - dx_ref = x.grad - dx = scaled_softmax_backward(dy, y, scale) - - assert_allclose(y_ref, y, rtol=1e-4, atol=1e-3) - assert_allclose(dx_ref, dx, rtol=1e-4, atol=1e-3) - - @staticmethod - @pytest.mark.parametrize("dtype", ["float16", "bfloat16"]) - def test_scaled_masked_softmax_fwd_bwd(dtype): - """test scaled masked softmax""" - B, H, S = (16, 4, 32) - scale = 0.8 - - x = paddle.uniform(shape=(B, H, S, S), dtype=dtype) - x.stop_gradient = False - dy = paddle.uniform(shape=(B, H, S, S), dtype=dtype) - mask = paddle.reshape(x[0, 0] > 0.3, shape=(1, 1, S, S)) - mask_flipped = x[0, 0] <= 0.3 - mask_ref = (mask_flipped.astype(dtype) - 1.0) * 1e4 - - y_ref = F.softmax(scale * x + mask_ref) - y = scaled_masked_softmax_forward(x, mask, scale) - - paddle.autograd.backward([y_ref], [dy], True) - dx_ref = x.grad - dx = scaled_masked_softmax_backward(dy, y, scale) - - assert_allclose(y_ref, y, rtol=1e-4, atol=1e-3) - assert_allclose(dx_ref, dx, rtol=1e-4, atol=1e-3) - - @staticmethod - @pytest.mark.parametrize("dtype", ["float16", "bfloat16"]) - def test_scaled_upper_triang_masked_softmax_fwd_bwd(dtype): - """test scaled upper triang masked softmax""" - B, S = (16, 32) - scale = 0.8 - - x = paddle.uniform(shape=(B, S, S), dtype=dtype) - x.stop_gradient = False - dy = paddle.uniform(shape=(B, S, S), dtype=dtype) - - mask = paddle.ones((S, S), dtype="int32") - col_beg, col_end = 1, S - for row in range(0, S): - mask[row, col_beg:col_end] = 0 - col_beg += 1 - - mask_ref = (mask.astype(dtype) - 1.0) * 1e4 - - y_ref = F.softmax(scale * x + mask_ref) - y = scaled_upper_triang_masked_softmax_forward(x, scale) - - paddle.autograd.backward([y_ref], [dy], True) - dx_ref = x.grad - dx = scaled_upper_triang_masked_softmax_backward(dy, y, scale) - - assert_allclose(y_ref, y, rtol=1e-4, atol=5e-3) - assert_allclose(dx_ref, dx, rtol=1e-4, atol=5e-3) - - -@pytest.mark.parametrize("update_weight_scale_inv", [True, False]) -def test_amax_and_scale_update(update_weight_scale_inv): - """Test update_scale""" - num_gemm = 6 - history_len = 1024 - recipe = DelayedScaling() - fp8_dtype = tex.DType.kFloat8E4M3 - fp8_max = recipe.fp8_format.value.max_fwd - non_weight_mask = paddle.to_tensor([True, False] * (num_gemm // 2)) - - amax_history_tensor = paddle.rand(shape=[history_len, num_gemm], dtype="float32") - rolled_history_ref = paddle.roll(amax_history_tensor, -1, axis=0) - rolled_history_ref[0] = 0.0 - amax_tensor = paddle.max(amax_history_tensor, axis=0) - scale_tensor = paddle.ones(shape=[num_gemm], dtype="float32") - - def calc_ref(amax, scale, fp8_max, margin=0): - """Calculate reference scale""" - sf = (fp8_max / amax) / (2**margin) - sf = paddle.where(amax > 0.0, sf, scale) - sf = paddle.where(paddle.isfinite(amax), sf, scale) - return sf - - scale_ref = calc_ref(amax_tensor, scale_tensor, fp8_max, 0.0) - if update_weight_scale_inv: - scale_inv_ref = 1.0 / scale_ref - else: - scale_inv_ref = paddle.zeros_like(scale_tensor) - scale_inv_ref = paddle.where(non_weight_mask, 1.0 / scale_ref, scale_inv_ref) - - # Placeholder - scale_actual = paddle.zeros_like(scale_tensor) - scale_inv_actual = paddle.zeros_like(scale_tensor) - - if update_weight_scale_inv: - non_weight_mask = paddle.empty([0]) - tex.amax_and_scale_update_inplace( - _amax_history=amax_history_tensor, - _scale=scale_actual, - _scale_inv=scale_inv_actual, - non_weight_mask=non_weight_mask, - fp8_dtype=int(fp8_dtype), - margin=0.0, - amax_compute="max", - ) - - assert_allclose(scale_actual, scale_ref, rtol=1e-7, atol=1e-7) - assert_allclose(scale_inv_actual, scale_inv_ref, rtol=1e-7, atol=1e-7) - assert_allclose(amax_history_tensor, rolled_history_ref, rtol=1e-7, atol=1e-7) - - -def test_update_latest_history(): - """Test update_latest_history""" - num_gemm = 6 - history_len = 1024 - - amax_history_tensor = paddle.rand(shape=[history_len, num_gemm], dtype="float32") - amax = paddle.rand(shape=[num_gemm], dtype="float32") - - tex.update_latest_amax_history_inplace(_history=amax_history_tensor, amax=amax) - - assert_allclose(amax_history_tensor[0], amax, rtol=1e-7, atol=1e-7) diff --git a/tests/paddle/test_parallel.py b/tests/paddle/test_parallel.py deleted file mode 100644 index 82f970b2c8..0000000000 --- a/tests/paddle/test_parallel.py +++ /dev/null @@ -1,99 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""Test TE Paddle Parallel""" - -from pathlib import Path -import unittest - -from dist_launcher import TestDistributed -from utils import is_devices_enough - -from transformer_engine.paddle.fp8 import is_fp8_available - -test_root = Path(__file__).resolve().parent -gpu_has_fp8, reason = is_fp8_available() - - -class TestParallelLinear(TestDistributed): - """Test Linear in Parallel mode""" - - @unittest.skipIf(not is_devices_enough(2), "TestParallelLinear needs 2 GPUs") - @unittest.skipIf(not gpu_has_fp8, reason) - def test_linear_tp(self): - """Tests linear with tensor parallel in BF16""" - self.run_2gpu(str(test_root / "parallel_tests" / "linear_tp.py")) - - -class TestParallelLayerNormLinear(TestDistributed): - """Test LayerNormLinear in Parallel mode""" - - @unittest.skipIf(not is_devices_enough(2), "TestParallelLayerNormLinear needs 2 GPUs") - @unittest.skipIf(not gpu_has_fp8, reason) - def test_layernorm_linear_tp(self): - """Tests layernorm_linear with tensor parallel in BF16""" - self.run_2gpu(str(test_root / "parallel_tests" / "layernorm_linear_tp.py")) - - -class TestParallelLayerNormMLP(TestDistributed): - """Test LayerNormMLP in Parallel mode""" - - @unittest.skipIf(not is_devices_enough(2), "TestParallelLayerNormMLP needs 2 GPUs") - @unittest.skipIf(not gpu_has_fp8, reason) - def test_layernorm_mlp_tp(self): - """Tests layernorm_mlp with tensor parallel in BF16""" - self.run_2gpu(str(test_root / "parallel_tests" / "layernorm_mlp_tp.py")) - - -class TestAmaxReduction(TestDistributed): - """Test amax reduction in dp mode""" - - @unittest.skipIf(not is_devices_enough(2), "TestAmaxReduction needs 2 GPUs") - @unittest.skipIf(not gpu_has_fp8, reason) - def test_amax_reduction(self): - """Tests amax reduction""" - self.run_2gpu(str(test_root / "parallel_tests" / "amax_reduction.py")) - - -class TestPipelineParallel(TestDistributed): - """Test pipeline parallel""" - - @unittest.skipIf(not is_devices_enough(2), "TestPipelineParallel needs 2 GPUs") - @unittest.skipIf(not gpu_has_fp8, reason) - def test_pipeline_parallel(self): - """Tests pipeline parallel""" - self.run_2gpu(str(test_root / "parallel_tests" / "linear_pp.py")) - - -class TestGroupSharding(TestDistributed): - """Test group sharding""" - - @unittest.skipIf(not is_devices_enough(2), "TestGroupSharding needs 2 GPUs") - @unittest.skipIf(not gpu_has_fp8, reason) - def test_group_sharding(self): - """Tests group sharding""" - self.run_2gpu(str(test_root / "parallel_tests" / "group_sharding.py")) - - -class TestParallelAttention(TestDistributed): - """Test MultiHeadAttention Layer in Parallel mode""" - - @unittest.skipIf(not is_devices_enough(2), "TestParallelAttention needs 2 GPUs") - @unittest.skipIf(not gpu_has_fp8, reason) - def test_attention_tp(self): - """Tests TransMultiHeadAttentionformer Layer with tensor parallel in BF16""" - self.run_2gpu(str(test_root / "parallel_tests" / "attention_tp.py")) - - -class TestParallelTransformerLayer(TestDistributed): - """Test Transformer Layer in Parallel mode""" - - @unittest.skipIf(not is_devices_enough(2), "TestParallelTransformerLayer needs 2 GPUs") - @unittest.skipIf(not gpu_has_fp8, reason) - def test_transformer_tp(self): - """Tests Transformer Layer with tensor parallel in BF16""" - self.run_2gpu(str(test_root / "parallel_tests" / "transformer_tp.py")) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/paddle/test_recompute.py b/tests/paddle/test_recompute.py deleted file mode 100644 index 59079b0d1d..0000000000 --- a/tests/paddle/test_recompute.py +++ /dev/null @@ -1,56 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""Test TE Paddle Recompute""" - -from pathlib import Path -import re -import subprocess - -import numpy as np -import pytest - -from transformer_engine.paddle.fp8 import is_fp8_available - -test_root = Path(__file__).resolve().parent -is_fp8_supported, reason = is_fp8_available() - - -@pytest.mark.skipif(not is_fp8_supported, reason=reason) -@pytest.mark.parametrize("use_reentrant", [False, True]) -def test_transformer_encoder_recompute(use_reentrant): - """ - Test TransformerLayer encoder recompute - """ - rtol = 1e-5 - atol = 1e-5 - - def launch_subprocess_and_check_output(enable_recompute): - """Launch training in subprocess and check output""" - try: - cmd = [ - "python", - str(test_root / "recompute_tests" / "recompute_transformer_encoder.py"), - str(int(enable_recompute)), - str(int(use_reentrant)), - ] - result = subprocess.check_output(cmd, stderr=subprocess.STDOUT, universal_newlines=True) - - print(result) - - loss_match = re.search(r"Loss:\s+(-?\d+\.\d+)", result) - memory_match = re.search(r"Peak memory:\s+(\d+)", result) - - loss_value = float(loss_match.group(1)) - memory_value = int(memory_match.group(1)) - - return loss_value, memory_value - - except subprocess.CalledProcessError as e: - raise ValueError(f"Subprocess failed with error: {e}") from e - - loss_recompute, peak_memory_recompute = launch_subprocess_and_check_output(True) - loss_ref, peak_memory_ref = launch_subprocess_and_check_output(False) - - assert peak_memory_recompute < peak_memory_ref - np.testing.assert_allclose(loss_recompute, loss_ref, rtol=rtol, atol=atol) diff --git a/tests/paddle/utils.py b/tests/paddle/utils.py deleted file mode 100644 index b0a8d0d80b..0000000000 --- a/tests/paddle/utils.py +++ /dev/null @@ -1,221 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""Utils for testing""" - -import random -from typing import Union - -import numpy as np -import paddle -from paddle.distributed import fleet -from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker - -import transformer_engine # pylint: disable=unused-import -from transformer_engine.paddle.constants import ( - TE_DType, - AttnBiasType, - AttnMaskType, - FusedAttnBackend, -) -from transformer_engine.paddle.fp8 import FP8TensorMeta -from transformer_engine import ( - transformer_engine_paddle as tex, -) # pylint: disable=wrong-import-order - - -def create_fp8_meta(num_gemms=1, amax_history_len=10): - """ - Create and initialize FP8TensorMeta - """ - fp8_meta = FP8TensorMeta(is_forward=True) - fp8_meta.prepare(num_gemms, amax_history_len) - return fp8_meta - - -def assert_allclose( - actual, desired, rtol=1e-05, atol=1e-08, equal_nan=True, err_msg="", verbose=True -): - """Compare two input paddle tensors""" - if isinstance(actual, paddle.Tensor): - actual = paddle.cast(actual, "float32") - if isinstance(desired, paddle.Tensor): - desired = paddle.cast(desired, "float32") - if len(actual.shape) == 0: - actual = actual.item() - desired = desired.item() - else: - actual = actual.numpy() - desired = desired.numpy() - np.testing.assert_allclose(actual, desired, rtol, atol, equal_nan, err_msg, verbose) - - -def assert_shape(inp, expected_shape): - """Assert the shape of input tensor equals to expected shape""" - assert ( - inp.shape == expected_shape - ), f"Expected tensor shape: {expected_shape} != actual tensor shape: {inp.shape}" - - -def is_devices_enough(required): - """If the number of device is enough""" - return paddle.device.cuda.device_count() >= required - - -def set_random_seed(seed): - """Set random seed for reproducability.""" - fleet.meta_parallel.model_parallel_random_seed(seed) - - hcg = fleet.get_hybrid_communicate_group() - if paddle.distributed.get_world_size() > 1: - # obtain rank message of hybrid parallel - - mp_rank = hcg.get_model_parallel_rank() - mp_size = hcg.get_model_parallel_world_size() - - pp_rank = hcg.get_stage_id() - pp_size = hcg.get_pipe_parallel_world_size() - - dp_rank = hcg.get_data_parallel_rank() - dp_size = hcg.get_data_parallel_world_size() - - sharding_rank = hcg.get_sharding_parallel_rank() - else: - mp_rank, mp_size = 0, 1 - pp_rank, pp_size = 0, 1 - dp_rank, dp_size = 0, 1 - sharding_rank, _ = 0, 1 - - random.seed(seed + 100 * pp_rank) - np.random.seed(seed + 100 * pp_rank) - - seed_offset = seed + 1024 + paddle.distributed.get_world_size() - global_seed = ( - seed_offset - + pp_rank * (mp_size) - + dp_rank * (mp_size * pp_size) - + sharding_rank * (mp_size * pp_size * dp_size) - ) - - seed_offset += paddle.distributed.get_world_size() - local_seed = ( - seed_offset - + mp_rank - + pp_rank * (mp_size) - + dp_rank * (mp_size * pp_size) - + sharding_rank * (mp_size * pp_size * dp_size) - ) - - tracker = get_rng_state_tracker() - # tracker.reset() - if "global_seed" not in tracker.states_: - tracker.add("global_seed", global_seed) - if "local_seed" not in tracker.states_: - tracker.add("local_seed", local_seed) - - paddle.seed(global_seed) - - -def get_fused_attention_backend( - num_heads: int, - num_gqa_groups: int, - q_seqlen: int, - kv_seqlen: int, - head_size: int, - dtype: Union[paddle.dtype, str], - dropout: float, - qkv_layout: str = "bs3hd", - bias_type: str = "no_bias", - mask_type: str = "causal", -) -> tex.NVTE_Fused_Attn_Backend: - """Get cuDNN fused attention backend for attention config""" - if isinstance(dtype, str): - dtype = dict( - float32=paddle.float32, - bfloat16=paddle.bfloat16, - float16=paddle.float16, - )[dtype] - return tex.get_fused_attn_backend( - TE_DType[dtype], - TE_DType[dtype], - tex.get_nvte_qkv_layout(qkv_layout), - AttnBiasType[bias_type], - AttnMaskType[mask_type], - dropout, - num_heads, - num_gqa_groups, - q_seqlen, - kv_seqlen, - head_size, - ) - - -def is_fused_attention_supported( - num_heads: int, - num_gqa_groups: int, - q_seqlen: int, - kv_seqlen: int, - head_size: int, - dtype: Union[paddle.dtype, str], - dropout: float, - qkv_layout: str = "bs3hd", - bias_type: str = "no_bias", - mask_type: str = "causal", -) -> bool: - """Check if cuDNN fused attention is supported for attention config""" - backend = get_fused_attention_backend( - num_heads=num_heads, - num_gqa_groups=num_gqa_groups, - q_seqlen=q_seqlen, - kv_seqlen=kv_seqlen, - head_size=head_size, - dtype=dtype, - dropout=dropout, - qkv_layout=qkv_layout, - bias_type=bias_type, - mask_type=mask_type, - ) - return backend != FusedAttnBackend["No_Backend"] - - -def register_sequence_parallel_allreduce_hooks(model, accumulation_steps) -> None: - """Register allreduce hooks for sequence parallel tensors""" - - def is_sequence_parallel_parameter(parameter): - """If input tensor is marked as sequence parallel tensor""" - out = getattr(parameter, "sequence_parallel", False) - return out - - def create_allreduce_gradient_hook(param, accumulation_steps): - """Create allreduce gradient hook""" - hcg = fleet.get_hybrid_communicate_group() - pg = hcg.get_model_parallel_group().process_group - step = [0] - - @paddle.autograd.no_grad() - def __impl__(): - step[0] += 1 - if (step[0] % accumulation_steps) == 0: - if hasattr(param, "main_grad"): - pg.allreduce(param.main_grad).wait() - else: - pg.allreduce(param.grad).wait() - - return __impl__ - - if accumulation_steps <= 0 or not paddle.distributed.is_initialized(): - return - - hcg = fleet.get_hybrid_communicate_group() - mp_group = hcg.get_model_parallel_group() - if mp_group.nranks <= 1: - return - - params = [] - for p in model.parameters(): - if is_sequence_parallel_parameter(p): - params.append(p) - - for p in params: - hook = create_allreduce_gradient_hook(p, accumulation_steps) - p._register_backward_hook(hook) diff --git a/tests/pytorch/custom_ort_ops/.gitignore b/tests/pytorch/custom_ort_ops/.gitignore deleted file mode 100644 index d491fb774c..0000000000 --- a/tests/pytorch/custom_ort_ops/.gitignore +++ /dev/null @@ -1,3 +0,0 @@ -build -onnxruntime -libcustom_ort_ops.so diff --git a/tests/pytorch/custom_ort_ops/CMakeLists.txt b/tests/pytorch/custom_ort_ops/CMakeLists.txt deleted file mode 100644 index d3e95bd4bc..0000000000 --- a/tests/pytorch/custom_ort_ops/CMakeLists.txt +++ /dev/null @@ -1,29 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -cmake_minimum_required(VERSION 3.21) -project(custom_ort_ops LANGUAGES CXX) - -# Dependencies -find_package(CUDAToolkit REQUIRED) -set(ONNX_INCLUDE_DIR ${CMAKE_SOURCE_DIR}/onnxruntime/include) -if(NOT EXISTS "${ONNX_INCLUDE_DIR}") - message(FATAL_ERROR - "Could not find ONNX Runtime headers. " - "Please clone https://github.com/microsoft/onnxruntime " - "into TransformerEngine/tests/pytorch/onnx.") -endif() -include_directories(${ONNX_INCLUDE_DIR}) - -# Configure library -add_library(custom_ort_ops SHARED custom_op_library.cc) -target_link_libraries(custom_ort_ops PUBLIC CUDA::cudart) -target_include_directories(custom_ort_ops PUBLIC - ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) -target_include_directories(custom_ort_ops PRIVATE - ${ONNX_INCLUDE_DIR}/onnxruntime - ${ONNX_INCLUDE_DIR}/onnxruntime/core/session) - -# Install library -install(TARGETS custom_ort_ops DESTINATION .) diff --git a/tests/pytorch/custom_ort_ops/README.md b/tests/pytorch/custom_ort_ops/README.md deleted file mode 100644 index ca392805be..0000000000 --- a/tests/pytorch/custom_ort_ops/README.md +++ /dev/null @@ -1,22 +0,0 @@ -# Custom ONNX Runtime operators for Transformer Engine tests - -This directory contains code that builds custom ONNX operators for use -in Transformer Engine tests. It includes basic, non-performant -implementations of the FP8 quantization and dequantization operators -that are used when exporting Transformer Engine models to ONNX. - -For more information, see [the ONNX Runtime reference for custom -operators](https://onnxruntime.ai/docs/reference/operators/add-custom-op.html). -Much of the code has been adapted from [an ONNX Runtime -test](https://github.com/microsoft/onnxruntime/blob/de93f40240459953a6e3bbb86b6ad83eaeab681f/onnxruntime/test/testdata/custom_op_library/custom_op_library.cc). - -## Usage - -* Build the custom operators: -```bash -$ bash TransformerEngine/tests/pytorch/custom_ort_ops/build.sh -``` -* Run the ONNX export tests with pytest: -```bash -$ python -m pytest TransformerEngine/tests/pytorch/test_onnx_export.py -``` \ No newline at end of file diff --git a/tests/pytorch/custom_ort_ops/build.sh b/tests/pytorch/custom_ort_ops/build.sh deleted file mode 100644 index 01502ba6fb..0000000000 --- a/tests/pytorch/custom_ort_ops/build.sh +++ /dev/null @@ -1,17 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -set -ex - -: ${CUSTOM_ORT_OPS_PATH=$(dirname $(realpath $0))} -cd ${CUSTOM_ORT_OPS_PATH} - -# Download ONNX Runtime source -git clone --depth=1 -b rel-1.19.2 --single-branch https://github.com/microsoft/onnxruntime.git || true - -# Configure and build with CMake -mkdir -p build -cmake -S . -B build -DCMAKE_INSTALL_PREFIX=. -cmake --build build --verbose -cmake --install build --verbose diff --git a/tests/pytorch/custom_ort_ops/custom_op_library.cc b/tests/pytorch/custom_ort_ops/custom_op_library.cc deleted file mode 100755 index c7b94ff700..0000000000 --- a/tests/pytorch/custom_ort_ops/custom_op_library.cc +++ /dev/null @@ -1,102 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -#include "custom_op_library.h" - -#define ORT_API_MANUAL_INIT -#include "onnxruntime_c_api.h" -#include "onnxruntime_cxx_api.h" -#undef ORT_API_MANUAL_INIT - -#include -#include -#include -#include -#include - -#include "core/common/common.h" -#include "core/session/onnxruntime_lite_custom_op.h" -#include - -namespace { - -template -void Quantize(OrtKernelContext* context, - const Ort::Custom::Tensor& input, - const Ort::Custom::Tensor& scale_inv, - Ort::Custom::Tensor& output) { - auto raw_input = input.Data(); - auto raw_scale_inv = scale_inv.Data(); - auto raw_output = reinterpret_cast(output.Allocate(input.Shape())); - const auto rs = static_cast(raw_scale_inv[0]); - const size_t N = input.NumberOfElement(); - for (size_t i = 0; i < N; ++i) { - const auto x = static_cast(raw_input[i]); - raw_output[i] = static_cast(x / rs); - } -} - -template -void Dequantize(OrtKernelContext* context, - const Ort::Custom::Tensor& input, - const Ort::Custom::Tensor& scale_inv, - Ort::Custom::Tensor& output) { - auto raw_input = reinterpret_cast(input.Data()); - auto raw_scale_inv = scale_inv.Data(); - auto raw_output = output.Allocate(input.Shape()); - const auto rs = static_cast(raw_scale_inv[0]); - const size_t N = input.NumberOfElement(); - for (size_t i = 0; i < N; ++i) { - const auto x = rs * static_cast(raw_input[i]); - raw_output[i] = static_cast(x); - } -} - -static void AddOrtCustomOpDomainToContainer(Ort::CustomOpDomain&& domain) { - static std::vector ort_custom_op_domain_container; - static std::mutex ort_custom_op_domain_mutex; - std::lock_guard lock(ort_custom_op_domain_mutex); - ort_custom_op_domain_container.push_back(std::move(domain)); -} - -} // namespace - -OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options, const OrtApiBase* api) { - Ort::Global::api_ = api->GetApi(ORT_API_VERSION); - - // Namespace for custom ops - static const char* c_OpDomain = "trt"; - - // Construct custom ops - static const std::unique_ptr c_Quantize{ - Ort::Custom::CreateLiteCustomOp("TRT_FP8QuantizeLinear", - "CPUExecutionProvider", - Quantize) - }; - static const std::unique_ptr c_Dequantize{ - Ort::Custom::CreateLiteCustomOp("TRT_FP8DequantizeLinear", - "CPUExecutionProvider", - Dequantize<__nv_fp8_e4m3, float, float>) - }; - - // Register custom ops - OrtStatus* result = nullptr; - ORT_TRY { - Ort::CustomOpDomain domain{c_OpDomain}; - domain.Add(c_Quantize.get()); - domain.Add(c_Dequantize.get()); - Ort::UnownedSessionOptions session_options(options); - session_options.Add(domain); - AddOrtCustomOpDomainToContainer(std::move(domain)); - } - ORT_CATCH(const std::exception& e) { - ORT_HANDLE_EXCEPTION([&]() { - Ort::Status status{e}; - result = status.release(); - }); - } - return result; -} diff --git a/tests/pytorch/distributed/run_gemm_with_overlap.py b/tests/pytorch/distributed/run_gemm_with_overlap.py index 4f170e3f84..9e11e07e11 100644 --- a/tests/pytorch/distributed/run_gemm_with_overlap.py +++ b/tests/pytorch/distributed/run_gemm_with_overlap.py @@ -19,8 +19,8 @@ import transformer_engine.pytorch as te import transformer_engine.pytorch.cpp_extensions as tex -from transformer_engine.common.recipe import Format -from transformer_engine.pytorch.fp8 import _default_sf_compute +from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer +from transformer_engine.pytorch.module.base import get_cublas_workspace_size_bytes warnings.filterwarnings("ignore", category=DeprecationWarning) warnings.filterwarnings("ignore", category=FutureWarning) @@ -47,14 +47,14 @@ def _mapped_argtype(opt, typemap): def _parse_args(argv=None, namespace=None): parser = argparse.ArgumentParser(description="Test comm+GEMM overlap with Userbuffers.") parser.add_argument("-b", "--batch-size", type=int, default=2, help="Input batch size.") - parser.add_argument("-s", "--seq-length", type=int, default=512, help="Input sequence length.") + parser.add_argument("-s", "--seq-length", type=int, default=1024, help="Input sequence length.") parser.add_argument( - "-n", "--num-heads", type=int, default=12, help="Number of attention heads." + "-n", "--num-heads", type=int, default=16, help="Number of attention heads." ) parser.add_argument( - "-d", "--head-dim", type=int, default=64, help="Dimension of each attention head." + "-d", "--head-dim", type=int, default=48, help="Dimension of each attention head." ) - parser.add_argument("--seed", type=int, default=1234, help="RNG seed.") + parser.add_argument("--seed", type=int, default=42, help="RNG seed.") parser.add_argument( "--fp8", action="store_true", default=False, help="Enables the te.fp8_autocast() context." ) @@ -288,33 +288,6 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None else tex.CommOverlapHelper(bootstrap_pg) ) - if opts.comm_type == tex.CommOverlapType.RS: - if opts.bulk_overlap: - ub_algo = tex.CommOverlapAlgo.BULK_OVERLAP_RS - elif opts.p2p: - ub_algo = ( - tex.CommOverlapAlgo.ATOMIC_GEMM_RS_P2P - if opts.atomic - else tex.CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P - ) - else: - ub_algo = ( - tex.CommOverlapAlgo.ATOMIC_GEMM_RS - if opts.atomic - else tex.CommOverlapAlgo.SPLIT_PIPELINED_RS - ) - elif opts.comm_type == tex.CommOverlapType.AG: - if opts.bulk_overlap: - ub_algo = tex.CommOverlapAlgo.BULK_OVERLAP_AG - else: - ub_algo = ( - tex.CommOverlapAlgo.ATOMIC_GEMM_AG_P2P - if opts.atomic - else tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P - ) - else: - raise TypeError("Invalid comm+GEMM overlap type!") - # Initialize userbuffers with (M, N) buffer # M = sequence * batch # N = hidden size @@ -322,11 +295,7 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None inp_shape = (opts.seq_length, opts.batch_size, hidden_size) outer_size = reduce(operator.mul, inp_shape[:-1], 1) buffer_dtype = torch.bfloat16 - if ( - opts.fp8 - and not opts.bulk_overlap - and (opts.comm_type == tex.CommOverlapType.AG or opts.fp8_output) - ): + if opts.fp8 and not opts.bulk_overlap and opts.comm_type == tex.CommOverlapType.AG: buffer_dtype = torch.uint8 ub_obj = ( tex.CommOverlapP2P( @@ -421,6 +390,10 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None std=opts.std, ) + # Allocate cuBLAS workspace + workspace_size = 3 * get_cublas_workspace_size_bytes() + workspace = torch.empty(workspace_size, dtype=torch.uint8, device="cuda") + # Gather global tensors and calculate reference result (need these first for Fp8 scales) if opts.bulk_overlap: ker_g = torch.transpose(kernel_t, 0, 1) @@ -467,120 +440,123 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None inp2_g = torch.nn.functional.gelu(ref_g) # pylint: disable=not-callable ref2_g = torch.matmul(inp2_g, ker2_g) + inp_quantizer = None + ker_quantizer = None + out_quantizer = None + bulk_inp_quantizer = None + inp2_quantizer = None + ker2_quantizer = None + out2_quantizer = None if opts.fp8: - fp8_formats = { - tex.DType.kFloat8E4M3: Format.E4M3, - tex.DType.kFloat8E5M2: Format.E5M2, - } - # Structure to maintain amax and scale/scale_inv information for the kernel and input - fp8_dtype = tex.DType.kFloat8E4M3 - fp8_meta = tex.FP8TensorMeta() num_gemms = 6 if ub_obj2 is not None else 3 - fp8_meta.amax_history = torch.zeros((2, num_gemms), dtype=torch.float, device="cuda") - fp8_meta.scale = torch.ones(num_gemms, dtype=torch.float, device="cuda") - fp8_meta.scale_inv = torch.ones(num_gemms, dtype=torch.float, device="cuda") + fp8_dtype = tex.DType.kFloat8E4M3 + fp8_scales = torch.ones(num_gemms, dtype=torch.float, device="cuda") + fp8_amaxes = torch.zeros(num_gemms, dtype=torch.float, device="cuda") # Compute initial amaxes and scales inp_amax = torch.max(torch.abs(inp_g)) - fp8_meta.amax_history[1][tex.FP8FwdTensors.GEMM1_INPUT].copy_(inp_amax) + fp8_amaxes[0].copy_(inp_amax) ker_amax = torch.max(torch.abs(ker_g)) - fp8_meta.amax_history[1][tex.FP8FwdTensors.GEMM1_WEIGHT].copy_(ker_amax) + fp8_amaxes[1].copy_(ker_amax) ref_amax = torch.max(torch.abs(ref_g)) - fp8_meta.amax_history[1][tex.FP8FwdTensors.GEMM1_OUTPUT].copy_(ref_amax) + fp8_amaxes[2].copy_(ref_amax) if opts.bulk_overlap and opts.comm_type == tex.CommOverlapType.RS: bulk_amax = torch.max(torch.abs(bulk_inp)) - fp8_meta.amax_history[1][tex.FP8FwdTensors.GEMM2_OUTPUT].copy_(bulk_amax) + fp8_amaxes[5].copy_(bulk_amax) elif ub_obj2 is not None: inp2_amax = torch.max(torch.abs(inp2_g)) - fp8_meta.amax_history[1][tex.FP8FwdTensors.GEMM2_INPUT].copy_(inp2_amax) + fp8_amaxes[3].copy_(inp2_amax) ker2_amax = torch.max(torch.abs(ker2_g)) - fp8_meta.amax_history[1][tex.FP8FwdTensors.GEMM2_WEIGHT].copy_(ker2_amax) + fp8_amaxes[4].copy_(ker2_amax) ref2_amax = torch.max(torch.abs(ref2_g)) - fp8_meta.amax_history[1][tex.FP8FwdTensors.GEMM2_OUTPUT].copy_(ref2_amax) - fp8_meta.scale = _default_sf_compute( - fp8_meta.amax_history[1], fp8_meta.scale, fp8_formats[fp8_dtype].value.max_fwd, 1 - ) - fp8_meta.scale_inv = torch.reciprocal(fp8_meta.scale) + fp8_amaxes[5].copy_(ref2_amax) - # Cast input to Float8Tensor - inp_fp8 = tex.cast_to_fp8(inp, fp8_meta, tex.FP8FwdTensors.GEMM1_INPUT, fp8_dtype) + inp_quantizer = Float8Quantizer(fp8_scales[0].clone(), fp8_amaxes[0].clone(), fp8_dtype) + ker_quantizer = Float8Quantizer(fp8_scales[1].clone(), fp8_amaxes[1].clone(), fp8_dtype) + if opts.fp8_output: + out_quantizer = Float8Quantizer(fp8_scales[2].clone(), fp8_amaxes[2].clone(), fp8_dtype) - # Cast kernel to Float8Tensor - kernel_t_fp8 = tex.cast_to_fp8( - kernel_t, fp8_meta, tex.FP8FwdTensors.GEMM1_WEIGHT, fp8_dtype - ) if opts.bulk_overlap and opts.comm_type == tex.CommOverlapType.RS: - bulk_inp_fp8 = tex.cast_to_fp8( - bulk_inp, fp8_meta, tex.FP8Tensors.GEMM2_OUTPUT, fp8_dtype + bulk_inp_quantizer = Float8Quantizer( + fp8_scales[5].clone(), fp8_amaxes[5].clone(), fp8_dtype ) elif ub_obj2 is not None: - kernel2_t_fp8 = tex.cast_to_fp8( - kernel2_t, fp8_meta, tex.FP8FwdTensors.GEMM2_WEIGHT, fp8_dtype + inp2_quantizer = Float8Quantizer( + fp8_scales[3].clone(), fp8_amaxes[3].clone(), fp8_dtype + ) + ker2_quantizer = Float8Quantizer( + fp8_scales[4].clone(), fp8_amaxes[4].clone(), fp8_dtype ) + if opts.fp8_output: + out2_quantizer = Float8Quantizer( + fp8_scales[5].clone(), fp8_amaxes[5].clone(), fp8_dtype + ) + + # Cast input to Float8Tensor + inp_fp8 = inp_quantizer(inp) + + # Cast kernel to Float8Tensor + kernel_t_fp8 = ker_quantizer(kernel_t) + if opts.bulk_overlap and opts.comm_type == tex.CommOverlapType.RS: + bulk_inp_fp8 = bulk_inp_quantizer(bulk_inp) + elif ub_obj2 is not None: + kernel2_t_fp8 = ker2_quantizer(kernel2_t) # Make sure the inputs are cast correctly if opts.check_numerics: torch.allclose( inp.to(dtype=torch.float32), - inp_fp8 * fp8_meta.scale_inv[tex.FP8FwdTensors.GEMM1_INPUT], + inp_fp8.dequantize(dtype=torch.float32), rtol=0.125, atol=0.0675, ) torch.allclose( kernel_t.to(dtype=torch.float32), - kernel_t_fp8 * fp8_meta.scale_inv[tex.FP8FwdTensors.GEMM1_WEIGHT], + kernel_t_fp8.dequantize(dtype=torch.float32), rtol=0.125, atol=0.0675, ) if opts.bulk_overlap and opts.comm_type == tex.CommOverlapType.RS: torch.allclose( bulk_inp.to(dtype=torch.float32), - bulk_inp_fp8 * fp8_meta.scale_inv[tex.FP8FwdTensors.GEMM2_OUTPUT], + bulk_inp_fp8.dequantize(dtype=torch.float32), rtol=0.125, atol=0.0675, ) elif ub_obj2 is not None: torch.allclose( kernel2_t.to(dtype=torch.float32), - kernel2_t_fp8 * fp8_meta.scale_inv[tex.FP8FwdTensors.GEMM2_WEIGHT], + kernel2_t_fp8.dequantize(dtype=torch.float32), rtol=0.125, atol=0.0675, ) - # Set Fp8 scales for userbuffers - if opts.comm_type == tex.CommOverlapType.AG: - ub_obj.set_ubuf_scale_inv(fp8_meta.scale_inv[tex.FP8FwdTensors.GEMM1_INPUT]) - if ub_obj2 is not None: - ub_obj2.set_ubuf_scale_inv(fp8_meta.scale_inv[tex.FP8FwdTensors.GEMM2_OUTPUT]) - elif opts.bulk_overlap: - ub_obj.set_ubuf_scale_inv(fp8_meta.scale_inv[tex.FP8FwdTensors.GEMM2_OUTPUT]) - else: - ub_obj.set_ubuf_scale_inv(fp8_meta.scale_inv[tex.FP8FwdTensors.GEMM1_OUTPUT]) - # Set up comm/compute buffers - ubuf_out2 = None + rs_out = None rs_out2 = None if opts.comm_type == tex.CommOverlapType.AG: if opts.bulk_overlap: - ub_obj.copy_input_to_ubuf(bulk_inp, 1) + ub_obj.copy_into_buffer(bulk_inp, bulk_inp_quantizer, True) gemm_inp = inp else: - ub_obj.copy_input_to_ubuf(inp_fp8 if opts.fp8 else inp, 1) - gemm_inp = ub_obj.get_ubuf_output(1) - ubuf_out = None - rs_out = None + ub_obj.copy_into_buffer(inp_fp8 if opts.fp8 else inp, inp_quantizer, True) + gemm_inp = ub_obj.get_buffer(inp_quantizer, False, inp_g.size()) if ub_obj2 is not None: - ubuf_out2 = ub_obj2.get_ubuf_output(1) + if opts.fp8 and opts.fp8_output: + ub_obj2.set_buffer_params(out_quantizer) rs_out2 = torch.empty( (outer_size // tp_size, hidden_size), dtype=torch.bfloat16, device="cuda" ) else: if opts.bulk_overlap: - ub_obj.copy_input_to_ubuf(bulk_inp_fp8 if opts.fp8 else bulk_inp, 0) - ubuf_out = None - else: - ubuf_out = ub_obj.get_ubuf_output(1) + ub_obj.copy_into_buffer( + bulk_inp_fp8 if opts.fp8 else bulk_inp, bulk_inp_quantizer, False + ) + if opts.fp8: + ub_obj.set_buffer_params(bulk_inp_quantizer) + elif opts.fp8 and opts.fp8_output: + ub_obj.set_buffer_params(out_quantizer) gemm_inp = inp_fp8 if opts.fp8 else inp rs_out = torch.empty( (outer_size // tp_size, hidden_size), dtype=torch.bfloat16, device="cuda" @@ -588,88 +564,47 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None # Wrap GEMM ops in condensed functions to make CUDA Graphs easier to use def _fp8_gemm(): - return tex.fp8_gemm( + return tex.general_gemm( kernel_t_fp8, - fp8_meta.scale_inv, - tex.FP8FwdTensors.GEMM1_WEIGHT, - fp8_dtype, gemm_inp, - fp8_meta.scale_inv, - tex.FP8FwdTensors.GEMM1_INPUT, - fp8_dtype, - torch.uint8 if opts.fp8_output else torch.bfloat16, - te.module.base.get_workspace(), - bias=None, - use_bias=False, - gelu=False, + workspace, + out_dtype=torch.float8_e4m3fn if opts.fp8_output else torch.bfloat16, + quantization_params=out_quantizer, use_split_accumulator=te.module.base._2X_ACC_FPROP, - ub_algo=ub_algo, ub=ub_obj, - extra_output_tensor=rs_out, - out=ubuf_out, - D_dtype=fp8_dtype if opts.fp8_output else None, - fp8_meta_tensor=fp8_meta if opts.fp8_output else None, - out_index=tex.FP8FwdTensors.GEMM1_OUTPUT if opts.fp8_output else None, + ub_type=opts.comm_type, + extra_output=rs_out, + bulk_overlap=opts.bulk_overlap, ) def _fp8_gemm2(gemm1_out): gemm2_inp = tex.gelu( - ( - tex.cast_from_fp8( - gemm1_out, - fp8_meta, - tex.FP8FwdTensors.GEMM1_OUTPUT, - fp8_dtype, - tex.DType.kFloat32, - ) - if opts.fp8_output - else gemm1_out - ), - fp8_meta, - tex.FP8FwdTensors.GEMM2_INPUT, - fp8_dtype, + (gemm1_out.dequantize() if opts.fp8_output else gemm1_out), + inp2_quantizer, ) - return tex.fp8_gemm( + return tex.general_gemm( kernel2_t_fp8, - fp8_meta.scale_inv, - tex.FP8FwdTensors.GEMM2_WEIGHT, - fp8_dtype, gemm2_inp, - fp8_meta.scale_inv, - tex.FP8FwdTensors.GEMM2_INPUT, - fp8_dtype, - torch.uint8 if opts.fp8_output else torch.bfloat16, - te.module.base.get_workspace(), - bias=None, - use_bias=False, - gelu=False, + workspace, + out_dtype=torch.float8_e4m3fn if opts.fp8_output else torch.bfloat16, + quantization_params=out2_quantizer, use_split_accumulator=te.module.base._2X_ACC_FPROP, - ub_algo=( - tex.CommOverlapAlgo.ATOMIC_GEMM_RS_P2P - if opts.atomic_rs_p2p - else tex.CommOverlapAlgo.ATOMIC_GEMM_RS - ), ub=ub_obj2, - extra_output_tensor=rs_out2, - out=ubuf_out2, - D_dtype=fp8_dtype if opts.fp8_output else None, - fp8_meta_tensor=fp8_meta if opts.fp8_output else None, - out_index=tex.FP8FwdTensors.GEMM2_OUTPUT if opts.fp8_output else None, + ub_type=tex.CommOverlapType.AG, + extra_output=rs_out2, ) def _gemm(): - return tex.gemm( + return tex.general_gemm( kernel_t, gemm_inp, - torch.bfloat16, - te.module.base.get_workspace(), - bias=None, - use_bias=False, - gelu=False, - ub_algo=ub_algo, + workspace, + out_dtype=torch.bfloat16, + use_split_accumulator=te.module.base._2X_ACC_FPROP, ub=ub_obj, - extra_output_tensor=rs_out, - out=ubuf_out, + ub_type=opts.comm_type, + extra_output=rs_out, + bulk_overlap=opts.bulk_overlap, ) # Trigger GEMM @@ -746,10 +681,10 @@ def _gemm(): output_info = "" if opts.comm_type == tex.CommOverlapType.AG: # Bulk overlap AG output is already gathered - test_out = ub_obj.get_ubuf_output(1) + test_out = ub_obj.get_buffer(bulk_inp_quantizer, False) else: # Bulk overlap RS output needs to be gathered - out_local = ub_obj.get_ubuf_output(0) + out_local = ub_obj.get_buffer(bulk_inp_quantizer, True) output_info += f"rs_output: {list(out_local.shape)} | " test_out = te.distributed.gather_along_first_dim(out_local, tp_group)[0] @@ -775,17 +710,7 @@ def _gemm(): test_out = te.distributed.gather_along_first_dim(output, tp_group)[0] else: # AG Output: (M, K/P) -> T -> (K/P, M) -> gather -> (K, M) -> T -> (M, K) - output = ( - tex.cast_from_fp8( - all_outputs[0], - fp8_meta, - tex.FP8FwdTensors.GEMM1_OUTPUT, - fp8_dtype, - tex.DType.kFloat32, - ) - if opts.fp8_output - else all_outputs[0] - ) + output = all_outputs[0].dequantize() if opts.fp8_output else all_outputs[0] test_out = torch.transpose( te.distributed.gather_along_first_dim( torch.transpose(output, 0, 1), tp_group @@ -798,25 +723,6 @@ def _gemm(): output = rs_out.to(dtype=torch.float32) test_out = te.distributed.gather_along_first_dim(output, tp_group)[0] - if opts.fp8: - dist_print("GEMM1 FP8 metas = [INPUT, WEIGHT, OUTPUT]", src=0, section=True) - fp8_meta_info = ( - f"amax_reference = {fp8_meta.amax_history[1][:3].tolist()}\n" - + f"amax_history = {fp8_meta.amax_history[0][:3].tolist()}\n" - + f"scale = {fp8_meta.scale[:3].tolist()}\n" - + f"scale_inv = {fp8_meta.scale_inv[:3].tolist()}" - ) - dist_print(fp8_meta_info, src=0, group=tp_group) - if ub_obj2 is not None: - dist_print("GEMM2 FP8 metas = [INPUT, WEIGHT, OUTPUT]", src=0, section=True) - fp8_meta_info = ( - f"amax_reference = {fp8_meta.amax_history[1][3:].tolist()}\n" - + f"amax_history = {fp8_meta.amax_history[0][3:].tolist()}\n" - + f"scale = {fp8_meta.scale[3:].tolist()}\n" - + f"scale_inv = {fp8_meta.scale_inv[3:].tolist()}" - ) - dist_print(fp8_meta_info, src=0, group=tp_group) - ref_out = ref2_g if ub_obj2 is not None else ref_g test_nonzeros = torch.count_nonzero(test_out) ref_nonzeros = torch.count_nonzero(ref_out) diff --git a/tests/pytorch/distributed/run_layer_with_overlap.py b/tests/pytorch/distributed/run_layer_with_overlap.py index 5a67bd616a..d4a01386ee 100644 --- a/tests/pytorch/distributed/run_layer_with_overlap.py +++ b/tests/pytorch/distributed/run_layer_with_overlap.py @@ -9,6 +9,7 @@ import socket import argparse import warnings +import pprint import torch import torch.distributed as dist @@ -39,6 +40,8 @@ def _te_layer_argtype(name): def _get_layer_args(config, tp_group, tp_size, reference=False): hidden_size = config.num_heads * config.head_dim + ffn_hidden_size = 4 * hidden_size + qkv_size = 3 * hidden_size input_shape = [config.seq_length, config.batch_size, hidden_size] args = [hidden_size] kwargs = { @@ -47,46 +50,41 @@ def _get_layer_args(config, tp_group, tp_size, reference=False): "tp_group": tp_group, "tp_size": tp_size, "sequence_parallel": True, + "ub_overlap_ag": not reference, + "ub_overlap_rs": not reference, } - kwargs["ub_overlap_ag"] = not reference - if config.layer_type is te.Linear: + if config.layer_type in [te.Linear, te.LayerNormLinear]: if config.linear_parallel_mode == "row": - input_shape[2] = hidden_size // tp_size - args.append(hidden_size) - kwargs["ub_overlap_rs"] = not reference + input_shape[-1] = ffn_hidden_size // tp_size + args = [ffn_hidden_size, hidden_size] + kwargs["ub_name"] = "proj" if config.layer_type == te.Linear else "fc2" elif config.linear_parallel_mode == "column": input_shape[0] = config.seq_length // tp_size - args.append(3 * hidden_size) - kwargs["ub_overlap_rs"] = config.overlap_rs_dgrad and not reference + args.append(qkv_size) + kwargs["ub_name"] = "qkv" + kwargs["ub_overlap_rs_dgrad"] = config.overlap_rs_dgrad and not reference kwargs["ub_bulk_dgrad"] = not config.overlap_rs_dgrad and not reference kwargs["ub_bulk_wgrad"] = not config.overlap_rs_dgrad and not reference kwargs["parallel_mode"] = config.linear_parallel_mode - kwargs["ub_name"] = "proj" if config.linear_parallel_mode == "row" else "qkv" else: input_shape[0] = config.seq_length // tp_size + if config.layer_type in [te.LayerNormMLP, te.TransformerLayer]: + args.append(ffn_hidden_size) + kwargs["seq_length"] = config.seq_length + if config.layer_type in [te.MultiheadAttention, te.TransformerLayer]: + args.append(config.num_heads) + kwargs["attention_dropout"] = 0.0 + kwargs["fuse_qkv_params"] = True + if config.layer_type is te.MultiheadAttention: + kwargs["input_layernorm"] = True + else: + kwargs["ub_tp_comm_overlap"] = not reference + kwargs["hidden_dropout"] = 0.0 + kwargs["set_parallel_mode"] = True kwargs["ub_overlap_rs_dgrad"] = config.overlap_rs_dgrad and not reference - kwargs["ub_bulk_wgrad"] = not config.overlap_rs_dgrad and not reference kwargs["ub_bulk_dgrad"] = not config.overlap_rs_dgrad and not reference - if config.layer_type is te.LayerNormLinear: - args.append(3 * hidden_size) - kwargs["parallel_mode"] = "column" - kwargs["ub_name"] = "qkv" - else: - kwargs["set_parallel_mode"] = True - kwargs["ub_overlap_rs"] = not reference - if config.layer_type in [te.LayerNormMLP, te.TransformerLayer]: - args.append(4 * hidden_size) - kwargs["seq_length"] = config.seq_length - if config.layer_type in [te.MultiheadAttention, te.TransformerLayer]: - args.append(config.num_heads) - kwargs["attention_dropout"] = 0.0 - kwargs["fuse_qkv_params"] = True - if config.layer_type is te.MultiheadAttention: - kwargs["input_layernorm"] = True - else: - kwargs["ub_tp_comm_overlap"] = not reference - kwargs["hidden_dropout"] = 0.0 + kwargs["ub_bulk_wgrad"] = not config.overlap_rs_dgrad and not reference return args, kwargs, input_shape @@ -97,12 +95,12 @@ def _parse_args(argv=None, namespace=None): ) parser.add_argument("-l", "--layer-type", type=_te_layer_argtype, default=te.LayerNormMLP) parser.add_argument("-b", "--batch-size", type=int, default=2, help="Input batch size.") - parser.add_argument("-s", "--seq-length", type=int, default=2048, help="Input sequence length.") + parser.add_argument("-s", "--seq-length", type=int, default=1024, help="Input sequence length.") parser.add_argument( - "-n", "--num-heads", type=int, default=12, help="Number of attention heads." + "-n", "--num-heads", type=int, default=16, help="Number of attention heads." ) parser.add_argument( - "-d", "--head-dim", type=int, default=64, help="Dimension of each attention head." + "-d", "--head-dim", type=int, default=48, help="Dimension of each attention head." ) parser.add_argument("--seed", type=int, default=42, help="RNG seed.") parser.add_argument( @@ -144,7 +142,7 @@ def _parse_args(argv=None, namespace=None): "--overlap-rs-dgrad", action="store_true", default=False, - help="Overlap reduce-scatter with DGRAD in the backward pass instead of bulk overlaps.", + help="Replace bulk DGRAD/WGRAD overlaps with DGRAD+RS in the backward pass for AG+GEMM.", ) parser.add_argument( "--debug", @@ -175,7 +173,7 @@ def _compare_tensors(name, test, ref, rtol, atol): ) return 1, numerics_info - diff = torch.abs(test - ref).flatten() + diff = torch.abs(test.flatten() - ref.flatten()) m = torch.argmax(diff) abs_err = diff[m].item() rel_err = abs_err / max(abs(ref.flatten()[m].item()), 1e-5) @@ -254,8 +252,8 @@ def dist_print(msg, src=None, end="\n", debug=False, error=False): ub_cfgs = None if opts.overlap_rs_dgrad: ub_cfgs = { - "proj_dgrad": {"method": "ring_exchange"}, "qkv_dgrad": {"method": "ring_exchange"}, + "fc1_dgrad": {"method": "ring_exchange"}, } te.module.base.initialize_ub( [opts.seq_length * opts.batch_size, opts.num_heads * opts.head_dim], @@ -271,6 +269,10 @@ def dist_print(msg, src=None, end="\n", debug=False, error=False): with te.fp8_model_init(enabled=opts.fp8_init): test_model = opts.layer_type(*args, **kwargs) dist_print("Initialized test model...", debug=True) + if WORLD_RANK == 0: + pprint.pprint(kwargs) + sys.stdout.write("\n") + dist.barrier() # Initialize the reference model and copy all parameters ref_args, ref_kwargs, _ = _get_layer_args(opts, nccl_world, WORLD_SIZE, reference=True) @@ -305,8 +307,8 @@ def run_fwd_bwd(model, x): out, *_ = y else: out = y - loss = out.sum() - loss.backward() + loss = out.sum() + loss.backward() return out torch_rng_state = torch.get_rng_state() @@ -342,29 +344,27 @@ def run_fwd_bwd(model, x): ref_grads.append(ref_param.grad) # Make sure we have the same number of gradients - num_grads_failed = torch.tensor([0], dtype=torch.uint8, device="cuda") + numerics_failed = torch.tensor([0], dtype=torch.uint8, device="cuda") if len(test_grads) != len(ref_grads): - num_grads_failed[0] = 1 + numerics_failed[0] = 1 numerics_info = ( "NUMERICAL CHECK FAILED: Incorrect number of gradients, " + f"expected {len(ref_grads)} but got {len(test_grads)}." ) dist_print(numerics_info, src=WORLD_RANK, error=True) - dist.all_reduce(num_grads_failed, dist.ReduceOp.MAX, nccl_world) + dist.all_reduce(numerics_failed, dist.ReduceOp.MAX, nccl_world) # Now validate accuracy - numerics_failed = torch.zeros(len(test_grads), dtype=torch.uint8, device="cuda") - if not bool(num_grads_failed.item()): + if not bool(numerics_failed.item()): for i, (test_g, ref_g) in enumerate(zip(test_grads, ref_grads)): rtol = 0.125 if opts.fp8 else 0.025 atol = 0.0625 if opts.fp8 else 0.00125 grad_failed, grad_info = _compare_tensors(names[i], test_g, ref_g, rtol, atol) dist_print(grad_info, src=WORLD_RANK, error=grad_failed) - numerics_failed[i] = int(grad_failed) - return_code = torch.max(numerics_failed) - dist.all_reduce(return_code, dist.ReduceOp.MAX, nccl_world) - else: - return_code = num_grads_failed + numerics_failed[0] = int(grad_failed) + dist.all_reduce(numerics_failed, dist.ReduceOp.MAX, nccl_world) + if bool(numerics_failed.item()) and not opts.debug: + break te.module.base.destroy_ub() dist_print("Destroying Userbuffers objects...", debug=True) @@ -374,7 +374,7 @@ def run_fwd_bwd(model, x): if opts.debug and WORLD_RANK == 0: print("Exiting...\n", end="", flush=True) - return return_code.item() + return numerics_failed[0].item() if __name__ == "__main__": diff --git a/tests/pytorch/distributed/run_numerics.py b/tests/pytorch/distributed/run_numerics.py index 64f36051c6..2d301e3151 100644 --- a/tests/pytorch/distributed/run_numerics.py +++ b/tests/pytorch/distributed/run_numerics.py @@ -4,9 +4,10 @@ # # See LICENSE for license information. -import sys -import os import argparse +import datetime +import os +import sys from functools import wraps import transformer_engine.pytorch as te @@ -14,7 +15,12 @@ from torch import nn import torch.distributed as dist -from transformer_engine.common.recipe import Format, DelayedScaling +from transformer_engine.common.recipe import ( + MXFP8BlockScaling, + DelayedScaling, + Format, + Recipe, +) from run_layer_with_overlap import _compare_tensors SEQ_LEN, BATCH_SIZE = 16, 16 @@ -23,15 +29,27 @@ WORLD_RANK, WORLD_SIZE = None, None NCCL_WORLD = None LOSS_FN = nn.MSELoss() -FP8 = False +QUANTIZATION = None + + +# Disable TF32 +torch.backends.cuda.matmul.allow_tf32 = False +torch.backends.cudnn.allow_tf32 = False -# Fp8 recipe setup -fp8_format = Format.HYBRID -fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max") + +# Quantization recipe setup +def quantization_recipe() -> Recipe: + if QUANTIZATION == "fp8": + return DelayedScaling( + fp8_format=Format.HYBRID, amax_history_len=32, amax_compute_algo="max" + ) + if QUANTIZATION == "mxfp8": + return MXFP8BlockScaling() + return te.fp8.get_default_fp8_recipe() def main(argv=None, namespace=None): - global WORLD_RANK, WORLD_SIZE, NCCL_WORLD, FP8 + global WORLD_RANK, WORLD_SIZE, NCCL_WORLD, QUANTIZATION WORLD_RANK = int(os.getenv("RANK", "0")) WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1")) @@ -44,6 +62,7 @@ def main(argv=None, namespace=None): "backend": "nccl", "rank": WORLD_RANK, "world_size": WORLD_SIZE, + "timeout": datetime.timedelta(seconds=30), } dist_init_kwargs["init_method"] = "env://" dist_init_kwargs["device_id"] = torch.device(f"cuda:{LOCAL_RANK}") @@ -57,9 +76,17 @@ def main(argv=None, namespace=None): parser = argparse.ArgumentParser() parser.add_argument("-l", "--layer-type", type=str) - parser.add_argument("--fp8", action="store_true", default=False) + parser.add_argument("--quantization", type=str, default=None) args = parser.parse_args(argv, namespace) + # Quantization scheme + QUANTIZATION = args.quantization + if QUANTIZATION in ("fp8", "mxfp8"): + global SEQ_LEN, BATCH_SIZE, HIDDEN_SIZE + SEQ_LEN = 32 + BATCH_SIZE = 32 + HIDDEN_SIZE = 128 + test_dict = [ test_linear, test_layernorm, @@ -68,8 +95,6 @@ def main(argv=None, namespace=None): test_transformer_layer, ] - FP8 = args.fp8 - for test in test_dict: test() dist.destroy_process_group() @@ -124,11 +149,10 @@ def dist_print(msg, src=None, end="\n", error=False): stream = sys.stderr if error else sys.stdout if WORLD_RANK == (0 if src is None else src): stream.write(f"[rank{WORLD_RANK}] {msg}{end}\n") - dist.barrier() def _get_tolerances(dtype): - if FP8: + if QUANTIZATION is not None: return {"rtol": 0.125, "atol": 0.0625} if dtype == torch.float16: @@ -153,8 +177,7 @@ def _check_outputs(output_single_node, output_distributed): dist_print(output_info, src=WORLD_RANK, error=output_failed) numerics_failed[0] = int(output_failed) dist.all_reduce(numerics_failed, dist.ReduceOp.MAX, NCCL_WORLD) - if bool(numerics_failed.item()): - sys.exit(1) + assert not bool(numerics_failed.item()) def _match_param_sizes(dist_param, single_param): @@ -213,13 +236,12 @@ def _check_gradients(model_distributed, model_single, main_grad_check=False): ) if grad_failed: - dist_print(i) - dist_print(name) + dist_print(i, src=WORLD_RANK) + dist_print(name, src=WORLD_RANK) dist_print(grad_info, src=WORLD_RANK, error=grad_failed) numerics_failed[0] = int(grad_failed) dist.all_reduce(numerics_failed, dist.ReduceOp.MAX, NCCL_WORLD) - if bool(numerics_failed.item()): - sys.exit(1) + assert not bool(numerics_failed.item()) def _copy_params(model_distributed, model_single): @@ -243,9 +265,18 @@ def _apply_models( model_single_node, model_distributed, input_single_node, input_distributed, **kwargs ): _alloc_main_grad(model_single_node, model_distributed) # for fuse_wgrad_accumulation=True - with te.fp8_autocast(enabled=FP8, fp8_recipe=fp8_recipe): + input_single_node.requires_grad_() + input_distributed.requires_grad_() + with te.fp8_autocast( + enabled=QUANTIZATION is not None, + fp8_recipe=quantization_recipe(), + ): output_single_node = model_single_node(input_single_node, **kwargs) - with te.fp8_autocast(enabled=FP8, fp8_recipe=fp8_recipe, fp8_group=NCCL_WORLD): + with te.fp8_autocast( + enabled=QUANTIZATION is not None, + fp8_recipe=quantization_recipe(), + fp8_group=NCCL_WORLD, + ): output_distributed = model_distributed(input_distributed, **kwargs) return output_single_node, output_distributed @@ -544,9 +575,7 @@ def _test_layernorm_mlp(set_parallel_mode=None, sequence_parallel=False, **kwarg """ # Set parameter data type params_dtype = kwargs.get("params_dtype", torch.float32) - FFN_HIDDEN_SIZE = ( - 64 if FP8 else 32 - ) # larger tensors lead to numerical failures with thight atol and rtol + FFN_HIDDEN_SIZE = 32 if QUANTIZATION is None else 128 # Create models model_single_node = te.LayerNormMLP(HIDDEN_SIZE, FFN_HIDDEN_SIZE, **kwargs) @@ -636,9 +665,7 @@ def test_layernorm_mlp(): @run_distributed_test() def _test_transformer_layer_parallel(sequence_parallel=False, **kwargs): params_dtype = kwargs.get("params_dtype", torch.float32) - FFN_HIDDEN_SIZE = ( - 64 if FP8 else 32 - ) # larger tensors lead to numerical failures with thight atol and rtol + FFN_HIDDEN_SIZE = 32 if QUANTIZATION is None else 128 model_single_node = te.TransformerLayer( HIDDEN_SIZE, FFN_HIDDEN_SIZE, NR_HEADS, attention_dropout=0, hidden_dropout=0, **kwargs diff --git a/tests/pytorch/distributed/test_comm_gemm_overlap.py b/tests/pytorch/distributed/test_comm_gemm_overlap.py index c285da7fbd..52420efca5 100644 --- a/tests/pytorch/distributed/test_comm_gemm_overlap.py +++ b/tests/pytorch/distributed/test_comm_gemm_overlap.py @@ -16,23 +16,22 @@ fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() -RNG_SEED: int = 1234 -SEQ_LENGTH: int = 512 +RNG_SEED: int = 42 +SEQ_LENGTH: int = 1024 BATCH_SIZE: int = 2 -NUM_HEADS: int = 12 -HEAD_DIM: int = 64 - -# NOTE: te.Linear is intentionally omitted here and manually added later for testing both -# row and column parallel layouts. +NUM_HEADS: int = 16 +HEAD_DIM: int = 48 TE_LAYERS = [ + te.Linear, te.LayerNormLinear, te.LayerNormMLP, te.MultiheadAttention, te.TransformerLayer, ] +MAX_LAYER_NAME_LENGTH = max([len(layer.__name__) for layer in TE_LAYERS]) TEST_ROOT = Path(__file__).parent.resolve() -NUM_PROCS: int = min(torch.cuda.device_count(), 4) +NUM_PROCS: int = torch.cuda.device_count() LAUNCH_CMD = ["torchrun", f"--nproc_per_node={NUM_PROCS}"] if tex.ubuf_built_with_mpi(): LAUNCH_CMD = ["mpirun", "-np", str(NUM_PROCS), "--oversubscribe", "--quiet", "python"] @@ -48,7 +47,7 @@ torch._dynamo.reset() -def _run_gemm_with_overlap(comm_type, bulk, p2p, atomic, fp8_in, fp8_out, aggregate): +def _run_gemm_with_overlap(comm_type, bulk, p2p, atomic, fp8): test_path = TEST_ROOT / "run_gemm_with_overlap.py" test_cmd = LAUNCH_CMD + [ str(test_path), @@ -64,19 +63,15 @@ def _run_gemm_with_overlap(comm_type, bulk, p2p, atomic, fp8_in, fp8_out, aggreg if bulk: test_cmd.append("--bulk-overlap") else: - if fp8_in: + if fp8: if not fp8_available: pytest.skip(reason_for_no_fp8) test_cmd.append("--fp8") - if fp8_out: - test_cmd.append("--fp8-output") if p2p: test_cmd.append("--p2p") - if aggregate: - test_cmd.append("--aggregate") if atomic: - if torch.cuda.get_device_properties(0).major < 9: - pytest.skip("Device compute capability 9.0 or higher required for Atomic GEMM.") + if torch.cuda.get_device_properties(0).major != 9: + pytest.skip("Atomic GEMM is requires device compute capability 9.x (Hopper).") test_cmd.append("--atomic") result = subprocess.run(test_cmd, env=os.environ, capture_output=True, check=False) @@ -88,7 +83,7 @@ def _run_gemm_with_overlap(comm_type, bulk, p2p, atomic, fp8_in, fp8_out, aggreg raise AssertionError(result.stderr.decode()) -def _run_layer_with_overlap(layer_type, linear_parallel_mode, fp8, fp8_init): +def _run_layer_with_overlap(layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8): test_path = TEST_ROOT / "run_layer_with_overlap.py" test_cmd = LAUNCH_CMD + [ str(test_path), @@ -99,15 +94,16 @@ def _run_layer_with_overlap(layer_type, linear_parallel_mode, fp8, fp8_init): f"--head-dim={HEAD_DIM}", f"--layer-type={layer_type}", ] - if layer_type == te.Linear.__name__: + if layer_type in [te.Linear.__name__, te.LayerNormLinear.__name__]: test_cmd.append(f"--linear-parallel-mode={linear_parallel_mode}") + if overlap_rs_dgrad: + test_cmd.append("--overlap-rs-dgrad") + if fp8: if not fp8_available: pytest.skip(reason_for_no_fp8) test_cmd.append("--fp8") - if fp8_init: - test_cmd.append("--fp8-init") os.environ["PYTORCH_JIT"] = "0" os.environ["NVTE_TORCH_COMPILE"] = "0" @@ -128,88 +124,39 @@ def _run_layer_with_overlap(layer_type, linear_parallel_mode, fp8, fp8_init): @pytest.mark.parametrize( - "fp8,aggregate", - [ - (False, False), - (False, True), - (True, False), - (True, True), - ], - ids=[ - " BF16 IN - RING-EXCHANGE ", - " BF16 IN - RING-EXCHANGE - 2x AGGREGATED ", - " FP8 IN - RING-EXCHANGE ", - " FP8 IN - RING-EXCHANGE - 2x AGGREGATED ", - ], + "fp8", + (False, True), + ids=[" BF16 - RING-EXCHANGE ", " FP8 - RING-EXCHANGE "], ) -def test_split_all_gather_overlaps(fp8, aggregate): +def test_split_all_gather_overlaps(fp8): """ Test (split GEMM -> all-gather) overlaps with direct calls to te.cpp_extensions.gemm or te.cpp_extensions.fp8_gemm. """ - _run_gemm_with_overlap("AG", False, True, False, fp8, False, aggregate) + _run_gemm_with_overlap("AG", False, True, False, fp8) @pytest.mark.parametrize( - "fp8_in,fp8_out,p2p", + "fp8,p2p", [ - (False, False, False), - (False, False, True), - (True, False, False), - (True, False, True), - (True, True, False), - (True, True, True), + (False, False), + (False, True), + (True, False), + (True, True), ], ids=[ - " BF16 IN - BF16 OUT - PIPELINE ", - " BF16 IN - BF16 OUT - RING-EXCHANGE ", - " FP8 IN - BF16 OUT - PIPELINE ", - " FP8 IN - BF16 OUT - RING-EXCHANGE ", - " FP8 IN - FP8 OUT - PIPELINE ", - " FP8 IN - FP8 OUT - RING-EXCHANGE ", + " BF16 - PIPELINE ", + " BF16 - RING-EXCHANGE ", + " FP8 - PIPELINE ", + " FP8 - RING-EXCHANGE ", ], ) -def test_split_reduce_scatter_overlaps(fp8_in, fp8_out, p2p): +def test_split_reduce_scatter_overlaps(fp8, p2p): """ Test (reduce-scatter -> split GEMM) overlaps with direct calls to te.cpp_extensions.gemm or te.cpp_extensions.fp8_gemm. """ - _run_gemm_with_overlap("RS", False, p2p, False, fp8_in, fp8_out, False) - - -@pytest.mark.parametrize( - "ag_type,rs_type,p2p,fp8_out", - [ - (0, 0, False, False), - (0, 1, False, False), - (0, 1, False, True), - (0, 2, False, False), - (0, 2, False, True), - (0, 0, True, False), - (0, 0, True, True), - (1, 0, True, False), - (1, 0, True, True), - ], - ids=[ - " NON-ATOMIC AG - NON-ATOMIC RS - PIPELINE - BF16 OUT ", - " NON-ATOMIC AG - ATOMIC RS - PIPELINE - BF16 OUT ", - " NON-ATOMIC AG - ATOMIC RS - PIPELINE - FP8 OUT ", - " NON-ATOMIC AG - MULTI-ATOMIC RS - PIPELINE - BF16 OUT ", - " NON-ATOMIC AG - MULTI-ATOMIC RS - PIPELINE - FP8 OUT ", - " NON-ATOMIC AG - NON-ATOMIC RS - RING-EXCHANGE - BF16 OUT ", - " NON-ATOMIC AG - NON-ATOMIC RS - RING-EXCHANGE - FP8 OUT ", - " MULTI-ATOMIC AG - NON-ATOMIC RS - RING-EXCHANGE - BF16 OUT ", - " MULTI-ATOMIC AG - NON-ATOMIC RS - RING-EXCHANGE - FP8 OUT ", - ], -) -def test_atomic_gemm_overlaps(ag_type, rs_type, p2p, fp8_out): - """ - Test paired (all-gather -> atomic GEMM) and (atomic GEMM -> reduce-scatter) overlaps with - direct calls to te.cpp_extensions.gemm or te.cpp_extensions.fp8_gemm. - """ - os.environ["NVTE_AG_P2P_MULTI_ATOMIC"] = str(ag_type) - os.environ["NVTE_RS_STRIDED_ATOMIC"] = str(rs_type) - _run_gemm_with_overlap("AG", False, p2p, True, True, fp8_out, False) + _run_gemm_with_overlap("RS", False, p2p, False, fp8) @pytest.mark.parametrize( @@ -223,12 +170,12 @@ def test_atomic_gemm_overlaps(ag_type, rs_type, p2p, fp8_out): ("RS", True, 8), ], ids=[ - "ALL-GATHER - BF16 - 1 connections", + "ALL-GATHER - BF16 - 1 connections", "REDUCE-SCATTER - BF16 - 1 connections", - "REDUCE-SCATTER - FP8 - 1 connections", - "ALL-GATHER - BF16 - 8 connections", + "REDUCE-SCATTER - FP8 - 1 connections", + "ALL-GATHER - BF16 - 8 connections", "REDUCE-SCATTER - BF16 - 8 connections", - "REDUCE-SCATTER - FP8 - 8 connections", + "REDUCE-SCATTER - FP8 - 8 connections", ], ) def test_bulk_overlaps(comm_type, fp8, connections): @@ -242,38 +189,48 @@ def test_bulk_overlaps(comm_type, fp8, connections): " 9.0 (HOPPER ARCH)." ) os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "8" - _run_gemm_with_overlap(comm_type, True, False, False, fp8, False, False) + _run_gemm_with_overlap(comm_type, True, False, False, fp8) os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" else: - _run_gemm_with_overlap(comm_type, True, False, False, fp8, False, False) + _run_gemm_with_overlap(comm_type, True, False, False, fp8) +@pytest.mark.parametrize("fp8", (False, True), ids=[" BF16 ", " FP8 "]) @pytest.mark.parametrize( - "layer_type,linear_parallel_mode", - ( - [(te.Linear.__name__, "row"), (te.Linear.__name__, "column")] - + list(zip([layer.__name__ for layer in TE_LAYERS], [None for _ in range(len(TE_LAYERS))])) - ), - ids=( - [f" {te.Linear.__name__} (row-parallel) ", f" {te.Linear.__name__} (column-parallel) "] - + [(" " + layer.__name__ + " ") for layer in TE_LAYERS] - ), -) -@pytest.mark.parametrize( - "fp8,fp8_init", + "layer_type,linear_parallel_mode,overlap_rs_dgrad", [ - (False, False), - (True, False), - (True, True), - ], + (te.Linear.__name__, "row", False), + (te.Linear.__name__, "column", False), + (te.Linear.__name__, "column", True), + (te.LayerNormLinear.__name__, "row", False), + (te.LayerNormLinear.__name__, "column", False), + (te.LayerNormLinear.__name__, "column", True), + ] + + list( + zip( + [layer.__name__ for layer in TE_LAYERS[2:] for _ in range(2)], + [None] * len(TE_LAYERS[2:]) * 2, + [False, True] * len(TE_LAYERS[2:]), + ) + ), ids=[ - " BF16 GEMM - BF16 PARAMS ", - " FP8 GEMM - BF16 PARAMS ", - " FP8 GEMM - FP8 PARAMS ", + f" {te.Linear.__name__} - ROW-PARALLEL ", + f" {te.Linear.__name__} - COL-PARALLEL - BULK DGRAD/WGRAD ", + f" {te.Linear.__name__} - COL-PARLALEL - DGRAD+RS ", + f" {te.LayerNormLinear.__name__} - ROW-PARALLEL ", + f" {te.LayerNormLinear.__name__} - COL-PARALLEL - BULK DGRAD/WGRAD ", + f" {te.LayerNormLinear.__name__} - COL-PARALLEL - DGRAD+RS ", + ] + + [ + " " + " - ".join(test_name_parts) + " " + for test_name_parts in zip( + [layer.__name__ for layer in TE_LAYERS[2:] for _ in range(2)], + ["BULK DGRAD/WGRAD", "DGRAD+RS"] * len(TE_LAYERS[2:]), + ) ], ) -def test_layers_with_overlap(layer_type, linear_parallel_mode, fp8, fp8_init): +def test_layers_with_overlap(layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8): """ Test Transformer Engine layers with comm+GEMM overlap. """ - _run_layer_with_overlap(layer_type, linear_parallel_mode, fp8, fp8_init) + _run_layer_with_overlap(layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8) diff --git a/tests/pytorch/distributed/test_fusible_ops.py b/tests/pytorch/distributed/test_fusible_ops.py index 598859b826..c8ef7687fa 100644 --- a/tests/pytorch/distributed/test_fusible_ops.py +++ b/tests/pytorch/distributed/test_fusible_ops.py @@ -5,27 +5,38 @@ from __future__ import annotations import argparse +from collections.abc import Iterable import functools import itertools import os import pathlib import subprocess import sys +from typing import Optional import pytest import torch import transformer_engine +import transformer_engine.common.recipe import transformer_engine.pytorch as te -from transformer_engine.pytorch.float8_tensor import Float8Tensor from transformer_engine.pytorch.fp8 import FP8GlobalStateManager +from transformer_engine.pytorch.tensor import QuantizedTensor +from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer import transformer_engine.pytorch.ops as te_ops from transformer_engine.pytorch.ops._common import is_float8_tensor from transformer_engine.pytorch.utils import is_bf16_compatible import transformer_engine_torch as tex -# Check if FP8 is supported + +# Check what quantization schemes are supported fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() +mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() +quantization_list: list[Optional[str]] = [None] +if fp8_available: + quantization_list.append("fp8") +if mxfp8_available: + quantization_list.append("mxfp8") @functools.cache @@ -66,22 +77,18 @@ def make_reference_and_test_tensors( in Transformer Engine operations. """ - - # Random data ref = torch.rand(shape, dtype=ref_dtype, device=ref_device) - - # Make copy of tensor + test = ref.to(device=test_device, dtype=test_dtype) if test_is_fp8: - test = Float8Tensor.to_float8(ref) - else: - test = ref.to(device=test_device, dtype=test_dtype) - if test.data_ptr() == ref.data_ptr(): - test = test.clone() - - # Make sure reference and test tensors represent exact same values + quantizer = Float8Quantizer( + scale=torch.ones(1, dtype=torch.float32, device=test_device), + amax=torch.zeros(1, dtype=torch.float32, device=test_device), + fp8_dtype=tex.DType.kFloat8E4M3, + ) + test = quantizer(test) + elif test.data_ptr() == ref.data_ptr(): + test = test.clone() ref.copy_(test) - - # Return reference and test tensors ref.requires_grad_(requires_grad) test.requires_grad_(requires_grad) return ref, test @@ -120,6 +127,21 @@ def dtype_tols(dtype: torch.dtype | tex.DType) -> dict[str, float]: raise ValueError(f"Unsupported dtype ({dtype})") +def make_recipe(name: Optional[str] = None) -> Optional[Recipe]: + """Make recipe for quantization scheme""" + if name is None: + return None + if name == "fp8": + return transformer_engine.common.recipe.DelayedScaling( + fp8_format=transformer_engine.common.recipe.Format.E4M3, + ) + if name == "mxfp8": + return transformer_engine.common.recipe.MXFP8BlockScaling( + fp8_format=transformer_engine.common.recipe.Format.E4M3, + ) + raise ValueError(f"Unsupported quantization scheme ({name})") + + def _test_all_reduce( *, local_size: int = 17, @@ -293,17 +315,16 @@ def _test_reduce_scatter( def _test_basic_linear( *, - local_weight_shape: tuple[int, int] = (16, 16), - batch_size: int = 16, + local_weight_shape: tuple[int, int] = (32, 32), + local_batch_size: int = 32, dtype: torch.dtype = torch.float32, device: torch.device = "cuda", - fp8_compute: bool = False, - fp8_input: bool = False, - fp8_weight: bool = False, - fp8_grad_output: bool = False, + quantization: Optional[str] = None, + quantized_weight: bool = False, tensor_parallel_mode: str = "column", sequence_parallel: bool = False, ) -> None: + quantized_compute = quantization is not None # Distributed process group process_group = world_group() @@ -313,10 +334,13 @@ def _test_basic_linear( # Tensor dimensions local_out_features, local_in_features = local_weight_shape out_features, in_features = local_out_features, local_in_features + batch_size = local_batch_size if tensor_parallel_mode == "column": out_features *= world_size elif tensor_parallel_mode == "row": in_features *= world_size + if sequence_parallel: + batch_size *= world_size in_shape = [batch_size, in_features] out_shape = [batch_size, out_features] @@ -326,21 +350,28 @@ def _test_basic_linear( in_shape, test_dtype=dtype, test_device=device, - test_is_fp8=(fp8_compute or fp8_input), + test_is_fp8=quantized_compute, ) + if isinstance(x_test, QuantizedTensor): + with torch.no_grad(): + x_test = x_test.dequantize().requires_grad_() w_ref, w_test = make_reference_and_test_tensors( (out_features, in_features), test_dtype=dtype, test_device=device, - test_is_fp8=(fp8_compute or fp8_weight), + test_is_fp8=(quantized_compute or quantized_weight), ) + if isinstance(w_test, QuantizedTensor): + w_test = w_test.dequantize() dy_ref, dy_test = make_reference_and_test_tensors( out_shape, test_dtype=dtype, test_device=device, - test_is_fp8=(fp8_compute or fp8_grad_output), + test_is_fp8=quantized_compute, requires_grad=False, ) + if isinstance(dy_test, QuantizedTensor): + dy_test = dy_test.dequantize() # Plain PyTorch implementation y_ref = torch.nn.functional.linear(x_ref, w_ref) @@ -391,7 +422,8 @@ def _test_basic_linear( x_test.requires_grad_() # Implementation with fusible operation - with te.fp8_model_init(enabled=fp8_weight): + recipe = make_recipe(quantization) + with te.fp8_model_init(enabled=quantized_weight, recipe=recipe): op = te_ops.BasicLinear( in_features, out_features, @@ -404,7 +436,7 @@ def _test_basic_linear( with torch.no_grad(): op.weight.copy_(w_test) del w_test - with te.fp8_autocast(enabled=fp8_compute): + with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): y_test = op(x_test) y_test.backward(dy_test) @@ -412,10 +444,8 @@ def _test_basic_linear( tols = dtype_tols(dtype) if dtype == torch.float32: tols = dtype_tols(torch.float16) # TF32 GEMM - if fp8_compute: - tols = dtype_tols( - op.weight._fp8_dtype if is_float8_tensor(op.weight) else tex.DType.kFloat8E4M3 - ) + if quantized_compute: + tols = dtype_tols(tex.DType.kFloat8E4M3) # Check results y_test = y_test.to(dtype=torch.float64, device="cpu") @@ -429,17 +459,16 @@ def _test_basic_linear( def _test_linear( *, bias: bool = True, - local_weight_shape: tuple[int, int] = (16, 16), - batch_size: int = 16, + local_weight_shape: tuple[int, int] = (32, 32), + local_batch_size: int = 32, dtype: torch.dtype = torch.float32, device: torch.device = "cuda", - fp8_compute: bool = False, - fp8_input: bool = False, - fp8_weight: bool = False, - fp8_grad_output: bool = False, + quantization: Optional[str] = None, + quantized_weight: bool = False, tensor_parallel_mode: str = "column", sequence_parallel: bool = False, ) -> None: + quantized_compute = quantization is not None # Distributed process group process_group = world_group() @@ -449,10 +478,13 @@ def _test_linear( # Tensor dimensions local_out_features, local_in_features = local_weight_shape out_features, in_features = local_out_features, local_in_features + batch_size = local_batch_size if tensor_parallel_mode == "column": out_features *= world_size elif tensor_parallel_mode == "row": in_features *= world_size + if sequence_parallel: + batch_size *= world_size in_shape = [batch_size, in_features] out_shape = [batch_size, out_features] @@ -462,14 +494,19 @@ def _test_linear( in_shape, test_dtype=dtype, test_device=device, - test_is_fp8=(fp8_compute or fp8_input), + test_is_fp8=quantized_compute, ) + if isinstance(x_test, QuantizedTensor): + with torch.no_grad(): + x_test = x_test.dequantize().requires_grad_() w_ref, w_test = make_reference_and_test_tensors( (out_features, in_features), test_dtype=dtype, test_device=device, - test_is_fp8=(fp8_compute or fp8_weight), + test_is_fp8=(quantized_compute or quantized_weight), ) + if isinstance(w_test, QuantizedTensor): + w_test = w_test.dequantize() b_ref, b_test = None, None if bias: if tensor_parallel_mode == "row": @@ -485,9 +522,11 @@ def _test_linear( out_shape, test_dtype=dtype, test_device=device, - test_is_fp8=(fp8_compute or fp8_grad_output), + test_is_fp8=quantized_compute, requires_grad=False, ) + if isinstance(dy_test, QuantizedTensor): + dy_test = dy_test.dequantize() # Plain PyTorch implementation y_ref = torch.nn.functional.linear(x_ref, w_ref) @@ -552,7 +591,8 @@ def _test_linear( x_test.requires_grad_() # Implementation with fusible operation - with te.fp8_model_init(enabled=fp8_weight): + recipe = make_recipe(quantization) + with te.fp8_model_init(enabled=quantized_weight, recipe=recipe): model = te_ops.Sequential( te_ops.Linear( in_features, @@ -571,7 +611,7 @@ def _test_linear( model[0].bias.copy_(b_test) del w_test del b_test - with te.fp8_autocast(enabled=fp8_compute): + with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): y_test = model(x_test) y_test.backward(dy_test) @@ -579,12 +619,8 @@ def _test_linear( tols = dtype_tols(dtype) if dtype == torch.float32: tols = dtype_tols(torch.float16) # TF32 GEMM - if fp8_compute: - tols = dtype_tols( - model[0].weight._fp8_dtype - if is_float8_tensor(model[0].weight) - else tex.DType.kFloat8E4M3 - ) + if quantized_compute: + tols = dtype_tols(tex.DType.kFloat8E4M3) # Check results y_test = y_test.to(dtype=torch.float64, device="cpu") @@ -603,8 +639,8 @@ def _test_fp8_scale_update( amax_history_len: int = 31, amax_compute_algo: str = "max", margin: float = 2, - local_weight_shape: tuple[int, int] = (16, 16), - batch_size: int = 16, + local_weight_shape: tuple[int, int] = (32, 32), + batch_size: int = 32, dtype: torch.dtype = torch.float32, device: torch.device = "cuda", tensor_parallel_mode: str = "column", @@ -715,20 +751,12 @@ def ref_amax_and_scale( y_test.backward(dy_test) # Check results - forward_key = FP8GlobalStateManager.get_meta_tensor_key(forward=True) - backward_key = FP8GlobalStateManager.get_meta_tensor_key(forward=False) - x_fp8_meta = op.get_fp8_meta("input")[forward_key] - w_fp8_meta = op.get_fp8_meta("param")[forward_key] - dy_fp8_meta = op.get_fp8_meta("grad_output")[backward_key] - x_amax_test = x_fp8_meta.amax_history[-1, 0].to(dtype=torch.float32, device="cpu") - w_amax_test = w_fp8_meta.amax_history[-1, 0].to(dtype=torch.float32, device="cpu") - dy_amax_test = dy_fp8_meta.amax_history[-1, 0].to(dtype=torch.float32, device="cpu") - x_scale_test = x_fp8_meta.scale[0].to(dtype=torch.float32, device="cpu") - w_scale_test = w_fp8_meta.scale[0].to(dtype=torch.float32, device="cpu") - dy_scale_test = dy_fp8_meta.scale[0].to(dtype=torch.float32, device="cpu") - torch.testing.assert_close(x_amax_test, x_amax_ref) - torch.testing.assert_close(w_amax_test, w_amax_ref) - torch.testing.assert_close(dy_amax_test, dy_amax_ref) + x_quantizer = op.get_quantizer("forward", 0) + w_quantizer = op.get_quantizer("forward", 1) + dy_quantizer = op.get_quantizer("backward", 0) + x_scale_test = x_quantizer.scale.to(dtype=torch.float32, device="cpu").reshape([]) + w_scale_test = w_quantizer.scale.to(dtype=torch.float32, device="cpu").reshape([]) + dy_scale_test = dy_quantizer.scale.to(dtype=torch.float32, device="cpu").reshape([]) torch.testing.assert_close(x_scale_test, x_scale_ref) torch.testing.assert_close(w_scale_test, w_scale_ref) torch.testing.assert_close(dy_scale_test, dy_scale_ref) @@ -755,38 +783,32 @@ def run_parallel_tests() -> None: # Basic linear op for config in itertools.product( - (False, True) if fp8_available else (False,), + quantization_list, ("column", "row"), (False, True), ): if rank == 0: print(f"Running _test_basic_linear with {config=}") - fp8, tensor_parallel_mode, sequence_parallel = config + quantization, tensor_parallel_mode, sequence_parallel = config _test_basic_linear( - fp8_compute=fp8, - fp8_input=fp8, - fp8_weight=fp8, - fp8_grad_output=fp8, + quantization=quantization, tensor_parallel_mode=tensor_parallel_mode, sequence_parallel=sequence_parallel, ) # Linear op for config in itertools.product( - (False, True) if fp8_available else (False,), + quantization_list, ("column", "row"), ): if rank == 0: print(f"Running _test_linear with {config=}") - fp8, tensor_parallel_mode = config + quantization, tensor_parallel_mode = config dtype = torch.bfloat16 if is_bf16_compatible() else torch.float32 _test_linear( bias=True, # bias=False is tested in _test_basic_linear dtype=dtype, - fp8_compute=fp8, - fp8_input=fp8, - fp8_weight=fp8, - fp8_grad_output=fp8, + quantization=quantization, tensor_parallel_mode=tensor_parallel_mode, ) diff --git a/tests/pytorch/distributed/test_numerics.py b/tests/pytorch/distributed/test_numerics.py index 1a6191f06c..7be9cd01ae 100644 --- a/tests/pytorch/distributed/test_numerics.py +++ b/tests/pytorch/distributed/test_numerics.py @@ -27,29 +27,31 @@ pytest.skip("Distributed training needs at least 2 GPUs.") fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() +mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() TEST_ROOT = Path(__file__).parent.resolve() NUM_PROCS: int = min(4, torch.cuda.device_count()) LAUNCH_CMD = ["torchrun", f"--nproc_per_node={NUM_PROCS}"] -def _run_test(fp8): +def _run_test(quantization): test_path = TEST_ROOT / "run_numerics.py" test_cmd = LAUNCH_CMD + [str(test_path)] - if fp8: - test_cmd += ["--fp8"] + if quantization is not None: + test_cmd += ["--quantization", quantization] - result = subprocess.run(test_cmd, env=os.environ, capture_output=True, check=False) - if result.returncode != 0 or "NUMERICAL CHECK FAILED" in result.stderr.decode(): - raise AssertionError(result.stderr.decode()) + result = subprocess.run(test_cmd, env=os.environ, check=False) + assert result.returncode == 0 all_boolean = [True, False] -@pytest.mark.parametrize("fp8", all_boolean) -def test_distributed(fp8): - if fp8 and not fp8_available: +@pytest.mark.parametrize("quantization", [None, "fp8", "mxfp8"]) +def test_distributed(quantization): + if quantization == "fp8" and not fp8_available: pytest.skip(reason_for_no_fp8) - _run_test(fp8) + if quantization == "mxfp8" and not mxfp8_available: + pytest.skip(reason_for_no_mxfp8) + _run_test(quantization) diff --git a/tests/pytorch/distributed/test_torch_fsdp2.py b/tests/pytorch/distributed/test_torch_fsdp2.py index 02a85f0ac4..4298d17c9c 100644 --- a/tests/pytorch/distributed/test_torch_fsdp2.py +++ b/tests/pytorch/distributed/test_torch_fsdp2.py @@ -12,7 +12,7 @@ def get_torch_version(): - """Get pytorch version from __version__""" + """Get PyTorch version from __version__""" def get_torch_version_str(): import torch @@ -22,25 +22,14 @@ def get_torch_version_str(): return PkgVersion(get_torch_version_str()) -if torch.cuda.device_count() < 4: - pytest.skip("FSDP2 test requires at least 4 GPUs.") - -if torch.cuda.device_count() % 2 != 0: - pytest.skip("Number of device should be divided by 2.") - -if not get_torch_version() >= PkgVersion("2.4"): - pytest.skip("FSDP2 requires PyTorch >= 2.4.0 with FSDP 2 support.") - fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() -TEST_ROOT = Path(__file__).parent.resolve() NUM_PROCS: int = torch.cuda.device_count() -LAUNCH_CMD = ["torchrun", f"--nproc_per_node={NUM_PROCS}"] def _run_test(fp_init, sharding_dims): - test_path = TEST_ROOT / "run_fsdp2_model.py" - test_cmd = LAUNCH_CMD + [str(test_path)] + test_path = Path(__file__).parent.resolve() / "run_fsdp2_model.py" + test_cmd = ["torchrun", f"--nproc_per_node={NUM_PROCS}", str(test_path)] if fp_init: test_cmd += ["--fp8-init"] @@ -50,18 +39,30 @@ def _run_test(fp_init, sharding_dims): test_cmd += ["--sharding-dims", str(sharding_dims[0]), str(sharding_dims[1])] else: assert False - result = subprocess.run(test_cmd, env=os.environ, capture_output=True, check=False) - if result.returncode != 0: - raise AssertionError(result.stderr.decode()) + result = subprocess.run(test_cmd, env=os.environ, check=True) -all_boolean = [True, False] -sharding_dims = [[NUM_PROCS], [2, NUM_PROCS // 2]] +@pytest.mark.skipif(NUM_PROCS < 4, reason="Requires 4+ GPUs") +@pytest.mark.skipif(NUM_PROCS % 2 != 0, reason="Requires even number of GPUs") +@pytest.mark.skipif(not get_torch_version() >= PkgVersion("2.4"), reason="Requires PyTorch 2.4.0+") +@pytest.mark.parametrize("sharding_dims", ([NUM_PROCS], [2, NUM_PROCS // 2])) +@pytest.mark.parametrize("fp8_init", (False, True)) +def test_distributed(fp8_init, sharding_dims): + # Skip invalid configurations + if torch.cuda.device_count() < 4: + pytest.skip("FSDP2 test requires at least 4 GPUs") -@pytest.mark.parametrize("sharding_dims", sharding_dims) -@pytest.mark.parametrize("fp8_init", all_boolean) -def test_distributed(fp8_init, sharding_dims): if fp8_init and not fp8_available: pytest.skip(reason_for_no_fp8) + _run_test(fp8_init, sharding_dims) + + +def test_dummy() -> None: + """Dummy test + + pytest returns exit code 5 if all tests are skipped. + + """ + pass diff --git a/tests/pytorch/fused_attn/run_fused_attn_with_cp.py b/tests/pytorch/fused_attn/run_fused_attn_with_cp.py index 1fae9e99f2..4a1fd17be7 100644 --- a/tests/pytorch/fused_attn/run_fused_attn_with_cp.py +++ b/tests/pytorch/fused_attn/run_fused_attn_with_cp.py @@ -11,7 +11,7 @@ import transformer_engine_torch as tex from test_fused_attn_with_cp import model_configs_flash_attn, model_configs_fused_attn from transformer_engine.pytorch.fp8 import fp8_autocast -from transformer_engine.pytorch.float8_tensor import Float8Tensor +from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8Quantizer from transformer_engine.common.recipe import DelayedScaling dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16} @@ -176,6 +176,11 @@ def run_dpa_with_cp( k = torch.randn(kv_input_shape, dtype=dtypes[dtype]).cuda() v = torch.randn(kv_input_shape, dtype=dtypes[dtype]).cuda() dout = torch.randn(attn_output_shape, dtype=dtypes[dtype]).cuda() + dout_quantizer = Float8Quantizer( + fp8_dtype=tex.DType.kFloat8E5M2, + scale=torch.tensor([1], dtype=torch.float32).cuda(), + amax=torch.tensor([0], dtype=torch.float32).cuda(), + ) # create flash attention bias if config.attn_bias_type not in ["no_bias", "alibi"]: @@ -206,7 +211,7 @@ def run_dpa_with_cp( cu_seqlens_kv_padded=cu_seqlens_kv_padded, ) if fp8_mha: - dout_fp8 = Float8Tensor.to_float8(dout, fp8_dtype=tex.DType.kFloat8E5M2) + dout_fp8 = dout_quantizer(dout) out.backward(dout_fp8) else: out.backward(dout) @@ -276,7 +281,7 @@ def run_dpa_with_cp( cu_seqlens_kv_padded=cu_seqlens_kv_padded, ) if fp8_mha: - dout_fp8_ = Float8Tensor.to_float8(dout_, fp8_dtype=tex.DType.kFloat8E5M2) + dout_fp8_ = dout_quantizer(dout_) out_.backward(dout_fp8_) else: out_.backward(dout_) diff --git a/tests/pytorch/fused_attn/test_fused_attn.py b/tests/pytorch/fused_attn/test_fused_attn.py index d546118ffb..ff45d1e38f 100644 --- a/tests/pytorch/fused_attn/test_fused_attn.py +++ b/tests/pytorch/fused_attn/test_fused_attn.py @@ -20,6 +20,7 @@ MultiheadAttention, RotaryPositionEmbedding, get_attention_backend, + _flash_attn_is_installed, _flash_attn_2_3_plus, _flash_attn_3_is_installed, check_set_window_size, @@ -48,6 +49,12 @@ from transformer_engine.pytorch.utils import get_cudnn_version import transformer_engine_torch as tex from transformer_engine_torch import NVTE_Fused_Attn_Backend +from transformer_engine.pytorch.tensor.quantized_tensor import ( + QuantizedTensor, + Quantizer, + prepare_for_saving, + restore_from_saved, +) # Only run FP8 tests on H100 fp8_available, reason_for_no_fp8 = fp8.FP8GlobalStateManager.is_fp8_available() @@ -257,11 +264,17 @@ def test_dot_product_attention( pad_between_seqs=pad_between_seqs, ) flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends + # FlashAttention does not support pad_between_seqs, but _run_dot_product_attention # mannually pads and unpads the input and output of FlashAttention for testing purposes - if pad_between_seqs and not ( - config.max_seqlen_q != config.max_seqlen_kv - and config.attn_mask_type in ["causal", "padding_causal"] + if ( + pad_between_seqs + and _flash_attn_is_installed + and not ( + config.max_seqlen_q != config.max_seqlen_kv + and config.attn_mask_type in ["causal", "padding_causal"] + ) + and (config.window_size[0] == -1 or _flash_attn_2_3_plus) ): flash_attn_supported = True @@ -1365,13 +1378,18 @@ def _run_transformer_layer( model_configs_fp8_vs_f16 = { # test: b, h, hg, d, sq, skv, p, mask, bias - "fp8_9": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), - "fp8_10": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "causal", "no_bias"), - "fp8_11": ModelConfig(2, 24, 12, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), - "fp8_12": ModelConfig(2, 24, 12, 128, 2048, 2048, 0.0, "causal", "no_bias"), - "fp8_13": ModelConfig(1, 32, 4, 128, 8192, 8192, 0.0, "no_mask", "no_bias"), + "fp8_9": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), + "fp8_10": ModelConfig(2, 24, 12, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), + "fp8_11": ModelConfig(1, 32, 4, 128, 8192, 8192, 0.0, "no_mask", "no_bias"), + "fp8_12": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "causal", "no_bias"), + "fp8_13": ModelConfig(2, 24, 12, 128, 2048, 2048, 0.0, "causal", "no_bias"), "fp8_14": ModelConfig(1, 32, 4, 128, 8192, 8192, 0.0, "causal", "no_bias"), - "fp8_15": ModelConfig(1, 16, 16, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), + "fp8_15": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "padding", "no_bias"), + "fp8_16": ModelConfig(2, 24, 12, 128, 2048, 2048, 0.0, "padding", "no_bias"), + "fp8_17": ModelConfig(1, 32, 4, 128, 8192, 8192, 0.0, "padding", "no_bias"), + "fp8_18": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"), + "fp8_19": ModelConfig(2, 24, 12, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"), + "fp8_20": ModelConfig(1, 32, 4, 128, 8192, 8192, 0.0, "padding_causal", "no_bias"), } param_types_fp8_vs_f16 = [torch.float16, torch.bfloat16] @@ -1420,8 +1438,14 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1" os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0" config = model_configs_fp8_vs_f16[model] + if ("padding" in config.attn_mask_type or config.head_dim_qk != 128) and get_cudnn_version() < ( + 9, + 7, + 0, + ): + pytest.skip("FP8 with padding or head_dim != 128 is not supported for cuDNN < 9.7") - if _flash_attn_3_is_installed and not is_training: + if _flash_attn_3_is_installed and not is_training and "padding" not in config.attn_mask_type: os.environ["NVTE_FLASH_ATTN"] = "1" os.environ["NVTE_FUSED_ATTN"] = "0" _attention_backends["backend_selection_requires_update"] = True @@ -1447,7 +1471,7 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, rtol = 5e-1 rmse_tol = 0.15 logging.debug("========== {:^25s} ==========".format("forward output")) - if _flash_attn_3_is_installed and not is_training: + if _flash_attn_3_is_installed and not is_training and "padding" not in config.attn_mask_type: _error( flash_attn_fwd_fp8, fused_attn_fwd_f16, @@ -1499,7 +1523,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: fp8_mha=fp8_mha, ) - with fp8_model_init(enabled=fp8_mha): + with fp8_model_init(enabled=fp8_mha, recipe=fp8_recipe): rotary_pos_emb = None if RoPE: PE = RotaryPositionEmbedding(dim=config.head_dim_qk) @@ -1523,12 +1547,26 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: if not is_training: mha = mha.eval() - seqlens_q = torch.full( - [config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda" - ) - seqlens_kv = torch.full( - [config.batch_size], config.max_seqlen_kv, dtype=torch.int32, device="cuda" - ) + if "padding" in config.attn_mask_type or qkv_format == "thd": + if config.attn_type == "self": + seqlens_q = torch.randint( + 1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda" + ) + seqlens_kv = seqlens_q + if config.attn_type == "cross": + seqlens_q = torch.randint( + 1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda" + ) + seqlens_kv = torch.randint( + 1, config.max_seqlen_kv, [config.batch_size], dtype=torch.int32, device="cuda" + ) + else: + seqlens_q = torch.full( + [config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda" + ) + seqlens_kv = torch.full( + [config.batch_size], config.max_seqlen_kv, dtype=torch.int32, device="cuda" + ) cu_seqlens_q = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda") cu_seqlens_kv = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda") cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0) @@ -1565,6 +1603,8 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: core_attention_bias_type=config.attn_bias_type, is_first_microbatch=None, rotary_pos_emb=rotary_pos_emb, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_kv, ) if is_training: out.backward(out_grad) @@ -1594,13 +1634,29 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training): config = model_configs_fp8_vs_f16[model] + # TODO(cyang): think of another way to verify dropout results + # test cuDNN FP8 dropout + # 1. we modify the config here to not affect mha_fp8_vs_f16 tests + # 2. there is no other backend that implements dropout the same way as cuDNN FP8, and as an + # indirect verification method, we create Q/K/V as all 1s and check if O is all 1s + # 3. we avoid running FP16/BF16 kernels as they do not have dropout support on Blackwell + # if "padding" not in config.attn_mask_type and "causal" not in config.attn_mask_type: + # if get_device_compute_capability() >= (10, 0): + # config.dropout_p = 0.1 + + if ("padding" in config.attn_mask_type or config.head_dim_qk != 128) and get_cudnn_version() < ( + 9, + 7, + 0, + ): + pytest.skip("FP8 with padding or head_dim != 128 is not supported for cuDNN < 9.7") if config.num_heads != config.num_gqa_groups and "3" in qkv_layout: pytest.skip("qkv_layout not applicable for MQA/GQA") os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0" os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1" - if _flash_attn_3_is_installed and not is_training: + if _flash_attn_3_is_installed and not is_training and "padding" not in config.attn_mask_type: os.environ["NVTE_FLASH_ATTN"] = "1" os.environ["NVTE_FUSED_ATTN"] = "0" _attention_backends["backend_selection_requires_update"] = True @@ -1617,17 +1673,19 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training): dtype, config, True, qkv_layout, is_training ) - logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = False") - fused_attn_fwd_f16, fused_attn_bwd_f16 = _run_dpa_fp8_vs_f16( - dtype, config, False, qkv_layout, is_training - ) + if config.dropout_p == 0.0: + # test cuDNN FP8 dropout: need a FP16/BF16 reference on Blackwell + logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = False") + fused_attn_fwd_f16, fused_attn_bwd_f16 = _run_dpa_fp8_vs_f16( + dtype, config, False, qkv_layout, is_training + ) atol = 5e-1 rtol = 5e-2 - rmse_tol = 0.1 + rmse_tol = 0.11 bwd_names = ["dq", "dk", "dv"] logging.debug("========== {:^25s} ==========".format("forward output")) - if _flash_attn_3_is_installed and not is_training: + if _flash_attn_3_is_installed and not is_training and "padding" not in config.attn_mask_type: _error( flash_attn_fwd_fp8, fused_attn_fwd_f16, @@ -1637,27 +1695,33 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training): rtol, rmse_tol, ) - _error( - fused_attn_fwd_fp8, - fused_attn_fwd_f16, - "fused_attn_fwd_fp8", - "fused_attn_fwd_f16", - atol, - rtol, - rmse_tol, - ) - if is_training: - for i, _ in enumerate(fused_attn_bwd_f16): - logging.debug("========== {:^25s} ==========".format(bwd_names[i])) - _error( - fused_attn_bwd_fp8[i], - fused_attn_bwd_f16[i], - f"fused_attn_bwd_fp8[{i}]", - f"fused_attn_bwd_f16[{i}]", - atol, - rtol, - rmse_tol, - ) + if config.dropout_p != 0.0: + # test cuDNN FP8 dropout + assert torch.all( + fused_attn_fwd_fp8 == 1 + ), "fused_attn_fwd_fp8 must be all 1s when Q/K/V are all 1s." + else: + _error( + fused_attn_fwd_fp8, + fused_attn_fwd_f16, + "fused_attn_fwd_fp8", + "fused_attn_fwd_f16", + atol, + rtol, + rmse_tol, + ) + if is_training: + for i, _ in enumerate(fused_attn_bwd_f16): + logging.debug("========== {:^25s} ==========".format(bwd_names[i])) + _error( + fused_attn_bwd_fp8[i], + fused_attn_bwd_f16[i], + f"fused_attn_bwd_fp8[{i}]", + f"fused_attn_bwd_f16[{i}]", + atol, + rtol, + rmse_tol, + ) def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout, is_training): @@ -1696,12 +1760,26 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: if not is_training: dpa = dpa.eval() - seqlens_q = torch.full( - [config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda" - ) - seqlens_kv = torch.full( - [config.batch_size], config.max_seqlen_kv, dtype=torch.int32, device="cuda" - ) + if "padding" in config.attn_mask_type or qkv_format == "thd": + if config.attn_type == "self": + seqlens_q = torch.randint( + 1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda" + ) + seqlens_kv = seqlens_q + if config.attn_type == "cross": + seqlens_q = torch.randint( + 1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda" + ) + seqlens_kv = torch.randint( + 1, config.max_seqlen_kv, [config.batch_size], dtype=torch.int32, device="cuda" + ) + else: + seqlens_q = torch.full( + [config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda" + ) + seqlens_kv = torch.full( + [config.batch_size], config.max_seqlen_kv, dtype=torch.int32, device="cuda" + ) cu_seqlens_q = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda") cu_seqlens_kv = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda") cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0) @@ -1730,7 +1808,11 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: layout = layout.replace("h", "hg") layout = layout.replace("t", "tg") tensor_shape = [dim_to_num[j] for j in layout.split("_")] - tensor = torch.randn(tensor_shape, dtype=dtype, device="cuda") + if config.dropout_p == 0.0: + tensor = torch.randn(tensor_shape, dtype=dtype, device="cuda") + else: + # test cuDNN FP8 dropout + tensor = torch.ones(tensor_shape, dtype=dtype, device="cuda") tensor_count = 1 split_dim = 0 for dim, l in enumerate(layout.split("_")): @@ -1766,7 +1848,6 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: attn_mask_type=config.attn_mask_type, checkpoint_core_attention=False, core_attention_bias_type=config.attn_bias_type, - is_first_microbatch=True, ) if is_training: out.backward(out_grad) @@ -1819,7 +1900,7 @@ def test_custom_mha_fp8_vs_f16(dtype, model): atol = 5e-1 rtol = 5e-1 - rmse_tol = 0.1 + rmse_tol = 0.13 _error( fused_attn_fwd_fp8, unfused_attn_fwd_f16, @@ -1973,7 +2054,9 @@ def forward( workspace: torch.Tensor, is_training: bool, mask_type: str, + quantizers: list[Quantizer], ) -> torch.Tensor: + qkv_dtype = inp.dtype assert inp.dim() == 2 in_features = qkv_weight.shape[-1] @@ -1981,83 +2064,53 @@ def forward( d = in_features // h b = cu_seqlens.numel() - 1 - fp8_dtype_forward = fp8.get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) + input_quantizer = quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT] + qkv_quantizer = quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_INPUT] + qkv_weight_quantizer = quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT] + o_quantizer = quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_OUTPUT] + dO_quantizer = quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT1] + dQKV_quantizer = quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_INPUT1] + s_quantizer = quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT2] + dP_quantizer = quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT3] - inp_fp8, inp_t_fp8 = ext.fp8_cast_transpose_fused( - inp, - fp8_meta["scaling_fwd"], - tex.FP8FwdTensors.GEMM1_INPUT, - fp8_dtype_forward, - ) - - qkv_weight_fp8, qkv_weight_t_fp8 = ext.fp8_cast_transpose_fused( - qkv_weight, - fp8_meta["scaling_fwd"], - tex.FP8FwdTensors.GEMM1_WEIGHT, - fp8_dtype_forward, - ) + inp_fp8 = input_quantizer(inp) - M = None - ZInv = None - philox_unpacked = None + qkv_weight_fp8 = qkv_weight_quantizer(qkv_weight) - qkv, _ = ext.fp8_gemm( + qkv, *_ = ext.general_gemm( qkv_weight_fp8, - fp8_meta["scaling_fwd"].scale_inv, - tex.FP8FwdTensors.GEMM1_WEIGHT, - fp8_dtype_forward, inp_fp8, - fp8_meta["scaling_fwd"].scale_inv, - tex.FP8FwdTensors.GEMM1_INPUT, - fp8_dtype_forward, - torch.uint8, workspace, bias=qkv_bias, - use_bias=True, - out_index=META_QKV, - fp8_meta_tensor=fp8_meta["scaling_fwd"], + out_dtype=qkv_weight_fp8.dtype, + quantization_params=qkv_quantizer, use_split_accumulator=_2X_ACC_FPROP, - D_dtype=fp8_dtype_forward, ) qkv = qkv.view(-1, 3, h, d) - qkv_fp16 = ( - ext.cast_from_fp8( - qkv, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward, tex.DType.kFloat16 - ) - .view(b, max_s, 3, h, d) - .contiguous() - ) + qkv_fp16 = qkv.dequantize().view(b, max_s, 3, h, d).contiguous() torch.save(qkv_fp16, "qkv.pt") if cudnn_frontend_version == 1: qkv = qkv.view(b, max_s, 3, h, d) # bs3hd # FMHA - out, aux_ctx_tensors, *rest = fused_attn_fwd( + q_data = qkv._data[:, :, 0, :, :] if cudnn_frontend_version == 1 else qkv._data[:, 0, :, :] + k_data = qkv._data[:, :, 1, :, :] if cudnn_frontend_version == 1 else qkv._data[:, 1, :, :] + v_data = qkv._data[:, :, 2, :, :] if cudnn_frontend_version == 1 else qkv._data[:, 2, :, :] + q = qkv.make_like(tensor=qkv, data=q_data, shape=q_data.shape) + k = qkv.make_like(tensor=qkv, data=k_data, shape=k_data.shape) + v = qkv.make_like(tensor=qkv, data=v_data, shape=v_data.shape) + + out, aux_ctx_tensors = fused_attn_fwd( is_training, max_s, max_s, cu_seqlens, cu_seqlens, - qkv[:, :, 0, :, :] if cudnn_frontend_version == 1 else qkv[:, 0, :, :], - qkv[:, :, 1, :, :] if cudnn_frontend_version == 1 else qkv[:, 1, :, :], - qkv[:, :, 2, :, :] if cudnn_frontend_version == 1 else qkv[:, 2, :, :], - fp8_dtype_forward, + q, + k, + v, + qkv_dtype, FusedAttnBackend["FP8"], - None, - None, - None, - fp8_meta["scaling_fwd"].scale_inv, # d_scale_qkv - META_QKV, # d_scale_qkv_offset - fp8_meta["scaling_fwd"].scale_inv, # d_scale_s - META_S, # d_scale_s_offset - fp8_meta["scaling_fwd"].scale, # q_scale_s - META_S, # q_scale_s_offset - fp8_meta["scaling_fwd"].scale, # q_scale_o - META_O, # q_scale_o_offset - fp8_meta["scaling_fwd"].amax_history, # amax_s - META_S, # amax_s_offset - fp8_meta["scaling_fwd"].amax_history, # amax_o - META_O, # amax_o_offset attn_scale=None, dropout=p_dropout, fast_zero_fill=fast_zero_fill, @@ -2065,20 +2118,18 @@ def forward( attn_bias_type="no_bias", attn_mask_type=mask_type if cudnn_frontend_version == 1 else "padding", rng_gen=None, + o_quantizer=o_quantizer, + s_quantizer=s_quantizer, ) - M, ZInv, philox_unpacked = aux_ctx_tensors - - ctx.save_for_backward( - inp_t_fp8, - qkv_weight_t_fp8, - workspace, - qkv, - out, - fp8_meta["scaling_fwd"].scale, - fp8_meta["scaling_fwd"].scale_inv, + tensors_to_save, tensor_objects = prepare_for_saving( + q, k, v, inp_fp8, qkv_weight_fp8, workspace, out ) + + ctx.save_for_backward(*tensors_to_save) + ctx.tensor_objects = tensor_objects ctx.aux_ctx_tensors = aux_ctx_tensors + ctx.qkv_dtype = qkv_dtype ctx.fp8_meta = fp8_meta ctx.cu_seqlens = cu_seqlens ctx.p_dropout = p_dropout @@ -2089,58 +2140,46 @@ def forward( ctx.mask_type = mask_type ctx.dtype = inp.dtype + ctx.dQKV_quantizer = dQKV_quantizer + ctx.dO_quantizer = dO_quantizer + ctx.dP_quantizer = dP_quantizer + ctx.S_quantizer = s_quantizer + out = out.view(-1, in_features) # (bs)(hd) - out_fp16 = ext.cast_from_fp8( - out, fp8_meta["scaling_fwd"], META_O, fp8_dtype_forward, tex.DType.kFloat16 - ) + out_fp16 = out.dequantize() torch.save(out_fp16, "out.pt") # (bs)(hd) return out_fp16 @staticmethod def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: with torch.cuda.nvtx.range("_DPA"): - ( - inp_t_fp8, - qkv_weight_t_fp8, - workspace, - qkv, - out, - fwd_scales, - fwd_scale_inverses, - ) = ctx.saved_tensors - fp8_dtype_forward = fp8.get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True) - fp8_dtype_backward = fp8.get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False) + saved_tensors = ctx.saved_tensors + (q, k, v, inp_fp8, qkv_weight_fp8, workspace, out) = restore_from_saved( + ctx.tensor_objects, saved_tensors + ) - proj_dgrad = ext.cast_to_fp8( - grad_output, ctx.fp8_meta["scaling_bwd"], META_DO, fp8_dtype_backward - ) # (bs)(hd) + proj_dgrad = ctx.dO_quantizer(grad_output) + fp8_dtype_backward = fp8.get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False) dq, dk, dv, *rest = fused_attn_bwd( ctx.max_s, ctx.max_s, ctx.cu_seqlens, ctx.cu_seqlens, - qkv[:, :, 0, :, :] if cudnn_frontend_version == 1 else qkv[:, 0, :, :], - qkv[:, :, 1, :, :] if cudnn_frontend_version == 1 else qkv[:, 1, :, :], - qkv[:, :, 2, :, :] if cudnn_frontend_version == 1 else qkv[:, 2, :, :], + q, + k, + v, out, proj_dgrad.view_as(out), - fp8_dtype_forward, + ctx.qkv_dtype, fp8_dtype_backward, ctx.aux_ctx_tensors, FusedAttnBackend["FP8"], None, None, - fwd_scale_inverses[META_QKV], # d_scale_qkv, - fwd_scale_inverses[META_S], # d_scale_s, - fwd_scale_inverses[META_O], # d_scale_o, - ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO], # d_scale_do - ctx.fp8_meta["scaling_bwd"].scale_inv[META_DP], # d_scale_dp - fwd_scales[META_S], # q_scale_s - ctx.fp8_meta["scaling_bwd"].scale[META_DP], # q_scale_dp - ctx.fp8_meta["scaling_bwd"].scale[META_DQKV], # q_scale_dqkv - ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DP], # amax_dp - ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DQKV], # amax_dqkv + ctx.S_quantizer, + ctx.dP_quantizer, + ctx.dQKV_quantizer, attn_scale=None, dropout=ctx.p_dropout, fast_zero_fill=ctx.fast_zero_fill, @@ -2149,58 +2188,42 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], attn_mask_type=ctx.mask_type if cudnn_frontend_version == 1 else "padding", ) dim = 2 if cudnn_frontend_version == 1 else 1 - dqkv = torch.Tensor().to(device=dq.device, dtype=dq.dtype) - dqkv_shape = list(dq.shape) + dqkv = torch.Tensor().to(device=dq._data.device, dtype=dq._data.dtype) + dqkv_shape = list(dq._data.shape) dqkv_shape.insert(dim, 3) - dqkv_stride = list(dq.stride()) + dqkv_stride = list(dq._data.stride()) dqkv_stride.insert(dim, int(dqkv_stride[-3] / 3)) - dqkv.set_(dq.untyped_storage(), dq.storage_offset(), dqkv_shape, dqkv_stride) # bs3hd + dqkv.set_( + dq._data.untyped_storage(), dq._data.storage_offset(), dqkv_shape, dqkv_stride + ) # bs3hd dqkv_c = dqkv.view(-1, 3 * ctx.hidden_size) - dqkv_c_fp16 = ext.cast_from_fp8( - dqkv_c, - ctx.fp8_meta["scaling_bwd"], - META_DQKV, - fp8_dtype_backward, - tex.DType.kFloat16, - ) + dqkv_c = dq.make_like(tensor=dq, data=dqkv_c, shape=dqkv_c.shape) + dqkv_c_fp16 = dqkv_c.dequantize() torch.save(dqkv_c_fp16, "dqkv.pt") - qkv_bgrad, dqkv_t = ext.fp8_transpose_bgrad_fused( - dqkv_c, - ctx.fp8_meta["scaling_bwd"], - META_DQKV, - fp8_dtype_backward, - ctx.dtype, - ) + qkv_bgrad, dqkv = ext.bgrad_quantize(dqkv_c_fp16, ctx.dQKV_quantizer) + dqkv_c._transpose = None + dqkv_c._create_transpose() # QKV DGRAD - qkv_dgrad, _ = ext.fp8_gemm( - qkv_weight_t_fp8, - fwd_scale_inverses, - tex.FP8FwdTensors.GEMM1_WEIGHT, - fp8_dtype_forward, + qkv_dgrad, *_ = ext.general_gemm( + qkv_weight_fp8, dqkv_c, - ctx.fp8_meta["scaling_bwd"].scale_inv, - META_DQKV, - fp8_dtype_backward, - ctx.dtype, workspace, + ctx.dtype, use_split_accumulator=_2X_ACC_DGRAD, + layout="NN", ) + # QKV WGRAD - qkv_wgrad, _ = ext.fp8_gemm( - inp_t_fp8, - fwd_scale_inverses, - tex.FP8FwdTensors.GEMM1_INPUT, - fp8_dtype_forward, - dqkv_t, - ctx.fp8_meta["scaling_bwd"].scale_inv, - META_DQKV, - fp8_dtype_backward, - ctx.dtype, + qkv_wgrad, *_ = ext.general_gemm( + inp_fp8, + dqkv, workspace, + ctx.dtype, use_split_accumulator=_2X_ACC_WGRAD, + layout="NT", ) return ( @@ -2258,7 +2281,7 @@ def forward( cu_seqlens, max_s, ) -> torch.Tensor: - with self.prepare_forward(inp, None, num_gemms=3) as inp: + with self.prepare_forward(inp, num_gemms=3) as inp: out = _custom_mha_fp8.apply( inp, self.qkv_weight, @@ -2272,5 +2295,6 @@ def forward( self.workspace, self.training, self.mask_type, + self.quantizers, ) return out diff --git a/tests/pytorch/test_cpu_offloading.py b/tests/pytorch/test_cpu_offloading.py new file mode 100644 index 0000000000..61b4a2553c --- /dev/null +++ b/tests/pytorch/test_cpu_offloading.py @@ -0,0 +1,57 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import pytest +import torch +from contextlib import nullcontext + +import transformer_engine.pytorch as te + +SIZE = 4096 + +models = { + "linear": te.Linear, + "layernorm_mlp": te.LayerNormMLP, + "layernorm_linear": te.LayerNormLinear, +} + + +def _get_input(): + return torch.empty((1, SIZE, SIZE)).cuda() # input size - 1 * 2048 * 2048 * 4b = 16MB + + +def _measure_memory_between_forward_and_backward(model_cls, fp8, cpu_offload): + torch.cuda.empty_cache() + model = model_cls(SIZE, SIZE, 1) + + input = _get_input() + if cpu_offload: + offload_context, sync_function = te.get_cpu_offload_context(enabled=True) + else: + offload_context = nullcontext() + sync_function = lambda x: x + + with te.fp8_autocast(enabled=fp8), offload_context: + out = model(input) + out = sync_function(out) + input.data = torch.Tensor() # delete data from input + out.data = torch.Tensor() # delete data from out + del input + del out + torch.cuda.empty_cache() + allocated_memory_mb = torch.cuda.memory_allocated() / 1024**2 + del model + return allocated_memory_mb + + +@pytest.mark.parametrize("fp8", [False, True]) +@pytest.mark.parametrize("model_key", models.keys()) +def test_cpu_offload(fp8, model_key) -> None: + model_cls = models[model_key] + without_offloading = _measure_memory_between_forward_and_backward(model_cls, fp8, False) + torch.cuda.empty_cache() + with_offloading = _measure_memory_between_forward_and_backward(model_cls, fp8, True) + + assert without_offloading > 30 + assert with_offloading < 10 diff --git a/tests/pytorch/test_cuda_graphs.py b/tests/pytorch/test_cuda_graphs.py index d92884eaa2..dcdfa771c8 100644 --- a/tests/pytorch/test_cuda_graphs.py +++ b/tests/pytorch/test_cuda_graphs.py @@ -22,10 +22,12 @@ from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.utils import is_bf16_compatible import transformer_engine.pytorch.ops as te_ops +from transformer_engine.common import recipe # Check if FP8 is supported. fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() +mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() # Record initial RNG state. @@ -49,6 +51,11 @@ class ModelConfig: model_configs = {"small": ModelConfig(2, 32, 64, 2, 32)} +fp8_recipes = [ + recipe.DelayedScaling(), + recipe.MXFP8BlockScaling(), +] + # Supported data types dtypes: List[torch.dtype] = [torch.float32, torch.float16] if is_bf16_compatible(): # bf16 requires sm_80 or higher @@ -152,6 +159,7 @@ def _test_cuda_graphs( fp8: bool, fp8_params: bool, fp8_weight_caching: bool, + fp8_recipe: recipe.Recipe, ) -> List[torch.Tensor]: """Helper function for CUDA graph test.""" reset_rng_states() @@ -162,7 +170,7 @@ def _test_cuda_graphs( fp8_weight_caching = False # Create modules. - with fp8_model_init(enabled=fp8_params): + with fp8_model_init(enabled=fp8_params, recipe=fp8_recipe): if module == "transformer": modules = [ TransformerLayer( @@ -244,6 +252,7 @@ def _test_cuda_graphs( num_warmup_iters=10, fp8_enabled=fp8, fp8_weight_caching=fp8_weight_caching, + fp8_recipe=fp8_recipe, ) elif graph_mode == "individual": # Graph individual modules. @@ -254,6 +263,7 @@ def _test_cuda_graphs( num_warmup_iters=10, fp8_enabled=fp8, fp8_weight_caching=fp8_weight_caching, + fp8_recipe=fp8_recipe, ) for module in modules ] @@ -270,7 +280,7 @@ def _test_cuda_graphs( for grad_accumulation_step in range(2): input_ = generate_data(model_config, dtype) grad_output = generate_data(model_config, dtype, requires_grad=False) - with fp8_autocast(enabled=fp8): + with fp8_autocast(enabled=fp8, fp8_recipe=fp8_recipe): kwargs = {} if fp8_weight_caching: kwargs["is_first_microbatch"] = grad_accumulation_step == 0 @@ -285,6 +295,7 @@ def _test_cuda_graphs( @pytest.mark.parametrize("dtype", dtypes) @pytest.mark.parametrize("fp8", (False, True)) @pytest.mark.parametrize("fp8_params", (False, True)) +@pytest.mark.parametrize("fp8_recipe", fp8_recipes) def test_make_graphed_callables( *, module: str, @@ -293,6 +304,7 @@ def test_make_graphed_callables( dtype: torch.dtype, fp8: bool, fp8_params: bool, + fp8_recipe: recipe.Recipe, fp8_weight_caching: bool = False, ) -> None: @@ -303,6 +315,8 @@ def test_make_graphed_callables( pytest.skip("FP8 needed for FP8 parameters.") if fp8_weight_caching and not fp8: pytest.skip("FP8 needed for FP8 parameters.") + if fp8_recipe.mxfp8() and not mxfp8_available: + pytest.skip(reason_for_no_mxfp8) # Run model with different CUDA graph settings. model_config = model_configs[model_config] @@ -314,6 +328,7 @@ def test_make_graphed_callables( fp8=fp8, fp8_params=fp8_params, fp8_weight_caching=fp8_weight_caching, + fp8_recipe=fp8_recipe, ) outputs = _test_cuda_graphs(graph_mode="none", **kwargs) graph_outputs_mode1 = _test_cuda_graphs(graph_mode="full", **kwargs) @@ -339,16 +354,19 @@ def test_make_graphed_callables( _test_make_graphed_callables_with_fp8_weight_caching_modules, ) @pytest.mark.parametrize("fp8_params", (False, True)) +@pytest.mark.parametrize("fp8_recipe", fp8_recipes) def test_make_graphed_callables_with_fp8_weight_caching( *, module: str, fp8_params: bool, + fp8_recipe: recipe.Recipe, ) -> None: test_make_graphed_callables( module=module, dtype=torch.float32, fp8=True, fp8_params=fp8_params, + fp8_recipe=fp8_recipe, fp8_weight_caching=True, ) diff --git a/tests/pytorch/test_float8tensor.py b/tests/pytorch/test_float8tensor.py index 96b4ab4967..56b01f1dbc 100644 --- a/tests/pytorch/test_float8tensor.py +++ b/tests/pytorch/test_float8tensor.py @@ -11,8 +11,8 @@ import transformer_engine.common.recipe import transformer_engine.pytorch as te -from transformer_engine.pytorch.float8_tensor import Float8Tensor from transformer_engine.pytorch.fp8 import FP8GlobalStateManager +from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer, Float8Tensor import transformer_engine_torch as tex # PyTorch tensor dtypes @@ -42,6 +42,20 @@ def _to_list(x: Union[Iterable, Any]) -> List: fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() +def to_float8( + tensor: torch.Tensor, + fp8_dtype: tex.DType = tex.DType.kFloat8E4M3, + scale: float = 1.0, +) -> Float8Tensor: + """Cast tensor to FP8""" + quantizer = Float8Quantizer( + scale=torch.full([1], scale, dtype=torch.float32, device="cuda"), + amax=torch.empty([1], dtype=torch.float32, device="cuda"), + fp8_dtype=fp8_dtype, + ) + return quantizer(tensor.cuda()) + + @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) class TestFloat8Tensor: @@ -62,10 +76,11 @@ def test_constructor( """Call constructor and perform sanity checks""" dims = _to_list(dims) tensor = Float8Tensor( + shape=dims, + dtype=dtype, data=torch.zeros(dims, device="cuda", dtype=torch.uint8), fp8_dtype=fp8_dtype, fp8_scale_inv=torch.full([1], scale_inv), - dtype=dtype, ) assert list(tensor.size()) == dims, "Incorrect dims" assert tensor.dtype == dtype, "Incorrect nominal dtype" @@ -84,11 +99,7 @@ def _test_quantize_dequantize( x_ref = 2 * torch.rand(_to_list(dims), dtype=dtype, device="cpu") - 1 # Cast to FP8 and back - x_fp8 = Float8Tensor.to_float8( - x_ref, - fp8_dtype=fp8_dtype, - scale=torch.full([1], scale), - ) + x_fp8 = to_float8(x_ref, fp8_dtype=fp8_dtype, scale=scale) x_fp8 = x_fp8.dequantize().cpu() # Check results @@ -115,62 +126,6 @@ def test_quantize_dequantize_scales(self, scale: float) -> None: def test_quantize_dequantize_dims(self, dims: DimsType) -> None: self._test_quantize_dequantize(dims=dims) - def test_fp8_meta( - self, - dtype: torch.dtype = torch.float32, - dims: DimsType = 23, - ) -> None: - """Construct Float8Tensor using FP8 metadata and perform basic checks""" - - # Get FP8 metadata from linear module - fp8_dtype = tex.DType.kFloat8E4M3 - recipe = transformer_engine.common.recipe.DelayedScaling( - fp8_format=transformer_engine.common.recipe.Format.E4M3, - ) - with te.fp8_autocast(enabled=True, fp8_recipe=recipe): - module = te.Linear(32, 32) - _ = module(torch.zeros([8, 32], device="cuda")) - fp8_meta = module.fp8_meta - fp8_meta_index = tex.FP8FwdTensors.GEMM1_WEIGHT - fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(forward=True) - - # Initialize random data - dims = _to_list(dims) - x_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1 - - # Make Float8Tensor - x_fp8 = Float8Tensor.to_float8( - x_ref, - fp8_meta=fp8_meta, - fp8_meta_index=fp8_meta_index, - ) - x_ref = x_fp8.dequantize() - assert list(x_fp8.size()) == dims, "Incorrect dims" - assert x_fp8.dtype == dtype, "Incorrect nominal dtype" - assert x_fp8.is_cuda, "Incorrect device" - assert x_fp8._fp8_dtype == fp8_dtype, "Incorrect FP8 dtype" - - # Change FP8 metadata scale - fp8_meta[fp8_meta_key].scale[fp8_meta_index] = 2 - fp8_meta[fp8_meta_key].scale_inv.fill_(123) - - # Check results - torch.testing.assert_close(x_fp8, x_ref, **_tols[fp8_dtype]) - with pytest.raises(AssertionError): - # Make sure we are not trivially passing the test - torch.testing.assert_close(x_fp8, -x_ref, **_tols[fp8_dtype]) - - # Check if scaling factor is updated after in-place ops - x_fp8 += 0 - fp8_meta[fp8_meta_key].scale[fp8_meta_index] = 4 - fp8_meta[fp8_meta_key].scale_inv.fill_(321) - assert x_fp8._scale_inv.item() == 0.5, "Incorrect FP8 scale_inv" - torch.testing.assert_close(x_fp8, x_ref, **_tols[fp8_dtype]) - y = x_fp8.detach() - y += 0 - assert x_fp8._scale_inv.item() == 0.25, "Incorrect FP8 scale_inv" - torch.testing.assert_close(x_fp8, x_ref, **_tols[fp8_dtype]) - def test_basic_ops( self, dims: DimsType = 23, @@ -184,16 +139,8 @@ def test_basic_ops( dims = _to_list(dims) x_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1 y_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1 - x_fp8 = Float8Tensor.to_float8( - x_ref, - fp8_dtype=fp8_dtype, - scale=torch.full([1], scale), - ) - y_fp8 = Float8Tensor.to_float8( - y_ref, - fp8_dtype=fp8_dtype, - scale=torch.full([1], scale), - ) + x_fp8 = to_float8(x_ref, fp8_dtype=fp8_dtype, scale=scale) + y_fp8 = to_float8(y_ref, fp8_dtype=fp8_dtype, scale=scale) x_ref = x_fp8.dequantize() y_ref = y_fp8.dequantize() @@ -227,16 +174,8 @@ def test_inplace_ops( dims = _to_list(dims) x_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1 y_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1 - x_fp8 = Float8Tensor.to_float8( - x_ref, - fp8_dtype=fp8_dtype, - scale=torch.full([1], scale), - ) - y_fp8 = Float8Tensor.to_float8( - y_ref, - fp8_dtype=fp8_dtype, - scale=torch.full([1], scale), - ) + x_fp8 = to_float8(x_ref, fp8_dtype=fp8_dtype, scale=scale) + y_fp8 = to_float8(y_ref, fp8_dtype=fp8_dtype, scale=scale) x_ref = x_fp8.dequantize() y_ref = y_fp8.dequantize() @@ -260,56 +199,6 @@ def test_inplace_ops( with pytest.raises(AssertionError): torch.testing.assert_close(x_fp8, x_ref, **tols) - @pytest.mark.parametrize("dims", [[33, 41], [7, 11]]) - def test_transpose( - self, - dims: DimsType, - fp8_dtype: tex.DType = tex.DType.kFloat8E4M3, - scale: float = 0.5, - dtype: torch.dtype = torch.float32, - ) -> None: - """Test transpose""" - - # Initialize random data - dims = _to_list(dims) - x = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1 - x_fp8 = Float8Tensor.to_float8( - x, - fp8_dtype=fp8_dtype, - scale=torch.full([1], scale), - ) - x = x_fp8.dequantize() - - # Perform transpose - x_fp8_t = x_fp8.transpose_2d() - x_t = x.transpose(0, 1) - x_fp8_t = Float8Tensor.make_like(x_fp8, data=x_fp8_t) - - # Check results - tols = dict(rtol=0, atol=0) - torch.testing.assert_close(x_fp8_t, x_t, **tols) - - # Make sure we are not trivially passing the test - with pytest.raises(AssertionError): - torch.testing.assert_close(x_fp8_t, x, **tols) - - # Caching test - assert x_fp8._transpose_invalid, "Transpose cache must be invalid when not caching." - x_fp8 += 0.5 - x = x_fp8.dequantize() - x_fp8_t = Float8Tensor.make_like(x_fp8, data=x_fp8.transpose_2d(fill_cache=True)) - x_t = x.transpose(0, 1) - torch.testing.assert_close(x_fp8_t, x_t, **tols) - assert not x_fp8._transpose_invalid, "Transpose cache reset incorrectly." - - # Inplace update test - x_fp8 += 0.5 - assert not x_fp8._transpose_invalid, "Transpose cache reset incorrectly." - x = x_fp8.dequantize() - x_fp8_t = Float8Tensor.make_like(x_fp8, data=x_fp8._transpose) - x_t = x.transpose(0, 1) - torch.testing.assert_close(x_fp8_t, x_t, **tols) - def test_serialization( self, dims: DimsType = [2, 3, 5], @@ -321,11 +210,7 @@ def test_serialization( # Initialize random data dims = _to_list(dims) x_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1 - x_fp8 = Float8Tensor.to_float8( - x_ref, - fp8_dtype=fp8_dtype, - scale=torch.full([1], scale), - ) + x_fp8 = to_float8(x_ref, fp8_dtype=fp8_dtype, scale=scale) x_ref = x_fp8.dequantize() # Serialize tensor @@ -357,7 +242,7 @@ def test_set_data(self): # Initialize Float8Tensor x0 = torch.zeros(4, dtype=torch.float32) - x = Float8Tensor.to_float8(x0) + x = to_float8(x0) assert isinstance(x, Float8Tensor) assert x0.size() == x.size() == x._data.size() assert x.dtype == torch.float32 @@ -382,7 +267,7 @@ def test_set_data(self): assert x.device == y.device # Set data to Float8Tensor - x0 = Float8Tensor.to_float8(torch.zeros((4, 3, 1), dtype=torch.float32)) + x0 = to_float8(torch.zeros((4, 3, 1), dtype=torch.float32)) x.data = x0 assert isinstance(x, Float8Tensor) assert x0.size() == x.size() == x._data.size() diff --git a/tests/pytorch/test_fused_optimizer.py b/tests/pytorch/test_fused_optimizer.py index 96acb699ad..507fd3f350 100644 --- a/tests/pytorch/test_fused_optimizer.py +++ b/tests/pytorch/test_fused_optimizer.py @@ -11,6 +11,7 @@ from torch import nn from torch.testing._internal.common_device_type import largeTensorTest import transformer_engine.pytorch as te +from transformer_engine.common.recipe import DelayedScaling from transformer_engine.pytorch.attention import MultiheadAttention from transformer_engine.pytorch import fp8_model_init from transformer_engine.pytorch.utils import is_bf16_compatible @@ -446,7 +447,7 @@ def test_bf16_model_weight_cast(self): @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) def test_fp8_model_weight_cast(self): dtype = torch.bfloat16 - with fp8_model_init(enabled=True): + with fp8_model_init(enabled=True, recipe=DelayedScaling()): model = MultiheadAttention( hidden_size=1024, num_attention_heads=16, diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index e2f712cce8..97d48e2aa3 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -4,7 +4,9 @@ from __future__ import annotations +from collections.abc import Iterable import math +from typing import Optional import pytest import torch @@ -12,7 +14,6 @@ import transformer_engine import transformer_engine.common.recipe import transformer_engine.pytorch as te -from transformer_engine.pytorch.float8_tensor import Float8Tensor from transformer_engine.pytorch.fp8 import FP8GlobalStateManager import transformer_engine.pytorch.ops as te_ops from transformer_engine.pytorch.ops._common import is_float8_tensor @@ -21,11 +22,14 @@ ForwardLinearBiasActivation, ForwardLinearBiasAdd, ) +from transformer_engine.pytorch.tensor import QuantizedTensor +from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8Quantizer from transformer_engine.pytorch.utils import is_bf16_compatible import transformer_engine_torch as tex # Check if FP8 is supported fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() +mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() # Supported data types _dtypes: list[torch.dtype] = [torch.float32, torch.float16] @@ -36,6 +40,38 @@ _devices: list[torch.device] = [torch.device("cpu"), torch.device("cuda")] +def maybe_skip_quantization( + quantization: Optional[str], + *, + dims: Optional[Iterable[int] | int] = None, + device: Optional[torch.device | str] = None, +) -> None: + + # Don't skip if there is no quantization + if quantization is None: + return + + # Check if quantization scheme is supported + if quantization == "fp8" and not fp8_available: + pytest.skip(reason_for_no_fp8) + if quantization == "mxfp8" and not mxfp8_available: + pytest.skip(reason_for_no_mxfp8) + + if dims is not None: + if not isinstance(dims, Iterable): + dims = (dims,) + if quantization == "fp8": + if math.prod(dims[:-1]) % 16 != 0 or dims[-1] % 16 != 0: + pytest.skip("FP8 GEMMs require dims that are divisible by 16") + elif quantization == "mxfp8": + if math.prod(dims[:-1]) % 32 != 0 or dims[-1] % 32 != 0: + pytest.skip("MXFP8 GEMMs require dims that are divisible by 32") + + # Check if device is supported + if device is not None and torch.device(device).type != "cuda": + pytest.skip("Quantization is only supported on CUDA devices") + + def dtype_tols(dtype: torch.dtype | tex.DType) -> dict[str, float]: """Estimated numerical error for a datatype @@ -89,7 +125,12 @@ def make_reference_and_test_tensors( ref = torch.rand(shape, dtype=ref_dtype, device=ref_device) test = ref.to(device=test_device, dtype=test_dtype) if test_is_fp8: - test = Float8Tensor.to_float8(test, with_transpose_cache=True) + quantizer = Float8Quantizer( + scale=torch.ones(1, dtype=torch.float32, device=test_device).squeeze(), + amax=torch.zeros(1, dtype=torch.float32, device=test_device), + fp8_dtype=tex.DType.kFloat8E4M3, + ) + test = quantizer(test) elif test.data_ptr() == ref.data_ptr(): test = test.clone() ref.copy_(test) @@ -98,6 +139,21 @@ def make_reference_and_test_tensors( return ref, test +def make_recipe(name: Optional[str] = None) -> Optional[Recipe]: + """Make recipe for quantization scheme""" + if name is None: + return None + if name == "fp8": + return transformer_engine.common.recipe.DelayedScaling( + fp8_format=transformer_engine.common.recipe.Format.E4M3, + ) + if name == "mxfp8": + return transformer_engine.common.recipe.MXFP8BlockScaling( + fp8_format=transformer_engine.common.recipe.Format.E4M3, + ) + raise ValueError(f"Unsupported quantization scheme ({name})") + + class TestSequential: """Tests for sequential container""" @@ -239,7 +295,7 @@ def test_fp8_scale_update( ) # Construct model - with te.fp8_model_init(): + with te.fp8_model_init(recipe=recipe): model = te_ops.basic.BasicLinear( size, size, @@ -299,35 +355,30 @@ def test_fp8_scale_update( w_scale_ref = (fp8_format.value.max_fwd / w_amax_ref) / (2**margin) x_scale_ref = (fp8_format.value.max_fwd / x_amax_ref) / (2**margin) dy_scale_ref = (fp8_format.value.max_bwd / dy_amax_ref) / (2**margin) - forward_key = FP8GlobalStateManager.get_meta_tensor_key(forward=True) - backward_key = FP8GlobalStateManager.get_meta_tensor_key(forward=False) - w_scale = model.get_fp8_meta("param")[forward_key].scale - x_scale = model.get_fp8_meta("input")[forward_key].scale - dy_scale = model.get_fp8_meta("grad_output")[backward_key].scale + w_scale = model.get_quantizer("forward", 1).scale + x_scale = model.get_quantizer("forward", 0).scale + dy_scale = model.get_quantizer("backward", 0).scale torch.testing.assert_close(w_scale, torch.full_like(w_scale, w_scale_ref)) torch.testing.assert_close(x_scale, torch.full_like(x_scale, x_scale_ref)) torch.testing.assert_close(dy_scale, torch.full_like(dy_scale, dy_scale_ref)) @pytest.mark.parametrize("init_dtype", _dtypes) @pytest.mark.parametrize("final_dtype", _dtypes) - @pytest.mark.parametrize("fp8_weight", (False, True)) + @pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8")) def test_dtype_cast( self, *, - size: int = 16, + size: int = 32, init_dtype: torch.dtype, final_dtype: torch.dtype, device: torch.device = "cuda", - fp8_weight: bool, + quantization: Optional[str], ) -> None: """Check dtype cast functions""" # Skip invalid configurations - if fp8_weight: - if not fp8_available: - pytest.skip(reason_for_no_fp8) - if torch.device(device).type != "cuda": - pytest.skip("FP8 is only supported on CUDA devices") + maybe_skip_quantization(quantization, device=device) + with_quantization = quantization is not None # Random data dtype = torch.float32 @@ -339,11 +390,11 @@ def test_dtype_cast( (size, size), test_dtype=dtype, test_device=device, - test_is_fp8=fp8_weight, + test_is_fp8=with_quantization, ) # Construct operation - with te.fp8_model_init(enabled=fp8_weight): + with te.fp8_model_init(enabled=with_quantization, recipe=make_recipe(quantization)): op = te_ops.Linear(size, size, bias=False, device=device, dtype=init_dtype) with torch.no_grad(): op.weight.copy_(w_test) @@ -358,7 +409,7 @@ def test_dtype_cast( op.bfloat16() # Check weights - assert isinstance(op.weight, Float8Tensor) == fp8_weight + assert isinstance(op.weight, QuantizedTensor) == with_quantization assert op.weight.dtype == final_dtype w_test = op.weight.to(dtype=torch.float64, device="cpu") torch.testing.assert_close(w_test, w_ref, rtol=0, atol=0) @@ -378,29 +429,27 @@ def test_dtype_cast( @pytest.mark.parametrize("model_dtype", _dtypes) @pytest.mark.parametrize("autocast_dtype", _dtypes) - @pytest.mark.parametrize("fp8_compute", (False, True)) + @pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8")) def test_pyt_autocast( self, *, - size: int = 16, + size: int = 32, model_dtype: torch.dtype, autocast_dtype: torch.dtype, device: torch.device = "cuda", - fp8_weight: bool = False, - fp8_compute: bool, + quantization: Optional[str], + quantized_weights: bool = False, ) -> None: """Test with PyTorch autocast""" device = torch.device(device) # Skip invalid configurations - if fp8_weight or fp8_compute: - if not fp8_available: - pytest.skip(reason_for_no_fp8) - if torch.device(device).type != "cuda": - pytest.skip("FP8 is only supported on CUDA devices") + quantized_compute = quantization is not None + maybe_skip_quantization(quantization) # Construct operation - with te.fp8_model_init(enabled=fp8_weight): + recipe = make_recipe(quantization) + with te.fp8_model_init(enabled=quantized_weights, recipe=recipe): op = te_ops.Linear(size, size, bias=False, device=device, dtype=model_dtype) # Check forward and backward pass @@ -410,7 +459,7 @@ def test_pyt_autocast( device=device, requires_grad=True, ) - with te.fp8_autocast(enabled=fp8_compute): + with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): with torch.autocast(device_type=device.type, dtype=autocast_dtype): y = op(x) y.backward(torch.zeros_like(y)) @@ -419,11 +468,11 @@ def test_pyt_autocast( assert op.weight.grad.dtype == model_dtype # Check forward and backward pass (swapped context order) - if fp8_compute: + if quantized_compute: x.grad = None op.weight.grad = None with torch.autocast(device_type=device.type, dtype=autocast_dtype): - with te.fp8_autocast(enabled=fp8_compute): + with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): y = op(x) y.backward(torch.zeros_like(y)) assert y.dtype == autocast_dtype @@ -505,19 +554,14 @@ def test_identity( ), ) @pytest.mark.parametrize("dtype", _dtypes) - @pytest.mark.parametrize("device", ("cuda", "cpu")) - @pytest.mark.parametrize( - "memory_format", - (torch.contiguous_format, torch.channels_last), - ) @pytest.mark.parametrize("fp8", (False, True)) def test_reshape( self, *, shapes: tuple[Iterable[int], Iterable[int]], dtype: torch.dtype, - device: torch.device, - memory_format: torch.memory_format, + device: torch.device = "cuda", + memory_format: torch.memory_format = torch.contiguous_format, fp8: bool, ) -> None: in_shape, out_shape = shapes @@ -634,19 +678,23 @@ def test_bias( torch.testing.assert_close(dx_test, x_ref.grad, **tols) torch.testing.assert_close(db_test, b_ref.grad, **tols) - @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) + @pytest.mark.parametrize("quantization", ("fp8", "mxfp8")) @pytest.mark.parametrize("cast_forward", (False, True)) @pytest.mark.parametrize("cast_backward", (False, True)) - def test_cast_float8( + def test_quantize( self, *, - in_shape: Iterable[int] = (1,), + in_shape: Iterable[int] = (32, 32), dtype: torch.dtype = torch.bfloat16, device: torch.device = "cuda", + quantization: str, cast_forward: bool, cast_backward: bool, ) -> None: - """FP8 cast""" + """Quantize""" + + # Skip invalid configurations + maybe_skip_quantization(quantization) # Random data x_ref, x_test = make_reference_and_test_tensors( @@ -656,7 +704,7 @@ def test_cast_float8( requires_grad=False, test_is_fp8=True, ) - x_test = x_test.from_float8().requires_grad_() + x_test = x_test.dequantize().requires_grad_() dy_ref, dy_test = make_reference_and_test_tensors( in_shape, test_dtype=dtype, @@ -664,7 +712,7 @@ def test_cast_float8( requires_grad=False, test_is_fp8=True, ) - dy_test = dy_test.from_float8() + dy_test = dy_test.dequantize() # Plain PyTorch implementation y_ref = x_ref @@ -672,16 +720,14 @@ def test_cast_float8( # Implementation with fusible operation op = te_ops.Quantize(forward=cast_forward, backward=cast_backward) - recipe = transformer_engine.common.recipe.DelayedScaling( - fp8_format=transformer_engine.common.recipe.Format.E4M3, - ) + recipe = make_recipe(quantization) with te.fp8_autocast(fp8_recipe=recipe): y_test = op(x_test) y_test.backward(dy_test) # Check tensor types - assert is_float8_tensor(y_test) == cast_forward - assert is_float8_tensor(x_test.grad) == cast_backward + assert isinstance(y_test, QuantizedTensor) == cast_forward + assert isinstance(x_test.grad, QuantizedTensor) == cast_backward # Check values tols = dict(rtol=0, atol=0) @@ -697,12 +743,13 @@ def _test_basic_linear( in_shape: Iterable[int] = (32, -1), dtype: torch.dtype = torch.float32, device: torch.device = "cuda", - fp8_compute: bool = False, - fp8_input: bool = False, - fp8_weight: bool = False, - fp8_output: bool = False, - fp8_grad_output: bool = False, - fp8_grad_input: bool = False, + quantization: Optional[str] = None, + quantized_compute: bool = False, + quantized_input: bool = False, + quantized_weight: bool = False, + quantized_output: bool = False, + quantized_grad_output: bool = False, + quantized_grad_input: bool = False, accumulate_into_main_grad: bool = False, ) -> None: """Helper function for tests with GEMM""" @@ -713,50 +760,50 @@ def _test_basic_linear( out_shape = in_shape[:-1] + [out_features] # Skip invalid configurations - if fp8_compute or fp8_input or fp8_weight or fp8_output or fp8_grad_output: - if not fp8_available: - pytest.skip(reason_for_no_fp8) - if torch.device(device).type != "cuda": - pytest.skip("FP8 is only supported on CUDA devices") - if fp8_compute: - if ( - math.prod(in_shape[:-1]) % 16 != 0 - or in_features % 16 != 0 - or out_features % 16 != 0 - ): - pytest.skip("FP8 GEMMs require dims that are divisible by 16") - if fp8_output and not fp8_compute: + maybe_skip_quantization(quantization, dims=in_shape, device=device) + maybe_skip_quantization(quantization, dims=out_shape) + if quantization == "fp8" and quantized_output and not quantized_compute: pytest.skip("FP8 output is only supported with FP8 GEMMs") - if fp8_grad_input and not fp8_compute: + if quantization == "fp8" and quantized_grad_input and not quantized_compute: pytest.skip("FP8 grad input is only supported with FP8 GEMMs") + if quantization == "mxfp8" and quantized_output: + pytest.skip("MXFP8 output is not supported with MXFP8 GEMMs") + if quantization == "mxfp8" and quantized_grad_input: + pytest.skip("MXFP8 grad input is not supported with MXFP8 GEMMs") # Random data x_ref, x_test = make_reference_and_test_tensors( in_shape, test_dtype=dtype, test_device=device, - test_is_fp8=(fp8_compute or fp8_input), + test_is_fp8=(quantized_compute or quantized_input), ) + if isinstance(x_test, QuantizedTensor): + with torch.no_grad(): + x_test = x_test.dequantize().requires_grad_() w_ref, w_test = make_reference_and_test_tensors( (out_features, in_features), test_dtype=dtype, test_device=device, - test_is_fp8=(fp8_compute or fp8_weight), + test_is_fp8=(quantized_compute or quantized_weight), ) dy_ref, dy_test = make_reference_and_test_tensors( out_shape, test_dtype=dtype, test_device=device, - test_is_fp8=(fp8_compute or fp8_grad_output), + test_is_fp8=(quantized_compute or quantized_grad_output), requires_grad=False, ) + if isinstance(dy_test, QuantizedTensor): + dy_test = dy_test.dequantize() # Plain PyTorch implementation y_ref = torch.nn.functional.linear(x_ref, w_ref) y_ref.backward(dy_ref) # Implementation with fusible operation - with te.fp8_model_init(enabled=fp8_weight): + recipe = make_recipe(quantization) + with te.fp8_model_init(enabled=quantized_weight, recipe=recipe): op = te_ops.BasicLinear( in_features, out_features, @@ -769,14 +816,11 @@ def _test_basic_linear( del w_test op.weight.main_grad = torch.full_like(op.weight, 0.5, dtype=torch.float32) forward = te_ops.Sequential( - te_ops.Quantize(forward=fp8_input, backward=fp8_grad_input), + te_ops.Quantize(forward=quantized_input, backward=quantized_grad_input), op, - te_ops.Quantize(forward=fp8_output, backward=fp8_grad_output), - ) - recipe = transformer_engine.common.recipe.DelayedScaling( - fp8_format=transformer_engine.common.recipe.Format.E4M3, + te_ops.Quantize(forward=quantized_output, backward=quantized_grad_output), ) - with te.fp8_autocast(enabled=fp8_compute, fp8_recipe=recipe): + with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): y_test = forward(x_test) y_test.backward(dy_test) @@ -784,10 +828,8 @@ def _test_basic_linear( tols = dtype_tols(dtype) if dtype == torch.float32: tols = dtype_tols(torch.float16) # TF32 GEMM - if fp8_compute or fp8_output or fp8_grad_input: - tols = dtype_tols( - op.weight._fp8_dtype if is_float8_tensor(op.weight) else tex.DType.kFloat8E4M3 - ) + if quantized_compute or quantized_output or quantized_grad_input: + tols = dtype_tols(tex.DType.kFloat8E4M3) # Check results y_test = y_test.to(dtype=torch.float64, device="cpu") @@ -813,10 +855,10 @@ def _test_basic_linear( ) torch.testing.assert_close(dw_test, w_ref.grad, **tols) - @pytest.mark.parametrize("weight_shape", ((48, 16), (3, 5))) - @pytest.mark.parametrize("in_shape", ((-1,), (5, 1, -1), (2, 2, 4, -1))) + @pytest.mark.parametrize("weight_shape", ((64, 32), (3, 5))) + @pytest.mark.parametrize("in_shape", ((-1,), (5, 1, -1), (4, 2, 4, -1))) @pytest.mark.parametrize("dtype", _dtypes) - @pytest.mark.parametrize("fp8_compute", (False, True)) + @pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8")) @pytest.mark.parametrize("accumulate_into_main_grad", (False, True)) def test_basic_linear( self, @@ -824,7 +866,7 @@ def test_basic_linear( weight_shape: tuple[int, int], in_shape: Iterable[int], dtype: torch.dtype, - fp8_compute: bool, + quantization: Optional[str], accumulate_into_main_grad: bool, ) -> None: """GEMM""" @@ -832,52 +874,55 @@ def test_basic_linear( weight_shape=weight_shape, in_shape=in_shape, dtype=dtype, - fp8_compute=fp8_compute, + quantization=quantization, + quantized_compute=quantization is not None, accumulate_into_main_grad=accumulate_into_main_grad, ) @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) - @pytest.mark.parametrize("fp8_compute", (False, True)) - @pytest.mark.parametrize("fp8_input", (False, True)) - @pytest.mark.parametrize("fp8_weight", (False, True)) - @pytest.mark.parametrize("fp8_output", (False, True)) - @pytest.mark.parametrize("fp8_grad_output", (False, True)) - @pytest.mark.parametrize("fp8_grad_input", (False, True)) - def test_basic_linear_fp8( + @pytest.mark.parametrize("quantization", ("fp8", "mxfp8")) + @pytest.mark.parametrize("quantized_compute", (False, True)) + @pytest.mark.parametrize("quantized_input", (False, True)) + @pytest.mark.parametrize("quantized_weight", (False, True)) + @pytest.mark.parametrize("quantized_output", (False, True)) + @pytest.mark.parametrize("quantized_grad_output", (False, True)) + @pytest.mark.parametrize("quantized_grad_input", (False, True)) + def test_basic_linear_quantized( self, *, - fp8_compute: bool, - fp8_input: bool, - fp8_weight: bool, - fp8_output: bool, - fp8_grad_output: bool, - fp8_grad_input: bool, + quantization: str, + quantized_compute: bool, + quantized_input: bool, + quantized_weight: bool, + quantized_output: bool, + quantized_grad_output: bool, + quantized_grad_input: bool, ) -> None: """GEMM with FP8 inputs and outputs""" self._test_basic_linear( dtype=torch.bfloat16, - fp8_compute=fp8_compute, - fp8_input=fp8_input, - fp8_weight=fp8_weight, - fp8_output=fp8_output, - fp8_grad_output=fp8_grad_output, - fp8_grad_input=fp8_grad_input, + quantization=quantization, + quantized_compute=quantized_compute, + quantized_input=quantized_input, + quantized_weight=quantized_weight, + quantized_output=quantized_output, + quantized_grad_output=quantized_grad_output, + quantized_grad_input=quantized_grad_input, ) @pytest.mark.parametrize("bias", (False, True)) - @pytest.mark.parametrize("fp8_compute", (False, True)) - @pytest.mark.parametrize("fp8_weight", (False, True)) + @pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8")) + @pytest.mark.parametrize("quantized_weight", (False, True)) def test_linear( self, *, bias: bool, - weight_shape: tuple[int, int] = (16, 16), - in_shape: Iterable[int] = (16, -1), + weight_shape: tuple[int, int] = (32, 32), + in_shape: Iterable[int] = (32, -1), dtype: torch.dtype = torch.float32, device: torch.device = "cuda", - fp8_compute: bool, - fp8_input: bool = False, - fp8_weight: bool, + quantization: Optional[str], + quantized_weight: bool, ) -> None: """GEMM + bias""" @@ -887,31 +932,25 @@ def test_linear( out_shape = in_shape[:-1] + [out_features] # Skip invalid configurations - if fp8_input or fp8_weight or fp8_compute: - if not fp8_available: - pytest.skip(reason_for_no_fp8) - if torch.device(device).type != "cuda": - pytest.skip("FP8 is only supported on CUDA devices") - if fp8_compute: - if ( - math.prod(in_shape[:-1]) % 16 != 0 - or in_features % 16 != 0 - or out_features % 16 != 0 - ): - pytest.skip("FP8 GEMMs require dims that are divisible by 16") + quantized_compute = quantization is not None + maybe_skip_quantization(quantization, dims=in_shape, device=device) + maybe_skip_quantization(quantization, dims=out_shape) # Random data x_ref, x_test = make_reference_and_test_tensors( in_shape, test_dtype=dtype, test_device=device, - test_is_fp8=(fp8_compute or fp8_input), + test_is_fp8=quantized_compute, ) + if isinstance(x_test, QuantizedTensor): + with torch.no_grad(): + x_test = x_test.dequantize().requires_grad_() w_ref, w_test = make_reference_and_test_tensors( (out_features, in_features), test_dtype=dtype, test_device=device, - test_is_fp8=(fp8_compute or fp8_weight), + test_is_fp8=(quantized_compute or quantized_weight), ) b_ref, b_test = None, None if bias: @@ -932,7 +971,8 @@ def test_linear( y_ref.backward(dy_ref) # Implementation with fusible operation - with te.fp8_model_init(enabled=fp8_weight): + recipe = make_recipe(quantization) + with te.fp8_model_init(enabled=quantized_weight, recipe=recipe): op = te_ops.Linear( in_features, out_features, @@ -946,7 +986,7 @@ def test_linear( op.bias.copy_(b_test) del w_test del b_test - with te.fp8_autocast(enabled=fp8_compute): + with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): y_test = op(x_test) y_test.backward(dy_test) @@ -954,10 +994,8 @@ def test_linear( tols = dtype_tols(dtype) if dtype == torch.float32: tols = dtype_tols(torch.float16) # TF32 GEMM - if fp8_compute: - tols = dtype_tols( - op.weight._fp8_dtype if is_float8_tensor(op.weight) else tex.DType.kFloat8E4M3 - ) + if quantized_compute: + tols = dtype_tols(tex.DType.kFloat8E4M3) # Check results y_test = y_test.to(dtype=torch.float64, device="cpu") @@ -970,12 +1008,11 @@ def test_linear( db_test = op.bias.grad.to(dtype=torch.float64, device="cpu") torch.testing.assert_close(db_test, b_ref.grad, **tols) - @pytest.mark.parametrize("weight_shape", ((19,), (16, 4))) - @pytest.mark.parametrize("in_shape", ((-1,), (6, 8, -1))) + @pytest.mark.parametrize("weight_shape", ((7, 2), (32,))) + @pytest.mark.parametrize("in_shape", ((-1,), (6, 16, -1))) @pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("zero_centered_gamma", (False, True)) - @pytest.mark.parametrize("fp8_input", (False, True)) - @pytest.mark.parametrize("fp8_output", (False, True)) + @pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8")) def test_layer_norm( self, *, @@ -985,8 +1022,7 @@ def test_layer_norm( device: torch.device = "cuda", eps: float = 0.3, zero_centered_gamma: bool, - fp8_input: bool, - fp8_output: bool, + quantization: Optional[str], ) -> None: """Layer norm""" @@ -994,18 +1030,13 @@ def test_layer_norm( in_shape = list(in_shape)[:-1] + list(weight_shape) # Skip invalid configurations - if fp8_input or fp8_output: - if not fp8_available: - pytest.skip(reason_for_no_fp8) - if torch.device(device).type != "cuda": - pytest.skip("FP8 is only supported on CUDA devices") + maybe_skip_quantization(quantization, dims=in_shape, device=device) # Random data x_ref, x_test = make_reference_and_test_tensors( in_shape, test_dtype=dtype, test_device=device, - test_is_fp8=fp8_input, ) w_ref, w_test = make_reference_and_test_tensors( weight_shape, @@ -1047,17 +1078,19 @@ def test_layer_norm( op.bias.copy_(b_test) del w_test del b_test + quantized_compute = quantization is not None + recipe = make_recipe(quantization) forward = te_ops.Sequential( op, - te_ops.Quantize(forward=fp8_output, backward=False), + te_ops.Quantize(forward=quantized_compute, backward=False), ) - with te.fp8_autocast(enabled=fp8_output): + with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): y_test = forward(x_test) y_test.backward(dy_test) # Expected numerical error tols = dtype_tols(dtype) - if fp8_output: + if quantized_compute: tols = dtype_tols(tex.DType.kFloat8E4M3) # Check results @@ -1145,12 +1178,11 @@ def test_layer_norm_autocast( torch.testing.assert_close(dw_test, w_ref.grad, **dtype_tols(dtype)) torch.testing.assert_close(db_test, b_ref.grad, **dtype_tols(dtype)) - @pytest.mark.parametrize("weight_shape", ((19,), (16, 4))) - @pytest.mark.parametrize("in_shape", ((-1,), (6, 8, -1))) + @pytest.mark.parametrize("weight_shape", ((19,), (64,))) + @pytest.mark.parametrize("in_shape", ((-1,), (6, 16, -1))) @pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("zero_centered_gamma", (False, True)) - @pytest.mark.parametrize("fp8_input", (False, True)) - @pytest.mark.parametrize("fp8_output", (False, True)) + @pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8")) def test_rmsnorm( self, *, @@ -1160,8 +1192,7 @@ def test_rmsnorm( device: torch.device = "cuda", eps: float = 0.3, zero_centered_gamma: bool, - fp8_input: bool, - fp8_output: bool, + quantization: Optional[str], ) -> None: """Layer norm""" @@ -1169,18 +1200,13 @@ def test_rmsnorm( in_shape = list(in_shape)[:-1] + list(weight_shape) # Skip invalid configurations - if fp8_input or fp8_output: - if not fp8_available: - pytest.skip(reason_for_no_fp8) - if torch.device(device).type != "cuda": - pytest.skip("FP8 is only supported on CUDA devices") + maybe_skip_quantization(quantization, dims=in_shape, device=device) # Random data x_ref, x_test = make_reference_and_test_tensors( in_shape, test_dtype=dtype, test_device=device, - test_is_fp8=fp8_input, ) w_ref, w_test = make_reference_and_test_tensors( weight_shape, @@ -1214,17 +1240,19 @@ def test_rmsnorm( with torch.no_grad(): op.weight.copy_(w_test) del w_test + quantized_compute = quantization is not None + recipe = make_recipe(quantization) forward = te_ops.Sequential( op, - te_ops.Quantize(forward=fp8_output, backward=False), + te_ops.Quantize(forward=quantized_compute, backward=False), ) - with te.fp8_autocast(enabled=fp8_output): + with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): y_test = forward(x_test) y_test.backward(dy_test) # Expected numerical error tols = dtype_tols(dtype) - if fp8_output: + if quantized_compute: tols = dtype_tols(tex.DType.kFloat8E4M3) # Check results @@ -1363,10 +1391,9 @@ def test_make_extra_output( torch.testing.assert_close(dx_test, x_ref.grad, **tols) @pytest.mark.parametrize("activation", ("relu", "gelu", "geglu", "reglu", "swiglu")) - @pytest.mark.parametrize("out_shape", ((37,), (2, 13), (4, 1, 16))) + @pytest.mark.parametrize("out_shape", ((37,), (2, 13), (32, 1, 32))) @pytest.mark.parametrize("dtype", _dtypes) - @pytest.mark.parametrize("fp8_input", (False, True)) - @pytest.mark.parametrize("fp8_output", (False, True)) + @pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8")) def test_activation( self, *, @@ -1374,8 +1401,7 @@ def test_activation( out_shape: Iterable[int], dtype: torch.dtype, device: torch.device = "cuda", - fp8_input: bool, - fp8_output: bool, + quantization: Optional[str], ) -> None: """Activation functions""" @@ -1385,19 +1411,19 @@ def test_activation( in_shape[-1] *= 2 # Skip invalid configurations - if fp8_input or fp8_output: - if not fp8_available: - pytest.skip(reason_for_no_fp8) - if torch.device(device).type != "cuda": - pytest.skip("FP8 is only supported on CUDA devices") + quantized_compute = quantization is not None + maybe_skip_quantization(quantization, dims=in_shape, device=device) # Random data x_ref, x_test = make_reference_and_test_tensors( in_shape, test_dtype=dtype, test_device=device, - test_is_fp8=fp8_input, + test_is_fp8=quantized_compute, ) + if quantized_compute: + with torch.no_grad(): + x_test = x_test.dequantize().requires_grad_() dy_ref, dy_test = make_reference_and_test_tensors( out_shape, test_dtype=dtype, @@ -1425,6 +1451,7 @@ def test_activation( y_ref.backward(dy_ref) # Implementation with fusible operation + recipe = make_recipe(quantization) make_op = dict( gelu=te_ops.GELU, relu=te_ops.ReLU, @@ -1434,16 +1461,18 @@ def test_activation( )[activation] forward = te_ops.Sequential( make_op(), - te_ops.Quantize(forward=fp8_output, backward=False), + te_ops.Quantize(forward=quantized_compute, backward=False), ) - with te.fp8_autocast(enabled=fp8_output): + with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): y_test = forward(x_test) y_test.backward(dy_test) # Expected numerical error tols = dtype_tols(dtype) - if fp8_output: + if quantized_compute: tols = dtype_tols(tex.DType.kFloat8E4M3) + if activation == "relu": + tols = {"atol": 0, "rtol": 0} # Check results y_test = y_test.to(dtype=torch.float64, device="cpu") @@ -1452,16 +1481,18 @@ def test_activation( torch.testing.assert_close(dx_test, x_ref.grad, **tols) @pytest.mark.parametrize("dtype", _dtypes) - @pytest.mark.parametrize("fp8_output", (False, True)) - @pytest.mark.parametrize("fp8_grad_input", (False, True)) + @pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8")) + @pytest.mark.parametrize("quantize_forward", (False, True)) + @pytest.mark.parametrize("quantize_backward", (False, True)) def test_swiglu( self, *, - out_shape: Iterable[int] = (16, 16), + out_shape: Iterable[int] = (32, 32), dtype: torch.dtype, device: torch.device = "cuda", - fp8_output: bool, - fp8_grad_input: bool, + quantization: Optional[str], + quantize_forward: bool, + quantize_backward: bool, ): # Tensor dimensions @@ -1469,19 +1500,10 @@ def test_swiglu( in_shape[-1] *= 2 # Skip invalid configurations - fp8 = fp8_output or fp8_grad_input - if fp8: - if not fp8_available: - pytest.skip(reason_for_no_fp8) - if torch.device(device).type != "cuda": - pytest.skip("FP8 is only supported on CUDA devices") - - # FP8 recipe - fp8_recipe = None - if fp8_grad_input: - fp8_recipe = transformer_engine.common.recipe.DelayedScaling( - fp8_format=transformer_engine.common.recipe.Format.E4M3, - ) + quantized_compute = quantization is not None + if not quantized_compute and (quantize_forward or quantize_backward): + pytest.skip("Quantization scheme has not been provided") + maybe_skip_quantization(quantization, dims=in_shape, device=device) # Random data x_ref, x_test = make_reference_and_test_tensors( @@ -1502,18 +1524,19 @@ def test_swiglu( y_ref.backward(dy_ref) # Implementation with fusible operation + recipe = make_recipe(quantization) forward = te_ops.Sequential( - te_ops.Quantize(forward=False, backward=fp8_grad_input), + te_ops.Quantize(forward=False, backward=quantize_backward), te_ops.SwiGLU(), - te_ops.Quantize(forward=fp8_output, backward=False), + te_ops.Quantize(forward=quantize_forward, backward=False), ) - with te.fp8_autocast(enabled=fp8, fp8_recipe=fp8_recipe): + with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): y_test = forward(x_test) y_test.backward(dy_test) # Expected numerical error tols = dtype_tols(dtype) - if fp8: + if quantized_compute: tols = dtype_tols(tex.DType.kFloat8E4M3) # Check results @@ -1533,12 +1556,11 @@ def setup_class(cls) -> None: torch.manual_seed(seed) torch.cuda.manual_seed(seed) - @pytest.mark.parametrize("weight_shape", ((32, 48), (3, 5))) - @pytest.mark.parametrize("in_shape", ((-1,), (1, 7, -1), (4, 2, 10, -1))) + @pytest.mark.parametrize("weight_shape", ((32, 64), (3, 5))) + @pytest.mark.parametrize("in_shape", ((-1,), (1, 7, -1), (8, 2, 10, -1))) @pytest.mark.parametrize("dtype", _dtypes) - @pytest.mark.parametrize("fp8_compute", (False, True)) - @pytest.mark.parametrize("fp8_input", (False, True)) - @pytest.mark.parametrize("fp8_weight", (False, True)) + @pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8")) + @pytest.mark.parametrize("quantized_weight", (False, True)) def test_forward_linear_bias_activation( self, *, @@ -1547,9 +1569,8 @@ def test_forward_linear_bias_activation( in_shape: Iterable[int], dtype: torch.dtype, device: torch.device = "cuda", - fp8_compute: bool, - fp8_input: bool, - fp8_weight: bool, + quantization: Optional[str], + quantized_weight: bool, ) -> None: """Forward GEMM + bias + activation""" @@ -1559,18 +1580,9 @@ def test_forward_linear_bias_activation( out_shape = in_shape[:-1] + [out_features] # Skip invalid configurations - if fp8_input or fp8_weight or fp8_compute: - if not fp8_available: - pytest.skip(reason_for_no_fp8) - if torch.device(device).type != "cuda": - pytest.skip("FP8 is only supported on CUDA devices") - if fp8_compute: - if ( - math.prod(in_shape[:-1]) % 16 != 0 - or in_features % 16 != 0 - or out_features % 16 != 0 - ): - pytest.skip("FP8 GEMMs require dims that are divisible by 16") + quantized_compute = quantization is not None + maybe_skip_quantization(quantization, dims=in_shape, device=device) + maybe_skip_quantization(quantization, dims=out_shape) if dtype not in (torch.float16, torch.bfloat16): pytest.skip( "FP8 fused linear-bias-activation is only supported with FP16 or BF16 output" @@ -1581,13 +1593,16 @@ def test_forward_linear_bias_activation( in_shape, test_dtype=dtype, test_device=device, - test_is_fp8=(fp8_compute or fp8_input), + test_is_fp8=quantized_compute, ) + if quantized_compute: + with torch.no_grad(): + x_test = x_test.dequantize().requires_grad_() w_ref, w_test = make_reference_and_test_tensors( (out_features, in_features), test_dtype=dtype, test_device=device, - test_is_fp8=(fp8_compute or fp8_weight), + test_is_fp8=(quantized_compute or quantized_weight), ) b_ref, b_test = None, None if bias: @@ -1608,7 +1623,8 @@ def test_forward_linear_bias_activation( y_ref.backward(dy_ref) # Implementation with fusible operations - with te.fp8_model_init(enabled=fp8_weight): + recipe = make_recipe(quantization) + with te.fp8_model_init(enabled=quantized_compute, recipe=recipe): model = te_ops.Sequential( te_ops.Linear( in_features, @@ -1624,7 +1640,7 @@ def test_forward_linear_bias_activation( model[0].bias.copy_(b_test) del w_test del b_test - with te.fp8_autocast(enabled=fp8_compute): + with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): y_test = model(x_test) y_test.backward(dy_test) @@ -1637,12 +1653,8 @@ def test_forward_linear_bias_activation( tols = dtype_tols(dtype) if dtype == torch.float32: tols = dtype_tols(torch.float16) # TF32 GEMM - if fp8_compute: - tols = dtype_tols( - model[0].weight._fp8_dtype - if is_float8_tensor(model[0].weight) - else tex.DType.kFloat8E4M3 - ) + if quantized_compute: + tols = dtype_tols(tex.DType.kFloat8E4M3) # Check results y_test = y_test.to(dtype=torch.float64, device="cpu") @@ -1657,19 +1669,17 @@ def test_forward_linear_bias_activation( @pytest.mark.parametrize("bias", (False, True)) @pytest.mark.parametrize("dtype", _dtypes) - @pytest.mark.parametrize("fp8_compute", (False, True)) + @pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8")) def test_forward_linear_bias_add( self, *, bias: bool, - weight_shape: tuple[int, int] = (16, 16), - in_shape: Iterable[int] = (16, -1), + weight_shape: tuple[int, int] = (32, 32), + in_shape: Iterable[int] = (32, -1), dtype: torch.dtype, device: torch.device = "cuda", - fp8_compute: bool, - fp8_input: bool = False, - fp8_weight: bool = False, - fp8_output: bool = False, + quantization: Optional[str], + quantized_weight: bool = False, ) -> None: """Forward GEMM + bias + add""" @@ -1679,21 +1689,10 @@ def test_forward_linear_bias_add( out_shape = in_shape[:-1] + [out_features] # Skip invalid configurations - if fp8_input or fp8_weight or fp8_output or fp8_compute: - if not fp8_available: - pytest.skip(reason_for_no_fp8) - if torch.device(device).type != "cuda": - pytest.skip("FP8 is only supported on CUDA devices") - if fp8_compute: - if ( - math.prod(in_shape[:-1]) % 16 != 0 - or in_features % 16 != 0 - or out_features % 16 != 0 - ): - pytest.skip("FP8 GEMMs require dims that are divisible by 16") - if fp8_output and not fp8_compute: - pytest.skip("FP8 output requires FP8 compute") - if fp8_compute and dtype not in (torch.float16, torch.bfloat16): + quantized_compute = quantization is not None + maybe_skip_quantization(quantization, dims=in_shape, device=device) + maybe_skip_quantization(quantization, dims=out_shape) + if quantized_compute and dtype not in (torch.float16, torch.bfloat16): pytest.skip("FP8 GEMM is only supported with FP8, FP16, or BF16 output") # Random data @@ -1701,13 +1700,16 @@ def test_forward_linear_bias_add( in_shape, test_dtype=dtype, test_device=device, - test_is_fp8=(fp8_compute or fp8_input), + test_is_fp8=quantized_compute, ) + if isinstance(x1_test, QuantizedTensor): + with torch.no_grad(): + x1_test = x1_test.dequantize().requires_grad_() w_ref, w_test = make_reference_and_test_tensors( (out_features, in_features), test_dtype=dtype, test_device=device, - test_is_fp8=(fp8_compute or fp8_weight), + test_is_fp8=(quantized_compute or quantized_weight), ) b_ref, b_test = None, None if bias: @@ -1720,7 +1722,6 @@ def test_forward_linear_bias_add( out_shape, test_dtype=dtype, test_device=device, - test_is_fp8=fp8_output, ) dy_ref, dy_test = make_reference_and_test_tensors( out_shape, @@ -1734,7 +1735,8 @@ def test_forward_linear_bias_add( y_ref.backward(dy_ref) # Implementation with fusible operations - with te.fp8_model_init(enabled=fp8_weight): + recipe = make_recipe(quantization) + with te.fp8_model_init(enabled=quantized_weight, recipe=recipe): model = te_ops.Sequential( te_ops.Linear( in_features, @@ -1751,7 +1753,7 @@ def test_forward_linear_bias_add( model[0].bias.copy_(b_test) del w_test del b_test - with te.fp8_autocast(enabled=fp8_compute): + with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): y_test = model(x1_test, x2_test) y_test.backward(dy_test) @@ -1764,12 +1766,8 @@ def test_forward_linear_bias_add( tols = dtype_tols(dtype) if dtype == torch.float32: tols = dtype_tols(torch.float16) # TF32 GEMM - if fp8_compute: - tols = dtype_tols( - model[0].weight._fp8_dtype - if is_float8_tensor(model[0].weight) - else tex.DType.kFloat8E4M3 - ) + if quantized_compute: + tols = dtype_tols(tex.DType.kFloat8E4M3) # Check results y_test = y_test.to(dtype=torch.float64, device="cpu") @@ -1785,18 +1783,16 @@ def test_forward_linear_bias_add( torch.testing.assert_close(db_test, b_ref.grad, **tols) @pytest.mark.parametrize("dtype", _dtypes) - @pytest.mark.parametrize("fp8_compute", (False, True)) + @pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8")) def test_backward_linear_add( self, *, - weight_shape: tuple[int, int] = (16, 16), - in_shape: Iterable[int] = (16, -1), + weight_shape: tuple[int, int] = (32, 32), + in_shape: Iterable[int] = (32, -1), dtype: torch.dtype, device: torch.device = "cuda", - fp8_compute: bool, - fp8_input: bool = False, - fp8_weight: bool = False, - fp8_output: bool = False, + quantization: Optional[str], + quantized_weight: bool = False, ) -> None: """Backward dgrad GEMM + add""" @@ -1806,21 +1802,10 @@ def test_backward_linear_add( out_shape = in_shape[:-1] + [out_features] # Skip invalid configurations - if fp8_input or fp8_weight or fp8_output or fp8_compute: - if not fp8_available: - pytest.skip(reason_for_no_fp8) - if torch.device(device).type != "cuda": - pytest.skip("FP8 is only supported on CUDA devices") - if fp8_compute: - if ( - math.prod(in_shape[:-1]) % 16 != 0 - or in_features % 16 != 0 - or out_features % 16 != 0 - ): - pytest.skip("FP8 GEMMs require dims that are divisible by 16") - if fp8_output and not fp8_compute: - pytest.skip("FP8 output requires FP8 compute") - if fp8_compute and dtype not in (torch.float16, torch.bfloat16): + quantized_compute = quantization is not None + maybe_skip_quantization(quantization, dims=in_shape, device=device) + maybe_skip_quantization(quantization, dims=out_shape) + if quantized_compute and dtype not in (torch.float16, torch.bfloat16): pytest.skip("FP8 GEMM is only supported with FP8, FP16, or BF16 output") # Random data @@ -1828,13 +1813,16 @@ def test_backward_linear_add( in_shape, test_dtype=dtype, test_device=device, - test_is_fp8=(fp8_compute or fp8_input), + test_is_fp8=quantized_compute, ) + if isinstance(x_test, QuantizedTensor): + with torch.no_grad(): + x_test = x_test.dequantize().requires_grad_() w_ref, w_test = make_reference_and_test_tensors( (out_features, in_features), test_dtype=dtype, test_device=device, - test_is_fp8=(fp8_compute or fp8_weight), + test_is_fp8=(quantized_compute or quantized_weight), ) dy1_ref, dy1_test = make_reference_and_test_tensors( out_shape, @@ -1855,7 +1843,8 @@ def test_backward_linear_add( (y1_ref * dy1_ref + y2_ref * dy2_ref).sum().backward() # Implementation with fusible operations - with te.fp8_model_init(enabled=fp8_weight): + recipe = make_recipe(quantization) + with te.fp8_model_init(enabled=quantized_weight): model = te_ops.Sequential( te_ops.MakeExtraOutput(), te_ops.Linear( @@ -1869,7 +1858,7 @@ def test_backward_linear_add( with torch.no_grad(): model[1].weight.copy_(w_test) del w_test - with te.fp8_autocast(enabled=fp8_compute): + with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): y1_test, y2_test = model(x_test) (y1_test * dy1_test + y2_test * dy2_test).sum().backward() @@ -1882,12 +1871,8 @@ def test_backward_linear_add( tols = dtype_tols(dtype) if dtype == torch.float32: tols = dtype_tols(torch.float16) # TF32 GEMM - if fp8_compute: - tols = dtype_tols( - model[1].weight._fp8_dtype - if is_float8_tensor(model[1].weight) - else tex.DType.kFloat8E4M3 - ) + if quantized_compute: + tols = dtype_tols(tex.DType.kFloat8E4M3) # Check results y1_test = y1_test.to(dtype=torch.float64, device="cpu") diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index e9b6303933..2401f3ca95 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -13,7 +13,11 @@ import torch.nn as nn from torch.nn import Parameter -from transformer_engine.pytorch.fp8 import fp8_autocast, FP8GlobalStateManager, fp8_model_init +from transformer_engine.pytorch.fp8 import ( + FP8GlobalStateManager, + fp8_autocast, + fp8_model_init, +) from transformer_engine.pytorch.utils import ( init_method_normal, scaled_init_method_normal, @@ -35,13 +39,16 @@ Fp8Unpadding, ) from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint -from transformer_engine.pytorch.cpp_extensions import fp8_gemm, fp8_grouped_gemm, gemm, grouped_gemm +from transformer_engine.pytorch.cpp_extensions import general_gemm, general_grouped_gemm +from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer from transformer_engine.pytorch.module.base import get_multi_stream_cublas_workspace, get_workspace from transformer_engine.pytorch.utils import get_device_compute_capability +from transformer_engine.common import recipe import transformer_engine_torch as tex -# Only run FP8 tests on H100. +# Only run FP8 tests on supported devices. fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() +mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() sm_80plus = get_device_compute_capability() >= (8, 0) @@ -90,6 +97,11 @@ def __init__(self, hidden_size, eps, num_attention_heads, embed, num_layers, seq mask_types = ["causal", "no_mask"] +fp8_recipes = [ + recipe.MXFP8BlockScaling(), + recipe.DelayedScaling(), +] + def get_causal_attn_mask(sq: int) -> torch.Tensor: return torch.triu(torch.ones(sq, sq, device="cuda"), diagonal=1).bool() @@ -450,7 +462,8 @@ def __init__( self.fc2 = nn.Linear(ffn_hidden_size, hidden_size) def forward(self, x): - return self.fc2(self.gelu(self.fc1(self.ln(x)))) + t = self.gelu(self.fc1(self.ln(x))) + return self.fc2(t) class TorchGPT(nn.Module): @@ -480,7 +493,9 @@ def forward( return x -def _test_e2e_selective_recompute(bs, dtype, config, fp8, fp8_model_params=False, recompute=False): +def _test_e2e_selective_recompute( + bs, dtype, config, fp8, recipe, fp8_model_params=False, recompute=False +): reset_rng_states() FP8GlobalStateManager.reset() @@ -488,7 +503,7 @@ def _test_e2e_selective_recompute(bs, dtype, config, fp8, fp8_model_params=False init_method = init_method_normal(sigma) output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) - with fp8_model_init(enabled=fp8 and fp8_model_params): + with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): block = TransformerLayer( config.hidden_size, 4 * config.hidden_size, @@ -515,7 +530,7 @@ def _test_e2e_selective_recompute(bs, dtype, config, fp8, fp8_model_params=False te_inp_hidden_states.retain_grad() te_inp_attn_mask = get_causal_attn_mask(config.seq_len) - with fp8_autocast(enabled=fp8): + with fp8_autocast(enabled=fp8, fp8_recipe=recipe): te_out = block( te_inp_hidden_states, attention_mask=te_inp_attn_mask, @@ -536,18 +551,21 @@ def _test_e2e_selective_recompute(bs, dtype, config, fp8, fp8_model_params=False @pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("model", ["126m"]) @pytest.mark.parametrize("fp8", all_boolean) +@pytest.mark.parametrize("recipe", fp8_recipes) @pytest.mark.parametrize("fp8_model_params", all_boolean) -def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, fp8_model_params): +def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, recipe, fp8_model_params): if fp8 and not fp8_available: pytest.skip(reason_for_no_fp8) + if recipe.mxfp8() and not mxfp8_available: + pytest.skip(reason_for_no_mxfp8) config = model_configs[model] outputs = _test_e2e_selective_recompute( - bs, dtype, config, fp8, fp8_model_params, recompute=False + bs, dtype, config, fp8, recipe, fp8_model_params, recompute=False ) outputs_recompute = _test_e2e_selective_recompute( - bs, dtype, config, fp8, fp8_model_params, recompute=True + bs, dtype, config, fp8, recipe, fp8_model_params, recompute=True ) # Check that results match @@ -556,6 +574,7 @@ def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, fp8_model_par tols["atol"] = 1e-4 if fp8 or fp8_model_params: tols.update(dict(rtol=0.125, atol=0.0675)) + for i, (ref, test) in enumerate(zip(outputs, outputs_recompute)): torch.testing.assert_close( test, @@ -566,7 +585,7 @@ def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, fp8_model_par def _test_e2e_full_recompute( - bs, dtype, config, fp8, fp8_model_params=False, recompute=False, use_reentrant=True + bs, dtype, config, fp8, recipe, fp8_model_params=False, recompute=False, use_reentrant=True ): reset_rng_states() FP8GlobalStateManager.reset() @@ -575,7 +594,7 @@ def _test_e2e_full_recompute( init_method = init_method_normal(sigma) output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) - with fp8_model_init(enabled=fp8 and fp8_model_params): + with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): block = TransformerLayer( config.hidden_size, 4 * config.hidden_size, @@ -603,7 +622,7 @@ def _test_e2e_full_recompute( te_inp_hidden_states.retain_grad() te_inp_attn_mask = get_causal_attn_mask(config.seq_len) - with fp8_autocast(enabled=fp8): + with fp8_autocast(enabled=fp8, fp8_recipe=recipe): if recompute: te_out = te_checkpoint( block, @@ -641,11 +660,16 @@ def _test_e2e_full_recompute( @pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("model", ["126m"]) @pytest.mark.parametrize("fp8", all_boolean) +@pytest.mark.parametrize("recipe", fp8_recipes) @pytest.mark.parametrize("fp8_model_params", all_boolean) @pytest.mark.parametrize("use_reentrant", all_boolean) -def test_gpt_full_activation_recompute(dtype, bs, model, fp8, fp8_model_params, use_reentrant): +def test_gpt_full_activation_recompute( + dtype, bs, model, fp8, recipe, fp8_model_params, use_reentrant +): if fp8 and not fp8_available: pytest.skip(reason_for_no_fp8) + if recipe.mxfp8() and not mxfp8_available: + pytest.skip(reason_for_no_mxfp8) config = model_configs[model] @@ -654,10 +678,24 @@ def test_gpt_full_activation_recompute(dtype, bs, model, fp8, fp8_model_params, os.environ["NVTE_BIAS_GELU_NVFUSION"] = "0" outputs, names = _test_e2e_full_recompute( - bs, dtype, config, fp8, fp8_model_params, recompute=False, use_reentrant=use_reentrant + bs, + dtype, + config, + fp8, + recipe, + fp8_model_params, + recompute=False, + use_reentrant=use_reentrant, ) outputs_recompute, _ = _test_e2e_full_recompute( - bs, dtype, config, fp8, fp8_model_params, recompute=True, use_reentrant=use_reentrant + bs, + dtype, + config, + fp8, + recipe, + fp8_model_params, + recompute=True, + use_reentrant=use_reentrant, ) if not use_reentrant: @@ -741,7 +779,7 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path= del block block = _test_e2e_checkpointing_get_model(config, dtype) - block.load_state_dict(torch.load(path)) + block.load_state_dict(torch.load(path, weights_only=False)) reset_rng_states() for p in block.parameters(): @@ -1267,9 +1305,14 @@ def test_layernorm_linear_accuracy(dtype, bs, model, normalization, zero_centere torch.half: 2e-3, torch.bfloat16: 2e-2, } + rtol = { + torch.float32: 1e-3, + torch.half: 4e-2, + torch.bfloat16: 4e-2, + } # Check output. - assert_allclose(te_outputs[0], torch_outputs[0], atol[dtype]) + assert_allclose(te_outputs[0], torch_outputs[0], atol[dtype], rtol[dtype]) if model == "small": atol = { @@ -1335,8 +1378,14 @@ def test_layernorm_mlp_accuracy(dtype, bs, model, activation, normalization): torch.bfloat16: 5e-2, } + rtol = { + torch.float32: 1e-3, + torch.half: 4e-2, + torch.bfloat16: 4e-2, + } + # Check output. - assert_allclose(te_outputs[0], torch_outputs[0], atol[dtype]) + assert_allclose(te_outputs[0], torch_outputs[0], atol[dtype], rtol[dtype]) # Check gradients, only for small model rtol = { @@ -1351,7 +1400,7 @@ def test_layernorm_mlp_accuracy(dtype, bs, model, activation, normalization): assert_allclose(te_output, torch_output, atol[dtype], rtol[dtype]) -def _test_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, fp8=False): +def _test_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, recipe, fp8=False): reset_rng_states() if fp8: FP8GlobalStateManager.reset() @@ -1365,16 +1414,22 @@ def _test_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, fp8=False inp_hidden_states.retain_grad() if num_gemms > 1: - m = config.seq_len // 16 + split_size = 1 + if fp8: + if recipe.delayed(): + split_size = 16 + if recipe.mxfp8(): + split_size = 128 + m = config.seq_len // split_size dist = torch.sort(torch.randint(0, m, (num_gemms - 2,))).values.tolist() dist.append(dist[-1]) # Manually add a zero m_splits = torch.tensor(dist + [m]) - torch.tensor([0] + dist) - m_splits = m_splits * 16 + m_splits = m_splits * split_size assert m_splits.sum() == config.seq_len and len(m_splits) == num_gemms else: m_splits = torch.tensor([config.seq_len]) - with fp8_autocast(enabled=fp8): + with fp8_autocast(enabled=fp8, fp8_recipe=recipe): if isinstance(block, GroupedLinear): m_splits = m_splits * bs out = block(inp_hidden_states, m_splits.tolist()) @@ -1401,18 +1456,23 @@ def _test_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, fp8=False @pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("model", ["126m"]) @pytest.mark.parametrize("fp8", all_boolean) +@pytest.mark.parametrize("recipe", fp8_recipes) @pytest.mark.parametrize("fp8_model_params", all_boolean) def test_grouped_linear_accuracy( - dtype, num_gemms, bs, model, fp8, fp8_model_params, parallel_mode=None + dtype, num_gemms, bs, model, fp8, recipe, fp8_model_params, parallel_mode=None ): if fp8 and not fp8_available: pytest.skip(reason_for_no_fp8) + if recipe.mxfp8() and not mxfp8_available: + pytest.skip(reason_for_no_mxfp8) + if fp8 and recipe.mxfp8(): # TODO(ksivamani): debug mismatches + pytest.skip("MXFP8 unsupported for grouped linear.") config = model_configs[model] if config.seq_len % 16 != 0 and fp8: pytest.skip("FP8 requires sequence length to be divisible by 16.") - with fp8_model_init(enabled=fp8 and fp8_model_params): + with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): grouped_linear = GroupedLinear( num_gemms, config.hidden_size, @@ -1442,9 +1502,11 @@ def test_grouped_linear_accuracy( sequential_linear[i].weight = Parameter(getattr(grouped_linear, f"weight{i}").clone()) sequential_linear[i].bias = Parameter(getattr(grouped_linear, f"bias{i}").clone()) - outputs = _test_grouped_linear_accuracy(grouped_linear, num_gemms, bs, dtype, config, fp8) outputs_ref = _test_grouped_linear_accuracy( - sequential_linear, num_gemms, bs, dtype, config, fp8 + sequential_linear, num_gemms, bs, dtype, config, recipe, fp8 + ) + outputs = _test_grouped_linear_accuracy( + grouped_linear, num_gemms, bs, dtype, config, recipe, fp8 ) # Shoule be bit-wise match @@ -1453,7 +1515,8 @@ def test_grouped_linear_accuracy( @pytest.mark.parametrize("parallel_mode", ["column", "row"]) -def test_grouped_linear_accuracy_parallel_mode(parallel_mode): +@pytest.mark.parametrize("recipe", fp8_recipes) +def test_grouped_linear_accuracy_parallel_mode(parallel_mode, recipe): """Split the tests to save CI time""" test_grouped_linear_accuracy( dtype=torch.float32, @@ -1461,12 +1524,14 @@ def test_grouped_linear_accuracy_parallel_mode(parallel_mode): bs=2, model="126m", fp8=True, + recipe=recipe, fp8_model_params=True, parallel_mode=parallel_mode, ) -def test_grouped_linear_accuracy_single_gemm(): +@pytest.mark.parametrize("recipe", fp8_recipes) +def test_grouped_linear_accuracy_single_gemm(recipe): """Split the tests to save CI time""" test_grouped_linear_accuracy( dtype=torch.float32, @@ -1474,11 +1539,12 @@ def test_grouped_linear_accuracy_single_gemm(): bs=2, model="126m", fp8=True, + recipe=recipe, fp8_model_params=True, ) -def _test_padding_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, fp8=False): +def _test_padding_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, recipe, fp8=False): def _pad_tensor_for_fp8(hidden_states, tokens_per_expert): """Padding tensor shapes to multiples of 16.""" @@ -1546,7 +1612,7 @@ def _generate_random_numbers(n, total_sum): m_splits = _generate_random_numbers(num_gemms, config.seq_len * bs) - with fp8_autocast(enabled=fp8): + with fp8_autocast(enabled=fp8, fp8_recipe=recipe): if isinstance(block, TorchGroupedLinearWithPadding): out = block(inp_hidden_states, m_splits) else: @@ -1575,18 +1641,23 @@ def _generate_random_numbers(n, total_sum): @pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("model", ["126m"]) @pytest.mark.parametrize("fp8", [True]) +@pytest.mark.parametrize("recipe", fp8_recipes) @pytest.mark.parametrize("fp8_model_params", all_boolean) def test_padding_grouped_linear_accuracy( - dtype, num_gemms, bs, model, fp8, fp8_model_params, parallel_mode=None + dtype, num_gemms, bs, model, fp8, recipe, fp8_model_params, parallel_mode=None ): if fp8 and not fp8_available: pytest.skip(reason_for_no_fp8) + if recipe.mxfp8() and not mxfp8_available: + pytest.skip(reason_for_no_mxfp8) + if fp8 and recipe.mxfp8(): # TODO(ksivamani): debug mismatches + pytest.skip("MXFP8 unsupported for grouped linear.") config = model_configs[model] if config.seq_len % 16 != 0 and fp8: pytest.skip("FP8 requires sequence length to be divisible by 16.") - with fp8_model_init(enabled=fp8 and fp8_model_params): + with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): grouped_linear = TorchGroupedLinearWithPadding( num_gemms, config.hidden_size, @@ -1597,7 +1668,7 @@ def test_padding_grouped_linear_accuracy( fp8=fp8, ).eval() - with fp8_model_init(enabled=fp8 and fp8_model_params): + with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): ref_grouped_linear = GroupedLinear( num_gemms, config.hidden_size, @@ -1619,10 +1690,10 @@ def test_padding_grouped_linear_accuracy( ) outputs = _test_padding_grouped_linear_accuracy( - grouped_linear, num_gemms, bs, dtype, config, fp8 + grouped_linear, num_gemms, bs, dtype, config, recipe, fp8 ) outputs_ref = _test_padding_grouped_linear_accuracy( - ref_grouped_linear, num_gemms, bs, dtype, config, fp8 + ref_grouped_linear, num_gemms, bs, dtype, config, recipe, fp8 ) # Shoule be bit-wise match @@ -1734,7 +1805,7 @@ def test_gpt_cuda_graph(dtype, bs, model): assert_allclose(grads, graphed_grads, 1e-3) -def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params): +def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params, recipe): reset_rng_states() FP8GlobalStateManager.reset() @@ -1742,7 +1813,7 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params): init_method = init_method_normal(sigma) output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) - with fp8_model_init(enabled=fp8_model_params): + with fp8_model_init(enabled=fp8_model_params, recipe=recipe): block = TransformerLayer( config.hidden_size, 4 * config.hidden_size, @@ -1769,7 +1840,7 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params): te_inp_hidden_states.retain_grad() te_inp_attn_mask = get_causal_attn_mask(config.seq_len) - with fp8_autocast(enabled=True): + with fp8_autocast(enabled=True, fp8_recipe=recipe): te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask) loss = te_out.sum() loss.backward() @@ -1785,14 +1856,17 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params): @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("model", ["126m"]) -def test_gpt_fp8_parameters(dtype, bs, model): +@pytest.mark.parametrize("recipe", fp8_recipes) +def test_gpt_fp8_parameters(dtype, bs, model, recipe): if not fp8_available: pytest.skip(reason_for_no_fp8) + if recipe.mxfp8() and not mxfp8_available: + pytest.skip(reason_for_no_mxfp8) config = model_configs[model] - outputs = _test_gpt_fp8_parameters(bs, dtype, config, False) - outputs_fp8_params = _test_gpt_fp8_parameters(bs, dtype, config, True) + outputs = _test_gpt_fp8_parameters(bs, dtype, config, False, recipe) + outputs_fp8_params = _test_gpt_fp8_parameters(bs, dtype, config, True, recipe) # Check that results match tols = dict(rtol=0.125, atol=0.0675) @@ -2073,23 +2147,24 @@ def test_grouped_gemm(shape, dtype, layout, accumulate): out_ref = [o.clone() for o in out] for i in range(z): - gemm( + general_gemm( A[i], B[i], - dtype, get_workspace(), + dtype, grad=grad, accumulate=accumulate, layout=layout, out=out_ref[i], ) - grouped_gemm( + general_grouped_gemm( A, - B, - out, + list(B), + list(out), dtype, get_multi_stream_cublas_workspace(), + m_splits=[k] * n, # TODO, not sure grad=grad, accumulate=accumulate, layout=layout, @@ -2124,64 +2199,52 @@ def test_fp8_grouped_gemm(shape, fp8_dtype, accumulate): out_ref = [o.clone() for o in out] # fp8 should be robust enough to this fake scale - scale = 1 + torch.rand(z * 3, dtype=torch.float32, device="cuda") - scale_inv = 1 / scale - amax = torch.zeros(1024, z * 3, dtype=torch.float32, device="cuda") + scale = 1 + torch.rand(1, dtype=torch.float32, device="cuda").squeeze() + amax = torch.zeros(1, 1, dtype=torch.float32, device="cuda") - A_fp8 = [ - torch.ops.tex_ts.cast_to_fp8_ts( - A[i], - scale, - amax, - scale_inv, - i, # fp8 meta tensor index + a_quantizers = [ + Float8Quantizer( + scale.clone(), + amax.clone(), tex.DType.kFloat8E4M3, ) - for i in range(z) + for _ in range(z) ] - B_fp8 = [ - torch.ops.tex_ts.cast_to_fp8_ts( - B[i], - scale, - amax, - scale_inv, - z + i, # fp8 meta tensor index - fp8_dtype, + b_quantizers = [ + Float8Quantizer( + scale.clone(), + amax.clone(), + tex.DType.kFloat8E4M3, ) - for i in range(z) + for _ in range(z) ] - fp8_grouped_gemm( - A_fp8, - [scale_inv], - 0, # A_offset - tex.DType.kFloat8E4M3, - B_fp8, - scale_inv, - z, # B_offset - fp8_dtype, - out, - dtype, - get_multi_stream_cublas_workspace(), - accumulate=accumulate, - ) + A_fp8 = [] + B_fp8 = [] + + for i in range(z): + A_fp8.append(a_quantizers[i](A[i])) + B_fp8.append(b_quantizers[i](B[i])) # baseline for i in range(z): - fp8_gemm( + general_gemm( A_fp8[i], - scale_inv, - i, - tex.DType.kFloat8E4M3, B_fp8[i], - scale_inv, - z + i, - fp8_dtype, - dtype, get_workspace(), + dtype, out=out_ref[i], accumulate=accumulate, ) + general_grouped_gemm( + A_fp8, + B_fp8, + out, + dtype, + get_multi_stream_cublas_workspace(), + m_splits=[k] * m_splits, + accumulate=accumulate, + ) # should be bit-wise match for o, o_ref in zip(out, out_ref): diff --git a/tests/pytorch/test_onnx_export.py b/tests/pytorch/test_onnx_export.py deleted file mode 100644 index 46e888462a..0000000000 --- a/tests/pytorch/test_onnx_export.py +++ /dev/null @@ -1,1562 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -""" -This file contains tests for exporting TransformerEngine models to ONNX. - -The purpose of these tests is validation that TE models are converted to their correct ONNX -representation. Toward this end, each test captures the output of a TE module forward pass, -converts the TE module to ONNX, and uses ONNX Runtime (ORT) to execute the ONNX graph and -validate the output against TE's output. - -Until FP8 is introduced to the ONNX standard, FP8 QuantizeLinear/DequantizeLinear is implemented -using custom ORT operations. - -To run many repetitive tests use pytest-loop: - $ python3 -m pip install pytest-loop - $ pytest --loop 1000 tests/pytorch/test_onnx_export.py::test_export_layernorm - -For reproducability use: torch.manual_seed(0) -""" - -import os -import tempfile -import pytest -import warnings -import numpy as np -import onnxruntime as ort -import torch -from torch import nn as nn -from typing import Optional, Union, Tuple, List -import transformer_engine.pytorch as te -from transformer_engine.common import recipe -import transformer_engine_torch as tex -from transformer_engine.pytorch.cpp_extensions import ( - gemm, - fp8_gemm, - gelu, - cast_to_fp8, - cast_from_fp8, -) -from transformer_engine.pytorch.module.base import get_workspace -import transformer_engine.pytorch.cpp_extensions as texcpp -import transformer_engine.pytorch.softmax as softmax_defs -from transformer_engine.pytorch.utils import get_default_init_method -from transformer_engine.pytorch.export import is_in_onnx_export_mode -from transformer_engine.pytorch.fp8 import FP8GlobalStateManager - -# Global test configuration knobs. - -# Enable this to serialize test inputs and outputs to file (as a Polygraphy RunResults instance). -SAVE_TEST_IO = bool(int(os.getenv("NVTE_ONNX_EXPORT_SAVE_TEST_IO", "0"))) - -if SAVE_TEST_IO: - from polygraphy.json import save_json - from polygraphy.comparator import RunResults - -# The directory where generated ONNX test models are stored. -NVTE_TEST_ARTIFACTS_DIR = os.environ.get("NVTE_TEST_ARTIFACTS_DIR") -NVTE_TEST_ARTIFACTS_DIR = NVTE_TEST_ARTIFACTS_DIR or os.path.join( - tempfile.gettempdir(), "./gen_onnx_models" -) - - -# The directory where this file is stored. -TESTS_DIR = os.path.dirname(os.path.abspath(__file__)) - -# ScaledUpperTriangMaskedSoftmax is exported via ONNX::Trilu which was introduced in opset 14. -TRILU_OPSET = 14 -# Opset used in the ONNX files generated by the tests. -OPSET = 17 -assert OPSET >= TRILU_OPSET - -# Shared library implementing custom FP8 Q/DQ operators for ONNX Runtime (ORT). -ORT_CUSTOM_OPS_LIB = os.path.join(TESTS_DIR, "custom_ort_ops", "libcustom_ort_ops.so") - -fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() -skip_FP8 = pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) - -supported_activations = ["gelu", "relu", "reglu", "geglu", "swiglu"] - -all_normalizations = ["LayerNorm", "RMSNorm"] - - -@pytest.fixture() -def seed_default_rng(): - """Reseed the PRNG for test reproducibility""" - torch.manual_seed(1234) - - -@pytest.fixture() -def set_max_seq_len(max_seq_len=128): - """Set the maximum sequence length that can be used for attention masking""" - os.environ["NVTE_ONNX_KVCACHE_MAX_SEQ_LEN"] = f"{max_seq_len}" - - -@pytest.fixture(autouse=True) -def reset_global_fp8_state(): - yield - FP8GlobalStateManager.reset() - - -def create_fp8_recipe(): - return recipe.DelayedScaling(margin=0, fp8_format=recipe.Format.E4M3) - - -def do_export( - model: torch.nn.Module, - inp: torch.Tensor, - fname: str, - use_fp8: bool = True, - opset: int = OPSET, - input_names: List[str] = None, - output_names: List[str] = None, - dynamic_axes: List[str] = None, -): - """Export to ONNX""" - fp8_recipe = create_fp8_recipe() - input_names = input_names or ["input"] - output_names = output_names or ["output"] - - with torch.inference_mode(), te.fp8_autocast( - enabled=use_fp8, fp8_recipe=fp8_recipe - ), warnings.catch_warnings(): - warnings.filterwarnings(action="ignore", category=torch.jit.TracerWarning, module=r".*") - - model.cuda().eval() - os.makedirs(NVTE_TEST_ARTIFACTS_DIR, exist_ok=True) - fname = os.path.join(NVTE_TEST_ARTIFACTS_DIR, fname) - - inps = inp if isinstance(inp, list) or isinstance(inp, tuple) else (inp,) - assert len(inps) == len(input_names) - inds_to_del = [i for i in range(len(inps)) if inps[i] is None] - input_names = [input_names[i] for i in range(len(inps)) if i not in inds_to_del] - - with te.onnx_export(True): - torch.onnx.export( - model, - inps, - fname, - verbose=True, - dynamic_axes=dynamic_axes, - opset_version=opset, - input_names=input_names, - output_names=output_names, - do_constant_folding=True, - operator_export_type=torch.onnx.OperatorExportTypes.ONNX_FALLTHROUGH, - ) - - -def to_numpy(tensor): - if isinstance(tensor, torch.Tensor): - if tensor.dtype == torch.bfloat16: - tensor = tensor.type(torch.float32) - tensor = tensor.detach().cpu().numpy() - return tensor - - -def set_layer_scale(module: torch.nn.Module, scale: float, num_gemms: int): - """Initialize the FP8 quantization scales in module""" - NB_SCALES_PER_GEMM = 3 # One scale per: input, weights, and output GEMM tensors. - nb_total_scales = num_gemms * NB_SCALES_PER_GEMM - module.init_fp8_metadata(num_gemms) - module.fp8_meta["scaling_fwd"].scale = ( - torch.ones(nb_total_scales, dtype=torch.float32, device="cuda") / scale - ) - module.fp8_meta["scaling_fwd"].scale_inv = ( - torch.ones(nb_total_scales, dtype=torch.float32, device="cuda") * scale - ) - - -def te_infer(model: torch.nn.Module, inps: Union[Tuple[torch.tensor], torch.tensor], is_fp8: bool): - """Transformer Engine forward propagation.""" - fp8_recipe = create_fp8_recipe() - with torch.inference_mode(), te.fp8_autocast( - enabled=is_fp8, fp8_recipe=fp8_recipe - ), warnings.catch_warnings(): - te_outputs = model(*inps if isinstance(inps, tuple) else (inps,)) - if not isinstance(te_outputs, tuple): - te_outputs = (te_outputs,) - return te_outputs - - -def compare_outputs( - onnx_outputs, te_outputs, atol, rtol, max_errors_printed, allow_cnt_errors, fname -): - """Compare ORT and TE outputs.""" - assert len(onnx_outputs) == len(te_outputs) - # Compare ORT and PyTorch outputs. - for onnx_output, te_output in zip(onnx_outputs, te_outputs): - # np.isclose: abs(a - b) <= (atol + rtol * abs(b)) - te_output = to_numpy(te_output) - onnx_output = to_numpy(onnx_output) - ac = ~np.isclose(onnx_output, te_output, atol=atol, rtol=rtol) - mismatches = ac.nonzero() - mismatched_ids = [loc for loc in zip(*mismatches)] - if mismatched_ids: - # Log some information in case of error. - print("*" * 100) - nb_errors = len(mismatched_ids) - nb_vals = min(nb_errors, max_errors_printed) - print(f"Detected {nb_errors} diverging values (output shape={onnx_output.shape})") - print(f"Showing first {nb_vals} errors (ONNX -- TE):") - abs_err = np.abs(onnx_output - te_output) - errors = abs_err[mismatches] - for loc in mismatched_ids[:nb_vals]: - ref = te_output[loc] - print( - f"{onnx_output[loc]} -- {te_output[loc]} err={abs_err[loc]} >" - f" {atol + rtol * abs(ref)}" - ) - print(f"Max error: {np.max(errors)}") - if nb_errors > allow_cnt_errors: - raise ValueError(f"Output validation of {fname} failed with {nb_errors} errors") - - -def serialize_inputs_outputs( - fname: str, - inputs: Union[Tuple[torch.Tensor], torch.Tensor], - te_outputs: List[torch.Tensor], - input_names: Optional[List[str]] = None, - output_names: Optional[List[str]] = None, -): - if not SAVE_TEST_IO: - return - - fname = os.path.join(NVTE_TEST_ARTIFACTS_DIR, fname) - - input_names = input_names or ["input"] - output_names = output_names or ["output"] - inputs = inputs if isinstance(inputs, list) or isinstance(inputs, tuple) else (inputs,) - named_inputs = zip(input_names, inputs) - input_data = [{k: v.cpu() for k, v in named_inputs if v is not None}] - json_fname = fname[: -len(".onnx")] + "_inputs.json" - save_json(input_data, json_fname, description="custom input data") - - json_fname = fname[: -len(".onnx")] + "_output.json" - named_outputs = zip(output_names, te_outputs) - output_data = {k: v.detach().cpu() for k, v in named_outputs if v is not None} - custom_outputs = RunResults() - custom_outputs.add([output_data], runner_name="custom_runner") - custom_outputs.save(json_fname) - - -def validate_result( - fname: str, - inps: Union[Tuple[torch.Tensor], torch.Tensor], - model: torch.nn.Module, - atol: float = 1.0e-8, # np.isclose default atol - rtol: float = 1.0e-5, # np.isclose default rtol - max_errors_printed: int = 10, - is_fp8: bool = False, - allow_cnt_errors: int = 0, - input_names: List[str] = None, - output_names: List[str] = None, - te_outputs: List[torch.Tensor] = None, -): - """Compare the outputs of a Transformer Engine (TE) module vs the outputs of its ONNX - representation using ONNX Runtime (ORT) and ensure they are close. - - The purpose of the output comparison is to validate that TE models are converted to - their correct ONNX representation by testing that TE and ORT outputs match within some - small threshold (allowing for finite precision errors). - - Argument `allow_cnt_errors` reduces test failure noise due to spurious errors by ignoring, - a very small number (0-3) of outliers. This is fine to do because these outliers are due to - small kernel implementation differences between TE and ORT and do not imply an incorrect ONNX - representation (the tests assume both ORT or TE kernels are correct). - - Argument `te_outputs` can be used to provide pre-computed TE outputs. - """ - - def create_ort_session(fname: str, is_fp8: bool): - def load_custom_ops(session_opts: ort.SessionOptions): - """For FP8 validation with ORT we need to load our custom FP8 Q/DQ extension.""" - if not os.path.exists(ORT_CUSTOM_OPS_LIB): - raise FileNotFoundError(f"Unable to find {ORT_CUSTOM_OPS_LIB}") - session_opts.register_custom_ops_library(ORT_CUSTOM_OPS_LIB) - print("registered custom FP8 Q/DQ ops!") - - """Create an ONNX Runtime session for validation.""" - kwargs = {"providers": ["CUDAExecutionProvider", "CPUExecutionProvider"]} - if is_fp8: - sess_options = ort.SessionOptions() - load_custom_ops(sess_options) - kwargs["sess_options"] = sess_options - - s = ort.InferenceSession(fname, **kwargs) - return s - - def create_ort_input_dict(session, inputs): - inputs = inputs if isinstance(inputs, list) or isinstance(inputs, tuple) else (inputs,) - input_names = [x.name for x in session.get_inputs()] - inps = [to_numpy(x) for x in inputs if x is not None] - inp_dict = dict(zip(input_names, inps)) - return inp_dict - - input_names = input_names or ["input"] - output_names = output_names or ["output"] - - # Run ORT session and TE model. - fname = os.path.join(NVTE_TEST_ARTIFACTS_DIR, fname) - if not te_outputs: - te_outputs = te_infer(model, inps, is_fp8) - ort_s = create_ort_session(fname, is_fp8) - input_feed = create_ort_input_dict(ort_s, inps) - onnx_outputs = ort_s.run(None, input_feed=input_feed) - compare_outputs( - onnx_outputs, te_outputs, atol, rtol, max_errors_printed, allow_cnt_errors, fname - ) - - -def create_meta(scale_factor: float, size: int = 1): - meta = tex.FP8TensorMeta() - meta.amax_history = torch.zeros(1, size, dtype=torch.float32, device="cuda") - meta.scale_inv = torch.ones(size, dtype=torch.float32, device="cuda") / scale_factor - meta.scale = torch.ones(size, dtype=torch.float32, device="cuda") * scale_factor - return meta - - -def dtype2str(dtype: torch.dtype, fake_bf16_io=False): - if fake_bf16_io: - assert dtype == torch.bfloat16 - return "_fake_bf16" - return { - torch.float32: "_fp32", - torch.float16: "_fp16", - torch.bfloat16: "_bf16", - }[dtype] - - -def as_te_type(dtype: torch.dtype): - return { - torch.float32: tex.DType.kFloat32, - torch.float16: tex.DType.kFloat16, - torch.bfloat16: tex.DType.kBFloat16, - }[dtype] - - -def get_attn_mask_str(use_mask, attn_mask_type): - # See FusedScaleMaskSoftmax::forward_fused_softmax for logic behind names. - if attn_mask_type is None: - return "_mask" if use_mask else "_no-mask" - attn_mask_str = "_arbitrary-no-mask" - attn_mask_str = "_causal-mask" if attn_mask_type == "causal" else attn_mask_str - attn_mask_str = ( - "_arbitrary-mask" if use_mask and attn_mask_type == "arbitrary" else attn_mask_str - ) - return attn_mask_str - - -class FP8GemmModule(nn.Module): - def __init__(self, precision, use_bias, gelu, scale_factors, hidden_size, out_features): - super().__init__() - self.use_bias = use_bias - self.gelu = gelu - self.precision = precision - - self.fp8_tensor_inp = tex.FP8FwdTensors.GEMM1_INPUT - self.fp8_tensor_weight = tex.FP8FwdTensors.GEMM1_WEIGHT - nb_inp_scales, nb_weight_scales = 1, out_features - act_scale_factor, weight_scale_factor = scale_factors - self.meta_inp = create_meta(act_scale_factor, nb_inp_scales) - self.meta_weight = create_meta(weight_scale_factor, nb_weight_scales) - - bias_size = nb_weight_scales - self.bias = torch.randn(bias_size, dtype=precision, device="cuda") - self.gelu_input = torch.randn(hidden_size, out_features, dtype=precision, device="cuda") - - self.inp_type = tex.DType.kFloat8E4M3 - self.weights_type = tex.DType.kFloat8E4M3 - self.outp_type = precision - - def forward(self, inp, weight): - inp_fp8 = cast_to_fp8(inp, self.meta_inp, self.fp8_tensor_inp, self.inp_type) - - weight_fp8 = cast_to_fp8( - weight, self.meta_weight, self.fp8_tensor_weight, self.weights_type - ) - - ret, _ = fp8_gemm( - weight_fp8, - self.meta_weight.scale_inv, - self.fp8_tensor_weight, - self.inp_type, - inp_fp8, - self.meta_inp.scale_inv, - self.fp8_tensor_inp, - self.weights_type, - self.outp_type, - get_workspace(), - bias=self.bias, - use_bias=self.use_bias, - use_split_accumulator=False, - ) - return ret - - -""" -Tests cases begin here. -""" - - -@skip_FP8 -@pytest.mark.parametrize("scale_factor", [1, 224]) -@pytest.mark.parametrize( - "precision, atol", - [ - [torch.float32, 1e-7], - [torch.float16, 1e-7], - [torch.bfloat16, 5e-3], - ["fake-torch.bfloat16", 5e-3], - ], -) -def test_export_cast_ops( - seed_default_rng, scale_factor: float, atol: float, precision: torch.dtype -): - fake_bf16_io = precision == "fake-torch.bfloat16" - # reset precision to torch.bfloat16 after capturing fake BF16 mode - precision = torch.bfloat16 if precision == "fake-torch.bfloat16" else precision - - class TestFP8_QDQ(nn.Module): - def __init__(self, fake_bf16_io): - super().__init__() - self.fp8_tensor = 0 - self.meta = create_meta(scale_factor) - self.highprec_type = as_te_type(precision) - self.fp8_type = tex.DType.kFloat8E4M3 - self.fake_bf16_io = fake_bf16_io - - def forward(self, inp): - ret = cast_to_fp8(inp, self.meta, self.fp8_tensor, self.fp8_type) - - ret = cast_from_fp8(ret, self.meta, self.fp8_tensor, self.fp8_type, self.highprec_type) - if self.fake_bf16_io: - ret = ret.type(torch.float32) - return ret - - # Set dimensions (these are arbitrary). - in_features = 64 - hidden_size = 256 - inp = torch.randn( - hidden_size, in_features, device="cuda", dtype=torch.float if fake_bf16_io else precision - ) - high_prec_str = dtype2str(precision, fake_bf16_io=fake_bf16_io) - fname = f"te.cast_fp8_{scale_factor}{high_prec_str}.onnx" - model = TestFP8_QDQ(fake_bf16_io) - - do_export(model, inp, fname) - te_outputs = te_infer(model, inp, is_fp8=True) - serialize_inputs_outputs(fname, inp, te_outputs) - if fake_bf16_io or precision != torch.bfloat16: - validate_result(fname, inp, model, atol=atol, is_fp8=True, te_outputs=te_outputs) - - -@skip_FP8 -@pytest.mark.parametrize("scale_factor", [448]) -@pytest.mark.parametrize( - "precision, atol", - [ - [torch.float32, 1e-5], - [torch.float16, 1e-5], - [torch.bfloat16, 5e-3], - ["fake-torch.bfloat16", 5e-3], - ], -) -def test_export_gelu_fp8(scale_factor: float, precision: torch.dtype, atol: float): - fake_bf16_io = precision == "fake-torch.bfloat16" - # reset precision to torch.bfloat16 after capturing fake BF16 mode - precision = torch.bfloat16 if precision == "fake-torch.bfloat16" else precision - - class TestFP8_Gelu(nn.Module): - def __init__(self, fake_bf16_io): - super().__init__() - self.fp8_tensor = 0 - self.meta = create_meta(scale_factor) - self.highprec_type = as_te_type(precision) - self.fp8_type = tex.DType.kFloat8E4M3 - self.fake_bf16_io = fake_bf16_io - - def forward(self, inp): - ret = gelu(inp, self.meta, self.fp8_tensor, self.fp8_type) - ret = cast_from_fp8(ret, self.meta, self.fp8_tensor, self.fp8_type, self.highprec_type) - if self.fake_bf16_io: - ret = ret.type(torch.float32) - return ret - - # Set dimensions (these are arbitrary). - in_features = 64 - hidden_size = 256 - inp = torch.randn( - hidden_size, in_features, device="cuda", dtype=torch.float if fake_bf16_io else precision - ) - high_prec_str = dtype2str(precision, fake_bf16_io=fake_bf16_io) - fname = f"te.gelu_fp8_{scale_factor}{high_prec_str}.onnx" - model = TestFP8_Gelu(fake_bf16_io) - do_export(model, inp, fname) - te_outputs = te_infer(model, inp, is_fp8=True) - serialize_inputs_outputs(fname, inp, te_outputs) - if fake_bf16_io or precision != torch.bfloat16: - validate_result( - fname, - inp, - model, - rtol=0, - atol=atol, - is_fp8=True, - allow_cnt_errors=2, - te_outputs=te_outputs, - ) - - -@pytest.mark.parametrize( - "scale_factors", - [ - ( - 224, - 224, - ), - ], -) -@pytest.mark.parametrize( - "precision, use_fp8, use_bias, use_gelu", - [ - (torch.float32, False, False, False), - (torch.float16, False, False, False), - (torch.bfloat16, False, False, False), - (torch.float32, False, True, False), - (torch.float16, False, True, False), - (torch.bfloat16, False, True, False), - (torch.float32, False, True, True), - (torch.float16, False, True, True), - (torch.bfloat16, False, True, True), - # For FP8 GEMM GeLU is not used. - (torch.float32, True, False, False), - (torch.float16, True, False, False), - (torch.bfloat16, True, False, False), - # When enabling bias we must use float16 or bfloat16 (because of kernel limitations) - (torch.float16, True, True, False), - (torch.bfloat16, True, True, False), - ], -) -def test_export_gemm( - seed_default_rng, - precision, # Precision of inputs, weights, output and bias - use_fp8, - use_bias, - use_gelu, - scale_factors, -): - # Skip FP8 tests on non-hopper devices - if use_fp8 and not fp8_available: - pytest.skip(reason_for_no_fp8) - - class Test_GEMM(nn.Module): - def __init__(self, precision, use_bias=False, gelu=False): - super().__init__() - self.use_bias = use_bias - self.gelu = gelu - self.precision = precision - bias_size = out_features - self.bias = torch.randn(bias_size, dtype=precision, device="cuda") - self.gelu_input = torch.randn(hidden_size, out_features, dtype=precision, device="cuda") - - def forward(self, inp, weight): - outp_type = self.precision - - # note: due to logic in lines 104:116 and L129 in cpp_extensions.py - # it appears either bias OR gelu can be activated, not both - ret, _, _ = gemm( - weight, - inp, - outp_type, - get_workspace(), - # test bias - bias=self.bias, - use_bias=self.use_bias, - # test gelu - gelu=self.gelu, - gelu_input=self.gelu_input, - grad=False, # only True for backward pass - accumulate=False, - ) - return ret - - # If gelu is applied then bias must be added, as defined by TE kernel. - if use_gelu: - assert use_bias - # Set dimensions (these are arbitrary). - out_features = 128 - hidden_size = 256 - in_features = 64 - inp = torch.randn(hidden_size, in_features, device="cuda", dtype=precision) - weight = torch.randn(out_features, in_features, device="cuda", dtype=precision) - fp8_str = "_fp8" if use_fp8 else "" - bias_str = "_bias" if use_bias else "" - gelu_str = "_gelu" if use_gelu else "" - high_prec_str = dtype2str(precision) - fname = f"te.gemm{fp8_str}{bias_str}{gelu_str}{high_prec_str}.onnx" - input_names = ["input", "weight"] - if use_fp8: - model = FP8GemmModule( - precision, use_bias, use_gelu, scale_factors, hidden_size, out_features - ) - do_export(model, (inp, weight), fname, use_fp8, input_names=input_names) - te_outputs = te_infer(model, (inp, weight), is_fp8=use_fp8) - serialize_inputs_outputs(fname, (inp, weight), te_outputs, input_names=input_names) - if precision != torch.bfloat16: - validate_result( - fname, - (inp, weight), - model, - rtol=1e-2, - atol=2e-2, - is_fp8=True, - input_names=input_names, - te_outputs=te_outputs, - ) - else: - model = Test_GEMM(precision, use_bias, use_gelu) - do_export(model, (inp, weight), fname, use_fp8, input_names=input_names) - te_outputs = te_infer(model, (inp, weight), is_fp8=use_fp8) - serialize_inputs_outputs(fname, (inp, weight), te_outputs, input_names=input_names) - if precision != torch.bfloat16: - validate_result( - fname, - (inp, weight), - model, - rtol=1e-2, - atol=2e-2, - input_names=input_names, - te_outputs=te_outputs, - ) - - -@pytest.mark.parametrize("scale_factor", [448, 112]) -@pytest.mark.parametrize("zero_centered_gamma", [False, True]) -@pytest.mark.parametrize( - "use_fp8, precision, atol", - [ - [False, torch.float32, 1e-7], - [False, torch.float16, 1e-7], - [False, torch.bfloat16, 1e-7], - [False, "fake-torch.bfloat16", 1e-7], - [True, torch.float32, 1e-7], - [True, torch.float16, 1e-7], - [True, torch.bfloat16, 1e-2], - [True, "fake-torch.bfloat16", 1e-2], - ], -) -def test_export_layernorm( - seed_default_rng, - use_fp8: bool, - scale_factor: float, - precision: torch.dtype, - zero_centered_gamma: bool, - atol: float, -): - fake_bf16_io = precision == "fake-torch.bfloat16" - # reset precision to torch.bfloat16 after capturing fake BF16 mode - precision = torch.bfloat16 if precision == "fake-torch.bfloat16" else precision - - # Skip FP8 tests on non-hopper devices - if use_fp8 and not fp8_available: - pytest.skip(reason_for_no_fp8) - - # Set dimensions (these are arbitrary). - inp_shape = [64, 32] - - class Test_Layernorm(nn.Module): - def __init__(self) -> None: - super().__init__() - eps = 1e-6 # An arbitrary small value - dtype = torch.float if fake_bf16_io else precision - self.ln = ( - te.LayerNorm( - inp_shape[1], eps, params_dtype=dtype, zero_centered_gamma=zero_centered_gamma - ) - .eval() - .cuda() - ) - - def forward(self, inp): - ret = self.ln(inp) - return ret - - class TestFP8_Layernorm(nn.Module): - def __init__(self) -> None: - super().__init__() - normalized_shape = torch.Size(inp.shape[1:]) - self.weight = torch.randn( - *normalized_shape, device="cuda", dtype=torch.float32 if fake_bf16_io else precision - ) - self.bias = torch.zeros( - *normalized_shape, device="cuda", dtype=torch.float32 if fake_bf16_io else precision - ) - self.eps = 1e-6 # An arbitrary small value - - self.fp8_tensor = tex.FP8FwdTensors.GEMM1_INPUT - self.meta = create_meta(scale_factor) - self.fp8_type = tex.DType.kFloat8E4M3 - - def forward(self, inp): - ret = texcpp.layernorm_fwd_fp8_inf( - inp, - self.weight, - self.bias, - self.eps, - self.meta, - self.fp8_tensor, - self.fp8_type, - 0, - zero_centered_gamma, - ) - - ret = cast_from_fp8( - ret, self.meta, self.fp8_tensor, self.fp8_type, as_te_type(precision) - ) - if fake_bf16_io: - ret = ret.type(torch.float32) - return ret - - inp = torch.randn(*inp_shape, device="cuda", dtype=torch.float32 if fake_bf16_io else precision) - model = TestFP8_Layernorm() if use_fp8 else Test_Layernorm() - high_prec_str = dtype2str(precision, fake_bf16_io=fake_bf16_io) - fp8_str = f"_fp8-{scale_factor}" if use_fp8 else "" - fname = f"te.layernorm{fp8_str}{high_prec_str}.onnx" - do_export(model, inp, fname, use_fp8=use_fp8) - te_outputs = te_infer(model, inp, is_fp8=use_fp8) - serialize_inputs_outputs(fname, inp, te_outputs) - if fake_bf16_io or precision != torch.bfloat16: - validate_result( - fname, inp, model, atol=atol, is_fp8=use_fp8, allow_cnt_errors=3, te_outputs=te_outputs - ) - - -@pytest.mark.parametrize("scale_factor", [448, 112]) -@pytest.mark.parametrize("zero_centered_gamma", [False, True]) -@pytest.mark.parametrize( - "use_fp8, precision, atol", - [ - [False, torch.float32, 1e-7], - [False, torch.float16, 1e-7], - [False, torch.bfloat16, 1e-7], - [False, "fake-torch.bfloat16", 1e-7], - [True, torch.float32, 1e-7], - [True, torch.float16, 1e-7], - [True, torch.bfloat16, 1e-2], - [True, "fake-torch.bfloat16", 1e-2], - ], -) -def test_export_rmsnorm( - seed_default_rng, - use_fp8: bool, - scale_factor: float, - precision: torch.dtype, - zero_centered_gamma: bool, - atol: float, -): - fake_bf16_io = precision == "fake-torch.bfloat16" - # reset precision to torch.bfloat16 after capturing fake BF16 mode - precision = torch.bfloat16 if precision == "fake-torch.bfloat16" else precision - - # Skip FP8 tests on non-hopper devices - if use_fp8 and not fp8_available: - pytest.skip(reason_for_no_fp8) - - # Set dimensions (these are arbitrary). - inp_shape = [64, 32] - - class Test_RMSnorm(nn.Module): - def __init__(self) -> None: - super().__init__() - eps = 1e-6 # An arbitrary small value - dtype = torch.float if fake_bf16_io else precision - self.ln = ( - te.RMSNorm( - inp_shape[1], eps, params_dtype=dtype, zero_centered_gamma=zero_centered_gamma - ) - .eval() - .cuda() - ) - - def forward(self, inp): - ret = self.ln(inp) - return ret - - class TestFP8_RMSnorm(nn.Module): - def __init__(self) -> None: - super().__init__() - normalized_shape = torch.Size(inp.shape[1:]) - self.weight = torch.randn( - *normalized_shape, device="cuda", dtype=torch.float32 if fake_bf16_io else precision - ) - self.eps = 1e-6 # An arbitrary small value - - self.fp8_tensor = tex.FP8FwdTensors.GEMM1_INPUT - self.meta = create_meta(scale_factor) - self.fp8_type = tex.DType.kFloat8E4M3 - - def forward(self, inp): - ret = texcpp.rmsnorm_fwd_fp8_inf( - inp, - self.weight, - self.eps, - self.meta, - self.fp8_tensor, - self.fp8_type, - 0, - zero_centered_gamma, - ) - - ret = cast_from_fp8( - ret, self.meta, self.fp8_tensor, self.fp8_type, as_te_type(precision) - ) - if fake_bf16_io: - ret = ret.type(torch.float32) - return ret - - inp = torch.randn(*inp_shape, device="cuda", dtype=torch.float32 if fake_bf16_io else precision) - model = TestFP8_RMSnorm() if use_fp8 else Test_RMSnorm() - high_prec_str = dtype2str(precision, fake_bf16_io=fake_bf16_io) - fp8_str = f"_fp8-{scale_factor}" if use_fp8 else "" - fname = f"te.layernorm{fp8_str}{high_prec_str}.onnx" - do_export(model, inp, fname, use_fp8=use_fp8) - te_outputs = te_infer(model, inp, is_fp8=use_fp8) - serialize_inputs_outputs(fname, inp, te_outputs) - if fake_bf16_io or precision != torch.bfloat16: - validate_result( - fname, inp, model, atol=atol, is_fp8=use_fp8, allow_cnt_errors=3, te_outputs=te_outputs - ) - - -@pytest.mark.parametrize("scale_factor", [1]) -@pytest.mark.parametrize("use_fp8", [False, True]) -# Returning the bias is a TE fusion optimization we don't care about. -@pytest.mark.parametrize("return_bias", [False]) -@pytest.mark.parametrize( - "precision, use_bias", - [ - (torch.float32, False), - (torch.float32, True), - (torch.float16, False), - (torch.float16, True), - # Todo: cannot configure BF16 when bias is disabled (ORT issue?) - (torch.bfloat16, False), - # Todo: cannot configure BF16 when bias is enabled (ORT issue?) - (torch.bfloat16, True), - ], -) -def test_export_linear( - seed_default_rng, - scale_factor: float, - use_fp8: bool, - use_bias: bool, - return_bias: bool, - precision: torch.dtype, -): - # Skip FP8 tests on non-hopper devices - if use_fp8 and not fp8_available: - pytest.skip(reason_for_no_fp8) - - # Set dimensions (these are arbitrary). - in_features = 64 - out_features = 256 - hidden_size = 256 - - class Test_Linear(nn.Module): - def __init__(self, in_features, out_features, use_bias, return_bias, precision): - super().__init__() - self.linear = te.Linear( - in_features, - out_features, - bias=use_bias, - return_bias=return_bias, - params_dtype=precision, - ) - - def forward(self, inp): - ret = self.linear(inp) - return ret - - inp = torch.randn(hidden_size, in_features, device="cuda", dtype=precision) - fp8_str = "_fp8" if use_fp8 else "" - bias_str = "_bias" if use_bias else "" - high_prec_str = dtype2str(precision) - fname = f"te.linear{fp8_str}{bias_str}{high_prec_str}.onnx" - with te.fp8_autocast(enabled=use_fp8): - model = Test_Linear(in_features, out_features, use_bias, return_bias, precision).to( - device="cuda" - ) - if use_fp8: - set_layer_scale(model.linear, scale_factor, num_gemms=1) - do_export(model, inp, fname, use_fp8) - te_outputs = te_infer(model, inp, is_fp8=use_fp8) - serialize_inputs_outputs(fname, inp, te_outputs) - - if precision in (torch.bfloat16,): - return - if not use_fp8: - validate_result(fname, inp, model, atol=1e-3, te_outputs=te_outputs) - else: - validate_result(fname, inp, model, atol=1e-3, is_fp8=use_fp8, te_outputs=te_outputs) - - -@pytest.mark.parametrize("scale_factor", [112]) -@pytest.mark.parametrize("use_fp8", [False, True]) -# Returning the bias is a TE fusion optimization we don't care about. -@pytest.mark.parametrize("return_bias", [False]) -@pytest.mark.parametrize("return_layernorm_output", [False]) -@pytest.mark.parametrize( - "precision, use_bias", - [ - (torch.float32, False), - (torch.float32, True), - (torch.float16, True), - (torch.float16, False), - (torch.bfloat16, True), - (torch.bfloat16, False), - ], -) -@pytest.mark.parametrize("zero_centered_gamma", [False, True]) -@pytest.mark.parametrize("normalization", all_normalizations) -def test_export_layernorm_linear( - seed_default_rng, - scale_factor: float, - use_fp8: bool, - use_bias: bool, - return_bias: bool, - return_layernorm_output: bool, - precision: torch.dtype, - zero_centered_gamma: bool, - normalization: str, -): - # Skip FP8 tests on non-hopper devices - if use_fp8 and not fp8_available: - pytest.skip(reason_for_no_fp8) - - # Set dimensions (these are arbitrary). - in_features = 64 - out_features = 256 - hidden_size = 256 - - inp = torch.randn(in_features, out_features, device="cuda", dtype=precision) - fp8_str = "_fp8" if use_fp8 else "" - bias_str = "_bias" if use_bias else "" - high_prec_str = dtype2str(precision) - fname = f"te.layernorm_linear{fp8_str}{bias_str}{high_prec_str}.onnx" - - with te.fp8_autocast(enabled=use_fp8): - model = te.LayerNormLinear( - hidden_size, - 3 * hidden_size, - bias=use_bias, - return_bias=return_bias, - return_layernorm_output=return_layernorm_output, - params_dtype=precision, - zero_centered_gamma=zero_centered_gamma, - normalization=normalization, - ).to(device="cuda") - if use_fp8: - set_layer_scale(model, scale_factor, num_gemms=1) - do_export(model, inp, fname, use_fp8) - te_outputs = te_infer(model, inp, is_fp8=use_fp8) - serialize_inputs_outputs(fname, inp, te_outputs) - if precision in (torch.bfloat16,): - return - if not use_fp8: - validate_result(fname, inp, model, atol=1e-3, te_outputs=te_outputs) - elif precision != torch.bfloat16: - validate_result(fname, inp, model, atol=1e-6, is_fp8=use_fp8, te_outputs=te_outputs) - - -@pytest.mark.parametrize("scale_factor", [112]) -@pytest.mark.parametrize("use_fp8", [False, True]) -# Returning the bias is a TE fusion optimization we don't care about. -@pytest.mark.parametrize("return_bias", [False]) -@pytest.mark.parametrize("return_layernorm_output", [False]) -@pytest.mark.parametrize( - "precision, use_bias", - [ - (torch.float32, False), - (torch.float32, True), - (torch.float16, True), - (torch.float16, False), - (torch.bfloat16, True), - (torch.bfloat16, False), - ], -) -@pytest.mark.parametrize("zero_centered_gamma", [False, True]) -@pytest.mark.parametrize("activation", supported_activations) -@pytest.mark.parametrize("normalization", all_normalizations) -def test_export_layernorm_mlp( - seed_default_rng, - scale_factor: float, - use_fp8: bool, - use_bias: bool, - return_bias: bool, - return_layernorm_output: bool, - precision: torch.dtype, - zero_centered_gamma: bool, - activation: str, - normalization: str, -): - # Skip FP8 tests on non-hopper devices - if use_fp8 and not fp8_available: - pytest.skip(reason_for_no_fp8) - - # Set dimensions (these are arbitrary). - in_features = 64 - out_features = 256 - hidden_size = 256 - ffn_hidden_size = 256 - - inp = torch.randn(in_features, out_features, device="cuda", dtype=precision) - fp8_str = "_fp8" if use_fp8 else "" - bias_str = "_bias" if use_bias else "" - high_prec_str = dtype2str(precision) - fname = f"te.layernorm_mlp{fp8_str}{bias_str}{high_prec_str}_{activation}.onnx" - with te.fp8_autocast(enabled=use_fp8): - model = te.LayerNormMLP( - hidden_size, - ffn_hidden_size, - bias=use_bias, - return_bias=return_bias, - return_layernorm_output=return_layernorm_output, - params_dtype=precision, - zero_centered_gamma=zero_centered_gamma, - activation=activation, - normalization=normalization, - ).to(device="cuda") - if use_fp8: - set_layer_scale(model, scale_factor, num_gemms=2) - do_export(model, inp, fname, use_fp8) - te_outputs = te_infer(model, inp, is_fp8=use_fp8) - serialize_inputs_outputs(fname, inp, te_outputs) - if precision in (torch.bfloat16,): - return - atol = 1e-6 if use_fp8 else (5e-1 if activation == "swiglu" else 1e-3) - validate_result(fname, inp, model, atol=atol, is_fp8=use_fp8, te_outputs=te_outputs) - - -@skip_FP8 -@pytest.mark.parametrize( - "precision, use_mask, attn_mask_type", - [ - (torch.float32, True, "arbitrary"), # calls forward_torch_softmax (apply user mask) - (torch.float32, False, "no_mask"), # calls forward_torch_softmax (apply no mask) - (torch.float16, False, "causal"), # calls forward_torch_softmax (apply dynamic onnx mask) - (torch.float16, True, "arbitrary"), # calls forward_torch_softmax (apply user mask) - (torch.float16, False, "no_mask"), # calls forward_torch_softmax (apply no mask) - (torch.bfloat16, False, "causal"), # calls forward_torch_softmax (apply dynamic onnx mask) - (torch.bfloat16, True, "arbitrary"), # calls forward_torch_softmax (apply user mask) - (torch.bfloat16, False, "no_mask"), # calls forward_torch_softmax (apply no mask) - ], -) -def test_export_core_attention( - seed_default_rng, - set_max_seq_len, - precision: torch.dtype, - use_mask: bool, - attn_mask_type: str, -): - # Set dimensions (these are arbitrary). - seq_len, batch_size, num_attention_heads, kv_channels = (64, 4, 1, 64) - qkv_size = (seq_len, batch_size, num_attention_heads, kv_channels) - qkv_format = "sbhd" - - query_layer = torch.randn(qkv_size, dtype=precision, device="cuda") - key_layer = torch.randn(qkv_size, dtype=precision, device="cuda") - value_layer = torch.randn(qkv_size, dtype=precision, device="cuda") - input_names = ["query", "key", "value", "attention_mask"] - attention_mask = None - if use_mask: - # Generate a random mask with 50% probability for 0 or 1. - probs = 0.5 * torch.ones(batch_size, 1, 1, seq_len, device="cuda", dtype=precision) - attention_mask = torch.bernoulli(probs).to("cuda", dtype=torch.bool) - inp = (query_layer, key_layer, value_layer, attention_mask) - - mask_str = get_attn_mask_str(use_mask, attn_mask_type) - high_prec_str = dtype2str(precision) - fname = f"te.core_attention{mask_str}{high_prec_str}.onnx" - - model = te.attention.DotProductAttention( - num_attention_heads=num_attention_heads, - kv_channels=kv_channels, - attention_dropout=0.5, - qkv_format=qkv_format, - attn_mask_type=attn_mask_type, - ).to(device="cuda") - do_export(model, inp, fname, input_names=input_names, use_fp8=True) - te_outputs = te_infer(model, inp, is_fp8=True) - serialize_inputs_outputs(fname, inp, te_outputs, input_names=input_names) - if precision in (torch.bfloat16,): - return - validate_result( - fname, inp, model, is_fp8=True, atol=1e-2, input_names=input_names, te_outputs=te_outputs - ) - - -test_configs_multihead_attention = [ - # "use_mask, attn_mask_type" - (False, "no_mask"), # calls ScaledSoftmax - (True, "arbitrary"), # calls ScaledMaskedSoftmax -] -test_configs_attention_type = [ - # "input_layernorm, attention_type, fuse_qkv_params" - (True, "self", True), - (False, "self", True), - (True, "self", False), - (False, "self", False), - (True, "cross", True), - (False, "cross", True), - (True, "cross", False), - (False, "cross", False), -] - - -@pytest.mark.parametrize("use_fp8", [False, True]) -@pytest.mark.parametrize("use_mask, attn_mask_type", test_configs_multihead_attention) -@pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16]) -@pytest.mark.parametrize("return_layernorm_output", [False]) -@pytest.mark.parametrize( - "input_layernorm, attention_type, fuse_qkv_params", test_configs_attention_type -) -def test_export_multihead_attention( - seed_default_rng, - set_max_seq_len, - use_fp8: bool, - use_mask: bool, - attn_mask_type: str, - precision: torch.dtype, - return_layernorm_output: bool, - input_layernorm: bool, - attention_type: str, - fuse_qkv_params: bool, -): - # Skip FP8 tests on non-hopper devices - if use_fp8 and not fp8_available: - pytest.skip(reason_for_no_fp8) - - hidden_size = 256 - sequence_length = 128 - batch_size = 4 - num_attention_heads = 32 - kv_channels = 8 - attention_dropout = 0.1 - layernorm_epsilon = 1e-5 - init_method = output_layer_init_method = get_default_init_method() - attention_args = ( - hidden_size, - num_attention_heads, - kv_channels, - attention_dropout, - layernorm_epsilon, - init_method, - output_layer_init_method, - ) - - hidden_states_context = torch.randn( - sequence_length, batch_size, hidden_size, dtype=precision, device="cuda" - ) - attention_mask = None - if use_mask and attn_mask_type != "causal": - # Generate a random mask with 50% probability for 0 or 1. - probs = 0.5 * torch.ones( - batch_size, 1, sequence_length, sequence_length, device="cuda", dtype=precision - ) - attention_mask = torch.bernoulli(probs).to("cuda", dtype=torch.bool) - - encoder_output = None - - if attention_type == "cross": - encoder_output = torch.randn( - sequence_length, batch_size, hidden_size, dtype=precision, device="cuda" - ) - - fp8_str = "_fp8" if use_fp8 else "" - dtype_str = dtype2str(precision) - attn_type_str = "_self-attention" if attention_type == "self" else "_cross-attention" - fuse_qkv_str = "_fused-qkv" if fuse_qkv_params else "" - attn_mask_str = get_attn_mask_str(use_mask, attn_mask_type) - input_ln_str = "_input-ln" if input_layernorm else "" - fname = f"te.multihead_attention{fp8_str}{attn_mask_str}{attn_type_str}{input_ln_str}{fuse_qkv_str}{dtype_str}.onnx" - - model = te.MultiheadAttention( - *attention_args, - attn_mask_type=attn_mask_type, - params_dtype=precision, - return_layernorm_output=return_layernorm_output, - input_layernorm=input_layernorm, - attention_type=attention_type, - fuse_qkv_params=fuse_qkv_params, - return_bias=True, - ).to(device="cuda") - - inp_context = (hidden_states_context, attention_mask, encoder_output) - input_names = ["hidden_states", "attention_mask", "encoder_output"] - output_names = ["attention_output", "attention_bias"] - do_export( - model, - inp_context, - fname, - use_fp8, - input_names=input_names, - output_names=output_names, - dynamic_axes={ - "hidden_states": {0: "seq", 1: "bs"}, - "attention_output": {0: "seq", 1: "bs"}, - }, - ) - te_outputs = te_infer(model, inp_context, is_fp8=use_fp8) - serialize_inputs_outputs( - fname, inp_context, te_outputs, input_names=input_names, output_names=output_names - ) - if precision in (torch.bfloat16,): - return - - if not use_fp8: - validate_result( - fname, - inp_context, - model, - atol=1e-3, - input_names=input_names, - output_names=output_names, - te_outputs=te_outputs, - ) - else: - validate_result( - fname, - inp_context, - model, - atol=1e-2, - is_fp8=use_fp8, - input_names=input_names, - output_names=output_names, - allow_cnt_errors=3, - te_outputs=te_outputs, - ) - - # In GPT generative phase (inference) the input sequence is smaller than the maximum - # allowed sequence length and we want to test this condition. - # Pretend that we're in generative phase when it makes sense (causal mask and self-attention). - is_generative_phase = attn_mask_type == "causal" and attention_type == "self" - if is_generative_phase: - seq_len_offset = 8 - hidden_states_generative = torch.randn( - sequence_length - seq_len_offset, - batch_size, - hidden_size, - dtype=precision, - device="cuda", - ) - inp_generative = (hidden_states_generative, attention_mask, encoder_output) - if not use_fp8: - validate_result( - fname, - inp_generative, - model, - atol=1e-3, - input_names=input_names, - output_names=output_names, - ) - else: - validate_result( - fname, - inp_generative, - model, - atol=1e-2, - is_fp8=use_fp8, - input_names=input_names, - output_names=output_names, - allow_cnt_errors=3, - ) - - -@pytest.mark.parametrize("use_fp8", [False, True]) -@pytest.mark.parametrize("use_mask, attn_mask_type", test_configs_multihead_attention) -@pytest.mark.parametrize( - "output_layernorm", - [ - # True, # TO DO: handle this - False - ], -) -@pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16]) -@pytest.mark.parametrize("fuse_qkv_params", [False, True]) -@pytest.mark.parametrize("zero_centered_gamma", [False, True]) -@pytest.mark.parametrize("activation", supported_activations) -def test_export_transformer_layer( - seed_default_rng, - set_max_seq_len, - use_fp8: bool, - use_mask: bool, - attn_mask_type: str, - output_layernorm: bool, - precision: torch.dtype, - fuse_qkv_params: bool, - zero_centered_gamma: bool, - activation: str, -): - # Skip FP8 tests on non-hopper devices - if use_fp8 and not fp8_available: - pytest.skip(reason_for_no_fp8) - - # Layer configuration - hidden_size = 64 - sequence_length = 128 - batch_size = 1 - ffn_hidden_size = 256 - num_attention_heads = 4 - - input_tensor = torch.rand( - sequence_length, batch_size, hidden_size, dtype=precision, device="cuda" - ) - input_names = ["input", "attention_mask"] - attention_mask = None - if use_mask and attn_mask_type != "causal": - # Generate a random mask with 50% probability for 0 or 1. - probs = 0.5 * torch.ones( - batch_size, 1, sequence_length, sequence_length, device="cuda", dtype=precision - ) - attention_mask = torch.bernoulli(probs).to("cuda", dtype=torch.bool) - inp = (input_tensor, attention_mask) - - fp8_str = "_fp8" if use_fp8 else "" - fuse_qkv_params_str = "_fused-qkv" if fuse_qkv_params else "" - high_prec_str = dtype2str(precision) - attn_mask_str = get_attn_mask_str(use_mask, attn_mask_type) - fname = f"te.transformer_layer{fp8_str}{attn_mask_str}{fuse_qkv_params_str}{high_prec_str}_{activation}.onnx" - - model = te.TransformerLayer( - hidden_size, - ffn_hidden_size, - num_attention_heads, - self_attn_mask_type=attn_mask_type, - output_layernorm=output_layernorm, - params_dtype=precision, - fuse_qkv_params=fuse_qkv_params, - zero_centered_gamma=zero_centered_gamma, - activation=activation, - ).to(device="cuda") - do_export(model, inp, fname, use_fp8, input_names=input_names) - te_outputs = te_infer(model, inp, is_fp8=use_fp8) - serialize_inputs_outputs(fname, inp, te_outputs, input_names=input_names) - if precision in (torch.bfloat16,): - return - atol = 5e-1 if use_fp8 else (5e-1 if activation == "swiglu" else 1e-3) - validate_result( - fname, inp, model, atol=atol, is_fp8=use_fp8, input_names=input_names, te_outputs=te_outputs - ) - - -@pytest.mark.parametrize("use_fp8", [True]) -@pytest.mark.parametrize("ln_scale_factor", [448 * 2]) -@pytest.mark.parametrize( - "gemm_scale_factors", - [ - ( - 224, - 224, - ), - ], -) -@pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16]) -@pytest.mark.parametrize("zero_centered_gamma", [False, True]) -def test_export_gemm_layernorm( - seed_default_rng, - use_fp8: bool, - ln_scale_factor: float, - gemm_scale_factors: Tuple[float, float], - precision: torch.dtype, - zero_centered_gamma: bool, -): - """This is a regression test for testing that all LN inputs have the same type. - - The test sets up GEMM with FP32 output which feeds into an LN that is configured - with FP16 or BF16 weights and bias. - """ - out_features = 128 - hidden_size = 128 - in_features = 128 - - # Skip FP8 tests on non-hopper devices - if use_fp8 and not fp8_available: - pytest.skip(reason_for_no_fp8) - - class TestFP8_GemmLayernorm(nn.Module): - def __init__(self) -> None: - super().__init__() - normalized_shape = torch.Size(inp.shape[1:]) - self.weight = torch.randn(*normalized_shape, dtype=precision, device="cuda") - self.bias = torch.zeros(*normalized_shape, dtype=precision, device="cuda") - self.eps = 1e-6 # An arbitrary small value - - self.fp8_tensor = tex.FP8FwdTensors.GEMM1_INPUT - self.meta = create_meta(ln_scale_factor) - self.fp8_type = tex.DType.kFloat8E4M3 - self.gemm = FP8GemmModule( - precision, - use_bias=False, - gelu=False, - scale_factors=gemm_scale_factors, - hidden_size=hidden_size, - out_features=out_features, - ) - - def forward(self, inp, weight): - x = self.gemm(inp, weight) - x = texcpp.layernorm_fwd_fp8_inf( - x, - self.weight, - self.bias, - self.eps, - self.meta, - self.fp8_tensor, - self.fp8_type, - 0, - zero_centered_gamma, - ) - - x = cast_from_fp8( - x, - self.meta, - self.fp8_tensor, - self.fp8_type, - tex.DType.kFloat32 if precision == torch.float32 else tex.DType.kFloat16, - ) - return x - - inp = torch.randn(hidden_size, in_features, dtype=precision, device="cuda") - weight = torch.randn(out_features, in_features, dtype=precision, device="cuda") - model = TestFP8_GemmLayernorm() - high_prec_str = dtype2str(precision) - fp8_str = f"_fp8" if use_fp8 else "" - fname = f"te.gemm_layernorm{fp8_str}{high_prec_str}.onnx" - input_names = ["input", "weight"] - do_export(model, (inp, weight), fname, use_fp8=use_fp8, input_names=input_names) - te_outputs = te_infer(model, (inp, weight), is_fp8=use_fp8) - serialize_inputs_outputs(fname, (inp, weight), te_outputs, input_names=input_names) - if precision not in (torch.bfloat16,): - validate_result( - fname, - (inp, weight), - model, - atol=5e-2, - is_fp8=use_fp8, - allow_cnt_errors=2, - input_names=input_names, - te_outputs=te_outputs, - ) - - -@skip_FP8 -@pytest.mark.parametrize("use_fp8", [True, False]) -@pytest.mark.parametrize("precision", [torch.float16, torch.bfloat16]) -@pytest.mark.parametrize("zero_centered_gamma", [True]) -def test_export_gpt_generation( - seed_default_rng, - set_max_seq_len, - use_fp8: bool, - precision: torch.dtype, - zero_centered_gamma: bool, -): - """Test that the ONNX model can correctly handle inputs with different shapes and that - the attention mask it adjusted on-the-fly to different sequence lengths. - """ - - # Skip FP8 tests on non-hopper devices - if use_fp8 and not fp8_available: - pytest.skip(reason_for_no_fp8) - - # Layer configuration - hidden_size = 64 - sequence_length = 128 - batch_size = 1 - ffn_hidden_size = 256 - num_attention_heads = 4 - attention_mask = None - use_mask = True - attn_mask_type = "causal" - fuse_qkv_params = True - output_layernorm = False - - fp8_str = "_fp8" if use_fp8 else "" - fuse_qkv_params_str = "_fused-qkv" if fuse_qkv_params else "" - high_prec_str = dtype2str(precision) - attn_mask_str = get_attn_mask_str(use_mask, attn_mask_type) - fname = f"te.transformer_layer_generative{fp8_str}{attn_mask_str}{fuse_qkv_params_str}{high_prec_str}.onnx" - - model = te.TransformerLayer( - hidden_size, - ffn_hidden_size, - num_attention_heads, - self_attn_mask_type=attn_mask_type, - output_layernorm=output_layernorm, - params_dtype=precision, - fuse_qkv_params=fuse_qkv_params, - zero_centered_gamma=zero_centered_gamma, - ).to(device="cuda") - - # "Context phase": use full input sequence length - input_names = ["input"] - output_names = ["output"] - input_tensor = torch.rand( - sequence_length, batch_size, hidden_size, dtype=precision, device="cuda" - ) - inp = (input_tensor,) - do_export( - model, - inp, - fname, - use_fp8, - input_names=input_names, - output_names=output_names, - dynamic_axes={ - "input": {0: "seq", 1: "bs"}, - "output": {0: "seq", 1: "bs"}, - }, - ) - te_outputs = te_infer(model, inp, is_fp8=use_fp8) - serialize_inputs_outputs( - fname, inp, te_outputs, input_names=input_names, output_names=output_names - ) - if precision not in (torch.bfloat16,): - validate_result( - fname, - inp, - model, - atol=6e-3, - is_fp8=use_fp8, - input_names=input_names, - te_outputs=te_outputs, - ) - - # "Generative phase": use a single input (sequence len=1). For FP8 we need to pad the sequence to mult of 8. - sequence_length = 1 if not use_fp8 else 8 - input_tensor = torch.rand( - sequence_length, batch_size, hidden_size, dtype=precision, device="cuda" - ) - inp = (input_tensor, attention_mask) - te_outputs = te_infer(model, inp, is_fp8=use_fp8) - serialize_inputs_outputs(fname, inp, te_outputs, input_names=input_names) - if precision not in (torch.bfloat16,): - validate_result( - fname, - inp, - model, - atol=6e-3, - is_fp8=use_fp8, - input_names=input_names, - te_outputs=te_outputs, - ) - - -@pytest.mark.parametrize("enabled", [True, False]) -def test_export_ctx_manager(enabled): - assert is_in_onnx_export_mode() == False - with te.onnx_export(enabled): - assert is_in_onnx_export_mode() == enabled - assert is_in_onnx_export_mode() == False diff --git a/tests/pytorch/test_permutation.py b/tests/pytorch/test_permutation.py index c29c01b433..35c6266a3f 100644 --- a/tests/pytorch/test_permutation.py +++ b/tests/pytorch/test_permutation.py @@ -15,7 +15,7 @@ ) from transformer_engine.pytorch.utils import is_bf16_compatible from transformer_engine.pytorch.fp8 import FP8GlobalStateManager -from transformer_engine.pytorch.float8_tensor import Float8Tensor +from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8Quantizer import transformer_engine_torch as tex @@ -246,20 +246,28 @@ def _test_permutation_index_map( unpermute_bwd_input = torch.rand( size=(num_tokens, hidden_size), dtype=torch.float32, device="cuda" ) - - permute_fwd_input = Float8Tensor.to_float8( - permute_fwd_input, fp8_dtype=te_dtype, scale=torch.full([1], 1.0) + _permute_fwd_input_quantizer = Float8Quantizer( + scale=torch.full([1], 1.0).cuda().squeeze(), + amax=torch.full([1], 1.0).cuda(), + fp8_dtype=te_dtype, ) - permute_bwd_input = Float8Tensor.to_float8( - permute_bwd_input, fp8_dtype=te_dtype, scale=torch.full([1], 1.0) + _permute_bwd_input_quantizer = Float8Quantizer( + scale=torch.full([1], 1.0).cuda().squeeze(), + amax=torch.full([1], 1.0).cuda(), + fp8_dtype=te_dtype, ) - unpermute_bwd_input = Float8Tensor.to_float8( - unpermute_bwd_input, fp8_dtype=te_dtype, scale=torch.full([1], 1.0) + _unpermute_bwd_quantizer = Float8Quantizer( + scale=torch.full([1], 1.0).cuda().squeeze(), + amax=torch.full([1], 1.0).cuda(), + fp8_dtype=te_dtype, ) + permute_fwd_input = _permute_fwd_input_quantizer(permute_fwd_input) + permute_bwd_input = _permute_bwd_input_quantizer(permute_bwd_input) + unpermute_bwd_input = _unpermute_bwd_quantizer(unpermute_bwd_input) - pytorch_permute_fwd_input = permute_fwd_input.from_float8(torch.float16) - pytorch_permute_bwd_input = permute_bwd_input.from_float8(torch.float16) - pytorch_unpermute_bwd_input = unpermute_bwd_input.from_float8(torch.float16) + pytorch_permute_fwd_input = permute_fwd_input.dequantize().to(torch.float16) + pytorch_permute_bwd_input = permute_bwd_input.dequantize().to(torch.float16) + pytorch_unpermute_bwd_input = unpermute_bwd_input.dequantize().to(torch.float16) else: pytorch_permute_fwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda() pytorch_permute_bwd_input = torch.rand((num_out_tokens, hidden_size), dtype=dtype).cuda() @@ -333,10 +341,10 @@ def _test_permutation_index_map( tols = dtype_tols(te_dtype) if fp8: - te_permute_output_ = te_permute_output.from_float8(torch.float32) - te_permute_fwd_input_grad = te_permute_fwd_input.grad.from_float8(torch.float32) - te_unpermute_output_ = te_unpermute_output.from_float8(torch.float32) - te_unpermute_fwd_input_grad = te_unpermute_fwd_input.grad.from_float8(torch.float32) + te_permute_output_ = te_permute_output.dequantize().to(torch.float32) + te_permute_fwd_input_grad = te_permute_fwd_input.grad.dequantize().to(torch.float32) + te_unpermute_output_ = te_unpermute_output.dequantize().to(torch.float32) + te_unpermute_fwd_input_grad = te_unpermute_fwd_input.grad.dequantize().to(torch.float32) else: te_permute_output_ = te_permute_output.float() te_permute_fwd_input_grad = te_permute_fwd_input.grad.float() diff --git a/tests/pytorch/test_recipe.py b/tests/pytorch/test_recipe.py index 646dea552e..dcac5f1500 100644 --- a/tests/pytorch/test_recipe.py +++ b/tests/pytorch/test_recipe.py @@ -15,6 +15,7 @@ _amax_and_scale_update, get_default_fp8_recipe, ) +from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer import transformer_engine.pytorch.ops as te_ops import transformer_engine_torch as tex @@ -64,17 +65,17 @@ def test_fp8_scale_update_with_linear_module( forward_key = FP8GlobalStateManager.get_meta_tensor_key(forward=True) amax_history_forward = fp8_meta[forward_key].amax_history scale_forward = fp8_meta[forward_key].scale - scale_inv_forward = fp8_meta[forward_key].scale_inv + # scale_inv_forward = fp8_meta[forward_key].scale_inv backward_key = FP8GlobalStateManager.get_meta_tensor_key(forward=False) amax_history_backward = fp8_meta[backward_key].amax_history scale_backward = fp8_meta[backward_key].scale - scale_inv_backward = fp8_meta[backward_key].scale_inv + # scale_inv_backward = fp8_meta[backward_key].scale_inv # Tweak amax history and scaling factors amax_history_forward.copy_(2 * torch.rand_like(amax_history_forward) + 0.5) amax_history_forward[0, :].zero_() scale_forward.copy_(2 * torch.rand_like(scale_forward) + 0.5) - scale_inv_forward.copy_(torch.reciprocal(scale_forward)) + # scale_inv_forward.copy_(torch.reciprocal(scale_forward)) amax_history_backward[0, :].zero_() # Expected amax history after update @@ -100,11 +101,11 @@ def test_fp8_scale_update_with_linear_module( raise ValueError(f"{amax_compute_algo=} is not supported") ref_scale_forward = (fp8_format.value.max_fwd / ref_amax_forward) / (2**margin) ref_scale_backward = (fp8_format.value.max_bwd / ref_amax_backward) / (2**margin) - ref_scale_inv_forward = torch.reciprocal(ref_scale_forward) + # ref_scale_inv_forward = torch.reciprocal(ref_scale_forward) update_weight_amax = is_first_microbatch is None or is_first_microbatch - if not update_weight_amax: - ref_scale_inv_forward[1].copy_(scale_inv_forward[1]) - ref_scale_inv_backward = torch.reciprocal(ref_scale_backward) + # if not update_weight_amax: + # ref_scale_inv_forward[1].copy_(scale_inv_forward[1]) + # ref_scale_inv_backward = torch.reciprocal(ref_scale_backward) # Perform forward, backward, and optimizer steps to update fp8_meta with te.fp8_autocast(enabled=True, fp8_recipe=recipe): @@ -133,8 +134,8 @@ def test_fp8_scale_update_with_linear_module( raise ValueError(f"{amax_compute_algo=} is not supported") ref_scale_forward = (fp8_format.value.max_fwd / ref_amax_forward) / (2**margin) ref_scale_backward = (fp8_format.value.max_bwd / ref_amax_backward) / (2**margin) - ref_scale_inv_forward = torch.reciprocal(ref_scale_forward) - ref_scale_inv_backward = torch.reciprocal(ref_scale_backward) + # ref_scale_inv_forward = torch.reciprocal(ref_scale_forward) + # ref_scale_inv_backward = torch.reciprocal(ref_scale_backward) # Check that scale and scale inverse match expected values # Note: scale and scale inverse are only updated when amax is updated @@ -142,27 +143,15 @@ def test_fp8_scale_update_with_linear_module( scale_forward[0], ref_scale_forward[0], ) - torch.testing.assert_close( - scale_inv_forward[0], - ref_scale_inv_forward[0], - ) if update_weight_amax: torch.testing.assert_close( scale_forward[1], ref_scale_forward[1], ) - torch.testing.assert_close( - scale_inv_forward[1], - ref_scale_inv_forward[1], - ) torch.testing.assert_close( scale_backward[0], ref_scale_backward[0], ) - torch.testing.assert_close( - scale_inv_backward[0], - ref_scale_inv_backward[0], - ) @pytest.mark.parametrize("amax_history_len", [31, 1024]) @pytest.mark.parametrize("amax_compute_algo", ["max", "most_recent"]) @@ -180,12 +169,23 @@ def test_fp8_scale_update_with_linear_fuser_op( # Construct linear op op = te_ops.BasicLinear(in_shape[-1], in_shape[-1]) - # Get FP8 meta tensors + # FP8 recipe forward_key = FP8GlobalStateManager.get_meta_tensor_key(forward=True) backward_key = FP8GlobalStateManager.get_meta_tensor_key(forward=False) - x_fp8_meta = op.get_fp8_meta("input")[forward_key] - w_fp8_meta = op.get_fp8_meta("param")[forward_key] - dy_fp8_meta = op.get_fp8_meta("grad_output")[backward_key] + fp8_format = transformer_engine.common.recipe.Format.HYBRID + recipe = transformer_engine.common.recipe.DelayedScaling( + margin=margin, + interval=1, + fp8_format=fp8_format, + amax_history_len=amax_history_len, + amax_compute_algo=amax_compute_algo, + ) + + # Get FP8 meta tensors + with te.fp8_autocast(fp8_recipe=recipe): + x_fp8_meta = op.get_quantizer("forward", 0) + w_fp8_meta = op.get_quantizer("forward", 1) + dy_fp8_meta = op.get_quantizer("backward", 0) # Perform training steps x_history = [] @@ -214,14 +214,6 @@ def test_fp8_scale_update_with_linear_fuser_op( op.weight.fill_(w_history[-1]) # Forward and backward pass - fp8_format = transformer_engine.common.recipe.Format.HYBRID - recipe = transformer_engine.common.recipe.DelayedScaling( - margin=margin, - interval=1, - fp8_format=fp8_format, - amax_history_len=amax_history_len, - amax_compute_algo=amax_compute_algo, - ) with te.fp8_autocast(fp8_recipe=recipe): y = op(x) y.backward(dy) @@ -247,7 +239,7 @@ def check_amax_history( ) def check_scale( - fp8_meta: dict, + quantizer: Float8Quantizer, ref_amax_history: Iterable[float], stage: str, ): @@ -272,18 +264,11 @@ def check_scale( # Check values in FP8 meta tensors torch.testing.assert_close( - fp8_meta.scale.item(), + quantizer.scale.item(), ref_scale, ) - torch.testing.assert_close( - fp8_meta.scale_inv.item(), - 1 / ref_scale, - ) # Check that results match expected values - check_amax_history(x_fp8_meta, x_history) - check_amax_history(w_fp8_meta, w_history) - check_amax_history(dy_fp8_meta, dy_history) check_scale(x_fp8_meta, x_history, "forward") check_scale(w_fp8_meta, w_history, "forward") check_scale(dy_fp8_meta, dy_history, "backward") @@ -369,7 +354,6 @@ def setup_fp8_meta(): fp8_meta[forward_key].amax_history.clone().view(-1), [fp8_meta[forward_key].amax_history], [fp8_meta[forward_key].scale], - [fp8_meta[forward_key].scale_inv], recipe.amax_compute_algo, fp8_dtype, recipe.margin, @@ -378,12 +362,8 @@ def setup_fp8_meta(): _amax_and_scale_update( fp8_meta[forward_key].amax_history, fp8_meta[forward_key].scale, - fp8_meta[forward_key].scale_inv, fp8_max, recipe, ) torch.testing.assert_close(fp8_meta[forward_key].scale, expected_scale) - torch.testing.assert_close( - fp8_meta[forward_key].scale_inv, torch.reciprocal(expected_scale) - ) diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index daf8506593..d3bf34943d 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -8,7 +8,6 @@ import torch import pytest -import io import os from transformer_engine.pytorch.fp8 import ( @@ -34,19 +33,22 @@ ) from transformer_engine.common import recipe import transformer_engine_torch as tex -from transformer_engine.pytorch.cpp_extensions import ( - gemm, - fp8_gemm, - gelu, - cast_to_fp8, - cast_from_fp8, -) +from transformer_engine.pytorch.cpp_extensions import general_gemm from transformer_engine.pytorch.module.base import get_workspace -from test_onnx_export import create_meta +from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer from test_numerics import reset_rng_states, dtype_tols -# Only run FP8 tests on H100. +# Only run FP8 tests on supported devices. fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() +mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() + + +def create_meta(scale_factor: float, size: int = 1): + meta = tex.FP8TensorMeta() + meta.amax_history = torch.zeros(1, size, dtype=torch.float32, device="cuda") + meta.scale_inv = torch.ones(size, dtype=torch.float32, device="cuda") / scale_factor + meta.scale = torch.ones(size, dtype=torch.float32, device="cuda") * scale_factor + return meta def custom_amax_to_scale( @@ -96,13 +98,9 @@ def is_fp8_supported(self): fp8_recipes = [ None, # Handles non-FP8 case + recipe.MXFP8BlockScaling(), recipe.DelayedScaling(margin=0, fp8_format=recipe.Format.E4M3), recipe.DelayedScaling(margin=0, fp8_format=recipe.Format.HYBRID), - recipe.DelayedScaling( - margin=0, - fp8_format=recipe.Format.E4M3, - override_linear_precision=(False, False, True), - ), recipe.DelayedScaling( margin=0, fp8_format=recipe.Format.E4M3, @@ -136,7 +134,7 @@ def is_fp8_supported(self): all_boolean = [True, False] batch_sizes_with_zero = [0, 1, 2] -all_activations = ["gelu", "relu", "reglu", "geglu", "swiglu", "srelu"] +all_activations = ["gelu", "relu", "reglu", "geglu", "swiglu", "srelu", "qgelu", "qgeglu"] all_normalizations = ["LayerNorm", "RMSNorm"] @@ -236,6 +234,7 @@ def _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad): torch.cuda.synchronize() assert te_out.dtype == dtype, "AMP wrong output type." + assert te_inp_hidden_states.grad is not None, "Gradient should not be empty" assert te_inp_hidden_states.grad.dtype == torch.float32, "AMP wrong dgrad type." for name, p in block.named_parameters(): if p.requires_grad: @@ -272,11 +271,14 @@ def _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_reci loss.backward() torch.cuda.synchronize() + failed_grads = [] for name, p in block.named_parameters(): if "layer_norm_weight" in name: continue elif "weight" in name and p.requires_grad: - assert torch.count_nonzero(p.main_grad) > 0, "Gradient not accumulated." + if not torch.count_nonzero(p.main_grad) > 0: + failed_grads.append(name) + assert len(failed_grads) == 0, f"Gradient not accumulated for {failed_grads}." def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, cpu_offload): @@ -411,6 +413,7 @@ def _test_sanity_normalization_amp(block, dtype, config, skip_wgrad, skip_dgrad) torch.cuda.synchronize() assert te_out.dtype == dtype, "AMP wrong output type." + assert te_inp.grad is not None, "Gradient should not be empty" assert te_inp.grad.dtype == torch.float32, "AMP wrong dgrad type." for name, p in block.named_parameters(): if p.requires_grad: @@ -445,6 +448,8 @@ def test_sanity_layernorm_linear( if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) + if fp8_recipe.mxfp8() and not mxfp8_available: + pytest.skip(reason_for_no_mxfp8) if not config.is_fp8_supported(): pytest.skip("Model config does not support FP8") @@ -474,6 +479,8 @@ def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad): if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) + if fp8_recipe.mxfp8() and not mxfp8_available: + pytest.skip(reason_for_no_mxfp8) if not config.is_fp8_supported(): pytest.skip("Model config does not support FP8") @@ -504,11 +511,13 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_ if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) + if fp8_recipe.mxfp8() and not mxfp8_available: + pytest.skip(reason_for_no_mxfp8) if not config.is_fp8_supported(): pytest.skip("Model config does not support FP8") use_fp8 = fp8_recipe is not None - with fp8_model_init(enabled=use_fp8 and fp8_model_params): + with fp8_model_init(enabled=use_fp8 and fp8_model_params, recipe=fp8_recipe): te_linear = Linear( config.hidden_size, ffn_hidden_size, bias=use_bias, params_dtype=dtype ).cuda() @@ -539,6 +548,8 @@ def test_sanity_layernorm_mlp( if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) + if fp8_recipe.mxfp8() and not mxfp8_available: + pytest.skip(reason_for_no_mxfp8) if not config.is_fp8_supported(): pytest.skip("Model config does not support FP8") @@ -587,6 +598,8 @@ def test_sanity_gpt( if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) + if fp8_recipe.mxfp8() and not mxfp8_available: + pytest.skip(reason_for_no_mxfp8) if not config.is_fp8_supported(): pytest.skip("Model config does not support FP8") @@ -652,6 +665,8 @@ def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) + if fp8_recipe.mxfp8() and not mxfp8_available: + pytest.skip(reason_for_no_mxfp8) if not config.is_fp8_supported(): pytest.skip("Model config does not support FP8") @@ -709,6 +724,8 @@ def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, no if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) + if fp8_recipe.mxfp8() and not mxfp8_available: + pytest.skip(reason_for_no_mxfp8) if not config.is_fp8_supported(): pytest.skip("Model config does not support FP8") @@ -764,6 +781,8 @@ def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad): if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) + if fp8_recipe.mxfp8() and not mxfp8_available: + pytest.skip(reason_for_no_mxfp8) if not config.is_fp8_supported(): pytest.skip("Model config does not support FP8") @@ -797,6 +816,8 @@ def test_sanity_drop_path(dtype, fp8_recipe, model, skip_wgrad): if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) + if fp8_recipe.mxfp8() and not mxfp8_available: + pytest.skip(reason_for_no_mxfp8) if not config.is_fp8_supported(): pytest.skip("Model config does not support FP8") @@ -833,6 +854,8 @@ def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad): if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) + if fp8_recipe.mxfp8() and not mxfp8_available: + pytest.skip(reason_for_no_mxfp8) if not config.is_fp8_supported(): pytest.skip("Model config does not support FP8") @@ -872,6 +895,8 @@ def test_sanity_gradient_accumulation_fusion( if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) + if fp8_recipe.mxfp8() and not mxfp8_available: + pytest.skip(reason_for_no_mxfp8) if not config.is_fp8_supported(): pytest.skip("Model config does not support FP8") @@ -912,6 +937,8 @@ def test_gpt_cuda_graph(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamm if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) + if fp8_recipe.mxfp8() and not mxfp8_available: + pytest.skip(reason_for_no_mxfp8) if not config.is_fp8_supported(): pytest.skip("Model config does not support FP8") @@ -962,7 +989,7 @@ def test_sanity_gemm_with_unalignment(N, offset, datatype): inp = torch.reshape(scratchpad[offset:-offset], (N, N)) weight = torch.reshape(scratchpad[offset * 2 :], (N, N)) - _, _, _ = gemm(A=weight, B=inp, dtype=datatype, workspace=get_workspace()) + _ = general_gemm(A=weight, B=inp, workspace=get_workspace()) torch.cuda.synchronize() @@ -971,35 +998,24 @@ def test_sanity_gemm_with_unalignment(N, offset, datatype): @pytest.mark.parametrize("datatype", [torch.float16, torch.bfloat16]) def test_sanity_fp8_gemm_with_unalignment(N, datatype): offset = 16 - scratchpad = torch.randn(N * N + offset, device="cuda", dtype=datatype) + scratchpad = torch.randn(N, N * N + offset, device="cuda", dtype=datatype) - fp8_tensor_inp = tex.FP8FwdTensors.GEMM1_INPUT - fp8_tensor_weight = tex.FP8FwdTensors.GEMM1_WEIGHT + scales = torch.ones(1).cuda().squeeze() + amaxes = torch.ones(1).cuda().squeeze() + dtype = tex.DType.kFloat8E4M3 + fp8_quantizer = Float8Quantizer(scales, amaxes, dtype) - nb_inp_scales, nb_weight_scales = 1, N - scale_factor = 1.0 - meta_inp = create_meta(scale_factor, nb_inp_scales) - meta_weight = create_meta(scale_factor, nb_weight_scales) - inp_type = tex.DType.kFloat8E4M3 - weights_type = tex.DType.kFloat8E4M3 outp_type = datatype - scratchpad_fp8 = cast_to_fp8(scratchpad, meta_weight, fp8_tensor_inp, inp_type) - inp_fp8 = torch.reshape(scratchpad_fp8[:-offset], (N, N)) - weight_fp8 = torch.reshape(scratchpad_fp8[offset:], (N, N)) - _, _ = fp8_gemm( + scratchpad_fp8 = fp8_quantizer(scratchpad) + inp_fp8 = torch.reshape(scratchpad_fp8[0][:-offset], (N, N)) + weight_fp8 = torch.reshape(scratchpad_fp8[0][offset:], (N, N)) + general_gemm( weight_fp8, - meta_weight.scale_inv, - fp8_tensor_weight, - inp_type, inp_fp8, - meta_inp.scale_inv, - fp8_tensor_inp, - weights_type, - outp_type, get_workspace(), + outp_type, bias=None, - use_bias=False, use_split_accumulator=False, ) torch.cuda.synchronize() @@ -1062,13 +1078,15 @@ def get_model(dtype, config): init_method = init_method_normal(sigma) output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) - with fp8_model_init(enabled=fp8_enabled): + with fp8_model_init(enabled=fp8_enabled, recipe=fp8_recipe): block = TransformerLayer( config.hidden_size, 4 * config.hidden_size, config.num_attention_heads, init_method=init_method, output_layer_init_method=output_layer_init_method, + hidden_dropout=0.0, + attention_dropout=0.0, fuse_qkv_params=True, params_dtype=dtype, device="cuda", diff --git a/tests/pytorch/test_torch_save_load.py b/tests/pytorch/test_torch_save_load.py deleted file mode 100644 index 46ce33becc..0000000000 --- a/tests/pytorch/test_torch_save_load.py +++ /dev/null @@ -1,474 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -""" -This file contains tests for saving and loading TransformerEngine torch checkpoints. - -The purpose of this test is to validate the TransformerEngine hooks for saving FP8 metadata -in torch checkpoints, which are called as part of torch.save() and torch.load(). -The test verifies the values of FP8 metadata object after saving and loading a checkpoint -are identical to the original values. -""" - -import io -import tempfile -from typing import Iterable, Union - -import pytest -import torch -import transformer_engine.common -import transformer_engine.pytorch as te -import transformer_engine.pytorch.ops as te_ops -import transformer_engine_torch as tex -from transformer_engine.pytorch.cpp_extensions import fp8_gemm, cast_to_fp8 -from transformer_engine.pytorch.fp8 import FP8GlobalStateManager -from transformer_engine.pytorch.module.base import get_workspace -from transformer_engine.pytorch.module.base import TransformerEngineBaseModule - -# Check if FP8 is supported -fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() - - -def init_meta(size: int = 1): - meta = tex.FP8TensorMeta() - meta.scale = torch.ones(size, dtype=torch.float32, device="cuda") - meta.scale_inv = torch.ones(size, dtype=torch.float32, device="cuda") - meta.amax_history = torch.zeros(1, size, dtype=torch.float32, device="cuda") - return meta - - -@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) -@pytest.mark.parametrize("scale_fwd", [224, 112, 66]) -@pytest.mark.parametrize("scale_bwd", [448, 33]) -@pytest.mark.parametrize("history_fwd", [1.23, 4.56]) -@pytest.mark.parametrize("history_bwd", [2.34, 5.67]) -def test_export_loaded_checkpoint(scale_fwd, scale_bwd, history_fwd, history_bwd): - - tmp_filename = tempfile.NamedTemporaryFile().name - - precision = torch.float32 - - class Test_TE_Export(TransformerEngineBaseModule): - def __init__(self, precision, use_bias): - super().__init__() - self.use_bias = use_bias - self.precision = precision - - self.fp8_tensor_inp = tex.FP8FwdTensors.GEMM1_INPUT - self.fp8_tensor_weight = tex.FP8FwdTensors.GEMM1_WEIGHT - nb_inp_scales = nb_weight_scales = 1 - self.meta_inp = init_meta(nb_inp_scales) - self.meta_weight = init_meta(nb_weight_scales) - - bias_size = nb_weight_scales - self.bias = torch.randn(bias_size, dtype=precision, device="cuda") - - self.inp_type = tex.DType.kFloat8E4M3 - self.weights_type = tex.DType.kFloat8E4M3 - self.outp_type = precision - - def get_fp8_weights_scratchpad(self, is_first_microbatch): - raise RuntimeError( - "Method get_fp8_weights_scratchpad is dummy and should not be invoked." - ) - - def forward(self, inp, weight): - inp_fp8 = cast_to_fp8(inp, self.meta_inp, self.fp8_tensor_inp, self.inp_type) - - weight_fp8 = cast_to_fp8( - weight, self.meta_weight, self.fp8_tensor_weight, self.weights_type - ) - - ret = fp8_gemm( - weight_fp8, - self.meta_weight.scale_inv, - self.fp8_tensor_weight, - self.inp_type, - inp_fp8, - self.meta_inp.scale_inv, - self.fp8_tensor_inp, - self.weights_type, - self.outp_type, - get_workspace(), - bias=self.bias, - use_bias=self.use_bias, - use_split_accumulator=False, - ) - return ret - - model_in = Test_TE_Export(precision, True) - with te.fp8_autocast(enabled=True): - model_in.init_fp8_metadata() - # scaling fwd - model_in.fp8_meta["scaling_fwd"].scale = ( - torch.ones(3, dtype=torch.float32, device="cuda") * scale_fwd - ) - model_in.fp8_meta["scaling_fwd"].scale_inv = ( - torch.ones(3, dtype=torch.float32, device="cuda") / scale_fwd - ) - model_in.fp8_meta["scaling_fwd"].amax_history = ( - torch.ones(3, dtype=torch.float32, device="cuda") * history_fwd - ) - # scaling bwd - model_in.fp8_meta["scaling_bwd"].scale = ( - torch.ones(2, dtype=torch.float32, device="cuda") * scale_bwd - ) - model_in.fp8_meta["scaling_bwd"].scale_inv = ( - torch.ones(2, dtype=torch.float32, device="cuda") / scale_bwd - ) - model_in.fp8_meta["scaling_bwd"].amax_history = ( - torch.ones(2, dtype=torch.float32, device="cuda") * history_bwd - ) - - torch.save(model_in.state_dict(), tmp_filename) - - model_out = Test_TE_Export(precision, True) - model_out.load_state_dict(torch.load(tmp_filename, weights_only=False)) - model_out.eval() - - # scaling fwd - assert torch.allclose( - model_in.fp8_meta["scaling_fwd"].scale, model_out.fp8_meta["scaling_fwd"].scale - ) - assert torch.allclose( - model_in.fp8_meta["scaling_fwd"].scale_inv, model_out.fp8_meta["scaling_fwd"].scale_inv - ) - assert torch.allclose( - model_in.fp8_meta["scaling_fwd"].amax_history, - model_out.fp8_meta["scaling_fwd"].amax_history, - ) - # scaling bwd - assert torch.allclose( - model_in.fp8_meta["scaling_bwd"].scale, model_out.fp8_meta["scaling_bwd"].scale - ) - assert torch.allclose( - model_in.fp8_meta["scaling_bwd"].scale_inv, model_out.fp8_meta["scaling_bwd"].scale_inv - ) - assert torch.allclose( - model_in.fp8_meta["scaling_bwd"].amax_history, - model_out.fp8_meta["scaling_bwd"].amax_history, - ) - - -@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) -@pytest.mark.parametrize("save_fp8_model", [True, False]) -@pytest.mark.parametrize("load_fp8_model", [True, False]) -def test_fp8_model_checkpoint( - save_fp8_model: bool, - load_fp8_model: bool, - dims: Iterable[int] = [32, 32], - dtype: torch.dtype = torch.float32, - device: Union[torch.device, str] = "cuda", -): - - # Construct model - dims = list(dims) - hidden_dim = dims[-1] - with te.fp8_model_init(enabled=save_fp8_model): - model = te.Linear( - hidden_dim, - hidden_dim, - bias=False, - params_dtype=dtype, - device=device, - ) - # Keep track of model output - x = torch.randn(dims, dtype=dtype, device=device) - with te.fp8_autocast(): - y_ref = model(x.detach().clone()).detach().clone() - - fp8_meta_ref = {"scaling_fwd": {}, "scaling_bwd": {}} - with te.fp8_autocast(), torch.no_grad(): - fp8_meta_fwd = model.fp8_meta["scaling_fwd"] - fp8_meta_bwd = model.fp8_meta["scaling_bwd"] - fp8_meta_fwd_ref = fp8_meta_ref["scaling_fwd"] - fp8_meta_bwd_ref = fp8_meta_ref["scaling_bwd"] - fp8_meta_fwd_ref["scale"] = torch.rand_like(fp8_meta_fwd.scale) + 0.5 - fp8_meta_fwd_ref["scale_inv"] = fp8_meta_fwd_ref["scale"].reciprocal() - fp8_meta_bwd_ref["scale"] = torch.rand_like(fp8_meta_bwd.scale) + 0.5 - fp8_meta_bwd_ref["scale_inv"] = fp8_meta_bwd_ref["scale"].reciprocal() - fp8_meta_fwd.scale.copy_(fp8_meta_fwd_ref["scale"]) - fp8_meta_fwd.scale_inv.copy_(fp8_meta_fwd_ref["scale_inv"]) - fp8_meta_bwd.scale.copy_(fp8_meta_bwd_ref["scale"]) - fp8_meta_bwd.scale_inv.copy_(fp8_meta_bwd_ref["scale_inv"]) - del fp8_meta_fwd, fp8_meta_bwd - - # [ This is part of logic that tests save_fp8_model=False and load_fp8_model=True ] - # This line copies the fp8 scale_inv from the model metadata to the weight fp8 tensor. - # The sole purpose of the following lines is to set the scale_inv of the weight tensor, which is the simplest method. - # It is essential for these values to be equal, so setting scale_inv only in the model metadata is insufficient. - model.weight.data.copy_(model.weight.float().cuda()) - # After copying, the tensor computes the meta scale_inv based on the amax history; we then reset these values. - model.fp8_meta["scaling_fwd"].scale = fp8_meta_fwd_ref["scale"] - model.fp8_meta["scaling_fwd"].scale_inv = fp8_meta_fwd_ref["scale_inv"] - - # Keep track of weights and FP8 scaling factors - weight_ref = model.weight.float().detach().clone() - - # Save checkpoint - byte_stream = io.BytesIO() - torch.save(model.state_dict(), byte_stream) - model_bytes = byte_stream.getvalue() - del byte_stream - - # Disturb and destroy model - with torch.no_grad(): - model.weight.zero_() - model.fp8_meta = {"This": "is", "filled": "with", "nonsense": 1234} - del model - - # Construct new model - with te.fp8_model_init(enabled=load_fp8_model): - model = te.Linear( - hidden_dim, - hidden_dim, - bias=False, - params_dtype=dtype, - device=device, - ) - - # Make sure new model does not match saved model - tols = dict(rtol=0.125, atol=0.0675) # fp8e4me3 epsilon = 0.0625 - with pytest.raises(AssertionError): - torch.testing.assert_close(model.weight, weight_ref, **tols) - with te.fp8_autocast(): - model.init_fp8_metadata() - fp8_meta_fwd = model.fp8_meta["scaling_fwd"] - fp8_meta_bwd = model.fp8_meta["scaling_bwd"] - fp8_meta_fwd_ref = fp8_meta_ref["scaling_fwd"] - fp8_meta_bwd_ref = fp8_meta_ref["scaling_bwd"] - with pytest.raises(AssertionError): - torch.testing.assert_close(fp8_meta_fwd.scale, fp8_meta_fwd_ref["scale"]) - with pytest.raises(AssertionError): - torch.testing.assert_close(fp8_meta_fwd.scale_inv, fp8_meta_fwd_ref["scale_inv"]) - with pytest.raises(AssertionError): - torch.testing.assert_close(fp8_meta_bwd.scale, fp8_meta_bwd_ref["scale"]) - with pytest.raises(AssertionError): - torch.testing.assert_close(fp8_meta_bwd.scale_inv, fp8_meta_bwd_ref["scale_inv"]) - with te.fp8_autocast(): - y = model(x.detach().clone()) - with pytest.raises(AssertionError): - torch.testing.assert_close(y, y_ref, **tols) - - # [ This is part of logic that tests save_fp8_model=False and load_fp8_model=True ] - # When save_fp8_model=True, we load a model with weights in high precision, - # which does not include _scale_inv, - # but has the fp8 scaling factor in the meta data. This scenario can occur - # when using te.fp8_autocast(enabled=False, calibrating=True). - # - # In such cases, the default behavior of load_state_dict is incorrect - it loads tensors first, - # followed by the fp8 metadata. This results in an incorrect _scale_inv for the tensor. This behavior - # is corrected by overriding the _load_state_dict method from PyTorch in TransformerEngineBaseModule, - # to load the fp8 metadata before loading tensors. - # - # Load checkpoint - model.load_state_dict(torch.load(io.BytesIO(model_bytes), weights_only=False)) - del model_bytes - - # Check that loaded model matches saved model - torch.testing.assert_close(model.weight, weight_ref, **tols) - with te.fp8_autocast(): - fp8_meta_fwd = model.fp8_meta["scaling_fwd"] - fp8_meta_bwd = model.fp8_meta["scaling_bwd"] - fp8_meta_fwd_ref = fp8_meta_ref["scaling_fwd"] - fp8_meta_bwd_ref = fp8_meta_ref["scaling_bwd"] - torch.testing.assert_close(fp8_meta_fwd.scale, fp8_meta_fwd_ref["scale"]) - torch.testing.assert_close(fp8_meta_fwd.scale_inv, fp8_meta_fwd_ref["scale_inv"]) - torch.testing.assert_close(fp8_meta_bwd.scale, fp8_meta_bwd_ref["scale"]) - torch.testing.assert_close(fp8_meta_bwd.scale_inv, fp8_meta_bwd_ref["scale_inv"]) - with te.fp8_autocast(): - y = model(x.detach().clone()) - torch.testing.assert_close(y, y_ref, **tols) - - if load_fp8_model: - # [ This is part of logic that tests save_fp8_model=False and load_fp8_model=True ] - # We need to ensure that the tensor's scale_inv parameter matches its meta data. - # This is crucial to avoid confusion about which value is correct. - meta_index = model.weight._fp8_meta_index - torch.testing.assert_close( - model.weight._scale_inv.item(), fp8_meta_fwd_ref["scale_inv"][meta_index].item() - ) - - -@pytest.mark.parametrize("fp8", (False, True)) -@pytest.mark.parametrize("save_fp8_model", (False, True)) -@pytest.mark.parametrize("load_fp8_model", (False, True)) -def test_sequential_model( - *, - in_shape: Iterable[int] = (16, 16), - dtype: torch.dtype = torch.float32, - device: torch.device = "cuda", - save_steps: int = 2, - load_steps: int = 2, - fp8: bool, - save_fp8_model: bool, - load_fp8_model: bool, -) -> None: - - # Skip invalid configurations - if fp8 or save_fp8_model or load_fp8_model: - if not fp8_available: - pytest.skip(reason_for_no_fp8) - if torch.device(device).type != "cuda": - pytest.skip("FP8 is only supported on CUDA devices") - - # FP8 recipe - margin = 2 - fp8_format = transformer_engine.common.recipe.Format.E4M3 - recipe = transformer_engine.common.recipe.DelayedScaling( - margin=margin, - fp8_format=fp8_format, - amax_history_len=8, - amax_compute_algo="max", - ) - - # Construct model to save to checkpoint - with te.fp8_model_init(enabled=save_fp8_model): - model = te_ops.Sequential( - te_ops.Linear(in_shape[-1], in_shape[-1], device=device, dtype=dtype), - ) - with torch.no_grad(): - torch.rand(model[0].weight.size(), out=model[0].weight) - torch.rand(model[0].bias.size(), out=model[0].bias) - - # Synthetic data - xs_ref = [ - torch.rand(in_shape, dtype=dtype, device=device) for _ in range(save_steps + load_steps) - ] - dys_ref = [ - torch.rand(in_shape, dtype=dtype, device=device) for _ in range(save_steps + load_steps) - ] - - def train_step( - model: te_ops.Sequential, - x: torch.Tensor, - dy: torch.Tensor, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Helper function to perform training step""" - x = x.detach().clone().requires_grad_() - dy = dy.detach().clone() - with te.fp8_autocast(enabled=fp8, fp8_recipe=recipe): - y = model(x) - y.backward(dy) - with torch.no_grad(): - for param in model.parameters(): - param += 0.125 - return ( - y.detach().clone(), - x.grad.detach().clone(), - model[0].weight.detach().float().clone(), - ) - - # Initial training steps with saved model - ys_ref = [] - dxs_ref = [] - ws_ref = [] - for step in range(save_steps): - y, dx, w = train_step(model, xs_ref[step], dys_ref[step]) - ys_ref.append(y) - dxs_ref.append(dx) - ws_ref.append(w) - - # Keep track of FP8 metadata if needed - fp8_meta_ref = dict(input={}, param={}, grad_output={}) - if fp8: - for fp8_meta_type, fp8_meta_key in ( - ("input", "scaling_fwd"), - ("param", "scaling_fwd"), - ("grad_output", "scaling_bwd"), - ): - m_model = model[0].basic_ops[0].get_fp8_meta(fp8_meta_type)[fp8_meta_key] - m_ref = fp8_meta_ref[fp8_meta_type] - m_ref["amax"] = m_model.amax_history.detach().clone() - m_ref["scale"] = m_model.scale.detach().clone() - m_ref["scale_inv"] = m_model.scale_inv.detach().clone() - del m_model, m_ref - - # Save checkpoint - byte_stream = io.BytesIO() - torch.save(model.state_dict(), byte_stream) - model_bytes = byte_stream.getvalue() - del byte_stream - - # More training steps with saved model - for step in range(save_steps, save_steps + load_steps): - y, dx, w = train_step(model, xs_ref[step], dys_ref[step]) - ys_ref.append(y) - dxs_ref.append(dx) - ws_ref.append(w) - - # Disturb and destroy model - with torch.no_grad(): - for param in model.parameters(): - param.zero_() - model[0].basic_ops[0]._fp8_metas = None - del model - - # Construct new model to load from checkpoint - with te.fp8_model_init(enabled=load_fp8_model): - model = te_ops.Sequential( - te_ops.Linear(in_shape[-1], in_shape[-1], device=device, dtype=dtype), - ) - - # Tolerances for numerical checks - tols = {} - if fp8 or save_fp8_model or load_fp8_model: - tols = dict(rtol=0.125, atol=0.0675) # fp8e4me3 epsilon = 0.0625 - exact_tols = dict(rtol=0, atol=0) - - # Training steps with dummy data - for step in range(save_steps): - y, dx, w = train_step( - model, - torch.zeros_like(xs_ref[step]), - torch.zeros_like(dys_ref[step]), - ) - - # Make sure results don't match saved model - with pytest.raises(AssertionError): - torch.testing.assert_close(y, ys_ref[step], **tols) - with pytest.raises(AssertionError): - torch.testing.assert_close(dx, dxs_ref[step], **tols) - with pytest.raises(AssertionError): - torch.testing.assert_close(w, ws_ref[step], **tols) - - # Make sure new model's FP8 metadata doesn't match saved model - if fp8: - for fp8_meta_type, fp8_meta_key in ( - ("input", "scaling_fwd"), - ("param", "scaling_fwd"), - ("grad_output", "scaling_bwd"), - ): - m_model = model[0].basic_ops[0].get_fp8_meta(fp8_meta_type)[fp8_meta_key] - m_ref = fp8_meta_ref[fp8_meta_type] - with pytest.raises(AssertionError): - torch.testing.assert_close(m_model.amax_history, m_ref["amax"], **exact_tols) - with pytest.raises(AssertionError): - torch.testing.assert_close(m_model.scale, m_ref["scale"], **exact_tols) - with pytest.raises(AssertionError): - torch.testing.assert_close(m_model.scale_inv, m_ref["scale_inv"], **exact_tols) - - # Load checkpoint - model.load_state_dict(torch.load(io.BytesIO(model_bytes), weights_only=False)) - del model_bytes - - # Check that new model's FP8 metadata matches saved model - if fp8: - for fp8_meta_type, fp8_meta_key in ( - ("input", "scaling_fwd"), - ("param", "scaling_fwd"), - ("grad_output", "scaling_bwd"), - ): - m_model = model[0].basic_ops[0].get_fp8_meta(fp8_meta_type)[fp8_meta_key] - m_ref = fp8_meta_ref[fp8_meta_type] - torch.testing.assert_close(m_model.amax_history, m_ref["amax"], **exact_tols) - torch.testing.assert_close(m_model.scale, m_ref["scale"], **exact_tols) - torch.testing.assert_close(m_model.scale_inv, m_ref["scale_inv"], **exact_tols) - - # More training steps with loaded model - for step in range(save_steps, save_steps + load_steps): - y, dx, w = train_step(model, xs_ref[step], dys_ref[step]) - torch.testing.assert_close(y, ys_ref[step], **tols) - torch.testing.assert_close(dx, dxs_ref[step], **tols) - torch.testing.assert_close(w, ws_ref[step], **tols) diff --git a/transformer_engine/__init__.py b/transformer_engine/__init__.py index d97d9653e6..8b80364a3d 100644 --- a/transformer_engine/__init__.py +++ b/transformer_engine/__init__.py @@ -19,19 +19,9 @@ except (ImportError, StopIteration) as e: pass -try: - from . import paddle -except (ImportError, StopIteration) as e: - pass - try: import transformer_engine_jax except ImportError: pass -try: - import transformer_engine_paddle -except ImportError: - pass - __version__ = str(metadata.version("transformer_engine")) diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 3afddcc48d..ed59153954 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -6,13 +6,17 @@ cmake_minimum_required(VERSION 3.21) # Language options if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES) - set(CMAKE_CUDA_ARCHITECTURES 70 80 89 90) + if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.8) + set(CMAKE_CUDA_ARCHITECTURES 70 80 89 90 100 120) + else () + set(CMAKE_CUDA_ARCHITECTURES 70 80 89 90) + endif() endif() set(CMAKE_CXX_STANDARD 17) set(CMAKE_CUDA_STANDARD 17) set(CMAKE_CUDA_STANDARD_REQUIRED ON) if (CMAKE_BUILD_TYPE STREQUAL "Debug") - set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -G") + set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -g -G") endif() # Hide non-necessary symbols in shared object. @@ -78,6 +82,7 @@ list(APPEND transformer_engine_SOURCES util/cuda_runtime.cpp util/rtc.cpp util/system.cpp + swizzle/swizzle.cu fused_softmax/scaled_masked_softmax.cu fused_softmax/scaled_upper_triang_masked_softmax.cu fused_softmax/scaled_aligned_causal_masked_softmax.cu diff --git a/transformer_engine/common/activation/activation_template.h b/transformer_engine/common/activation/activation_template.h index ddb786bd3a..708403f911 100644 --- a/transformer_engine/common/activation/activation_template.h +++ b/transformer_engine/common/activation/activation_template.h @@ -4,111 +4,71 @@ * See LICENSE for license information. ************************************************************************/ +/*! \file activation_template.h + * \brief Activation functions template. + */ + +#ifndef TRANSFORMER_ENGINE_ACTIVATION_TEMPLATE_H_ +#define TRANSFORMER_ENGINE_ACTIVATION_TEMPLATE_H_ + #include #include #include "../common.h" +#include "../util/cast_gated_kernels.cuh" +#include "../util/cast_kernels.cuh" +#include "../util/math.h" #include "../util/vectorized_pointwise.h" namespace transformer_engine { template -void act_fn(const Tensor &input, Tensor *output, cudaStream_t stream) { - CheckInputTensor(input, "act_lu_input"); - CheckOutputTensor(*output, "act_lu_output"); - NVTE_CHECK(input.data.shape == output->data.shape, "Input and output shapes must match."); - const size_t tot_elts = product(input.data.shape); +void act_fn(const NVTETensor input, NVTETensor output, cudaStream_t stream) { + using namespace detail; + constexpr bool IS_DBIAS = false; + constexpr bool IS_DACT = false; + constexpr bool IS_ACT = true; + constexpr NVTETensor dbias = nullptr; + constexpr NVTETensor workspace = nullptr; + constexpr const NVTETensor grad = nullptr; - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( - input.data.dtype, IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( - output->data.dtype, OType, constexpr int nvec = 32 / sizeof(IType); - VectorizedUnaryKernelLauncher( - reinterpret_cast(input.data.dptr), - reinterpret_cast(output->data.dptr), - reinterpret_cast(output->scale.dptr), - reinterpret_cast(output->amax.dptr), - reinterpret_cast(output->scale_inv.dptr), tot_elts, {}, - stream);); // NOLINT(*) - ); // NOLINT(*) + quantize_helper(input, grad, nullptr, output, dbias, + workspace, stream); } template -void dact_fn(const Tensor &grad, const Tensor &input, Tensor *output, cudaStream_t stream) { - CheckInputTensor(input, "dact_lu_input"); - CheckInputTensor(grad, "dact_lu_input_grad"); - CheckOutputTensor(*output, "dact_lu_output"); - NVTE_CHECK(input.data.shape == output->data.shape, "Input and output shapes must match."); - NVTE_CHECK(input.data.dtype == grad.data.dtype, "Input and incoming gradient types must match."); - const size_t tot_elts = product(input.data.shape); +void dact_fn(const NVTETensor grad, const NVTETensor input, NVTETensor output, + cudaStream_t stream) { + using namespace detail; + constexpr bool IS_DBIAS = false; + constexpr bool IS_DACT = true; + constexpr bool IS_ACT = false; + constexpr NVTETensor dbias = nullptr; + constexpr NVTETensor workspace = nullptr; - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( - input.data.dtype, IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( - output->data.dtype, OType, constexpr int nvec = 32 / sizeof(IType); - VectorizedUnaryGradKernelLauncher( - reinterpret_cast(grad.data.dptr), - reinterpret_cast(input.data.dptr), - reinterpret_cast(output->data.dptr), - reinterpret_cast(output->scale.dptr), - reinterpret_cast(output->amax.dptr), - reinterpret_cast(output->scale_inv.dptr), tot_elts, {}, - stream);); // NOLINT(*) - ); // NOLINT(*) + quantize_helper(input, grad, nullptr, output, dbias, + workspace, stream); } -template -void gated_act_fn(const Tensor &input, Tensor *output, cudaStream_t stream) { - CheckInputTensor(input, "gated_act_input"); - CheckOutputTensor(*output, "gated_act_output"); - NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions."); - NVTE_CHECK(output->data.shape.size() == 2, "Output must have 2 dimensions."); - NVTE_CHECK(input.data.shape[0] == output->data.shape[0], - "Input shape[0] must be equal to output shape[0]."); - NVTE_CHECK(input.data.shape[1] == output->data.shape[1] * 2, - "Input shape[1] must be 2x larger than output shape[1]."); +template +void gated_act_fn(const NVTETensor input, NVTETensor output, cudaStream_t stream) { + using namespace detail; + constexpr bool IS_DGATED = false; + constexpr NVTETensor grad = nullptr; - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( - input.data.dtype, IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( - output->data.dtype, OType, constexpr int nvec = 32 / sizeof(IType); - GatedActivationKernelLauncher( - reinterpret_cast(input.data.dptr), - reinterpret_cast(output->data.dptr), - reinterpret_cast(output->scale.dptr), - reinterpret_cast(output->amax.dptr), - reinterpret_cast(output->scale_inv.dptr), output->data.shape[0], - output->data.shape[1], {}, - stream);); // NOLINT(*) - ); // NOLINT(*) + quantize_gated_helper(grad, input, output, stream); } -template -void dgated_act_fn(const Tensor &grad, const Tensor &input, Tensor *output, cudaStream_t stream) { - CheckInputTensor(grad, "dgated_act_grad"); - CheckInputTensor(input, "dgated_act_input"); - CheckOutputTensor(*output, "dgated_act_output"); - NVTE_CHECK(grad.data.shape.size() == 2, "Grad must have 2 dimensions."); - NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions."); - NVTE_CHECK(output->data.shape.size() == 2, "Output must have 2 dimensions."); - NVTE_CHECK(output->data.shape[0] == grad.data.shape[0], - "Output shape[0] must be equal to grad shape[0]."); - NVTE_CHECK(output->data.shape[1] == grad.data.shape[1] * 2, - "Output shape[1] must be 2x larger than grad shape[1]."); - NVTE_CHECK(input.data.shape == output->data.shape, "Input and output shapes must match."); +template +void dgated_act_fn(const NVTETensor grad, const NVTETensor input, NVTETensor output, + cudaStream_t stream) { + using namespace detail; + constexpr bool IS_DGATED = true; - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( - input.data.dtype, IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( - output->data.dtype, OType, constexpr int nvec = 32 / sizeof(IType); - DGatedActivationKernelLauncher( - reinterpret_cast(grad.data.dptr), - reinterpret_cast(input.data.dptr), - reinterpret_cast(output->data.dptr), grad.data.shape[0], grad.data.shape[1], - {}, - stream);); // NOLINT(*) - ); // NOLINT(*) + quantize_gated_helper(grad, input, output, stream); } } // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_ACTIVATION_TEMPLATE_H_ diff --git a/transformer_engine/common/activation/gelu.cu b/transformer_engine/common/activation/gelu.cu index cb38b351e9..0cf43007a7 100644 --- a/transformer_engine/common/activation/gelu.cu +++ b/transformer_engine/common/activation/gelu.cu @@ -3,69 +3,58 @@ * * See LICENSE for license information. ************************************************************************/ + #include "../util/math.h" #include "./activation_template.h" void nvte_gelu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_gelu); using namespace transformer_engine; - act_fn>(*reinterpret_cast(input), - reinterpret_cast(output), stream); + act_fn>(input, output, stream); } void nvte_dgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dgelu); using namespace transformer_engine; - dact_fn>(*reinterpret_cast(grad), - *reinterpret_cast(input), - reinterpret_cast(output), stream); + dact_fn>(grad, input, output, stream); } void nvte_geglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_geglu); using namespace transformer_engine; - gated_act_fn>(*reinterpret_cast(input), - reinterpret_cast(output), stream); + gated_act_fn>(input, output, stream); } void nvte_dgeglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dgeglu); using namespace transformer_engine; - dgated_act_fn, dgelu>( - *reinterpret_cast(grad), *reinterpret_cast(input), - reinterpret_cast(output), stream); + dgated_act_fn, dgelu>(grad, input, output, stream); } void nvte_qgelu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_qgelu); using namespace transformer_engine; - act_fn>(*reinterpret_cast(input), - reinterpret_cast(output), stream); + act_fn>(input, output, stream); } void nvte_dqgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dqgelu); using namespace transformer_engine; - dact_fn>(*reinterpret_cast(grad), - *reinterpret_cast(input), - reinterpret_cast(output), stream); + dact_fn>(grad, input, output, stream); } void nvte_qgeglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_qgeglu); using namespace transformer_engine; - gated_act_fn>(*reinterpret_cast(input), - reinterpret_cast(output), stream); + gated_act_fn>(input, output, stream); } void nvte_dqgeglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dqgeglu); using namespace transformer_engine; - dgated_act_fn, dqgelu>( - *reinterpret_cast(grad), *reinterpret_cast(input), - reinterpret_cast(output), stream); + dgated_act_fn, dqgelu>(grad, input, output, stream); } diff --git a/transformer_engine/common/activation/relu.cu b/transformer_engine/common/activation/relu.cu index 7653991819..a794b7315f 100644 --- a/transformer_engine/common/activation/relu.cu +++ b/transformer_engine/common/activation/relu.cu @@ -10,63 +10,51 @@ void nvte_relu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_relu); using namespace transformer_engine; - act_fn>(*reinterpret_cast(input), - reinterpret_cast(output), stream); + act_fn>(input, output, stream); } void nvte_drelu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_drelu); using namespace transformer_engine; - dact_fn>(*reinterpret_cast(grad), - *reinterpret_cast(input), - reinterpret_cast(output), stream); + dact_fn>(grad, input, output, stream); } void nvte_reglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_reglu); using namespace transformer_engine; - gated_act_fn>(*reinterpret_cast(input), - reinterpret_cast(output), stream); + gated_act_fn>(input, output, stream); } void nvte_dreglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dreglu); using namespace transformer_engine; - dgated_act_fn, drelu>( - *reinterpret_cast(grad), *reinterpret_cast(input), - reinterpret_cast(output), stream); + dgated_act_fn, drelu>(grad, input, output, stream); } void nvte_srelu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_srelu); using namespace transformer_engine; - act_fn>(*reinterpret_cast(input), - reinterpret_cast(output), stream); + act_fn>(input, output, stream); } void nvte_dsrelu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dsrelu); using namespace transformer_engine; - dact_fn>(*reinterpret_cast(grad), - *reinterpret_cast(input), - reinterpret_cast(output), stream); + dact_fn>(grad, input, output, stream); } void nvte_sreglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_sreglu); using namespace transformer_engine; - gated_act_fn>(*reinterpret_cast(input), - reinterpret_cast(output), stream); + gated_act_fn>(input, output, stream); } void nvte_dsreglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dsreglu); using namespace transformer_engine; - dgated_act_fn, dsrelu>( - *reinterpret_cast(grad), *reinterpret_cast(input), - reinterpret_cast(output), stream); + dgated_act_fn, dsrelu>(grad, input, output, stream); } diff --git a/transformer_engine/common/activation/swiglu.cu b/transformer_engine/common/activation/swiglu.cu index 5a0e0ead84..8194964745 100644 --- a/transformer_engine/common/activation/swiglu.cu +++ b/transformer_engine/common/activation/swiglu.cu @@ -10,31 +10,25 @@ void nvte_silu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_silu); using namespace transformer_engine; - act_fn>(*reinterpret_cast(input), - reinterpret_cast(output), stream); + act_fn>(input, output, stream); } void nvte_dsilu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dsilu); using namespace transformer_engine; - dact_fn>(*reinterpret_cast(grad), - *reinterpret_cast(input), - reinterpret_cast(output), stream); + dact_fn>(grad, input, output, stream); } void nvte_swiglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_swiglu); using namespace transformer_engine; - gated_act_fn>(*reinterpret_cast(input), - reinterpret_cast(output), stream); + gated_act_fn>(input, output, stream); } void nvte_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dswiglu); using namespace transformer_engine; - dgated_act_fn, dsilu>( - *reinterpret_cast(grad), *reinterpret_cast(input), - reinterpret_cast(output), stream); + dgated_act_fn, dsilu>(grad, input, output, stream); } diff --git a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp index 003ea9588c..d988de6f66 100644 --- a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp +++ b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp @@ -21,6 +21,8 @@ #define HALF_BYTES 2 #define UB_MAX_SM 32 +#define AS_VECTOR(shape) std::vector(shape.data, shape.data + shape.ndim) + using namespace std::placeholders; namespace transformer_engine { @@ -40,8 +42,9 @@ bool ubuf_built_with_mpi() { CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numlocal, int mynode, int numnodes, int tp_size, ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle, int num_splits, int num_max_streams, - int comm_cga_size, int num_comm_sm, bool set_sm_margin, - bool use_ce, bool atomic_gemm) { + int comm_cga_size, int gemm_priority, int comm_priority, + int num_comm_sm, bool set_sm_margin, bool use_ce, + bool atomic_gemm) { // Initialize userbuf communicator if (!_comm_created) { if (myrank == 0) { @@ -59,9 +62,15 @@ CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numl _num_comm_sm = num_comm_sm; _cga_size = comm_cga_size; + if (gemm_priority == 0 && comm_priority == 0) { + transformer_engine::cuda::stream_priority_range(&_gemm_priority, &_comm_priority); + } else { + _gemm_priority = gemm_priority; + _comm_priority = comm_priority; + } for (int i = 0; i < std::min(num_max_streams, num_splits); i++) { cudaStream_t stream; - NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&stream, cudaStreamNonBlocking, -1)); + NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&stream, cudaStreamNonBlocking, _gemm_priority)); _stream_compute.push_back(std::move(stream)); } @@ -130,6 +139,73 @@ CommOverlapCore::~CommOverlapCore() { } } +TensorWrapper CommOverlapCore::get_tensor_chunk(const TensorWrapper &source, size_t chunk_offset, + const std::vector &chunk_shape) { + TensorWrapper chunk; + for (int param_id = 0; param_id < NVTETensorParam::kNVTENumTensorParams; param_id++) { + auto param_type = static_cast(param_id); + auto param = source.get_parameter(param_type); + auto param_dptr = reinterpret_cast(param.data_ptr); + auto param_dtype = static_cast(param.dtype); + auto param_shape = AS_VECTOR(param.shape); + + if (param_dptr != nullptr) { + if (param_type == NVTETensorParam::kNVTERowwiseData || + param_type == NVTETensorParam::kNVTEColumnwiseData) { + // Offset data pointer + param_dptr += chunk_offset * typeToSize(param_dtype); + param_shape = chunk_shape; + + if (param_type == NVTETensorParam::kNVTEColumnwiseData && + source.scaling_mode() != NVTEScalingMode::NVTE_MXFP8_1D_SCALING) { + // Columnwise shape for non-block scaled tensors shifts the last dimension to the front + auto last_dim = param_shape.back(); + param_shape.pop_back(); + param_shape.insert(param_shape.begin(), last_dim); + } + } else if (source.scaling_mode() == NVTEScalingMode::NVTE_MXFP8_1D_SCALING && + (param_type == NVTETensorParam::kNVTERowwiseScaleInv || + param_type == NVTETensorParam::kNVTEColumnwiseScaleInv)) { + // Calculate block scaling offset and size + auto scaled_tensor_dim_size = (param_type == NVTETensorParam::kNVTERowwiseScaleInv) + ? source.shape().data[0] + : source.columnwise_shape().data[0]; + auto scaled_chunk_dim_size = (param_type == NVTETensorParam::kNVTERowwiseScaleInv) + ? chunk_shape.front() + : chunk_shape.back(); + auto chunk_scale_start = chunk_offset / 32; + auto chunk_scale_end = (chunk_offset + scaled_chunk_dim_size) / 32; + auto chunk_scale_size = chunk_scale_end - chunk_scale_start; + param_dptr += chunk_scale_start * typeToSize(param_dtype); + param_shape = std::vector{chunk_scale_size}; + } + + // Set chunked source parameters into the chunked tensor output + chunk.set_parameter(param_type, reinterpret_cast(param_dptr), param_dtype, + param_shape); + } + } + return chunk; +} + +TensorWrapper CommOverlapCore::get_buffer_chunk_like(const TensorWrapper &source, + size_t chunk_offset, + const std::vector &chunk_shape) { + // Start with a chunk of the source tensor + auto chunk = get_tensor_chunk(source, chunk_offset, chunk_shape); + + // Update chunk with offset data pointers from the communication buffer + auto ubuf_ptr = reinterpret_cast(_ubuf.dptr()) + (chunk_offset * _ubuf.element_size()); + if (chunk.dptr() != nullptr) { + chunk.set_rowwise_data(reinterpret_cast(ubuf_ptr), chunk.dtype(), chunk.shape()); + } + if (chunk.columnwise_dptr() != nullptr) { + chunk.set_columnwise_data(reinterpret_cast(ubuf_ptr), chunk.dtype(), + chunk.columnwise_shape()); + } + return chunk; +} + /*************************************************************************************************** * Comm+GEMM Overlap Base (Pipelined / Collective) **************************************************************************************************/ @@ -138,11 +214,14 @@ CommOverlapBase::CommOverlapBase(const std::vector &buffer_shape, DType int myrank, int numranks, int mylocal, int numlocal, int mynode, int numnodes, int tp_size, ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle, int num_splits, int num_max_streams, - int comm_cga_size, int num_comm_sm, bool set_sm_margin, - bool atomic_gemm) + int comm_cga_size, int gemm_priority, int comm_priority, + int num_comm_sm, bool set_sm_margin, bool atomic_gemm, + bool rs_overlap_first_gemm) : CommOverlapCore(myrank, numranks, mylocal, numlocal, mynode, numnodes, tp_size, allgather_handle, barrier_handle, num_splits, num_max_streams, comm_cga_size, - num_comm_sm, set_sm_margin, false, atomic_gemm) { + gemm_priority, comm_priority, num_comm_sm, set_sm_margin, false, + atomic_gemm) { + _rs_overlap_first_gemm = rs_overlap_first_gemm; _rs_kernel_type = getenv("NVTE_RS_STRIDED_ATOMIC", 0); NVTE_CHECK(_rs_kernel_type >= 0 && _rs_kernel_type <= 3, "Invalid choice for NVTE_RS_STRIDED_ATOMIC: Must be 0 (non-atomic), 1 (atomic) ", @@ -155,7 +234,8 @@ CommOverlapBase::CommOverlapBase(const std::vector &buffer_shape, DType if (_ub_comm->myrank == 0) printf("!!! [UB] Register UBuf %d\n", _ub_reg); _ubuf = TensorWrapper(buffer_ptr, buffer_shape, buffer_dtype); - NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&_stream_comm, cudaStreamNonBlocking, -1)); + NVTE_CHECK_CUDA( + cudaStreamCreateWithPriority(&_stream_comm, cudaStreamNonBlocking, _comm_priority)); NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_start_d2dcopy, 0)); } @@ -168,8 +248,8 @@ CommOverlapBase::~CommOverlapBase() { ** Bulk GEMM + COMM ** This function assumes the communication input is pre-copied to _ubuf */ -void CommOverlapBase::bulk_overlap(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, - TensorWrapper &D, TensorWrapper &bias, +void CommOverlapBase::bulk_overlap(const TensorWrapper &A, bool transa, const TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, CommOverlapType comm_type, TensorWrapper &rs_output, @@ -196,7 +276,7 @@ void CommOverlapBase::bulk_overlap(TensorWrapper &A, bool transa, TensorWrapper assert(rs_output.size(0) == _ubuf.size(0) / _tp_size); assert(rs_output.element_size() == 2); char *rs_output_ptr = reinterpret_cast(rs_output.dptr()); - reducescatter2_userbuff_fp8<__nv_fp8_e5m2>(rs_output_ptr, _ubuf_scale_inv, _ub_reg, 0, + reducescatter2_userbuff_fp8<__nv_fp8_e5m2>(rs_output_ptr, _ubuf.scale_inv(), _ub_reg, 0, comm_elements, _ub_comm, _stream_comm, (cudaEvent_t)_comm_launch_event); } else { @@ -221,20 +301,20 @@ void CommOverlapBase::bulk_overlap(TensorWrapper &A, bool transa, TensorWrapper /* ** Split FPROP GEMM + ReduceScatter */ -void CommOverlapBase::atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, - bool transb, TensorWrapper &D, TensorWrapper &bias, - TensorWrapper &pre_gelu_out, TensorWrapper &workspace, - bool grad, bool accumulate, bool use_split_accumulator, - bool gemm_overlap, TensorWrapper &rs_output, +void CommOverlapBase::atomic_gemm_overlap_rs(const TensorWrapper &A, bool transa, + const TensorWrapper &B, bool transb, TensorWrapper &D, + TensorWrapper &bias, TensorWrapper &pre_gelu_out, + TensorWrapper &workspace, bool grad, bool accumulate, + bool use_split_accumulator, TensorWrapper &rs_output, cudaStream_t stream_main) { int ori_sms = _ub_comm->sms; _ub_comm->use_ce = _use_ce; _ub_comm->sms = _num_comm_sm; _ub_comm->cga_size = _cga_size; // Get GEMM dimensions - size_t m = A.size(0); - size_t k = A.size(1); - size_t n = B.size(0); + size_t m = transa ? A.size(0) : A.size(1); + size_t k = transa ? A.size(1) : A.size(0); + size_t n = _ubuf.size(0); size_t m_chunk = m / _num_splits; size_t workspace_size_chunk = workspace.numel() / _stream_compute.size(); @@ -255,9 +335,8 @@ void CommOverlapBase::atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, Tens assert(pre_gelu_out.numel() == 0); - auto output_d = TensorWrapper(_ubuf.dptr(), {n, m}, D.dtype(), D.amax(), D.scale(), nullptr); - auto workspace_chunk = - TensorWrapper(workspace.dptr(), std::vector{workspace_size_chunk}, workspace.dtype()); + auto output_d = get_buffer_chunk_like(D, 0, {n, m}); + auto workspace_chunk = get_tensor_chunk(workspace, 0, {workspace_size_chunk}); nvte_cublas_atomic_gemm(A.data(), B.data(), output_d.data(), bias.data(), pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), accumulate, use_split_accumulator, _math_sms, _num_splits, 0, true, _counter.data(), @@ -269,11 +348,10 @@ void CommOverlapBase::atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, Tens _ub_comm->sms = UB_MAX_SM; } if (_ubuf.element_size() == 1) { - assert(_ubuf_scale_inv_initialized); TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( D.dtype(), fp8_type, reducescatter2_userbuff_strided_atomic_fp8( - rs_output_ptr, _ubuf_scale_inv, _ub_reg, i * m_chunk, m_chunk, n, m, m, _num_splits, + rs_output_ptr, D.scale_inv(), _ub_reg, i * m_chunk, m_chunk, n, m, m, _num_splits, &counter_ptr[i], _ub_comm, _stream_comm);); } else { reducescatter2_userbuff_strided_atomic(rs_output_ptr, _ub_reg, i * m_chunk, m_chunk, n, m, @@ -282,11 +360,10 @@ void CommOverlapBase::atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, Tens } } else if (_rs_kernel_type == 2) { if (_ubuf.element_size() == 1) { - assert(_ubuf_scale_inv_initialized); TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( D.dtype(), fp8_type, reducescatter2_userbuff_strided_multiatomic_fp8( - rs_output_ptr, _ubuf_scale_inv, _ub_reg, m_chunk, m_chunk, n, m, m, _num_splits, + rs_output_ptr, D.scale_inv(), _ub_reg, m_chunk, m_chunk, n, m, m, _num_splits, counter_ptr, _ub_comm, _stream_comm);); } else { reducescatter2_userbuff_strided_multiatomic(rs_output_ptr, _ub_reg, m_chunk, m_chunk, n, m, @@ -299,7 +376,7 @@ void CommOverlapBase::atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, Tens if (_ubuf.element_size() == 1) { TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( D.dtype(), fp8_type, - reducescatter2_userbuff_stridedoutput_fp8(rs_output_ptr, _ubuf_scale_inv, + reducescatter2_userbuff_stridedoutput_fp8(rs_output_ptr, D.scale_inv(), _ub_reg, i * m_chunk, m_chunk, n, m, _ub_comm, _stream_comm);); } else { @@ -321,34 +398,24 @@ void CommOverlapBase::atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, Tens /* ** Split FPROP GEMM + ReduceScatter */ -void CommOverlapBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, - TensorWrapper &D, TensorWrapper &bias, +void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, const TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, - bool gemm_overlap, TensorWrapper &rs_output, - cudaStream_t stream_main) { + TensorWrapper &rs_output, cudaStream_t stream_main) { // Get GEMM dimensions int ori_sms = _ub_comm->sms; _ub_comm->use_ce = _use_ce; _ub_comm->sms = _num_comm_sm; _ub_comm->cga_size = _cga_size; - size_t m = A.size(0); - size_t k = A.size(1); - size_t n = B.size(0); + size_t m = transa ? A.size(0) : A.size(1); + size_t k = transa ? A.size(1) : A.size(0); + size_t n = _ubuf.size(0); size_t m_chunk = m / _num_splits; size_t input_a_chunk_size = m_chunk * k; size_t output_chunk_size = n * m_chunk; - size_t bias_chunk_size = m_chunk; size_t workspace_size_chunk = workspace.numel() / _stream_compute.size(); - // Get input, output, and workspace data pointers - char *input_a_chunk_ptr = reinterpret_cast(A.dptr()); - char *output_buf_chunk_ptr = reinterpret_cast(_ubuf.dptr()); - char *bias_chunk_ptr = reinterpret_cast(bias.dptr()); - char *workspace_ptr = reinterpret_cast(workspace.dptr()); - - char *rs_output_ptr = reinterpret_cast(rs_output.dptr()); - // Catch up the default torch stream NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main)); for (size_t i = 0; i < _stream_compute.size(); i++) { @@ -358,39 +425,23 @@ void CommOverlapBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorWrap assert(pre_gelu_out.numel() == 0); - if (gemm_overlap) { - auto input_a_chunk = - TensorWrapper(A.dptr(), {m_chunk, k}, A.dtype(), nullptr, nullptr, A.scale_inv()); - auto output_chunk = - TensorWrapper(_ubuf.dptr(), {m, m_chunk}, D.dtype(), D.amax(), D.scale(), nullptr); - auto bias_chunk = - TensorWrapper(bias.dptr(), {m_chunk}, bias.dtype(), nullptr, nullptr, nullptr); - auto workspace_chunk = TensorWrapper( - workspace.dptr(), std::vector{workspace_size_chunk}, workspace.dtype()); - - nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias_chunk.data(), + char *rs_output_ptr = reinterpret_cast(rs_output.dptr()); + if (_rs_overlap_first_gemm) { + auto input_a_chunk = get_tensor_chunk(A, 0, {m_chunk, k}); + auto output_chunk = get_buffer_chunk_like(D, 0, {m, m_chunk}); + auto workspace_chunk = get_tensor_chunk(workspace, 0, {workspace_size_chunk}); + + nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(), pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), accumulate, use_split_accumulator, _math_sms, _stream_compute[0]); for (int i = 1; i < _num_splits; i++) { - input_a_chunk_ptr += input_a_chunk_size * B.element_size(); - output_buf_chunk_ptr += output_chunk_size * D.element_size(); - if (bias_chunk_ptr != nullptr) { - bias_chunk_ptr += bias_chunk_size * bias.element_size(); - } - char *workspace_chunk_ptr = - workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk; - - input_a_chunk = TensorWrapper(reinterpret_cast(input_a_chunk_ptr), {m_chunk, k}, - A.dtype(), nullptr, nullptr, A.scale_inv()); - output_chunk = TensorWrapper(reinterpret_cast(output_buf_chunk_ptr), {n, m_chunk}, - D.dtype(), D.amax(), D.scale(), nullptr); - bias_chunk = TensorWrapper(reinterpret_cast(bias_chunk_ptr), {m_chunk}, bias.dtype(), - nullptr, nullptr, nullptr); - workspace_chunk = TensorWrapper(reinterpret_cast(workspace_chunk_ptr), - std::vector{workspace_size_chunk}, workspace.dtype()); - - nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias_chunk.data(), + input_a_chunk = get_tensor_chunk(A, i * input_a_chunk_size, {m_chunk, k}); + output_chunk = get_buffer_chunk_like(D, i * output_chunk_size, {n, m_chunk}); + workspace_chunk = get_tensor_chunk( + workspace, (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk}); + + nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(), pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), accumulate, use_split_accumulator, _math_sms, _stream_compute[i % _stream_compute.size()]); @@ -401,11 +452,10 @@ void CommOverlapBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorWrap // Communication chunk if (_ubuf.element_size() == 1) { - assert(_ubuf_scale_inv_initialized); TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( D.dtype(), fp8_type, reducescatter2_userbuff_stridedoutput_fp8( - rs_output_ptr, _ubuf_scale_inv, _ub_reg, (i - 1) * output_chunk_size, m_chunk, n, m, + rs_output_ptr, D.scale_inv(), _ub_reg, (i - 1) * output_chunk_size, m_chunk, n, m, _ub_comm, _stream_comm);); } else { reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg, (i - 1) * output_chunk_size, @@ -422,12 +472,11 @@ void CommOverlapBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorWrap // Last communication chunk with max SM _ub_comm->sms = UB_MAX_SM; if (_ubuf.element_size() == 1) { - assert(_ubuf_scale_inv_initialized); TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( D.dtype(), fp8_type, reducescatter2_userbuff_stridedoutput_fp8( - rs_output_ptr, _ubuf_scale_inv, _ub_reg, (_num_splits - 1) * output_chunk_size, - m_chunk, n, m, _ub_comm, _stream_comm);); + rs_output_ptr, D.scale_inv(), _ub_reg, (_num_splits - 1) * output_chunk_size, m_chunk, + n, m, _ub_comm, _stream_comm);); } else { reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg, (_num_splits - 1) * output_chunk_size, m_chunk, n, m, @@ -435,20 +484,12 @@ void CommOverlapBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorWrap } } else { for (int i = 0; i < _num_splits; i++) { - char *workspace_chunk_ptr = - workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk; - - auto input_a_chunk = TensorWrapper(reinterpret_cast(input_a_chunk_ptr), {m_chunk, k}, - A.dtype(), nullptr, nullptr, A.scale_inv()); - auto output_chunk = TensorWrapper(reinterpret_cast(output_buf_chunk_ptr), - {n, m_chunk}, D.dtype(), D.amax(), D.scale(), nullptr); - auto bias_chunk = TensorWrapper(reinterpret_cast(bias_chunk_ptr), {m_chunk}, - bias.dtype(), nullptr, nullptr, nullptr); - auto workspace_chunk = - TensorWrapper(reinterpret_cast(workspace_chunk_ptr), - std::vector{workspace_size_chunk}, workspace.dtype()); - - nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias_chunk.data(), + auto input_a_chunk = get_tensor_chunk(A, i * input_a_chunk_size, {m_chunk, k}); + auto output_chunk = get_buffer_chunk_like(D, i * output_chunk_size, {n, m_chunk}); + auto workspace_chunk = get_tensor_chunk( + workspace, (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk}); + + nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(), pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), accumulate, use_split_accumulator, _math_sms, _stream_compute[i % _stream_compute.size()]); @@ -461,11 +502,10 @@ void CommOverlapBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorWrap _ub_comm->sms = UB_MAX_SM; } if (_ubuf.element_size() == 1) { - assert(_ubuf_scale_inv_initialized); TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( D.dtype(), fp8_type, reducescatter2_userbuff_stridedoutput_fp8( - rs_output_ptr, _ubuf_scale_inv, _ub_reg, i * output_chunk_size, m_chunk, n, m, + rs_output_ptr, D.scale_inv(), _ub_reg, i * output_chunk_size, m_chunk, n, m, _ub_comm, _stream_comm);); } else { reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg, i * output_chunk_size, @@ -473,11 +513,6 @@ void CommOverlapBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorWrap } rs_output_ptr += m_chunk * rs_output.element_size(); - input_a_chunk_ptr += input_a_chunk_size * B.element_size(); - output_buf_chunk_ptr += output_chunk_size * _ubuf.element_size(); - if (bias_chunk_ptr != nullptr) { - bias_chunk_ptr += bias_chunk_size * bias.element_size(); - } } } @@ -499,11 +534,13 @@ CommOverlapP2PBase::CommOverlapP2PBase(const std::vector &buffer_shape, int mynode, int numnodes, int tp_size, ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle, CommOverlapType comm_type, int num_max_streams, - int comm_cga_size, int num_comm_sm, bool set_sm_margin, - bool use_ce, bool atomic_gemm, bool aggregate) + int comm_cga_size, int gemm_priority, int comm_priority, + int num_comm_sm, bool set_sm_margin, bool use_ce, + bool atomic_gemm, bool aggregate) : CommOverlapCore(myrank, numranks, mylocal, numlocal, mynode, numnodes, tp_size, allgather_handle, barrier_handle, tp_size, num_max_streams, comm_cga_size, - num_comm_sm, set_sm_margin, use_ce, atomic_gemm) { + gemm_priority, comm_priority, num_comm_sm, set_sm_margin, use_ce, + atomic_gemm) { _is_p2p = true; _is_reduce_scatter = comm_type == CommOverlapType::RS; _aggregate = aggregate; @@ -552,8 +589,13 @@ CommOverlapP2PBase::CommOverlapP2PBase(const std::vector &buffer_shape, NVTE_CHECK_CUDA(cudaMemset(_counter.dptr(), 0, sizeof(int32_t))); } - NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&_stream_send, cudaStreamNonBlocking, -1)); - NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&_stream_recv, cudaStreamNonBlocking, -1)); + for (int i = 0; i < std::min(num_max_streams, _tp_size); i++) { + cudaStream_t stream; + NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&stream, cudaStreamNonBlocking, _comm_priority)); + _stream_send.push_back(std::move(stream)); + } + NVTE_CHECK_CUDA( + cudaStreamCreateWithPriority(&_stream_recv, cudaStreamNonBlocking, _comm_priority)); NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_stop_send, 0)); NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_stop_recv, 0)); } @@ -562,7 +604,22 @@ CommOverlapP2PBase::~CommOverlapP2PBase() { cudaEventDestroy(_stop_recv); cudaEventDestroy(_stop_send); cudaStreamDestroy(_stream_recv); - cudaStreamDestroy(_stream_send); + for (size_t i = 0; i < _stream_send.size(); i++) cudaStreamDestroy(_stream_send[i]); +} + +TensorWrapper CommOverlapP2PBase::get_buffer_chunk_by_id(const TensorWrapper &source, + size_t chunk_id) { + // Start with a chunk of the source tensor + auto chunk = get_tensor_chunk(source, 0, AS_VECTOR(_ubufs[chunk_id].shape())); + + // Update chunk with offset data pointers from the communication buffer + if (chunk.dptr() != nullptr) { + chunk.set_rowwise_data(_ubufs[chunk_id].dptr(), chunk.dtype(), chunk.shape()); + } + if (chunk.columnwise_dptr() != nullptr) { + chunk.set_columnwise_data(_ubufs[chunk_id].dptr(), chunk.dtype(), chunk.columnwise_shape()); + } + return chunk; } /* @@ -570,12 +627,10 @@ CommOverlapP2PBase::~CommOverlapP2PBase() { ** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is needed to have AG ** outputs in each rank to be in the contiguous memory space after all ring exchange phases. */ -void CommOverlapP2PBase::atomic_gemm_overlap_ag(TensorWrapper &A, bool transa, TensorWrapper &B, - bool transb, TensorWrapper &D, TensorWrapper &bias, - TensorWrapper &pre_gelu_out, - TensorWrapper &workspace, bool grad, - bool accumulate, bool use_split_accumulator, - TensorWrapper &B_copy, cudaStream_t stream_main) { +void CommOverlapP2PBase::atomic_gemm_overlap_ag( + const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, TensorWrapper &D, + TensorWrapper &bias, TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, + bool accumulate, bool use_split_accumulator, TensorWrapper &B_copy, cudaStream_t stream_main) { int ori_sms = _ub_comm->sms; _ub_comm->use_ce = _use_ce; _ub_comm->sms = _num_comm_sm; @@ -583,8 +638,7 @@ void CommOverlapP2PBase::atomic_gemm_overlap_ag(TensorWrapper &A, bool transa, T // Get GEMM dimensions between TN and NN input layouts const size_t m = (transa) ? A.size(0) : A.size(1); - const size_t n = _ubuf.size(0); - const size_t n_chunk = n / _tp_size; + const size_t n_chunk = _ubufs[0].size(0); assert(pre_gelu_out.numel() == 0); // Get communication and GEMM output chunk sizes @@ -594,7 +648,8 @@ void CommOverlapP2PBase::atomic_gemm_overlap_ag(TensorWrapper &A, bool transa, T void *D_buffer_ptr; int D_chunk_bytes = n_chunk * m * D.element_size(); NVTE_CHECK_CUDA(cudaMallocAsync(&D_buffer_ptr, (_tp_size + 1) * D_chunk_bytes, stream_main)); - auto D_buffer = TensorWrapper(D_buffer_ptr, D.shape(), D.dtype(), D.amax(), D.scale(), nullptr); + auto D_buffer = TensorWrapper(D_buffer_ptr, D.shape(), D.dtype(), D.amax(), D.scale(), + D.scale_inv(), D.scale_inv_shape(), D.scaling_mode()); // Reset atomic counters int *counter_ptr = reinterpret_cast(_counter.dptr()); @@ -602,13 +657,12 @@ void CommOverlapP2PBase::atomic_gemm_overlap_ag(TensorWrapper &A, bool transa, T // Catch up the default torch stream NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send, _start_compute, 0)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[0], _start_compute, 0)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_recv, _start_compute, 0)); - auto input_b = TensorWrapper(_ubuf.dptr(), B.shape(), B.dtype(), nullptr, nullptr, B.scale_inv()); + auto input_b = get_buffer_chunk_like(B, 0, AS_VECTOR(B.shape())); size_t workspace_size_chunk = workspace.numel() / _stream_compute.size(); - auto workspace_chunk = - TensorWrapper(workspace.dptr(), std::vector{workspace_size_chunk}, workspace.dtype()); + auto workspace_chunk = get_tensor_chunk(workspace, 0, {workspace_size_chunk}); for (int i = 0; i < _tp_size - 1; i++) { // Set the userbuffer id. Buffer under send is the input for the current @@ -649,8 +703,8 @@ void CommOverlapP2PBase::atomic_gemm_overlap_ag(TensorWrapper &A, bool transa, T NVTE_CHECK_CUDA( cudaMemcpyAsync(B_copy.dptr(), _ubufs[_self_chunk_id].dptr(), _ubufs[_self_chunk_id].numel() * _ubufs[_self_chunk_id].element_size(), - cudaMemcpyDeviceToDevice, _stream_send)); - NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, _stream_send)); + cudaMemcpyDeviceToDevice, _stream_send[0])); + NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, _stream_send[0])); NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_send, 0)); } @@ -674,11 +728,12 @@ void CommOverlapP2PBase::atomic_gemm_overlap_ag(TensorWrapper &A, bool transa, T ** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is needed to have AG ** outputs in each rank to be in the contiguous memory space after all ring exchange phases. */ -void CommOverlapP2PBase::split_overlap_ag(TensorWrapper &A, bool transa, TensorWrapper &B, - bool transb, TensorWrapper &D, TensorWrapper &bias, - TensorWrapper &pre_gelu_out, TensorWrapper &workspace, - bool grad, bool accumulate, bool use_split_accumulator, - TensorWrapper &B_copy, cudaStream_t stream_main) { +void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa, + const TensorWrapper &B, bool transb, TensorWrapper &D, + TensorWrapper &bias, TensorWrapper &pre_gelu_out, + TensorWrapper &workspace, bool grad, bool accumulate, + bool use_split_accumulator, TensorWrapper &B_copy, + cudaStream_t stream_main) { int ori_sms = _ub_comm->sms; _ub_comm->use_ce = _use_ce; _ub_comm->sms = _num_comm_sm; @@ -691,24 +746,20 @@ void CommOverlapP2PBase::split_overlap_ag(TensorWrapper &A, bool transa, TensorW // Get communication and GEMM output chunk sizes const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size(); const bool do_gelu = pre_gelu_out.numel() > 0; - const int output_chunk_bytes = (n_chunk * m) * D.element_size(); - const int aux_chunk_bytes = do_gelu ? (n_chunk * m) * pre_gelu_out.element_size() : 0; - - // Get output and workspace data pointers - char *output_ptr = reinterpret_cast(D.dptr()); - char *pre_gelu_out_ptr = reinterpret_cast(pre_gelu_out.dptr()); - char *workspace_ptr = reinterpret_cast(workspace.dptr()); + size_t input_chunk_size = n_chunk * k; + size_t output_chunk_size = n_chunk * m; size_t workspace_size_chunk = workspace.numel() / _stream_compute.size(); NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send, _start_compute, 0)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[0], _start_compute, 0)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_recv, _start_compute, 0)); for (size_t i = 0; i < _stream_compute.size(); i++) { NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_compute[i], _start_compute, 0)); } if (_aggregate) { const int num_steps = _tp_size / 2; - char *input_b_ptr = reinterpret_cast(_ubuf.dptr()); + input_chunk_size *= 2; + output_chunk_size *= 2; // Initial 1X input chunk exchange between neighboring peers int send_chunk_id = _tp_id; @@ -717,11 +768,11 @@ void CommOverlapP2PBase::split_overlap_ag(TensorWrapper &A, bool transa, TensorW int recv_offset = comm_bytes * recv_chunk_id; int peer_rank = (_tp_id % 2 == 0) ? _next_rank : _prev_rank; userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, comm_bytes, _ub_comm, peer_rank, - _stream_send); + _stream_send[0]); userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, peer_rank, _stream_recv); NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send, _stop_recv, 0)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[0], _stop_recv, 0)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_compute[0], _stop_recv, 0)); int local_rank_round2 = (_tp_id % 2 == 0) ? _tp_id : _tp_id - 1; @@ -736,27 +787,15 @@ void CommOverlapP2PBase::split_overlap_ag(TensorWrapper &A, bool transa, TensorW recv_offset = comm_bytes * recv_chunk_id; // GEMM - char *input_b_chunk_ptr = input_b_ptr + send_offset; auto input_b_chunk = - TensorWrapper(reinterpret_cast(input_b_chunk_ptr), {n_chunk * 2, k}, B.dtype(), - nullptr, nullptr, B.scale_inv()); - - char *output_chunk_ptr = output_ptr + (send_chunk_id * output_chunk_bytes); - auto output_chunk = TensorWrapper(reinterpret_cast(output_chunk_ptr), - {n_chunk * 2, m}, D.dtype(), D.amax(), D.scale(), nullptr); - - char *aux_chunk_ptr = - (do_gelu) ? pre_gelu_out_ptr + (send_chunk_id * aux_chunk_bytes) : nullptr; - auto aux_chunk_shape = - (do_gelu) ? std::vector{n_chunk * 2, m} : std::vector{0}; - auto aux_chunk = TensorWrapper(reinterpret_cast(aux_chunk_ptr), aux_chunk_shape, - pre_gelu_out.dtype()); - - char *workspace_chunk_ptr = - workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk; - auto workspace_chunk = - TensorWrapper(reinterpret_cast(workspace_chunk_ptr), - std::vector{workspace_size_chunk}, workspace.dtype()); + get_buffer_chunk_like(B, input_chunk_size * send_chunk_id, {n_chunk * 2, k}); + auto output_chunk = get_tensor_chunk(D, output_chunk_size * send_chunk_id, {n_chunk * 2, m}); + auto aux_chunk = + (do_gelu) + ? get_tensor_chunk(pre_gelu_out, output_chunk_size * send_chunk_id, {n_chunk * 2, k}) + : TensorWrapper(nullptr, std::vector{0}, pre_gelu_out.dtype()); + auto workspace_chunk = get_tensor_chunk( + workspace, (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk}); nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(), aux_chunk.data(), transa, transb, grad, workspace_chunk.data(), accumulate, @@ -766,11 +805,11 @@ void CommOverlapP2PBase::split_overlap_ag(TensorWrapper &A, bool transa, TensorW if (i < num_steps - 1) { // P2P communication userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, comm_bytes * 2, _ub_comm, - next_rank, _stream_send); + next_rank, _stream_send[0]); userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, comm_bytes * 2, _ub_comm, prev_rank, _stream_recv); NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send, _stop_recv, 0)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[0], _stop_recv, 0)); NVTE_CHECK_CUDA( cudaStreamWaitEvent(_stream_compute[(i + 1) % _stream_compute.size()], _stop_recv, 0)); } else if (B_copy.numel() > 0) { @@ -778,7 +817,7 @@ void CommOverlapP2PBase::split_overlap_ag(TensorWrapper &A, bool transa, TensorW assert(B_copy.element_size() == _ubufs[_tp_id].element_size()); NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.dptr(), _ubufs[_tp_id].dptr(), _ubufs[_tp_id].numel() * _ubufs[_tp_id].element_size(), - cudaMemcpyDeviceToDevice, _stream_send)); + cudaMemcpyDeviceToDevice, _stream_send[0])); } } } else { @@ -793,24 +832,14 @@ void CommOverlapP2PBase::split_overlap_ag(TensorWrapper &A, bool transa, TensorW int recv_offset = comm_bytes * recv_chunk_id; // GEMM - auto input_b_chunk = TensorWrapper(_ubufs[send_chunk_id].dptr(), {n_chunk, k}, B.dtype(), - nullptr, nullptr, B.scale_inv()); - - char *output_chunk_ptr = output_ptr + (send_chunk_id * output_chunk_bytes); - auto output_chunk = TensorWrapper(reinterpret_cast(output_chunk_ptr), {n_chunk, m}, - D.dtype(), D.amax(), D.scale(), nullptr); - - char *aux_chunk_ptr = - (do_gelu) ? pre_gelu_out_ptr + (send_chunk_id * aux_chunk_bytes) : nullptr; - auto aux_chunk_shape = (do_gelu) ? std::vector{n_chunk, m} : std::vector{0}; - auto aux_chunk = TensorWrapper(reinterpret_cast(aux_chunk_ptr), aux_chunk_shape, - pre_gelu_out.dtype()); - - char *workspace_chunk_ptr = - workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk; - auto workspace_chunk = - TensorWrapper(reinterpret_cast(workspace_chunk_ptr), - std::vector{workspace_size_chunk}, workspace.dtype()); + auto input_b_chunk = get_buffer_chunk_like(B, input_chunk_size * send_chunk_id, {n_chunk, k}); + auto output_chunk = get_tensor_chunk(D, output_chunk_size * send_chunk_id, {n_chunk, m}); + auto aux_chunk = + (do_gelu) + ? get_tensor_chunk(pre_gelu_out, output_chunk_size * send_chunk_id, {n_chunk, k}) + : TensorWrapper(nullptr, std::vector{0}, pre_gelu_out.dtype()); + auto workspace_chunk = get_tensor_chunk( + workspace, (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk}); nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(), aux_chunk.data(), transa, transb, grad, workspace_chunk.data(), accumulate, @@ -820,11 +849,11 @@ void CommOverlapP2PBase::split_overlap_ag(TensorWrapper &A, bool transa, TensorW if (i < _tp_size - 1) { // P2P communication userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, comm_bytes, _ub_comm, - _next_rank, _stream_send); + _next_rank, _stream_send[0]); userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, _prev_rank, _stream_recv); NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send, _stop_recv, 0)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[0], _stop_recv, 0)); NVTE_CHECK_CUDA( cudaStreamWaitEvent(_stream_compute[(i + 1) % _stream_compute.size()], _stop_recv, 0)); } else if (B_copy.numel() > 0) { @@ -832,7 +861,7 @@ void CommOverlapP2PBase::split_overlap_ag(TensorWrapper &A, bool transa, TensorW assert(B_copy.element_size() == _ubufs[_tp_id].element_size()); NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.dptr(), _ubufs[_tp_id].dptr(), _ubufs[_tp_id].numel() * _ubufs[_tp_id].element_size(), - cudaMemcpyDeviceToDevice, _stream_send)); + cudaMemcpyDeviceToDevice, _stream_send[0])); } } } @@ -842,7 +871,7 @@ void CommOverlapP2PBase::split_overlap_ag(TensorWrapper &A, bool transa, TensorW NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, _stream_compute[i])); NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_compute, 0)); } - NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, _stream_send)); + NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, _stream_send[0])); NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_send, 0)); NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_recv, 0)); @@ -851,13 +880,11 @@ void CommOverlapP2PBase::split_overlap_ag(TensorWrapper &A, bool transa, TensorW /* ** Split ReduceScatter + GEMM using P2P communication */ -void CommOverlapP2PBase::atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, - bool transb, TensorWrapper &D, TensorWrapper &bias, - TensorWrapper &pre_gelu_out, - TensorWrapper &workspace, bool grad, - bool accumulate, bool use_split_accumulator, - TensorWrapper &rs_output, - cudaStream_t stream_main) { +void CommOverlapP2PBase::atomic_gemm_overlap_rs( + const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, TensorWrapper &D, + TensorWrapper &bias, TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, + bool accumulate, bool use_split_accumulator, TensorWrapper &rs_output, + cudaStream_t stream_main) { int ori_sms = _ub_comm->sms; _ub_comm->use_ce = _use_ce; _ub_comm->sms = _num_comm_sm; @@ -876,14 +903,10 @@ void CommOverlapP2PBase::atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, T // Atomic GEMM // Process GEMM chunks in the order that AG+GEMM places the output chunks. - auto output_d = TensorWrapper(_ubuf.dptr(), D.shape(), D.dtype(), D.amax(), D.scale(), nullptr); - size_t workspace_size_chunk = workspace.numel() / _stream_compute.size(); - auto workspace_chunk = - TensorWrapper(workspace.data(), std::vector{workspace_size_chunk}, workspace.dtype()); + auto output_d = get_buffer_chunk_like(D, 0, AS_VECTOR(D.shape())); nvte_cublas_atomic_gemm(A.data(), B.data(), output_d.data(), bias.data(), pre_gelu_out.data(), - transa, transb, grad, workspace_chunk.data(), accumulate, - use_split_accumulator, _math_sms, 0, _tp_size, true, _counter.data(), - stream_main); + transa, transb, grad, workspace.data(), accumulate, use_split_accumulator, + _math_sms, 0, _tp_size, true, _counter.data(), stream_main); // P2P communication chunk for (int i = 1; i < _tp_size; i++) { @@ -907,10 +930,9 @@ void CommOverlapP2PBase::atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, T char *reduce_buf_ptr = reinterpret_cast(_ubufs[_tp_size - 1].dptr()); char *rs_output_ptr = reinterpret_cast(rs_output.dptr()); if (_ubuf.element_size() == 1 && rs_output.element_size() == 2) { - assert(_ubuf_scale_inv_initialized); TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( D.dtype(), fp8_type, - reduce_fp8_in_bf16_out(reduce_buf_ptr, rs_output_ptr, _ubuf_scale_inv, _tp_size, + reduce_fp8_in_bf16_out(reduce_buf_ptr, rs_output_ptr, D.scale_inv(), _tp_size, _ubufs[0].numel(), stream_main);); } else { reduce_bf16(reduce_buf_ptr, rs_output_ptr, _tp_size, _ubufs[0].numel(), stream_main); @@ -921,31 +943,33 @@ void CommOverlapP2PBase::atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, T /* ** Split ReduceScatter + GEMM using P2P communication */ -void CommOverlapP2PBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, - bool transb, TensorWrapper &D, TensorWrapper &bias, - TensorWrapper &pre_gelu_out, TensorWrapper &workspace, - bool grad, bool accumulate, bool use_split_accumulator, - TensorWrapper &rs_output, cudaStream_t stream_main) { +void CommOverlapP2PBase::split_overlap_rs(const TensorWrapper &A, bool transa, + const TensorWrapper &B, bool transb, TensorWrapper &D, + TensorWrapper &bias, TensorWrapper &pre_gelu_out, + TensorWrapper &workspace, bool grad, bool accumulate, + bool use_split_accumulator, TensorWrapper &rs_output, + cudaStream_t stream_main) { int ori_sms = _ub_comm->sms; _ub_comm->use_ce = _use_ce; _ub_comm->sms = _num_comm_sm; _ub_comm->cga_size = _cga_size; - size_t k = A.size(1); - size_t n = B.size(0); // Get communication and GEMM input chunk sizes - size_t n_chunk = n / _tp_size; + size_t m = transa ? A.size(0) : A.size(1); + size_t k = transa ? A.size(1) : A.size(0); + size_t n_chunk = _ubufs[0].size(0); const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size(); - const int input_b_chunk_bytes = n_chunk * k * B.element_size(); // Get input and workspace data pointers - char *input_b_ptr = reinterpret_cast(B.dptr()); - char *workspace_ptr = reinterpret_cast(workspace.dptr()); + size_t input_chunk_size = n_chunk * k; + size_t output_chunk_size = n_chunk * m; size_t workspace_size_chunk = workspace.numel() / _stream_compute.size(); // Catch up the main stream NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send, _start_compute, 0)); + for (size_t i = 0; i < _stream_send.size(); i++) { + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[i], _start_compute, 0)); + } NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_recv, _start_compute, 0)); for (size_t i = 0; i < _stream_compute.size(); i++) { NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_compute[i], _start_compute, 0)); @@ -954,36 +978,30 @@ void CommOverlapP2PBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorW // GEMM and send/recv chunks for (int i = 0; i < _tp_size; i++) { // GEMM chunk + int stream_id = i % _stream_compute.size(); int input_b_chunk_id = (_tp_id + i + 1) % _tp_size; - char *input_b_chunk_ptr = input_b_ptr + (input_b_chunk_id * input_b_chunk_bytes); - - auto input_b_chunk = TensorWrapper(reinterpret_cast(input_b_chunk_ptr), {n_chunk, k}, - B.dtype(), nullptr, nullptr, B.scale_inv()); - - auto output_chunk = - TensorWrapper(_ubufs[i].dptr(), _ubufs[i].shape(), D.dtype(), D.amax(), D.scale(), nullptr); - char *workspace_chunk_ptr = workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk; + auto input_b_chunk = get_tensor_chunk(B, input_b_chunk_id * input_chunk_size, {n_chunk, k}); + auto output_chunk = get_buffer_chunk_by_id(D, i); auto workspace_chunk = - TensorWrapper(reinterpret_cast(workspace_chunk_ptr), - std::vector{workspace_size_chunk}, workspace.dtype()); + get_tensor_chunk(workspace, stream_id * workspace_size_chunk, {workspace_size_chunk}); nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(), pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), accumulate, - use_split_accumulator, _math_sms, _stream_compute[i % _stream_compute.size()]); + use_split_accumulator, _math_sms, _stream_compute[stream_id]); if (i > 0) { // P2P communication chunk + int prev_stream_id = (i - 1) % _stream_compute.size(); int send_offset = comm_bytes * (i - 1); int recv_offset = comm_bytes * (i - 1 + _tp_size); int send_rank = (_tp_id + i) % _tp_size + _rank_round_tp; int recv_rank = (_tp_size + _tp_id - i) % _tp_size + _rank_round_tp; - NVTE_CHECK_CUDA( - cudaEventRecord(_start_comm, _stream_compute[(i - 1) % _stream_compute.size()])); - NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send, _start_comm, 0)); + NVTE_CHECK_CUDA(cudaEventRecord(_start_comm, _stream_compute[prev_stream_id])); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[prev_stream_id], _start_comm, 0)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_recv, _start_comm, 0)); userbuffers_send(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, send_rank, - _stream_send); + _stream_send[prev_stream_id]); userbuffers_recv(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, recv_rank, _stream_recv); } @@ -993,8 +1011,10 @@ void CommOverlapP2PBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorW NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, _stream_compute[i])); NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_compute, 0)); } - NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, _stream_send)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_send, 0)); + for (size_t i = 0; i < _stream_compute.size(); i++) { + NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, _stream_send[i])); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_send, 0)); + } NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_recv, 0)); @@ -1002,11 +1022,10 @@ void CommOverlapP2PBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorW char *reduce_buf_ptr = reinterpret_cast(_ubufs[_tp_size - 1].dptr()); char *rs_output_ptr = reinterpret_cast(rs_output.dptr()); if (_ubuf.element_size() == 1 && rs_output.element_size() == 2) { - assert(_ubuf_scale_inv_initialized); char *rs_output_ptr = reinterpret_cast(rs_output.dptr()); TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( D.dtype(), fp8_type, - reduce_fp8_in_bf16_out(reduce_buf_ptr, rs_output_ptr, _ubuf_scale_inv, _tp_size, + reduce_fp8_in_bf16_out(reduce_buf_ptr, rs_output_ptr, D.scale_inv(), _tp_size, _ubufs[0].numel(), stream_main);); } else { reduce_bf16(reduce_buf_ptr, rs_output_ptr, _tp_size, _ubufs[0].numel(), stream_main); diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu index b2cd71f76b..735148a811 100644 --- a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu @@ -19,6 +19,7 @@ #include #include +#include "common/util/system.h" #include "userbuffers.h" #define MAX_THREADS 1024 diff --git a/transformer_engine/common/common.cu b/transformer_engine/common/common.cu index 01b940f06a..cbeec66958 100644 --- a/transformer_engine/common/common.cu +++ b/transformer_engine/common/common.cu @@ -6,27 +6,138 @@ #include +#include + #include "./common.h" #include "./utils.cuh" +#include "common/util/cuda_runtime.h" +#include "common/util/logging.h" namespace transformer_engine { namespace { __global__ void __launch_bounds__(1) - update_tensor_scale_inv_kernel(const float* __restrict__ scale_ptr, - float* __restrict__ scale_inv_ptr) { + update_tensor_scale_inv_kernel(const float *__restrict__ scale_ptr, + float *__restrict__ scale_inv_ptr) { const float scale = scale_ptr == nullptr ? 1 : *scale_ptr; reciprocal(scale_inv_ptr, scale); } } // namespace -void update_tensor_scale_inv(Tensor* t, cudaStream_t stream) { - if (t->scale_inv.dptr != nullptr) { +void update_tensor_scale_inv(Tensor *t, cudaStream_t stream) { + if (is_fp8_dtype(t->data.dtype) && is_tensor_scaling(t->scaling_mode)) { + NVTE_CHECK(t->scale_inv.dptr != nullptr, "Tensor should have allocated scale_inv."); update_tensor_scale_inv_kernel<<<1, 1, 0, stream>>>( - reinterpret_cast(t->scale.dptr), reinterpret_cast(t->scale_inv.dptr)); + reinterpret_cast(t->scale.dptr), + reinterpret_cast(t->scale_inv.dptr)); } } +void checkCuDriverContext(CUstream stream) { + CUcontext ctx; + const CUresult driver_status = cuda_driver::call("cuStreamGetCtx", stream, &ctx); + switch (driver_status) { + case CUDA_SUCCESS: + break; + + case CUDA_ERROR_INVALID_CONTEXT: + int current_device; + NVTE_CHECK_CUDA(cudaGetDevice(¤t_device)); + NVTE_CALL_CHECK_CUDA_DRIVER(cuDevicePrimaryCtxRetain, &ctx, current_device); + NVTE_CALL_CHECK_CUDA_DRIVER(cuCtxSetCurrent, ctx); + break; + + default: + const char *desc_NVTE_CHECK_CUDA_DRIVER; + cuda_driver::call("cuGetErrorString", driver_status, &desc_NVTE_CHECK_CUDA_DRIVER); + NVTE_ERROR("CUDA Error: ", desc_NVTE_CHECK_CUDA_DRIVER); + } +} + +CUtensorMapDataType get_CUtensorMapDataType(DType dtype) { + static const std::unordered_map dtypeMapping = { + {DType::kByte, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_UINT8}, + {DType::kFloat32, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT32}, + {DType::kFloat16, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT16}, + {DType::kBFloat16, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16}, + {DType::kFloat8E4M3, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_UINT8}, + {DType::kFloat8E5M2, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_UINT8}}; + return dtypeMapping.at(dtype); +} + +inline bool isPointerAligned(const void *const ptr, const int alignment) { + const uint64_t ptr_as_uint = reinterpret_cast(ptr); + return ptr_as_uint % alignment == 0; +} + +// Set up parameters to create TMA descriptor. +void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor, + const uint64_t globalY, const uint64_t globalX, const uint32_t shmemY, + const uint32_t shmemX, const uint32_t stride_elems, + const uint32_t offset_elems, const size_t type_size) { + // Get a function pointer to the cuTensorMapEncodeTiled driver API + static PFN_cuTensorMapEncodeTiled cuDriverTensorMapEncodeTiled = []() { + void *driver_ptr = cuda_driver::get_symbol("cuTensorMapEncodeTiled"); + return reinterpret_cast(driver_ptr); + }(); + // rank is the number of dimensions of the array + constexpr uint32_t rank = 2; + uint64_t size[rank] = {globalX, globalY}; + + // The stride is the number of bytes to traverse from the first element of one row to the next + uint64_t stride[rank - 1] = {stride_elems * type_size}; + + // The boxSize is the size of the shared memory buffer that is used as the + // source/destination of a TMA transfer + uint32_t boxSize[rank] = {shmemX, shmemY}; + + // The distance between elements in units of sizeof(element) + uint32_t elemStride[rank] = {1, 1}; + + const CUtensorMapDataType tensorDataType = get_CUtensorMapDataType(tensor.dtype); + void *dataPtr = + reinterpret_cast(reinterpret_cast(tensor.dptr) + offset_elems * type_size); + + constexpr int TMA_gmem_alignment = 16; // Alignment of the global memory address + NVTE_CHECK(isPointerAligned(dataPtr, TMA_gmem_alignment), + "Tensor data pointer must be 16B aligned"); + + const int TMA_needed_size = TMA_gmem_alignment / type_size; + NVTE_CHECK(globalX % TMA_needed_size == 0, "Shape not supported. For ", type_size, + "-byte data type, expected multiple of ", TMA_needed_size, ", got ", globalX); + + // Create the tensor descriptor. + NVTE_CHECK_CUDA_DRIVER(cuDriverTensorMapEncodeTiled( + &tensorMap, // CUtensorMap *tensorMap, + tensorDataType, + rank, // cuuint32_t tensorRank, + dataPtr, // void *globalAddress, + size, // const cuuint64_t *globalDim, + stride, // const cuuint64_t *globalStrides, + boxSize, // const cuuint32_t *boxDim, + elemStride, // const cuuint32_t *elementStrides, + // Interleave patterns can be used to accelerate loading of values that + // are less than 4 bytes long. + CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE, + + // Swizzling can be used to avoid shared memory bank conflicts. + CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE, + + // L2 Promotion can be used to widen the effect of a cache-policy to a wider + // set of L2 cache lines. + CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_NONE, + // CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_256B, + + // Any element that is outside of bounds will be set to zero by the TMA transfer. + CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE)); +} + +bool is_supported_by_CC_100() { + int deviceComputeCapability = cuda::sm_arch(cuda::current_device()); + + return deviceComputeCapability >= 100; +} + } // namespace transformer_engine diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index d47ce472e5..ca9103532d 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -7,6 +7,7 @@ #ifndef TRANSFORMER_ENGINE_COMMON_COMMON_H_ #define TRANSFORMER_ENGINE_COMMON_COMMON_H_ +#include #include #include #include @@ -22,10 +23,29 @@ #include #include "./nvtx.h" +#include "./util/cuda_driver.h" #include "./util/logging.h" namespace transformer_engine { +inline size_t product(const std::vector &shape, const size_t begin, const size_t end) { + NVTE_CHECK(begin <= end && end <= shape.size(), "Attempted to access entries ", begin, " to ", + end, " in a vector with ", shape.size(), " entries"); + size_t ret = 1; + for (size_t i = begin; i < end; ++i) { + ret *= shape[i]; + } + return ret; +} + +inline size_t product(const std::vector &shape) { + size_t ret = 1; + for (const auto &elem : shape) { + ret *= elem; + } + return ret; +} + struct SimpleTensor { void *dptr; std::vector shape; @@ -33,20 +53,114 @@ struct SimpleTensor { SimpleTensor(void *dptr, const std::vector &shape, DType dtype) : dptr(dptr), shape(shape), dtype(dtype) {} + + SimpleTensor(const NVTEBasicTensor &tensor) // NOLINT + : dptr(tensor.data_ptr), + shape(tensor.shape.data, tensor.shape.data + tensor.shape.ndim), + dtype(static_cast(tensor.dtype)) {} + SimpleTensor() : SimpleTensor(nullptr, {}, DType::kFloat32) {} + + operator NVTEBasicTensor() const { + const NVTEShape shape = {this->shape.data(), this->shape.size()}; + return {dptr, static_cast(dtype), shape}; + } + + int numel() const { + size_t acc = 1; + for (const auto &dim : shape) { + acc *= dim; + } + return acc; + } }; struct Tensor { SimpleTensor data; + SimpleTensor columnwise_data; SimpleTensor amax; SimpleTensor scale; SimpleTensor scale_inv; + SimpleTensor columnwise_scale_inv; + + NVTEScalingMode scaling_mode; Tensor() : data(), + columnwise_data(), amax(nullptr, {1}, DType::kFloat32), scale(nullptr, {1}, DType::kFloat32), - scale_inv(nullptr, {1}, DType::kFloat32) {} + scale_inv(nullptr, {1}, DType::kFloat32), + columnwise_scale_inv(nullptr, {1}, DType::kFloat32), + scaling_mode(NVTE_DELAYED_TENSOR_SCALING) {} + + int numel() const { + NVTE_CHECK(data.dptr != nullptr || columnwise_data.dptr != nullptr, + "Tensor does not hold any data!"); + size_t acc = 1; + if (data.dptr != nullptr) { + for (const auto &dim : data.shape) { + acc *= dim; + } + return acc; + } + // data is empty, use columnwise_data + for (const auto &dim : columnwise_data.shape) { + acc *= dim; + } + return acc; + } + + bool has_data() const noexcept { return data.dptr != nullptr; } + + bool has_columnwise_data() const noexcept { return columnwise_data.dptr != nullptr; } + + DType dtype() const { + if (has_data()) return data.dtype; + if (has_columnwise_data()) return columnwise_data.dtype; + // Fallback, used e.g. in workspace + return data.dtype; + } + + /*! Matrix height after tensor is flattened to 2D + * + * If a tensor has dimensions (D1, D2, ..., Dn), it is reinterpreted + * as a (D1*D2*...*D(n-1), Dn) matrix. + */ + size_t flat_first_dim() const { + if (!has_data() && has_columnwise_data()) { + const auto &data_shape = columnwise_data.shape; + if (data_shape.empty()) return 1; + if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { + return product(data_shape, 1, data_shape.size()); + } else { + return product(data_shape, 0, data_shape.size() - 1); + } + } + const auto &data_shape = data.shape; + if (data_shape.empty()) return 1; + return product(data_shape, 0, data_shape.size() - 1); + } + + /*! Matrix width after tensor is flattened to 2D + * + * If a tensor has dimensions (D1, D2, ..., Dn), it is reinterpreted + * as a (D1*D2*...*D(n-1), Dn) matrix. + */ + size_t flat_last_dim() const { + if (!has_data() && has_columnwise_data()) { + const auto &data_shape = columnwise_data.shape; + if (data_shape.empty()) return 1; + if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { + return data_shape.front(); + } else { + return data_shape.back(); + } + } + const auto &data_shape = data.shape; + if (data_shape.empty()) return 1; + return data_shape.back(); + } }; template @@ -62,6 +176,10 @@ using fp16 = half; using bf16 = nv_bfloat16; using fp8e4m3 = __nv_fp8_e4m3; using fp8e5m2 = __nv_fp8_e5m2; +#if CUDA_VERSION >= 12080 +using fp8e8m0 = __nv_fp8_e8m0; +#endif +using e8m0_t = uint8_t; namespace detail { @@ -80,6 +198,9 @@ TRANSFORMER_ENGINE_TYPE_NAME(half) TRANSFORMER_ENGINE_TYPE_NAME(nv_bfloat16) TRANSFORMER_ENGINE_TYPE_NAME(__nv_fp8_e4m3) TRANSFORMER_ENGINE_TYPE_NAME(__nv_fp8_e5m2) +#if CUDA_VERSION >= 12080 +TRANSFORMER_ENGINE_TYPE_NAME(__nv_fp8_e8m0) +#endif #undef TRANSFORMER_ENGINE_TYPE_NAME } // namespace detail @@ -150,6 +271,10 @@ struct TypeInfo { using type = fp8e5m2; \ { __VA_ARGS__ } \ } break; \ + case DType::kFloat8E8M0: { \ + using type = byte; \ + { __VA_ARGS__ } \ + } break; \ default: \ NVTE_ERROR("Invalid type."); \ } @@ -181,6 +306,25 @@ struct TypeInfo { NVTE_ERROR("Invalid type."); \ } +#define TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(dtype, type, ...) \ + switch (dtype) { \ + using namespace transformer_engine; \ + case DType::kFloat32: { \ + using type = float; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kFloat16: { \ + using type = fp16; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kBFloat16: { \ + using type = bf16; \ + { __VA_ARGS__ } \ + } break; \ + default: \ + NVTE_ERROR("Invalid type."); \ + } + #define TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(dtype, type, ...) \ switch (dtype) { \ using namespace transformer_engine; \ @@ -236,15 +380,22 @@ struct TypeInfo { NVTE_ERROR("Invalid type for 16 bit."); \ } -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline size_t product(const std::vector &shape) { - size_t ret = 1; - for (const auto &elem : shape) { - ret *= elem; +#define TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH(SCALE_DIM, DIM, ...) \ + switch (SCALE_DIM) { \ + case 1: { \ + constexpr size_t DIM = 1; \ + { __VA_ARGS__ } \ + } break; \ + case 32: { \ + constexpr size_t DIM = 32; \ + { __VA_ARGS__ } \ + } break; \ + default: { \ + NVTE_ERROR("Invalid size of the MX scaling factor."); \ + } \ } - return ret; -} + +//////////////////////////////////////////////////////////////////////////////////////////////////// inline int log2_ceil(int value) { int log2_value = 0; @@ -269,13 +420,37 @@ struct is_fp8 : std::true_type {}; template <> struct is_fp8 : std::true_type {}; +// [128,4] rowwise and [4,128] colwise alignment requirements for the tensor with scaling factors +constexpr size_t scale_tensor_alignment_X_rowwise = 4; +constexpr size_t scale_tensor_alignment_Y_rowwise = 128; +constexpr size_t scale_tensor_alignment_X_colwise = 128; +constexpr size_t scale_tensor_alignment_Y_colwise = 4; + size_t typeToSize(const DType type); +void CheckNoopTensor(const Tensor &t, const std::string &name); void CheckInputTensor(const Tensor &t, const std::string &name); void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empty = false); bool is_fp8_dtype(const DType t); +std::string to_string(const DType type); +std::string to_string(const NVTEScalingMode &type); + +inline bool is_tensor_scaling(const NVTEScalingMode &mode) { + return mode == NVTE_DELAYED_TENSOR_SCALING; +} + +inline bool is_block_scaling(const NVTEScalingMode &mode) { + return mode != NVTE_DELAYED_TENSOR_SCALING; +} + +inline bool is_delayed_tensor_scaling(const NVTEScalingMode &mode) { + return is_tensor_scaling(mode); +} + +inline bool is_mxfp_scaling(const NVTEScalingMode &mode) { return mode == NVTE_MXFP8_1D_SCALING; } + /*! \brief Update a tensor's FP8 scale-inverse * * The FP8 scale-inverse (dequantization scaling factor) is updated @@ -286,6 +461,20 @@ void update_tensor_scale_inv(Tensor *t, cudaStream_t stream); #define NVTE_API_CALL(api_name) \ transformer_engine::nvtx::NVTXWrapper _##api_name##_nvtx_wrapper(#api_name); +void checkCuDriverContext(CUstream stream); + +CUtensorMapDataType get_CUtensorMapDataType(DType dtype); + +inline bool isPointerAligned(const void *const ptr, const int alignment); + +// Set up parameters to create TMA descriptor. +void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor, + const uint64_t globalY, const uint64_t globalX, const uint32_t shmemY, + const uint32_t shmemX, const uint32_t stride_elems, + const uint32_t offset_elems, const size_t type_size); + +bool is_supported_by_CC_100(); + } // namespace transformer_engine #endif // TRANSFORMER_ENGINE_COMMON_COMMON_H_ diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 5d3e1d6097..01151a50db 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -93,17 +93,31 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( const bool supported_ragged_offset_size = (!requires_64bit_ragged_offset || cudnn_runtime_version >= 90500); - if (((q_dtype == NVTEDType::kNVTEFloat8E4M3) || (q_dtype == NVTEDType::kNVTEFloat8E5M2)) && - (sm_arch_ >= 90) && (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) && - (((cudnn_runtime_version >= 8900) && (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) && - (max_seqlen_q == max_seqlen_kv) && (max_seqlen_q <= 512) && (head_dim_qk == 64) && - (head_dim_v == 64) && (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK)) || - ((cudnn_runtime_version >= 90201) && (max_seqlen_q % 128 == 0) && - (max_seqlen_kv % 128 == 0) && (head_dim_qk == 128) && (head_dim_v == 128) && - ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || - (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) && - ((attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) || - (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK)))) && + if ((q_dtype == NVTEDType::kNVTEFloat8E4M3 || q_dtype == NVTEDType::kNVTEFloat8E5M2) && + sm_arch_ >= 90 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && + // 8.9: t3hd, max_s=512, d=64, padding + ((cudnn_runtime_version >= 8900 && sm_arch_ < 100 && + qkv_layout == NVTE_QKV_Layout::NVTE_T3HD && max_seqlen_q == max_seqlen_kv && + max_seqlen_q <= 512 && head_dim_qk == 64 && head_dim_v == 64 && + attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) || + // 9.2: {bshd, sbhd}, any seqlen, d=128, {no_mask, causal} + (cudnn_runtime_version >= 90201 && sm_arch_ < 100 && max_seqlen_q % 128 == 0 && + max_seqlen_kv % 128 == 0 && head_dim_qk == 128 && head_dim_v == 128 && + (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK)) || + // 9.7: {bshd, sbhd}, any seqlen, d<=256 for sm90 and d<=128 for sm100, {padding, padding_causal} + (cudnn_runtime_version >= 90700 && + // TODO (cyang): add is_training to nvte_get_fused_attn_backend + // sm90: fwd d<=256, bwd d=128 only + // sm100: fwd d<=128, bwd d<=128 + ((sm_arch_ < 100 && head_dim_qk <= 256 && head_dim_v <= 256) || + (sm_arch_ >= 100 && head_dim_qk <= 128 && head_dim_v <= 128)) && + head_dim_qk % 16 == 0 && head_dim_v % 16 == 0 && + (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK))) && + (qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD) && !requires_64bit_ragged_offset) { if (cudnn_runtime_version >= 8900) { backend = NVTE_Fused_Attn_Backend::NVTE_FP8; @@ -135,8 +149,12 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( !requires_64bit_ragged_offset) { flag_m512 = true; } - // TODO(cyang): replace with cudnn-frontend check_support for cleaner logic and better error messaging - if ( // architecture + if ( + // TODO(cyang): replace with cudnn-frontend check_support for cleaner logic and better error messaging + // special conditions for blackwell + // TODO: enable THD max_t in f16_arbitrary_seqlen when support becomes available in 9.7 + !(sm_arch_ == 100 && (head_dim_qk > 128 || head_dim_v > 128)) && + // architecture ((cudnn_runtime_version >= 8903 && sm_arch_ >= 80) || (cudnn_runtime_version < 8903 && (sm_arch_ == 80 || sm_arch_ == 90))) && // sequence length @@ -218,9 +236,16 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( (cudnn_runtime_version >= 90600 && ((window_size_left == -1 && (window_size_right == -1 || window_size_right == 0)) || ((window_size_left >= 0 || window_size_left == -1) && window_size_right == 0 && - (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK || + ((attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK && + // TODO(cyang): fix bug for BRCM + cross-attention on sm100 + (sm_arch_ < 100 || (sm_arch_ == 100 && ((max_seqlen_q == max_seqlen_kv && + cudnn_runtime_version <= 90700) || + cudnn_runtime_version > 90700)))) || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK) && + (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK && + (sm_arch_ < 100 || (sm_arch_ == 100 && ((max_seqlen_q == max_seqlen_kv && + cudnn_runtime_version <= 90700) || + cudnn_runtime_version > 90700))))) && max_seqlen_q <= max_seqlen_kv && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && dropout == 0.0)))) && // check 64-bit ragged offset support diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index 20467af663..36ff5291a8 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -227,7 +227,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( .set_attn_scale(attn_scale); if (cudnn_runtime_version >= 90200 && window_size_left != -1) { - sdpa_options.set_sliding_window_length(window_size_left + 1); + sdpa_options.set_diagonal_band_left_bound(window_size_left + 1); } sdpa_options.set_alibi_mask(is_alibi); @@ -457,8 +457,6 @@ void fused_attn_arbitrary_seqlen_bwd_impl( bool is_dropout = (dropout_probability != 0.0f); bool is_ragged = (nvte_get_qkv_format(layout) == NVTE_QKV_Format::NVTE_THD); const auto cudnn_runtime_version = cudnnGetVersion(); - const int device_id = cuda::current_device(); - const int sm_arch_ = cuda::sm_arch(device_id); // keep original batch size because cu_seqlens are created with [b+1] shape int64_t actual_b = b; if (is_ragged && cudnn_runtime_version >= 90600) { @@ -667,7 +665,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl( } if (cudnn_runtime_version >= 90200 && window_size_left != -1) { - sdpa_backward_options.set_sliding_window_length(window_size_left + 1); + sdpa_backward_options.set_diagonal_band_left_bound(window_size_left + 1); } if (cudnn_runtime_version >= 90000) { diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index 0044a94b2f..b4424d9bf6 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -1670,8 +1670,6 @@ void fused_attn_fp8_fwd_impl_v1( auto bias_h = h; NVTE_CHECK(~is_bias, "FP8 fused attention does not support pre/post_scale_bias yet!"); NVTE_CHECK(~is_alibi, "FP8 fused attention does not support ALiBi yet!"); - NVTE_CHECK(~is_padding, "FP8 fused attention does not support padding/padding_causal mask yet!"); - NVTE_CHECK(~is_dropout, "FP8 fused attention does not support dropout yet!"); try { FADescriptor_v1 descriptor{b, @@ -1798,36 +1796,33 @@ void fused_attn_fp8_fwd_impl_v1( // sdpa_options.set_bias(bias); // } - // if (is_padding) { - // seq_q = mha_graph->tensor(fe::graph::Tensor_attributes() - // .set_name("seq_q") - // .set_dim({b, 1, 1, 1}) - // .set_stride({1, 1, 1, 1}) - // .set_data_type(fe::DataType_t::INT32)); - // seq_kv = mha_graph->tensor(fe::graph::Tensor_attributes() - // .set_name("seq_kv") - // .set_dim({b, 1, 1, 1}) - // .set_stride({1, 1, 1, 1}) - // .set_data_type(fe::DataType_t::INT32)); - // sdpa_options.set_padding_mask(is_padding) - // .set_seq_len_q(seq_q) - // .set_seq_len_kv(seq_kv); - // } + if (is_padding) { + seq_q = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("seq_q") + .set_dim({b, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); + seq_kv = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("seq_kv") + .set_dim({b, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); + sdpa_options.set_padding_mask(is_padding).set_seq_len_q(seq_q).set_seq_len_kv(seq_kv); + } - // if (is_dropout) { - // dropout_seed = mha_graph->tensor(fe::graph::Tensor_attributes() - // .set_name("Seed") - // .set_dim({1, 1, 1, 1}) - // .set_stride({1, 1, 1, 1}) - // .set_data_type(fe::DataType_t::INT64)); - // dropout_offset = mha_graph->tensor(fe::graph::Tensor_attributes() - // .set_name("Offset") - // .set_dim({1, 1, 1, 1}) - // .set_stride({1, 1, 1, 1}) - // .set_data_type(fe::DataType_t::INT64)); - // sdpa_options.set_dropout( - // dropout_probability, dropout_seed, dropout_offset); - // } + if (is_dropout) { + dropout_seed = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Seed") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT64)); + dropout_offset = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Offset") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT64)); + sdpa_options.set_dropout(dropout_probability, dropout_seed, dropout_offset); + } auto [O, Stats, amax_s, amax_o] = mha_graph->sdpa_fp8( Q, K, V, descale_q, descale_k, descale_v, descale_s, scale_s, scale_o, sdpa_options); @@ -1919,29 +1914,28 @@ void fused_attn_fp8_fwd_impl_v1( {amax_o, devPtrAmaxO}, {Stats, devPtrM}}; - // if (is_bias) { - // variant_pack[bias] = devPtrBias; - // } - - // if (is_padding) { - // constexpr size_t nthreads_per_block = 128; - // const size_t grid = (b + nthreads_per_block - 1) / nthreads_per_block; - // void *devActualSeqlenQ = static_cast(workspace) + plan_workspace_size; - // void *devActualSeqlenKV = static_cast(devActualSeqlenQ) - // + b * sizeof(int32_t); - // cu_seqlens_to_actual_seqlens<<>>( - // b, static_cast(devPtrCuSeqlensQ), - // static_cast(devPtrCuSeqlensKV), - // static_cast(devActualSeqlenQ), - // static_cast(devActualSeqlenKV)); - // variant_pack[seq_q] = devActualSeqlenQ; - // variant_pack[seq_kv] = devActualSeqlenKV; - // } - - // if (is_dropout) { - // variant_pack[dropout_seed] = devPtrDropoutSeed; - // variant_pack[dropout_offset] = devPtrDropoutOffset; - // } + /* if (is_bias) { + variant_pack[bias] = devPtrBias; + } */ + + if (is_padding) { + constexpr size_t nthreads_per_block = 128; + const size_t grid = (b + nthreads_per_block - 1) / nthreads_per_block; + void* devActualSeqlenQ = static_cast(workspace) + plan_workspace_size; + void* devActualSeqlenKV = static_cast(devActualSeqlenQ) + b * sizeof(int32_t); + cu_seqlens_to_actual_seqlens<<>>( + b, b, static_cast(devPtrcuSeqlensQ), // TODO(pass max_b) + static_cast(devPtrcuSeqlensKV), static_cast(devActualSeqlenQ), + static_cast(devActualSeqlenKV)); + variant_pack[seq_q] = devActualSeqlenQ; + variant_pack[seq_kv] = devActualSeqlenKV; + } + + if (is_dropout) { + variant_pack[dropout_seed] = devPtrDropoutSeed; + variant_pack[dropout_offset] = devPtrDropoutOffset; + } + NVTE_CHECK_CUDNN_FE(mha_graph->execute(handle, variant_pack, workspace)); } catch (cudnn_frontend::cudnnException& e) { NVTE_ERROR(e.what()); @@ -1974,8 +1968,6 @@ void fused_attn_fp8_bwd_impl_v1( auto bias_h = h; NVTE_CHECK(~is_bias, "FP8 fused attention does not support pre/post_scale_bias yet!"); NVTE_CHECK(~is_alibi, "FP8 fused attention does not support ALiBi yet!"); - NVTE_CHECK(~is_padding, "FP8 fused attention does not support padding/padding_causal mask yet!"); - NVTE_CHECK(~is_dropout, "FP8 fused attention does not support dropout yet!"); try { FADescriptor_v1 descriptor{b, @@ -2151,36 +2143,35 @@ void fused_attn_fp8_bwd_impl_v1( // } // } - // if (is_padding) { - // seq_q = mha_graph->tensor(fe::graph::Tensor_attributes() - // .set_name("seq_q") - // .set_dim({b, 1, 1, 1}) - // .set_stride({1, 1, 1, 1}) - // .set_data_type(fe::DataType_t::INT32)); - // seq_kv = mha_graph->tensor(fe::graph::Tensor_attributes() - // .set_name("seq_kv") - // .set_dim({b, 1, 1, 1}) - // .set_stride({1, 1, 1, 1}) - // .set_data_type(fe::DataType_t::INT32)); - // sdpa_backward_options.set_padding_mask(is_padding) - // .set_seq_len_q(seq_q) - // .set_seq_len_kv(seq_kv); - // } + if (is_padding) { + seq_q = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("seq_q") + .set_dim({b, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); + seq_kv = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("seq_kv") + .set_dim({b, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); + sdpa_backward_options.set_padding_mask(is_padding) + .set_seq_len_q(seq_q) + .set_seq_len_kv(seq_kv); + } - // if (is_dropout) { - // dropout_seed = mha_graph->tensor(fe::graph::Tensor_attributes() - // .set_name("Seed") - // .set_dim({1, 1, 1, 1}) - // .set_stride({1, 1, 1, 1}) - // .set_data_type(fe::DataType_t::INT64)); - // dropout_offset = mha_graph->tensor(fe::graph::Tensor_attributes() - // .set_name("Offset") - // .set_dim({1, 1, 1, 1}) - // .set_stride({1, 1, 1, 1}) - // .set_data_type(fe::DataType_t::INT64)); - // sdpa_backward_options.set_dropout( - // dropout_probability, dropout_seed, dropout_offset); - // } + if (is_dropout) { + dropout_seed = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Seed") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT64)); + dropout_offset = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Offset") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT64)); + sdpa_backward_options.set_dropout(dropout_probability, dropout_seed, dropout_offset); + } auto [dQ, dK, dV, amax_dQ, amax_dK, amax_dV, amax_dP] = mha_graph->sdpa_fp8_backward( q, k, v, o, dO, stats, descale_q, descale_k, descale_v, descale_o, descale_dO, descale_s, @@ -2308,34 +2299,32 @@ void fused_attn_fp8_bwd_impl_v1( {amax_dP, devPtrAmaxdP}, }; - // if (is_bias) { - // variant_pack[bias] = devPtrBias; - // if ((bias_b == 1) && (bias_h == h)) { - // variant_pack[dBias] = devPtrdBias; - // } else { - // variant_pack[dBias] = nullptr; - // } - // } - - // if (is_padding) { - // constexpr size_t nthreads_per_block = 128; - // const size_t grid = (b + nthreads_per_block - 1) / nthreads_per_block; - // void *devActualSeqlenQ = static_cast(workspace) + plan_workspace_size; - // void *devActualSeqlenKV = static_cast(devActualSeqlenQ) - // + b * sizeof(int32_t); - // cu_seqlens_to_actual_seqlens<<>>( - // b, static_cast(devPtrCuSeqlensQ), - // static_cast(devPtrCuSeqlensKV), - // static_cast(devActualSeqlenQ), - // static_cast(devActualSeqlenKV)); - // variant_pack[seq_q] = devActualSeqlenQ; - // variant_pack[seq_kv] = devActualSeqlenKV; - // } - - // if (is_dropout) { - // variant_pack[dropout_seed] = devPtrDropoutSeed; - // variant_pack[dropout_offset] = devPtrDropoutOffset; - // } + /* if (is_bias) { + variant_pack[bias] = devPtrBias; + if ((bias_b == 1) && (bias_h == h)) { + variant_pack[dBias] = devPtrdBias; + } else { + variant_pack[dBias] = nullptr; + } + } */ + + if (is_padding) { + constexpr size_t nthreads_per_block = 128; + const size_t grid = (b + nthreads_per_block - 1) / nthreads_per_block; + void* devActualSeqlenQ = static_cast(workspace) + plan_workspace_size; + void* devActualSeqlenKV = static_cast(devActualSeqlenQ) + b * sizeof(int32_t); + cu_seqlens_to_actual_seqlens<<>>( + b, b, static_cast(devPtrcuSeqlensQ), // TODO(pass max_b) + static_cast(devPtrcuSeqlensKV), static_cast(devActualSeqlenQ), + static_cast(devActualSeqlenKV)); + variant_pack[seq_q] = devActualSeqlenQ; + variant_pack[seq_kv] = devActualSeqlenKV; + } + + if (is_dropout) { + variant_pack[dropout_seed] = devPtrDropoutSeed; + variant_pack[dropout_offset] = devPtrDropoutOffset; + } NVTE_CHECK_CUDNN_FE(mha_graph->execute(handle, variant_pack, workspace)); } catch (cudnn_frontend::cudnnException& e) { diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index ef7cdc0af9..52fa89b914 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -15,6 +15,7 @@ #include "../common.h" #include "../util/logging.h" +#include "common/util/cuda_runtime.h" namespace { @@ -46,6 +47,95 @@ uint32_t _getAlignment(uintptr_t address) { } } +struct GemmParam { + void *A; + void *B; + cublasOperation_t transA; + cublasOperation_t transB; + transformer_engine::DType Atype; + transformer_engine::DType Btype; + void *A_scale_inv; + void *B_scale_inv; + int lda; + int ldb; + + GemmParam(cublasOperation_t transA, cublasOperation_t transB) + : A(nullptr), + B(nullptr), + transA(transA), + transB(transB), + Atype(transformer_engine::DType::kNumTypes), + Btype(transformer_engine::DType::kNumTypes), + A_scale_inv(nullptr), + B_scale_inv(nullptr), + lda(0), + ldb(0) {} +}; + +GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cublasOperation_t transA, + const transformer_engine::Tensor &B, const cublasOperation_t transB, + const int k, const int lda, const int ldb) { + using namespace transformer_engine; + NVTE_CHECK(A.scaling_mode == B.scaling_mode, + "Inputs A and B to GEMM need to have the same scaling mode!"); + NVTE_CHECK(A.has_data() || A.has_columnwise_data(), "Input A does not hold any data!"); + NVTE_CHECK(B.has_data() || B.has_columnwise_data(), "Input B does not hold any data!"); + GemmParam ret(transA, transB); + + ret.lda = lda; + ret.ldb = ldb; + + if (is_tensor_scaling(A.scaling_mode)) { + ret.A = A.data.dptr; + ret.A_scale_inv = A.scale_inv.dptr; + if (transA == CUBLAS_OP_T) { + ret.Atype = A.data.dtype; + } else { + ret.Atype = A.has_columnwise_data() ? A.columnwise_data.dtype : A.data.dtype; + if (is_fp8_dtype(ret.Atype)) { + int arch = cuda::sm_arch(cuda::current_device()); + if (arch < 100) { + // Hopper and Ada - we need to use columnwise_data and change transA + NVTE_CHECK(A.has_columnwise_data(), "Input A is not suitable for columnwise usage!"); + ret.A = A.columnwise_data.dptr; + ret.transA = CUBLAS_OP_T; + ret.A_scale_inv = A.columnwise_scale_inv.dptr; + ret.lda = k; + } + } + } + ret.B = B.data.dptr; + ret.B_scale_inv = B.scale_inv.dptr; + if (transB == CUBLAS_OP_T) { + ret.Btype = B.has_columnwise_data() ? B.columnwise_data.dtype : B.data.dtype; + if (is_fp8_dtype(ret.Btype)) { + int arch = cuda::sm_arch(cuda::current_device()); + if (arch < 100) { + // Hopper and Ada - we need to use columnwise_data and change transA + NVTE_CHECK(B.has_columnwise_data(), "Input B is not suitable for columnwise usage!"); + ret.B = B.columnwise_data.dptr; + ret.transB = CUBLAS_OP_N; + ret.B_scale_inv = B.columnwise_scale_inv.dptr; + ret.ldb = k; + } + } + } else { + ret.Btype = B.data.dtype; + } + } else { + // If not tensor scaling (which includes also high precision types), we need to + // use the proper version of data + // We leave the transA/B values as is, since Blackwell supports transposes + ret.A = transA ? A.data.dptr : A.columnwise_data.dptr; + ret.Atype = transA ? A.data.dtype : A.columnwise_data.dtype; + ret.A_scale_inv = transA ? A.scale_inv.dptr : A.columnwise_scale_inv.dptr; + ret.B = transB ? B.columnwise_data.dptr : B.data.dptr; + ret.Btype = transB ? B.columnwise_data.dtype : B.data.dtype; + ret.B_scale_inv = transB ? B.columnwise_scale_inv.dptr : B.scale_inv.dptr; + } + return ret; +} + } // namespace namespace transformer_engine { @@ -56,10 +146,13 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, void *workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator, int math_sm_count, int m_split, int n_split, bool gemm_producer, const Tensor *inputCounter, cudaStream_t stream) { - void *A = inputA->data.dptr; - void *A_scale_inverse = inputA->scale_inv.dptr; - void *B = inputB->data.dptr; - void *B_scale_inverse = inputB->scale_inv.dptr; + // Return immediately if GEMM is trivial + if (m <= 0 || n <= 0) { + return; + } + NVTE_CHECK(k > 0); + + const GemmParam ¶m = CanonicalizeGemmInput(*inputA, transa, *inputB, transb, k, lda, ldb); void *C = outputD->data.dptr; void *D = outputD->data.dptr; void *D_scale = outputD->scale.dptr; @@ -72,15 +165,16 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, counter = inputCounter->data.dptr; } const bool gelu = pre_gelu_out != nullptr; - const bool use_fp8 = is_fp8_dtype(inputA->data.dtype) || is_fp8_dtype(inputB->data.dtype); - const cudaDataType_t A_type = get_cuda_dtype(inputA->data.dtype); - const cudaDataType_t B_type = get_cuda_dtype(inputB->data.dtype); + const bool use_fp8 = is_fp8_dtype(param.Atype) || is_fp8_dtype(param.Btype); + + const cudaDataType_t A_type = get_cuda_dtype(param.Atype); + const cudaDataType_t B_type = get_cuda_dtype(param.Btype); const cudaDataType_t D_type = get_cuda_dtype(outputD->data.dtype); const cudaDataType_t bias_type = get_cuda_dtype(inputBias->data.dtype); - NVTE_CHECK(!is_fp8_dtype(inputA->data.dtype) || A_scale_inverse != nullptr, + NVTE_CHECK(!is_fp8_dtype(param.Atype) || param.A_scale_inv != nullptr, "FP8 input to GEMM requires inverse of scale!"); - NVTE_CHECK(!is_fp8_dtype(inputB->data.dtype) || B_scale_inverse != nullptr, + NVTE_CHECK(!is_fp8_dtype(param.Btype) || param.B_scale_inv != nullptr, "FP8 input to GEMM requires inverse of scale!"); // check consistency of arguments: @@ -117,17 +211,17 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, } // Create matrix descriptors. Not setting any extra attributes. - NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Adesc, A_type, transa == CUBLAS_OP_N ? m : k, - transa == CUBLAS_OP_N ? k : m, lda)); - NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Bdesc, B_type, transb == CUBLAS_OP_N ? k : n, - transb == CUBLAS_OP_N ? n : k, ldb)); + NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Adesc, A_type, param.transA == CUBLAS_OP_N ? m : k, + param.transA == CUBLAS_OP_N ? k : m, param.lda)); + NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Bdesc, B_type, param.transB == CUBLAS_OP_N ? k : n, + param.transB == CUBLAS_OP_N ? n : k, param.ldb)); NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Ddesc, D_type, m, n, ldd)); NVTE_CHECK_CUBLAS(cublasLtMatmulDescCreate(&operationDesc, gemm_compute_type, CUDA_R_32F)); NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, - &transa, sizeof(transa))); + ¶m.transA, sizeof(param.transA))); NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, - &transb, sizeof(transb))); + ¶m.transB, sizeof(param.transB))); // Set math SM count if (math_sm_count != 0) { NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, @@ -143,12 +237,53 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, const int8_t fastAccuMode = (use_split_accumulator) ? 0 : 1; NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_FAST_ACCUM, &fastAccuMode, sizeof(fastAccuMode))); - NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, - CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, - &A_scale_inverse, sizeof(A_scale_inverse))); - NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, - CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, - &B_scale_inverse, sizeof(B_scale_inverse))); + + // Scaling factors. +#if CUDA_VERSION >= 12080 + cublasLtMatmulMatrixScale_t scaling_mode; +#endif + if ((is_delayed_tensor_scaling(inputA->scaling_mode) && + is_delayed_tensor_scaling(inputB->scaling_mode))) { + void *A_scale_inverse = param.A_scale_inv; + void *B_scale_inverse = param.B_scale_inv; + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, + CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, + &A_scale_inverse, sizeof(A_scale_inverse))); + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, + CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, + &B_scale_inverse, sizeof(B_scale_inverse))); +#if CUDA_VERSION >= 12080 + scaling_mode = CUBLASLT_MATMUL_MATRIX_SCALE_SCALAR_32F; + } else if ((is_block_scaling(inputA->scaling_mode) && is_block_scaling(inputB->scaling_mode))) { + fp8e8m0 *A_scale_inverse = reinterpret_cast(param.A_scale_inv); + fp8e8m0 *B_scale_inverse = reinterpret_cast(param.B_scale_inv); + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, + CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, + &A_scale_inverse, sizeof(A_scale_inverse))); + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, + CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, + &B_scale_inverse, sizeof(B_scale_inverse))); + scaling_mode = CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0; + // Workaround for heuristic cache bug in cublasLt. This separates the MXFP8 cache key from non-block scaling. + // CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE is unused for block scaling so it's safe to set. + if (cublasLtGetVersion() <= 120803) { + const int64_t dummy_a_vec_stride = 1; + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( + operationDesc, CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE, &dummy_a_vec_stride, + sizeof(dummy_a_vec_stride))); + } +#endif + } else { + NVTE_ERROR("Not implemented scaling modes: " + to_string(inputA->scaling_mode) + " and " + + to_string(inputB->scaling_mode) + "."); + } + +#if CUDA_VERSION >= 12080 + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( + operationDesc, CUBLASLT_MATMUL_DESC_A_SCALE_MODE, &scaling_mode, sizeof(scaling_mode))); + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( + operationDesc, CUBLASLT_MATMUL_DESC_B_SCALE_MODE, &scaling_mode, sizeof(scaling_mode))); +#endif if (is_fp8_dtype(outputD->data.dtype)) { // Accumulation mode not supported for FP8 output C = nullptr; @@ -156,8 +291,14 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, operationDesc, CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, &D_scale, sizeof(D_scale))); NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( operationDesc, CUBLASLT_MATMUL_DESC_AMAX_D_POINTER, &D_amax, sizeof(D_amax))); - // For FP8 output, cuBLAS requires C_type to be same as bias_type - NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Cdesc, bias_type, m, n, ldd)); +#if CUDA_VERSION >= 12080 + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( + operationDesc, CUBLASLT_MATMUL_DESC_D_SCALE_MODE, &scaling_mode, sizeof(scaling_mode))); +#endif + // For FP8 output, cuBLAS requires C_type to match bias_type and + // be FP16/BF16 + const cudaDataType_t C_type = bias ? bias_type : CUDA_R_16BF; + NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Cdesc, C_type, m, n, ldd)); } else { NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Cdesc, D_type, m, n, ldd)); } @@ -235,8 +376,8 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceCreate(&preference)); NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize))); - const auto A_alignment = _getAlignment(reinterpret_cast(A)); - const auto B_alignment = _getAlignment(reinterpret_cast(B)); + const auto A_alignment = _getAlignment(reinterpret_cast(param.A)); + const auto B_alignment = _getAlignment(reinterpret_cast(param.B)); const auto C_alignment = _getAlignment(reinterpret_cast(C)); const auto D_alignment = _getAlignment(reinterpret_cast(D)); NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( @@ -260,8 +401,8 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, // D = alpha * (A * B) + beta * C NVTE_CHECK_CUBLAS(cublasLtMatmul(handle, operationDesc, static_cast(&one), /* alpha */ - A, /* A */ - Adesc, B, /* B */ + param.A, /* A */ + Adesc, param.B, /* B */ Bdesc, static_cast(&beta), /* beta */ C, /* C */ Cdesc, D, /* D */ @@ -270,7 +411,10 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, workspaceSize, stream)); /* stream */ // Update FP8 scale-inv in output tensor - if (is_fp8_dtype(outputD->data.dtype)) { + // Note: This is a WAR for the case when we have fp8 output but D->scale_inv is not allocated. + // TODO: Changing gemm interface so that D->scale_inv is allocated and the scale_inv can be + // calculated here. + if (is_fp8_dtype(outputD->data.dtype) && outputD->scale_inv.dptr) { update_tensor_scale_inv(outputD, stream); } @@ -309,9 +453,14 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons Tensor *outputGelu = reinterpret_cast(pre_gelu_out); Tensor *wspace = reinterpret_cast(workspace); - const int m = transa ? inputA->data.shape[0] : inputA->data.shape[1]; - const int k = transa ? inputA->data.shape[1] : inputA->data.shape[0]; - const int n = transb ? inputB->data.shape[1] : inputB->data.shape[0]; + const size_t A0 = inputA->flat_first_dim(); + const size_t A1 = inputA->flat_last_dim(); + const size_t B0 = inputB->flat_first_dim(); + const size_t B1 = inputB->flat_last_dim(); + + const int m = transa ? A0 : A1; + const int k = transa ? A1 : A0; + const int n = transb ? B1 : B0; int lda, ldb, ldd; if (transa && !transb) { // TN lda = k; @@ -357,6 +506,10 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor const Tensor *inputCounter = reinterpret_cast(counter); Tensor *wspace = reinterpret_cast(workspace); + NVTE_CHECK(is_delayed_tensor_scaling(inputA->scaling_mode) && + is_delayed_tensor_scaling(inputB->scaling_mode), + "Atomic GEMM only supports delayed scaling."); + const int m = transa ? inputA->data.shape[0] : inputA->data.shape[1]; const int k = transa ? inputA->data.shape[1] : inputA->data.shape[0]; const int n = transb ? inputB->data.shape[1] : inputB->data.shape[0]; diff --git a/transformer_engine/common/include/transformer_engine/activation.h b/transformer_engine/common/include/transformer_engine/activation.h index 53a66c25b5..49029ed588 100644 --- a/transformer_engine/common/include/transformer_engine/activation.h +++ b/transformer_engine/common/include/transformer_engine/activation.h @@ -19,7 +19,9 @@ extern "C" { /* Supported activations: GeLU, SiLU, ReLU, QuickGeLU, SquaredReLU */ -/*! \brief Compute activation of the input. +/*! \brief Computes activation of the input. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. * * \param[in] input Input tensor for activation. * \param[in,out] output Output tensor. @@ -39,17 +41,59 @@ enum class NVTE_Activation_Type { SREGLU, }; +/*! \brief Computes the GeLU activation of the input. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] input Input tensor for activation. + * \param[in,out] output Output tensor. + * \param[in] stream CUDA stream used for the operation. + */ void nvte_gelu(const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Computes the SiLU activation of the input. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] input Input tensor for activation. + * \param[in,out] output Output tensor. + * \param[in] stream CUDA stream used for the operation. + */ void nvte_silu(const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Computes the ReLU activation of the input. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] input Input tensor for activation. + * \param[in,out] output Output tensor. + * \param[in] stream CUDA stream used for the operation. + */ void nvte_relu(const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Computes the Quick GeLU activation of the input. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] input Input tensor for activation. + * \param[in,out] output Output tensor. + * \param[in] stream CUDA stream used for the operation. + */ void nvte_qgelu(const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Computes the Squared ReLU activation of the input. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] input Input tensor for activation. + * \param[in,out] output Output tensor. + * \param[in] stream CUDA stream used for the operation. + */ void nvte_srelu(const NVTETensor input, NVTETensor output, cudaStream_t stream); -/*! \brief Compute activation gradient. +/*! \brief Computes the GeLU activation gradient. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. * * \param[in] grad Incoming gradient. * \param[in] input Input tensor for activation. @@ -59,19 +103,57 @@ void nvte_srelu(const NVTETensor input, NVTETensor output, cudaStream_t stream); void nvte_dgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Computes the SiLU activation gradient. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] grad Incoming gradient. + * \param[in] input Input tensor for activation. + * \param[in,out] output Output tensor. + * \param[in] stream CUDA stream used for the operation. + */ void nvte_dsilu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Computes the ReLU activation gradient. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] grad Incoming gradient. + * \param[in] input Input tensor for activation. + * \param[in,out] output Output tensor. + * \param[in] stream CUDA stream used for the operation. + */ void nvte_drelu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Computes the Quick GeLU activation gradient. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] grad Incoming gradient. + * \param[in] input Input tensor for activation. + * \param[in,out] output Output tensor. + * \param[in] stream CUDA stream used for the operation. + */ void nvte_dqgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Computes the Squared ReLU activation gradient. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] grad Incoming gradient. + * \param[in] input Input tensor for activation. + * \param[in,out] output Output tensor. + * \param[in] stream CUDA stream used for the operation. + */ void nvte_dsrelu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream); -/*! \brief Compute gated activation of the input. +/*! \brief Computes the gated GeLU activation of the input. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. * * \param[in] input Input tensor of shape [N, H * 2]. * \param[in,out] output Output tensor of shape [N, H]. @@ -80,15 +162,54 @@ void nvte_dsrelu(const NVTETensor grad, const NVTETensor input, NVTETensor outpu */ void nvte_geglu(const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Computes the gated Swish activation of the input. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] input Input tensor of shape [N, H * 2]. + * \param[in,out] output Output tensor of shape [N, H]. + * It computes Act(input[N, :H]) x input[N, H:] + * \param[in] stream CUDA stream used for the operation. + */ void nvte_swiglu(const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Computes the gated ReLU activation of the input. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] input Input tensor of shape [N, H * 2]. + * \param[in,out] output Output tensor of shape [N, H]. + * It computes Act(input[N, :H]) x input[N, H:] + * \param[in] stream CUDA stream used for the operation. + */ void nvte_reglu(const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Computes the gated Quick GeLU activation of the input. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] input Input tensor of shape [N, H * 2]. + * \param[in,out] output Output tensor of shape [N, H]. + * It computes Act(input[N, :H]) x input[N, H:] + * \param[in] stream CUDA stream used for the operation. + */ void nvte_qgeglu(const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Computes the gated Squared ReLU activation of the input. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] input Input tensor of shape [N, H * 2]. + * \param[in,out] output Output tensor of shape [N, H]. + * It computes Act(input[N, :H]) x input[N, H:] + * \param[in] stream CUDA stream used for the operation. + */ void nvte_sreglu(const NVTETensor input, NVTETensor output, cudaStream_t stream); -/*! \brief Compute gated activation gradient. +/*! \brief Computes the gated GeLU activation gradient. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * * \param[in] grad Incoming gradient of shape [N, H]. * \param[in] input Forward input tensor of shape [N, H * 2]. * \param[in,out] output Outgoing gradient of shape [N, H * 2]. @@ -97,15 +218,51 @@ void nvte_sreglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) void nvte_dgeglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Computes the gated Swish activation gradient. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] grad Incoming gradient of shape [N, H]. + * \param[in] input Forward input tensor of shape [N, H * 2]. + * \param[in,out] output Outgoing gradient of shape [N, H * 2]. + * \param[in] stream CUDA stream used for the operation. + */ void nvte_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Computes the gated ReLU activation gradient. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] grad Incoming gradient of shape [N, H]. + * \param[in] input Forward input tensor of shape [N, H * 2]. + * \param[in,out] output Outgoing gradient of shape [N, H * 2]. + * \param[in] stream CUDA stream used for the operation. + */ void nvte_dreglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Computes the gated Quick GeLU activation gradient. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] grad Incoming gradient of shape [N, H]. + * \param[in] input Forward input tensor of shape [N, H * 2]. + * \param[in,out] output Outgoing gradient of shape [N, H * 2]. + * \param[in] stream CUDA stream used for the operation. + */ void nvte_dqgeglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Computes the gated Squared ReLU activation gradient. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] grad Incoming gradient of shape [N, H]. + * \param[in] input Forward input tensor of shape [N, H * 2]. + * \param[in,out] output Outgoing gradient of shape [N, H * 2]. + * \param[in] stream CUDA stream used for the operation. + */ void nvte_dsreglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream); diff --git a/transformer_engine/common/include/transformer_engine/cast.h b/transformer_engine/common/include/transformer_engine/cast.h index 88a7dec251..d57975b2f4 100644 --- a/transformer_engine/common/include/transformer_engine/cast.h +++ b/transformer_engine/common/include/transformer_engine/cast.h @@ -5,7 +5,7 @@ ************************************************************************/ /*! \file cast.h - * \brief Functions to cast to/from FP8. + * \brief Functions to cast to/from FP8/MXFP8. */ #ifndef TRANSFORMER_ENGINE_CAST_H_ @@ -17,21 +17,200 @@ extern "C" { #endif -/*! \brief Cast tensor to FP8. +/* Cast the tensor to FP8 (or microscaling FP8 if the compute capability of the device is 10.0 or newer) + * The implementation is per the microscaling format MXFP8 defined by the OCP specification: + * https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf * - * \param[in] input Input tensor to be cast. - * \param[in,out] output Output FP8 tensor. - * \param[in] stream CUDA stream used for the operation. + * Supported modes of scaling (live scaling): + * 1) Rowwise scaling (along the dim=0) computes one set of the output data, which includes: + * - the scaled output tensor + * - the corresponding scaling factors + * The scaling factors are computed for blocks of the shape [1,32] + * (i.e., each scaling factor spans 32 contiguous elements along rows). + * + * 2) Columwise scaling (along the dim=1) computes one set of the output data. + * The scaling factors are computed for blocks of the shape [32,1] + * (i.e., each scaling factor spans 32 contiguous elements along columns). + * + * 3) Both rowwise AND columnwise scaling (along the dim=0 and the dim=1) + * computes two sets of the output data: both 1) and 2). + * + * The shape of the MX block must be specified in the 'output' argument, + * and can be either [1,32] or [32,1] as no other shapes are currently supported. + * + * To cast the input tensor to the MXFP8, the scaling_mode.delayed_scaling parameter + * of the output tensor should be set to 0. + */ + +/*! \brief Casts input tensor to FP8/MXFP8. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] input Input tensor to be cast. + * \param[in,out] output Output FP8/MXFP8 tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_quantize(const NVTETensor input, NVTETensor output, cudaStream_t stream); + +/*! \brief Casts input tensor to FP8/MXFP8, providing the option to immediately exit the kernel + * based on the value of the 'noop' tensor. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] input Input tensor to be cast. + * \param[in,out] output Output FP8/MXFP8 tensor. + * \param[out] noop Noop tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_quantize_noop(const NVTETensor input, NVTETensor output, NVTETensor noop, + cudaStream_t stream); + +/*! \brief Casts input tensor to MXFP8. Additionally, reduces the input along columns. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * This function produces 2 results: + * - `output` is equal to `cast(dact(input))` + * - `dbias` is equal to `reduce(dact(input), dim=1)` + * + * Calling this function with the workspace being an empty tensor will not perform the operation, + * but instead set the shape and type of the workspace tensor to the required values. + * + * \param[in] input Input tensor to be cast. + * \param[in,out] output Output FP8/MXFP8 tensor. + * \param[out] dbias Result of the reduction of the input along columns. + * \param[out] workspace Workspace tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor dbias, + NVTETensor workplace, cudaStream_t stream); + +/*! \brief Computes backward of GeLU operation on the input, then casts to FP8/MXFP8. + * Additionally, reduces the result of the GeLU backward along columns. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * This function produces 2 results: + * - `output` is equal to `cast(dact(input))` + * - `dbias` is equal to `reduce(dact(input), dim=1)` + * + * Calling this function with the workspace being an empty tensor will not perform the operation, + * but instead set the shape and type of the workspace tensor to the required values. + * + * \param[in] input Input tensor to be cast. + * \param[in] act_input Activation input tensor. + * \param[in,out] output Output FP8/MXFP8 tensor. + * \param[out] dbias Result of the reduction of the input along columns. + * \param[out] workspace Workspace tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor act_input, + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + cudaStream_t stream); + +/*! \brief Computes backward of SiLU operation on the input, then casts to FP8/MXFP8. + * Additionally, reduces the result of the SiLU backward along columns. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * This function produces 2 results: + * - `output` is equal to `cast(dact(input))` + * - `dbias` is equal to `reduce(dact(input), dim=1)` + * + * Calling this function with the workspace being an empty tensor will not perform the operation, + * but instead set the shape and type of the workspace tensor to the required values. + * + * \param[in] input Input tensor to be cast. + * \param[in] act_input Activation input tensor. + * \param[in,out] output Output FP8/MXFP8 tensor. + * \param[out] dbias Result of the reduction of the input along columns. + * \param[out] workspace Workspace tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor act_input, + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + cudaStream_t stream); + +/*! \brief Computes backward of ReLU operation on the input, then casts to FP8/MXFP8. + * Additionally, reduces the result of the ReLU backward along columns. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * This function produces 2 results: + * - `output` is equal to `cast(dact(input))` + * - `dbias` is equal to `reduce(dact(input), dim=1)` + * + * Calling this function with the workspace being an empty tensor will not perform the operation, + * but instead set the shape and type of the workspace tensor to the required values. + * + * \param[in] input Input tensor to be cast. + * \param[in] act_input Activation input tensor. + * \param[in,out] output Output FP8/MXFP8 tensor. + * \param[out] dbias Result of the reduction of the input along columns. + * \param[out] workspace Workspace tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_quantize_dbias_drelu(const NVTETensor input, const NVTETensor act_input, + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + cudaStream_t stream); + +/*! \brief Computes backward of Quick GeLU operation on the input, then casts to FP8/MXFP8. + * Additionally, reduces the result of the Quick GeLU backward along columns. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * This function produces 2 results: + * - `output` is equal to `cast(dact(input))` + * - `dbias` is equal to `reduce(dact(input), dim=1)` + * + * Calling this function with the workspace being an empty tensor will not perform the operation, + * but instead set the shape and type of the workspace tensor to the required values. + * + * \param[in] input Input tensor to be cast. + * \param[in] act_input Activation input tensor. + * \param[in,out] output Output FP8/MXFP8 tensor. + * \param[out] dbias Result of the reduction of the input along columns. + * \param[out] workspace Workspace tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor act_input, + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + cudaStream_t stream); + +/*! \brief Computes backward of Squared ReLU operation on the input, then casts to FP8/MXFP8. + * Additionally, reduces the result of the Squared ReLU backward along columns. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * This function produces 2 results: + * - `output` is equal to `cast(dact(input))` + * - `dbias` is equal to `reduce(dact(input), dim=1)` + * + * Calling this function with the workspace being an empty tensor will not perform the operation, + * but instead set the shape and type of the workspace tensor to the required values. + * + * \param[in] input Input tensor to be cast. + * \param[in] act_input Activation input tensor. + * \param[in,out] output Output FP8/MXFP8 tensor. + * \param[out] dbias Result of the reduction of the input along columns. + * \param[out] workspace Workspace tensor. + * \param[in] stream CUDA stream used for the operation. */ -void nvte_fp8_quantize(const NVTETensor input, NVTETensor output, cudaStream_t stream); +void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor act_input, + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + cudaStream_t stream); -/*! \brief Cast tensor from FP8. +/*! \brief Casts input tensor from reduced to higher precision. + * If the scaling mode of the input tensor is set to NVTE_MXFP8_1D_SCALING, + * the block dequantization (MXFP8) of the specified shape of the block will be used. + * In case of the MXFP8 dequantization, the dequantized values are stored to the rowwise + * data of the output tensor, regardless of whether the row- or columnwise scaling is used. * - * \param[in] input Input tensor to be cast. - * \param[out] output Output tensor. + * \param[in] input Input FP8/MXFP8 tensor to be cast. + * \param[in,out] output Output tensor. * \param[in] stream CUDA stream used for the operation. */ -void nvte_fp8_dequantize(const NVTETensor input, NVTETensor output, cudaStream_t stream); +void nvte_dequantize(const NVTETensor input, NVTETensor output, cudaStream_t stream); #ifdef __cplusplus } // extern "C" diff --git a/transformer_engine/common/include/transformer_engine/cast_transpose_noop.h b/transformer_engine/common/include/transformer_engine/cast_transpose_noop.h index ea3bdcd14e..678ffe9191 100644 --- a/transformer_engine/common/include/transformer_engine/cast_transpose_noop.h +++ b/transformer_engine/common/include/transformer_engine/cast_transpose_noop.h @@ -17,11 +17,26 @@ extern "C" { #endif +/*! \brief Transposes the input, providing the option to immediately exit the kernel + * based on the value of the 'noop' tensor. + * + * \param[in] input Input tensor. + * \param[in] noop Noop tensor. + * \param[in,out] output Output tensor. + * \param[in] stream CUDA stream used for the operation. + */ void nvte_transpose_with_noop(const NVTETensor input, const NVTETensor noop, NVTETensor output, cudaStream_t stream); -void nvte_cast_transpose_with_noop(const NVTETensor input, const NVTETensor noop, - NVTETensor cast_output, NVTETensor transposed_output, +/*! \brief Casts and transposes the input, providing the option to immediately exit the kernel + * based on the value of the 'noop' tensor. + * + * \param[in] input Input tensor. + * \param[in] noop Noop tensor. + * \param[in,out] output Output tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_cast_transpose_with_noop(const NVTETensor input, const NVTETensor noop, NVTETensor output, cudaStream_t stream); #ifdef __cplusplus diff --git a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h index 8e0d017a0d..293c57526d 100644 --- a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h +++ b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h @@ -53,6 +53,8 @@ class CommOverlapCore { int _cga_size; int _use_ce; int _ub_reg; + int _gemm_priority; + int _comm_priority; bool _atomic_gemm{false}; bool _is_p2p{false}; @@ -65,10 +67,13 @@ class CommOverlapCore { cudaEvent_t _start_compute, _stop_compute, _start_comm, _stop_comm, _comm_launch_event; public: + CommOverlapCore() {} // dummy constructor for exposing type to Python + CommOverlapCore(int myrank, int numranks, int mylocal, int numlocal, int mynode, int numnodes, int tp_size, ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle, - int num_splits, int num_max_streams, int comm_cga_size, int num_comm_sm, - bool set_sm_margin, bool use_ce, bool atomic_gemm); + int num_splits, int num_max_streams, int comm_cga_size, int gemm_priority, + int comm_priority, int num_comm_sm, bool set_sm_margin, bool use_ce, + bool atomic_gemm); virtual ~CommOverlapCore(); @@ -77,25 +82,76 @@ class CommOverlapCore { _ubuf_scale_inv_initialized = true; } + TensorWrapper get_tensor_chunk(const TensorWrapper &source, size_t offset, + const std::vector &shape); + + TensorWrapper get_buffer_chunk_like(const TensorWrapper &source, size_t offset, + const std::vector &shape); + bool is_atomic_gemm() { return _atomic_gemm; } bool is_p2p_overlap() { return _is_p2p; } bool is_fp8_ubuf() { return _ubuf.element_size() == 1; } + + virtual void bulk_overlap(const TensorWrapper &A, bool transa, const TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, + bool accumulate, bool use_split_accumulator, CommOverlapType comm_type, + TensorWrapper &rs_output, cudaStream_t stream_main) { + NVTE_ERROR("Operation is not implemented."); + } + + virtual void atomic_gemm_overlap_rs(const TensorWrapper &A, bool transa, const TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, TensorWrapper &workspace, + bool grad, bool accumulate, bool use_split_accumulator, + TensorWrapper &rs_output, cudaStream_t stream_main) { + NVTE_ERROR("Operation is not implemented."); + } + + virtual void split_overlap_rs(const TensorWrapper &A, bool transa, const TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, + bool accumulate, bool use_split_accumulator, + TensorWrapper &rs_output, cudaStream_t stream_main) { + NVTE_ERROR("Operation is not implemented."); + } + + virtual void atomic_gemm_overlap_ag(const TensorWrapper &A, bool transa, const TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, TensorWrapper &workspace, + bool grad, bool accumulate, bool use_split_accumulator, + TensorWrapper &B_copy, cudaStream_t stream_main) { + NVTE_ERROR("Operation is not implemented."); + } + + virtual void split_overlap_ag(const TensorWrapper &A, bool transa, const TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, + bool accumulate, bool use_split_accumulator, TensorWrapper &B_copy, + cudaStream_t stream_main) { + NVTE_ERROR("Operation is not implemented."); + } }; // CommOverlapCore class CommOverlapBase : public CommOverlapCore { protected: int _rs_kernel_type; + bool _rs_overlap_first_gemm; cudaStream_t _stream_comm; cudaEvent_t _start_d2dcopy; public: + CommOverlapBase() {} // dummy constructor for exposing type to Python + CommOverlapBase(const std::vector &buffer_shape, DType buffer_dtype, int myrank, int numranks, int mylocal, int numlocal, int mynode, int numnodes, int tp_size, ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle, int num_splits = 3, int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, int comm_cga_size = 2, - int num_comm_sm = 16, bool set_sm_margin = true, bool atomic_gemm = false); + int gemm_priority = 0, int comm_priority = 0, int num_comm_sm = 16, + bool set_sm_margin = true, bool atomic_gemm = false, + bool rs_overlap_first_gemm = false); virtual ~CommOverlapBase(); @@ -103,97 +159,124 @@ class CommOverlapBase : public CommOverlapCore { ** Bulk GEMM + COMM ** This function assumes the communication input is pre-copied to _ubuf */ - void bulk_overlap(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, TensorWrapper &D, - TensorWrapper &bias, TensorWrapper &pre_gelu_out, TensorWrapper &workspace, - bool grad, bool accumulate, bool use_split_accumulator, - CommOverlapType comm_type, TensorWrapper &rs_output, cudaStream_t stream_main); + void bulk_overlap(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, + TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, + TensorWrapper &workspace, bool grad, bool accumulate, + bool use_split_accumulator, CommOverlapType comm_type, TensorWrapper &rs_output, + cudaStream_t stream_main) override; + + void atomic_gemm_overlap_ag(const TensorWrapper &A, bool transa, const TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, + bool accumulate, bool use_split_accumulator, TensorWrapper &B_copy, + cudaStream_t stream_main) override { + NVTE_ERROR("Operation not supported."); + } + + void split_overlap_ag(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, + TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, + TensorWrapper &workspace, bool grad, bool accumulate, + bool use_split_accumulator, TensorWrapper &B_copy, + cudaStream_t stream_main) override { + NVTE_ERROR("Operation not supported."); + } /* ** Split FPROP GEMM + ReduceScatter */ - void atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, - TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, - TensorWrapper &workspace, bool grad, bool accumulate, - bool use_split_accumulator, bool gemm_overlap, - TensorWrapper &rs_output, cudaStream_t stream_main); + void atomic_gemm_overlap_rs(const TensorWrapper &A, bool transa, const TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, + bool accumulate, bool use_split_accumulator, TensorWrapper &rs_output, + cudaStream_t stream_main) override; /* ** Split FPROP GEMM + ReduceScatter */ - void split_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, + void split_overlap_rs(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, bool accumulate, - bool use_split_accumulator, bool gemm_overlap, TensorWrapper &rs_output, - cudaStream_t stream_main); + bool use_split_accumulator, TensorWrapper &rs_output, + cudaStream_t stream_main) override; }; // CommOverlapBase class CommOverlapP2PBase : public CommOverlapCore { protected: bool _is_reduce_scatter{false}; bool _use_multiatomic_ag{false}; - + bool _aggregate; int _next_rank; int _prev_rank; int _rank_round_tp; - int _aggregate; int _num_ubuf_chunks; int _self_chunk_id; - std::vector _ubufs; - - cudaStream_t _stream_send; + std::vector _stream_send; cudaStream_t _stream_recv; cudaEvent_t _stop_send, _stop_recv; public: + CommOverlapP2PBase() {} // dummy constructor for exposing type to Python + CommOverlapP2PBase(const std::vector &buffer_shape, DType buffer_dtype, int myrank, int numranks, int mylocal, int numlocal, int mynode, int numnodes, int tp_size, ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle, CommOverlapType comm_type, int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, - int comm_cga_size = 1, int num_comm_sm = 1, bool set_sm_margin = false, - bool use_ce = true, bool atomic_gemm = false, bool aggregate = false); + int comm_cga_size = 1, int gemm_priority = 0, int comm_priority = 0, + int num_comm_sm = 1, bool set_sm_margin = false, bool use_ce = true, + bool atomic_gemm = false, bool aggregate = false); virtual ~CommOverlapP2PBase(); + TensorWrapper get_buffer_chunk_by_id(const TensorWrapper &source, size_t buffer_id); + + void bulk_overlap(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, + TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, + TensorWrapper &workspace, bool grad, bool accumulate, + bool use_split_accumulator, CommOverlapType comm_type, TensorWrapper &rs_output, + cudaStream_t stream_main) override { + NVTE_ERROR("Operation not supported."); + } + /* ** Split AllGather + AtomicGEMM using P2P communication ** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is needed to have AG ** outputs in each rank to be in the contiguous memory space after all ring exchange phases. */ - void atomic_gemm_overlap_ag(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, - TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, - TensorWrapper &workspace, bool grad, bool accumulate, - bool use_split_accumulator, TensorWrapper &B_copy, - cudaStream_t stream_main); + void atomic_gemm_overlap_ag(const TensorWrapper &A, bool transa, const TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, + bool accumulate, bool use_split_accumulator, TensorWrapper &B_copy, + cudaStream_t stream_main) override; /* ** Split AllGather + GEMM using P2P communication ** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is needed to have AG ** outputs in each rank to be in the contiguous memory space after all ring exchange phases. */ - void split_overlap_ag(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, + void split_overlap_ag(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, TensorWrapper &B_copy, - cudaStream_t stream_main); + cudaStream_t stream_main) override; /* ** Split ReduceScatter + GEMM using P2P communication */ - void atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, - TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, - TensorWrapper &workspace, bool grad, bool accumulate, - bool use_split_accumulator, TensorWrapper &rs_output, - cudaStream_t stream_main); + void atomic_gemm_overlap_rs(const TensorWrapper &A, bool transa, const TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, + bool accumulate, bool use_split_accumulator, TensorWrapper &rs_output, + cudaStream_t stream_main) override; /* ** Split ReduceScatter + GEMM using P2P communication */ - void split_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, + void split_overlap_rs(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, TensorWrapper &rs_output, - cudaStream_t stream_main); + cudaStream_t stream_main) override; }; // CommOverlapP2PBase } // namespace transformer_engine diff --git a/transformer_engine/common/include/transformer_engine/recipe.h b/transformer_engine/common/include/transformer_engine/recipe.h index a076a4e89a..b30a6e1338 100644 --- a/transformer_engine/common/include/transformer_engine/recipe.h +++ b/transformer_engine/common/include/transformer_engine/recipe.h @@ -28,16 +28,10 @@ extern "C" { * \param[in] amax_history History of maximum absolute values. * Shape: [history_length, num_scales] * \param[in] scale Scaling factor for casting to FP8. Shape: [num_scales] - * \param[in] scale_inv Scaling factor for casting from FP8. Shape: [num_scales] - * \param[in] scale_inv_mask Boolean mask indicating scale_inv entries to update. May be - * empty, in which case all scale_inv entries are updated. - * Shape: [num_scales] * \param[out] updated_amax_history Updated history of maximum absolute values. * Shape: [history_length, num_scales] * \param[out] updated_scale Updated scaling factor for casting to FP8. * Shape: [num_scales] - * \param[out] updated_scale_inv Updated scaling factor for casting from FP8. - * Shape: [num_scales] * \param[in] amax_compute_algo Method to reduce amax history. Options are "max" and * "most_recent". * \param[in] fp8_dtype FP8 datatype. @@ -45,9 +39,8 @@ extern "C" { * \param[in] stream CUDA stream. */ void nvte_delayed_scaling_recipe_amax_and_scale_update( - const NVTETensor amax_history, const NVTETensor scale, const NVTETensor scale_inv, - const NVTETensor scale_inv_mask, NVTETensor updated_amax_history, NVTETensor updated_scale, - NVTETensor updated_scale_inv, const char* amax_compute_algo, NVTEDType fp8_dtype, float margin, + const NVTETensor amax_history, const NVTETensor scale, NVTETensor updated_amax_history, + NVTETensor updated_scale, const char* amax_compute_algo, NVTEDType fp8_dtype, float margin, cudaStream_t stream); /*! \brief Bulk-update FP8 scaling factors with delayed scaling recipe after amax reduction. @@ -55,7 +48,7 @@ void nvte_delayed_scaling_recipe_amax_and_scale_update( * Operations performed include, updating the most recent amax history * with the relevant segment of global reduction buffer if it's not 0, * rotating the amax history based on the rule below, and updating the - * scales and scale_invs. + * scales. * * The amax history is rotated by -1 (e.g. the first entry shifts to * the last, the last entry shifts to the second to last) and the @@ -69,8 +62,6 @@ void nvte_delayed_scaling_recipe_amax_and_scale_update( * Shape: num_tensors x [history_length, num_scales] * \param[in,out] scales List of scaling factors for casting to FP8. * Shape: num_tensors x [num_scales] - * \param[in,out] scale_invs List of scaling factors for casting from FP8. - * Shape: num_tensors x [num_scales] * \param[in] amax_compute_algo Method to reduce amax history. Options are "max" and * "most_recent". * \param[in] fp8_dtype FP8 datatype. @@ -79,8 +70,8 @@ void nvte_delayed_scaling_recipe_amax_and_scale_update( */ void nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction( const NVTETensor amax_reduction_buffer, std::vector amax_histories, - std::vector scales, std::vector scale_invs, - const char* amax_compute_algo, NVTEDType fp8_dtype, float margin, cudaStream_t stream); + std::vector scales, const char* amax_compute_algo, NVTEDType fp8_dtype, + float margin, cudaStream_t stream); #ifdef __cplusplus } // extern "C" diff --git a/transformer_engine/common/include/transformer_engine/swizzle.h b/transformer_engine/common/include/transformer_engine/swizzle.h new file mode 100644 index 0000000000..de5a11eb73 --- /dev/null +++ b/transformer_engine/common/include/transformer_engine/swizzle.h @@ -0,0 +1,37 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file cast.h + * \brief Functions to cast to/from FP8. + */ + +#ifndef TRANSFORMER_ENGINE_SWIZZLE_H_ +#define TRANSFORMER_ENGINE_SWIZZLE_H_ + +#include "transformer_engine.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/*! \brief Swizzling scaling factors into the required interleaved layout for GEMM + * + * \param[in] input Input tensor with non-swizzled scale_inv. + * \param[in,out] output Output tensor which hosts swizzled scale_inv. + * \param[in] stream CUDA stream used for the operation. + * + * Requirements: + * - scale_inv is stored in row-major. + * - scale_inv size is padded to 128x4 for row-scale and 4x128 for col-scale. + * - data is quantitized along K-dimension, i.e. 1D-scaling block lies along the K-dimension. + */ +void nvte_swizzle_scaling_factors(const NVTETensor input, NVTETensor output, cudaStream_t stream); + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // TRANSFORMER_ENGINE_SWIZZLE_H_ diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index 99b3508362..e393dbffc4 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -30,6 +30,7 @@ enum NVTEDType { kNVTEBFloat16 = 5, /*!< 16-bit bfloat (E8M7) */ kNVTEFloat8E4M3 = 6, /*!< 8-bit float (E4M3) */ kNVTEFloat8E5M2 = 7, /*!< 8-bit float (E5M2) */ + kNVTEFloat8E8M0 = 8, /*!< 8-bit float (E8M0) */ kNVTENumTypes /*!< Number of supported types */ }; @@ -43,6 +44,42 @@ struct NVTEShape { size_t ndim; }; +/*! \struct NVTEBasicTensor + * \brief A basic tensor type used to populate parameters of NVTETensor. + * It does not own the memory it points to. + */ +struct NVTEBasicTensor { + void *data_ptr; + NVTEDType dtype; + NVTEShape shape; +}; + +/*! \enum NVTETensorParam + * \brief Indicates the kind of the tensor parameter to set/get. + */ +enum NVTETensorParam { + kNVTERowwiseData = 0, /*!< Data usable in rowwise manner */ + kNVTEColumnwiseData = 1, /*!< Data usable in columnwise manner */ + kNVTEScale = 2, /*!< Scale tensor */ + kNVTEAmax = 3, /*!< Amax tensor */ + kNVTERowwiseScaleInv = 4, /*!< Scale inverse tensor for decoding Rowwise Data */ + kNVTEColumnwiseScaleInv = 5, /*!< Scale inverse tensor for decoding Columnwise Data */ + kNVTENumTensorParams +}; + +/*! \enum NVTEScalingMode + * \brief Granularity of scaling: + */ +enum NVTEScalingMode { + /*! Single scale per tensor, computed in delayed manner. + Used also for high precision data, without scaling */ + NVTE_DELAYED_TENSOR_SCALING = 0, + /*! Single scale per block of 32 elements consecutive in either + rowwise or columnwise direction */ + NVTE_MXFP8_1D_SCALING = 1, + NVTE_INVALID_SCALING +}; + /*! \brief TE Tensor type * * NVTETensor is a contiguous tensor type storing a pointer @@ -53,21 +90,15 @@ typedef void *NVTETensor; /*! \brief Create a new TE tensor. * - * Create a new TE tensor with a given shape, datatype and data. + * Create a new TE tensor. Before use its parameters need to be set. * TE tensors are just wrappers on top of raw data and do not * own memory. * - * \param[in] dptr Pointer to the tensor data. - * \param[in] shape Shape of the tensor. - * \param[in] dtype Data type of the tensor. - * \param[in] amax_dptr Pointer to the AMAX value. - * \param[in] scale_dptr Pointer to the scale value. - * \param[in] scale_inv_dptr Pointer to the inverse of scale value. + * \param[in] scaling_mode Scaling mode of the tensor. * * \return A new TE tensor. */ -NVTETensor nvte_create_tensor(void *dptr, const NVTEShape shape, const NVTEDType dtype, - float *amax_dptr, float *scale_dptr, float *scale_inv_dptr); +NVTETensor nvte_create_tensor(NVTEScalingMode scaling_mode); /*! \brief Destroy a TE tensor. * @@ -78,14 +109,22 @@ NVTETensor nvte_create_tensor(void *dptr, const NVTEShape shape, const NVTEDType */ void nvte_destroy_tensor(NVTETensor tensor); -/*! \brief Get a raw pointer to the tensor's data. +/*! \brief Get a raw pointer to the tensor's rowwise data. * * \param[in] tensor Tensor. * - * \return A raw pointer to tensor's data. + * \return A raw pointer to tensor's rowwise data. */ void *nvte_tensor_data(const NVTETensor tensor); +/*! \brief Get a raw pointer to the tensor's columnwise data. + * + * \param[in] tensor Tensor. + * + * \return A raw pointer to tensor's columnwise data. + */ +void *nvte_tensor_columnwise_data(const NVTETensor tensor); + /*! \brief Get a tensor's data shape. * * \param[in] tensor Tensor. @@ -94,6 +133,14 @@ void *nvte_tensor_data(const NVTETensor tensor); */ NVTEShape nvte_tensor_shape(const NVTETensor tensor); +/*! \brief Get a tensor's data shape. + * + * \param[in] tensor Tensor. + * + * \return A shape of the input tensor. + */ +NVTEShape nvte_tensor_columnwise_shape(const NVTETensor tensor); + /*! \brief Get a tensor's number of dimensions. * * \param[in] tensor Tensor. @@ -159,6 +206,46 @@ float *nvte_tensor_scale(const NVTETensor tensor); */ float *nvte_tensor_scale_inv(const NVTETensor tensor); +/*! \brief Get a tensor's scale_inv shape. + * + * \param[in] tensor Tensor. + * + * \return A scale_inv shape of the input tensor. + */ +NVTEShape nvte_tensor_scale_inv_shape(const NVTETensor tensor); + +/*! \brief Reset tensor value to zero. + * + * \param[in] tensor Tensor. + * + * \return A scale_inv shape of the input tensor. + */ +void nvte_zero_tensor(const NVTETensor tensor, cudaStream_t stream); + +/*! \brief Set a parameter of the tensor. + * + * \param[in/out] tensor Tensor. + * \param[in] param_name The parameter to be set. + * \param[in] param The value to be set. + */ +void nvte_set_tensor_param(NVTETensor *tensor, NVTETensorParam param_name, + const NVTEBasicTensor *param); + +/*! \brief Get a value of the parameter of the tensor. + * + * \param[in] tensor Tensor. + * \param[in] param_name The parameter to be set. + */ +NVTEBasicTensor nvte_get_tensor_param(const NVTETensor tensor, NVTETensorParam param_name); + +/*! \brief Get the granularity of scaling of this tensor. + * + * \param[in] tensor Tensor. + * + * \return A struct containing the granularity of tensor's scaling. + */ +NVTEScalingMode nvte_tensor_scaling_mode(const NVTETensor tensor); + /*! \struct NVTETensorPack \brief Pack of tensors, generally used for auxiliary outputs. */ @@ -201,6 +288,7 @@ enum class DType { kBFloat16 = 5, kFloat8E4M3 = 6, kFloat8E5M2 = 7, + kFloat8E8M0 = 8, kNumTypes }; @@ -220,12 +308,23 @@ class TensorWrapper { * \param[in] dtype Data type of the tensor. * \param[in] amax_dptr Pointer to the AMAX value. * \param[in] scale_dptr Pointer to the scale value. + * \param[in] scale_inv_shape Shape of scale_inv * \param[in] scale_inv_dptr Pointer to the inverse of scale value. */ TensorWrapper(void *dptr, const NVTEShape &shape, const DType dtype, float *amax_dptr = nullptr, - float *scale_dptr = nullptr, float *scale_inv_dptr = nullptr) - : tensor_(nvte_create_tensor(dptr, shape, static_cast(dtype), amax_dptr, - scale_dptr, scale_inv_dptr)) {} + float *scale_dptr = nullptr, float *scale_inv_dptr = nullptr, + const NVTEShape scale_inv_shape = defaultShape, + const NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING) { + tensor_ = nvte_create_tensor(scaling_mode); + NVTEBasicTensor data = {dptr, static_cast(dtype), shape}; + nvte_set_tensor_param(&tensor_, kNVTERowwiseData, &data); + NVTEBasicTensor amax = {amax_dptr, kNVTEFloat32, defaultShape}; + nvte_set_tensor_param(&tensor_, kNVTEAmax, &amax); + NVTEBasicTensor scale = {scale_dptr, kNVTEFloat32, defaultShape}; + nvte_set_tensor_param(&tensor_, kNVTEScale, &scale); + NVTEBasicTensor scale_inv = {scale_inv_dptr, kNVTEFloat32, scale_inv_shape}; + nvte_set_tensor_param(&tensor_, kNVTERowwiseScaleInv, &scale_inv); + } /*! \brief Constructs new TensorWrapper. * @@ -238,19 +337,23 @@ class TensorWrapper { * \param[in] dtype Data type of the tensor. * \param[in] amax_dptr Pointer to the AMAX value. * \param[in] scale_dptr Pointer to the scale value. + * \param[in] scale_inv_shape Shape of scale_inv * \param[in] scale_inv_dptr Pointer to the inverse of scale value. */ TensorWrapper(void *dptr, const std::vector &shape, const DType dtype, float *amax_dptr = nullptr, float *scale_dptr = nullptr, - float *scale_inv_dptr = nullptr) + float *scale_inv_dptr = nullptr, const std::vector &scale_inv_shape = {1}, + const NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING) : TensorWrapper(dptr, NVTEShape{shape.data(), shape.size()}, dtype, amax_dptr, scale_dptr, - scale_inv_dptr) {} + scale_inv_dptr, NVTEShape{scale_inv_shape.data(), scale_inv_shape.size()}, + scaling_mode) {} /*! \brief Constructs new empty TensorWrapper. * * Create a new empty TE tensor which holds nothing. */ - TensorWrapper() : TensorWrapper(nullptr, std::vector(), DType::kFloat32) {} + explicit TensorWrapper(const NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING) + : tensor_(nvte_create_tensor(scaling_mode)) {} /*! \brief TensorWrapper destructor. */ ~TensorWrapper() { nvte_destroy_tensor(tensor_); } @@ -283,6 +386,70 @@ class TensorWrapper { return *this; } + // Parameter setters + template + TensorWrapper &set_parameter(const NVTETensorParam param, void *dptr, DType type, + const ShapeType &shape) noexcept { + NVTEShape nvte_shape = this->convertShape(shape); + NVTEBasicTensor data = {dptr, static_cast(type), nvte_shape}; + nvte_set_tensor_param(&tensor_, param, &data); + return *this; + } + + template + TensorWrapper &set_rowwise_data(void *dptr, DType type, const ShapeType &shape) noexcept { + return set_parameter(kNVTERowwiseData, dptr, type, shape); + } + + template + TensorWrapper &set_columnwise_data(void *dptr, DType type, const ShapeType &shape) noexcept { + return set_parameter(kNVTEColumnwiseData, dptr, type, shape); + } + + template + TensorWrapper &set_scale(void *dptr, DType type, const ShapeType &shape) noexcept { + return set_parameter(kNVTEScale, dptr, type, shape); + } + + template + TensorWrapper &set_amax(void *dptr, DType type, const ShapeType &shape) noexcept { + return set_parameter(kNVTEAmax, dptr, type, shape); + } + + template + TensorWrapper &set_rowwise_scale_inv(void *dptr, DType type, const ShapeType &shape) noexcept { + return set_parameter(kNVTERowwiseScaleInv, dptr, type, shape); + } + + template + TensorWrapper &set_columnwise_scale_inv(void *dptr, DType type, const ShapeType &shape) noexcept { + return set_parameter(kNVTEColumnwiseScaleInv, dptr, type, shape); + } + + // Parameter getters + + NVTEBasicTensor get_parameter(const NVTETensorParam param) const noexcept { + return nvte_get_tensor_param(tensor_, param); + } + + NVTEBasicTensor get_rowwise_data() const noexcept { return get_parameter(kNVTERowwiseData); } + + NVTEBasicTensor get_columnwise_data() const noexcept { + return get_parameter(kNVTEColumnwiseData); + } + + NVTEBasicTensor get_scale() const noexcept { return get_parameter(kNVTEScale); } + + NVTEBasicTensor get_amax() const noexcept { return get_parameter(kNVTEAmax); } + + NVTEBasicTensor get_rowwise_scale_inv() const noexcept { + return get_parameter(kNVTERowwiseScaleInv); + } + + NVTEBasicTensor get_columnwise_scale_inv() const noexcept { + return get_parameter(kNVTEColumnwiseScaleInv); + } + /*! \brief Get an underlying NVTETensor. * * \return NVTETensor held by this TensorWrapper. @@ -298,6 +465,15 @@ class TensorWrapper { return nvte_tensor_shape(tensor_); } + /*! \brief Get the shape of this TensorWrapper. + * + * \return Shape of this TensorWrapper. + */ + const NVTEShape columnwise_shape() const noexcept { + if (tensor_ == nullptr) return NVTEShape{nullptr, 0}; + return nvte_tensor_columnwise_shape(tensor_); + } + /*! \brief Get the size of this TensorWrapper in the given dimension. * * \param[in] size_t Dimension index. @@ -366,6 +542,15 @@ class TensorWrapper { return nvte_tensor_data(tensor_); } + /*! \brief Get a raw pointer to the tensor's data. + * + * \return A raw pointer to tensor's data. + */ + void *columnwise_dptr() const noexcept { + if (tensor_ == nullptr) return nullptr; + return nvte_tensor_columnwise_data(tensor_); + } + /*! \brief Get a pointer to the tensor's amax data. * * \return A pointer to tensor's amax data. @@ -393,7 +578,34 @@ class TensorWrapper { return nvte_tensor_scale_inv(tensor_); } + /*! \brief Get the scale_inv_shape of this TensorWrapper. + * + * \return scale_inv_shape of this TensorWrapper. + */ + const NVTEShape scale_inv_shape() const noexcept { + if (tensor_ == nullptr) return NVTEShape{nullptr, 0}; + return nvte_tensor_scale_inv_shape(tensor_); + } + + /*! \brief Get a scaling mode of the tensor. + * + * \return Scaling mode of the tensor. + */ + NVTEScalingMode scaling_mode() const noexcept { + if (tensor_ == nullptr) return NVTE_DELAYED_TENSOR_SCALING; + return nvte_tensor_scaling_mode(tensor_); + } + + void zero_(cudaStream_t stream) { nvte_zero_tensor(tensor_, stream); } + + static constexpr size_t defaultData = 1; + static constexpr NVTEShape defaultShape = {&defaultData, 1}; + private: + NVTEShape convertShape(const NVTEShape &s) { return s; } + + NVTEShape convertShape(const std::vector &s) { return {s.data(), s.size()}; } + /*! \brief Wrapped NVTETensor. */ NVTETensor tensor_ = nullptr; }; diff --git a/transformer_engine/common/include/transformer_engine/transpose.h b/transformer_engine/common/include/transformer_engine/transpose.h index 781f171cd8..a7db5cba47 100644 --- a/transformer_engine/common/include/transformer_engine/transpose.h +++ b/transformer_engine/common/include/transformer_engine/transpose.h @@ -20,16 +20,16 @@ extern "C" { /*! \brief Cast and transpose the input. * * This function casts the input and produces 2 results: - * - `cast_output` is the result of the cast - * - `transposed_output` is the transposed result of the cast. + * - rowwise data in `output` is the result of the cast + * - columnwise data in `output` is the transposed result of the cast. * - * \param[in] input Input tensor of shape [N, H]. - * \param[in,out] cast_output Result of the cast. Shape: [N, H]. - * \param[in,out] transposed_output Result of the cast and transpose. Shape: [H, N]. - * \param[in] stream CUDA stream used for the operation. + * \param[in] input Input tensor of shape [N, H]. + * \param[in,out] output Result of the cast and transpose. + * Shape of the rowwise data: [N, H]. + * Shape of the columnwise data: [H, N] + * \param[in] stream CUDA stream used for the operation. */ -void nvte_cast_transpose(const NVTETensor input, NVTETensor cast_output, - NVTETensor transposed_output, cudaStream_t stream); +void nvte_cast_transpose(const NVTETensor input, NVTETensor output, cudaStream_t stream); /*! \brief Transpose the input. * @@ -41,25 +41,24 @@ void nvte_transpose(const NVTETensor input, NVTETensor transposed_output, cudaSt /*! \brief Cast and transpose the input. Additionally, reduce the input along the first dimension. * - * This function casts the input and produces 3 results: - * - `cast_output` is the result of the cast - * - `transposed_output` is the transposed result of the cast. + * This function casts the input and produces 2 results: + * - `output` is the result of the cast (rowwise data) and transposed cast (columnwise data) * - `dbias` is the result of the reduction of the input along the first dimension. * * Calling this function with workspace being an empty tensor will not perform the operation, * but instead set the shape and type of the workspace tensor to the required values. * - * \param[in] input Input tensor of shape [N, H]. - * \param[in,out] cast_output Result of the cast. Shape: [N, H]. - * \param[in,out] transposed_output Result of the cast and transpose. Shape: [H, N]. - * \param[out] dbias Result of the reduction of the input along the - * first dimension. Shape: [H]. - * \param[out] workspace Workspace tensor. - * \param[in] stream CUDA stream used for the operation. + * \param[in] input Input tensor of shape [N, H]. + * \param[in,out] output Result of the cast and transpose. + * Shape of the rowwise data: [N, H]. + * Shape of the columnwise data: [H, N] + * \param[out] dbias Result of the reduction of the input along the + * first dimension. Shape: [H]. + * \param[out] workspace Workspace tensor. + * \param[in] stream CUDA stream used for the operation. */ -void nvte_cast_transpose_dbias(const NVTETensor input, NVTETensor cast_output, - NVTETensor transposed_output, NVTETensor dbias, NVTETensor workspace, - cudaStream_t stream); +void nvte_cast_transpose_dbias(const NVTETensor input, NVTETensor output, NVTETensor dbias, + NVTETensor workspace, cudaStream_t stream); /*! \brief Transpose the FP8 input. Additionally, reduce the input along the first dimension. * @@ -82,102 +81,242 @@ void nvte_fp8_transpose_dbias(const NVTETensor input, NVTETensor transposed_outp /*! \brief Cast and transpose multiple tensors. * - * This function casts each input tensor and produces 2 results: - * - `cast_output` is the result of the cast - * - `transposed_output` is the transposed result of the cast. - * - * \param[in] num_tensors Number of tensors. - * \param[in] input_list List of 2D input tensors. - * \param[in,out] cast_output_list List of casted tensors. Dimensions - * match tensors in input_list. - * \param[in,out] transposed_output_list List of casted and transposed - * tensors. Dimensions are transpose - * of tensors in input_list. - * \param[in] stream CUDA stream used for the operation. + * \param[in] num_tensors Number of tensors. + * \param[in] input_list List of 2D input tensors. + * \param[in,out] output_list List of casted tensors. Dimensions + * of their rowwise data members match + * tensors in input_list. Dimensions of + * their columnwise data members are + * transposed. + * \param[in] stream CUDA stream used for the operation. */ void nvte_multi_cast_transpose(size_t num_tensors, const NVTETensor* input_list, - NVTETensor* cast_output_list, NVTETensor* transposed_output_list, - cudaStream_t stream); + NVTETensor* output_list, cudaStream_t stream); -/*! \brief Compute backward of ActLU operation on the input, then cast and transpose. Additionally, - * reduce the result of the SiLU backward along the first dimension. +/*! \brief Compute backward of GeLU operation on the input, then cast and transpose. + * Additionally, reduce the result of the GeLU backward along the first dimension. * - * This function produces 3 results: - * - `cast_output` is equal to `cast(dact(input))` - * - `transposed_output` is equal to `transpose(cast(dact(input)))` + * This function produces 2 results: + * - rowwise data of `output` is equal to `cast(dact(input))` + * - columnwise data of `output` is equal to `transpose(cast(dact(input)))` * - `dbias` is equal to `reduce(dact(input), axis=0)` * * Calling this function with workspace being an empty tensor will not perform the operation, * but instead set the shape and type of the workspace tensor to the required values. * * \param[in] input Input tensor of shape [N, H]. - * \param[in] act_input Tensor used as input to the forward of SiLU operation. + * \param[in] act_input Tensor used as input for the operation of forward activation. * Shape [N, H]. - * \param[in,out] cast_output Result of the cast. Shape: [N, H]. - * \param[in,out] transposed_output Result of the cast and transpose. Shape: [H, N]. - * \param[out] dbias Result of the reduction of the dSiLU(input) along the + * \param[in,out] output Result of the cast. + * Shape of rowwise data: [N, H]. + * Shape of columnwise data: [H, N]. + * \param[out] dbias Result of the reduction of the dact(input) along the * first dimension. Shape: [H]. * \param[out] workspace Workspace tensor. * \param[in] stream CUDA stream used for the operation. - - Supported activations: GeLU, SiLU, ReLU, QuickGeLU, SquaredReLU */ - void nvte_cast_transpose_dbias_dgelu(const NVTETensor input, const NVTETensor act_input, - NVTETensor cast_output, NVTETensor transposed_output, - NVTETensor dbias, NVTETensor workspace, cudaStream_t stream); + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + cudaStream_t stream); +/*! \brief Compute backward of SiLU operation on the input, then cast and transpose. + * Additionally, reduce the result of the SiLU backward along the first dimension. + * + * This function produces 2 results: + * - rowwise data of `output` is equal to `cast(dact(input))` + * - columnwise data of `output` is equal to `transpose(cast(dact(input)))` + * - `dbias` is equal to `reduce(dact(input), axis=0)` + * + * Calling this function with workspace being an empty tensor will not perform the operation, + * but instead set the shape and type of the workspace tensor to the required values. + * + * \param[in] input Input tensor of shape [N, H]. + * \param[in] act_input Tensor used as input for the operation of forward activation. + * Shape [N, H]. + * \param[in,out] output Result of the cast. + * Shape of rowwise data: [N, H]. + * Shape of columnwise data: [H, N]. + * \param[out] dbias Result of the reduction of the dact(input) along the + * first dimension. Shape: [H]. + * \param[out] workspace Workspace tensor. + * \param[in] stream CUDA stream used for the operation. + */ void nvte_cast_transpose_dbias_dsilu(const NVTETensor input, const NVTETensor act_input, - NVTETensor cast_output, NVTETensor transposed_output, - NVTETensor dbias, NVTETensor workspace, cudaStream_t stream); + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + cudaStream_t stream); +/*! \brief Compute backward of ReLU operation on the input, then cast and transpose. + * Additionally, reduce the result of the ReLU backward along the first dimension. + * + * This function produces 2 results: + * - rowwise data of `output` is equal to `cast(dact(input))` + * - columnwise data of `output` is equal to `transpose(cast(dact(input)))` + * - `dbias` is equal to `reduce(dact(input), axis=0)` + * + * Calling this function with workspace being an empty tensor will not perform the operation, + * but instead set the shape and type of the workspace tensor to the required values. + * + * \param[in] input Input tensor of shape [N, H]. + * \param[in] act_input Tensor used as input for the operation of forward activation. + * Shape [N, H]. + * \param[in,out] output Result of the cast. + * Shape of rowwise data: [N, H]. + * Shape of columnwise data: [H, N]. + * \param[out] dbias Result of the reduction of the dact(input) along the + * first dimension. Shape: [H]. + * \param[out] workspace Workspace tensor. + * \param[in] stream CUDA stream used for the operation. + */ void nvte_cast_transpose_dbias_drelu(const NVTETensor input, const NVTETensor act_input, - NVTETensor cast_output, NVTETensor transposed_output, - NVTETensor dbias, NVTETensor workspace, cudaStream_t stream); + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + cudaStream_t stream); +/*! \brief Compute backward of the Quick GeLU operation on the input, then cast and transpose. + * Additionally, reduce the result of the Quick GeLU backward along the first dimension. + * + * This function produces 2 results: + * - rowwise data of `output` is equal to `cast(dact(input))` + * - columnwise data of `output` is equal to `transpose(cast(dact(input)))` + * - `dbias` is equal to `reduce(dact(input), axis=0)` + * + * Calling this function with workspace being an empty tensor will not perform the operation, + * but instead set the shape and type of the workspace tensor to the required values. + * + * \param[in] input Input tensor of shape [N, H]. + * \param[in] act_input Tensor used as input for the operation of forward activation. + * Shape [N, H]. + * \param[in,out] output Result of the cast. + * Shape of rowwise data: [N, H]. + * Shape of columnwise data: [H, N]. + * \param[out] dbias Result of the reduction of the dact(input) along the + * first dimension. Shape: [H]. + * \param[out] workspace Workspace tensor. + * \param[in] stream CUDA stream used for the operation. + */ void nvte_cast_transpose_dbias_dqgelu(const NVTETensor input, const NVTETensor act_input, - NVTETensor cast_output, NVTETensor transposed_output, - NVTETensor dbias, NVTETensor workspace, cudaStream_t stream); + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + cudaStream_t stream); +/*! \brief Compute backward of the Squared ReLU operation on the input, then cast and transpose. + * Additionally, reduce the result of the Squared ReLU backward along the first dimension. + * + * This function produces 2 results: + * - rowwise data of `output` is equal to `cast(dact(input))` + * - columnwise data of `output` is equal to `transpose(cast(dact(input)))` + * - `dbias` is equal to `reduce(dact(input), axis=0)` + * + * Calling this function with workspace being an empty tensor will not perform the operation, + * but instead set the shape and type of the workspace tensor to the required values. + * + * \param[in] input Input tensor of shape [N, H]. + * \param[in] act_input Tensor used as input for the operation of forward activation. + * Shape [N, H]. + * \param[in,out] output Result of the cast. + * Shape of rowwise data: [N, H]. + * Shape of columnwise data: [H, N]. + * \param[out] dbias Result of the reduction of the dact(input) along the + * first dimension. Shape: [H]. + * \param[out] workspace Workspace tensor. + * \param[in] stream CUDA stream used for the operation. + */ void nvte_cast_transpose_dbias_dsrelu(const NVTETensor input, const NVTETensor act_input, - NVTETensor cast_output, NVTETensor transposed_output, - NVTETensor dbias, NVTETensor workspace, cudaStream_t stream); + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + cudaStream_t stream); -/*! \brief Compute dgeglu of the input, additionally does cast and transpose the dgeglu output. +/*! \brief Computes the gated GeLU activation of the input, additionally casts and transposes + * the output. * * This function produces 2 results: - * - `cast_output` is the result of the cast - * - `transposed_output` is the transposed result of the cast. + * - rowwise data of `output` is equal to `cast(dact(input))` + * - columnwise data of `output` is equal to `transpose(cast(dact(input)))` * * \param[in] input Input tensor of shape [N, H]. - * \param[in] gated_act_input Tensor used as input to the forward of GeGLU operation. + * \param[in] gated_act_input Tensor used as input to the forward of + * gated activation operation. * Shape [N, H * 2]. - * \param[in,out] cast_output Result of the cast. Shape: [N, H * 2]. - * \param[in,out] transposed_output Result of the cast and transpose. Shape: [H * 2, N]. + * \param[in,out] output Result of the cast. + * Shape of rowwise data: [N, H * 2]. + * Shape of columnwise data: [H * 2, N]. * \param[in] stream CUDA stream used for the operation. - - Supported activations: GeLU, SiLU, ReLU, QuickGeLU, SquaredReLU */ - void nvte_dgeglu_cast_transpose(const NVTETensor input, const NVTETensor act_input, - NVTETensor cast_output, NVTETensor transposed_output, - cudaStream_t stream); + NVTETensor output, cudaStream_t stream); +/*! \brief Computes the gated Swish activation of the input, + * additionally casts and transposes the output. + * + * This function produces 2 results: + * - rowwise data of `output` is equal to `cast(dact(input))` + * - columnwise data of `output` is equal to `transpose(cast(dact(input)))` + * + * \param[in] input Input tensor of shape [N, H]. + * \param[in] gated_act_input Tensor used as input to the forward of + * gated activation operation. + * Shape [N, H * 2]. + * \param[in,out] output Result of the cast. + * Shape of rowwise data: [N, H * 2]. + * Shape of columnwise data: [H * 2, N]. + * \param[in] stream CUDA stream used for the operation. +*/ void nvte_dswiglu_cast_transpose(const NVTETensor input, const NVTETensor act_input, - NVTETensor cast_output, NVTETensor transposed_output, - cudaStream_t stream); + NVTETensor output, cudaStream_t stream); +/*! \brief Computes the gated ReLU activation of the input, + * additionally casts and transposes the output. + * + * This function produces 2 results: + * - rowwise data of `output` is equal to `cast(dact(input))` + * - columnwise data of `output` is equal to `transpose(cast(dact(input)))` + * + * \param[in] input Input tensor of shape [N, H]. + * \param[in] gated_act_input Tensor used as input to the forward of + * gated activation operation. + * Shape [N, H * 2]. + * \param[in,out] output Result of the cast. + * Shape of rowwise data: [N, H * 2]. + * Shape of columnwise data: [H * 2, N]. + * \param[in] stream CUDA stream used for the operation. +*/ void nvte_dreglu_cast_transpose(const NVTETensor input, const NVTETensor act_input, - NVTETensor cast_output, NVTETensor transposed_output, - cudaStream_t stream); + NVTETensor output, cudaStream_t stream); +/*! \brief Computes the gated Quick GeLU activation of the input, + * additionally casts and transposes the output. + * + * This function produces 2 results: + * - rowwise data of `output` is equal to `cast(dact(input))` + * - columnwise data of `output` is equal to `transpose(cast(dact(input)))` + * + * \param[in] input Input tensor of shape [N, H]. + * \param[in] gated_act_input Tensor used as input to the forward of + * gated activation operation. + * Shape [N, H * 2]. + * \param[in,out] output Result of the cast. + * Shape of rowwise data: [N, H * 2]. + * Shape of columnwise data: [H * 2, N]. + * \param[in] stream CUDA stream used for the operation. +*/ void nvte_dqgeglu_cast_transpose(const NVTETensor input, const NVTETensor act_input, - NVTETensor cast_output, NVTETensor transposed_output, - cudaStream_t stream); + NVTETensor output, cudaStream_t stream); +/*! \brief Computes the gated Squared ReLU activation of the input, + * additionally casts and transposes the output. + * + * This function produces 2 results: + * - rowwise data of `output` is equal to `cast(dact(input))` + * - columnwise data of `output` is equal to `transpose(cast(dact(input)))` + * + * \param[in] input Input tensor of shape [N, H]. + * \param[in] gated_act_input Tensor used as input to the forward of + * gated activation operation. + * Shape [N, H * 2]. + * \param[in,out] output Result of the cast. + * Shape of rowwise data: [N, H * 2]. + * Shape of columnwise data: [H * 2, N]. + * \param[in] stream CUDA stream used for the operation. +*/ void nvte_dsreglu_cast_transpose(const NVTETensor input, const NVTETensor act_input, - NVTETensor cast_output, NVTETensor transposed_output, - cudaStream_t stream); + NVTETensor output, cudaStream_t stream); #ifdef __cplusplus } // extern "C" diff --git a/transformer_engine/common/normalization/common.cpp b/transformer_engine/common/normalization/common.cpp index 89e2e9feec..7ef3ac44e7 100644 --- a/transformer_engine/common/normalization/common.cpp +++ b/transformer_engine/common/normalization/common.cpp @@ -15,6 +15,7 @@ #include #include "transformer_engine/normalization.h" +#include "transformer_engine/transformer_engine.h" /* @@ -38,13 +39,21 @@ Compute always in FP32 namespace transformer_engine { namespace normalization { -TupleKeyType get_key(NVTE_Norm_Type NormType, NVTE_Norm_Stage NormStage, DType wtype, DType itype, - DType otype, DType ctype, uint64_t batch_size, uint64_t hidden_size, - bool zero_centered_gamma, bool is_tuned) { +cudnn_frontend::NormFwdPhase_t get_cudnn_forward_phase(const bool training) { + return training ? cudnn_frontend::NormFwdPhase_t::TRAINING + : cudnn_frontend::NormFwdPhase_t::INFERENCE; +} + +TupleKeyType get_key(NVTE_Norm_Backend NormBackend, NVTE_Norm_Type NormType, + NVTE_Norm_Stage NormStage, DType wtype, DType itype, DType otype, DType ctype, + uint64_t batch_size, uint64_t hidden_size, bool zero_centered_gamma, + bool is_tuned, NVTEScalingMode mode, bool training) { + // TODO: Add scaling_mode to general_key is needed uint64_t general_key = static_cast(itype) | (static_cast(otype) << 3) | (static_cast(ctype) << 6) | (static_cast(wtype) << 9) | (uint32_t(NormType) << 12) | (uint32_t(NormStage)) << 14 | - (uint32_t(zero_centered_gamma) << 16); + (uint32_t(NormBackend) << 16) | (uint32_t(zero_centered_gamma) << 18) | + (uint32_t(mode) << 19) | (uint32_t(training) << 22); return std::make_tuple(general_key, batch_size, hidden_size, is_tuned); } @@ -64,8 +73,8 @@ TeNormalizationPlan::TeNormalizationPlan( kernel_params.fp8_out = is_fp8_dtype(otype); } // TE kernels have no template for batch_size and zero_centered_gamma, thus zero out those - auto key = - get_key(NormType, NormStage, wtype, itype, otype, ctype, 0, hidden_size, false, is_tuned); + auto key = get_key(NVTE_Norm_Backend::Te, NormType, NormStage, wtype, itype, otype, ctype, 0, + hidden_size, false, is_tuned); _kernel = KernelRegistry::getKernel(key); this->_build(); @@ -179,13 +188,25 @@ CudnnNormalizationPlan::CudnnNormalizationPlan(NVTE_Norm_Type NormType, NVTE_Nor DType wtype, DType itype, DType otype, DType ctype, const size_t batch_size, const size_t hidden_size, const size_t sm_count, - const bool zero_centered_gamma) - : _fp8_out(is_fp8_dtype(otype)), _zero_centered(zero_centered_gamma) { + const bool zero_centered_gamma, + const NVTEScalingMode mode, bool training) + : _fp8_out(is_fp8_dtype(otype)), + _zero_centered(zero_centered_gamma), + _training(training), + _norm_stage(NormStage), + _norm_type(NormType) { static_assert(CUDNN_FRONTEND_VERSION >= 10601, "CUDNN_FRONTEND_VERSION should be at least 1.6.1!"); namespace fe = cudnn_frontend; + if (is_tensor_scaling(mode)) { + _ndim_scale_block = 0; + } else { + NVTE_CHECK(mode == NVTE_MXFP8_1D_SCALING, "Unsupported scaling mode."); + _ndim_scale_block = 1; + } + _scalar_dptr = std::make_unique(typeToSize(wtype)); TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( wtype, cpp_dtype, *(reinterpret_cast(_scalar_dptr.get())) = (cpp_dtype)1.0f;); @@ -213,7 +234,7 @@ CudnnNormalizationPlan::CudnnNormalizationPlan(NVTE_Norm_Type NormType, NVTE_Nor .set_dim({1, hidden_dim, 1, 1}) .set_stride({hidden_dim, 1, hidden_dim, hidden_dim}) .set_data_type(get_cudnn_fe_dtype(wtype))); - if (zero_centered_gamma) { + if (_zero_centered) { _scalar_offset = _graph.tensor(fe::graph::Tensor_attributes() .set_name("one") .set_dim({1, 1, 1, 1}) @@ -230,59 +251,97 @@ CudnnNormalizationPlan::CudnnNormalizationPlan(NVTE_Norm_Type NormType, NVTE_Nor } // Create graph computation nodes - if (NormStage == NVTE_Norm_Stage::Forward) { + if (_norm_stage == NVTE_Norm_Stage::Forward) { _eps = _graph.tensor(fe::graph::Tensor_attributes() .set_name("epsilon") .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) .set_data_type(get_cudnn_fe_dtype(ctype)) .set_is_pass_by_value(true)); - if (NormType == NVTE_Norm_Type::LayerNorm) { + if (_norm_type == NVTE_Norm_Type::LayerNorm) { _beta = _graph.tensor(fe::graph::Tensor_attributes() .set_name("bias") .set_dim({1, hidden_dim, 1, 1}) .set_stride({hidden_dim, 1, hidden_dim, hidden_dim}) .set_data_type(get_cudnn_fe_dtype(wtype))); auto norm_options = fe::graph::Layernorm_attributes() - .set_forward_phase(fe::NormFwdPhase_t::TRAINING) + .set_forward_phase(get_cudnn_forward_phase(_training)) .set_epsilon(_eps) .set_compute_data_type(get_cudnn_fe_dtype(ctype)); auto ret = _graph.layernorm(_x, _gamma, _beta, norm_options); std::tie(_z, _mean, _rsigma) = std::make_tuple(ret[0], ret[1], ret[2]); - _mean->set_output(true).set_data_type(get_cudnn_fe_dtype(ctype)); - } else if (NormType == NVTE_Norm_Type::RMSNorm) { + if (_training) _mean->set_output(true).set_data_type(get_cudnn_fe_dtype(ctype)); + } else { auto norm_options = fe::graph::Rmsnorm_attributes() - .set_forward_phase(fe::NormFwdPhase_t::TRAINING) + .set_forward_phase(get_cudnn_forward_phase(_training)) .set_epsilon(_eps) .set_compute_data_type(get_cudnn_fe_dtype(ctype)); auto ret = _graph.rmsnorm(_x, _gamma, norm_options); std::tie(_z, _rsigma) = std::make_tuple(ret[0], ret[1]); } - _rsigma->set_output(true).set_data_type(get_cudnn_fe_dtype(ctype)); + if (_training) _rsigma->set_output(true).set_data_type(get_cudnn_fe_dtype(ctype)); const auto ZDtype = _fp8_out ? ctype : otype; _z->set_output(!_fp8_out).set_data_type(get_cudnn_fe_dtype(ZDtype)); if (_fp8_out) { - // create a scale node - _z_scale = _graph.tensor(fe::graph::Tensor_attributes() - .set_name("z_scale") - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(get_cudnn_fe_dtype(ctype))); - auto z_scale_options = fe::graph::Pointwise_attributes() - .set_mode(fe::PointwiseMode_t::MUL) - .set_compute_data_type(get_cudnn_fe_dtype(ctype)); - _z_fp8 = _graph.pointwise(_z, _z_scale, z_scale_options); - - _z_fp8->set_output(true).set_data_type(get_cudnn_fe_dtype(otype)); - - // create an amax reduction node - _amax = _graph.reduction(_z, fe::graph::Reduction_attributes() - .set_mode(fe::ReductionMode_t::AMAX) - .set_compute_data_type(get_cudnn_fe_dtype(ctype))); - _amax->set_output(true).set_data_type(get_cudnn_fe_dtype(ctype)).set_dim({1, 1, 1, 1}); + if (_ndim_scale_block == 0) { // tensor_scaling + // create a scale node + _z_scale = _graph.tensor(fe::graph::Tensor_attributes() + .set_name("z_scale") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(get_cudnn_fe_dtype(ctype))); + auto z_scale_options = fe::graph::Pointwise_attributes() + .set_mode(fe::PointwiseMode_t::MUL) + .set_compute_data_type(get_cudnn_fe_dtype(ctype)); + _z_fp8 = _graph.pointwise(_z, _z_scale, z_scale_options); + + _z_fp8->set_output(true).set_data_type(get_cudnn_fe_dtype(otype)); + + // create an amax reduction node + _amax = _graph.reduction(_z, fe::graph::Reduction_attributes() + .set_mode(fe::ReductionMode_t::AMAX) + .set_compute_data_type(get_cudnn_fe_dtype(ctype))); + _amax->set_output(true).set_data_type(get_cudnn_fe_dtype(ctype)).set_dim({1, 1, 1, 1}); + _one_for_div = _graph.tensor(fe::graph::Tensor_attributes() + .set_name("one_for_div") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(get_cudnn_fe_dtype(ctype)) + .set_is_pass_by_value(true)); + auto div_options = fe::graph::Pointwise_attributes() + .set_mode(fe::PointwiseMode_t::DIV) + .set_compute_data_type(get_cudnn_fe_dtype(ctype)); + _z_scale_inv = _graph.pointwise(_one_for_div, _z_scale, div_options); + _z_scale_inv->set_output(true).set_data_type(get_cudnn_fe_dtype(ctype)); + } else if (_ndim_scale_block == 1) { // 1d block scaling + auto z_2d = _graph.reshape(_z, fe::graph::Reshape_attributes()); + z_2d->set_dim({batch_dim, hidden_dim}); + + auto mx_quantize_row_opts = fe::graph::Block_scale_quantize_attributes() + .set_block_size(32) + .set_axis(1) + .set_transpose(false); + auto bs_row_ret = _graph.block_scale_quantize(z_2d, mx_quantize_row_opts); + std::tie(_z_mx_row, _sf_row) = std::make_tuple(bs_row_ret[0], bs_row_ret[1]); + _z_mx_row->set_output(true).set_data_type(get_cudnn_fe_dtype(otype)); + _sf_row->set_output(true).set_data_type(fe::DataType_t::FP8_E8M0); //TODO + + if (_training) { + auto mx_quantize_col_opts = fe::graph::Block_scale_quantize_attributes() + .set_block_size(32) + .set_axis(0) + .set_transpose(false); + auto bs_col_ret = _graph.block_scale_quantize(z_2d, mx_quantize_col_opts); + std::tie(_z_mx_col, _sf_col) = std::make_tuple(bs_col_ret[0], bs_col_ret[1]); + _z_mx_col->set_output(true).set_data_type(get_cudnn_fe_dtype(otype)); + _sf_col->set_output(true).set_data_type(fe::DataType_t::FP8_E8M0); + } + } else { + NVTE_ERROR("Unsupported scaling mode."); + } } } else { _dz = _graph.tensor(fe::graph::Tensor_attributes() @@ -299,7 +358,7 @@ CudnnNormalizationPlan::CudnnNormalizationPlan(NVTE_Norm_Type NormType, NVTE_Nor .set_dim({batch_dim, 1, 1, 1}) .set_stride({1, 1, 1, 1}) .set_data_type(get_cudnn_fe_dtype(ctype))); - if (NormType == NVTE_Norm_Type::LayerNorm) { + if (_norm_type == NVTE_Norm_Type::LayerNorm) { auto norm_options = fe::graph::Layernorm_backward_attributes() .set_saved_mean_and_inv_variance(_mean, _rsigma) .set_compute_data_type(get_cudnn_fe_dtype(ctype)); @@ -341,10 +400,14 @@ void CudnnNormalizationPlan::execute(Tensor* z, void* x_dptr, void* gamma_dptr, void* mean_dptr, void* eps_dptr, void* rsigma_dptr, void* workspace_dptr, cudaStream_t stream) { // Binding data pointers to graph tensors - _variant_pack = {{_x, x_dptr}, {_rsigma, rsigma_dptr}, {_eps, eps_dptr}}; + _variant_pack = {{_x, x_dptr}, {_eps, eps_dptr}}; - // layernorm should have valid mean_dptr and beta_dptr - if (mean_dptr && beta_dptr) _variant_pack.insert({{_mean, mean_dptr}, {_beta, beta_dptr}}); + if (_training) _variant_pack.insert({{_rsigma, rsigma_dptr}}); + + if (_norm_type == NVTE_Norm_Type::LayerNorm) { + _variant_pack.insert({{_beta, beta_dptr}}); + if (_training) _variant_pack.insert({{_mean, mean_dptr}}); + } if (_zero_centered) _variant_pack.insert( @@ -352,16 +415,24 @@ void CudnnNormalizationPlan::execute(Tensor* z, void* x_dptr, void* gamma_dptr, else _variant_pack.insert({{_gamma, gamma_dptr}}); - if (_fp8_out) - _variant_pack.insert( - {{_z_scale, z->scale.dptr}, {_amax, z->amax.dptr}, {_z_fp8, z->data.dptr}}); - else + if (_fp8_out && _ndim_scale_block == 0) { + _variant_pack.insert({{_one_for_div, reinterpret_cast(_one_dptr.get())}, + {_z_scale, z->scale.dptr}, + {_z_scale_inv, z->scale_inv.dptr}, + {_amax, z->amax.dptr}, + {_z_fp8, z->data.dptr}}); + } else if (_fp8_out && _ndim_scale_block == 1) { + _variant_pack.insert({{_z_mx_row, z->data.dptr}, {_sf_row, z->scale_inv.dptr}}); + if (_training) + _variant_pack.insert( + {{_z_mx_col, z->columnwise_data.dptr}, {_sf_col, z->columnwise_scale_inv.dptr}}); + } else { _variant_pack.insert({{_z, z->data.dptr}}); + } // Execute the computation NVTE_CHECK_CUDNN(cudnnSetStream(_handle, stream)); NVTE_CHECK(_graph.execute(_handle, _variant_pack, workspace_dptr).is_good()); - if (_fp8_out) update_tensor_scale_inv(z, stream); } void CudnnNormalizationPlan::execute(void* x_dptr, void* gamma_dptr, void* mean_dptr, @@ -389,11 +460,12 @@ void CudnnNormalizationPlan::execute(void* x_dptr, void* gamma_dptr, void* mean_ NormalizationPlanBase* NormalizationPlanRegistry::getNormalizationPlan( NVTE_Norm_Backend NormBackend, NVTE_Norm_Type NormType, NVTE_Norm_Stage NormStage, DType wtype, DType itype, DType otype, const size_t batch_size, const size_t hidden_size, - const size_t sm_count, const bool zero_centered_gamma, const bool is_aligned) { + const size_t sm_count, const bool zero_centered_gamma, const bool is_aligned, + const NVTEScalingMode mode, const bool training) { const DType ctype = DType::kFloat32; bool is_tuned = is_aligned && (batch_size % 4 == 0); - auto key = get_key(NormType, NormStage, wtype, itype, otype, ctype, batch_size, hidden_size, - zero_centered_gamma, is_tuned); + auto key = get_key(NormBackend, NormType, NormStage, wtype, itype, otype, ctype, batch_size, + hidden_size, zero_centered_gamma, is_tuned, mode, training); auto it = normalizationPlanMap.find(key); if (it != normalizationPlanMap.end()) { @@ -404,7 +476,7 @@ NormalizationPlanBase* NormalizationPlanRegistry::getNormalizationPlan( if (NormBackend == NVTE_Norm_Backend::Cudnn) { plan = std::make_unique(NormType, NormStage, wtype, itype, otype, ctype, batch_size, hidden_size, sm_count, - zero_centered_gamma); + zero_centered_gamma, mode, training); } else if (NormStage == NVTE_Norm_Stage::Forward) { plan = std::make_unique>( NormType, NormStage, wtype, itype, otype, ctype, batch_size, hidden_size, sm_count, diff --git a/transformer_engine/common/normalization/common.h b/transformer_engine/common/normalization/common.h index f366ba26db..ea0450f1c2 100644 --- a/transformer_engine/common/normalization/common.h +++ b/transformer_engine/common/normalization/common.h @@ -154,9 +154,12 @@ struct TupleHash { } }; -TupleKeyType get_key(NVTE_Norm_Type NormType, NVTE_Norm_Stage NormStage, DType wtype, DType itype, - DType otype, DType ctype, uint64_t batch_size, uint64_t hidden_size, - bool zero_centered_gamma, bool is_tuned); +// Note: the default mode here should match with the default mode with QTensor +TupleKeyType get_key(NVTE_Norm_Backend NormBackend, NVTE_Norm_Type NormType, + NVTE_Norm_Stage NormStage, DType wtype, DType itype, DType otype, DType ctype, + uint64_t batch_size, uint64_t hidden_size, bool zero_centered_gamma, + bool is_tuned, NVTEScalingMode mode = NVTE_DELAYED_TENSOR_SCALING, + bool training = true); template class TeNormalizationRegistry { @@ -257,7 +260,8 @@ class CudnnNormalizationPlan : public NormalizationPlanBase { CudnnNormalizationPlan(NVTE_Norm_Type NormType, NVTE_Norm_Stage NormStage, DType wtype, DType itype, DType otype, DType ctype, const size_t batch_size, const size_t hidden_size, const size_t sm_count, - const bool zero_centered_gamma); + const bool zero_centered_gamma, const NVTEScalingMode mode, + const bool training); std::vector getWorkspaceShape() const override; @@ -273,10 +277,17 @@ class CudnnNormalizationPlan : public NormalizationPlanBase { void _build() override; const bool _zero_centered, _fp8_out; + int _ndim_scale_block; + const NVTE_Norm_Stage _norm_stage; + const NVTE_Norm_Type _norm_type; std::unique_ptr _scalar_dptr; + std::unique_ptr _one_dptr = std::make_unique(1.0f); // FWD std::shared_ptr _x, _gamma_zero, _scalar_offset, _gamma, _beta, - _eps, _mean, _rsigma, _z, _z_scale, _amax, _z_fp8; + _eps, _mean, _rsigma, _z, _z_scale, _one_for_div, _z_scale_inv, _amax, _z_fp8; + // MX FWD + std::shared_ptr _z_mx_row, _z_mx_col, _sf_row, _sf_col; + const bool _training; // BWD std::shared_ptr _dz, _dx, _dgamma, _dbeta; @@ -292,12 +303,11 @@ class NormalizationPlanRegistry { return instance; } - NormalizationPlanBase* getNormalizationPlan(NVTE_Norm_Backend NormBackend, - NVTE_Norm_Type NormType, NVTE_Norm_Stage NormStage, - DType wtype, DType itype, DType otype, - const size_t batch_size, const size_t hidden_size, - const size_t sm_count, const bool zero_centered_gamma, - const bool is_aligned); + NormalizationPlanBase* getNormalizationPlan( + NVTE_Norm_Backend NormBackend, NVTE_Norm_Type NormType, NVTE_Norm_Stage NormStage, + DType wtype, DType itype, DType otype, const size_t batch_size, const size_t hidden_size, + const size_t sm_count, const bool zero_centered_gamma, const bool is_aligned, + const NVTEScalingMode mode = NVTE_DELAYED_TENSOR_SCALING, const bool training = true); private: NormalizationPlanRegistry() {} @@ -356,15 +366,12 @@ struct TypeToDType { static int \ register_##NORM_TYPE##_##NORM_STAGE##_##LAUNCH_TYPE##_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE = \ TeNormalizationRegistry::registerFunction( \ - (get_key(NVTE_Norm_Type::NORM_TYPE, NVTE_Norm_Stage::NORM_STAGE, \ - (TypeToDType::value), (TypeToDType::value), \ - (TypeToDType::value), (TypeToDType::value), 0, HIDDEN_SIZE, \ - 0, IS_TUNED(LAUNCH_TYPE))), \ + (get_key(NVTE_Norm_Backend::Te, NVTE_Norm_Type::NORM_TYPE, \ + NVTE_Norm_Stage::NORM_STAGE, (TypeToDType::value), \ + (TypeToDType::value), (TypeToDType::value), \ + (TypeToDType::value), 0, HIDDEN_SIZE, 0, IS_TUNED(LAUNCH_TYPE))), \ FUNC_NAME) -// For FP8 only -void ComputeScaleInv(void* scale, void* scale_inv); - // Alignment check template bool is_ptr_aligned(const Args*... ptrs) { @@ -375,7 +382,6 @@ bool use_cudnn_norm_fwd(); bool use_cudnn_norm_bwd(); } // namespace normalization - } // namespace transformer_engine #endif diff --git a/transformer_engine/common/normalization/layernorm/ln_api.cpp b/transformer_engine/common/normalization/layernorm/ln_api.cpp index a412bae745..dae39d82bf 100644 --- a/transformer_engine/common/normalization/layernorm/ln_api.cpp +++ b/transformer_engine/common/normalization/layernorm/ln_api.cpp @@ -5,6 +5,7 @@ ************************************************************************/ #include +#include #include #include @@ -25,6 +26,11 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size const float epsilon, Tensor* z, Tensor* mu, Tensor* rsigma, Tensor* workspace, const int multiprocessorCount, const bool zero_centered_gamma, cudaStream_t stream) { + if (is_fp8_dtype(z->data.dtype) && !is_delayed_tensor_scaling(z->scaling_mode) && + !is_block_scaling(z->scaling_mode)) { + NVTE_ERROR("Not implemented scaling mode: " + to_string(z->scaling_mode) + "."); + } + NVTE_CHECK(x.data.shape.size() == 2); NVTE_CHECK(gamma.data.shape == beta.data.shape); NVTE_CHECK(x.data.shape[1] == gamma.data.shape[0]); @@ -51,7 +57,9 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size NVTE_Norm_Backend norm_backend; bool is_aligned = true; - if (use_cudnn_norm_fwd()) { + bool cudnn_backend = use_cudnn_norm_fwd() || is_block_scaling(z->scaling_mode); + + if (cudnn_backend) { // TODO: add check for GPU ARCH norm_backend = NVTE_Norm_Backend::Cudnn; } else { @@ -59,6 +67,10 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size is_aligned = is_ptr_aligned(z->data.dptr, x.data.dptr, gamma.data.dptr, beta.data.dptr, mu->data.dptr, rsigma->data.dptr); } + + bool training = + is_delayed_tensor_scaling(z->scaling_mode) || (z->columnwise_data).dptr != nullptr; + auto plan = NormalizationPlanRegistry::getInstance().getNormalizationPlan( norm_backend, NVTE_Norm_Type::LayerNorm, NVTE_Norm_Stage::Forward, gamma.data.dtype, // wtype @@ -66,18 +78,31 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size z->data.dtype, // otype x.data.shape[0], // batch_size x.data.shape[1], // hidden_size - multiprocessorCount, zero_centered_gamma, is_aligned); + multiprocessorCount, zero_centered_gamma, is_aligned, z->scaling_mode, training); if (workspace->data.shape.empty()) { workspace->data.shape = plan->getWorkspaceShape(); workspace->data.dtype = DType::kByte; return; - } else { - NVTE_CHECK(workspace->data.shape == plan->getWorkspaceShape()); - plan->execute(z, x.data.dptr, gamma.data.dptr, beta.data.dptr, mu->data.dptr, - reinterpret_cast(const_cast(&epsilon)), rsigma->data.dptr, - workspace->data.dptr, stream); } + + NVTE_CHECK(workspace->data.shape == plan->getWorkspaceShape()); + NVTE_CHECK( + !is_block_scaling(z->scaling_mode) || (!training || z->columnwise_scale_inv.dptr != nullptr), + "Columnwise scale_inv must be allocated for NormFwdTraining!"); + plan->execute(z, x.data.dptr, gamma.data.dptr, beta.data.dptr, mu->data.dptr, + reinterpret_cast(const_cast(&epsilon)), rsigma->data.dptr, + workspace->data.dptr, stream); + + // Compute FP8 transpose if required + if (z->has_columnwise_data() && is_tensor_scaling(z->scaling_mode)) { + Tensor transpose_data; + transpose_data.data = z->columnwise_data; + transpose_data.scaling_mode = z->scaling_mode; + nvte_transpose(reinterpret_cast(z), reinterpret_cast(&transpose_data), + stream); + } + return; } diff --git a/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp b/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp index dd4c8e580d..8519fe1b64 100644 --- a/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp +++ b/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp @@ -13,6 +13,7 @@ #include "../../common.h" #include "../common.h" #include "transformer_engine/normalization.h" +#include "transformer_engine/transpose.h" namespace transformer_engine { @@ -21,6 +22,11 @@ using namespace normalization; void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tensor *z, Tensor *rsigma, Tensor *workspace, const int multiprocessorCount, const bool zero_centered_gamma, cudaStream_t stream) { + if (is_fp8_dtype(z->data.dtype) && !is_delayed_tensor_scaling(z->scaling_mode) && + !is_block_scaling(z->scaling_mode)) { + NVTE_ERROR("Not implemented scaling mode: " + to_string(z->scaling_mode) + "."); + } + NVTE_CHECK(x.data.shape.size() == 2); NVTE_CHECK(gamma.data.shape[0] == x.data.shape[1]); @@ -39,17 +45,21 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens CheckOutputTensor(*rsigma, "rsigma"); } - Tensor empty; - NVTE_Norm_Backend norm_backend; bool is_aligned = true; - if (use_cudnn_norm_fwd()) { + bool cudnn_backend = use_cudnn_norm_fwd() || is_block_scaling(z->scaling_mode); + + bool training = + is_delayed_tensor_scaling(z->scaling_mode) || (z->columnwise_data).dptr != nullptr; + + if (cudnn_backend) { // TODO: add check for GPU ARCH norm_backend = NVTE_Norm_Backend::Cudnn; } else { norm_backend = NVTE_Norm_Backend::Te; is_aligned = is_ptr_aligned(z->data.dptr, x.data.dptr, gamma.data.dptr, rsigma->data.dptr); } + auto plan = NormalizationPlanRegistry::getInstance().getNormalizationPlan( norm_backend, NVTE_Norm_Type::RMSNorm, NVTE_Norm_Stage::Forward, gamma.data.dtype, // wtype @@ -57,17 +67,29 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens z->data.dtype, // otype x.data.shape[0], // batch_size x.data.shape[1], // hidden_size - multiprocessorCount, zero_centered_gamma, is_aligned); + multiprocessorCount, zero_centered_gamma, is_aligned, z->scaling_mode, training); if (workspace->data.shape.empty()) { workspace->data.shape = plan->getWorkspaceShape(); workspace->data.dtype = DType::kByte; return; - } else { - NVTE_CHECK(workspace->data.shape == plan->getWorkspaceShape()); - plan->execute(z, x.data.dptr, gamma.data.dptr, nullptr, nullptr, - reinterpret_cast(const_cast(&epsilon)), rsigma->data.dptr, - workspace->data.dptr, stream); + } + + NVTE_CHECK(workspace->data.shape == plan->getWorkspaceShape()); + NVTE_CHECK( + !is_block_scaling(z->scaling_mode) || (!training || z->columnwise_scale_inv.dptr != nullptr), + "Columnwise scale_inv must be allocated for NormFwdTraining!"); + plan->execute(z, x.data.dptr, gamma.data.dptr, nullptr /*beta*/, nullptr /*mu*/, + reinterpret_cast(const_cast(&epsilon)), rsigma->data.dptr, + workspace->data.dptr, stream); + + // Compute FP8 transpose if required + if (z->has_columnwise_data() && is_tensor_scaling(z->scaling_mode)) { + Tensor transpose_data; + transpose_data.data = z->columnwise_data; + transpose_data.scaling_mode = z->scaling_mode; + nvte_transpose(reinterpret_cast(z), reinterpret_cast(&transpose_data), + stream); } return; @@ -101,8 +123,6 @@ void rmsnorm_bwd(const Tensor &dz, const Tensor &x, const Tensor &rsigma, const CheckOutputTensor(*dgamma, "dgamma"); } - Tensor empty; - NVTE_Norm_Backend norm_backend; bool is_aligned = true; if (use_cudnn_norm_bwd()) { @@ -128,8 +148,8 @@ void rmsnorm_bwd(const Tensor &dz, const Tensor &x, const Tensor &rsigma, const return; } else { NVTE_CHECK(workspace->data.shape == plan->getWorkspaceShape()); - plan->execute(x.data.dptr, gamma.data.dptr, nullptr, rsigma.data.dptr, dx->data.dptr, - dz.data.dptr, nullptr, dgamma->data.dptr, workspace->data.dptr, stream); + plan->execute(x.data.dptr, gamma.data.dptr, nullptr /*mu*/, rsigma.data.dptr, dx->data.dptr, + dz.data.dptr, nullptr /*dbeta*/, dgamma->data.dptr, workspace->data.dptr, stream); } return; } diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 2c9944439d..f68edf155c 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -39,19 +39,22 @@ class Format(Enum): HYBRID = _FormatHelper(max_fwd=E4M3.max_fwd, max_bwd=E5M2.max_bwd) -class _OverrideLinearPrecision(NamedTuple): +class Recipe: """ - Whether or not the execute the `fprop`, `dgrad`, and `wgrad` - GEMMs in higher precision when using FP8. + Base recipe class. """ - fprop: bool = False - dgrad: bool = False - wgrad: bool = False + def mxfp8(self): + """Whether the given recipe is MXFP8 block scaling.""" + return isinstance(self, MXFP8BlockScaling) + + def delayed(self): + """Whether the given recipe is delayed scaling.""" + return isinstance(self, DelayedScaling) @dataclass() -class DelayedScaling: +class DelayedScaling(Recipe): """ Use the delayed scaling factor strategy. Use scale factor from previous iteration and record amax history of `amax_history_len` steps. @@ -92,9 +95,6 @@ def scaling_factor_compute(amax: Tensor, recipe: DelayedScaling) -> Tensor where `Tensor` is a framework tensor type. - override_linear_precision: Tuple(bool, bool, bool), default=(False, False, False) - Whether or not to execute the `fprop`, `dgrad`, and `wgrad` - GEMMs (respectively) in higher precision when using FP8. reduce_amax: bool, default = `True` By default, if `torch.distributed` is initialized, the `amax` value for FP8 tensors is reduced across the `fp8_group` (specified in the `fp8_autocast` @@ -137,7 +137,6 @@ def scaling_factor_compute(amax: Tensor, fp8_format: Format = Format.HYBRID amax_history_len: int = 1024 amax_compute_algo: Union[Literal["max", "most_recent"], Callable] = "max" - override_linear_precision: _OverrideLinearPrecision = _OverrideLinearPrecision() scaling_factor_compute_algo: Optional[Callable] = None reduce_amax: bool = True fp8_dpa: bool = False @@ -145,10 +144,6 @@ def scaling_factor_compute(amax: Tensor, def __post_init__(self) -> None: assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." - assert self.override_linear_precision in ( - (False, False, False), - (False, False, True), - ), "Only wgrad GEMM override is currently supported." if self.interval >= 0: warnings.warn( "`interval` argument is deprecated and unused. " @@ -161,7 +156,32 @@ def __repr__(self) -> str: f"margin={self.margin}, " f"format={str(self.fp8_format).split('.')[1]}, " f"amax_history_len={self.amax_history_len}, " - f"wgrad_override={self.override_linear_precision.wgrad}, " f"fp8_dpa={self.fp8_dpa}, " f"fp8_mha={self.fp8_mha}" ) + + +@dataclass() +class MXFP8BlockScaling(Recipe): + """ + Use the current scaling factor strategy. + + Parameters + ---------- + margin : int, default = 0 + Margin for the scaling factor computation. + fp8_format : {Format.E4M3, Format.HYBRID}, default = Format.HYBRID + Controls the FP8 data format used during forward and backward + pass. + """ + + margin: int = 0 + fp8_format: Format = Format.E4M3 + fp8_dpa: bool = False + fp8_mha: bool = False + + def __post_init__(self) -> None: + assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." + + def __repr__(self) -> str: + return f"margin={self.margin}, format={str(self.fp8_format).split('.')[1]}," diff --git a/transformer_engine/common/recipe/delayed_scaling.cu b/transformer_engine/common/recipe/delayed_scaling.cu index b16bad9e6a..658ce054da 100644 --- a/transformer_engine/common/recipe/delayed_scaling.cu +++ b/transformer_engine/common/recipe/delayed_scaling.cu @@ -46,7 +46,6 @@ struct AmaxParam { int num_scale = 0; float* amax_history = nullptr; float* scale = nullptr; - float* scale_inv = nullptr; }; // dummy struct for kernel_bulk's other params @@ -83,10 +82,9 @@ constexpr size_t bsize = 256; * Grid dims: num_scales x 1 x 1 */ __global__ void __launch_bounds__(bsize) - kernel(const float* amax_history_ptr, const float* scale_ptr, const float* scale_inv_ptr, - const unsigned char* scale_inv_mask_ptr, float* updated_amax_history_ptr, - float* updated_scale_ptr, float* updated_scale_inv_ptr, size_t amax_history_length, - size_t amax_history_stride, AmaxComputeAlgo amax_compute_algo, float scaled_max) { + kernel(const float* amax_history_ptr, const float* scale_ptr, float* updated_amax_history_ptr, + float* updated_scale_ptr, size_t amax_history_length, size_t amax_history_stride, + AmaxComputeAlgo amax_compute_algo, float scaled_max) { const size_t tid = threadIdx.x; const size_t bid = blockIdx.x; @@ -135,7 +133,7 @@ __global__ void __launch_bounds__(bsize) } } - // Update scale and scale inverse + // Update scale if (tid == 0) { // Update scale float scale; @@ -152,15 +150,6 @@ __global__ void __launch_bounds__(bsize) scale = std::numeric_limits::max(); } updated_scale_ptr[bid] = scale; - - // Update scale inverse - float scale_inv; - if (scale_inv_mask_ptr == nullptr || scale_inv_mask_ptr[bid]) { - scale_inv = 1 / scale; - } else { - scale_inv = scale_inv_ptr[bid]; - } - updated_scale_inv_ptr[bid] = scale_inv; } } @@ -232,7 +221,7 @@ __global__ void __launch_bounds__(bsize) } } - // Update scale and scale inverse + // Update scale if (tid == 0) { // Computing the scaling factor requires consideration of the following scenarios: // 1. amax == 0: @@ -259,7 +248,6 @@ __global__ void __launch_bounds__(bsize) scale = std::numeric_limits::max(); } p.param[bid].scale[count] = scale; - p.param[bid].scale_inv[count] = 1 / scale; } } } @@ -268,23 +256,12 @@ __global__ void __launch_bounds__(bsize) } // namespace -void amax_and_scale_update(const Tensor& amax_history, const Tensor& scale, const Tensor& scale_inv, - const Tensor& scale_inv_mask, Tensor* updated_amax_history_, - Tensor* updated_scale_, Tensor* updated_scale_inv_, +void amax_and_scale_update(const Tensor& amax_history, const Tensor& scale, + Tensor* updated_amax_history_, Tensor* updated_scale_, const std::string& amax_compute_algo, DType fp8_dtype, float margin, cudaStream_t stream) { auto& updated_amax_history = *updated_amax_history_; auto& updated_scale = *updated_scale_; - auto& updated_scale_inv = *updated_scale_inv_; - - // Number of elements in tensor - auto numel = [](const Tensor& tensor) -> size_t { - size_t acc = 1; - for (const auto& dim : tensor.data.shape) { - acc *= dim; - } - return acc; - }; // Check tensors NVTE_CHECK(amax_history.data.shape.size() == 2, "Found ", amax_history.data.shape.size(), @@ -293,18 +270,9 @@ void amax_and_scale_update(const Tensor& amax_history, const Tensor& scale, cons const size_t num_scales = amax_history.data.shape[1]; NVTE_CHECK(amax_history.data.dtype == DType::kFloat32, "Found ", dtype_name(amax_history.data.dtype), "."); - NVTE_CHECK(numel(scale) == num_scales, "Expected ", num_scales, " elements, ", "but found ", - numel(scale), "."); + NVTE_CHECK(scale.numel() == num_scales, "Expected ", num_scales, " elements, ", "but found ", + scale.numel(), "."); NVTE_CHECK(scale.data.dtype == DType::kFloat32, "Found ", dtype_name(scale.data.dtype), "."); - if (scale_inv_mask.data.dptr != nullptr) { - NVTE_CHECK(numel(scale_inv) == num_scales, "Expected ", num_scales, " elements, ", "but found ", - numel(scale_inv), "."); - NVTE_CHECK(scale_inv.data.dtype == DType::kFloat32); - NVTE_CHECK(numel(scale_inv_mask) == num_scales, "Expected ", num_scales, " elements, ", - "but found ", numel(scale_inv_mask), "."); - NVTE_CHECK(scale_inv_mask.data.dtype == DType::kByte, "Found ", - dtype_name(scale_inv_mask.data.dtype), "."); - } NVTE_CHECK(updated_amax_history.data.shape.size() == 2, "Found ", updated_amax_history.data.shape.size(), " dims."); NVTE_CHECK(updated_amax_history.data.shape[0] == amax_history_length, "Expected ", @@ -313,14 +281,10 @@ void amax_and_scale_update(const Tensor& amax_history, const Tensor& scale, cons "but found ", updated_amax_history.data.shape[1]); NVTE_CHECK(updated_amax_history.data.dtype == DType::kFloat32, "Got ", dtype_name(updated_amax_history.data.dtype), "."); - NVTE_CHECK(numel(updated_scale) == num_scales, "Expected ", num_scales, " elements, ", - "but found ", numel(updated_scale), "."); + NVTE_CHECK(updated_scale.numel() == num_scales, "Expected ", num_scales, " elements, ", + "but found ", updated_scale.numel(), "."); NVTE_CHECK(updated_scale.data.dtype == DType::kFloat32, "Got ", dtype_name(updated_scale.data.dtype), "."); - NVTE_CHECK(numel(updated_scale_inv) == num_scales, "Expected ", num_scales, " elements, ", - "but found ", numel(updated_scale_inv), "."); - NVTE_CHECK(updated_scale_inv.data.dtype == DType::kFloat32, "Got ", - dtype_name(updated_scale_inv.data.dtype), "."); // amax value to use for updating scaling factor AmaxComputeAlgo amax_compute_algo_ = AmaxComputeAlgo::INVALID; @@ -340,11 +304,8 @@ void amax_and_scale_update(const Tensor& amax_history, const Tensor& scale, cons const size_t grid_size = num_scales; amax_and_scale_update_impl::kernel<<>>( static_cast(amax_history.data.dptr), static_cast(scale.data.dptr), - static_cast(scale_inv.data.dptr), - static_cast(scale_inv_mask.data.dptr), static_cast(updated_amax_history.data.dptr), - static_cast(updated_scale.data.dptr), - static_cast(updated_scale_inv.data.dptr), amax_history_length, num_scales, + static_cast(updated_scale.data.dptr), amax_history_length, num_scales, amax_compute_algo_, scaled_max); NVTE_CHECK_CUDA(cudaGetLastError()); } @@ -352,7 +313,6 @@ void amax_and_scale_update(const Tensor& amax_history, const Tensor& scale, cons void amax_and_scale_update_after_reduction(const Tensor& amax_reduction_buffer, std::vector amax_histories, std::vector scales, - std::vector scale_invs, const std::string& amax_compute_algo, DType fp8_dtype, float margin, cudaStream_t stream) { using namespace transformer_engine; @@ -370,15 +330,6 @@ void amax_and_scale_update_after_reduction(const Tensor& amax_reduction_buffer, // Expected maximum value after scale is applied const float scaled_max = fp8_dtype_max(fp8_dtype) * std::pow(2.f, -margin); - // Number of elements in tensor - auto numel = [](const Tensor* tensor) -> size_t { - size_t acc = 1; - for (const auto& dim : tensor->data.shape) { - acc *= dim; - } - return acc; - }; - // Number of tensors in the bulk const size_t num_tensors = amax_histories.size(); size_t num_remaining_tensors = num_tensors; @@ -404,22 +355,21 @@ void amax_and_scale_update_after_reduction(const Tensor& amax_reduction_buffer, dtype_name(amax_histories[i]->data.dtype), "."); NVTE_CHECK(amax_histories[i]->data.shape.size() == 2, "Found ", amax_histories[i]->data.shape.size(), " dims"); - NVTE_CHECK(numel(amax_histories[i]) == amax_history_length * num_scale, "Expected ", + NVTE_CHECK(amax_histories[i]->numel() == amax_history_length * num_scale, "Expected ", amax_history_length * num_scale, " elements, ", "but found ", - numel(amax_histories[i]), "."); + amax_histories[i]->numel(), "."); NVTE_CHECK(scales[i]->data.dtype == DType::kFloat32, "Found ", dtype_name(scales[i]->data.dtype), "."); NVTE_CHECK(scales[i]->data.shape.size() == 1, "Found ", scales[i]->data.shape.size(), " dims"); - NVTE_CHECK(numel(scales[i]) == num_scale, "Expected ", num_scale, " elements, ", "Found ", - numel(scales[i]), "."); + NVTE_CHECK(scales[i]->numel() == num_scale, "Expected ", num_scale, " elements, ", "Found ", + scales[i]->numel(), "."); // amax parameters kernel_num_scales += num_scale; p.param[pi].num_scale = num_scale; p.param[pi].amax_history = static_cast(amax_histories[i]->data.dptr); p.param[pi].scale = static_cast(scales[i]->data.dptr); - p.param[pi].scale_inv = static_cast(scale_invs[i]->data.dptr); } // Launch CUDA kernel @@ -441,34 +391,30 @@ void amax_and_scale_update_after_reduction(const Tensor& amax_reduction_buffer, } // namespace transformer_engine void nvte_delayed_scaling_recipe_amax_and_scale_update( - const NVTETensor amax_history, const NVTETensor scale, const NVTETensor scale_inv, - const NVTETensor scale_inv_mask, NVTETensor updated_amax_history, NVTETensor updated_scale, - NVTETensor updated_scale_inv, const char* amax_compute_algo, NVTEDType fp8_dtype, float margin, + const NVTETensor amax_history, const NVTETensor scale, NVTETensor updated_amax_history, + NVTETensor updated_scale, const char* amax_compute_algo, NVTEDType fp8_dtype, float margin, cudaStream_t stream) { NVTE_API_CALL(nvte_delayed_scaling_recipe_amax_and_scale_update); using namespace transformer_engine; delayed_scaling_recipe::amax_and_scale_update( *reinterpret_cast(amax_history), *reinterpret_cast(scale), - *reinterpret_cast(scale_inv), *reinterpret_cast(scale_inv_mask), reinterpret_cast(updated_amax_history), reinterpret_cast(updated_scale), - reinterpret_cast(updated_scale_inv), amax_compute_algo, - static_cast(fp8_dtype), margin, stream); + amax_compute_algo, static_cast(fp8_dtype), margin, stream); } void nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction( const NVTETensor amax_reduction_buffer, std::vector amax_histories, - std::vector scales, std::vector scale_invs, - const char* amax_compute_algo, NVTEDType fp8_dtype, float margin, cudaStream_t stream) { + std::vector scales, const char* amax_compute_algo, NVTEDType fp8_dtype, + float margin, cudaStream_t stream) { NVTE_API_CALL(nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction); using namespace transformer_engine; size_t num_tensors = amax_histories.size(); - std::vector t_amax_histories, t_scales, t_scale_invs; + std::vector t_amax_histories, t_scales; for (size_t i = 0; i < num_tensors; i++) { t_amax_histories.push_back(reinterpret_cast(amax_histories[i])); t_scales.push_back(reinterpret_cast(scales[i])); - t_scale_invs.push_back(reinterpret_cast(scale_invs[i])); } delayed_scaling_recipe::amax_and_scale_update_after_reduction( *reinterpret_cast(amax_reduction_buffer), t_amax_histories, t_scales, - t_scale_invs, amax_compute_algo, static_cast(fp8_dtype), margin, stream); + amax_compute_algo, static_cast(fp8_dtype), margin, stream); } diff --git a/transformer_engine/common/swizzle/swizzle.cu b/transformer_engine/common/swizzle/swizzle.cu new file mode 100644 index 0000000000..a0fffc783c --- /dev/null +++ b/transformer_engine/common/swizzle/swizzle.cu @@ -0,0 +1,338 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include + +#include +#include +#include + +#include "../common.h" +#include "../util/logging.h" +#include "transformer_engine/transformer_engine.h" + +namespace { + +constexpr int TB_DIM = 32; +constexpr int NEW_SF_TILE_DIM_K = 16; +constexpr int N_SF_PER_TD_PER_TILE = 4; + +// output is in ~K-major interleaved blocks +constexpr int NEW_SF_TILE_DIM_K_I32 = NEW_SF_TILE_DIM_K / 4; +constexpr int NEW_SF_TILE_DIM_M_I32 = 32; + +template +__device__ inline void regs_shuffle_with_bit_shifts(LType* regs_vec) { + // inp, 4-byte chunks [0,1,2,3, 4,5,6,7, 8,9,10,11, 12,13,14,15] + // out, swapping byte to form new 4-byte chunks [0,4,8,12, 1,5,9,13, 2,6,10,14, 3,7,11,15] + + constexpr int N_TILE_PER_TD = sizeof(LType) / sizeof(int); + constexpr int kVectorSize = N_SF_PER_TD_PER_TILE * N_TILE_PER_TD; + int32_t new_regs[kVectorSize]; + int32_t* regs = reinterpret_cast(regs_vec); + +#pragma unroll + for (int i = 0; i < N_TILE_PER_TD; i++) { +#pragma unroll + for (int j = 0; j < N_SF_PER_TD_PER_TILE; j++) { + new_regs[i * N_SF_PER_TD_PER_TILE + j] = + (((regs[i + 0 * N_TILE_PER_TD] >> 8 * j) & 0xFF)) | + (((regs[i + 1 * N_TILE_PER_TD] >> 8 * j) & 0xFF) << 8) | + (((regs[i + 2 * N_TILE_PER_TD] >> 8 * j) & 0xFF) << 16) | + (((regs[i + 3 * N_TILE_PER_TD] >> 8 * j) & 0xFF) << 24); + } + } +#pragma unroll + for (int i = 0; i < kVectorSize; i++) regs[i] = new_regs[i]; +} + +template +__global__ void swizzle_col_scaling_kernel(const void* input, void* output, const int M, + const int K) { + constexpr int N_TILE_PER_TD = sizeof(LType) / sizeof(int); + constexpr int N_SF_PER_TD = N_TILE_PER_TD * N_SF_PER_TD_PER_TILE; + constexpr int SF_TILE_SIZE_I32 = SF_TILE_DIM_M * SF_TILE_DIM_K / 4; + + // input is in M-major + constexpr int SF_TILE_DIM_M_I32 = SF_TILE_DIM_M / 4; + constexpr int SF_TILE_DIM_K_I32 = SF_TILE_DIM_K; + + const int M_i32 = M / 4; + const int K_i32 = K; + + int m_tiles_in_tb = N_TILE_PER_TD; + int k_tiles_in_tb = TB_DIM; + if (blockIdx.x == gridDim.x - 1) { + k_tiles_in_tb = (K_i32 / SF_TILE_DIM_K_I32 - 1) % k_tiles_in_tb + 1; + } + if (blockIdx.y == gridDim.y - 1) { + m_tiles_in_tb = (M_i32 / SF_TILE_DIM_M_I32 - 1) % m_tiles_in_tb + 1; + } + + const int32_t* input_i32 = reinterpret_cast(input) + + blockIdx.x * TB_DIM * SF_TILE_DIM_K_I32 * M_i32 + + blockIdx.y * N_TILE_PER_TD * SF_TILE_DIM_M_I32; + int32_t* output_i32[N_TILE_PER_TD]; +#pragma unroll + for (int i = 0; i < m_tiles_in_tb; i++) { + output_i32[i] = reinterpret_cast(output) + blockIdx.x * TB_DIM * SF_TILE_SIZE_I32 + + (blockIdx.y * N_TILE_PER_TD + i) * SF_TILE_DIM_M_I32 * K_i32; + } + extern __shared__ int slm[]; + + // load, global -> regs + LType regs_vec[N_SF_PER_TD_PER_TILE]; + if (threadIdx.x * N_TILE_PER_TD < m_tiles_in_tb * SF_TILE_DIM_M_I32 && + threadIdx.y < k_tiles_in_tb) { +#pragma unroll + for (int i = 0; i < N_SF_PER_TD_PER_TILE; i++) { + regs_vec[i] = __ldg(reinterpret_cast( + input_i32 + (threadIdx.y * SF_TILE_DIM_K_I32 + i) * M_i32 + threadIdx.x * N_TILE_PER_TD)); + } + + // local shuffle + regs_shuffle_with_bit_shifts(regs_vec); + + // store, regs -> shared + int tM = threadIdx.x * N_SF_PER_TD; + int* slm_tile = slm + (threadIdx.y * SF_TILE_SIZE_I32 + + tM / SF_TILE_DIM_M * k_tiles_in_tb * SF_TILE_SIZE_I32); +#pragma unroll + for (int i = 0; i < N_SF_PER_TD; i++) { + /* TODO rotate_i */ + slm_tile[(tM % SF_TILE_DIM_M) / NEW_SF_TILE_DIM_M_I32 + + ((tM + i) % NEW_SF_TILE_DIM_M_I32) * NEW_SF_TILE_DIM_K_I32] = + reinterpret_cast(regs_vec)[i]; + } + } + __syncthreads(); + + // store, shared -> global + int linear_id = threadIdx.y * blockDim.x + threadIdx.x; +#pragma unroll + for (int i = 0; i < m_tiles_in_tb; i++) { + __align__(16) int4* output_v4i = reinterpret_cast(output_i32[i]); + __align__(16) int4* slm_v4i = + reinterpret_cast(slm + i * k_tiles_in_tb * SF_TILE_SIZE_I32); +#pragma unroll + for (int j = linear_id; j < SF_TILE_SIZE_I32 * k_tiles_in_tb / 4; + j += blockDim.x * blockDim.y) { + output_v4i[j] = slm_v4i[j]; + } + } +} + +template +__device__ inline void regs_shuffle(LType* regs_vec) { + constexpr int N_TILE_PER_TD = sizeof(LType) / sizeof(int); + if constexpr (N_TILE_PER_TD == 1) return; + + constexpr int kVectorSize = N_SF_PER_TD_PER_TILE * N_TILE_PER_TD; + int32_t tmp[kVectorSize]; + int32_t* ptr = reinterpret_cast(regs_vec); +#pragma unroll + for (int i = 0; i < kVectorSize; i++) + tmp[i % N_TILE_PER_TD * N_SF_PER_TD_PER_TILE + i / N_TILE_PER_TD] = ptr[i]; + +#pragma unroll + for (int i = 0; i < kVectorSize; i++) ptr[i] = tmp[i]; +} + +template +__global__ void swizzle_row_scaling_kernel(const void* input, void* output, const int M, + const int K) { + constexpr int N_TILE_PER_TD = sizeof(LType) / sizeof(int); + constexpr int N_TILES_IN_TB = TB_DIM * N_TILE_PER_TD; + + // input is in K-major + constexpr int SF_TILE_SIZE_I32 = SF_TILE_DIM_M * SF_TILE_DIM_K / 4; + constexpr int SF_TILE_DIM_M_I32 = SF_TILE_DIM_M; + + int n_tiles_in_tb = N_TILES_IN_TB; + const int K_i32 = K / 4; + if (blockIdx.x == gridDim.x - 1) { + n_tiles_in_tb = (K_i32 - 1) % N_TILES_IN_TB + 1; + } + + const int* input_i32 = reinterpret_cast(input) + + blockIdx.y * SF_TILE_DIM_M_I32 * K_i32 + blockIdx.x * N_TILES_IN_TB; + int* output_i32 = reinterpret_cast(output) + blockIdx.y * SF_TILE_DIM_M_I32 * K_i32 + + blockIdx.x * N_TILES_IN_TB * SF_TILE_SIZE_I32; + + extern __shared__ int4 slm_v4i[]; + + // load, global -> regs + LType regs_vec[N_SF_PER_TD_PER_TILE]; + if (threadIdx.x * N_TILE_PER_TD < n_tiles_in_tb) { +#pragma unroll + for (int i = 0; i < N_SF_PER_TD_PER_TILE; i++) { + regs_vec[i] = __ldg(reinterpret_cast( + input_i32 + (i * TB_DIM + threadIdx.y) * K_i32 + threadIdx.x * N_TILE_PER_TD)); + } + + // shuffle regs + regs_shuffle(regs_vec); + +// store, regs -> shared +#pragma unroll + for (int i = 0; i < N_TILE_PER_TD; i++) { + /* TODO rotate i */ + slm_v4i[(threadIdx.x * N_TILE_PER_TD + i) * SF_TILE_SIZE_I32 / 4 + threadIdx.y] = + reinterpret_cast(regs_vec)[i]; + } + } + __syncthreads(); + + // store, shared -> global + int linear_id = threadIdx.y * blockDim.x + threadIdx.x; + __align__(16) int4* output_v4i = reinterpret_cast(output_i32); +#pragma unroll + for (int i = linear_id; i < SF_TILE_SIZE_I32 * n_tiles_in_tb / 4; i += blockDim.x * blockDim.y) { + output_v4i[i] = slm_v4i[i]; + } +} + +} // namespace + +namespace transformer_engine { + +void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t stream) { + if (!is_fp8_dtype(input->dtype()) || is_delayed_tensor_scaling(input->scaling_mode)) { + NVTE_ERROR("Not implemented caling mode " + to_string(input->scaling_mode) + "."); + } + + // Do nothing if tensor is empty + if (input->data.numel() == 0) { + return; + } + + CheckInputTensor(*input, "scaling_factor_input"); + CheckInputTensor(*output, "scaling_factor_output"); + + auto& scaling_mode = input->scaling_mode; + + // 1D block scaling, row-wise or colum-wise + if (scaling_mode == NVTE_MXFP8_1D_SCALING) { + const int m = + input->has_data() ? input->scale_inv.shape[0] : input->columnwise_scale_inv.shape[1]; + const int k = + input->has_data() ? input->scale_inv.shape[1] : input->columnwise_scale_inv.shape[0]; + + constexpr int SF_TILE_DIM_M = 128; + constexpr int SF_TILE_DIM_K = 4; + + NVTE_CHECK(m % SF_TILE_DIM_M == 0, "Input should be padded in M/N dimension!"); + NVTE_CHECK(k % SF_TILE_DIM_K == 0, "Input should be padded in K dimension!"); + NVTE_CHECK(k > 0, "Input scale inverse should be 2D!"); + if (output->has_data()) { + NVTE_CHECK(m * k == std::accumulate(output->scale_inv.shape.begin(), + output->scale_inv.shape.end(), 1, std::multiplies()), + "Input.scale_inv size is not equal to Output.scale_inv size!"); + } + if (output->has_columnwise_data()) { + NVTE_CHECK(m * k == std::accumulate(output->columnwise_scale_inv.shape.begin(), + output->columnwise_scale_inv.shape.end(), 1, + std::multiplies()), + "Input.columnwise_scale_inv size is not equal to " + "Output.columnwise_scale_inv size!"); + } + + int num_tiles_m = m / SF_TILE_DIM_M; + int num_tiles_k = k / SF_TILE_DIM_K; + + dim3 block_size(TB_DIM, TB_DIM); + if (input->has_data()) { + int vec_load_size = (num_tiles_k - 1) % 4 + 1; + /* there is no int3 and misaligned if using int4/int2 */ + if (vec_load_size == 3) vec_load_size = 1; + int n_tiles_in_tb = TB_DIM * vec_load_size; + dim3 num_blocks(DIVUP(num_tiles_k, n_tiles_in_tb), num_tiles_m); + int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t); + switch (vec_load_size) { + case 4: + cudaFuncSetAttribute(swizzle_row_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); + swizzle_row_scaling_kernel + <<>>(input->scale_inv.dptr, + output->scale_inv.dptr, m, k); + break; + case 2: + cudaFuncSetAttribute(swizzle_row_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); + swizzle_row_scaling_kernel + <<>>(input->scale_inv.dptr, + output->scale_inv.dptr, m, k); + break; + case 1: + cudaFuncSetAttribute(swizzle_row_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); + swizzle_row_scaling_kernel + <<>>(input->scale_inv.dptr, + output->scale_inv.dptr, m, k); + break; + default: + NVTE_ERROR("Not valid vec_load_size."); + break; + } + } + if (input->has_columnwise_data()) { + int vec_load_size = (num_tiles_m - 1) % 4 + 1; + if (vec_load_size == 3) vec_load_size = 1; /* no int3 and misaligned if using int4/int2 */ + int n_tiles_in_tb = TB_DIM * vec_load_size; + dim3 num_blocks(DIVUP(num_tiles_k, TB_DIM), DIVUP(num_tiles_m, vec_load_size)); + int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t); + switch (vec_load_size) { + case 4: + cudaFuncSetAttribute(swizzle_col_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); + swizzle_col_scaling_kernel + <<>>( + input->columnwise_scale_inv.dptr, output->columnwise_scale_inv.dptr, m, k); + break; + case 2: + cudaFuncSetAttribute(swizzle_col_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); + swizzle_col_scaling_kernel + <<>>( + input->columnwise_scale_inv.dptr, output->columnwise_scale_inv.dptr, m, k); + break; + case 1: + cudaFuncSetAttribute(swizzle_col_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); + swizzle_col_scaling_kernel + <<>>( + input->columnwise_scale_inv.dptr, output->columnwise_scale_inv.dptr, m, k); + break; + default: + NVTE_ERROR("Not valid vec_load_size."); + break; + } + } + + // 2D block scaling + } else { + NVTE_ERROR("Not implemented for scaling_mode " + to_string(input->scaling_mode) + ", trans."); + } + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + printf("CUDA Error: %s\n", cudaGetErrorString(err)); + exit(-1); + } +} +} // namespace transformer_engine + +/* + * WIP (Phuong): + * - Opt for bank conflicts + * - Adding swizzle for 2d-block scaling. +*/ +void nvte_swizzle_scaling_factors(const NVTETensor input, NVTETensor output, cudaStream_t stream) { + NVTE_API_CALL(nvte_swizzle_scaling_factors); + using namespace transformer_engine; + swizzle_scaling_factors(reinterpret_cast(input), reinterpret_cast(output), + stream); +} diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index 11e0e319ed..faf6ec990d 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -6,71 +6,196 @@ #include +#include + #include "common.h" namespace transformer_engine { -size_t typeToSize(const transformer_engine::DType type) { +size_t typeToSize(const DType type) { TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(type, T, return TypeInfo::size;); // NOLINT(*) } -bool is_fp8_dtype(const transformer_engine::DType t) { - return t == transformer_engine::DType::kFloat8E4M3 || t == transformer_engine::DType::kFloat8E5M2; +bool is_fp8_dtype(const DType t) { return t == DType::kFloat8E4M3 || t == DType::kFloat8E5M2; } + +std::string to_string(const DType type) { + switch (type) { + case DType::kByte: + return "Byte"; + case DType::kBFloat16: + return "BFloat16"; + case DType::kFloat16: + return "Float16"; + case DType::kFloat32: + return "Float32"; + case DType::kFloat8E4M3: + return "Float8E4M3"; + case DType::kFloat8E5M2: + return "Float8E5M2"; + case DType::kFloat8E8M0: + return "Float8E8M0"; + case DType::kInt32: + return "Int32"; + case DType::kInt64: + return "Int64"; + default: + return concat_strings("Invalid type ", static_cast(type)); + } +} + +std::string to_string(const NVTEScalingMode &mode) { + switch (mode) { + case NVTE_DELAYED_TENSOR_SCALING: + return "Delayed Tensor Scaling"; + case NVTE_MXFP8_1D_SCALING: + return "MXFP8 1D Scaling"; + case NVTE_INVALID_SCALING: + return "Invalid Scaling"; + } + return "Invalid Scaling"; +} + +void CheckNoopTensor(const Tensor &t, const std::string &name) { + if (t.data.dptr != nullptr) { + NVTE_CHECK(t.numel() == 1, "Expected 1 element for ", name, " noop, but found ", t.numel(), + "."); + NVTE_CHECK(t.data.dtype == DType::kFloat32, "Found wrong dtype for ", name, + " noop. Expected kFloat32."); + } +} + +void CheckScaleTensorShape(const Tensor &t, const std::string &name) { + NVTE_CHECK(t.scaling_mode != NVTE_INVALID_SCALING, "Invalid scaling mode!"); + if (is_tensor_scaling(t.scaling_mode)) { + // per-tensor scaling + if (t.has_data()) { + NVTE_CHECK(t.scale_inv.numel() == 1, "Tensor \"", name, + "\" has invalid scale_inv shape (expected (1), got ", t.scale_inv.shape, ")"); + } + if (t.has_columnwise_data()) { + NVTE_CHECK(t.columnwise_scale_inv.numel() == 1, "Tensor \"", name, + "\" has invalid columnwise_scale_inv shape (expected (1), got ", + t.columnwise_scale_inv.shape, ")"); + } + } else { + if (t.scaling_mode == NVTE_MXFP8_1D_SCALING) { + // Need (4, 128) alignment even for e8 scaling factor + auto block_alignment = std::vector{128ul, 4ul}; + size_t expected_x, expected_y, alignment; + + if (t.has_data()) { + alignment = block_alignment[0]; + expected_x = + DIVUP(DIVUP(t.flat_first_dim(), static_cast(1)), alignment) * alignment; + alignment = block_alignment[1]; + expected_y = + DIVUP(DIVUP(t.flat_last_dim(), static_cast(32)), alignment) * alignment; + const auto &expected = std::vector{expected_x, expected_y}; + NVTE_CHECK(t.scale_inv.shape == expected, "Tensor \"", name, + "\" has invalid scale_inv shape (expected ", expected, ", got ", + t.scale_inv.shape, ")"); + } + if (t.has_columnwise_data()) { + alignment = block_alignment[1]; + expected_x = + DIVUP(DIVUP(t.flat_first_dim(), static_cast(32)), alignment) * alignment; + alignment = block_alignment[0]; + expected_y = DIVUP(DIVUP(t.flat_last_dim(), static_cast(1)), alignment) * alignment; + const auto &expected = std::vector{expected_x, expected_y}; + NVTE_CHECK(t.columnwise_scale_inv.shape == expected, "Tensor \"", name, + "\" has invalid columnwise_scale_inv shape (expected ", expected, ", got ", + t.columnwise_scale_inv.shape, ")"); + } + } + } } void CheckInputTensor(const Tensor &t, const std::string &name) { - const DType type = t.data.dtype; + const DType type = t.dtype(); if (is_fp8_dtype(type)) { // FP8 input needs to have scale_inv - NVTE_CHECK(t.scale_inv.dptr != nullptr, "FP8 input " + name + " must have inverse of scale."); - NVTE_CHECK(t.scale_inv.dtype == DType::kFloat32); - NVTE_CHECK(t.scale_inv.shape == std::vector{1}); + if (t.has_data()) { + NVTE_CHECK(t.scale_inv.dptr != nullptr, "FP8 scaling factor input ", name, + "_scale_inverse must be allocated"); + NVTE_CHECK(t.scale_inv.dtype == DType::kFloat32 || t.scale_inv.dtype == DType::kFloat8E8M0, + "FP8 scaling factor input ", name, + "_scale_inverse has invalid dtype " + "(expected Float32 or Byte, got ", + to_string(t.scale_inv.dtype), ")"); + } + if (t.has_columnwise_data()) { + NVTE_CHECK(t.columnwise_scale_inv.dptr != nullptr, "FP8 scaling factor input ", name, + "_columnwise_scale_inverse must be allocated"); + NVTE_CHECK(t.columnwise_scale_inv.dtype == DType::kFloat32 || + t.columnwise_scale_inv.dtype == DType::kFloat8E8M0, + "FP8 scaling factor input ", name, + "_columnwise_scale_inverse has invalid dtype " + "(expected Float32 or Byte, got ", + to_string(t.columnwise_scale_inv.dtype), ")"); + } } else { - NVTE_CHECK(t.scale.dptr == nullptr, "Scale is not supported for non-FP8 input " + name + "."); - NVTE_CHECK(t.amax.dptr == nullptr, "Amax is not supported for non-FP8 input " + name + "."); - NVTE_CHECK(t.scale_inv.dptr == nullptr, - "Scale_inv is not supported for non-FP8 input " + name + "."); + NVTE_CHECK(t.scale.dptr == nullptr, "Scale is not supported for non-FP8 input ", name); + NVTE_CHECK(t.amax.dptr == nullptr, "Amax is not supported for non-FP8 input ", name); + NVTE_CHECK(t.scale_inv.dptr == nullptr, "Scale_inv is not supported for non-FP8 input ", name); + NVTE_CHECK(t.columnwise_scale_inv.dptr == nullptr, + "Scale_inv is not supported for non-FP8 input ", name); } - NVTE_CHECK(t.data.dptr != nullptr, "Input " + name + " is not allocated!"); + NVTE_CHECK(t.has_data() || t.has_columnwise_data(), "Input ", name, " is not allocated!"); + + CheckScaleTensorShape(t, name); } void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empty) { - const DType type = t.data.dtype; + const DType type = t.dtype(); if (is_fp8_dtype(type)) { - // FP8 output needs to have scale, amax and scale_inv - NVTE_CHECK(t.amax.dptr != nullptr, "FP8 output " + name + " must have amax tensor."); - NVTE_CHECK(t.amax.dtype == DType::kFloat32); - NVTE_CHECK(t.amax.shape == std::vector{1}); - NVTE_CHECK(t.scale_inv.dptr != nullptr, "FP8 output " + name + " must have scale."); - NVTE_CHECK(t.scale_inv.dtype == DType::kFloat32); - NVTE_CHECK(t.scale_inv.shape == std::vector{1}); - NVTE_CHECK(t.scale.dptr != nullptr, "FP8 output " + name + " must have inverse of scale."); - NVTE_CHECK(t.scale.dtype == DType::kFloat32); - NVTE_CHECK(t.scale.shape == std::vector{1}); + // FP8 output needs to have scale, scale_inv and (if delayed scaling) amax + if (t.scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { + NVTE_CHECK(t.amax.dptr != nullptr, "FP8 output ", name, " must have amax tensor"); + NVTE_CHECK(t.amax.dtype == DType::kFloat32, "Invalid amax dtype (expected ", + to_string(DType::kFloat32), ", got ", to_string(t.amax.dtype), ")"); + NVTE_CHECK(product(t.amax.shape) == 1, "Invalid shape of amax in output ", name, + " (expected 1 entry, got shape=", t.amax.shape, ")"); + } + if (t.has_data()) { + NVTE_CHECK(t.scale_inv.dptr != nullptr, "FP8 scaling factor output ", name, + "_scale_inverse must be allocated"); + NVTE_CHECK(t.scale_inv.dtype == DType::kFloat32 || t.scale_inv.dtype == DType::kFloat8E8M0, + "FP8 scaling factor output ", name, + "_scale_inverse has invalid dtype " + "(expected Float32 or Float8E8M0, got ", + to_string(t.scale_inv.dtype), ")"); + } + if (t.has_columnwise_data()) { + NVTE_CHECK(t.columnwise_scale_inv.dptr != nullptr, "FP8 scaling factor output ", name, + "_columnwise_scale_inverse must be allocated"); + NVTE_CHECK(t.columnwise_scale_inv.dtype == DType::kFloat32 || + t.columnwise_scale_inv.dtype == DType::kFloat8E8M0, + "FP8 scaling factor output ", name, + "_columnwise_scale_inverse has invalid dtype " + "(expected Float32 or Float8E8M0, got ", + to_string(t.columnwise_scale_inv.dtype), ")"); + } } else { - NVTE_CHECK(t.scale.dptr == nullptr, "Scale is not supported for non-FP8 output " + name + "."); - NVTE_CHECK(t.amax.dptr == nullptr, "Amax is not supported for non-FP8 output " + name + "."); - NVTE_CHECK(t.scale_inv.dptr == nullptr, - "Scale_inv is not supported for non-FP8 output " + name + "."); + NVTE_CHECK(t.scale.dptr == nullptr, "Scale is not supported for non-FP8 output ", name); + NVTE_CHECK(t.amax.dptr == nullptr, "Amax is not supported for non-FP8 output ", name); + NVTE_CHECK(t.scale_inv.dptr == nullptr, "Scale_inv is not supported for non-FP8 output ", name); + NVTE_CHECK(t.columnwise_scale_inv.dptr == nullptr, + "Scale_inv is not supported for non-FP8 input ", name); } if (!allow_empty) { - NVTE_CHECK(t.data.dptr != nullptr, "Output " + name + " is not allocated!"); + NVTE_CHECK(t.has_data() || t.has_columnwise_data(), "Output ", name, " is not allocated!"); } + + CheckScaleTensorShape(t, name); } } // namespace transformer_engine -NVTETensor nvte_create_tensor(void *dptr, const NVTEShape shape, const NVTEDType dtype, float *amax, - float *scale, float *scale_inv) { +NVTETensor nvte_create_tensor(NVTEScalingMode scaling_mode) { transformer_engine::Tensor *ret = new transformer_engine::Tensor; - ret->data.dptr = dptr; - ret->data.shape = std::vector(shape.data, shape.data + shape.ndim); - ret->data.dtype = static_cast(dtype); - ret->amax.dptr = amax; - ret->scale.dptr = scale; - ret->scale_inv.dptr = scale_inv; + ret->scaling_mode = scaling_mode; return ret; } @@ -81,30 +206,65 @@ void nvte_destroy_tensor(NVTETensor tensor) { } NVTEDType nvte_tensor_type(const NVTETensor tensor) { + if (tensor == nullptr) return kNVTEFloat32; return static_cast( - reinterpret_cast(tensor)->data.dtype); + reinterpret_cast(tensor)->dtype()); } NVTEShape nvte_tensor_shape(const NVTETensor tensor) { + if (tensor == nullptr) return {nullptr, 0}; const auto &t = *reinterpret_cast(tensor); NVTEShape ret; + + // FP8 tensor keeps shape in rowwise data + if (t.scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { + ret.data = t.data.shape.data(); + ret.ndim = t.data.shape.size(); + return ret; + } + + // Get shape based on what data is available + if (t.has_data()) { + ret.data = t.data.shape.data(); + ret.ndim = t.data.shape.size(); + return ret; + } + if (t.has_columnwise_data()) { + ret.data = t.columnwise_data.shape.data(); + ret.ndim = t.columnwise_data.shape.size(); + return ret; + } + + // Tensor has no data ret.data = t.data.shape.data(); ret.ndim = t.data.shape.size(); return ret; } +NVTEShape nvte_tensor_columnwise_shape(const NVTETensor tensor) { + if (tensor == nullptr) return {nullptr, 0}; + const auto &t = *reinterpret_cast(tensor); + NVTEShape ret; + ret.data = t.columnwise_data.shape.data(); + ret.ndim = t.columnwise_data.shape.size(); + return ret; +} + size_t nvte_tensor_ndim(const NVTETensor tensor) { + if (tensor == nullptr) return 0; const auto &t = *reinterpret_cast(tensor); return t.data.shape.size(); } size_t nvte_tensor_size(const NVTETensor tensor, const size_t dim) { + if (tensor == nullptr) return 0; const auto &t = *reinterpret_cast(tensor); NVTE_CHECK(dim >= 0 && dim < t.data.shape.size(), "Invalid dimension index: ", dim); return t.data.shape[dim]; } size_t nvte_tensor_numel(const NVTETensor tensor) { + if (tensor == nullptr) return 0; const auto &t = *reinterpret_cast(tensor); size_t numel = 1; for (auto size : t.data.shape) { @@ -114,16 +274,25 @@ size_t nvte_tensor_numel(const NVTETensor tensor) { } size_t nvte_tensor_element_size(const NVTETensor tensor) { + if (tensor == nullptr) return sizeof(float); const auto &t = *reinterpret_cast(tensor); return transformer_engine::typeToSize(t.data.dtype); } void *nvte_tensor_data(const NVTETensor tensor) { + if (tensor == nullptr) return nullptr; const auto &t = *reinterpret_cast(tensor); return t.data.dptr; } +void *nvte_tensor_columnwise_data(const NVTETensor tensor) { + if (tensor == nullptr) return nullptr; + const auto &t = *reinterpret_cast(tensor); + return t.columnwise_data.dptr; +} + float *nvte_tensor_amax(const NVTETensor tensor) { + if (tensor == nullptr) return nullptr; const auto &t = *reinterpret_cast(tensor); NVTE_CHECK(t.amax.dtype == transformer_engine::DType::kFloat32, "Tensor's amax must have Float32 type!"); @@ -131,6 +300,7 @@ float *nvte_tensor_amax(const NVTETensor tensor) { } float *nvte_tensor_scale(const NVTETensor tensor) { + if (tensor == nullptr) return nullptr; const auto &t = *reinterpret_cast(tensor); NVTE_CHECK(t.scale.dtype == transformer_engine::DType::kFloat32, "Tensor's scale must have Float32 type!"); @@ -138,12 +308,83 @@ float *nvte_tensor_scale(const NVTETensor tensor) { } float *nvte_tensor_scale_inv(const NVTETensor tensor) { + if (tensor == nullptr) return nullptr; const auto &t = *reinterpret_cast(tensor); - NVTE_CHECK(t.scale_inv.dtype == transformer_engine::DType::kFloat32, - "Tensor's inverse of scale must have Float32 type!"); return reinterpret_cast(t.scale_inv.dptr); } +void *nvte_tensor_columnwise_scale_inv(const NVTETensor tensor) { + if (tensor == nullptr) return nullptr; + const auto &t = *reinterpret_cast(tensor); + return t.columnwise_scale_inv.dptr; +} + +NVTEShape nvte_tensor_scale_inv_shape(const NVTETensor tensor) { + if (tensor == nullptr) return {nullptr, 0}; + const auto &t = *reinterpret_cast(tensor); + NVTEShape ret; + ret.data = t.scale_inv.shape.data(); + ret.ndim = t.scale_inv.shape.size(); + return ret; +} + +void nvte_set_tensor_param(NVTETensor *tensor, NVTETensorParam param_name, + const NVTEBasicTensor *param) { + NVTE_CHECK(tensor != nullptr, "Tensor pointer can't be NULL."); + NVTE_CHECK(*tensor != nullptr, "Tensor is not allocated."); + auto &t = *reinterpret_cast(*tensor); + switch (param_name) { + case kNVTERowwiseData: + t.data = *param; + break; + case kNVTEColumnwiseData: + t.columnwise_data = *param; + break; + case kNVTEScale: + t.scale = *param; + break; + case kNVTEAmax: + t.amax = *param; + break; + case kNVTERowwiseScaleInv: + t.scale_inv = *param; + break; + case kNVTEColumnwiseScaleInv: + t.columnwise_scale_inv = *param; + break; + default: + NVTE_ERROR("Unknown tensor parameter!"); + } +} + +NVTEBasicTensor nvte_get_tensor_param(const NVTETensor tensor, NVTETensorParam param_name) { + if (tensor == nullptr) { + return {nullptr, kNVTEFloat32, {nullptr, 0}}; + } + const auto &t = *reinterpret_cast(tensor); + switch (param_name) { + case kNVTERowwiseData: + return t.data; + case kNVTEColumnwiseData: + return t.columnwise_data; + case kNVTEScale: + return t.scale; + case kNVTEAmax: + return t.amax; + case kNVTERowwiseScaleInv: + return t.scale_inv; + case kNVTEColumnwiseScaleInv: + return t.columnwise_scale_inv; + default: + NVTE_ERROR("Unknown tensor parameter!"); + } +} + +NVTEScalingMode nvte_tensor_scaling_mode(const NVTETensor tensor) { + const auto &t = *reinterpret_cast(tensor); + return t.scaling_mode; +} + void nvte_tensor_pack_create(NVTETensorPack *pack) { for (int i = 0; i < pack->MAX_SIZE; i++) { pack->tensors[i] = reinterpret_cast(new transformer_engine::Tensor); @@ -156,3 +397,18 @@ void nvte_tensor_pack_destroy(NVTETensorPack *pack) { delete t; } } + +void nvte_zero_tensor(const NVTETensor tensor, cudaStream_t stream) { + const auto &t = *reinterpret_cast(tensor); + // Zero out tensor data if allocated + if (t.data.dptr != nullptr) { + size_t size_in_bytes = nvte_tensor_element_size(tensor) * nvte_tensor_numel(tensor); + cudaMemsetAsync(t.data.dptr, 0, size_in_bytes, stream); + } + // Set amax to 0 if allocated + if (t.amax.dptr != nullptr) { + float zero = 0.0f; + cudaMemcpyAsync(t.amax.dptr, &zero, sizeof(float), cudaMemcpyHostToDevice, stream); + } + cudaStreamSynchronize(stream); +} diff --git a/transformer_engine/common/transpose/cast_transpose.cu b/transformer_engine/common/transpose/cast_transpose.cu index b49c61195e..4cdb39b70a 100644 --- a/transformer_engine/common/transpose/cast_transpose.cu +++ b/transformer_engine/common/transpose/cast_transpose.cu @@ -10,12 +10,12 @@ #include -#include "../common.h" #include "../util/rtc.h" #include "../util/string.h" #include "../utils.cuh" +#include "cast_transpose.h" -namespace transformer_engine { +namespace transformer_engine::detail { namespace { @@ -217,159 +217,143 @@ __global__ void __launch_bounds__(block_size) cast_transpose_general_kernel( } // namespace -void cast_transpose(const Tensor &input, const Tensor &noop, Tensor *cast_output_, - Tensor *transposed_output_, cudaStream_t stream) { - Tensor &cast_output = *cast_output_; - Tensor &transposed_output = *transposed_output_; +void cast_transpose(const Tensor &input, const Tensor &noop, Tensor *output_, cudaStream_t stream) { + Tensor &output = *output_; - // Check no-op flag - if (noop.data.dptr != nullptr) { - size_t numel = 1; - for (const auto &dim : noop.data.shape) { - numel *= dim; - } - NVTE_CHECK(numel == 1, "Expected 1 element, but found ", numel, "."); - NVTE_CHECK(noop.data.dtype == DType::kFloat32); - NVTE_CHECK(noop.data.dptr != nullptr); - } - - // Check tensor dims + CheckNoopTensor(noop, "cast_transpose_noop"); CheckInputTensor(input, "cast_transpose_input"); - CheckOutputTensor(cast_output, "cast_output"); - CheckOutputTensor(transposed_output, "transposed_output"); - NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions."); - NVTE_CHECK(cast_output.data.shape.size() == 2, "Cast output must have 2 dimensions."); - NVTE_CHECK(transposed_output.data.shape.size() == 2, "Transposed output must have 2 dimensions."); - const size_t row_length = input.data.shape[1]; - const size_t num_rows = input.data.shape[0]; - NVTE_CHECK(cast_output.data.shape[0] == num_rows, "Wrong dimension of cast output."); - NVTE_CHECK(cast_output.data.shape[1] == row_length, "Wrong dimension of cast output."); - NVTE_CHECK(transposed_output.data.shape[0] == row_length, - "Wrong dimension of transposed output."); - NVTE_CHECK(transposed_output.data.shape[1] == num_rows, "Wrong dimension of transposed output."); - - // Check tensor pointers - NVTE_CHECK(input.data.dptr != nullptr, "Input is not allocated."); - NVTE_CHECK(cast_output.data.dptr != nullptr, "Cast output is not allocated."); - NVTE_CHECK(transposed_output.data.dptr != nullptr, "Transposed output is not allocated."); - NVTE_CHECK(cast_output.data.dtype == transposed_output.data.dtype, + CheckOutputTensor(output, "cast_transpose_output"); + + // Check that inputs and outputs are available + NVTE_CHECK(input.has_data(), "Input is not allocated"); + NVTE_CHECK(output.has_data(), "Output rowwise data is not allocated"); + NVTE_CHECK(output.has_columnwise_data(), "Output columnwise is not allocated"); + + // Flatten tensor to 2D + NVTE_CHECK(input.data.shape == output.data.shape, + "Input and output shapes do not match (input=", input.data.shape, + ", output=", output.data.shape); + const size_t row_length = input.flat_last_dim(); + const size_t num_rows = input.flat_first_dim(); + NVTE_CHECK(output.flat_first_dim() == num_rows && output.flat_last_dim() == row_length, + "Invalid output dimensions (expected ", std::vector{num_rows, row_length}, + ", got ", std::vector{output.flat_first_dim(), output.flat_last_dim()}, ")"); + + // Check that cast and transposed output data matches + NVTE_CHECK(output.data.dtype == output.columnwise_data.dtype, "Cast and transposed output types must match."); - NVTE_CHECK(cast_output.amax.dptr == transposed_output.amax.dptr, - "Cast and transposed outputs need to share amax tensor."); - NVTE_CHECK(cast_output.scale.dptr == transposed_output.scale.dptr, - "Cast and transposed outputs need to share scale tensor."); - NVTE_CHECK(cast_output.scale_inv.dptr == transposed_output.scale_inv.dptr, + NVTE_CHECK(output.scale_inv.dptr == output.columnwise_scale_inv.dptr, "Cast and transposed outputs need to share scale-inverse tensor."); TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( - input.data.dtype, InputType, + input.dtype(), InputType, TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( - cast_output.data.dtype, OutputType, - constexpr const char *itype_name = TypeInfo::name; - constexpr const char *otype_name = TypeInfo::name; - constexpr size_t itype_size = sizeof(InputType); - constexpr size_t otype_size = sizeof(OutputType); - - // Choose between runtime-compiled or statically-compiled kernel - const bool aligned = - (row_length % THREADS_PER_WARP == 0 && num_rows % THREADS_PER_WARP == 0); - if (aligned && rtc::is_enabled()) { // Runtime-compiled tuned kernel - // Pick kernel config - std::vector kernel_configs; - kernel_configs.reserve(16); - const size_t sm_count = static_cast(cuda::sm_count()); - auto add_config = [&](size_t load_size, size_t store_size) { - kernel_configs.emplace_back(row_length, num_rows, itype_size, otype_size, load_size, - store_size, sm_count); - }; - add_config(8, 8); - add_config(4, 8); - add_config(8, 4); - add_config(4, 4); - add_config(2, 8); - add_config(8, 2); - add_config(2, 4); - add_config(4, 2); - add_config(2, 2); - add_config(1, 8); - add_config(8, 1); - add_config(1, 4); - add_config(4, 1); - add_config(1, 2); - add_config(2, 1); - add_config(1, 1); - const auto &kernel_config = - *std::min_element(kernel_configs.begin(), kernel_configs.end()); - NVTE_CHECK(kernel_config.valid, "invalid kernel config"); - const size_t load_size = kernel_config.load_size; - const size_t store_size = kernel_config.store_size; - const size_t num_blocks = kernel_config.num_blocks; - - // Compile NVRTC kernel if needed and launch - auto &rtc_manager = rtc::KernelManager::instance(); - const std::string kernel_label = concat_strings( - "cast_transpose" - ",itype=", - itype_name, ",otype=", otype_name, ",load_size=", load_size, - ",store_size=", store_size); - if (!rtc_manager.is_compiled(kernel_label)) { - std::string code = string_code_transpose_rtc_cast_transpose_cu; - code = regex_replace(code, "__ITYPE__", itype_name); - code = regex_replace(code, "__OTYPE__", otype_name); - code = regex_replace(code, "__LOAD_SIZE__", load_size); - code = regex_replace(code, "__STORE_SIZE__", store_size); - code = regex_replace(code, "__WARPS_PER_TILE__", warps_per_tile); - code = regex_replace(code, "__BLOCK_SIZE__", block_size); - rtc_manager.compile(kernel_label, "cast_transpose_optimized_kernel", code, - "transformer_engine/common/transpose/rtc/cast_transpose.cu"); + output.dtype(), OutputType, + if (is_delayed_tensor_scaling(output.scaling_mode)) { + constexpr const char *itype_name = TypeInfo::name; + constexpr const char *otype_name = TypeInfo::name; + constexpr size_t itype_size = sizeof(InputType); + constexpr size_t otype_size = sizeof(OutputType); + + // Choose between runtime-compiled or statically-compiled kernel + const bool aligned = + (row_length % THREADS_PER_WARP == 0 && num_rows % THREADS_PER_WARP == 0); + if (aligned && rtc::is_enabled()) { // Runtime-compiled tuned kernel + // Pick kernel config + std::vector kernel_configs; + kernel_configs.reserve(16); + const size_t sm_count = static_cast(cuda::sm_count()); + auto add_config = [&](size_t load_size, size_t store_size) { + kernel_configs.emplace_back(row_length, num_rows, itype_size, otype_size, load_size, + store_size, sm_count); + }; + add_config(8, 8); + add_config(4, 8); + add_config(8, 4); + add_config(4, 4); + add_config(2, 8); + add_config(8, 2); + add_config(2, 4); + add_config(4, 2); + add_config(2, 2); + add_config(1, 8); + add_config(8, 1); + add_config(1, 4); + add_config(4, 1); + add_config(1, 2); + add_config(2, 1); + add_config(1, 1); + const auto &kernel_config = + *std::min_element(kernel_configs.begin(), kernel_configs.end()); + NVTE_CHECK(kernel_config.valid, "invalid kernel config"); + const size_t load_size = kernel_config.load_size; + const size_t store_size = kernel_config.store_size; + const size_t num_blocks = kernel_config.num_blocks; + + // Compile NVRTC kernel if needed and launch + auto &rtc_manager = rtc::KernelManager::instance(); + const std::string kernel_label = concat_strings( + "cast_transpose" + ",itype=", + itype_name, ",otype=", otype_name, ",load_size=", load_size, + ",store_size=", store_size); + if (!rtc_manager.is_compiled(kernel_label)) { + std::string code = string_code_transpose_rtc_cast_transpose_cu; + code = regex_replace(code, "__ITYPE__", itype_name); + code = regex_replace(code, "__OTYPE__", otype_name); + code = regex_replace(code, "__LOAD_SIZE__", load_size); + code = regex_replace(code, "__STORE_SIZE__", store_size); + code = regex_replace(code, "__WARPS_PER_TILE__", warps_per_tile); + code = regex_replace(code, "__BLOCK_SIZE__", block_size); + rtc_manager.compile(kernel_label, "cast_transpose_optimized_kernel", code, + "transformer_engine/common/transpose/rtc/cast_transpose.cu"); + } + rtc_manager.launch(kernel_label, num_blocks, block_size, 0, stream, + static_cast(input.data.dptr), + reinterpret_cast(noop.data.dptr), + static_cast(output.data.dptr), + static_cast(output.columnwise_data.dptr), + static_cast(output.scale.dptr), + static_cast(output.amax.dptr), + static_cast(output.scale_inv.dptr), row_length, num_rows); + } else { // Statically-compiled general kernel + constexpr size_t load_size = 4; + constexpr size_t store_size = 4; + constexpr size_t row_tile_size = load_size / itype_size * THREADS_PER_WARP; + constexpr size_t col_tile_size = store_size / otype_size * THREADS_PER_WARP; + const int num_blocks = + (DIVUP(row_length, row_tile_size) * DIVUP(num_rows, col_tile_size)); + cast_transpose_general_kernel + <<>>( + static_cast(input.data.dptr), + reinterpret_cast(noop.data.dptr), + static_cast(output.data.dptr), + static_cast(output.columnwise_data.dptr), + static_cast(output.scale.dptr), + static_cast(output.amax.dptr), + static_cast(output.scale_inv.dptr), row_length, num_rows); } - rtc_manager.launch(kernel_label, num_blocks, block_size, 0, stream, - static_cast(input.data.dptr), - reinterpret_cast(noop.data.dptr), - static_cast(cast_output.data.dptr), - static_cast(transposed_output.data.dptr), - static_cast(cast_output.scale.dptr), - static_cast(cast_output.amax.dptr), - static_cast(cast_output.scale_inv.dptr), row_length, - num_rows); - } else { // Statically-compiled general kernel - constexpr size_t load_size = 4; - constexpr size_t store_size = 4; - constexpr size_t row_tile_size = load_size / itype_size * THREADS_PER_WARP; - constexpr size_t col_tile_size = store_size / otype_size * THREADS_PER_WARP; - const int num_blocks = - (DIVUP(row_length, row_tile_size) * DIVUP(num_rows, col_tile_size)); - cast_transpose_general_kernel - <<>>( - static_cast(input.data.dptr), - reinterpret_cast(noop.data.dptr), - static_cast(cast_output.data.dptr), - static_cast(transposed_output.data.dptr), - static_cast(cast_output.scale.dptr), - static_cast(cast_output.amax.dptr), - static_cast(cast_output.scale_inv.dptr), row_length, num_rows); + } else { + NVTE_ERROR("Not implemented scaling mode: ", to_string(output.scaling_mode)); }); // NOLINT(*) ); // NOLINT(*) } -} // namespace transformer_engine +} // namespace transformer_engine::detail -void nvte_cast_transpose(const NVTETensor input, NVTETensor cast_output, - NVTETensor transposed_output, cudaStream_t stream) { +void nvte_cast_transpose(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_cast_transpose); using namespace transformer_engine; auto noop = Tensor(); - cast_transpose(*reinterpret_cast(input), noop, - reinterpret_cast(cast_output), - reinterpret_cast(transposed_output), stream); + transformer_engine::detail::cast_transpose(*reinterpret_cast(input), noop, + reinterpret_cast(output), stream); } -void nvte_cast_transpose_with_noop(const NVTETensor input, const NVTETensor noop, - NVTETensor cast_output, NVTETensor transposed_output, +void nvte_cast_transpose_with_noop(const NVTETensor input, const NVTETensor noop, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_cast_transpose_with_noop); using namespace transformer_engine; - cast_transpose(*reinterpret_cast(input), *reinterpret_cast(noop), - reinterpret_cast(cast_output), - reinterpret_cast(transposed_output), stream); + transformer_engine::detail::cast_transpose(*reinterpret_cast(input), + *reinterpret_cast(noop), + reinterpret_cast(output), stream); } diff --git a/transformer_engine/common/transpose/cast_transpose.h b/transformer_engine/common/transpose/cast_transpose.h new file mode 100644 index 0000000000..ed9bd5f5f7 --- /dev/null +++ b/transformer_engine/common/transpose/cast_transpose.h @@ -0,0 +1,28 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_COMMON_TRANSPOSE_CAST_TRANSPOSE_H_ +#define TRANSFORMER_ENGINE_COMMON_TRANSPOSE_CAST_TRANSPOSE_H_ + +#include "../common.h" + +namespace transformer_engine::detail { + +void cast_transpose(const Tensor &input, const Tensor &noop, Tensor *output_, cudaStream_t stream); + +template +void cast_transpose_fused(const Tensor &input, const Tensor *act_input, Tensor *output, + Tensor *dbias, Tensor *workspace, cudaStream_t stream); + +template +void dgated_act_cast_transpose(const Tensor &input, const Tensor &gated_act_input, Tensor *output, + cudaStream_t stream); + +} // namespace transformer_engine::detail + +#endif // TRANSFORMER_ENGINE_COMMON_TRANSPOSE_CAST_TRANSPOSE_H_ diff --git a/transformer_engine/common/transpose/cast_transpose_fusion.cu b/transformer_engine/common/transpose/cast_transpose_fusion.cu index ed919c8b94..8347e117ce 100644 --- a/transformer_engine/common/transpose/cast_transpose_fusion.cu +++ b/transformer_engine/common/transpose/cast_transpose_fusion.cu @@ -8,18 +8,19 @@ #include #include -#include +#include +#include #include -#include "../common.h" #include "../util/math.h" #include "../util/rtc.h" #include "../util/string.h" #include "../utils.cuh" +#include "cast_transpose.h" namespace transformer_engine { -namespace { +namespace detail { // String with RTC kernel implementation #include "string_code_transpose_rtc_cast_transpose_fusion_cu.h" @@ -177,16 +178,31 @@ inline __device__ void cast_and_transpose_regs(const CVec (&in)[nvec_out], void populate_cast_transpose_dbias_workspace_config(const Tensor &cast_output, /*cast*/ Tensor *workspace, const int nvec_out) { - const size_t row_length = cast_output.data.shape[1]; - const size_t num_rows = cast_output.data.shape[0]; + const size_t row_length = cast_output.flat_last_dim(); + const size_t num_rows = cast_output.flat_first_dim(); const size_t tile_size_y = (nvec_out * THREADS_PER_WARP); NVTE_CHECK(num_rows % nvec_out == 0, "Unsupported shape."); const size_t num_rows_partial_dbias = DIVUP(num_rows, tile_size_y); - workspace->data.shape = {num_rows_partial_dbias, row_length}; - workspace->data.dtype = DType::kFloat32; + if (workspace->data.dptr == nullptr) { + workspace->data.shape = {num_rows_partial_dbias, row_length}; + workspace->data.dtype = DType::kFloat32; + } else { + // Check that workspace matches expected size + const size_t workspace_size = + std::accumulate(workspace->data.shape.begin(), workspace->data.shape.end(), 1, + std::multiplies()) * + typeToSize(workspace->data.dtype); + const size_t required_size = num_rows_partial_dbias * row_length * typeToSize(DType::kFloat32); + NVTE_CHECK(!workspace->data.shape.empty(), "Invalid workspace dims (expected (", + num_rows_partial_dbias, ",", row_length, "), found ())"); + NVTE_CHECK(workspace_size >= required_size, "Invalid workspace (expected dims=(", + num_rows_partial_dbias, ",", row_length, "), dtype=", to_string(DType::kFloat32), + "; found dims=", workspace->data.shape, + ", dtype=", typeToSize(workspace->data.dtype), ")"); + } } template @@ -248,11 +264,13 @@ void reduce_dbias(const Tensor &workspace, Tensor *dbias, const size_t row_lengt reduce_dbias_num_rows); } -template +template __global__ void __launch_bounds__(cast_transpose_num_threads) cast_transpose_fused_kernel_notaligned(const Param param, const size_t row_length, const size_t num_rows, const size_t num_tiles) { + static_assert(!(IS_DACT && IS_ACT), "forward and backward activation are mutually exclusive"); using IType = typename Param::InputType; using IType2 = typename Param::InputType2; using OType = typename Param::OutputType; @@ -373,6 +391,8 @@ __global__ void __launch_bounds__(cast_transpose_num_threads) if constexpr (IS_DACT) { after_dact[j].data.elt[k] = CType(in[current_in ^ 1][j].data.elt[k]) * OP(act_in[current_in ^ 1][j].data.elt[k], {}); + } else if constexpr (IS_ACT) { + after_dact[j].data.elt[k] = OP(in[current_in ^ 1][j].data.elt[k], {}); } else { after_dact[j].data.elt[k] = CType(in[current_in ^ 1][j].data.elt[k]); } @@ -449,78 +469,96 @@ __global__ void __launch_bounds__(cast_transpose_num_threads) } static const char *ActTypeToString[] = { - "NoAct", // 0 - "Sigmoid", // 1 - "GeLU", // 2 - "QGeLU", // 3 - "SiLU", // 4 - "ReLU", // 5 - "SReLU" // 6 + "none", // 0 + "sigmoid", // 1 + "dsigmoid", // 2 + "gelu", // 3 + "dgelu", // 4 + "qgelu", // 5 + "dqgelu", // 6 + "silu", // 7 + "dsilu", // 8 + "relu", // 9 + "drelu", // 10 + "srelu", // 11 + "dsrelu" // 12 }; template -int get_dactivation_type() { - if (OP == &sigmoid) { - return 1; - } else if (OP == &dgelu) { - return 2; - } else if (OP == &dqgelu) { - return 3; - } else if (OP == &dsilu) { - return 4; - } else if (OP == &drelu) { - return 5; - } else if (OP == &dsrelu) { - return 6; - } else { - return 0; +constexpr int get_activation_type() { + constexpr decltype(OP) ActivationList[] = { + nullptr, // 0 + &sigmoid, // 1 + &dsigmoid, // 2 + &gelu, // 3 + &dgelu, // 4 + &qgelu, // 5 + &dqgelu, // 6 + &silu, // 7 + &dsilu, // 8 + &relu, // 9 + &drelu, // 10 + &srelu, // 11 + &dsrelu // 12 + }; +#pragma unroll + for (int i = 0; i < sizeof(ActivationList) / sizeof(ActivationList[0]); ++i) { + if (OP == ActivationList[i]) { + return i; + } } + return 0; } -template -void cast_transpose_fused(const Tensor &input, const Tensor &act_input, Tensor *cast_output, - Tensor *transposed_output, Tensor *dbias, Tensor *workspace, - cudaStream_t stream) { - if (workspace->data.dptr != nullptr) { +void cast_transpose_fused(const Tensor &input, const Tensor *act_input, Tensor *output, + Tensor *dbias, Tensor *workspace, cudaStream_t stream) { + // Check tensors, unless querying dbias workspace + if (!IS_DBIAS || workspace->data.dptr != nullptr) { CheckInputTensor(input, "cast_transpose_fused_input"); - CheckOutputTensor(*cast_output, "cast_output"); - CheckOutputTensor(*transposed_output, "transposed_output"); - if constexpr (IS_DBIAS) CheckOutputTensor(*dbias, "dbias"); - if constexpr (IS_DACT) CheckInputTensor(act_input, "act_input"); + CheckOutputTensor(*output, "output"); + if constexpr (IS_DBIAS) { + NVTE_CHECK(dbias != nullptr && dbias->has_data()); + CheckOutputTensor(*dbias, "dbias"); + } + if constexpr (IS_DACT) { + NVTE_CHECK(act_input != nullptr && act_input->has_data()); + CheckInputTensor(*act_input, "act_input"); + } } - NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions."); - NVTE_CHECK(cast_output->data.shape.size() == 2, "C output must have 2 dimensions."); - NVTE_CHECK(transposed_output->data.shape.size() == 2, "T output must have 2 dimensions."); - NVTE_CHECK(input.data.shape == cast_output->data.shape, - "Input and C output must have the same shape."); - const size_t row_length = input.data.shape[1]; - const size_t num_rows = input.data.shape[0]; + // Check that inputs and outputs are available + NVTE_CHECK(input.has_data(), "Input is not allocated"); + NVTE_CHECK(output->has_data(), "Output rowwise data is not allocated"); + NVTE_CHECK(output->has_columnwise_data(), "Output columnwise data is not allocated"); - NVTE_CHECK(transposed_output->data.shape[0] == row_length, "Wrong dimension of T output."); - NVTE_CHECK(transposed_output->data.shape[1] == num_rows, "Wrong dimension of T output."); + // Flatten tensor to 2D + NVTE_CHECK(input.data.shape == output->data.shape, + "Input and output shapes do not match (input=", input.data.shape, + ", output=", output->data.shape); + const size_t row_length = input.flat_last_dim(); + const size_t num_rows = input.flat_first_dim(); - NVTE_CHECK(cast_output->data.dtype == transposed_output->data.dtype, - "C and T outputs need to have the same type."); - NVTE_CHECK(cast_output->amax.dptr == transposed_output->amax.dptr, - "C and T outputs need to share amax tensor."); - NVTE_CHECK(cast_output->scale.dptr == transposed_output->scale.dptr, - "C and T outputs need to share scale tensor."); + // Check that cast and transposed output data matches + NVTE_CHECK(output->data.dtype == output->columnwise_data.dtype, + "Cast and transposed output types must match."); + NVTE_CHECK(output->scale_inv.dptr == output->columnwise_scale_inv.dptr, + "Cast and transposed outputs need to share scale-inverse tensor."); if constexpr (IS_DBIAS) { NVTE_CHECK(dbias->data.dtype == input.data.dtype, "DBias must have the same type as input."); NVTE_CHECK(dbias->data.shape == std::vector{row_length}, "Wrong shape of DBias."); } if constexpr (IS_DACT) { - NVTE_CHECK(input.data.dtype == act_input.data.dtype, "Types of both inputs must match."); - NVTE_CHECK(input.data.shape == act_input.data.shape, "Shapes of both inputs must match."); + NVTE_CHECK(input.dtype() == act_input->dtype(), "Types of both inputs must match."); + NVTE_CHECK(input.data.shape == act_input->data.shape, "Shapes of both inputs must match."); } TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( - input.data.dtype, InputType, + input.dtype(), InputType, TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( - cast_output->data.dtype, OutputType, using InputType2 = InputType; + output->dtype(), OutputType, using InputType2 = InputType; using Param = CTDBiasDActParam; constexpr int itype_size = sizeof(InputType); @@ -584,8 +622,9 @@ void cast_transpose_fused(const Tensor &input, const Tensor &act_input, Tensor * if (!jit_compiled) { num_blocks = DIVUP(num_tiles * n_warps_per_tile, n_warps_per_block); } if constexpr (IS_DBIAS) { + // Check workspace size + populate_cast_transpose_dbias_workspace_config(*output, workspace, nvec_out); if (workspace->data.dptr == nullptr) { - populate_cast_transpose_dbias_workspace_config(*cast_output, workspace, nvec_out); return; } } @@ -631,15 +670,15 @@ void cast_transpose_fused(const Tensor &input, const Tensor &act_input, Tensor * Param param; param.input = reinterpret_cast(input.data.dptr); - param.output_c = reinterpret_cast(cast_output->data.dptr); - param.output_t = reinterpret_cast(transposed_output->data.dptr); - param.scale_ptr = reinterpret_cast(transposed_output->scale.dptr); - param.amax = reinterpret_cast(transposed_output->amax.dptr); - param.scale_inv = reinterpret_cast(cast_output->scale_inv.dptr); + param.output_c = reinterpret_cast(output->data.dptr); + param.output_t = reinterpret_cast(output->columnwise_data.dptr); + param.scale_ptr = reinterpret_cast(output->scale.dptr); + param.amax = reinterpret_cast(output->amax.dptr); + param.scale_inv = reinterpret_cast(output->scale_inv.dptr); if constexpr (IS_DBIAS) { param.workspace = reinterpret_cast(workspace->data.dptr); } if constexpr (IS_DACT) { - param.act_input = reinterpret_cast(act_input.data.dptr); + param.act_input = reinterpret_cast(act_input->data.dptr); } // Runtime-compiled tuned kernel @@ -648,9 +687,9 @@ void cast_transpose_fused(const Tensor &input, const Tensor &act_input, Tensor * constexpr const char *itype2_name = TypeInfo::name; constexpr const char *otype_name = TypeInfo::name; - int dActType = 0; - if constexpr (IS_DACT) { - dActType = get_dactivation_type(); + int actType = 0; + if constexpr (IS_DACT || IS_ACT) { + actType = get_activation_type(); } // Compile NVRTC kernel if needed and launch @@ -660,7 +699,8 @@ void cast_transpose_fused(const Tensor &input, const Tensor &act_input, Tensor * ",itype=", itype_name, ",itype2=", itype2_name, ",otype=", otype_name, ",load_size=", load_size, ",store_size=", store_size, ",IS_DBIAS=", IS_DBIAS, - ",IS_DACT=", IS_DACT, ",dactivationType=", ActTypeToString[dActType]); + ",IS_DACT=", IS_DACT, ",IS_ACT=", IS_ACT, + ",activationType=", ActTypeToString[actType]); if (!rtc_manager.is_compiled(kernel_label)) { std::string code = string_code_transpose_rtc_cast_transpose_fusion_cu; @@ -673,7 +713,8 @@ void cast_transpose_fused(const Tensor &input, const Tensor &act_input, Tensor * code = regex_replace(code, "__BLOCK_SIZE__", cast_transpose_num_threads); code = regex_replace(code, "__IS_DBIAS__", IS_DBIAS); code = regex_replace(code, "__IS_DACT__", IS_DACT); - code = regex_replace(code, "__DACTIVATION_TYPE__", dActType); + code = regex_replace(code, "__IS_ACT__", IS_ACT); + code = regex_replace(code, "__ACTIVATION_TYPE__", actType); rtc_manager.compile( kernel_label, "cast_transpose_fusion_kernel_optimized", code, @@ -695,11 +736,11 @@ void cast_transpose_fused(const Tensor &input, const Tensor &act_input, Tensor * NVTE_CHECK(num_rows % nvec_out == 0, "Unsupported shape."); cudaFuncSetAttribute( - cast_transpose_fused_kernel_notaligned, + cast_transpose_fused_kernel_notaligned, cudaFuncAttributePreferredSharedMemoryCarveout, 100); - cast_transpose_fused_kernel_notaligned + cast_transpose_fused_kernel_notaligned <<>>( param, row_length, num_rows, num_tiles); } @@ -1101,43 +1142,39 @@ __global__ void __launch_bounds__(cast_transpose_num_threads) template -void dgated_act_cast_transpose(const Tensor &input, const Tensor &gated_act_input, - Tensor *cast_output, Tensor *transposed_output, +void dgated_act_cast_transpose(const Tensor &input, const Tensor &gated_act_input, Tensor *output, cudaStream_t stream) { CheckInputTensor(input, "dgated_act_cast_transpose_input"); CheckInputTensor(gated_act_input, "dgated_act_cast_transpose_gated_act_input"); - CheckOutputTensor(*cast_output, "dgated_act_cast_transpose_cast_output"); - CheckOutputTensor(*transposed_output, "dgated_act_cast_transpose_transposed_output"); + CheckOutputTensor(*output, "dgated_act_cast_transpose_output"); NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions."); NVTE_CHECK(gated_act_input.data.shape.size() == 2, "Input must have 2 dimensions."); - NVTE_CHECK(cast_output->data.shape.size() == 2, "C output must have 2 dimensions."); - NVTE_CHECK(transposed_output->data.shape.size() == 2, "T output must have 2 dimensions."); + NVTE_CHECK(output->has_data() && output->has_columnwise_data(), + "Both rowwise and columnwise data need to be allocated."); + NVTE_CHECK(output->data.shape.size() == 2, "C output must have 2 dimensions."); + NVTE_CHECK(output->columnwise_data.shape.size() == 2, "T output must have 2 dimensions."); const size_t row_length = input.data.shape[1]; const size_t num_rows = input.data.shape[0]; NVTE_CHECK(gated_act_input.data.shape[0] == num_rows, "Wrong dimension of output."); NVTE_CHECK(gated_act_input.data.shape[1] == row_length * 2, "Wrong dimension of output."); - NVTE_CHECK(cast_output->data.shape[0] == num_rows, "Wrong dimension of output."); - NVTE_CHECK(cast_output->data.shape[1] == row_length * 2, "Wrong dimension of output."); - NVTE_CHECK(transposed_output->data.shape[0] == row_length * 2, "Wrong dimension of T output."); - NVTE_CHECK(transposed_output->data.shape[1] == num_rows, "Wrong dimension of T output."); + NVTE_CHECK(output->data.shape[0] == num_rows, "Wrong dimension of output."); + NVTE_CHECK(output->data.shape[1] == row_length * 2, "Wrong dimension of output."); + NVTE_CHECK(output->columnwise_data.shape[0] == row_length * 2, "Wrong dimension of T output."); + NVTE_CHECK(output->columnwise_data.shape[1] == num_rows, "Wrong dimension of T output."); NVTE_CHECK(input.data.dtype == gated_act_input.data.dtype, "Types of both inputs must match."); - NVTE_CHECK(cast_output->data.dtype == transposed_output->data.dtype, + NVTE_CHECK(output->data.dtype == output->columnwise_data.dtype, "C and T outputs need to have the same type."); - NVTE_CHECK(cast_output->amax.dptr == transposed_output->amax.dptr, - "C and T outputs need to share amax tensor."); - NVTE_CHECK(cast_output->scale.dptr == transposed_output->scale.dptr, - "C and T outputs need to share scale tensor."); - NVTE_CHECK(cast_output->scale_inv.dptr == transposed_output->scale_inv.dptr, + NVTE_CHECK(output->scale_inv.dptr == output->columnwise_scale_inv.dptr, "C and T outputs need to share scale inverse tensor."); TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( - input.data.dtype, InputType, + input.dtype(), InputType, TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( - cast_output->data.dtype, OutputType, using InputType2 = InputType; + output->dtype(), OutputType, using InputType2 = InputType; /* dact fusion kernel uses more registers */ constexpr int desired_load_size_dact = 4; constexpr int desired_store_size_dact = 4; constexpr int itype_size = sizeof(InputType); @@ -1168,11 +1205,11 @@ void dgated_act_cast_transpose(const Tensor &input, const Tensor &gated_act_inpu <<>>( reinterpret_cast(input.data.dptr), reinterpret_cast(gated_act_input.data.dptr), - reinterpret_cast(cast_output->data.dptr), - reinterpret_cast(transposed_output->data.dptr), - reinterpret_cast(cast_output->scale.dptr), - reinterpret_cast(cast_output->amax.dptr), - reinterpret_cast(cast_output->scale_inv.dptr), row_length, num_rows, + reinterpret_cast(output->data.dptr), + reinterpret_cast(output->columnwise_data.dptr), + reinterpret_cast(output->scale.dptr), + reinterpret_cast(output->amax.dptr), + reinterpret_cast(output->scale_inv.dptr), row_length, num_rows, n_tiles); } else { cudaFuncSetAttribute( @@ -1184,194 +1221,193 @@ void dgated_act_cast_transpose(const Tensor &input, const Tensor &gated_act_inpu <<>>( reinterpret_cast(input.data.dptr), reinterpret_cast(gated_act_input.data.dptr), - reinterpret_cast(cast_output->data.dptr), - reinterpret_cast(transposed_output->data.dptr), - reinterpret_cast(cast_output->scale.dptr), - reinterpret_cast(cast_output->amax.dptr), - reinterpret_cast(cast_output->scale_inv.dptr), row_length, num_rows, + reinterpret_cast(output->data.dptr), + reinterpret_cast(output->columnwise_data.dptr), + reinterpret_cast(output->scale.dptr), + reinterpret_cast(output->amax.dptr), + reinterpret_cast(output->scale_inv.dptr), row_length, num_rows, n_tiles); }); // NOLINT(*) ); // NOLINT(*) } -} // namespace + +// Explicit template instantiation +template void cast_transpose_fused( + const Tensor &, const Tensor *, Tensor *, Tensor *, Tensor *, cudaStream_t); +#define NVTE_INSTANTIATE_ACTIVATION(op) \ + template void cast_transpose_fused>( \ + const Tensor &, const Tensor *, Tensor *, Tensor *, Tensor *, cudaStream_t); \ + template void cast_transpose_fused>( \ + const Tensor &, const Tensor *, Tensor *, Tensor *, Tensor *, cudaStream_t); +NVTE_INSTANTIATE_ACTIVATION(relu); +NVTE_INSTANTIATE_ACTIVATION(srelu); +NVTE_INSTANTIATE_ACTIVATION(gelu); +NVTE_INSTANTIATE_ACTIVATION(qgelu); +NVTE_INSTANTIATE_ACTIVATION(silu); +#undef NVTE_INSTANTIATE_ACTIVATION + +} // namespace detail } // namespace transformer_engine using ComputeType = typename transformer_engine::fp32; -void nvte_cast_transpose_dbias(const NVTETensor input, NVTETensor cast_output, - NVTETensor transposed_output, NVTETensor dbias, NVTETensor workspace, - cudaStream_t stream) { +void nvte_cast_transpose_dbias(const NVTETensor input, NVTETensor output, NVTETensor dbias, + NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_cast_transpose_dbias); using namespace transformer_engine; + using namespace transformer_engine::detail; constexpr bool IS_DBIAS = true; constexpr bool IS_DACT = false; + constexpr bool IS_ACT = false; constexpr const NVTETensor activation_input = nullptr; - cast_transpose_fused( - *reinterpret_cast(input), *reinterpret_cast(activation_input), - reinterpret_cast(cast_output), reinterpret_cast(transposed_output), - reinterpret_cast(dbias), reinterpret_cast(workspace), stream); + cast_transpose_fused( + *reinterpret_cast(input), reinterpret_cast(activation_input), + reinterpret_cast(output), reinterpret_cast(dbias), + reinterpret_cast(workspace), stream); } void nvte_cast_transpose_dbias_dgelu(const NVTETensor input, const NVTETensor act_input, - NVTETensor cast_output, NVTETensor transposed_output, - NVTETensor dbias, NVTETensor workspace, cudaStream_t stream) { + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + cudaStream_t stream) { NVTE_API_CALL(nvte_cast_transpose_dbias_dgelu); using namespace transformer_engine; + using namespace transformer_engine::detail; constexpr bool IS_DBIAS = true; constexpr bool IS_DACT = true; + constexpr bool IS_ACT = false; - constexpr auto dActivation = &dgelu; - - cast_transpose_fused( - *reinterpret_cast(input), *reinterpret_cast(act_input), - reinterpret_cast(cast_output), reinterpret_cast(transposed_output), - reinterpret_cast(dbias), reinterpret_cast(workspace), stream); + cast_transpose_fused>( + *reinterpret_cast(input), reinterpret_cast(act_input), + reinterpret_cast(output), reinterpret_cast(dbias), + reinterpret_cast(workspace), stream); } void nvte_cast_transpose_dbias_dsilu(const NVTETensor input, const NVTETensor silu_input, - NVTETensor cast_output, NVTETensor transposed_output, - NVTETensor dbias, NVTETensor workspace, cudaStream_t stream) { + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + cudaStream_t stream) { NVTE_API_CALL(nvte_cast_transpose_dbias_dsilu); using namespace transformer_engine; + using namespace transformer_engine::detail; constexpr bool IS_DBIAS = true; constexpr bool IS_DACT = true; + constexpr bool IS_ACT = false; - constexpr auto dActivation = &dsilu; - - cast_transpose_fused( - *reinterpret_cast(input), *reinterpret_cast(silu_input), - reinterpret_cast(cast_output), reinterpret_cast(transposed_output), - reinterpret_cast(dbias), reinterpret_cast(workspace), stream); + cast_transpose_fused>( + *reinterpret_cast(input), reinterpret_cast(silu_input), + reinterpret_cast(output), reinterpret_cast(dbias), + reinterpret_cast(workspace), stream); } void nvte_cast_transpose_dbias_drelu(const NVTETensor input, const NVTETensor relu_input, - NVTETensor cast_output, NVTETensor transposed_output, - NVTETensor dbias, NVTETensor workspace, cudaStream_t stream) { + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + cudaStream_t stream) { NVTE_API_CALL(nvte_cast_transpose_dbias_drelu); using namespace transformer_engine; + using namespace transformer_engine::detail; constexpr bool IS_DBIAS = true; constexpr bool IS_DACT = true; + constexpr bool IS_ACT = false; - constexpr auto dActivation = &drelu; - - cast_transpose_fused( - *reinterpret_cast(input), *reinterpret_cast(relu_input), - reinterpret_cast(cast_output), reinterpret_cast(transposed_output), - reinterpret_cast(dbias), reinterpret_cast(workspace), stream); + cast_transpose_fused>( + *reinterpret_cast(input), reinterpret_cast(relu_input), + reinterpret_cast(output), reinterpret_cast(dbias), + reinterpret_cast(workspace), stream); } void nvte_cast_transpose_dbias_dsrelu(const NVTETensor input, const NVTETensor srelu_input, - NVTETensor cast_output, NVTETensor transposed_output, - NVTETensor dbias, NVTETensor workspace, cudaStream_t stream) { + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + cudaStream_t stream) { NVTE_API_CALL(nvte_cast_transpose_dbias_dsrelu); using namespace transformer_engine; + using namespace transformer_engine::detail; constexpr bool IS_DBIAS = true; constexpr bool IS_DACT = true; + constexpr bool IS_ACT = false; - constexpr auto dActivation = &dsrelu; - - cast_transpose_fused( - *reinterpret_cast(input), *reinterpret_cast(srelu_input), - reinterpret_cast(cast_output), reinterpret_cast(transposed_output), - reinterpret_cast(dbias), reinterpret_cast(workspace), stream); + cast_transpose_fused>( + *reinterpret_cast(input), reinterpret_cast(srelu_input), + reinterpret_cast(output), reinterpret_cast(dbias), + reinterpret_cast(workspace), stream); } void nvte_cast_transpose_dbias_dqgelu(const NVTETensor input, const NVTETensor qgelu_input, - NVTETensor cast_output, NVTETensor transposed_output, - NVTETensor dbias, NVTETensor workspace, cudaStream_t stream) { + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + cudaStream_t stream) { NVTE_API_CALL(nvte_cast_transpose_dbias_dqgelu); using namespace transformer_engine; + using namespace transformer_engine::detail; constexpr bool IS_DBIAS = true; constexpr bool IS_DACT = true; + constexpr bool IS_ACT = false; - constexpr auto dActivation = &dqgelu; - - cast_transpose_fused( - *reinterpret_cast(input), *reinterpret_cast(qgelu_input), - reinterpret_cast(cast_output), reinterpret_cast(transposed_output), - reinterpret_cast(dbias), reinterpret_cast(workspace), stream); + cast_transpose_fused>( + *reinterpret_cast(input), reinterpret_cast(qgelu_input), + reinterpret_cast(output), reinterpret_cast(dbias), + reinterpret_cast(workspace), stream); } void nvte_dgeglu_cast_transpose(const NVTETensor input, const NVTETensor gated_act_input, - NVTETensor cast_output, NVTETensor transposed_output, - cudaStream_t stream) { + NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dgeglu_cast_transpose); using namespace transformer_engine; + using namespace transformer_engine::detail; - constexpr auto dActivation = &dgelu; - constexpr auto Activation = &gelu; - - dgated_act_cast_transpose( + dgated_act_cast_transpose, gelu>( *reinterpret_cast(input), *reinterpret_cast(gated_act_input), - reinterpret_cast(cast_output), reinterpret_cast(transposed_output), - stream); + reinterpret_cast(output), stream); } void nvte_dswiglu_cast_transpose(const NVTETensor input, const NVTETensor swiglu_input, - NVTETensor cast_output, NVTETensor transposed_output, - cudaStream_t stream) { + NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dswiglu_cast_transpose); using namespace transformer_engine; + using namespace transformer_engine::detail; - constexpr auto dActivation = &dsilu; - constexpr auto Activation = &silu; - - dgated_act_cast_transpose( + dgated_act_cast_transpose, silu>( *reinterpret_cast(input), *reinterpret_cast(swiglu_input), - reinterpret_cast(cast_output), reinterpret_cast(transposed_output), - stream); + reinterpret_cast(output), stream); } void nvte_dreglu_cast_transpose(const NVTETensor input, const NVTETensor gated_act_input, - NVTETensor cast_output, NVTETensor transposed_output, - cudaStream_t stream) { + NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dreglu_cast_transpose); using namespace transformer_engine; + using namespace transformer_engine::detail; - constexpr auto dActivation = &drelu; - constexpr auto Activation = &relu; - - dgated_act_cast_transpose( + dgated_act_cast_transpose, relu>( *reinterpret_cast(input), *reinterpret_cast(gated_act_input), - reinterpret_cast(cast_output), reinterpret_cast(transposed_output), - stream); + reinterpret_cast(output), stream); } void nvte_dsreglu_cast_transpose(const NVTETensor input, const NVTETensor gated_act_input, - NVTETensor cast_output, NVTETensor transposed_output, - cudaStream_t stream) { + NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dsreglu_cast_transpose); using namespace transformer_engine; + using namespace transformer_engine::detail; - constexpr auto dActivation = &dsrelu; - constexpr auto Activation = &srelu; - - dgated_act_cast_transpose( + dgated_act_cast_transpose, srelu>( *reinterpret_cast(input), *reinterpret_cast(gated_act_input), - reinterpret_cast(cast_output), reinterpret_cast(transposed_output), - stream); + reinterpret_cast(output), stream); } void nvte_dqgeglu_cast_transpose(const NVTETensor input, const NVTETensor gated_act_input, - NVTETensor cast_output, NVTETensor transposed_output, - cudaStream_t stream) { + NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dqgeglu_cast_transpose); using namespace transformer_engine; + using namespace transformer_engine::detail; - constexpr auto dActivation = &dqgelu; - constexpr auto Activation = &qgelu; - - dgated_act_cast_transpose( + dgated_act_cast_transpose, qgelu>( *reinterpret_cast(input), *reinterpret_cast(gated_act_input), - reinterpret_cast(cast_output), reinterpret_cast(transposed_output), - stream); + reinterpret_cast(output), stream); } diff --git a/transformer_engine/common/transpose/multi_cast_transpose.cu b/transformer_engine/common/transpose/multi_cast_transpose.cu index 16894ad4b5..5cf316f45e 100644 --- a/transformer_engine/common/transpose/multi_cast_transpose.cu +++ b/transformer_engine/common/transpose/multi_cast_transpose.cu @@ -195,42 +195,44 @@ __global__ void __launch_bounds__(threads_per_block) } // namespace -void multi_cast_transpose(const std::vector input_list, - std::vector cast_output_list, - std::vector transposed_output_list, cudaStream_t stream) { +void multi_cast_transpose(const std::vector input_list, std::vector output_list, + cudaStream_t stream) { // Check that number of tensors is valid - NVTE_CHECK(cast_output_list.size() == input_list.size(), - "Number of input and C output tensors must match"); - NVTE_CHECK(transposed_output_list.size() == input_list.size(), - "Number of input and T output tensors must match"); + NVTE_CHECK(output_list.size() == input_list.size(), + "Number of input and output tensors must match"); if (input_list.empty()) { return; } // Check that tensor properties are valid DType itype = input_list[0]->data.dtype; - DType otype = cast_output_list[0]->data.dtype; + DType otype = output_list[0]->dtype(); for (size_t tensor_id = 0; tensor_id < input_list.size(); ++tensor_id) { const auto& input = *input_list[tensor_id]; - const auto& cast_output = *cast_output_list[tensor_id]; - const auto& transposed_output = *transposed_output_list[tensor_id]; + const auto& output = *output_list[tensor_id]; CheckInputTensor(input, "multi_cast_transpose_input_" + std::to_string(tensor_id)); - CheckInputTensor(cast_output, "multi_cast_output_" + std::to_string(tensor_id)); - CheckInputTensor(transposed_output, "multi_transpose_output_" + std::to_string(tensor_id)); + CheckInputTensor(output, "multi_cast_transpose_output_" + std::to_string(tensor_id)); + //std::cout << *static_cast(output.data.dptr) << std::endl; + NVTE_CHECK(output.has_data() && output.has_columnwise_data(), + "Both rowwise and columnwise output data needs to be allocated."); NVTE_CHECK(input.data.dtype == itype, "Input tensor types do not match."); - NVTE_CHECK(cast_output.data.dtype == otype, "C output tensor types do not match."); - NVTE_CHECK(transposed_output.data.dtype == otype, "T output tensor types do not match."); + NVTE_CHECK(output.data.dtype == otype, "C output tensor types do not match."); + NVTE_CHECK(output.data.dtype == otype, "T output tensor types do not match."); - NVTE_CHECK(input.data.shape.size() == 2, "Input tensor must have 2 dimensions."); - NVTE_CHECK(cast_output.data.shape == input.data.shape, - "C output tensor shape does not match input tensor."); - NVTE_CHECK(transposed_output.data.shape.size() == 2, - "T output tensor shape does not match input tensor."); - NVTE_CHECK(transposed_output.data.shape[0] == input.data.shape[1], - "T output tensor shape does not match input tensor."); - NVTE_CHECK(transposed_output.data.shape[1] == input.data.shape[0], - "T output tensor shape does not match input tensor."); + NVTE_CHECK(input.data.shape.size() == 2, "Input tensor must have 2 dimensions, but shape is ", + input.data.shape); + NVTE_CHECK(output.data.shape == input.data.shape, "C output tensor shape ", output.data.shape, + "does not match input tensor shape ", input.data.shape); + NVTE_CHECK(output.columnwise_data.shape.size() == 2, "T output tensor shape ", + output.columnwise_data.shape, "does not match input tensor shape ", + input.data.shape); + NVTE_CHECK(output.columnwise_data.shape[0] == input.data.shape[1], "T output tensor shape ", + output.columnwise_data.shape, "does not match input tensor shape ", + input.data.shape); + NVTE_CHECK(output.columnwise_data.shape[1] == input.data.shape[0], "T output tensor shape ", + output.columnwise_data.shape, "does not match input tensor shape ", + input.data.shape); } // Input matrices are divided into tiles @@ -287,11 +289,11 @@ void multi_cast_transpose(const std::vector input_list, // Add tensor to kernel argument struct const int pos = kernel_args.num_tensors; kernel_args.input_list[pos] = const_cast(input_list[tensor_id]->data.dptr); - kernel_args.output_c_list[pos] = cast_output_list[tensor_id]->data.dptr; - kernel_args.output_t_list[pos] = transposed_output_list[tensor_id]->data.dptr; - kernel_args.scale_list[pos] = cast_output_list[tensor_id]->scale.dptr; - kernel_args.amax_list[pos] = cast_output_list[tensor_id]->amax.dptr; - kernel_args.scale_inv_list[pos] = cast_output_list[tensor_id]->scale_inv.dptr; + kernel_args.output_c_list[pos] = output_list[tensor_id]->data.dptr; + kernel_args.output_t_list[pos] = output_list[tensor_id]->columnwise_data.dptr; + kernel_args.scale_list[pos] = output_list[tensor_id]->scale.dptr; + kernel_args.amax_list[pos] = output_list[tensor_id]->amax.dptr; + kernel_args.scale_inv_list[pos] = output_list[tensor_id]->scale_inv.dptr; kernel_args.num_rows_list[pos] = num_rows; kernel_args.row_length_list[pos] = row_length; kernel_args.block_range[pos + 1] = kernel_args.block_range[pos] + num_tiles; @@ -327,15 +329,13 @@ void multi_cast_transpose(const std::vector input_list, } // namespace transformer_engine void nvte_multi_cast_transpose(size_t num_tensors, const NVTETensor* input_list, - NVTETensor* cast_output_list, NVTETensor* transposed_output_list, - cudaStream_t stream) { + NVTETensor* output_list, cudaStream_t stream) { NVTE_API_CALL(nvte_multi_cast_transpose); using namespace transformer_engine; - std::vector input_list_, cast_output_list_, transposed_output_list_; + std::vector input_list_, output_list_; for (size_t i = 0; i < num_tensors; ++i) { input_list_.push_back(reinterpret_cast(const_cast(input_list[i]))); - cast_output_list_.push_back(reinterpret_cast(cast_output_list[i])); - transposed_output_list_.push_back(reinterpret_cast(transposed_output_list[i])); + output_list_.push_back(reinterpret_cast(output_list[i])); } - multi_cast_transpose(input_list_, cast_output_list_, transposed_output_list_, stream); + multi_cast_transpose(input_list_, output_list_, stream); } diff --git a/transformer_engine/common/transpose/rtc/cast_transpose_fusion.cu b/transformer_engine/common/transpose/rtc/cast_transpose_fusion.cu index 2424247bbe..34359561aa 100644 --- a/transformer_engine/common/transpose/rtc/cast_transpose_fusion.cu +++ b/transformer_engine/common/transpose/rtc/cast_transpose_fusion.cu @@ -22,7 +22,9 @@ constexpr size_t WARPS_PER_TILE = __WARPS_PER_TILE__; constexpr size_t BLOCK_SIZE = __BLOCK_SIZE__; constexpr bool IS_DBIAS = __IS_DBIAS__; constexpr bool IS_DACT = __IS_DACT__; -constexpr size_t DACT_TYPE = __DACTIVATION_TYPE__; +constexpr bool IS_ACT = __IS_ACT__; +static_assert(!(IS_DACT && IS_ACT), "forward and backward activation are mutually exclusive"); +constexpr size_t ACT_TYPE = __ACTIVATION_TYPE__; constexpr size_t NVEC_IN = LOAD_SIZE / sizeof(IType); constexpr size_t NVEC_OUT = STORE_SIZE / sizeof(OType); @@ -33,14 +35,20 @@ using OVec = Vec; using Param = CTDBiasDActParam; using OP = CType (*)(const CType, const Empty &); -constexpr OP Activation[] = { +constexpr OP ActivationList[] = { nullptr, // 0 - &dsigmoid, // 1 - &dgelu, // 2 - &dqgelu, // 3 - &dsilu, // 4 - &drelu, // 5 - &dsrelu // 6 + &sigmoid, // 1 + &dsigmoid, // 2 + &gelu, // 3 + &dgelu, // 4 + &qgelu, // 5 + &dqgelu, // 6 + &silu, // 7 + &dsilu, // 8 + &relu, // 9 + &drelu, // 10 + &srelu, // 11 + &dsrelu // 12 }; } // namespace @@ -175,7 +183,10 @@ __global__ void __launch_bounds__(BLOCK_SIZE) if constexpr (IS_DACT) { in_cast_fp32[j].data.elt[k] = static_cast(in[current_in ^ 1][j].data.elt[k]) * - Activation[DACT_TYPE](act_in[current_in ^ 1][j].data.elt[k], {}); + ActivationList[ACT_TYPE](act_in[current_in ^ 1][j].data.elt[k], {}); + } else if constexpr (IS_ACT) { + in_cast_fp32[j].data.elt[k] = + ActivationList[ACT_TYPE](in[current_in ^ 1][j].data.elt[k], {}); } else { in_cast_fp32[j].data.elt[k] = static_cast(in[current_in ^ 1][j].data.elt[k]); } diff --git a/transformer_engine/common/transpose/transpose.cu b/transformer_engine/common/transpose/transpose.cu index 339748ead0..26740a3837 100644 --- a/transformer_engine/common/transpose/transpose.cu +++ b/transformer_engine/common/transpose/transpose.cu @@ -205,17 +205,8 @@ void transpose(const Tensor &input, const Tensor &noop, Tensor *output_, cudaStr NVTE_CHECK(output.data.dptr != nullptr, "Output is not allocated."); NVTE_CHECK(input.data.dtype == output.data.dtype, "Input and output type must match."); - // Number of elements in tensor - auto numel = [](const Tensor &tensor) -> size_t { - size_t acc = 1; - for (const auto &dim : tensor.data.shape) { - acc *= dim; - } - return acc; - }; - if (noop.data.dptr != nullptr) { - NVTE_CHECK(numel(noop) == 1, "Expected 1 element, ", "but found ", numel(noop), "."); + NVTE_CHECK(noop.numel() == 1, "Expected 1 element, ", "but found ", noop.numel(), "."); NVTE_CHECK(noop.data.dtype == DType::kFloat32); NVTE_CHECK(noop.data.dptr != nullptr); } diff --git a/transformer_engine/common/transpose/transpose_fusion.cu b/transformer_engine/common/transpose/transpose_fusion.cu index 39c702dade..fba3710beb 100644 --- a/transformer_engine/common/transpose/transpose_fusion.cu +++ b/transformer_engine/common/transpose/transpose_fusion.cu @@ -8,8 +8,8 @@ #include #include -#include -#include +#include +#include #include "../common.h" #include "../utils.cuh" @@ -376,8 +376,24 @@ void populate_transpose_dbias_workspace_config(const Tensor &input, /*cast*/ const size_t num_rows_partial_dbias = DIVUP(num_rows, tile_size_y); - workspace->data.shape = {num_rows_partial_dbias, row_length}; - workspace->data.dtype = DType::kFloat32; + if (workspace->data.dptr == nullptr) { + // Set workspace size + workspace->data.shape = {num_rows_partial_dbias, row_length}; + workspace->data.dtype = DType::kFloat32; + } else { + // Check that workspace matches expected size + const size_t workspace_size = + std::accumulate(workspace->data.shape.begin(), workspace->data.shape.end(), 1, + std::multiplies()) * + typeToSize(workspace->data.dtype); + const size_t required_size = num_rows_partial_dbias * row_length * typeToSize(DType::kFloat32); + NVTE_CHECK(!workspace->data.shape.empty(), "Invalid workspace dims (expected (", + num_rows_partial_dbias, ",", row_length, "), found ())"); + NVTE_CHECK(workspace_size >= required_size, "Invalid workspace (expected dims=(", + num_rows_partial_dbias, ",", row_length, "), dtype=", to_string(DType::kFloat32), + "; found dims=", workspace->data.shape, + ", dtype=", typeToSize(workspace->data.dtype), ")"); + } } template @@ -426,10 +442,9 @@ void fp8_transpose_dbias(const Tensor &input, Tensor *transposed_output, Tensor constexpr int nvec_in = desired_load_size / type_size; constexpr int nvec_out = desired_store_size / type_size; - if (workspace->data.dptr == nullptr) { - populate_transpose_dbias_workspace_config(input, workspace, nvec_out); - return; - } + // Check workspace size + populate_transpose_dbias_workspace_config(input, workspace, nvec_out); + if (workspace->data.dptr == nullptr) { return; } NVTE_CHECK(row_length % nvec_in == 0, "Unsupported shape."); NVTE_CHECK(num_rows % nvec_out == 0, "Unsupported shape."); diff --git a/transformer_engine/common/util/cast.cu b/transformer_engine/common/util/cast.cu index e0c92c22cb..22a50025df 100644 --- a/transformer_engine/common/util/cast.cu +++ b/transformer_engine/common/util/cast.cu @@ -4,88 +4,144 @@ * See LICENSE for license information. ************************************************************************/ +#include +#include +#include #include +#include +#include +#include + #include "../common.h" +#include "../transpose/cast_transpose.h" #include "../util/vectorized_pointwise.h" #include "../utils.cuh" +#include "cast_kernels.cuh" +#include "dequantize_kernels.cuh" +#include "math.h" +#include "ptx.cuh" +#include "transformer_engine/activation.h" +#include "transformer_engine/transpose.h" + +void nvte_quantize(const NVTETensor input, NVTETensor output, cudaStream_t stream) { + NVTE_API_CALL(nvte_quantize); + using namespace transformer_engine; -namespace transformer_engine { + constexpr bool IS_DBIAS = false; + constexpr bool IS_DACT = false; + constexpr bool IS_ACT = false; + constexpr NVTETensor dbias = nullptr; + constexpr NVTETensor workspace = nullptr; + constexpr const NVTETensor grad = nullptr; -namespace detail { + detail::quantize_helper(input, grad, nullptr, output, + dbias, workspace, stream); +} -struct Empty {}; +void nvte_quantize_noop(const NVTETensor input, NVTETensor output, NVTETensor noop, + cudaStream_t stream) { + NVTE_API_CALL(nvte_quantize_noop); + using namespace transformer_engine; -__device__ inline fp32 identity(fp32 value, const Empty &) { return value; } + constexpr bool IS_DBIAS = false; + constexpr bool IS_DACT = false; + constexpr bool IS_ACT = false; + constexpr NVTETensor dbias = nullptr; + constexpr NVTETensor workspace = nullptr; + constexpr const NVTETensor grad = nullptr; -struct DequantizeParam { - const fp32 *scale_inv; -}; + detail::quantize_helper(input, grad, noop, output, + dbias, workspace, stream); +} + +void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor dbias, + NVTETensor workspace, cudaStream_t stream) { + NVTE_API_CALL(nvte_quantize_dbias); + using namespace transformer_engine; -__device__ inline fp32 dequantize_func(fp32 value, const DequantizeParam ¶m) { - return value * (*(param.scale_inv)); + constexpr bool IS_DBIAS = true; + constexpr bool IS_DACT = false; + constexpr bool IS_ACT = false; + constexpr const NVTETensor activation_input = nullptr; + + detail::quantize_helper( + activation_input, input, nullptr, output, dbias, workspace, stream); } -} // namespace detail - -void fp8_quantize(const Tensor &input, Tensor *output, cudaStream_t stream) { - CheckInputTensor(input, "cast_input"); - CheckOutputTensor(*output, "cast_output"); - - NVTE_CHECK(!is_fp8_dtype(input.data.dtype), "Input must be in higher precision."); - - NVTE_CHECK(is_fp8_dtype(output->data.dtype), "Output must have FP8 type."); - NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match."); - - const size_t N = product(input.data.shape); - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( - input.data.dtype, IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - output->data.dtype, OType, constexpr int nvec = 32 / sizeof(IType); - VectorizedUnaryKernelLauncher( - reinterpret_cast(input.data.dptr), - reinterpret_cast(output->data.dptr), - reinterpret_cast(output->scale.dptr), - reinterpret_cast(output->amax.dptr), - reinterpret_cast(output->scale_inv.dptr), N, {}, - stream);); // NOLINT(*) - ); // NOLINT(*) +void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor activation_input, + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + cudaStream_t stream) { + NVTE_API_CALL(nvte_quantize_dbias_dgelu); + using namespace transformer_engine; + + constexpr bool IS_DBIAS = true; + constexpr bool IS_DACT = true; + constexpr bool IS_ACT = false; + + detail::quantize_helper>( + activation_input, input, nullptr, output, dbias, workspace, stream); } -void fp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) { - CheckInputTensor(input, "cast_input"); - CheckOutputTensor(*output, "cast_output"); - NVTE_CHECK(is_fp8_dtype(input.data.dtype), "Input must have FP8 type."); - - NVTE_CHECK(!is_fp8_dtype(output->data.dtype), "Output must be in higher precision."); - NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match."); - - const size_t N = product(input.data.shape); - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - input.data.dtype, IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( - output->data.dtype, OType, constexpr int nvec = 32 / sizeof(OType); - detail::DequantizeParam p; - p.scale_inv = reinterpret_cast(input.scale_inv.dptr); - VectorizedUnaryKernelLauncher( - reinterpret_cast(input.data.dptr), - reinterpret_cast(output->data.dptr), nullptr, nullptr, nullptr, N, p, - stream);); // NOLINT(*) - ); // NOLINT(*) +void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor activation_input, + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + cudaStream_t stream) { + NVTE_API_CALL(nvte_quantize_dbias_dsilu); + using namespace transformer_engine; + + constexpr bool IS_DBIAS = true; + constexpr bool IS_DACT = true; + constexpr bool IS_ACT = false; + + detail::quantize_helper>( + activation_input, input, nullptr, output, dbias, workspace, stream); +} + +void nvte_quantize_dbias_drelu(const NVTETensor input, const NVTETensor activation_input, + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + cudaStream_t stream) { + NVTE_API_CALL(nvte_quantize_dbias_drelu); + using namespace transformer_engine; + + constexpr bool IS_DBIAS = true; + constexpr bool IS_DACT = true; + constexpr bool IS_ACT = false; + + detail::quantize_helper>( + activation_input, input, nullptr, output, dbias, workspace, stream); } -} // namespace transformer_engine +void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor activation_input, + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + cudaStream_t stream) { + NVTE_API_CALL(nvte_quantize_dbias_dqgelu); + using namespace transformer_engine; + + constexpr bool IS_DBIAS = true; + constexpr bool IS_DACT = true; + constexpr bool IS_ACT = false; + + detail::quantize_helper>( + activation_input, input, nullptr, output, dbias, workspace, stream); +} -void nvte_fp8_quantize(const NVTETensor input, NVTETensor output, cudaStream_t stream) { - NVTE_API_CALL(nvte_fp8_quantize); +void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor activation_input, + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + cudaStream_t stream) { + NVTE_API_CALL(nvte_quantize_dbias_dsrelu); using namespace transformer_engine; - fp8_quantize(*reinterpret_cast(input), reinterpret_cast(output), - stream); + + constexpr bool IS_DBIAS = true; + constexpr bool IS_DACT = true; + constexpr bool IS_ACT = false; + + detail::quantize_helper>( + activation_input, input, nullptr, output, dbias, workspace, stream); } -void nvte_fp8_dequantize(const NVTETensor input, NVTETensor output, cudaStream_t stream) { - NVTE_API_CALL(nvte_fp8_dequantize); +void nvte_dequantize(const NVTETensor input, NVTETensor output, cudaStream_t stream) { + NVTE_API_CALL(nvte_dequantize); using namespace transformer_engine; - fp8_dequantize(*reinterpret_cast(input), reinterpret_cast(output), - stream); + detail::dequantize_helper(*reinterpret_cast(input), + reinterpret_cast(output), stream); } diff --git a/transformer_engine/common/util/cast_gated_kernels.cuh b/transformer_engine/common/util/cast_gated_kernels.cuh new file mode 100644 index 0000000000..e2240ba658 --- /dev/null +++ b/transformer_engine/common/util/cast_gated_kernels.cuh @@ -0,0 +1,1091 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file cast_gated_kernels.cuh + * \brief CUDA gated activations kernels to cast to/from FP8/MXFP8. + */ + +#ifndef TRANSFORMER_ENGINE_CAST_GATED_KERNELS_CUH_ +#define TRANSFORMER_ENGINE_CAST_GATED_KERNELS_CUH_ + +#include +#include +#include +#include +#include + +#include + +#include "../common.h" +#include "../util/vectorized_pointwise.h" +#include "../utils.cuh" +#include "math.h" +#include "ptx.cuh" + +namespace transformer_engine { + +template +__device__ __host__ __forceinline__ uint64_t DIVUP_TO_MULTIPLE(T1 N, T2 M) { + return DIVUP(static_cast(N), static_cast(M)) * M; +} + +namespace gated_kernels { + +constexpr size_t ALIGNMENT_SIZE = 128; +constexpr size_t CHUNK_DIM_Y = 128; +constexpr size_t CHUNK_DIM_X = 128; +constexpr size_t THREADS_PER_CHUNK = 512; +constexpr size_t THREADS_PER_CHUNK_X = CHUNK_DIM_X; +constexpr size_t THREADS_PER_CHUNK_Y = THREADS_PER_CHUNK / THREADS_PER_CHUNK_X; // 4 = 512 / 128 +constexpr size_t BUFFERS_NUM = 2; +constexpr size_t BUFFER_DIM_Y = 32; +constexpr size_t BUFFER_DIM_X = CHUNK_DIM_X; // 128 +constexpr size_t SHMEM_DIM_Y = BUFFER_DIM_Y; // 32 +constexpr size_t SHMEM_DIM_X = BUFFER_DIM_X; // 128 + +constexpr size_t BUFFER_STAGES_NUM = BUFFER_DIM_Y / THREADS_PER_CHUNK_Y; // 8 = 32 / 4 +constexpr size_t ITERATIONS = CHUNK_DIM_Y / BUFFER_DIM_Y; // 4 = 128 / 32 +static_assert(ITERATIONS >= 1); + +__device__ inline float sigmoidf(const float x) { return __frcp_rn(1.0f + __expf(-x)); } + +template +__global__ void __launch_bounds__(THREADS_PER_CHUNK) + cast_fp8_gated_kernel(const __grid_constant__ CUtensorMap tensor_map_grad, + const __grid_constant__ CUtensorMap tensor_map_input_act, + const __grid_constant__ CUtensorMap tensor_map_input_gate, + const __grid_constant__ CUtensorMap tensor_map_output_act, + const __grid_constant__ CUtensorMap tensor_map_output_gate, + float *const amax_ptr, float *const scale_inv_ptr, + const float *const scale_ptr, const size_t rows, const size_t cols) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + + const int chunk_offset_Y = blockIdx.y * CHUNK_DIM_Y; + const int chunk_offset_X = blockIdx.x * CHUNK_DIM_X; + + const int tid_Y = threadIdx.x / THREADS_PER_CHUNK_X; + const int tid_X = threadIdx.x % THREADS_PER_CHUNK_X; + + const int thread_offset_Y = tid_Y; + const int thread_offset_X = tid_X; + + float amax = 0; + const float scale = (scale_ptr != nullptr) ? *scale_ptr : 1; + + extern __shared__ char dshmem_unaligned[]; + const uint64_t dshmem_unaligned_as_uint = reinterpret_cast(dshmem_unaligned); + const uint64_t dshmem_aligned_as_uint = + DIVUP(dshmem_unaligned_as_uint, static_cast(ALIGNMENT_SIZE)) * ALIGNMENT_SIZE; + char *dshmem = reinterpret_cast(dshmem_aligned_as_uint); + + constexpr size_t buff_elems = SHMEM_DIM_Y * SHMEM_DIM_X; + constexpr size_t buff_elems_total = BUFFERS_NUM * buff_elems; + constexpr size_t buff_size_aligned_in = + DIVUP(buff_elems_total * sizeof(IType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE; + constexpr size_t buff_size_aligned_out = + DIVUP(buff_elems_total * sizeof(OType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE; + + constexpr size_t grad_mem = IS_DGATED ? buff_size_aligned_in : 0; + + constexpr size_t in_act_mem = buff_size_aligned_in; + constexpr size_t in_gate_mem = buff_size_aligned_in; + constexpr size_t in_mem = in_act_mem + in_gate_mem; + + constexpr size_t out_act_mem = buff_size_aligned_out; + constexpr size_t out_gate_mem = buff_size_aligned_out; + constexpr size_t out_mem = out_act_mem + out_gate_mem; + + // const size_t in_transaction_size = grad_mem + in_mem; + constexpr size_t in_transaction_size = buff_elems * sizeof(IType); + + // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned + IType *in_grad_sh = reinterpret_cast(dshmem); + IType *in_act_sh = reinterpret_cast(dshmem + grad_mem); + IType *in_gate_sh = reinterpret_cast(dshmem + grad_mem + in_act_mem); + OType *out_act_sh = reinterpret_cast(dshmem + grad_mem + in_mem); + OType *out_gate_sh = reinterpret_cast(dshmem + grad_mem + in_mem + out_act_mem); + // uint64_t *mbar = reinterpret_cast(dshmem + grad_mem + in_mem + out_mem); + + const uint64_t *TMAP_grad_in = reinterpret_cast(&tensor_map_grad); + const uint64_t *TMAP_in_act = reinterpret_cast(&tensor_map_input_act); + const uint64_t *TMAP_in_gate = reinterpret_cast(&tensor_map_input_gate); + const uint64_t *TMAP_output_act = reinterpret_cast(&tensor_map_output_act); + const uint64_t *TMAP_output_gate = reinterpret_cast(&tensor_map_output_gate); + + const bool is_master_thread = (threadIdx.x == 0); + +// Initialize shared memory barrier with the number of threads participating in the barrier. +#pragma nv_diag_suppress static_var_with_dynamic_init + __shared__ alignas(8) uint64_t mbar[ITERATIONS]; + + initialize_barriers(mbar, is_master_thread); + + int parity = 0; + + // Prefetch data of the first stage + + if constexpr (IS_DGATED) { + copy_2d_to_sharedx3(in_grad_sh, TMAP_grad_in, chunk_offset_X, chunk_offset_Y, in_act_sh, + TMAP_in_act, chunk_offset_X, chunk_offset_Y, in_gate_sh, TMAP_in_gate, + chunk_offset_X, chunk_offset_Y, in_transaction_size, &mbar[0], + is_master_thread); + } else { + copy_2d_to_sharedx2(in_act_sh, TMAP_in_act, chunk_offset_X, chunk_offset_Y, in_gate_sh, + TMAP_in_gate, chunk_offset_X, chunk_offset_Y, in_transaction_size, &mbar[0], + is_master_thread); + } + +#pragma unroll + for (int it = 0; it < ITERATIONS; ++it) { + const int buff = it % BUFFERS_NUM; + const int next_it = it + 1; + if (next_it < ITERATIONS) { + const int next_buff = next_it % BUFFERS_NUM; + const int chunk_it_offset_y = chunk_offset_Y + next_it * BUFFER_DIM_Y; + const int chunk_it_offset_x = chunk_offset_X; + if constexpr (IS_DGATED) { + copy_2d_to_sharedx3( + &in_grad_sh[next_buff * buff_elems], TMAP_grad_in, chunk_it_offset_x, chunk_it_offset_y, + &in_act_sh[next_buff * buff_elems], TMAP_in_act, chunk_it_offset_x, chunk_it_offset_y, + &in_gate_sh[next_buff * buff_elems], TMAP_in_gate, chunk_it_offset_x, chunk_it_offset_y, + in_transaction_size, &mbar[next_it], is_master_thread); + } else { + copy_2d_to_sharedx2(&in_act_sh[next_buff * buff_elems], TMAP_in_act, chunk_it_offset_x, + chunk_it_offset_y, &in_gate_sh[next_buff * buff_elems], TMAP_in_gate, + chunk_it_offset_x, chunk_it_offset_y, in_transaction_size, + &mbar[next_it], is_master_thread); + } + } + + ptx::fence_proxy_async_shared_cta(); + + // Wait for the data to have arrived + ptx::mbarrier_wait_parity(&mbar[it], parity); + + IType *in_grad_sh_curr = in_grad_sh + buff * buff_elems; + IType *in_act_sh_curr = in_act_sh + buff * buff_elems; + IType *in_gate_sh_curr = in_gate_sh + buff * buff_elems; + OType *out_act_sh_curr = out_act_sh + buff * buff_elems; + OType *out_gate_sh_curr = out_gate_sh + buff * buff_elems; + +#pragma unroll + for (int stage = 0; stage < BUFFER_STAGES_NUM; ++stage) { + const int stage_offset_Y = stage * THREADS_PER_CHUNK_Y; + const int shmem_offset_y = thread_offset_Y + stage_offset_Y; + const int shmem_offset_x = thread_offset_X; + const int shmem_idx = shmem_offset_y * SHMEM_DIM_X + shmem_offset_x; + + float act_elt = static_cast(in_act_sh_curr[shmem_idx]); + float gate_elt = static_cast(in_gate_sh_curr[shmem_idx]); + + if constexpr (IS_DGATED) { + float grad_elt = static_cast(in_grad_sh_curr[shmem_idx]); + + const float x = act_elt; + float act_x; + float dact_x; + + if constexpr ((ActOP == &silu) && (DActOP == &dsilu)) { + const float s = sigmoidf(x); + act_x = x * s; + dact_x = x * s * (1 - s) + s; + } else { + act_x = ActOP(x, {}); + dact_x = DActOP(x, {}); + } + + float after_dact = dact_x * grad_elt * gate_elt; + float after_dgate = act_x * grad_elt; + + out_act_sh_curr[shmem_idx] = static_cast(scale * after_dact); + out_gate_sh_curr[shmem_idx] = static_cast(scale * after_dgate); + + amax = fmaxf(amax, fabsf(after_dact)); + amax = fmaxf(amax, fabsf(after_dgate)); + } else { + const float after_act = ActOP(act_elt, {}) * gate_elt; + out_act_sh_curr[shmem_idx] = static_cast(scale * after_act); + amax = fmaxf(amax, fabsf(after_act)); + } + } + + // Wait for shared memory writes to be visible to TMA engine (cross-proxy fence) + ptx::fence_proxy_async_shared_cta(); + __syncthreads(); + // After syncthreads, writes by all threads are visible to TMA engine. + + // Initiate TMA transfer to copy shared memory to global memory + if (is_master_thread) { + const int chunk_it_offset_y = chunk_offset_Y + it * BUFFER_DIM_Y; + const int chunk_it_offset_x = chunk_offset_X; + + // dGeLU + ptx::cp_async_bulk_tensor_2d_shared_to_global(TMAP_output_act, chunk_it_offset_x, + chunk_it_offset_y, + reinterpret_cast(out_act_sh_curr)); + + if constexpr (IS_DGATED) { + // dGate + ptx::cp_async_bulk_tensor_2d_shared_to_global( + TMAP_output_gate, chunk_it_offset_x, chunk_it_offset_y, + reinterpret_cast(out_gate_sh_curr)); + } + + // Create a "bulk async-group" out of the previous bulk copy operation. + ptx::cp_async_bulk_commit_group(); + + // Wait for TMA transfer to have finished reading shared memory. + ptx::cp_async_bulk_wait_group_read(); + } + } + ptx::cp_async_bulk_wait_group_read<0>(); + __syncthreads(); + + if (amax_ptr != nullptr) { + const int warp_id = threadIdx.x / THREADS_PER_WARP; + // Reduce the amax over the block + amax = reduce_max(amax, warp_id); + // Update the global amax + if (is_master_thread) { + atomicMaxFloat(amax_ptr, amax); + } + } + + // Update scale-inverse + if (is_master_thread && blockIdx.x == 0 && (scale_inv_ptr != nullptr)) { + reciprocal(scale_inv_ptr, scale); + } + + // Destroy the barriers. This invalidates the memory region of the barrier. + // If further computations were to take place in the kernel, this allows the + // memory location of the shared memory barrier to be reused. + if (is_master_thread) { +#pragma unroll + for (int it = 0; it < ITERATIONS; ++it) { + ptx::mbarrier_invalid(&mbar[it]); + } + } +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} + +template +__global__ void __launch_bounds__(THREADS_PER_CHUNK) + cast_mxfp8_gated_kernel(const __grid_constant__ CUtensorMap tensor_map_grad, + const __grid_constant__ CUtensorMap tensor_map_input_act, + const __grid_constant__ CUtensorMap tensor_map_input_gate, + const __grid_constant__ CUtensorMap tensor_map_output_act_rowwise, + const __grid_constant__ CUtensorMap tensor_map_output_gate_rowwise, + const __grid_constant__ CUtensorMap tensor_map_output_act_colwise, + const __grid_constant__ CUtensorMap tensor_map_output_gate_colwise, + e8m0_t *const scales_rowwise, e8m0_t *const scales_colwise, + const size_t rows, const size_t cols, const size_t scale_stride_rowwise, + const size_t scale_stride_colwise) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + constexpr bool USE_ROWWISE_SCALING = SCALE_DIM_X > 1; + constexpr bool USE_COLWISE_SCALING = SCALE_DIM_Y > 1; + constexpr bool COMPUTE_IN_ROWWISE_SECTION = !USE_COLWISE_SCALING; + + constexpr size_t SCALES_ROWWISE_PER_CHUNK_Y = CHUNK_DIM_Y; // 128 + constexpr size_t SCALES_ROWWISE_PER_CHUNK_X = CHUNK_DIM_X / SCALE_DIM_X; // 4 = 128 / 32 + + constexpr size_t SCALES_COLWISE_PER_CHUNK_Y = CHUNK_DIM_Y / SCALE_DIM_Y; // 4 = 128 / 32 + constexpr size_t SCALES_COLWISE_PER_CHUNK_X = CHUNK_DIM_X; // 128 + + const int scales_rowwise_chunk_offset_Y = blockIdx.y * SCALES_ROWWISE_PER_CHUNK_Y; + const int scales_rowwise_chunk_offset_X = blockIdx.x * SCALES_ROWWISE_PER_CHUNK_X; + const int scales_colwise_chunk_offset_Y = blockIdx.y * SCALES_COLWISE_PER_CHUNK_Y; + const int scales_colwise_chunk_offset_X = blockIdx.x * SCALES_COLWISE_PER_CHUNK_X; + + const int chunk_offset_Y = blockIdx.y * CHUNK_DIM_Y; + const int chunk_offset_X = blockIdx.x * CHUNK_DIM_X; + + const int tid_Y = threadIdx.x / THREADS_PER_CHUNK_X; + const int tid_X = threadIdx.x % THREADS_PER_CHUNK_X; + + const int thread_offset_Y = tid_Y; + const int thread_offset_X = tid_X; + + const bool col_out_of_bounds = (chunk_offset_X + thread_offset_X >= cols); + + extern __shared__ char dshmem_unaligned[]; + const uint64_t dshmem_unaligned_as_uint = reinterpret_cast(dshmem_unaligned); + const uint64_t dshmem_aligned_as_uint = + DIVUP(dshmem_unaligned_as_uint, static_cast(ALIGNMENT_SIZE)) * ALIGNMENT_SIZE; + char *dshmem = reinterpret_cast(dshmem_aligned_as_uint); + + const size_t buff_elems = SHMEM_DIM_Y * SHMEM_DIM_X; + const size_t buff_elems_total = BUFFERS_NUM * buff_elems; + const size_t buff_size_aligned_in = + DIVUP(buff_elems_total * sizeof(IType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE; + const size_t buff_size_aligned_out = + DIVUP(buff_elems_total * sizeof(OType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE; + + const size_t grad_mem = (IS_DGATED ? buff_size_aligned_in : 0); + + const size_t in_act_mem = buff_size_aligned_in; + const size_t in_gate_mem = buff_size_aligned_in; + const size_t in_mem = in_act_mem + in_gate_mem; + + const size_t out_act_mem = buff_size_aligned_out; + const size_t out_gate_mem = buff_size_aligned_out; + const size_t out_mem = out_act_mem + out_gate_mem; + + // const size_t in_transaction_size = grad_mem + in_mem; + const size_t in_transaction_size = (IS_DGATED ? 3 : 2) * buff_elems * sizeof(IType); + + // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned + IType *in_grad_sh = reinterpret_cast(dshmem); + IType *in_act_sh = reinterpret_cast(dshmem + grad_mem); + IType *in_gate_sh = reinterpret_cast(dshmem + grad_mem + in_act_mem); + + OType *out_act_rowwise_sh = reinterpret_cast(dshmem + grad_mem + in_mem); + OType *out_gate_rowwise_sh = reinterpret_cast(dshmem + grad_mem + in_mem + out_act_mem); + + OType *out_act_colwise_sh = out_act_rowwise_sh; + OType *out_gate_colwise_sh = out_gate_rowwise_sh; + + if constexpr (USE_ROWWISE_SCALING && USE_COLWISE_SCALING) { + out_act_colwise_sh = reinterpret_cast(dshmem + grad_mem + in_mem + out_mem); + out_gate_colwise_sh = + reinterpret_cast(dshmem + grad_mem + in_mem + out_mem + out_act_mem); + } + + const uint64_t *TMAP_grad_in = reinterpret_cast(&tensor_map_grad); + const uint64_t *TMAP_in_act = reinterpret_cast(&tensor_map_input_act); + const uint64_t *TMAP_in_gate = reinterpret_cast(&tensor_map_input_gate); + const uint64_t *TMAP_output_act_rowwise = + reinterpret_cast(&tensor_map_output_act_rowwise); + const uint64_t *TMAP_output_gate_rowwise = + reinterpret_cast(&tensor_map_output_gate_rowwise); + const uint64_t *TMAP_output_act_colwise = + reinterpret_cast(&tensor_map_output_act_colwise); + const uint64_t *TMAP_output_gate_colwise = + reinterpret_cast(&tensor_map_output_gate_colwise); + + __shared__ float stage_amax_sh[THREADS_PER_CHUNK_Y][CHUNK_DIM_X]; + +// Initialize shared memory barrier with the number of threads participating in the barrier. +#pragma nv_diag_suppress static_var_with_dynamic_init + __shared__ alignas(8) uint64_t mbar[ITERATIONS]; + + const bool is_master_thread = (threadIdx.x == 0); + + if (is_master_thread) { +// Initialize barrier. All `blockDim.x * blockDim.y` threads in block participate. +#pragma unroll + for (int it = 0; it < ITERATIONS; ++it) { + ptx::mbarrier_init(&mbar[it], THREADS_PER_CHUNK); + } + ptx::fence_proxy_async_shared_cta(); + } + // Syncthreads so initialized barrier is visible to all threads. + __syncthreads(); + + int parity = 0; + + // Prefetch data of the first stage + if (is_master_thread) { + // Initiate bulk tensor copy + // Grad + if constexpr (IS_DGATED) { + ptx::cp_async_bulk_tensor_2d_global_to_shared(reinterpret_cast(&in_grad_sh[0]), + TMAP_grad_in, chunk_offset_X, chunk_offset_Y, + &mbar[0]); + } + + // Act + ptx::cp_async_bulk_tensor_2d_global_to_shared(reinterpret_cast(&in_act_sh[0]), + TMAP_in_act, chunk_offset_X, chunk_offset_Y, + &mbar[0]); + + // Gate + ptx::cp_async_bulk_tensor_2d_global_to_shared(reinterpret_cast(&in_gate_sh[0]), + TMAP_in_gate, chunk_offset_X, chunk_offset_Y, + &mbar[0]); + + // Arrive on the barrier and tell how many bytes are expected to come in. + ptx::mbarrier_arrive_expect_tx(&mbar[0], in_transaction_size); + } else { + // Other threads just arrive + ptx::mbarrier_arrive(&mbar[0]); + } + +#pragma unroll + for (int it = 0; it < ITERATIONS; ++it) { + const int buff = it % BUFFERS_NUM; + const int next_it = it + 1; + const size_t row_base = chunk_offset_Y + it * BUFFER_DIM_Y; + if (next_it < ITERATIONS) { + if (is_master_thread) { + const int next_buff = next_it % BUFFERS_NUM; + const int chunk_it_offset_y = chunk_offset_Y + next_it * BUFFER_DIM_Y; + const int chunk_it_offset_x = chunk_offset_X; + // Initiate bulk tensor copy + if constexpr (IS_DGATED) { + // Grad + ptx::cp_async_bulk_tensor_2d_global_to_shared( + reinterpret_cast(&in_grad_sh[next_buff * buff_elems]), TMAP_grad_in, + chunk_it_offset_x, chunk_it_offset_y, &mbar[next_it]); + } + // Act + ptx::cp_async_bulk_tensor_2d_global_to_shared( + reinterpret_cast(&in_act_sh[next_buff * buff_elems]), TMAP_in_act, + chunk_it_offset_x, chunk_it_offset_y, &mbar[next_it]); + // Gate + ptx::cp_async_bulk_tensor_2d_global_to_shared( + reinterpret_cast(&in_gate_sh[next_buff * buff_elems]), TMAP_in_gate, + chunk_it_offset_x, chunk_it_offset_y, &mbar[next_it]); + + // Arrive on the barrier and tell how many bytes are expected to come in. + ptx::mbarrier_arrive_expect_tx(&mbar[next_it], in_transaction_size); + } else { + // Other threads just arrive + ptx::mbarrier_arrive(&mbar[next_it]); + } + } + + ptx::fence_proxy_async_shared_cta(); + + // Wait for the data to have arrived + ptx::mbarrier_wait_parity(&mbar[it], parity); + + IType *in_grad_sh_curr = in_grad_sh + buff * buff_elems; + IType *in_act_sh_curr = in_act_sh + buff * buff_elems; + IType *in_gate_sh_curr = in_gate_sh + buff * buff_elems; + OType *out_act_rowwise_sh_curr = out_act_rowwise_sh + buff * buff_elems; + OType *out_gate_rowwise_sh_curr = out_gate_rowwise_sh + buff * buff_elems; + OType *out_act_colwise_sh_curr = out_act_colwise_sh + buff * buff_elems; + OType *out_gate_colwise_sh_curr = out_gate_colwise_sh + buff * buff_elems; + + // Assuming one iteration covers exactly 32 rows + const int iteration_scale_colwise_offset_Y = scales_colwise_chunk_offset_Y + it; + const int iteration_scale_rowwise_offset_Y = scales_rowwise_chunk_offset_Y + it * BUFFER_DIM_Y; + + float after_dact_reg[BUFFER_STAGES_NUM]; + float after_dgate_reg[BUFFER_STAGES_NUM]; + float thread_Y_mx_block_amax = 0.0f; + float thread_Y_mx_block_amax_gate = 0.0f; + +#pragma unroll + for (int stage = 0; stage < BUFFER_STAGES_NUM; ++stage) { + const int stage_offset_Y = stage * THREADS_PER_CHUNK_Y; + const int shmem_offset_y = thread_offset_Y + stage_offset_Y; + const int shmem_offset_x = thread_offset_X; + const int shmem_idx = shmem_offset_y * SHMEM_DIM_X + shmem_offset_x; + + const size_t row = row_base + shmem_offset_y; + const bool row_out_of_bounds = (row >= rows); + const bool out_of_bounds = (col_out_of_bounds || row_out_of_bounds); + + float act_elt = static_cast(in_act_sh_curr[shmem_idx]); + float gate_elt = static_cast(in_gate_sh_curr[shmem_idx]); + + if constexpr (IS_DGATED) { + float grad_elt = static_cast(in_grad_sh_curr[shmem_idx]); + const float x = act_elt; + float act_x; + float dact_x; + + if constexpr ((ActOP == &silu) && (DActOP == &dsilu)) { + const float s = sigmoidf(x); + act_x = x * s; + dact_x = x * s * (1 - s) + s; + } else { + act_x = ActOP(x, {}); + dact_x = DActOP(x, {}); + } + after_dact_reg[stage] = dact_x * grad_elt * gate_elt; + after_dgate_reg[stage] = act_x * grad_elt; + } else { + after_dact_reg[stage] = ActOP(act_elt, {}) * gate_elt; + } + + if constexpr (USE_ROWWISE_SCALING) { + if constexpr (IS_DGATED) { + // dgate + float amax = fabsf(after_dgate_reg[stage]); + const float mx_block_X_amax = warp_reduce_max_broadcast(amax); + const e8m0_t biased_exponent_X = + float_to_e8m0(mx_block_X_amax * Quantized_Limits::max_norm_rcp); + const float scale_reciprocal_X = exp2f_rcp(biased_exponent_X); + + out_gate_rowwise_sh_curr[shmem_idx] = + static_cast(scale_reciprocal_X * after_dgate_reg[stage]); + + // Only single thread writes the computed scaling factor + if ((tid_X % SCALE_DIM_X == 0) && !out_of_bounds) { + const int global_scales_offset_Y = + iteration_scale_rowwise_offset_Y + stage_offset_Y + thread_offset_Y; + const int global_scales_offset_X = + scales_rowwise_chunk_offset_X + (tid_X + cols) / SCALE_DIM_X; + const int scale_idx = + global_scales_offset_Y * scale_stride_rowwise + global_scales_offset_X; + scales_rowwise[scale_idx] = biased_exponent_X; + } + } + float amax = fabsf(after_dact_reg[stage]); + const float mx_block_X_amax = warp_reduce_max_broadcast(amax); + const e8m0_t biased_exponent_X = + float_to_e8m0(mx_block_X_amax * Quantized_Limits::max_norm_rcp); + const float scale_reciprocal_X = exp2f_rcp(biased_exponent_X); + + out_act_rowwise_sh_curr[shmem_idx] = + static_cast(scale_reciprocal_X * after_dact_reg[stage]); + + // Only single thread writes the computed scaling factor + if ((tid_X % SCALE_DIM_X == 0) && !out_of_bounds) { + const int global_scales_offset_Y = + iteration_scale_rowwise_offset_Y + stage_offset_Y + thread_offset_Y; + const int global_scales_offset_X = scales_rowwise_chunk_offset_X + tid_X / SCALE_DIM_X; + const int scale_idx = + global_scales_offset_Y * scale_stride_rowwise + global_scales_offset_X; + scales_rowwise[scale_idx] = biased_exponent_X; + } + } + + if constexpr (USE_COLWISE_SCALING) { + __builtin_assume(thread_Y_mx_block_amax >= 0); + __builtin_assume(thread_Y_mx_block_amax_gate >= 0); + thread_Y_mx_block_amax = fmaxf(thread_Y_mx_block_amax, fabsf(after_dact_reg[stage])); + if constexpr (IS_DGATED) { + thread_Y_mx_block_amax_gate = + fmaxf(thread_Y_mx_block_amax_gate, fabsf(after_dgate_reg[stage])); + } + } + } + + if constexpr (USE_COLWISE_SCALING) { + const bool row_out_of_bounds = (row_base >= rows); + const bool out_of_bounds = (col_out_of_bounds || row_out_of_bounds); + + if constexpr (IS_DGATED) { + // Colwise max reduction of the amax element + if (tid_Y > 0) { + stage_amax_sh[tid_Y][tid_X] = thread_Y_mx_block_amax_gate; + } + __syncthreads(); + if (tid_Y == 0) { +#pragma unroll + for (int y = 1; y < THREADS_PER_CHUNK_Y; ++y) { + thread_Y_mx_block_amax_gate = + fmaxf(thread_Y_mx_block_amax_gate, stage_amax_sh[y][tid_X]); + } + stage_amax_sh[0][tid_X] = thread_Y_mx_block_amax_gate; // write mx column-block amax + } + __syncthreads(); + + const float mx_block_Y_amax = stage_amax_sh[0][tid_X]; // read the mx column-block amax + + // For the scaling along both dimensions, the thread amax is already computed in ROWWISE section + if constexpr (!USE_ROWWISE_SCALING) { + __builtin_assume(mx_block_Y_amax >= 0); + } + + const e8m0_t biased_exponent = + float_to_e8m0(mx_block_Y_amax * Quantized_Limits::max_norm_rcp); + const float scale_reciprocal = exp2f_rcp(biased_exponent); + + // Only single thread writes the computed scaling factor + // Also assuming one iteration covers exactly 32 rows + if ((tid_Y == 0) && !out_of_bounds) { + const int global_scales_offset_Y = iteration_scale_colwise_offset_Y; + const int global_scales_offset_X = scales_colwise_chunk_offset_X + tid_X + cols; + const int scale_idx = + global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; + scales_colwise[scale_idx] = biased_exponent; + } + +#pragma unroll + for (int stage = 0; stage < BUFFER_STAGES_NUM; ++stage) { + const int stage_offset_Y = stage * THREADS_PER_CHUNK_Y; + const int shmem_offset_y = thread_offset_Y + stage_offset_Y; + const int shmem_offset_x = thread_offset_X; + const int shmem_idx = shmem_offset_y * SHMEM_DIM_X + shmem_offset_x; + + out_gate_colwise_sh_curr[shmem_idx] = + static_cast(scale_reciprocal * after_dgate_reg[stage]); + } + } + // Colwise max reduction of the amax element + if (tid_Y > 0) { + stage_amax_sh[tid_Y][tid_X] = thread_Y_mx_block_amax; + } + __syncthreads(); + if (tid_Y == 0) { +#pragma unroll + for (int y = 1; y < THREADS_PER_CHUNK_Y; ++y) { + thread_Y_mx_block_amax = fmaxf(thread_Y_mx_block_amax, stage_amax_sh[y][tid_X]); + } + stage_amax_sh[0][tid_X] = thread_Y_mx_block_amax; // write mx column-block amax + } + __syncthreads(); + + const float mx_block_Y_amax = stage_amax_sh[0][tid_X]; // read the mx column-block amax + + // For the scaling along both dimensions, the thread amax is already computed in ROWWISE section + if constexpr (!USE_ROWWISE_SCALING) { + __builtin_assume(mx_block_Y_amax >= 0); + } + + const e8m0_t biased_exponent = + float_to_e8m0(mx_block_Y_amax * Quantized_Limits::max_norm_rcp); + const float scale_reciprocal = exp2f_rcp(biased_exponent); + + // Only single thread writes the computed scaling factor + // Also assuming one iteration covers exactly 32 rows + if ((tid_Y == 0) && !out_of_bounds) { + const int global_scales_offset_Y = iteration_scale_colwise_offset_Y; + const int global_scales_offset_X = scales_colwise_chunk_offset_X + tid_X; + const int scale_idx = + global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; + scales_colwise[scale_idx] = biased_exponent; + } + +#pragma unroll + for (int stage = 0; stage < BUFFER_STAGES_NUM; ++stage) { + const int stage_offset_Y = stage * THREADS_PER_CHUNK_Y; + const int shmem_offset_y = thread_offset_Y + stage_offset_Y; + const int shmem_offset_x = thread_offset_X; + const int shmem_idx = shmem_offset_y * SHMEM_DIM_X + shmem_offset_x; + + out_act_colwise_sh_curr[shmem_idx] = + static_cast(scale_reciprocal * after_dact_reg[stage]); + } + } // endif USE_COLWISE_SCALING + + // Wait for shared memory writes to be visible to TMA engine (cross-proxy fence) + ptx::fence_proxy_async_shared_cta(); + __syncthreads(); + // After syncthreads, writes by all threads are visible to TMA engine. + + // Initiate TMA transfer to copy shared memory to global memory + if (is_master_thread) { + const int chunk_it_offset_y = chunk_offset_Y + it * BUFFER_DIM_Y; + const int chunk_it_offset_x = chunk_offset_X; + + // dGeLU + if constexpr (USE_ROWWISE_SCALING) { + ptx::cp_async_bulk_tensor_2d_shared_to_global( + TMAP_output_act_rowwise, chunk_it_offset_x, chunk_it_offset_y, + reinterpret_cast(out_act_rowwise_sh_curr)); + + if constexpr (IS_DGATED) { + // dGate + ptx::cp_async_bulk_tensor_2d_shared_to_global( + TMAP_output_gate_rowwise, chunk_it_offset_x, chunk_it_offset_y, + reinterpret_cast(out_gate_rowwise_sh_curr)); + } + } + + // dGeLU + if constexpr (USE_COLWISE_SCALING) { + ptx::cp_async_bulk_tensor_2d_shared_to_global( + TMAP_output_act_colwise, chunk_it_offset_x, chunk_it_offset_y, + reinterpret_cast(out_act_colwise_sh_curr)); + + if constexpr (IS_DGATED) { + // dGate + ptx::cp_async_bulk_tensor_2d_shared_to_global( + TMAP_output_gate_colwise, chunk_it_offset_x, chunk_it_offset_y, + reinterpret_cast(out_gate_colwise_sh_curr)); + } + } + + // Create a "bulk async-group" out of the previous bulk copy operation. + ptx::cp_async_bulk_commit_group(); + + // Wait for TMA transfer to have finished reading shared memory. + ptx::cp_async_bulk_wait_group_read(); + } + } + ptx::cp_async_bulk_wait_group_read<0>(); + __syncthreads(); + + // Destroy the barriers. This invalidates the memory region of the barrier. + // If further computations were to take place in the kernel, this allows the + // memory location of the shared memory barrier to be reused. + if (is_master_thread) { +#pragma unroll + for (int it = 0; it < ITERATIONS; ++it) { + ptx::mbarrier_invalid(&mbar[it]); + } + } +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} + +template +void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, + cudaStream_t stream) { + if (output->has_data()) { + NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated."); + } + if (output->has_columnwise_data()) { + NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr, "Scaling tensor must be allocated."); + } + + NVTE_CHECK(!output->has_columnwise_data(), "Only rowwise cast supported in this function."); + const size_t rows = gated_input.flat_first_dim(); + const size_t cols = gated_input.flat_last_dim() / 2; + const size_t output_cols = (IS_DGATED ? 2 : 1) * cols; + + const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y); + const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X); + + float *const amax_ptr = reinterpret_cast(output->amax.dptr); + float *const scale_inv_ptr = reinterpret_cast(output->scale_inv.dptr); + float *const scale_ptr = reinterpret_cast(output->scale.dptr); + + const dim3 block_dim(THREADS_PER_CHUNK); + const dim3 grid_dim(blocks_X, blocks_Y); + + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + gated_input.dtype(), IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + output->dtype(), OType, + + alignas(64) CUtensorMap tensor_map_grad{}; + alignas(64) CUtensorMap tensor_map_input_act{}; + alignas(64) CUtensorMap tensor_map_input_gate{}; + alignas(64) CUtensorMap tensor_map_output_act{}; + alignas(64) CUtensorMap tensor_map_output_gate{}; + + if constexpr (IS_DGATED) { + create_2D_tensor_map(tensor_map_grad, grad.data, rows, cols, SHMEM_DIM_Y, SHMEM_DIM_X, + cols, 0, sizeof(IType)); + } + + const uint32_t tensor_stride_elems = output_cols; + + create_2D_tensor_map(tensor_map_input_act, gated_input.data, rows, cols, SHMEM_DIM_Y, + SHMEM_DIM_X, cols * 2, 0, sizeof(IType)); + create_2D_tensor_map(tensor_map_input_gate, gated_input.data, rows, cols, SHMEM_DIM_Y, + SHMEM_DIM_X, cols * 2, cols, sizeof(IType)); + create_2D_tensor_map(tensor_map_output_act, output->data, rows, cols, SHMEM_DIM_Y, + SHMEM_DIM_X, tensor_stride_elems, 0, sizeof(OType)); + create_2D_tensor_map(tensor_map_output_gate, output->data, rows, cols, SHMEM_DIM_Y, + SHMEM_DIM_X, tensor_stride_elems, cols, sizeof(OType)); + + const size_t buff_elems_total = BUFFERS_NUM * SHMEM_DIM_Y * SHMEM_DIM_X; + const size_t buff_size_aligned_in = + DIVUP(buff_elems_total * sizeof(IType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE; + const size_t buff_size_aligned_out = + DIVUP(buff_elems_total * sizeof(OType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE; + const size_t grad_mem = (IS_DGATED ? buff_size_aligned_in : 0); + const size_t in_act_mem = buff_size_aligned_in; + const size_t in_gate_mem = buff_size_aligned_in; + const size_t out_act_mem = buff_size_aligned_out; + const size_t out_gate_mem = buff_size_aligned_out; + // const size_t mbar_mem = ITERATIONS * sizeof(uint64_t); + const size_t shmem_size = ALIGNMENT_SIZE + grad_mem + (in_act_mem + in_gate_mem) + + (out_act_mem + out_gate_mem); // + mbar_mem; + + cudaFuncSetAttribute( + cast_fp8_gated_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size); + + cast_fp8_gated_kernel + <<>>( + tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, tensor_map_output_act, + tensor_map_output_gate, amax_ptr, scale_inv_ptr, scale_ptr, rows, + cols);); // NOLINT(*) + ); // NOLINT(*) +} + +template +void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, + cudaStream_t stream) { + const bool USE_ROWWISE_SCALING = output->has_data(); + const bool USE_COLWISE_SCALING = output->has_columnwise_data(); + + if (USE_ROWWISE_SCALING) { + NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated."); + } + if (USE_COLWISE_SCALING) { + NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr, "Scaling tensor must be allocated."); + } + + // TODO: Make more general + const size_t scale_dim_X_rowwise = USE_ROWWISE_SCALING ? 32 : 1; + const size_t scale_dim_Y_colwise = USE_COLWISE_SCALING ? 32 : 1; + + const size_t rows = gated_input.flat_first_dim(); + const size_t cols = gated_input.flat_last_dim() / 2; + const size_t output_cols = (IS_DGATED ? 2 : 1) * cols; + + const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y); + const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X); + + size_t scale_stride_rowwise = USE_ROWWISE_SCALING ? output->scale_inv.shape[1] : 1; + size_t scale_stride_colwise = USE_COLWISE_SCALING ? output->columnwise_scale_inv.shape[1] : 1; + + float *const amax_ptr = reinterpret_cast(output->amax.dptr); + + e8m0_t *const scales_rowwise_ptr = + USE_ROWWISE_SCALING ? reinterpret_cast(output->scale_inv.dptr) : nullptr; + e8m0_t *const scales_colwise_ptr = + USE_COLWISE_SCALING ? reinterpret_cast(output->columnwise_scale_inv.dptr) : nullptr; + + const dim3 block_dim(THREADS_PER_CHUNK); + const dim3 grid_dim(blocks_X, blocks_Y); + + TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( + scale_dim_Y_colwise, SCALE_DIM_Y, + TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( + scale_dim_X_rowwise, SCALE_DIM_X, + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + gated_input.dtype(), IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + output->dtype(), OType, + + alignas(64) CUtensorMap tensor_map_grad{}; + alignas(64) CUtensorMap tensor_map_input_act{}; + alignas(64) CUtensorMap tensor_map_input_gate{}; + alignas(64) CUtensorMap tensor_map_output_act_rowwise{}; + alignas(64) CUtensorMap tensor_map_output_gate_rowwise{}; + alignas(64) CUtensorMap tensor_map_output_act_colwise{}; + alignas(64) CUtensorMap tensor_map_output_gate_colwise{}; + + if constexpr (IS_DGATED) { + create_2D_tensor_map(tensor_map_grad, grad.data, rows, cols, SHMEM_DIM_Y, + SHMEM_DIM_X, cols, 0, sizeof(IType)); + } + + const uint32_t tensor_stride_elems = output_cols; + create_2D_tensor_map(tensor_map_input_act, gated_input.data, rows, cols, + SHMEM_DIM_Y, SHMEM_DIM_X, cols * 2, 0, sizeof(IType)); + create_2D_tensor_map(tensor_map_input_gate, gated_input.data, rows, cols, + SHMEM_DIM_Y, SHMEM_DIM_X, cols * 2, cols, sizeof(IType)); + + if (USE_ROWWISE_SCALING) { + create_2D_tensor_map(tensor_map_output_act_rowwise, output->data, rows, cols, + SHMEM_DIM_Y, SHMEM_DIM_X, tensor_stride_elems, 0, + sizeof(OType)); + create_2D_tensor_map(tensor_map_output_gate_rowwise, output->data, rows, cols, + SHMEM_DIM_Y, SHMEM_DIM_X, tensor_stride_elems, cols, + sizeof(OType)); + } + + if (USE_COLWISE_SCALING) { + create_2D_tensor_map(tensor_map_output_act_colwise, output->columnwise_data, + rows, cols, SHMEM_DIM_Y, SHMEM_DIM_X, tensor_stride_elems, + 0, sizeof(OType)); + create_2D_tensor_map(tensor_map_output_gate_colwise, output->columnwise_data, + rows, cols, SHMEM_DIM_Y, SHMEM_DIM_X, tensor_stride_elems, + cols, sizeof(OType)); + } + + const size_t buff_elems_total = BUFFERS_NUM * SHMEM_DIM_Y * SHMEM_DIM_X; + const size_t buff_size_aligned_in = + DIVUP(buff_elems_total * sizeof(IType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE; + const size_t buff_size_aligned_out = + DIVUP(buff_elems_total * sizeof(OType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE; + + const size_t grad_mem = (IS_DGATED ? buff_size_aligned_in : 0); + const size_t in_act_mem = buff_size_aligned_in; + const size_t in_gate_mem = buff_size_aligned_in; + const size_t in_mem = grad_mem + in_act_mem + in_gate_mem; + + const size_t out_act_mem = buff_size_aligned_out; + const size_t out_gate_mem = buff_size_aligned_out; + size_t out_mem = out_act_mem + out_gate_mem; + if (USE_ROWWISE_SCALING && USE_COLWISE_SCALING) { out_mem *= 2; } + + // const size_t mbar_mem = ITERATIONS * sizeof(uint64_t); + // const size_t shmem_size = ALIGNMENT_SIZE + in_mem + out_mem + mbar_mem; + + const size_t shmem_size = ALIGNMENT_SIZE + in_mem + out_mem; + + cudaFuncSetAttribute( + cast_mxfp8_gated_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size); + + cast_mxfp8_gated_kernel + <<>>( + tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, + tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise, + tensor_map_output_act_colwise, tensor_map_output_gate_colwise, + scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise, + scale_stride_colwise);); // NOLINT(*) + ); // NOLINT(*) + ); // NOLINT(*) + ); // NOLINT(*) +} + +template +void cast_gated(const Tensor &input, Tensor *output, cudaStream_t stream) { + CheckInputTensor(input, "gated_act_input"); + CheckOutputTensor(*output, "gated_act_output"); + NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions."); + NVTE_CHECK(output->data.shape.size() == 2, "Output must have 2 dimensions."); + NVTE_CHECK(input.data.shape[0] == output->data.shape[0], + "Input shape[0] must be equal to output shape[0]."); + NVTE_CHECK(input.data.shape[1] == output->data.shape[1] * 2, + "Input shape[1] must be 2x larger than output shape[1]."); + + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + input.data.dtype, IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( + output->data.dtype, OType, + + if (!is_fp8_dtype(output->data.dtype) || + is_delayed_tensor_scaling(output->scaling_mode)) { + constexpr int nvec = 32 / sizeof(IType); + GatedActivationKernelLauncher( + reinterpret_cast(input.data.dptr), + reinterpret_cast(output->data.dptr), + reinterpret_cast(output->scale.dptr), + reinterpret_cast(output->amax.dptr), + reinterpret_cast(output->scale_inv.dptr), output->data.shape[0], + output->data.shape[1], {}, stream); + } else { + NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + "."); + }); // NOLINT(*) + ); // NOLINT(*) +} + +template +void cast_dgated(const Tensor &grad, const Tensor &input, Tensor *output, cudaStream_t stream) { + CheckInputTensor(grad, "dgated_act_grad"); + CheckInputTensor(input, "dgated_act_input"); + CheckOutputTensor(*output, "dgated_act_output"); + NVTE_CHECK(output->flat_first_dim() == grad.flat_first_dim(), + "Wrong output shape. Expected (after flattening) [", grad.flat_first_dim(), + ", *], got [", output->flat_first_dim(), ", ", output->flat_last_dim(), "]."); + NVTE_CHECK(output->flat_last_dim() == grad.flat_last_dim() * 2, + "Wrong output shape. Expected (after flattening) [*, ", grad.flat_last_dim() * 2, + "], got [", output->flat_first_dim(), ", ", output->flat_last_dim(), "]."); + NVTE_CHECK(input.data.shape == output->data.shape, + "Input and output shapes must match. Input shape: ", input.data.shape, + ", output shape: ", output->data.shape, "."); + + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + input.dtype(), IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( + output->dtype(), OType, + + if (!is_fp8_dtype(output->data.dtype) || + is_delayed_tensor_scaling(output->scaling_mode)) { + constexpr int nvec = 32 / sizeof(IType); + DGatedActivationKernelLauncher( + reinterpret_cast(grad.data.dptr), + reinterpret_cast(input.data.dptr), + reinterpret_cast(output->data.dptr), + reinterpret_cast(output->scale.dptr), + reinterpret_cast(output->amax.dptr), + reinterpret_cast(output->scale_inv.dptr), grad.flat_first_dim(), + grad.flat_last_dim(), {}, stream); + } else { + NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + "."); + }); // NOLINT(*) + ); // NOLINT(*) +} + +template +void quantize_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, + cudaStream_t stream) { + checkCuDriverContext(stream); + constexpr bool allow_empty = false; + CheckInputTensor(gated_input, "gated_input"); + CheckOutputTensor(*output, "output", allow_empty); + + NVTE_CHECK(gated_input.flat_last_dim() % 2 == 0, "Number of columns must be even."); + + const size_t rows = gated_input.flat_first_dim(); + const size_t cols = gated_input.flat_last_dim() / 2; + const size_t output_cols = (IS_DGATED ? 2 : 1) * cols; + + if constexpr (IS_DGATED) { + CheckInputTensor(grad, "grad"); + NVTE_CHECK(!is_fp8_dtype(grad.data.dtype), "Grad input must be in higher precision."); + NVTE_CHECK(grad.data.dtype == gated_input.data.dtype, "Types of both inputs must match."); + NVTE_CHECK(grad.flat_first_dim() == rows, "Wrong dimension of the grad input."); + NVTE_CHECK(grad.flat_last_dim() == cols, "Wrong dimension of the grad input."); + } + + NVTE_CHECK(output->has_data() || output->has_columnwise_data(), + "Either rowwise or columnwise output data need to be allocated."); + + bool is_fp8_rowwise_output = true; + bool is_fp8_colwise_output = true; + if (output->has_data()) { + is_fp8_rowwise_output = is_fp8_dtype(output->data.dtype); + NVTE_CHECK(output->flat_first_dim() == rows, "Wrong dimension of the output."); + NVTE_CHECK(output->flat_last_dim() == output_cols, "Wrong dimension of the output."); + } + if (output->has_columnwise_data()) { + is_fp8_colwise_output = is_fp8_dtype(output->columnwise_data.dtype); + NVTE_CHECK(output->flat_first_dim() == rows, "Wrong dimension of the output."); + NVTE_CHECK(output->flat_last_dim() == output_cols, "Wrong dimension of the output."); + } + + const bool use_tma_kernels = is_fp8_rowwise_output && is_fp8_colwise_output && cols % 32 == 0; + + if (is_delayed_tensor_scaling(output->scaling_mode)) { + if (use_tma_kernels) { + cast_fp8_gated(grad, gated_input, output, stream); + } else { + if constexpr (IS_DGATED) { + cast_dgated(grad, gated_input, output, stream); + } else { + cast_gated(gated_input, output, stream); + } + } + } else if (is_mxfp_scaling(output->scaling_mode)) { + if (use_tma_kernels) { + cast_mxfp8_gated(grad, gated_input, output, stream); + } else { + NVTE_ERROR("Invalid input shape. Expected the last dimension to be divisible ", + "by 32, got input of shape ", gated_input.data.shape); + } + } else { + NVTE_ERROR("Not supported scaling mode"); + } +} +} // namespace gated_kernels + +namespace detail { + +template +void quantize_gated_helper(const NVTETensor grad, const NVTETensor gated_input, NVTETensor output, + cudaStream_t stream) { + using namespace gated_kernels; + Tensor grad_empty_tensor; + const Tensor &grad_tensor = + IS_DGATED ? *(reinterpret_cast(grad)) : grad_empty_tensor; + const Tensor gated_input_tensor = *reinterpret_cast(gated_input); + Tensor *output_tensor = reinterpret_cast(output); + + if (is_supported_by_CC_100()) { + quantize_gated(grad_tensor, gated_input_tensor, + output_tensor, stream); + } else { + if (is_delayed_tensor_scaling(output_tensor->scaling_mode)) { + if constexpr (IS_DGATED) { + cast_dgated(grad_tensor, gated_input_tensor, output_tensor, stream); + } else { + cast_gated(gated_input_tensor, output_tensor, stream); + } + } else { + // MX scaling + NVTE_ERROR("Not supported by the Arch < 10.0"); + } + } +} +} // namespace detail + +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_CAST_GATED_KERNELS_CUH_ diff --git a/transformer_engine/common/util/cast_kernels.cuh b/transformer_engine/common/util/cast_kernels.cuh new file mode 100644 index 0000000000..404babc745 --- /dev/null +++ b/transformer_engine/common/util/cast_kernels.cuh @@ -0,0 +1,1251 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file cast_kernels.cuh + * \brief CUDA kernels to cast to/from FP8/MXFP8. + */ + +#ifndef TRANSFORMER_ENGINE_CAST_KERNELS_CUH_ +#define TRANSFORMER_ENGINE_CAST_KERNELS_CUH_ + +#include +#include +#include +#include + +#include + +#include "../common.h" +#include "../transpose/cast_transpose.h" +#include "../util/vectorized_pointwise.h" +#include "../utils.cuh" +#include "math.h" +#include "ptx.cuh" +#include "transformer_engine/transformer_engine.h" + +namespace transformer_engine { + +constexpr size_t MXFP8_CHUNK_DIM_Y = 64; +constexpr size_t MXFP8_CHUNK_DIM_X = 64; +constexpr size_t MXFP8_CHUNKS_PER_BLOCK_Y = 1; +constexpr size_t MXFP8_CHUNKS_PER_BLOCK_X = 1; +constexpr size_t MXFP8_CHUNKS_PER_BLOCK = MXFP8_CHUNKS_PER_BLOCK_Y * MXFP8_CHUNKS_PER_BLOCK_X; +constexpr size_t MXFP8_THREADS_PER_CHUNK = 64; +constexpr size_t MXFP8_BUFFERS_NUM = 2; +constexpr size_t MXFP8_PREFETCH_BUFFERS_NUM = 1; +static_assert(MXFP8_PREFETCH_BUFFERS_NUM < MXFP8_BUFFERS_NUM); + +constexpr size_t ELEMS_PER_THREAD = 16; +constexpr size_t MXFP8_BUFFER_DIM_Y = 32; // only 32 is supported +constexpr size_t MXFP8_BUFFER_DIM_X = MXFP8_CHUNK_DIM_X; // 64 +constexpr size_t MXFP8_SHMEM_DIM_Y = MXFP8_BUFFER_DIM_Y; // 32 +constexpr size_t MXFP8_SHMEM_DIM_X = MXFP8_BUFFER_DIM_X; // 64 + +constexpr size_t THREADS_PER_CHUNK_X_ROWWISE = + MXFP8_CHUNK_DIM_X / ELEMS_PER_THREAD; // 4 = 64 / 16 +constexpr size_t THREADS_PER_CHUNK_Y_ROWWISE = + MXFP8_THREADS_PER_CHUNK / THREADS_PER_CHUNK_X_ROWWISE; // 16 = 64 / 4 +constexpr size_t THREADS_PER_CHUNK_X_COLWISE = MXFP8_CHUNK_DIM_X; // 64 +constexpr size_t MXFP8_BUFF_STAGES_NUM = + MXFP8_BUFFER_DIM_Y / THREADS_PER_CHUNK_Y_ROWWISE; // 2 = 32 / 16 +constexpr size_t MXFP8_ITERATIONS = MXFP8_CHUNK_DIM_Y / MXFP8_BUFFER_DIM_Y; // 2 = 64 / 32 +static_assert(MXFP8_ITERATIONS >= MXFP8_PREFETCH_BUFFERS_NUM); + +template +__global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) + cast_mxfp8_2D_kernel(const __grid_constant__ CUtensorMap tensor_map_input, + const __grid_constant__ CUtensorMap tensor_map_act_input, + const __grid_constant__ CUtensorMap tensor_map_output_rowwise, + const __grid_constant__ CUtensorMap tensor_map_output_colwise, + e8m0_t *const scales_rowwise, e8m0_t *const scales_colwise, + const float *noop, float *const dbias_workspace, float *const amax_ptr, + const size_t rows, const size_t cols, const size_t scale_stride_rowwise, + const size_t scale_stride_colwise) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + if constexpr (!IS_DBIAS && !IS_DACT && !IS_ACT) { + if (noop != nullptr && noop[0] == 1.0f) return; + } + + constexpr bool USE_ROWWISE_SCALING = SCALE_DIM_X > 1; + constexpr bool USE_COLWISE_SCALING = SCALE_DIM_Y > 1; + constexpr bool COMPUTE_DBIAS_IN_ROWWISE_SECTION = !USE_COLWISE_SCALING; + + constexpr size_t SCALES_ROWWISE_PER_CHUNK_Y = MXFP8_CHUNK_DIM_Y; // 2 = 64 / 32 + constexpr size_t SCALES_ROWWISE_PER_CHUNK_X = MXFP8_CHUNK_DIM_X / SCALE_DIM_X; // 64 = 64 / 1 + constexpr size_t SCALES_ROWWISE_PER_BLOCK_Y = + SCALES_ROWWISE_PER_CHUNK_Y * MXFP8_CHUNKS_PER_BLOCK_Y; // 2 = 2 * 1 + constexpr size_t SCALES_ROWWISE_PER_BLOCK_X = + SCALES_ROWWISE_PER_CHUNK_X * MXFP8_CHUNKS_PER_BLOCK_X; // 64 = 64 * 1 + + constexpr size_t SCALES_COLWISE_PER_CHUNK_Y = MXFP8_CHUNK_DIM_Y / SCALE_DIM_Y; // 2 = 64 / 32 + constexpr size_t SCALES_COLWISE_PER_CHUNK_X = MXFP8_CHUNK_DIM_X; // 64 = 64 / 1 + constexpr size_t SCALES_COLWISE_PER_BLOCK_Y = + SCALES_COLWISE_PER_CHUNK_Y * MXFP8_CHUNKS_PER_BLOCK_Y; // 2 = 2 * 1 + constexpr size_t SCALES_COLWISE_PER_BLOCK_X = + SCALES_COLWISE_PER_CHUNK_X * MXFP8_CHUNKS_PER_BLOCK_X; // 64 = 64 * 1 + + constexpr size_t THREADS_PER_SCALE_X_ROWWISE = + DIVUP(SCALE_DIM_X, ELEMS_PER_THREAD); // 2 = 32 / 16 + constexpr size_t SUBWARP_WIDTH = THREADS_PER_SCALE_X_ROWWISE; // 2 + + const int block_offset_Y = blockIdx.y * MXFP8_CHUNKS_PER_BLOCK_Y * MXFP8_CHUNK_DIM_Y; + const int block_offset_X = blockIdx.x * MXFP8_CHUNKS_PER_BLOCK_X * MXFP8_CHUNK_DIM_X; + const int scales_rowwise_block_offset_Y = blockIdx.y * SCALES_ROWWISE_PER_BLOCK_Y; + const int scales_rowwise_block_offset_X = blockIdx.x * SCALES_ROWWISE_PER_BLOCK_X; + const int scales_colwise_block_offset_Y = blockIdx.y * SCALES_COLWISE_PER_BLOCK_Y; + const int scales_colwise_block_offset_X = blockIdx.x * SCALES_COLWISE_PER_BLOCK_X; + + const int tid_rowwise_Y = threadIdx.x / THREADS_PER_CHUNK_X_ROWWISE; + const int tid_rowwise_X = threadIdx.x % THREADS_PER_CHUNK_X_ROWWISE; + // const int tid_colwise_Y = threadIdx.x / THREADS_PER_CHUNK_X_COLWISE; + const int tid_colwise_X = threadIdx.x % THREADS_PER_CHUNK_X_COLWISE; + + const int thread_offset_Y = tid_rowwise_Y; + const int thread_offset_X_rowwise = tid_rowwise_X * ELEMS_PER_THREAD; + // const int thread_offset_X_colwise = tid_colwise_X; + + const int dbias_rowwise_offset_Y = blockIdx.y * MXFP8_CHUNKS_PER_BLOCK_Y + tid_rowwise_Y; + const int dbias_rowwise_block_offset_X = + blockIdx.x * MXFP8_CHUNKS_PER_BLOCK_X * MXFP8_CHUNK_DIM_X + thread_offset_X_rowwise; + const int dbias_colwise_offset_Y = blockIdx.y; + const int dbias_colwise_block_offset_X = + blockIdx.x * MXFP8_CHUNKS_PER_BLOCK_X * MXFP8_CHUNK_DIM_X + tid_colwise_X; + const int dbias_stride = cols; + + Vec partial_dbias_rowwise[MXFP8_CHUNKS_PER_BLOCK_X]; + float partial_dbias_colwise[MXFP8_CHUNKS_PER_BLOCK_X]; + if constexpr (IS_DBIAS) { + if constexpr (COMPUTE_DBIAS_IN_ROWWISE_SECTION) { +#pragma unroll + for (int i = 0; i < MXFP8_CHUNKS_PER_BLOCK_X; ++i) { + partial_dbias_rowwise[i].clear(); + } + } else { +#pragma unroll + for (int i = 0; i < MXFP8_CHUNKS_PER_BLOCK_X; ++i) { + partial_dbias_colwise[i] = 0; + } + } + } + + // The destination shared memory buffer of a bulk tensor operation should be 128 e8m0_t aligned + __shared__ alignas(128) IType in_sh[MXFP8_BUFFERS_NUM][MXFP8_SHMEM_DIM_Y][MXFP8_SHMEM_DIM_X]; + __shared__ alignas(128) IType act_in_sh[MXFP8_BUFFERS_NUM][MXFP8_SHMEM_DIM_Y][MXFP8_SHMEM_DIM_X]; + __shared__ alignas(128) + OType out_rowwise_sh[MXFP8_BUFFERS_NUM][MXFP8_SHMEM_DIM_Y][MXFP8_SHMEM_DIM_X]; + __shared__ alignas(128) + OType out_colwise_sh[MXFP8_BUFFERS_NUM][MXFP8_SHMEM_DIM_Y][MXFP8_SHMEM_DIM_X]; + + constexpr int shmem_buff_size = sizeof(in_sh) / MXFP8_BUFFERS_NUM; + constexpr int transaction_size = shmem_buff_size * (IS_DACT ? 2 : 1); + + const bool is_master_thread = (threadIdx.x == 0); + + float block_amax = 0; + +// Initialize shared memory barrier with the number of threads participating in the barrier. +#pragma nv_diag_suppress static_var_with_dynamic_init + __shared__ alignas(8) uint64_t mbar[MXFP8_ITERATIONS]; + + initialize_barriers(mbar, is_master_thread); + + int parity = 0; +#pragma unroll + for (int chunk = 0; chunk < MXFP8_CHUNKS_PER_BLOCK; ++chunk) { + const int chunk_Y = chunk / MXFP8_CHUNKS_PER_BLOCK_X; + const int chunk_X = chunk % MXFP8_CHUNKS_PER_BLOCK_X; + + const int chunk_offset_Y = block_offset_Y + chunk_Y * MXFP8_CHUNK_DIM_Y; + const int chunk_offset_X = block_offset_X + chunk_X * MXFP8_CHUNK_DIM_X; + + const int dbias_rowwise_offset_X = dbias_rowwise_block_offset_X + chunk_X * MXFP8_CHUNK_DIM_X; + const int dbias_colwise_offset_X = dbias_colwise_block_offset_X + chunk_X * MXFP8_CHUNK_DIM_X; + + const int scales_rowwise_chunk_offset_Y = + scales_rowwise_block_offset_Y + chunk_Y * SCALES_ROWWISE_PER_CHUNK_Y; + const int scales_rowwise_chunk_offset_X = + scales_rowwise_block_offset_X + chunk_X * SCALES_ROWWISE_PER_CHUNK_X; + const int scales_colwise_chunk_offset_Y = + scales_colwise_block_offset_Y + chunk_Y * SCALES_COLWISE_PER_CHUNK_Y; + const int scales_colwise_chunk_offset_X = + scales_colwise_block_offset_X + chunk_X * SCALES_COLWISE_PER_CHUNK_X; + +#pragma unroll + for (int prefetch_buff = 0; prefetch_buff < MXFP8_PREFETCH_BUFFERS_NUM; ++prefetch_buff) { + const int chunk_stage_offset_Y = chunk_offset_Y + prefetch_buff * MXFP8_BUFFER_DIM_Y; + const int chunk_stage_offset_X = chunk_offset_X; + if constexpr (IS_DACT) { + copy_2d_to_sharedx2(&in_sh[prefetch_buff], &tensor_map_input, chunk_stage_offset_X, + chunk_stage_offset_Y, &act_in_sh[prefetch_buff], &tensor_map_act_input, + chunk_stage_offset_X, chunk_stage_offset_Y, shmem_buff_size, + &mbar[prefetch_buff], is_master_thread); + } else { + copy_2d_to_shared(&in_sh[prefetch_buff], &tensor_map_input, chunk_stage_offset_X, + chunk_stage_offset_Y, shmem_buff_size, &mbar[prefetch_buff], + is_master_thread); + } + } + +#pragma unroll + for (int iter = 0; iter < MXFP8_ITERATIONS; ++iter) { + const int buff = iter % MXFP8_BUFFERS_NUM; + const int next_iter = iter + MXFP8_PREFETCH_BUFFERS_NUM; + const size_t row_base = chunk_offset_Y + iter * MXFP8_BUFFER_DIM_Y; + + if (next_iter < MXFP8_ITERATIONS) { + const int next_buff = next_iter % MXFP8_BUFFERS_NUM; + const int chunk_it_offset_y = chunk_offset_Y + next_iter * MXFP8_BUFFER_DIM_Y; + const int chunk_it_offset_x = chunk_offset_X; + if constexpr (IS_DACT) { + copy_2d_to_sharedx2(&in_sh[next_buff], &tensor_map_input, chunk_it_offset_x, + chunk_it_offset_y, &act_in_sh[next_buff], &tensor_map_act_input, + chunk_it_offset_x, chunk_it_offset_y, shmem_buff_size, + &mbar[next_iter], is_master_thread); + } else { + copy_2d_to_shared(&in_sh[next_buff], &tensor_map_input, chunk_it_offset_x, + chunk_it_offset_y, shmem_buff_size, &mbar[next_iter], is_master_thread); + } + } + + ptx::fence_proxy_async_shared_cta(); + + // Wait for the data to have arrived + ptx::mbarrier_wait_parity(&mbar[iter], parity); + + if constexpr (USE_ROWWISE_SCALING) { + Vec in; + Vec act_in; + Vec out_c; + + const int iteration_scale_rowwise_offset_Y = + scales_rowwise_chunk_offset_Y + iter * MXFP8_BUFFER_DIM_Y; + +#pragma unroll + for (int stage = 0; stage < MXFP8_BUFF_STAGES_NUM; ++stage) { + const int stage_offset_Y = stage * THREADS_PER_CHUNK_Y_ROWWISE; + const int shmem_offset_y = thread_offset_Y + stage_offset_Y; + const int shmem_offset_x = thread_offset_X_rowwise; + + const size_t row = row_base + shmem_offset_y; + const bool row_out_of_bounds = (row >= rows); + + in.load_from(&in_sh[buff][shmem_offset_y][shmem_offset_x]); + if constexpr (IS_DACT) { + act_in.load_from(&act_in_sh[buff][shmem_offset_y][shmem_offset_x]); + } + + float thread_amax = 0; + float in_compute[ELEMS_PER_THREAD]; + +#pragma unroll + for (int j = 0; j < ELEMS_PER_THREAD; ++j) { + const bool col_out_of_bounds = (dbias_rowwise_offset_X + j >= cols); + const bool out_of_bounds = (col_out_of_bounds || row_out_of_bounds); + + float elt = static_cast(in.data.elt[j]); + if constexpr (IS_ACT) { + elt = OP(elt, {}); + } + if constexpr (IS_DACT) { + float act_in_elt = static_cast(act_in.data.elt[j]); + elt *= OP(act_in_elt, {}); + } + if constexpr (IS_DBIAS && COMPUTE_DBIAS_IN_ROWWISE_SECTION) { + if (!out_of_bounds) { + partial_dbias_rowwise[chunk_X].data.elt[j] += elt; + } + } + in_compute[j] = elt; + if (!out_of_bounds) { + thread_amax = fmaxf(thread_amax, fabsf(elt)); + } + } + + __builtin_assume(block_amax >= 0); + __builtin_assume(thread_amax >= 0); + block_amax = fmaxf(block_amax, thread_amax); + + const float subwarp_amax = subwarp_reduce_max_broadcast(thread_amax); + const e8m0_t biased_exponent = + float_to_e8m0(subwarp_amax * Quantized_Limits::max_norm_rcp); + + // Only single thread writes the computed scaling factor + if (tid_rowwise_X % THREADS_PER_SCALE_X_ROWWISE == 0) { + const int global_scales_offset_Y = + iteration_scale_rowwise_offset_Y + stage_offset_Y + tid_rowwise_Y; + const int global_scales_offset_X = + scales_rowwise_chunk_offset_X + tid_rowwise_X / THREADS_PER_SCALE_X_ROWWISE; + const int scale_idx = + global_scales_offset_Y * scale_stride_rowwise + global_scales_offset_X; + scales_rowwise[scale_idx] = biased_exponent; + } + + const float block_scale_inverse = exp2f_rcp(biased_exponent); + +#pragma unroll + for (int j = 0; j < ELEMS_PER_THREAD; ++j) { + out_c.data.elt[j] = static_cast(in_compute[j] * block_scale_inverse); + } + out_c.store_to(&out_rowwise_sh[buff][shmem_offset_y][shmem_offset_x]); + } + } + + if constexpr (USE_COLWISE_SCALING) { + const bool col_out_of_bounds = (dbias_colwise_offset_X >= cols); + float in_compute[SCALE_DIM_Y]; + + float amax = 0; +#pragma unroll + for (int i = 0; i < SCALE_DIM_Y; ++i) { + const size_t row = row_base + i; + const bool row_out_of_bounds = (row >= rows); + const bool out_of_bounds = (col_out_of_bounds || row_out_of_bounds); + + float elt = static_cast(in_sh[buff][i][tid_colwise_X]); + if constexpr (IS_ACT) { + elt = OP(elt, {}); + } + if constexpr (IS_DACT) { + float act_in_elt = static_cast(act_in_sh[buff][i][tid_colwise_X]); + elt *= OP(act_in_elt, {}); + } + if constexpr (IS_DBIAS) { + if (!out_of_bounds) { + partial_dbias_colwise[chunk_X] += elt; + } + } + in_compute[i] = elt; + if (!out_of_bounds) { + amax = fmaxf(amax, fabsf(elt)); + } + } + + __builtin_assume(block_amax >= 0); + __builtin_assume(amax >= 0); + block_amax = fmaxf(block_amax, amax); + + const e8m0_t biased_exponent = float_to_e8m0(amax * Quantized_Limits::max_norm_rcp); + + const int global_scales_offset_Y = scales_colwise_chunk_offset_Y + iter; + const int global_scales_offset_X = scales_colwise_chunk_offset_X + tid_colwise_X; + const int scale_idx = + global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; + scales_colwise[scale_idx] = biased_exponent; + + const float block_scale_inverse = exp2f_rcp(biased_exponent); +#pragma unroll + for (int i = 0; i < SCALE_DIM_Y; ++i) { + out_colwise_sh[buff][i][tid_colwise_X] = + static_cast(in_compute[i] * block_scale_inverse); + } + } + + // Wait for shared memory writes to be visible to TMA engine. + ptx::fence_proxy_async_shared_cta(); + __syncthreads(); + // After syncthreads, writes by all threads are visible to TMA engine. + + // Initiate TMA transfer to copy shared memory to global memory + if (is_master_thread) { + const int chunk_it_offset_y = chunk_offset_Y + iter * MXFP8_BUFFER_DIM_Y; + const int chunk_it_offset_x = chunk_offset_X; + if constexpr (USE_ROWWISE_SCALING) { + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output_rowwise), chunk_it_offset_x, + chunk_it_offset_y, reinterpret_cast(&out_rowwise_sh[buff])); + } + if constexpr (USE_COLWISE_SCALING) { + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output_colwise), chunk_it_offset_x, + chunk_it_offset_y, reinterpret_cast(&out_colwise_sh[buff])); + } + // Create a "bulk async-group" out of the previous bulk copy operation. + ptx::cp_async_bulk_commit_group(); + + // Wait for TMA transfer to have finished reading shared memory. + ptx::cp_async_bulk_wait_group_read(); + } + } + ptx::cp_async_bulk_wait_group_read<0>(); + __syncthreads(); + + parity ^= 1; + } + + if constexpr (IS_DBIAS) { + if constexpr (COMPUTE_DBIAS_IN_ROWWISE_SECTION) { + constexpr size_t CZ = MXFP8_CHUNKS_PER_BLOCK_X; + constexpr size_t Y = THREADS_PER_CHUNK_Y_ROWWISE - 1; + constexpr size_t X = THREADS_PER_CHUNK_X_ROWWISE; + __shared__ float shmem_partial_dbias_rowwise[CZ][Y][X][ELEMS_PER_THREAD]; + + if (tid_rowwise_Y > 0) { +#pragma unroll + for (int c = 0; c < MXFP8_CHUNKS_PER_BLOCK_X; ++c) { + partial_dbias_rowwise[c].store_to( + &shmem_partial_dbias_rowwise[c][tid_rowwise_Y - 1][tid_rowwise_X]); + } + } + __syncthreads(); + + if (tid_rowwise_Y == 0) { +#pragma unroll + for (int c = 0; c < MXFP8_CHUNKS_PER_BLOCK_X; ++c) { + Vec other_row_dbias; + const int dbias_rowwise_offset_X = dbias_rowwise_block_offset_X + c * MXFP8_CHUNK_DIM_X; + const int dbias_offset = dbias_rowwise_offset_Y * dbias_stride + dbias_rowwise_offset_X; + + const int left_bound = dbias_rowwise_offset_X; + const int right_bound = dbias_rowwise_offset_X + ELEMS_PER_THREAD - 1; + +#pragma unroll + for (int i = 0; i < Y; ++i) { + other_row_dbias.load_from(&shmem_partial_dbias_rowwise[c][i][tid_rowwise_X]); +#pragma unroll + for (int j = 0; j < ELEMS_PER_THREAD; ++j) { + partial_dbias_rowwise[c].data.elt[j] += other_row_dbias.data.elt[j]; + } + } + + // Vectorized store when all elements are inside the boundaries + if (right_bound < cols) { + partial_dbias_rowwise[c].store_to(&dbias_workspace[dbias_offset]); + } else if (left_bound < cols && right_bound >= cols) { + // Element-by-element store when some elements cross the boundaries + const int in_bound_elts_count = cols - left_bound; + partial_dbias_rowwise[c].store_to_elts(&dbias_workspace[dbias_offset], 0, + in_bound_elts_count); + } + } + } + } else { +#pragma unroll + for (int i = 0; i < MXFP8_CHUNKS_PER_BLOCK_X; ++i) { + const int dbias_colwise_offset_X = dbias_colwise_block_offset_X + i * MXFP8_CHUNK_DIM_X; + const int dbias_offset = dbias_colwise_offset_Y * dbias_stride + dbias_colwise_offset_X; + const bool col_out_of_bounds = (dbias_colwise_offset_X >= cols); + if (!col_out_of_bounds) { + dbias_workspace[dbias_offset] = partial_dbias_colwise[i]; + } + } + } + } + + if (amax_ptr != nullptr) { + const int warp_id = threadIdx.x / THREADS_PER_WARP; + // Reduce the amax over the block + block_amax = reduce_max(block_amax, warp_id); + } + + if (is_master_thread && amax_ptr != nullptr) { + atomicMaxFloat(amax_ptr, block_amax); + } + + destroy_barriers(mbar, is_master_thread); +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} + +constexpr size_t FP8_CHUNK_DIM_Y = 128; +constexpr size_t FP8_CHUNK_DIM_X = 128; +constexpr size_t FP8_THREADS_PER_CHUNK = 128; +constexpr size_t FP8_BUFFERS_NUM = 2; +constexpr size_t FP8_PREFETCH_BUFFERS_NUM = 1; +static_assert(FP8_PREFETCH_BUFFERS_NUM < FP8_BUFFERS_NUM); + +constexpr size_t FP8_BUFFER_DIM_Y = 16; +constexpr size_t FP8_BUFFER_DIM_X = FP8_CHUNK_DIM_X; // 128 +constexpr size_t FP8_SHMEM_DIM_Y = FP8_BUFFER_DIM_Y; // 16 +constexpr size_t FP8_SHMEM_DIM_X = FP8_BUFFER_DIM_X; // 128 + +constexpr size_t FP8_BUFF_STAGES_NUM = FP8_BUFFER_DIM_Y; // 16 +constexpr size_t FP8_ITERATIONS = FP8_CHUNK_DIM_Y / FP8_BUFFER_DIM_Y; // 8 = 128 / 16 +static_assert(FP8_ITERATIONS >= FP8_PREFETCH_BUFFERS_NUM); + +template +__global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK) + cast_fp8_2D_kernel(const __grid_constant__ CUtensorMap tensor_map_input, + const __grid_constant__ CUtensorMap tensor_map_act_input, + const __grid_constant__ CUtensorMap tensor_map_output, + float *const dbias_workspace, float *const amax_ptr, + float *const scale_inv_ptr, const float *const scale_ptr, const size_t rows, + const size_t cols) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + + const int block_offset_Y = blockIdx.y * FP8_CHUNK_DIM_Y; + const int block_offset_X = blockIdx.x * FP8_CHUNK_DIM_X; + + const int tid_Y = threadIdx.x / FP8_THREADS_PER_CHUNK; + const int tid_X = threadIdx.x % FP8_THREADS_PER_CHUNK; + + const int thread_offset_Y = tid_Y; + const int thread_offset_X = tid_X; + + const int dbias_offset_Y = blockIdx.y + tid_Y; + const int my_column = blockIdx.x * FP8_CHUNK_DIM_X + thread_offset_X; + const bool col_out_of_bounds = my_column >= cols; + const int dbias_stride = cols; + + float partial_dbias = 0.f; + + float amax = 0; + const float scale = (scale_ptr != nullptr) ? *scale_ptr : 1; + + // The destination shared memory buffer of a bulk tensor operation should be 128-byte aligned + __shared__ alignas(128) IType in_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X]; + __shared__ alignas(128) IType act_in_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X]; + __shared__ alignas(128) OType out_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X]; + + constexpr int shmem_buff_size = sizeof(in_sh) / FP8_BUFFERS_NUM; + constexpr int transaction_size = shmem_buff_size * (IS_DACT ? 2 : 1); + + const bool is_master_thread = (threadIdx.x == 0); + +// Initialize shared memory barrier with the number of threads participating in the barrier. +#pragma nv_diag_suppress static_var_with_dynamic_init + __shared__ alignas(8) uint64_t mbar[FP8_ITERATIONS]; + + initialize_barriers(mbar, is_master_thread); + + int parity = 0; + + const int chunk_offset_Y = block_offset_Y; + const int chunk_offset_X = block_offset_X; + +#pragma unroll + for (int prefetch_buff = 0; prefetch_buff < FP8_PREFETCH_BUFFERS_NUM; ++prefetch_buff) { + const int chunk_stage_offset_Y = chunk_offset_Y + prefetch_buff * FP8_BUFFER_DIM_Y; + const int chunk_stage_offset_X = chunk_offset_X; + if constexpr (IS_DACT) { + copy_2d_to_sharedx2(&in_sh[prefetch_buff], &tensor_map_input, chunk_stage_offset_X, + chunk_stage_offset_Y, &act_in_sh[prefetch_buff], &tensor_map_act_input, + chunk_stage_offset_X, chunk_stage_offset_Y, shmem_buff_size, + &mbar[prefetch_buff], is_master_thread); + } else { + copy_2d_to_shared(&in_sh[prefetch_buff], &tensor_map_input, chunk_stage_offset_X, + chunk_stage_offset_Y, shmem_buff_size, &mbar[prefetch_buff], + is_master_thread); + } + } + +#pragma unroll + for (int iter = 0; iter < FP8_ITERATIONS; ++iter) { + const int buff = iter % FP8_BUFFERS_NUM; + const int next_iter = iter + FP8_PREFETCH_BUFFERS_NUM; + const size_t row_base = block_offset_Y + iter * FP8_BUFFER_DIM_Y; + if (next_iter < FP8_ITERATIONS) { + const int next_buff = next_iter % FP8_BUFFERS_NUM; + const int chunk_it_offset_y = chunk_offset_Y + next_iter * FP8_BUFFER_DIM_Y; + const int chunk_it_offset_x = chunk_offset_X; + if constexpr (IS_DACT) { + copy_2d_to_sharedx2(&in_sh[next_buff], &tensor_map_input, chunk_it_offset_x, + chunk_it_offset_y, &act_in_sh[next_buff], &tensor_map_act_input, + chunk_it_offset_x, chunk_it_offset_y, shmem_buff_size, &mbar[next_iter], + is_master_thread); + } else { + copy_2d_to_shared(&in_sh[next_buff], &tensor_map_input, chunk_it_offset_x, + chunk_it_offset_y, shmem_buff_size, &mbar[next_iter], is_master_thread); + } + } + + // Wait for the data to have arrived + ptx::mbarrier_wait_parity(&mbar[iter], parity); + +#pragma unroll + for (int stage = 0; stage < FP8_BUFF_STAGES_NUM; ++stage) { + const int stage_offset_Y = stage; + const int shmem_offset_y = thread_offset_Y + stage_offset_Y; + const int shmem_offset_x = thread_offset_X; + const size_t row = row_base + shmem_offset_y; + const bool row_out_of_bounds = row >= rows; + const bool out_of_bounds = col_out_of_bounds || row_out_of_bounds; + + float elt = static_cast(in_sh[buff][shmem_offset_y][shmem_offset_x]); + if constexpr (IS_DACT) { + float act_in_elt = static_cast(act_in_sh[buff][shmem_offset_y][shmem_offset_x]); + elt *= OP(act_in_elt, {}); + } + if constexpr (IS_DBIAS) { + if constexpr (IS_DACT) { + if (!out_of_bounds) { + partial_dbias += elt; + } + } else { + // If no activation, elt is 0 so we can safely do this + partial_dbias += elt; + } + } + __builtin_assume(amax >= 0); + if (IS_DACT) { + if (!out_of_bounds) { + amax = fmaxf(amax, fabsf(elt)); + } + } else { + // If no activation, elt is 0 so we can safely do this + amax = fmaxf(amax, fabsf(elt)); + } + out_sh[buff][shmem_offset_y][shmem_offset_x] = static_cast(elt * scale); + } + + // Wait for shared memory writes to be visible to TMA engine. + ptx::fence_proxy_async_shared_cta(); + __syncthreads(); + // After syncthreads, writes by all threads are visible to TMA engine. + + // Initiate TMA transfer to copy shared memory to global memory + if (is_master_thread) { + const int chunk_it_offset_y = chunk_offset_Y + iter * FP8_BUFFER_DIM_Y; + const int chunk_it_offset_x = chunk_offset_X; + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output), chunk_it_offset_x, + chunk_it_offset_y, reinterpret_cast(&out_sh[buff])); + + // Create a "bulk async-group" out of the previous bulk copy operation. + ptx::cp_async_bulk_commit_group(); + + // Wait for TMA transfer to have finished reading shared memory. + ptx::cp_async_bulk_wait_group_read(); + } + } + ptx::cp_async_bulk_wait_group_read<0>(); + __syncthreads(); + + parity ^= 1; + + if constexpr (IS_DBIAS) { + const int dbias_offset_X = my_column; + const int dbias_offset = dbias_offset_Y * dbias_stride + dbias_offset_X; + if (!col_out_of_bounds) { + dbias_workspace[dbias_offset] = partial_dbias; + } + } + + if (amax_ptr != nullptr) { + const int warp_id = threadIdx.x / THREADS_PER_WARP; + // Reduce the amax over the block + amax = reduce_max(amax, warp_id); + // Update the global amax + if (is_master_thread) { + atomicMaxFloat(amax_ptr, amax); + } + } + + // Update scale-inverse + if (is_master_thread && blockIdx.x == 0 && (scale_inv_ptr != nullptr)) { + reciprocal(scale_inv_ptr, scale); + } + + destroy_barriers(mbar, is_master_thread); +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} + +constexpr size_t CHUNKS_PER_BLOCK = 128; +constexpr size_t THREADS_PER_BLOCK = FP8_THREADS_PER_CHUNK; +constexpr size_t CHUNK_SIZE = THREADS_PER_BLOCK; +constexpr size_t ELEMS_PER_BLOCK = CHUNKS_PER_BLOCK * CHUNK_SIZE; +constexpr size_t CHUNKS_PER_ITERATION = 32; +constexpr size_t SHMEM_DIM = CHUNKS_PER_ITERATION * CHUNK_SIZE; +constexpr size_t ITERATIONS = CHUNKS_PER_BLOCK / CHUNKS_PER_ITERATION; +constexpr size_t SHMEM_BUFFERS = 2; +static_assert(CHUNKS_PER_BLOCK % CHUNKS_PER_ITERATION == 0); + +template +__global__ void __launch_bounds__(THREADS_PER_BLOCK) + cast_fp8_1D_kernel(const IType *input_ptr, OType *output_ptr, float *const amax_ptr, + float *const scale_inv_ptr, const float *const scale_ptr, const size_t N) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + + const int block_offset = blockIdx.x * ELEMS_PER_BLOCK; + const IType *input = input_ptr + block_offset; + OType *output = output_ptr + block_offset; + + float amax = 0; + const float scale = (scale_ptr != nullptr) ? *scale_ptr : 1; + + // The destination shared memory buffer of a bulk tensor operation should be 128-byte aligned + __shared__ alignas(128) IType in_sh[SHMEM_BUFFERS][SHMEM_DIM]; + __shared__ alignas(128) OType out_sh[SHMEM_BUFFERS][SHMEM_DIM]; + + constexpr int transaction_size_IN = sizeof(in_sh) / SHMEM_BUFFERS; + constexpr int transaction_size_OUT = sizeof(out_sh) / SHMEM_BUFFERS; + + const bool is_master_thread = (threadIdx.x == 0); + +// Initialize shared memory barrier with the number of threads participating in the barrier. +#pragma nv_diag_suppress static_var_with_dynamic_init + __shared__ alignas(8) uint64_t mbar[ITERATIONS]; + + initialize_barriers(mbar, is_master_thread); + + int parity = 0; + + copy_1d_to_shared(&(in_sh[0]), input, transaction_size_IN, &(mbar[0]), is_master_thread); + +#pragma unroll + for (int iter = 0; iter < ITERATIONS; ++iter) { + const int buff = iter % SHMEM_BUFFERS; + const int it_offset = iter * SHMEM_DIM; + + const int next_iter = iter + 1; + const int next_buff = next_iter % SHMEM_BUFFERS; + const int next_iter_offset = next_iter * SHMEM_DIM; + + if (next_iter < ITERATIONS) { + copy_1d_to_shared(&(in_sh[next_buff]), input + next_iter_offset, transaction_size_IN, + &(mbar[next_iter]), is_master_thread); + } + + ptx::fence_proxy_async_shared_cta(); + + // Wait for the data to have arrived + ptx::mbarrier_wait_parity(&mbar[iter], parity); + +#pragma unroll + for (int chunk = 0; chunk < CHUNKS_PER_ITERATION; ++chunk) { + const int shmem_offset = chunk * CHUNK_SIZE + threadIdx.x; + float elt = static_cast(in_sh[buff][shmem_offset]); + if constexpr (IS_ACT) { + elt = OP(elt, {}); + } + __builtin_assume(amax >= 0); + amax = fmaxf(amax, fabsf(elt)); + out_sh[buff][shmem_offset] = static_cast(elt * scale); + } + + // Wait for shared memory writes to be visible to TMA engine. + ptx::fence_proxy_async_shared_cta(); + __syncthreads(); + // After syncthreads, writes by all threads are visible to TMA engine. + + // Initiate TMA transfer to copy shared memory to global memory + if (is_master_thread) { + ptx::cp_async_bulk_tensor_1d_shared_to_global( + reinterpret_cast(output + it_offset), + reinterpret_cast(&out_sh[buff]), transaction_size_OUT); + + // Create a "bulk async-group" out of the previous bulk copy operation. + ptx::cp_async_bulk_commit_group(); + + // Wait for TMA transfer to have finished reading shared memory. + ptx::cp_async_bulk_wait_group_read<1>(); + } + } + ptx::cp_async_bulk_wait_group_read<0>(); + __syncthreads(); + + if (amax_ptr != nullptr) { + const int warp_id = threadIdx.x / THREADS_PER_WARP; + // Reduce the amax over the block + amax = reduce_max(amax, warp_id); + // Update the global amax + if (is_master_thread) { + atomicMaxFloat(amax_ptr, amax); + } + } + + // Update scale-inverse + if (is_master_thread && blockIdx.x == 0 && (scale_inv_ptr != nullptr)) { + reciprocal(scale_inv_ptr, scale); + } + + destroy_barriers(mbar, is_master_thread); +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} + +constexpr size_t DBIAS_THREADS_PER_BLOCK = 256; +template +__global__ void __launch_bounds__(DBIAS_THREADS_PER_BLOCK) + reduce_dbias_kernel(OType *const dbias_output, const float *const dbias_partial, const int rows, + const int cols) { + using ComputeVec = Vec; + using OutputVec = Vec; + + const int thread_id = blockIdx.x * blockDim.x + threadIdx.x; + + if (thread_id * nvec >= cols) { + return; + } + + const float *const thread_in_base = dbias_partial + thread_id * nvec; + OType *const thread_out_base = dbias_output + thread_id * nvec; + + ComputeVec ldg_vec; + ComputeVec acc_vec; + acc_vec.clear(); + for (int i = 0; i < rows; ++i) { + ldg_vec.load_from(thread_in_base + i * cols); +#pragma unroll + for (int e = 0; e < nvec; ++e) { + acc_vec.data.elt[e] += ldg_vec.data.elt[e]; + } + } + + OutputVec stg_vec; +#pragma unroll + for (int e = 0; e < nvec; ++e) { + stg_vec.data.elt[e] = static_cast(acc_vec.data.elt[e]); + } + stg_vec.store_to(thread_out_base); +} + +template +void reduce_dbias(const float *workspace_ptr, Tensor *dbias, const size_t rows, const size_t cols, + cudaStream_t stream) { + constexpr int reduce_dbias_store_bytes = 8; // stg.64 + constexpr int reduce_dbias_nvec = reduce_dbias_store_bytes / sizeof(IType); + + NVTE_CHECK(cols % reduce_dbias_nvec == 0, "Unsupported shape."); + const size_t reduce_dbias_num_blocks = DIVUP(cols, DBIAS_THREADS_PER_BLOCK * reduce_dbias_nvec); + + reduce_dbias_kernel + <<>>( + reinterpret_cast(dbias->data.dptr), workspace_ptr, rows, cols); +} + +template +static void cast_fp8_1D(const Tensor &input, Tensor *output, cudaStream_t stream) { + const size_t N = product(input.data.shape); + + const bool isFullTile = (N % ELEMS_PER_BLOCK == 0); + NVTE_CHECK(isFullTile, "Only full tiles are supported."); + NVTE_CHECK(is_fp8_dtype(output->dtype()), "Output must have FP8 type."); + NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated"); + + const size_t chunks = DIVUP(N, CHUNK_SIZE); + const size_t blocks = DIVUP(chunks, CHUNKS_PER_BLOCK); + + float *const amax_ptr = reinterpret_cast(output->amax.dptr); + float *const scale_inv_ptr = reinterpret_cast(output->scale_inv.dptr); + const float *const scale_ptr = reinterpret_cast(output->scale.dptr); + + const dim3 block(THREADS_PER_BLOCK); + const dim3 grid(blocks); + + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + input.dtype(), IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + output->dtype(), OType, + const IType *input_ptr = reinterpret_cast(input.data.dptr); + OType *output_ptr = reinterpret_cast(output->data.dptr); + + cast_fp8_1D_kernel<<>>( + input_ptr, output_ptr, amax_ptr, scale_inv_ptr, scale_ptr, N);); // NOLINT(*) + ); // NOLINT(*) +} + +template +void cast_fp8_2D(const Tensor &input, const Tensor *act_input, Tensor *output, Tensor *dbias, + Tensor *workspace, cudaStream_t stream) { + checkCuDriverContext(stream); + + const size_t rows = input.flat_first_dim(); + const size_t cols = input.flat_last_dim(); + const size_t chunks_Y = DIVUP(rows, FP8_CHUNK_DIM_Y); + const size_t chunks_X = DIVUP(cols, FP8_CHUNK_DIM_X); + const size_t blocks_Y = chunks_Y; + const size_t blocks_X = chunks_X; + + const size_t dbias_rows = blocks_Y; + const size_t dbias_cols = cols; + + NVTE_CHECK(is_fp8_dtype(output->dtype()), "Output must have FP8 type."); + NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated"); + + if constexpr (IS_DBIAS) { + NVTE_CHECK(dbias->data.dtype == input.data.dtype, "DBias must have the same type as input."); + NVTE_CHECK(dbias->data.shape == std::vector{cols}, "Wrong shape of DBias."); + NVTE_CHECK(workspace != nullptr, "Workspace must be a tensor."); + + if (workspace->data.dptr == nullptr) { + workspace->data.shape = {dbias_rows, dbias_cols}; + workspace->data.dtype = DType::kFloat32; + return; + } + } + float *const workspace_ptr = IS_DBIAS ? reinterpret_cast(workspace->data.dptr) : nullptr; + float *const amax_ptr = reinterpret_cast(output->amax.dptr); + float *const scale_inv_ptr = reinterpret_cast(output->scale_inv.dptr); + float *const scale_ptr = reinterpret_cast(output->scale.dptr); + + const dim3 block(FP8_THREADS_PER_CHUNK); + const dim3 grid(blocks_X, blocks_Y); + + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + input.data.dtype, IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + output->data.dtype, OType, + + alignas(64) CUtensorMap tensor_map_input{}; + alignas(64) CUtensorMap tensor_map_act_input{}; + alignas(64) CUtensorMap tensor_map_output{}; + + create_2D_tensor_map(tensor_map_input, input.data, rows, cols, FP8_SHMEM_DIM_Y, + FP8_SHMEM_DIM_X, cols, 0, sizeof(IType)); + + if constexpr (IS_DACT) { + create_2D_tensor_map(tensor_map_act_input, act_input->data, rows, cols, FP8_SHMEM_DIM_Y, + FP8_SHMEM_DIM_X, cols, 0, sizeof(IType)); + } + + create_2D_tensor_map(tensor_map_output, output->data, rows, cols, FP8_SHMEM_DIM_Y, + FP8_SHMEM_DIM_X, cols, 0, sizeof(OType)); + + cast_fp8_2D_kernel + <<>>(tensor_map_input, tensor_map_act_input, tensor_map_output, + workspace_ptr, amax_ptr, scale_inv_ptr, scale_ptr, rows, + cols); + + if constexpr (IS_DBIAS) { + reduce_dbias(workspace_ptr, dbias, dbias_rows, dbias_cols, stream); + }); // NOLINT(*) + ); // NOLINT(*) +} + +template +void mxfp8_quantize(const Tensor &input, const Tensor *act_input, + const Tensor *noop, // TODO (ksivamani) + Tensor *output, Tensor *dbias, Tensor *workspace, cudaStream_t stream) { + bool use_rowwise_scaling = output->has_data(); + bool use_colwise_scaling = output->has_columnwise_data(); + checkCuDriverContext(stream); + NVTE_CHECK(input.has_data(), "Cannot quantize tensor without rowwise data."); + const auto &input_shape = input.data.shape; + NVTE_CHECK(is_fp8_dtype(output->dtype()), "Output must have FP8 type."); + + if (use_rowwise_scaling) { + NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated"); + } + if (use_colwise_scaling) { + NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr, + "Columnwise scaling tensor must be allocated"); + } + CheckNoopTensor(*noop, "cast_noop"); + + // TODO: Make more general + const size_t scale_dim_X_rowwise = use_rowwise_scaling ? 32 : 1; + const size_t scale_dim_Y_colwise = use_colwise_scaling ? 32 : 1; + + const size_t rows = input.flat_first_dim(); + const size_t cols = input.flat_last_dim(); + const size_t chunks_Y = DIVUP(rows, MXFP8_CHUNK_DIM_Y); + const size_t chunks_X = DIVUP(cols, MXFP8_CHUNK_DIM_X); + const size_t blocks_Y = DIVUP(chunks_Y, MXFP8_CHUNKS_PER_BLOCK_Y); + const size_t blocks_X = DIVUP(chunks_X, MXFP8_CHUNKS_PER_BLOCK_X); + + const size_t scale_stride_rowwise = use_rowwise_scaling ? output->scale_inv.shape[1] : 1; + const size_t scale_stride_colwise = + use_colwise_scaling ? output->columnwise_scale_inv.shape[1] : 1; + + e8m0_t *const scales_rowwise_ptr = + use_rowwise_scaling ? reinterpret_cast(output->scale_inv.dptr) : nullptr; + e8m0_t *const scales_colwise_ptr = + use_colwise_scaling ? reinterpret_cast(output->columnwise_scale_inv.dptr) : nullptr; + const size_t dbias_rows = blocks_Y; + const size_t dbias_cols = cols; + + if constexpr (IS_DBIAS) { + NVTE_CHECK(dbias->data.dtype == input.dtype(), "DBias must have the same type as input."); + NVTE_CHECK(dbias->data.shape == std::vector{cols}, "Wrong shape of DBias."); + NVTE_CHECK(workspace != nullptr, "Workspace must be a tensor."); + + if (workspace->data.dptr == nullptr) { + workspace->data.shape = {dbias_rows, dbias_cols}; + workspace->data.dtype = DType::kFloat32; + return; + } + } + + float *const workspace_ptr = IS_DBIAS ? reinterpret_cast(workspace->data.dptr) : nullptr; + float *const amax_ptr = reinterpret_cast(output->amax.dptr); + + const dim3 block(MXFP8_THREADS_PER_CHUNK); + const dim3 grid(blocks_X, blocks_Y); + + TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( + scale_dim_Y_colwise, SCALE_DIM_Y, + TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( + scale_dim_X_rowwise, SCALE_DIM_X, + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + input.dtype(), IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + output->dtype(), OType, + + alignas(64) CUtensorMap tensor_map_input{}; + alignas(64) CUtensorMap tensor_map_act_input{}; + alignas(64) CUtensorMap tensor_map_output_rowwise{}; + alignas(64) CUtensorMap tensor_map_output_colwise{}; + + create_2D_tensor_map(tensor_map_input, input.data, rows, cols, MXFP8_SHMEM_DIM_Y, + MXFP8_SHMEM_DIM_X, cols, 0, sizeof(IType)); + + if constexpr (IS_DACT) { + create_2D_tensor_map(tensor_map_act_input, act_input->data, rows, cols, + MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X, cols, 0, + sizeof(IType)); + } + + if (use_rowwise_scaling) { + create_2D_tensor_map(tensor_map_output_rowwise, output->data, rows, cols, + MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X, cols, 0, + sizeof(OType)); + } + + if (use_colwise_scaling) { + create_2D_tensor_map(tensor_map_output_colwise, output->columnwise_data, rows, + cols, MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X, cols, 0, + sizeof(OType)); + } + + cast_mxfp8_2D_kernel<<>>( + tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, + tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, + reinterpret_cast(noop->data.dptr), workspace_ptr, amax_ptr, + rows, cols, scale_stride_rowwise, scale_stride_colwise); + + if constexpr (IS_DBIAS) { + reduce_dbias(workspace_ptr, dbias, dbias_rows, dbias_cols, stream); + }); // NOLINT(*) + ); // NOLINT(*) + ); // NOLINT(*) + ); // NOLINT(*) +} + +namespace detail { + +using Empty = transformer_engine::Empty; + +__device__ inline float identity(float value, const Empty &) { return value; } + +struct DequantizeParam { + const float *scale_inv; +}; + +__device__ inline float dequantize_func(float value, const DequantizeParam ¶m) { + return value * (*(param.scale_inv)); +} + +} // namespace detail + +template +void CastVectorizedUnaryKernelLauncher(const Tensor &input, const Tensor *noop, Tensor *output, + cudaStream_t stream) { + constexpr float (*UnaryOP)(float, const ParamOP &) = (OP == nullptr) ? detail::identity : OP; + const size_t N = product(input.data.shape); + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + input.data.dtype, IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( + output->data.dtype, OType, + if (!is_fp8_dtype(output->data.dtype) || + is_delayed_tensor_scaling(output->scaling_mode)) { + constexpr int nvec = 32 / sizeof(IType); + VectorizedUnaryKernelLauncher( + reinterpret_cast(input.data.dptr), + reinterpret_cast(noop->data.dptr), + reinterpret_cast(output->data.dptr), + reinterpret_cast(output->scale.dptr), + reinterpret_cast(output->amax.dptr), + reinterpret_cast(output->scale_inv.dptr), N, {}, stream); + } else { + NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + "."); + }); // NOLINT(*) + ); // NOLINT(*) +} + +template +void CastVectorizedUnaryGradKernelLauncher(const Tensor &grad, const Tensor *input, Tensor *output, + cudaStream_t stream) { + constexpr float (*UnaryOP)(float, const ParamOP &) = (OP == nullptr) ? detail::identity : OP; + const size_t N = product(input->data.shape); + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + input->data.dtype, IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( + output->data.dtype, OType, + if (!is_fp8_dtype(output->data.dtype) || + is_delayed_tensor_scaling(output->scaling_mode)) { + constexpr int nvec = 32 / sizeof(IType); + VectorizedUnaryGradKernelLauncher( + reinterpret_cast(grad.data.dptr), + reinterpret_cast(input->data.dptr), + reinterpret_cast(output->data.dptr), + reinterpret_cast(output->scale.dptr), + reinterpret_cast(output->amax.dptr), + reinterpret_cast(output->scale_inv.dptr), N, {}, stream); + } else { + NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + "."); + }); // NOLINT(*) + ); // NOLINT(*) +} + +namespace { + +static bool is_full_tile_1D_tensor(const Tensor *const t) { + const size_t N = product(t->data.shape); + const bool isFullTile = (N % ELEMS_PER_BLOCK == 0); + return isFullTile; +} + +bool dimensions_supported_by_TMA(const Tensor *const t) { + const size_t cols = t->flat_last_dim(); + constexpr int TMA_bytes = 16; + const int alignment_requirement = TMA_bytes / typeToSize(t->dtype()); + return cols % alignment_requirement == 0; +} + +} // namespace + +// Supported by the Arch >= 10.0 +template +void fp8_quantize_arch_ge_100(const Tensor &input, const Tensor *act_input, const Tensor *noop, + Tensor *output, Tensor *dbias, Tensor *workspace, + cudaStream_t stream) { + switch (output->scaling_mode) { + case NVTE_DELAYED_TENSOR_SCALING: { + if (!IS_DBIAS && !IS_DACT) { + if (is_full_tile_1D_tensor(output) && is_fp8_dtype(output->dtype())) { + // Aligned AND FP8 + cast_fp8_1D(input, output, stream); + } else { + // Unaligned + CastVectorizedUnaryKernelLauncher(input, noop, output, stream); + } + } else if (!IS_DBIAS && IS_DACT) { + if (dimensions_supported_by_TMA(output) && is_fp8_dtype(output->dtype())) { + // Aligned AND FP8 (+dAct) + cast_fp8_2D(input, act_input, output, dbias, workspace, + stream); + } else { + // Unaligned + CastVectorizedUnaryGradKernelLauncher(input, act_input, output, stream); + } + } else { + cast_fp8_2D(input, act_input, output, dbias, workspace, + stream); + } + break; + } + case NVTE_MXFP8_1D_SCALING: { + mxfp8_quantize(input, act_input, noop, output, dbias, + workspace, stream); + break; + } + default: + NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + "."); + } +} + +// Supported by the Arch < 10.0 +template +void fp8_quantize_arch_l_100(const Tensor &input, const Tensor *act_input, const Tensor *noop, + Tensor *output, Tensor *dbias, Tensor *workspace, + cudaStream_t stream) { + if (!is_delayed_tensor_scaling(output->scaling_mode) || IS_DBIAS) { + NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + + " on GPU with compute capability < 10.0."); + } + if (!IS_DACT) { + CastVectorizedUnaryKernelLauncher(input, noop, output, stream); + } else { + CastVectorizedUnaryGradKernelLauncher(input, act_input, output, stream); + } +} + +template +void fp8_quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, Tensor *output, + Tensor *dbias, Tensor *workspace, cudaStream_t stream) { + CheckNoopTensor(*noop, "cast_noop"); + CheckInputTensor(input, "cast_input"); + CheckOutputTensor(*output, "cast_output"); + + if constexpr (IS_DBIAS) { + NVTE_CHECK(dbias != nullptr); + CheckOutputTensor(*dbias, "dbias"); + } + if constexpr (IS_DACT) { + NVTE_CHECK(act_input != nullptr); + CheckInputTensor(*act_input, "activation_input"); + NVTE_CHECK(input.dtype() == act_input->dtype(), "Types of both inputs must match."); + NVTE_CHECK(input.data.shape == act_input->data.shape, "Shapes of both inputs must match."); + } + + NVTE_CHECK(!is_fp8_dtype(input.dtype()), "Input must be in higher precision."); + NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match."); + + // Supported by the Arch >= 10.0 + if (is_supported_by_CC_100()) { + fp8_quantize_arch_ge_100(input, act_input, noop, output, + dbias, workspace, stream); + } else { + // Supported by the Arch < 10.0 + fp8_quantize_arch_l_100(input, act_input, noop, output, + dbias, workspace, stream); + } +} + +namespace detail { + +template +void quantize_helper(const NVTETensor input, const NVTETensor grad, const NVTETensor noop, + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + cudaStream_t stream) { + const Tensor *input_tensor; + const Tensor *activation_input_tensor; + if constexpr (IS_DBIAS || IS_DACT) { + // backward - input is incoming gradient + input_tensor = reinterpret_cast(grad); + activation_input_tensor = reinterpret_cast(input); + } else { + // forward = input is activation input + input_tensor = reinterpret_cast(input); + activation_input_tensor = nullptr; + } + auto output_tensor = reinterpret_cast(output); + auto dbias_tensor = reinterpret_cast(dbias); + auto workspace_tensor = reinterpret_cast(workspace); + const auto noop_tensor = noop != nullptr ? *(reinterpret_cast(noop)) : Tensor(); + + switch (output_tensor->scaling_mode) { + case NVTE_DELAYED_TENSOR_SCALING: { + if (output_tensor->has_columnwise_data()) { + NVTE_CHECK(output_tensor->has_data(), + "Quantizing in only the columnwise direction not supported yet!"); + if constexpr (!IS_DBIAS && !IS_DACT && !IS_ACT) { + cast_transpose(*input_tensor, noop_tensor, output_tensor, stream); + } else { + cast_transpose_fused( + *input_tensor, activation_input_tensor, output_tensor, dbias_tensor, workspace_tensor, + stream); + } + } else if (output_tensor->has_data()) { + fp8_quantize( + *input_tensor, activation_input_tensor, &noop_tensor, output_tensor, dbias_tensor, + workspace_tensor, stream); + } + break; + } + case NVTE_MXFP8_1D_SCALING: { + mxfp8_quantize( + *input_tensor, activation_input_tensor, &noop_tensor, output_tensor, dbias_tensor, + workspace_tensor, stream); + break; + } + default: + NVTE_ERROR("Not implemented scaling mode: " + to_string(output_tensor->scaling_mode) + "."); + } +} + +} // namespace detail +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_CAST_KERNELS_CUH_ diff --git a/transformer_engine/common/util/cuda_runtime.cpp b/transformer_engine/common/util/cuda_runtime.cpp index cc9a659b5b..8b6bb52397 100644 --- a/transformer_engine/common/util/cuda_runtime.cpp +++ b/transformer_engine/common/util/cuda_runtime.cpp @@ -81,6 +81,26 @@ int sm_count(int device_id) { return cache[device_id]; } +void stream_priority_range(int *low_priority, int *high_priority, int device_id) { + static std::vector> cache(num_devices()); + static std::vector flags(num_devices()); + if (device_id < 0) { + device_id = current_device(); + } + NVTE_CHECK(0 <= device_id && device_id < num_devices(), "invalid CUDA device ID"); + auto init = [&]() { + int ori_dev = current_device(); + if (device_id != ori_dev) NVTE_CHECK_CUDA(cudaSetDevice(device_id)); + int min_pri, max_pri; + NVTE_CHECK_CUDA(cudaDeviceGetStreamPriorityRange(&min_pri, &max_pri)); + if (device_id != ori_dev) NVTE_CHECK_CUDA(cudaSetDevice(ori_dev)); + cache[device_id] = std::make_pair(min_pri, max_pri); + }; + std::call_once(flags[device_id], init); + *low_priority = cache[device_id].first; + *high_priority = cache[device_id].second; +} + bool supports_multicast(int device_id) { #if CUDART_VERSION >= 12010 // NOTE: This needs to be guarded at compile time because the diff --git a/transformer_engine/common/util/cuda_runtime.h b/transformer_engine/common/util/cuda_runtime.h index 33c2aea8d4..072eacd623 100644 --- a/transformer_engine/common/util/cuda_runtime.h +++ b/transformer_engine/common/util/cuda_runtime.h @@ -38,6 +38,16 @@ int sm_arch(int device_id = -1); */ int sm_count(int device_id = -1); +/* \brief Minimum and maximum stream priorities supported on device + * + * \param[in] device_id CUDA device (default is current device) + * + * \param[out] low_priority Lowest priority value on device. + * + * \param[out] high_priority Highest priority value on device. + */ +void stream_priority_range(int *low_priority, int *high_priority, int device_id = -1); + /* \brief CUDA Multicast support status for device * * \param[in] device_id CUDA device (default is current device) diff --git a/transformer_engine/common/util/dequantize_kernels.cuh b/transformer_engine/common/util/dequantize_kernels.cuh new file mode 100644 index 0000000000..e529289640 --- /dev/null +++ b/transformer_engine/common/util/dequantize_kernels.cuh @@ -0,0 +1,360 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file dequantize_kernels.cuh + * \brief CUDA kernels to cast from MXFP8. + */ + +#ifndef TRANSFORMER_ENGINE_DEQUANTIZE_KERNELS_CUH_ +#define TRANSFORMER_ENGINE_DEQUANTIZE_KERNELS_CUH_ + +#include +#include +#include +#include + +#include +#include + +#include "../common.h" +#include "../transpose/cast_transpose.h" +#include "../util/vectorized_pointwise.h" +#include "../utils.cuh" +#include "math.h" +#include "ptx.cuh" +#include "transformer_engine/activation.h" +#include "transformer_engine/transpose.h" + +namespace transformer_engine { + +namespace dequantization { + +constexpr size_t CHUNK_DIM_Y = 128; +constexpr size_t CHUNK_DIM_X = 128; +constexpr size_t THREADS_PER_CHUNK = 128; +constexpr size_t BUFFERS_NUM = 2; + +constexpr size_t ELEMS_PER_THREAD = 16; +constexpr size_t BUFFER_DIM_Y = 16; // only 32 is supported +constexpr size_t BUFFER_DIM_X = CHUNK_DIM_X; // 128 +constexpr size_t SHMEM_DIM_Y = BUFFER_DIM_Y; // 16 +constexpr size_t SHMEM_DIM_X = BUFFER_DIM_X; // 128 + +constexpr size_t THREADS_PER_CHUNK_X_ROWWISE = CHUNK_DIM_X / ELEMS_PER_THREAD; // 8 = 128 / 16 +constexpr size_t THREADS_PER_CHUNK_X_COLWISE = CHUNK_DIM_X; // 128 +constexpr size_t ITERATIONS = CHUNK_DIM_Y / BUFFER_DIM_Y; // 8 = 128 / 16 +static_assert(ITERATIONS >= 1); + +template +__global__ void __launch_bounds__(THREADS_PER_CHUNK) + dequantize_mxfp8_kernel(const __grid_constant__ CUtensorMap tensor_map_input, + const __grid_constant__ CUtensorMap tensor_map_output, + const e8m0_t *const scales_ptr, const size_t rows, const size_t cols, + const size_t scales_stride) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + constexpr bool USE_ROWWISE_SCALING = SCALE_DIM_X > 1; + constexpr bool USE_COLWISE_SCALING = SCALE_DIM_Y > 1; + + constexpr size_t SCALES_ROWWISE_PER_CHUNK_Y = CHUNK_DIM_Y; // 128 + constexpr size_t SCALES_ROWWISE_PER_CHUNK_X = CHUNK_DIM_X / SCALE_DIM_X; // 4 = 128 / 32 + + constexpr size_t SCALES_COLWISE_PER_CHUNK_Y = CHUNK_DIM_Y / SCALE_DIM_Y; // 4 = 128 / 32 + constexpr size_t SCALES_COLWISE_PER_CHUNK_X = CHUNK_DIM_X; // 128 + + constexpr size_t THREADS_PER_SCALE_X_ROWWISE = + DIVUP(SCALE_DIM_X, ELEMS_PER_THREAD); // 2 = 32 / 16 + constexpr size_t SUBWARP_WIDTH = THREADS_PER_SCALE_X_ROWWISE; // 2 + + const int chunk_offset_Y = blockIdx.y * CHUNK_DIM_Y; + const int chunk_offset_X = blockIdx.x * CHUNK_DIM_X; + + const int scales_rowwise_chunk_offset_Y = blockIdx.y * SCALES_ROWWISE_PER_CHUNK_Y; + const int scales_rowwise_chunk_offset_X = blockIdx.x * SCALES_ROWWISE_PER_CHUNK_X; + const int scales_colwise_chunk_offset_Y = blockIdx.y * SCALES_COLWISE_PER_CHUNK_Y; + const int scales_colwise_chunk_offset_X = blockIdx.x * SCALES_COLWISE_PER_CHUNK_X; + + const int tid_rowwise_Y = threadIdx.x / THREADS_PER_CHUNK_X_ROWWISE; + const int tid_rowwise_X = threadIdx.x % THREADS_PER_CHUNK_X_ROWWISE; + // const int tid_colwise_Y = threadIdx.x / THREADS_PER_CHUNK_X_COLWISE; + const int tid_colwise_X = threadIdx.x % THREADS_PER_CHUNK_X_COLWISE; + + const int thread_offset_Y = tid_rowwise_Y; + const int thread_offset_X_rowwise = tid_rowwise_X * ELEMS_PER_THREAD; + // const int thread_offset_X_colwise = tid_colwise_X; + + // The destination shared memory buffer of a bulk tensor operation should be 128 e8m0_t aligned + __shared__ alignas(128) IType in_sh[BUFFERS_NUM][SHMEM_DIM_Y][SHMEM_DIM_X]; + __shared__ alignas(128) OType out_sh[BUFFERS_NUM][SHMEM_DIM_Y][SHMEM_DIM_X]; + + constexpr int shmem_buff_size = sizeof(in_sh) / BUFFERS_NUM; + constexpr int transaction_size = shmem_buff_size; + + const bool is_master_thread = (threadIdx.x == 0); + +// Initialize shared memory barrier with the number of threads participating in the barrier. +#pragma nv_diag_suppress static_var_with_dynamic_init + __shared__ alignas(8) uint64_t mbar[ITERATIONS]; + + if (is_master_thread) { +// Initialize barrier. All `blockDim.x * blockDim.y` threads in block participate. +#pragma unroll + for (int iter = 0; iter < ITERATIONS; ++iter) { + ptx::mbarrier_init(&mbar[iter], THREADS_PER_CHUNK); + } + ptx::fence_proxy_async_shared_cta(); + } + // Syncthreads so initialized barrier is visible to all threads. + __syncthreads(); + + int parity = 0; + constexpr int iteration_zero = 0; + constexpr int buffer_zero = 0; + if (is_master_thread) { + const int chunk_stage_offset_Y = chunk_offset_Y; + const int chunk_stage_offset_X = chunk_offset_X; + // Initiate bulk tensor copy + ptx::cp_async_bulk_tensor_2d_global_to_shared( + reinterpret_cast(&in_sh[buffer_zero]), + reinterpret_cast(&tensor_map_input), chunk_stage_offset_X, + chunk_stage_offset_Y, &mbar[iteration_zero]); + + // Arrive on the barrier and tell how many bytes are expected to come in. + ptx::mbarrier_arrive_expect_tx(&mbar[iteration_zero], transaction_size); + + } else { + // Other threads just arrive + ptx::mbarrier_arrive(&mbar[iteration_zero]); + } + +#pragma unroll + for (int iter = 0; iter < ITERATIONS; ++iter) { + const int buff = iter % BUFFERS_NUM; + const int next_iter = iter + 1; + if (next_iter < ITERATIONS) { + if (is_master_thread) { + const int next_buff = next_iter % BUFFERS_NUM; + const int chunk_it_offset_y = chunk_offset_Y + next_iter * BUFFER_DIM_Y; + const int chunk_it_offset_x = chunk_offset_X; + // Initiate bulk tensor copy + ptx::cp_async_bulk_tensor_2d_global_to_shared( + reinterpret_cast(&in_sh[next_buff]), + reinterpret_cast(&tensor_map_input), chunk_it_offset_x, + chunk_it_offset_y, &mbar[next_iter]); + + // Arrive on the barrier and tell how many bytes are expected to come in. + ptx::mbarrier_arrive_expect_tx(&mbar[next_iter], transaction_size); + } else { + // Other threads just arrive + ptx::mbarrier_arrive(&mbar[next_iter]); + } + } + + ptx::fence_proxy_async_shared_cta(); + + // Wait for the data to have arrived + ptx::mbarrier_wait_parity(&mbar[iter], parity); + + const int scale_offset_Y = + USE_ROWWISE_SCALING ? (scales_rowwise_chunk_offset_Y + iter * BUFFER_DIM_Y + tid_rowwise_Y) + : (scales_colwise_chunk_offset_Y + (iter * BUFFER_DIM_Y) / SCALE_DIM_Y); + + const int scale_offset_X = + USE_ROWWISE_SCALING + ? (scales_rowwise_chunk_offset_X + tid_rowwise_X / THREADS_PER_SCALE_X_ROWWISE) + : (scales_colwise_chunk_offset_X + tid_colwise_X); + + const int scale_idx = scale_offset_Y * scales_stride + scale_offset_X; + const e8m0_t biased_exponent = scales_ptr[scale_idx]; + const float block_scale = exp2f(static_cast(biased_exponent) - FP32_EXPONENT_BIAS); + + if constexpr (USE_ROWWISE_SCALING) { + Vec in; + Vec out; + + const int shmem_offset_y = thread_offset_Y; + const int shmem_offset_x = thread_offset_X_rowwise; + in.load_from(&in_sh[buff][shmem_offset_y][shmem_offset_x]); + +#pragma unroll + for (int j = 0; j < ELEMS_PER_THREAD; ++j) { + out.data.elt[j] = static_cast(block_scale * static_cast(in.data.elt[j])); + } + out.store_to(&out_sh[buff][shmem_offset_y][shmem_offset_x]); + } else { +#pragma unroll + for (int i = 0; i < BUFFER_DIM_Y; ++i) { + const float elt = static_cast(in_sh[buff][i][tid_colwise_X]); + out_sh[buff][i][tid_colwise_X] = static_cast(block_scale * elt); + } + } + + // Wait for shared memory writes to be visible to TMA engine. + ptx::fence_proxy_async_shared_cta(); + __syncthreads(); + // After syncthreads, writes by all threads are visible to TMA engine. + + // Initiate TMA transfer to copy shared memory to global memory + if (is_master_thread) { + const int chunk_it_offset_y = chunk_offset_Y + iter * BUFFER_DIM_Y; + const int chunk_it_offset_x = chunk_offset_X; + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output), chunk_it_offset_x, + chunk_it_offset_y, reinterpret_cast(&out_sh[buff])); + + // Create a "bulk async-group" out of the previous bulk copy operation. + ptx::cp_async_bulk_commit_group(); + + // Wait for TMA transfer to have finished reading shared memory. + ptx::cp_async_bulk_wait_group_read<1>(); + } + } + ptx::cp_async_bulk_wait_group_read<0>(); + __syncthreads(); + + parity ^= 1; + + // Destroy barrier. This invalidates the memory region of the barrier. If + // further computations were to take place in the kernel, this allows the + // memory location of the shared memory barrier to be reused. + if (is_master_thread) { +#pragma unroll + for (int iter = 0; iter < ITERATIONS; ++iter) { + ptx::mbarrier_invalid(&mbar[iter]); + } + } +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} + +static void fp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) { + NVTE_CHECK(is_fp8_dtype(input.data.dtype), "Input must have FP8 type."); + NVTE_CHECK(!is_fp8_dtype(output->data.dtype), "Output must be in higher precision."); + NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match."); + + const size_t N = product(input.data.shape); + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + input.data.dtype, IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( + output->data.dtype, OType, + + constexpr int nvec = 32 / sizeof(OType); + detail::DequantizeParam p; + p.scale_inv = reinterpret_cast(input.scale_inv.dptr); + VectorizedUnaryKernelLauncher( + reinterpret_cast(input.data.dptr), nullptr, + reinterpret_cast(output->data.dptr), nullptr, nullptr, nullptr, N, p, + stream);); // NOLINT(*) + ); // NOLINT(*) +} + +static void mxfp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) { + bool use_rowwise_scaling = input.has_data(); + bool use_colwise_scaling = input.has_columnwise_data(); + checkCuDriverContext(stream); + + const auto &input_shape = input.data.shape; + NVTE_CHECK(input_shape.size() >= 2, "Input must have at least 2 dimensions."); + + if (use_rowwise_scaling) { + NVTE_CHECK(input.has_data(), "Cannot dequantize tensor without rowwise data."); + NVTE_CHECK(is_fp8_dtype(input.data.dtype), "Input must have FP8 type."); + } + + if (use_colwise_scaling) { + NVTE_CHECK(input.has_columnwise_data(), "Cannot dequantize tensor without columnwise data."); + NVTE_CHECK(is_fp8_dtype(input.columnwise_data.dtype), "Input must have FP8 type."); + } + + NVTE_CHECK(!is_fp8_dtype(output->data.dtype), "Output must be in higher precision."); + NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match."); + + // TODO: Make more general + const size_t scale_dim_X_rowwise = use_rowwise_scaling ? 32 : 1; + const size_t scale_dim_Y_colwise = use_colwise_scaling ? 32 : 1; + + const size_t rows = input.flat_first_dim(); + const size_t cols = input.flat_last_dim(); + const size_t chunks_Y = DIVUP(rows, CHUNK_DIM_Y); + const size_t chunks_X = DIVUP(cols, CHUNK_DIM_X); + + const size_t unpadded_scales_Y_rowwise = rows; + const size_t unpadded_scales_X_rowwise = DIVUP(cols, scale_dim_X_rowwise); + const size_t unpadded_scales_Y_colwise = DIVUP(rows, scale_dim_Y_colwise); + const size_t unpadded_scales_X_colwise = cols; + + const size_t scales_Y_rowwise = + DIVUP(unpadded_scales_Y_rowwise, scale_tensor_alignment_Y_rowwise) * + scale_tensor_alignment_Y_rowwise; + const size_t scales_X_rowwise = + DIVUP(unpadded_scales_X_rowwise, scale_tensor_alignment_X_rowwise) * + scale_tensor_alignment_X_rowwise; + const size_t scales_Y_colwise = + DIVUP(unpadded_scales_Y_colwise, scale_tensor_alignment_Y_colwise) * + scale_tensor_alignment_Y_colwise; + const size_t scales_X_colwise = + DIVUP(unpadded_scales_X_colwise, scale_tensor_alignment_X_colwise) * + scale_tensor_alignment_X_colwise; + + const e8m0_t *const scales_ptr = + use_rowwise_scaling ? reinterpret_cast(input.scale_inv.dptr) + : reinterpret_cast(input.columnwise_scale_inv.dptr); + + const size_t scales_stride = use_rowwise_scaling ? scales_X_rowwise : scales_X_colwise; + + const SimpleTensor &input_data = use_rowwise_scaling ? input.data : input.columnwise_data; + + const dim3 block(THREADS_PER_CHUNK); + const dim3 grid(chunks_X, chunks_Y); + + TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( + scale_dim_Y_colwise, SCALE_DIM_Y, + TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( + scale_dim_X_rowwise, SCALE_DIM_X, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + input.dtype(), IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( + output->dtype(), OType, + + alignas(64) CUtensorMap tensor_map_input{}; + alignas(64) CUtensorMap tensor_map_output{}; + + create_2D_tensor_map(tensor_map_input, input_data, rows, cols, SHMEM_DIM_Y, + SHMEM_DIM_X, cols, 0, sizeof(IType)); + create_2D_tensor_map(tensor_map_output, output->data, rows, cols, SHMEM_DIM_Y, + SHMEM_DIM_X, cols, 0, sizeof(OType)); + + dequantize_mxfp8_kernel + <<>>(tensor_map_input, tensor_map_output, scales_ptr, + rows, cols, scales_stride);); // NOLINT(*) + ); // NOLINT(*) + ); // NOLINT(*) + ); // NOLINT(*) +} +} // namespace dequantization + +namespace detail { + +void dequantize_helper(const Tensor &input, Tensor *output, cudaStream_t stream) { + CheckInputTensor(input, "cast_input"); + CheckOutputTensor(*output, "cast_output"); + + if (is_tensor_scaling(input.scaling_mode)) { + dequantization::fp8_dequantize(input, output, stream); + } else if (is_mxfp_scaling(input.scaling_mode)) { + if (is_supported_by_CC_100()) { + dequantization::mxfp8_dequantize(input, output, stream); + } else { + NVTE_ERROR("MXFP8 Dequantization is NOT supported by architectures < 10.0"); + } + } else { + NVTE_ERROR("Not implemented scaling mode: " + to_string(input.scaling_mode) + "."); + } +} + +} // namespace detail + +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_DEQUANTIZE_KERNELS_CUH_ diff --git a/transformer_engine/common/util/ptx.cuh b/transformer_engine/common/util/ptx.cuh new file mode 100644 index 0000000000..a22b930ecd --- /dev/null +++ b/transformer_engine/common/util/ptx.cuh @@ -0,0 +1,300 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file ptx.cuh + * \brief BW PTX + */ + +#ifndef TRANSFORMER_ENGINE_PTX_CUH_ +#define TRANSFORMER_ENGINE_PTX_CUH_ + +#include +#include + +namespace transformer_engine { +namespace ptx { + +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-init +__device__ __forceinline__ void mbarrier_init(uint64_t *mbar, const uint32_t count) { + uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); + asm volatile("mbarrier.init.shared.b64 [%0], %1;" ::"r"(mbar_ptr), "r"(count) : "memory"); +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-inval +__device__ __forceinline__ void mbarrier_invalid(uint64_t *mbar) { + uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); + asm volatile("mbarrier.inval.shared.b64 [%0];" ::"r"(mbar_ptr) : "memory"); +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive +__device__ __forceinline__ void mbarrier_arrive(uint64_t *mbar) { + uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); + asm volatile("mbarrier.arrive.shared.b64 _, [%0];" ::"r"(mbar_ptr) : "memory"); +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive +__device__ __forceinline__ void mbarrier_arrive_expect_tx(uint64_t *mbar, const uint32_t tx_count) { + uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); + asm volatile("mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;" ::"r"(mbar_ptr), "r"(tx_count) + : "memory"); +} + +__device__ __forceinline__ void fence_mbarrier_init_release_cluster() { + asm volatile("fence.mbarrier_init.release.cluster;"); +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor +// global -> shared::cluster +__device__ __forceinline__ void cp_async_bulk_tensor_1d_global_to_shared( + uint64_t *dst_shmem, const uint64_t *src_global_ptr, const uint32_t size, uint64_t *mbar) { + uint32_t dst_shmem_ptr = __cvta_generic_to_shared(dst_shmem); + uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); + // triggers async copy, i.e. the thread continues until wait() on mbarrier + // barrier condition: + // - leader must arrive (i.e. 1 thread as set above) + // - TMA hardware substracts bytes from expect_tx counter, must reach zero + asm volatile( + "cp.async.bulk.shared::cta.global" + ".mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];" ::"r"(dst_shmem_ptr), + "l"(src_global_ptr), "r"(size), "r"(mbar_ptr) + : "memory"); +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor +// global -> shared::cluster +__device__ __forceinline__ void cp_async_bulk_tensor_2d_global_to_shared( + uint64_t *dst_shmem, const uint64_t *tensor_map_ptr, const uint32_t offset_x, + const uint32_t offset_y, uint64_t *mbar) { + uint32_t dst_shmem_ptr = __cvta_generic_to_shared(dst_shmem); + uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); + // triggers async copy, i.e. the thread continues until wait() on mbarrier + // barrier condition: + // - leader must arrive (i.e. 1 thread as set above) + // - TMA hardware substracts bytes from expect_tx counter, must reach zero + asm volatile( + "cp.async.bulk.tensor.2d.shared::cluster.global.tile" + ".mbarrier::complete_tx::bytes [%0], [%1, {%2, %3}], [%4];" ::"r"(dst_shmem_ptr), + "l"(tensor_map_ptr), "r"(offset_x), "r"(offset_y), "r"(mbar_ptr) + : "memory"); +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor +// shared::cta -> global +__device__ __forceinline__ void cp_async_bulk_tensor_1d_shared_to_global(uint64_t *dst_global_ptr, + const uint64_t *src_shmem, + const uint32_t size) { + uint32_t src_shmem_ptr = __cvta_generic_to_shared(src_shmem); + asm volatile("cp.async.bulk.global.shared::cta.bulk_group [%0], [%1], %2;" ::"l"(dst_global_ptr), + "r"(src_shmem_ptr), "r"(size) + : "memory"); +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor +// shared::cta -> global +__device__ __forceinline__ void cp_async_bulk_tensor_2d_shared_to_global( + const uint64_t *tensor_map_ptr, const uint32_t offset_x, const uint32_t offset_y, + uint64_t *src_shmem) { + uint32_t src_shmem_ptr = __cvta_generic_to_shared(src_shmem); + asm volatile("cp.async.bulk.tensor.2d.global.shared::cta.bulk_group [%0, {%1, %2}], [%3];" ::"l"( + tensor_map_ptr), + "r"(offset_x), "r"(offset_y), "r"(src_shmem_ptr) + : "memory"); +} + +__device__ __forceinline__ bool mbarrier_try_wait_parity(uint32_t mbar_ptr, const uint32_t parity) { + uint32_t waitComplete; + asm volatile( + "{\n\t .reg .pred P_OUT; \n\t" + "mbarrier.try_wait.parity.shared::cta.b64 P_OUT, [%1], %2; \n\t" + "selp.b32 %0, 1, 0, P_OUT; \n" + "}" + : "=r"(waitComplete) + : "r"(mbar_ptr), "r"(parity) + : "memory"); + return static_cast(waitComplete); +} + +__device__ __forceinline__ void mbarrier_wait_parity(uint64_t *mbar, const uint32_t parity) { + uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); + while (!mbarrier_try_wait_parity(mbar_ptr, parity)) { + } +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-commit-group +__device__ __forceinline__ void cp_async_bulk_commit_group() { + asm volatile("cp.async.bulk.commit_group;"); +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-wait-group +__device__ __forceinline__ void cp_async_bulk_wait_group() { + asm volatile("cp.async.bulk.wait_group 0;"); +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-wait-group +template +__device__ __forceinline__ void cp_async_bulk_wait_group_read() { + asm volatile("cp.async.bulk.wait_group.read 0;"); +} + +template <> +__device__ __forceinline__ void cp_async_bulk_wait_group_read<0>() { + asm volatile("cp.async.bulk.wait_group.read 0;"); +} +template <> +__device__ __forceinline__ void cp_async_bulk_wait_group_read<1>() { + asm volatile("cp.async.bulk.wait_group.read 1;"); +} +template <> +__device__ __forceinline__ void cp_async_bulk_wait_group_read<2>() { + asm volatile("cp.async.bulk.wait_group.read 2;"); +} +template <> +__device__ __forceinline__ void cp_async_bulk_wait_group_read<4>() { + asm volatile("cp.async.bulk.wait_group.read 4;"); +} + +// Proxy fence (bi-directional): +__device__ __forceinline__ void fence_proxy_async() { asm volatile("fence.proxy.async;"); } +__device__ __forceinline__ void fence_proxy_async_shared_cta() { + asm volatile("fence.proxy.async.shared::cta;"); +} + +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + +} // namespace ptx + +namespace { + +template +__forceinline__ __device__ void initialize_barriers(uint64_t *mbar, const bool is_master_thread) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + if (is_master_thread) { + // Initialize barrier. All `blockDim.x * blockDim.y` threads in block participate. +#pragma unroll + for (int iter = 0; iter < num_barriers; ++iter) { + ptx::mbarrier_init(&mbar[iter], THREADS_PER_BLOCK); + } + ptx::fence_proxy_async_shared_cta(); + } + // Syncthreads so initialized barrier is visible to all threads. + __syncthreads(); +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} + +template +__forceinline__ __device__ void destroy_barriers(uint64_t *mbar, const bool is_master_thread) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + // Destroy barrier. This invalidates the memory region of the barrier. If + // further computations were to take place in the kernel, this allows the + // memory location of the shared memory barrier to be reused. + if (is_master_thread) { +#pragma unroll + for (int iter = 0; iter < num_barriers; ++iter) { + ptx::mbarrier_invalid(&mbar[iter]); + } + } +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} + +__forceinline__ __device__ void copy_1d_to_shared(void *dst, const void *src, + const size_t num_bytes, uint64_t *barrier, + const bool is_master_thread) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + if (is_master_thread) { + // Initiate bulk tensor copy + ptx::cp_async_bulk_tensor_1d_global_to_shared(reinterpret_cast(dst), + reinterpret_cast(src), + num_bytes, barrier); + + // Arrive on the barrier and tell how many bytes are expected to come in. + ptx::mbarrier_arrive_expect_tx(barrier, num_bytes); + } else { + // Other threads just arrive + ptx::mbarrier_arrive(barrier); + } +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} + +__forceinline__ __device__ void copy_2d_to_shared(void *dst, const void *src, const size_t chunk_X, + const size_t chunk_Y, const size_t num_bytes, + uint64_t *barrier, const bool is_master_thread) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + if (is_master_thread) { + // Initiate bulk tensor copy + ptx::cp_async_bulk_tensor_2d_global_to_shared(reinterpret_cast(dst), + reinterpret_cast(src), chunk_X, + chunk_Y, barrier); + + // Arrive on the barrier and tell how many bytes are expected to come in. + ptx::mbarrier_arrive_expect_tx(barrier, num_bytes); + } else { + // Other threads just arrive + ptx::mbarrier_arrive(barrier); + } +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} + +__forceinline__ __device__ void copy_2d_to_sharedx2(void *dst, const void *src, + const size_t chunk_X1, const size_t chunk_Y1, + void *dst2, const void *src2, + const size_t chunk_X2, const size_t chunk_Y2, + const size_t num_bytes, uint64_t *barrier, + const bool is_master_thread) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + if (is_master_thread) { + // Initiate bulk tensor copy + ptx::cp_async_bulk_tensor_2d_global_to_shared(reinterpret_cast(dst), + reinterpret_cast(src), chunk_X1, + chunk_Y1, barrier); + + ptx::cp_async_bulk_tensor_2d_global_to_shared(reinterpret_cast(dst2), + reinterpret_cast(src2), + chunk_X2, chunk_Y2, barrier); + + // Arrive on the barrier and tell how many bytes are expected to come in. + ptx::mbarrier_arrive_expect_tx(barrier, 2 * num_bytes); + } else { + // Other threads just arrive + ptx::mbarrier_arrive(barrier); + } +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} + +__forceinline__ __device__ void copy_2d_to_sharedx3( + void *dst, const void *src, const size_t chunk_X1, const size_t chunk_Y1, void *dst2, + const void *src2, const size_t chunk_X2, const size_t chunk_Y2, void *dst3, const void *src3, + const size_t chunk_X3, const size_t chunk_Y3, const size_t num_bytes, uint64_t *barrier, + const bool is_master_thread) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + if (is_master_thread) { + // Initiate bulk tensor copy + ptx::cp_async_bulk_tensor_2d_global_to_shared(reinterpret_cast(dst), + reinterpret_cast(src), chunk_X1, + chunk_Y1, barrier); + + ptx::cp_async_bulk_tensor_2d_global_to_shared(reinterpret_cast(dst2), + reinterpret_cast(src2), + chunk_X2, chunk_Y2, barrier); + + ptx::cp_async_bulk_tensor_2d_global_to_shared(reinterpret_cast(dst3), + reinterpret_cast(src3), + chunk_X3, chunk_Y3, barrier); + + // Arrive on the barrier and tell how many bytes are expected to come in. + ptx::mbarrier_arrive_expect_tx(barrier, 3 * num_bytes); + } else { + // Other threads just arrive + ptx::mbarrier_arrive(barrier); + } +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} + +} // namespace +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_PTX_CUH_ diff --git a/transformer_engine/common/util/pybind_helper.h b/transformer_engine/common/util/pybind_helper.h index 97c5bee2b1..b3087d1fb7 100644 --- a/transformer_engine/common/util/pybind_helper.h +++ b/transformer_engine/common/util/pybind_helper.h @@ -14,66 +14,98 @@ #include "cuda_runtime.h" -#define NVTE_DECLARE_COMMON_PYBIND11_HANDLES(m) \ - pybind11::enum_(m, "DType") \ - .value("kByte", transformer_engine::DType::kByte) \ - .value("kInt32", transformer_engine::DType::kInt32) \ - .value("kFloat32", transformer_engine::DType::kFloat32) \ - .value("kFloat16", transformer_engine::DType::kFloat16) \ - .value("kBFloat16", transformer_engine::DType::kBFloat16) \ - .value("kFloat8E4M3", transformer_engine::DType::kFloat8E4M3) \ - .value("kFloat8E5M2", transformer_engine::DType::kFloat8E5M2); \ - pybind11::enum_(m, "NVTE_Bias_Type") \ - .value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS) \ - .value("NVTE_PRE_SCALE_BIAS", NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS) \ - .value("NVTE_POST_SCALE_BIAS", NVTE_Bias_Type::NVTE_POST_SCALE_BIAS) \ - .value("NVTE_ALIBI", NVTE_Bias_Type::NVTE_ALIBI); \ - pybind11::enum_(m, "NVTE_Mask_Type") \ - .value("NVTE_NO_MASK", NVTE_Mask_Type::NVTE_NO_MASK) \ - .value("NVTE_PADDING_MASK", NVTE_Mask_Type::NVTE_PADDING_MASK) \ - .value("NVTE_CAUSAL_MASK", NVTE_Mask_Type::NVTE_CAUSAL_MASK) \ - .value("NVTE_PADDING_CAUSAL_MASK", NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) \ - .value("NVTE_CAUSAL_BOTTOM_RIGHT_MASK", NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) \ - .value("NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK", \ - NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); \ - pybind11::enum_(m, "NVTE_QKV_Layout") \ - .value("NVTE_SB3HD", NVTE_QKV_Layout::NVTE_SB3HD) \ - .value("NVTE_SBH3D", NVTE_QKV_Layout::NVTE_SBH3D) \ - .value("NVTE_SBHD_SB2HD", NVTE_QKV_Layout::NVTE_SBHD_SB2HD) \ - .value("NVTE_SBHD_SBH2D", NVTE_QKV_Layout::NVTE_SBHD_SBH2D) \ - .value("NVTE_SBHD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD) \ - .value("NVTE_BS3HD", NVTE_QKV_Layout::NVTE_BS3HD) \ - .value("NVTE_BSH3D", NVTE_QKV_Layout::NVTE_BSH3D) \ - .value("NVTE_BSHD_BS2HD", NVTE_QKV_Layout::NVTE_BSHD_BS2HD) \ - .value("NVTE_BSHD_BSH2D", NVTE_QKV_Layout::NVTE_BSHD_BSH2D) \ - .value("NVTE_BSHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) \ - .value("NVTE_T3HD", NVTE_QKV_Layout::NVTE_T3HD) \ - .value("NVTE_TH3D", NVTE_QKV_Layout::NVTE_TH3D) \ - .value("NVTE_THD_T2HD", NVTE_QKV_Layout::NVTE_THD_T2HD) \ - .value("NVTE_THD_TH2D", NVTE_QKV_Layout::NVTE_THD_TH2D) \ - .value("NVTE_THD_THD_THD", NVTE_QKV_Layout::NVTE_THD_THD_THD); \ - pybind11::enum_(m, "NVTE_Fused_Attn_Backend") \ - .value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) \ - .value("NVTE_F16_arbitrary_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) \ - .value("NVTE_FP8", NVTE_Fused_Attn_Backend::NVTE_FP8) \ - .value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend); \ - pybind11::enum_(m, "CommOverlapType") \ - .value("RS", transformer_engine::CommOverlapType::RS) \ - .value("AG", transformer_engine::CommOverlapType::AG); \ - pybind11::enum_(m, "CommOverlapAlgo") \ - .value("BULK_OVERLAP_AG", transformer_engine::CommOverlapAlgo::BULK_OVERLAP_AG) \ - .value("BULK_OVERLAP_RS", transformer_engine::CommOverlapAlgo::BULK_OVERLAP_RS) \ - .value("SPLIT_PIPELINED_AG_P2P", \ - transformer_engine::CommOverlapAlgo::SPLIT_PIPELINED_AG_P2P) \ - .value("SPLIT_PIPELINED_RS", transformer_engine::CommOverlapAlgo::SPLIT_PIPELINED_RS) \ - .value("SPLIT_PIPELINED_RS_P2P", \ - transformer_engine::CommOverlapAlgo::SPLIT_PIPELINED_RS_P2P) \ - .value("ATOMIC_GEMM_RS", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_RS) \ - .value("ATOMIC_GEMM_AG_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_AG_P2P) \ - .value("ATOMIC_GEMM_RS_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_RS_P2P); \ - m.def("device_supports_multicast", &transformer_engine::cuda::supports_multicast, \ - py::call_guard(), py::arg("device_id") = -1); \ - m.def("ubuf_built_with_mpi", &transformer_engine::ubuf_built_with_mpi, \ +#define NVTE_DECLARE_COMMON_PYBIND11_HANDLES(m) \ + pybind11::enum_(m, "DType", pybind11::module_local()) \ + .value("kByte", transformer_engine::DType::kByte) \ + .value("kInt32", transformer_engine::DType::kInt32) \ + .value("kFloat32", transformer_engine::DType::kFloat32) \ + .value("kFloat16", transformer_engine::DType::kFloat16) \ + .value("kBFloat16", transformer_engine::DType::kBFloat16) \ + .value("kFloat8E4M3", transformer_engine::DType::kFloat8E4M3) \ + .value("kFloat8E5M2", transformer_engine::DType::kFloat8E5M2); \ + pybind11::enum_(m, "NVTE_Bias_Type", pybind11::module_local()) \ + .value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS) \ + .value("NVTE_PRE_SCALE_BIAS", NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS) \ + .value("NVTE_POST_SCALE_BIAS", NVTE_Bias_Type::NVTE_POST_SCALE_BIAS) \ + .value("NVTE_ALIBI", NVTE_Bias_Type::NVTE_ALIBI); \ + pybind11::enum_(m, "NVTE_Mask_Type", pybind11::module_local()) \ + .value("NVTE_NO_MASK", NVTE_Mask_Type::NVTE_NO_MASK) \ + .value("NVTE_PADDING_MASK", NVTE_Mask_Type::NVTE_PADDING_MASK) \ + .value("NVTE_CAUSAL_MASK", NVTE_Mask_Type::NVTE_CAUSAL_MASK) \ + .value("NVTE_PADDING_CAUSAL_MASK", NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) \ + .value("NVTE_CAUSAL_BOTTOM_RIGHT_MASK", NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) \ + .value("NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK", \ + NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); \ + pybind11::enum_(m, "NVTE_QKV_Layout", pybind11::module_local()) \ + .value("NVTE_SB3HD", NVTE_QKV_Layout::NVTE_SB3HD) \ + .value("NVTE_SBH3D", NVTE_QKV_Layout::NVTE_SBH3D) \ + .value("NVTE_SBHD_SB2HD", NVTE_QKV_Layout::NVTE_SBHD_SB2HD) \ + .value("NVTE_SBHD_SBH2D", NVTE_QKV_Layout::NVTE_SBHD_SBH2D) \ + .value("NVTE_SBHD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD) \ + .value("NVTE_BS3HD", NVTE_QKV_Layout::NVTE_BS3HD) \ + .value("NVTE_BSH3D", NVTE_QKV_Layout::NVTE_BSH3D) \ + .value("NVTE_BSHD_BS2HD", NVTE_QKV_Layout::NVTE_BSHD_BS2HD) \ + .value("NVTE_BSHD_BSH2D", NVTE_QKV_Layout::NVTE_BSHD_BSH2D) \ + .value("NVTE_BSHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) \ + .value("NVTE_T3HD", NVTE_QKV_Layout::NVTE_T3HD) \ + .value("NVTE_TH3D", NVTE_QKV_Layout::NVTE_TH3D) \ + .value("NVTE_THD_T2HD", NVTE_QKV_Layout::NVTE_THD_T2HD) \ + .value("NVTE_THD_TH2D", NVTE_QKV_Layout::NVTE_THD_TH2D) \ + .value("NVTE_THD_THD_THD", NVTE_QKV_Layout::NVTE_THD_THD_THD); \ + pybind11::enum_(m, "NVTE_Fused_Attn_Backend", pybind11::module_local()) \ + .value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) \ + .value("NVTE_F16_arbitrary_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) \ + .value("NVTE_FP8", NVTE_Fused_Attn_Backend::NVTE_FP8) \ + .value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend); \ + pybind11::enum_(m, "CommOverlapType", \ + pybind11::module_local()) \ + .value("RS", transformer_engine::CommOverlapType::RS) \ + .value("AG", transformer_engine::CommOverlapType::AG); \ + pybind11::enum_(m, "CommOverlapAlgo", \ + pybind11::module_local()) \ + .value("BULK_OVERLAP_AG", transformer_engine::CommOverlapAlgo::BULK_OVERLAP_AG) \ + .value("BULK_OVERLAP_RS", transformer_engine::CommOverlapAlgo::BULK_OVERLAP_RS) \ + .value("SPLIT_PIPELINED_AG_P2P", \ + transformer_engine::CommOverlapAlgo::SPLIT_PIPELINED_AG_P2P) \ + .value("SPLIT_PIPELINED_RS", transformer_engine::CommOverlapAlgo::SPLIT_PIPELINED_RS) \ + .value("SPLIT_PIPELINED_RS_P2P", \ + transformer_engine::CommOverlapAlgo::SPLIT_PIPELINED_RS_P2P) \ + .value("ATOMIC_GEMM_RS", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_RS) \ + .value("ATOMIC_GEMM_AG_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_AG_P2P) \ + .value("ATOMIC_GEMM_RS_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_RS_P2P); \ + py::class_>(m, "CommOverlapCore", \ + pybind11::module_local()) \ + .def(py::init([]() { return new transformer_engine::CommOverlapCore(); }), \ + py::call_guard()) \ + .def("is_atomic_gemm", &transformer_engine::CommOverlapCore::is_atomic_gemm, \ + py::call_guard()) \ + .def("is_p2p_overlap", &transformer_engine::CommOverlapCore::is_p2p_overlap, \ + py::call_guard()) \ + .def("is_fp8_ubuf", &transformer_engine::CommOverlapCore::is_fp8_ubuf, \ + py::call_guard()); \ + py::class_, \ + transformer_engine::CommOverlapCore>(m, "CommOverlapBase", pybind11::module_local()) \ + .def(py::init([]() { return new transformer_engine::CommOverlapBase(); }), \ + py::call_guard()); \ + py::class_, \ + transformer_engine::CommOverlapCore>(m, "CommOverlapP2PBase", \ + pybind11::module_local()) \ + .def(py::init([]() { return new transformer_engine::CommOverlapP2PBase(); }), \ + py::call_guard()); \ + m.def("device_supports_multicast", &transformer_engine::cuda::supports_multicast, \ + py::call_guard(), py::arg("device_id") = -1); \ + m.def( \ + "get_stream_priority_range", \ + [](int device_id = -1) { \ + int low_pri, high_pri; \ + transformer_engine::cuda::stream_priority_range(&low_pri, &high_pri, device_id); \ + return std::make_pair(low_pri, high_pri); \ + }, \ + py::call_guard(), py::arg("device_id") = -1); \ + m.def("ubuf_built_with_mpi", &transformer_engine::ubuf_built_with_mpi, \ py::call_guard()); #endif diff --git a/transformer_engine/common/util/system.h b/transformer_engine/common/util/system.h index e3a7164932..71c7ef3216 100644 --- a/transformer_engine/common/util/system.h +++ b/transformer_engine/common/util/system.h @@ -9,8 +9,6 @@ #include -#include "../common.h" - namespace transformer_engine { /*! \brief Get environment variable and convert to type diff --git a/transformer_engine/common/util/vectorized_pointwise.h b/transformer_engine/common/util/vectorized_pointwise.h index faf3ea0a61..420b9ed3bb 100644 --- a/transformer_engine/common/util/vectorized_pointwise.h +++ b/transformer_engine/common/util/vectorized_pointwise.h @@ -44,6 +44,13 @@ class VectorizedStorage { return *this; } inline __device__ ~VectorizedStorage() {} + + /* \brief Access to separate elements. */ + inline __device__ DType *separate() { return scratch_.separate; } + + inline __device__ const DType *separate() const { return scratch_.separate; } + + inline __device__ LType &aligned() { return scratch_.aligned; } }; // Returns const LType is DType is const @@ -167,9 +174,11 @@ constexpr int unary_kernel_threads = 512; template __launch_bounds__(unary_kernel_threads) __global__ - void unary_kernel(const InputType *input, OutputType *output, const ComputeType *scale, - ComputeType *amax, ComputeType *scale_inv, Param p, const size_t N, - const size_t num_aligned_elements) { + void unary_kernel(const InputType *input, const ComputeType *noop, OutputType *output, + const ComputeType *scale, ComputeType *amax, ComputeType *scale_inv, Param p, + const size_t N, const size_t num_aligned_elements) { + if (noop != nullptr && noop[0] == 1.0f) return; + VectorizedLoader loader(input, N); VectorizedStorer storer(output, N); ComputeType max = 0; @@ -322,9 +331,9 @@ Alignment CheckAlignment(const size_t lead_dim, const int nvec, const T... ptrs) template -void VectorizedUnaryKernelLauncher(const InputType *input, OutputType *output, const fp32 *scale, - fp32 *amax, fp32 *scale_inv, const size_t N, const Param params, - cudaStream_t stream) { +void VectorizedUnaryKernelLauncher(const InputType *input, const fp32 *noop, OutputType *output, + const fp32 *scale, fp32 *amax, fp32 *scale_inv, const size_t N, + const Param params, cudaStream_t stream) { if (N != 0) { auto align = CheckAlignment(N, nvec, input, output); @@ -337,16 +346,16 @@ void VectorizedUnaryKernelLauncher(const InputType *input, OutputType *output, c switch (align) { case Alignment::SAME_ALIGNED: unary_kernel<<>>( - input, output, scale, amax, scale_inv, params, N, num_aligned_elements); + input, noop, output, scale, amax, scale_inv, params, N, num_aligned_elements); break; case Alignment::SAME_UNALIGNED: unary_kernel<<>>( - input, output, scale, amax, scale_inv, params, N, num_aligned_elements); + input, noop, output, scale, amax, scale_inv, params, N, num_aligned_elements); break; case Alignment::DIFFERENT: { // If the pointers are aligned differently we cannot vectorize unary_kernel<1, true, fp32, Param, OP><<>>( - input, output, scale, amax, scale_inv, params, N, N); + input, noop, output, scale, amax, scale_inv, params, N, N); break; } } @@ -395,18 +404,19 @@ __launch_bounds__(unary_kernel_threads) __global__ ComputeType *amax, ComputeType *scale_inv, const size_t m, const size_t n, const Param p, const size_t num_aligned_elements) { const size_t M = num_aligned_elements * m; + ComputeType max = 0; + ComputeType s = 1; + if constexpr (is_fp8::value) { + if (scale != nullptr) s = *scale; + } + const int warp_id = threadIdx.x / THREADS_PER_WARP; + for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < M; tid += gridDim.x * blockDim.x) { const size_t id_x = tid % num_aligned_elements; const size_t id_y = tid / num_aligned_elements; VectorizedLoader loader0(input + id_y * n * 2, n); VectorizedLoader loader1(input + id_y * n * 2 + n, n); VectorizedStorer storer(output + id_y * n, n); - ComputeType max = 0; - ComputeType s = 1; - if constexpr (is_fp8::value) { - if (scale != nullptr) s = *scale; - } - const int warp_id = threadIdx.x / THREADS_PER_WARP; loader0.load(id_x, n); loader1.load(id_x, n); @@ -423,21 +433,20 @@ __launch_bounds__(unary_kernel_threads) __global__ storer.separate()[i] = static_cast(static_cast(temp)); } storer.store(id_x, n); - - if constexpr (is_fp8::value) { - // Reduce amax over block - if (amax != nullptr) { - max = reduce_max(max, warp_id); - if (threadIdx.x == 0) { - static_assert(std::is_same::value); - atomicMaxFloat(amax, max); - } + } + if constexpr (is_fp8::value) { + // Reduce amax over block + if (amax != nullptr) { + max = reduce_max(max, warp_id); + if (threadIdx.x == 0) { + static_assert(std::is_same::value); + atomicMaxFloat(amax, max); } + } - // Update scale-inverse - if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv != nullptr) { - reciprocal(scale_inv, s); - } + // Update scale-inverse + if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv != nullptr) { + reciprocal(scale_inv, s); } } } @@ -482,9 +491,17 @@ template __launch_bounds__(unary_kernel_threads) __global__ void dgated_act_kernel(const InputType *grad, const InputType *input, OutputType *output, + const ComputeType *scale, ComputeType *amax, ComputeType *scale_inv, const size_t m, const size_t n, const Param p, const size_t num_aligned_elements) { const size_t M = num_aligned_elements * m; + ComputeType max = 0; + ComputeType s = 1; + if constexpr (is_fp8::value) { + if (scale != nullptr) s = *scale; + } + const int warp_id = threadIdx.x / THREADS_PER_WARP; + for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < M; tid += gridDim.x * blockDim.x) { const size_t id_x = tid % num_aligned_elements; const size_t id_y = tid / num_aligned_elements; @@ -507,12 +524,35 @@ __launch_bounds__(unary_kernel_threads) __global__ ComputeType after_dgelu = Dactivation(gelu_in, p) * grad_val * gate_in; ComputeType after_dgate = grad_val * Activation(gelu_in, p); + if constexpr (is_fp8::value) { + __builtin_assume(max >= 0); + max = fmaxf(fabsf(after_dgelu), max); + after_dgelu = after_dgelu * s; + max = fmaxf(fabsf(after_dgate), max); + after_dgate = after_dgate * s; + } + storer0.separate()[i] = static_cast(after_dgelu); storer1.separate()[i] = static_cast(after_dgate); } storer0.store(id_x, n); storer1.store(id_x, n); } + if constexpr (is_fp8::value) { + // Reduce amax over block + if (amax != nullptr) { + max = reduce_max(max, warp_id); + if (threadIdx.x == 0) { + static_assert(std::is_same::value); + atomicMaxFloat(amax, max); + } + } + + // Update scale-inverse + if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv != nullptr) { + reciprocal(scale_inv, s); + } + } } template void DGatedActivationKernelLauncher(const InputType *grad, const InputType *input, - OutputType *output, const size_t m, const size_t n, - const Param &p, cudaStream_t stream) { + OutputType *output, const fp32 *scale, fp32 *amax, + fp32 *scale_inv, const size_t m, const size_t n, const Param &p, + cudaStream_t stream) { if (m != 0 && n != 0) { size_t num_aligned_elements = get_num_aligned_elements(grad, n, nvec, sizeof(InputType)); constexpr size_t threads = unary_kernel_threads; @@ -532,18 +573,19 @@ void DGatedActivationKernelLauncher(const InputType *grad, const InputType *inpu switch (auto align = CheckAlignment(n, nvec, input, input + n, output, output + n)) { case Alignment::SAME_ALIGNED: dgated_act_kernel - <<>>(grad, input, output, m, n, p, - num_aligned_elements); + <<>>(grad, input, output, scale, amax, scale_inv, m, n, + p, num_aligned_elements); break; case Alignment::SAME_UNALIGNED: dgated_act_kernel - <<>>(grad, input, output, m, n, p, - num_aligned_elements); + <<>>(grad, input, output, scale, amax, scale_inv, m, n, + p, num_aligned_elements); break; case Alignment::DIFFERENT: { // If the pointers are aligned differently we cannot vectorize dgated_act_kernel<1, true, ComputeType, Param, Activation, Dactivation> - <<>>(grad, input, output, m, n, p, n); + <<>>(grad, input, output, scale, amax, scale_inv, m, n, + p, n); break; } } diff --git a/transformer_engine/common/utils.cuh b/transformer_engine/common/utils.cuh index 6267baf19e..63ce369892 100644 --- a/transformer_engine/common/utils.cuh +++ b/transformer_engine/common/utils.cuh @@ -819,6 +819,21 @@ __device__ __forceinline__ float warp_reduce_max(const float m) { return tmp; } +__forceinline__ __device__ float warp_reduce_max_broadcast(const float val) { + float val_tmp = val; +#pragma unroll + for (int offset = THREADS_PER_WARP / 2; offset > 0; offset /= 2) { + const float val_other = __shfl_down_sync(0xFFFFFFFF, val_tmp, offset); + __builtin_assume(val_tmp >= 0); + __builtin_assume(val_other >= 0); + val_tmp = fmaxf(val_tmp, val_other); + } + // Broadcast the amax to other threads of the subwarp from the zero subwarp lane_id + constexpr int subwarp_lane_zero = 0; + val_tmp = __shfl_sync(0xFFFFFFFF, val_tmp, subwarp_lane_zero); + return val_tmp; +} + template __device__ __forceinline__ compute_t reduce_max(const compute_t m, const int warpid) { __shared__ float staging[num_warps]; @@ -837,6 +852,29 @@ __device__ __forceinline__ compute_t reduce_max(const compute_t m, const int war return result; } +/** + * Max reduction in subwarps + * E.g., if nvec=4, each warp processes 128 elements (32 x 4), that covers four MXFP8 scaling factors. + * To compute an actual scaling factor for 32 consequentive elements, only 8 threads need to participate, + * thus splitting the warp into 4x smaller subwarps 8-thread width. + * 'Butterfly' reduction is used inside subwarps. + */ +template +__forceinline__ __device__ float subwarp_reduce_max_broadcast(const float val) { + float val_tmp = val; +#pragma unroll + for (int offset = subwarp_width / 2; offset > 0; offset /= 2) { + const float val_other = __shfl_down_sync(0xFFFFFFFF, val_tmp, offset, subwarp_width); + __builtin_assume(val_tmp >= 0); + __builtin_assume(val_other >= 0); + val_tmp = fmaxf(val_tmp, val_other); + } + // Broadcast the amax to other threads of the subwarp from the zero subwarp lane_id + constexpr int subwarp_lane_zero = 0; + val_tmp = __shfl_sync(0xFFFFFFFF, val_tmp, subwarp_lane_zero, subwarp_width); + return val_tmp; +} + // Works only on positive values __device__ __forceinline__ void atomicMaxFloat(float *addr, const float value) { atomicMax(reinterpret_cast(addr), __float_as_int(value)); @@ -857,6 +895,79 @@ __device__ __forceinline__ void reciprocal(float *value_inv, const float *value_inv = __frcp_rn(value); } +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using fp8e4m3 = __nv_fp8_e4m3; +using fp8e5m2 = __nv_fp8_e5m2; +using e8m0_t = uint8_t; + +constexpr uint32_t FP32_MANTISSA_BITS = 23; +constexpr uint32_t FP32_EXPONENT_BIAS = 127; + +enum ScalingType { ROWWISE = 0, COLWISE = 1, BIDIMENTIONAL = 2 }; + +template +struct Numeric_Traits; + +template <> +struct Numeric_Traits { + static constexpr int maxUnbiasedExponent = 8; + static constexpr double maxNorm = 448; +}; + +template <> +struct Numeric_Traits { + static constexpr int maxUnbiasedExponent = 15; + static constexpr double maxNorm = 57344; +}; + +template +struct Quantized_Limits { + static constexpr int max_unbiased_exponent = Numeric_Traits::maxUnbiasedExponent; + static constexpr float max_norm = Numeric_Traits::maxNorm; + static constexpr float max_norm_rcp = 1.0 / max_norm; + static constexpr float emax = 1 << max_unbiased_exponent; + static constexpr float emax_rcp = 1.0 / emax; +}; + +__device__ __forceinline__ e8m0_t float_to_e8m0(float val) { + // TODO: nan/inf needs to be set for any value + // of nan/inf in input not just amax. + if (isnan(val)) { + return 0xFF; + } + if (isinf(val)) { + return 0xFE; + } +#if ((__CUDA_ARCH_HAS_FEATURE__(SM100_ALL)) || (__CUDA_ARCH_HAS_FEATURE__(SM101_ALL)) || \ + (__CUDA_ARCH_HAS_FEATURE__(SM120_ALL))) + uint16_t out; + asm volatile( + "{\n" + "cvt.rp.satfinite.ue8m0x2.f32 %0, 0.0, %1;\n" + "}" + : "=h"(out) + : "f"(val)); + return *reinterpret_cast(&out); +#else + if (val == 0.0f) { + return 0x00; + } + uint32_t val_u32 = *reinterpret_cast(&val); + e8m0_t exponent = (val_u32 >> FP32_MANTISSA_BITS); + uint32_t mantissa = val_u32 & 0x7FFFFF; + // Round up exponent and deal with satfinite. + if ((mantissa > 0 && exponent != 0xFE) && !(exponent == 0 && mantissa <= 0x400000)) { + ++exponent; + } + return exponent; +#endif +} + +__device__ __forceinline__ float exp2f_rcp(e8m0_t biased_exp) { + return (biased_exp == 0) ? 1 : exp2f(FP32_EXPONENT_BIAS - static_cast(biased_exp)); +} + } // namespace transformer_engine #endif // TRANSFORMER_ENGINE_COMMON_UTILS_CUH_ diff --git a/transformer_engine/jax/csrc/extensions/activation.cpp b/transformer_engine/jax/csrc/extensions/activation.cpp index 41a6846a7c..a5457fa032 100644 --- a/transformer_engine/jax/csrc/extensions/activation.cpp +++ b/transformer_engine/jax/csrc/extensions/activation.cpp @@ -6,6 +6,7 @@ #include "transformer_engine/activation.h" #include "extensions.h" +#include "transformer_engine/cast.h" #include "transformer_engine/transpose.h" #include "xla/ffi/api/c_api.h" @@ -332,18 +333,27 @@ pybind11::tuple GetDActDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_ auto output_trans_shape = std::vector{hidden_size, batch_size}; auto dbias_shape = std::vector{hidden_size}; - auto input_tensor = TensorWrapper(nullptr, input_shape, in_dtype); - auto dact_input_tensor = TensorWrapper(nullptr, dact_input_shape, in_dtype); - auto output_tensor = TensorWrapper(nullptr, output_shape, out_dtype); - auto output_trans_tensor = TensorWrapper(nullptr, output_trans_shape, out_dtype); - auto dbias_tensor = TensorWrapper(nullptr, dbias_shape, in_dtype); + // Evil hack to specify TE impl + // Note: nvte_quantize_dbias_dgelu chooses its internal impl based + // on what pointers are allocated, e.g. whether to output with + // column-wise data. However, we don't have access to any allocated + // buffers in this function. We pass a dummy pointer as a + // workaround. + int temp = 0; + + auto input_tensor = TensorWrapper(reinterpret_cast(&temp), input_shape, in_dtype); + auto dact_input_tensor = + TensorWrapper(reinterpret_cast(&temp), dact_input_shape, in_dtype); + auto output_tensor = TensorWrapper(); + output_tensor.set_rowwise_data(reinterpret_cast(&temp), out_dtype, output_shape); + output_tensor.set_columnwise_data(reinterpret_cast(&temp), out_dtype, output_trans_shape); + auto dbias_tensor = TensorWrapper(reinterpret_cast(&temp), dbias_shape, in_dtype); TensorWrapper dummy_workspace; // For now, all dbias_dact(-s) have the same workspace size - nvte_cast_transpose_dbias_dgelu(input_tensor.data(), dact_input_tensor.data(), - output_tensor.data(), output_trans_tensor.data(), - dbias_tensor.data(), dummy_workspace.data(), nullptr); + nvte_quantize_dbias_dgelu(input_tensor.data(), dact_input_tensor.data(), output_tensor.data(), + dbias_tensor.data(), dummy_workspace.data(), nullptr); auto work_shape = MakeShapeVector(dummy_workspace.shape()); return pybind11::make_tuple(std::make_pair(work_shape, dummy_workspace.dtype())); @@ -384,37 +394,32 @@ void DActLuDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *o auto act_input_tensor = TensorWrapper(act_input, act_input_shape, desc.in_dtype); auto output_tensor = TensorWrapper(output, output_shape, desc.out_dtype, amax_out, scale, scale_inv); - auto output_trans_tensor = - TensorWrapper(output_trans, output_trans_shape, desc.out_dtype, amax_out, scale, scale_inv); + output_tensor.set_columnwise_data(output_trans, desc.out_dtype, output_trans_shape); + output_tensor.set_columnwise_scale_inv(scale_inv, DType::kFloat32, std::vector{1}); auto dbias_tensor = TensorWrapper(dbias, dbias_shape, desc.in_dtype); auto workspace = TensorWrapper(workspace_ptr, desc.wkshape.to_vector(), desc.wk_dtype); switch (act_enum) { case NVTE_Activation_Type::GELU: - nvte_cast_transpose_dbias_dgelu(input_tensor.data(), act_input_tensor.data(), - output_tensor.data(), output_trans_tensor.data(), - dbias_tensor.data(), workspace.data(), stream); + nvte_quantize_dbias_dgelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), + dbias_tensor.data(), workspace.data(), stream); break; case NVTE_Activation_Type::SILU: - nvte_cast_transpose_dbias_dsilu(input_tensor.data(), act_input_tensor.data(), - output_tensor.data(), output_trans_tensor.data(), - dbias_tensor.data(), workspace.data(), stream); + nvte_quantize_dbias_dsilu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), + dbias_tensor.data(), workspace.data(), stream); break; case NVTE_Activation_Type::RELU: - nvte_cast_transpose_dbias_drelu(input_tensor.data(), act_input_tensor.data(), - output_tensor.data(), output_trans_tensor.data(), - dbias_tensor.data(), workspace.data(), stream); + nvte_quantize_dbias_drelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), + dbias_tensor.data(), workspace.data(), stream); break; case NVTE_Activation_Type::QGELU: - nvte_cast_transpose_dbias_dqgelu(input_tensor.data(), act_input_tensor.data(), - output_tensor.data(), output_trans_tensor.data(), - dbias_tensor.data(), workspace.data(), stream); + nvte_quantize_dbias_dqgelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), + dbias_tensor.data(), workspace.data(), stream); break; case NVTE_Activation_Type::SRELU: - nvte_cast_transpose_dbias_dsrelu(input_tensor.data(), act_input_tensor.data(), - output_tensor.data(), output_trans_tensor.data(), - dbias_tensor.data(), workspace.data(), stream); + nvte_quantize_dbias_dsrelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), + dbias_tensor.data(), workspace.data(), stream); break; default: NVTE_ERROR("Unsupported ActivationEnum"); @@ -468,37 +473,32 @@ Error_Type DActLuDBiasCastTransposeFFI(cudaStream_t stream, Buffer_Type input_bu auto input_tensor = TensorWrapper(input, input_shape, in_dtype); auto act_input_tensor = TensorWrapper(act_input, input_shape, in_dtype); auto output_tensor = TensorWrapper(output, output_shape, out_dtype, amax_out, scale, scale_inv); - auto output_trans_tensor = - TensorWrapper(output_trans, output_trans_shape, out_dtype, amax_out, scale, scale_inv); + output_tensor.set_columnwise_data(output_trans, out_dtype, output_trans_shape); + output_tensor.set_columnwise_scale_inv(scale_inv, DType::kFloat32, std::vector{1}); auto dbias_tensor = TensorWrapper(dbias, dbias_shape, in_dtype); auto workspace_tensor = TensorWrapper(workspace, workspace_shape, workspace_dtype); auto act_type = static_cast(act_enum); switch (act_type) { case NVTE_Activation_Type::GELU: - nvte_cast_transpose_dbias_dgelu(input_tensor.data(), act_input_tensor.data(), - output_tensor.data(), output_trans_tensor.data(), - dbias_tensor.data(), workspace_tensor.data(), stream); + nvte_quantize_dbias_dgelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), + dbias_tensor.data(), workspace_tensor.data(), stream); break; case NVTE_Activation_Type::SILU: - nvte_cast_transpose_dbias_dsilu(input_tensor.data(), act_input_tensor.data(), - output_tensor.data(), output_trans_tensor.data(), - dbias_tensor.data(), workspace_tensor.data(), stream); + nvte_quantize_dbias_dsilu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), + dbias_tensor.data(), workspace_tensor.data(), stream); break; case NVTE_Activation_Type::RELU: - nvte_cast_transpose_dbias_drelu(input_tensor.data(), act_input_tensor.data(), - output_tensor.data(), output_trans_tensor.data(), - dbias_tensor.data(), workspace_tensor.data(), stream); + nvte_quantize_dbias_drelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), + dbias_tensor.data(), workspace_tensor.data(), stream); break; case NVTE_Activation_Type::QGELU: - nvte_cast_transpose_dbias_dqgelu(input_tensor.data(), act_input_tensor.data(), - output_tensor.data(), output_trans_tensor.data(), - dbias_tensor.data(), workspace_tensor.data(), stream); + nvte_quantize_dbias_dqgelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), + dbias_tensor.data(), workspace_tensor.data(), stream); break; case NVTE_Activation_Type::SRELU: - nvte_cast_transpose_dbias_dsrelu(input_tensor.data(), act_input_tensor.data(), - output_tensor.data(), output_trans_tensor.data(), - dbias_tensor.data(), workspace_tensor.data(), stream); + nvte_quantize_dbias_dsrelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), + dbias_tensor.data(), workspace_tensor.data(), stream); break; default: NVTE_ERROR("Unsupported ActivationEnum"); @@ -555,29 +555,29 @@ void DGatedActLuCastTranspose(cudaStream_t stream, void **buffers, const char *o auto act_input_tensor = TensorWrapper(act_input, act_input_shape, desc.in_dtype); auto output_tensor = TensorWrapper(output, output_shape, desc.out_dtype, amax_out, scale, scale_inv); - auto output_trans_tensor = - TensorWrapper(output_trans, output_trans_shape, desc.out_dtype, amax_out, scale, scale_inv); + output_tensor.set_columnwise_data(output_trans, desc.out_dtype, output_trans_shape); + output_tensor.set_columnwise_scale_inv(scale_inv, DType::kFloat32, std::vector{1}); switch (act_enum) { case NVTE_Activation_Type::GEGLU: nvte_dgeglu_cast_transpose(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), - output_trans_tensor.data(), stream); + stream); break; case NVTE_Activation_Type::SWIGLU: nvte_dswiglu_cast_transpose(input_tensor.data(), act_input_tensor.data(), - output_tensor.data(), output_trans_tensor.data(), stream); + output_tensor.data(), stream); break; case NVTE_Activation_Type::REGLU: nvte_dreglu_cast_transpose(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), - output_trans_tensor.data(), stream); + stream); break; case NVTE_Activation_Type::QGEGLU: nvte_dqgeglu_cast_transpose(input_tensor.data(), act_input_tensor.data(), - output_tensor.data(), output_trans_tensor.data(), stream); + output_tensor.data(), stream); break; case NVTE_Activation_Type::SREGLU: nvte_dsreglu_cast_transpose(input_tensor.data(), act_input_tensor.data(), - output_tensor.data(), output_trans_tensor.data(), stream); + output_tensor.data(), stream); break; default: NVTE_ERROR("Unsupported ActivationEnum"); @@ -622,30 +622,30 @@ Error_Type DGatedActLuCastTransposeFFI(cudaStream_t stream, Buffer_Type input_bu auto input_tensor = TensorWrapper(input, input_shape, in_dtype); auto act_input_tensor = TensorWrapper(act_input, act_input_shape, in_dtype); auto output_tensor = TensorWrapper(output, output_shape, out_dtype, amax_out, scale, scale_inv); - auto output_trans_tensor = - TensorWrapper(output_trans, output_trans_shape, out_dtype, amax_out, scale, scale_inv); + output_tensor.set_columnwise_data(output_trans, out_dtype, output_trans_shape); + output_tensor.set_columnwise_scale_inv(scale_inv, DType::kFloat32, std::vector{1}); auto act_type = static_cast(act_enum); switch (act_type) { case NVTE_Activation_Type::GEGLU: nvte_dgeglu_cast_transpose(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), - output_trans_tensor.data(), stream); + stream); break; case NVTE_Activation_Type::SWIGLU: nvte_dswiglu_cast_transpose(input_tensor.data(), act_input_tensor.data(), - output_tensor.data(), output_trans_tensor.data(), stream); + output_tensor.data(), stream); break; case NVTE_Activation_Type::REGLU: nvte_dreglu_cast_transpose(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), - output_trans_tensor.data(), stream); + stream); break; case NVTE_Activation_Type::QGEGLU: nvte_dqgeglu_cast_transpose(input_tensor.data(), act_input_tensor.data(), - output_tensor.data(), output_trans_tensor.data(), stream); + output_tensor.data(), stream); break; case NVTE_Activation_Type::SREGLU: nvte_dsreglu_cast_transpose(input_tensor.data(), act_input_tensor.data(), - output_tensor.data(), output_trans_tensor.data(), stream); + output_tensor.data(), stream); break; default: NVTE_ERROR("Unsupported ActivationEnum"); diff --git a/transformer_engine/jax/csrc/extensions/quantization.cpp b/transformer_engine/jax/csrc/extensions/quantization.cpp index 569dfd3baa..71d1456287 100644 --- a/transformer_engine/jax/csrc/extensions/quantization.cpp +++ b/transformer_engine/jax/csrc/extensions/quantization.cpp @@ -25,7 +25,7 @@ void Quantize(cudaStream_t stream, void **buffers, const char *opaque, size_t op auto input_tensor = TensorWrapper(input, shape, desc.in_dtype); auto output_tensor = TensorWrapper(output, shape, desc.out_dtype, amax_out, scale, scale_inv); - nvte_fp8_quantize(input_tensor.data(), output_tensor.data(), stream); + nvte_quantize(input_tensor.data(), output_tensor.data(), stream); } Error_Type QuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type amax_buf, @@ -48,7 +48,7 @@ Error_Type QuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type a auto input_tensor = TensorWrapper(input, shape, in_dtype); auto output_tensor = TensorWrapper(output, shape, out_dtype, amax_out, scale, scale_inv); - nvte_fp8_quantize(input_tensor.data(), output_tensor.data(), stream); + nvte_quantize(input_tensor.data(), output_tensor.data(), stream); return ffi_with_cuda_error_check(); } @@ -76,7 +76,7 @@ void Dequantize(cudaStream_t stream, void **buffers, const char *opaque, size_t auto input_tensor = TensorWrapper(input, shape, desc.in_dtype, amax, scale, scale_inv); auto output_tensor = TensorWrapper(output, shape, desc.out_dtype); - nvte_fp8_dequantize(input_tensor.data(), output_tensor.data(), stream); + nvte_dequantize(input_tensor.data(), output_tensor.data(), stream); } Error_Type DequantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type amax_buf, @@ -96,7 +96,7 @@ Error_Type DequantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type auto input_tensor = TensorWrapper(input, shape, in_dtype, amax, scale, scale_inv); auto output_tensor = TensorWrapper(output, shape, out_dtype); - nvte_fp8_dequantize(input_tensor.data(), output_tensor.data(), stream); + nvte_dequantize(input_tensor.data(), output_tensor.data(), stream); return ffi_with_cuda_error_check(); } diff --git a/transformer_engine/jax/csrc/extensions/transpose.cpp b/transformer_engine/jax/csrc/extensions/transpose.cpp index 516930c529..af347f45b2 100644 --- a/transformer_engine/jax/csrc/extensions/transpose.cpp +++ b/transformer_engine/jax/csrc/extensions/transpose.cpp @@ -7,6 +7,7 @@ #include "transformer_engine/transpose.h" #include "extensions.h" +#include "transformer_engine/cast.h" #include "xla/ffi/api/ffi.h" namespace transformer_engine { @@ -89,13 +90,12 @@ void CastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size auto input_trans_shape = std::vector{n, m}; auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype); - auto input_cast_tensor = + auto output_tensor = TensorWrapper(input_cast, input_shape, desc.out_dtype, amax_out, scale, scale_inv); - auto input_cast_trans_tensor = TensorWrapper(input_cast_trans, input_trans_shape, desc.out_dtype, - amax_out, scale, scale_inv); + output_tensor.set_columnwise_data(input_cast_trans, desc.out_dtype, input_trans_shape); + output_tensor.set_columnwise_scale_inv(scale_inv, DType::kFloat32, std::vector{1}); - nvte_cast_transpose(input_tensor.data(), input_cast_tensor.data(), input_cast_trans_tensor.data(), - stream); + nvte_quantize(input_tensor.data(), output_tensor.data(), stream); } Error_Type CastTransposeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type amax_buf, @@ -131,11 +131,11 @@ Error_Type CastTransposeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T auto input_tensor = TensorWrapper(input, input_shape, in_dtype); auto output_tensor = TensorWrapper(output, output_shape, out_dtype, amax_out, scale, scale_inv); - auto output_trans_tensor = - TensorWrapper(output_trans, output_trans_shape, out_dtype, amax_out, scale, scale_inv); + output_tensor.set_columnwise_data(output_trans, out_dtype, output_trans_shape); + output_tensor.set_columnwise_scale_inv(scale_inv, DType::kFloat32, std::vector{1}); + + nvte_quantize(input_tensor.data(), output_tensor.data(), stream); - nvte_cast_transpose(input_tensor.data(), output_tensor.data(), output_trans_tensor.data(), - stream); return ffi_with_cuda_error_check(); } @@ -159,15 +159,22 @@ pybind11::tuple GetDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hi auto output_trans_shape = std::vector{hidden_size, batch_size}; auto dbias_shape = std::vector{hidden_size}; - auto input_tensor = TensorWrapper(nullptr, input_shape, in_dtype); - auto output_tensor = TensorWrapper(nullptr, output_shape, out_dtype); - auto output_trans_tensor = TensorWrapper(nullptr, output_trans_shape, out_dtype); - auto dbias_tensor = TensorWrapper(nullptr, dbias_shape, in_dtype); + // Evil hack to specify TE impl + // Note: nvte_quantize_dbias chooses its internal impl based on what + // pointers are allocated, e.g. whether to output with column-wise + // data. However, we don't have access to any allocated buffers in + // this function. We pass a dummy pointer as a workaround. + int temp = 0; + + auto input_tensor = TensorWrapper(reinterpret_cast(&temp), input_shape, in_dtype); + auto output_tensor = TensorWrapper(reinterpret_cast(&temp), output_shape, out_dtype); + output_tensor.set_columnwise_data(reinterpret_cast(&temp), out_dtype, output_trans_shape); + auto dbias_tensor = TensorWrapper(reinterpret_cast(&temp), dbias_shape, in_dtype); TensorWrapper dummy_workspace; - nvte_cast_transpose_dbias(input_tensor.data(), output_tensor.data(), output_trans_tensor.data(), - dbias_tensor.data(), dummy_workspace.data(), nullptr); + nvte_quantize_dbias(input_tensor.data(), output_tensor.data(), dbias_tensor.data(), + dummy_workspace.data(), nullptr); auto work_shape = MakeShapeVector(dummy_workspace.shape()); return pybind11::make_tuple(std::make_pair(work_shape, dummy_workspace.dtype())); @@ -203,14 +210,14 @@ void DBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque, auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype); auto output_tensor = TensorWrapper(output, output_shape, desc.out_dtype, amax_out, scale, scale_inv); - auto output_trans_tensor = - TensorWrapper(output_trans, output_trans_shape, desc.out_dtype, amax_out, scale, scale_inv); + output_tensor.set_columnwise_data(output_trans, desc.out_dtype, output_trans_shape); + output_tensor.set_columnwise_scale_inv(scale_inv, DType::kFloat32, std::vector{1}); auto dbias_tensor = TensorWrapper(dbias, dbias_shape, desc.in_dtype); auto workspace = TensorWrapper(workspace_ptr, desc.wkshape.to_vector(), desc.wk_dtype); - nvte_cast_transpose_dbias(input_tensor.data(), output_tensor.data(), output_trans_tensor.data(), - dbias_tensor.data(), workspace.data(), stream); + nvte_quantize_dbias(input_tensor.data(), output_tensor.data(), dbias_tensor.data(), + workspace.data(), stream); } Error_Type DBiasCastTransposeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type amax_buf, @@ -253,13 +260,13 @@ Error_Type DBiasCastTransposeFFI(cudaStream_t stream, Buffer_Type input_buf, Buf auto input_tensor = TensorWrapper(input, input_shape, in_dtype); auto output_tensor = TensorWrapper(output, output_shape, out_dtype, amax_out, scale, scale_inv); - auto output_trans_tensor = - TensorWrapper(output_trans, output_trans_shape, out_dtype, amax_out, scale, scale_inv); + output_tensor.set_columnwise_data(output_trans, out_dtype, output_trans_shape); + output_tensor.set_columnwise_scale_inv(scale_inv, DType::kFloat32, std::vector{1}); auto dbias_tensor = TensorWrapper(dbias, dbias_shape, in_dtype); auto workspace_tensor = TensorWrapper(workspace, workspace_shape, workspace_dtype); - nvte_cast_transpose_dbias(input_tensor.data(), output_tensor.data(), output_trans_tensor.data(), - dbias_tensor.data(), workspace_tensor.data(), stream); + nvte_quantize_dbias(input_tensor.data(), output_tensor.data(), dbias_tensor.data(), + workspace_tensor.data(), stream); return ffi_with_cuda_error_check(); } diff --git a/transformer_engine/jax/fp8.py b/transformer_engine/jax/fp8.py index e7ee350b46..f2dbd3b131 100644 --- a/transformer_engine/jax/fp8.py +++ b/transformer_engine/jax/fp8.py @@ -354,11 +354,6 @@ def fp8_autocast( assert ( fp8_recipe.scaling_factor_compute_algo is None ), "DelayedScaling scaling_factor_compute_algo isn't supported by TE/JAX." - assert fp8_recipe.override_linear_precision == ( - False, - False, - False, - ), "DelayedScaling override_linear_precision isn't supported by TE/JAX." assert fp8_recipe.reduce_amax, "DelayedScaling reduce_amax should be enabled for TE/JAX." if mesh_resource is None: diff --git a/transformer_engine/paddle/MANIFEST.in b/transformer_engine/paddle/MANIFEST.in deleted file mode 100644 index 0c814f95da..0000000000 --- a/transformer_engine/paddle/MANIFEST.in +++ /dev/null @@ -1,3 +0,0 @@ -recursive-include build_tools *.* -recursive-include common_headers *.* -recursive-include csrc *.* diff --git a/transformer_engine/paddle/__init__.py b/transformer_engine/paddle/__init__.py deleted file mode 100644 index 583c4a7a7a..0000000000 --- a/transformer_engine/paddle/__init__.py +++ /dev/null @@ -1,60 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -"""Transformer Engine bindings for Paddle""" - -# pylint: disable=wrong-import-position,wrong-import-order - -import logging -from importlib.metadata import version - -from transformer_engine.common import is_package_installed - - -def _load_library(): - """Load shared library with Transformer Engine C extensions""" - module_name = "transformer_engine_paddle" - - if is_package_installed(module_name): - assert is_package_installed("transformer_engine"), "Could not find `transformer-engine`." - assert is_package_installed( - "transformer_engine_cu12" - ), "Could not find `transformer-engine-cu12`." - assert ( - version(module_name) - == version("transformer-engine") - == version("transformer-engine-cu12") - ), ( - "TransformerEngine package version mismatch. Found" - f" {module_name} v{version(module_name)}, transformer-engine" - f" v{version('transformer-engine')}, and transformer-engine-cu12" - f" v{version('transformer-engine-cu12')}. Install transformer-engine using 'pip install" - " transformer-engine[paddle]==VERSION'" - ) - - if is_package_installed("transformer-engine-cu12"): - if not is_package_installed(module_name): - logging.info( - "Could not find package %s. Install transformer-engine using 'pip" - " install transformer-engine[paddle]==VERSION'", - module_name, - ) - - from transformer_engine import transformer_engine_paddle # pylint: disable=unused-import - - -_load_library() -from .fp8 import fp8_autocast -from .layer import ( - Linear, - LayerNorm, - LayerNormLinear, - LayerNormMLP, - FusedScaleMaskSoftmax, - DotProductAttention, - MultiHeadAttention, - TransformerLayer, - RotaryPositionEmbedding, -) -from .recompute import recompute diff --git a/transformer_engine/paddle/constants.py b/transformer_engine/paddle/constants.py deleted file mode 100644 index dee8a70c38..0000000000 --- a/transformer_engine/paddle/constants.py +++ /dev/null @@ -1,74 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""Constants""" - -from enum import Enum - -import paddle - -from transformer_engine import transformer_engine_paddle as tex - - -class FP8FwdTensors(Enum): - """Used as named indices on the `scale`, `scale_inv`, - and `amax` tensors in the `FP8TensorMeta` class.""" - - GEMM1_INPUT = 0 - GEMM1_WEIGHT = 1 - GEMM1_OUTPUT = 2 - GEMM2_INPUT = 3 - GEMM2_WEIGHT = 4 - GEMM2_OUTPUT = 5 - - -class FP8BwdTensors(Enum): - """Used as named indices on the `scale`, `scale_inv`, - and `amax` tensors in the `FP8TensorMeta` class.""" - - GRAD_OUTPUT1 = 0 - GRAD_INPUT1 = 1 - GRAD_OUTPUT2 = 2 - GRAD_INPUT2 = 3 - - -""" -Map from paddle dtype to TE dtype -""" -TE_DType = { - paddle.uint8: tex.DType.kByte, - paddle.int32: tex.DType.kInt32, - paddle.float32: tex.DType.kFloat32, - paddle.float16: tex.DType.kFloat16, - paddle.bfloat16: tex.DType.kBFloat16, -} - -AttnMaskTypes = ("causal", "padding", "no_mask") - -AttnTypes = ("self", "cross") - -LayerTypes = ("encoder", "decoder") - -GemmParallelModes = ("row", "column", None) - -dist_group_type = paddle.distributed.collective.Group - -RecomputeFunctionNames = ("unpack", "backward") - -AttnBiasType = { - "no_bias": tex.NVTE_Bias_Type.NVTE_NO_BIAS, - "pre_scale_bias": tex.NVTE_Bias_Type.NVTE_PRE_SCALE_BIAS, - "post_scale_bias": tex.NVTE_Bias_Type.NVTE_POST_SCALE_BIAS, -} - -AttnMaskType = { - "no_mask": tex.NVTE_Mask_Type.NVTE_NO_MASK, - "padding": tex.NVTE_Mask_Type.NVTE_PADDING_MASK, - "causal": tex.NVTE_Mask_Type.NVTE_CAUSAL_MASK, -} - -FusedAttnBackend = { - "F16_max512_seqlen": tex.NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen, - "F16_arbitrary_seqlen": tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, - "No_Backend": tex.NVTE_Fused_Attn_Backend.NVTE_No_Backend, -} diff --git a/transformer_engine/paddle/cpp_extensions.py b/transformer_engine/paddle/cpp_extensions.py deleted file mode 100644 index 293c62a2fd..0000000000 --- a/transformer_engine/paddle/cpp_extensions.py +++ /dev/null @@ -1,1199 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""TE FP8 extensions and GEMMs""" - -import math -from typing import Optional, Tuple, Union -import paddle -import paddle.nn.functional as F -from transformer_engine import transformer_engine_paddle as tex -from .constants import TE_DType, FusedAttnBackend, FP8FwdTensors, FP8BwdTensors -from .fp8 import FP8TensorMeta, get_global_fp8_state - -BACKEND_F16m512_THREADS_PER_CTA = 128 -BACKEND_F16arb_ELTS_PER_THREADS = 16 - - -def gemm( - A: paddle.Tensor, - B: paddle.Tensor, - dtype: paddle.dtype, - workspace: paddle.Tensor, - gelu: bool = False, - gelu_input: Optional[paddle.Tensor] = None, - grad: bool = False, - accumulate: bool = False, - layout: str = "TN", - out: Optional[paddle.Tensor] = None, - out_dtype: Optional[paddle.dtype] = None, - bias: Optional[paddle.Tensor] = None, - use_bias: bool = False, -) -> Tuple[Union[paddle.Tensor, None], ...]: - """Non FP8 GEMM.""" - - assert layout in ("TN", "NN", "NT"), f"GEMM layout {layout} not supported." - transa = layout[0] == "T" - transb = layout[1] == "T" - - if out is None: - if accumulate: - out = paddle.zeros( - shape=[ - B.shape[1] if transb else B.shape[0], - A.shape[0] if transa else A.shape[1], - ], - dtype=out_dtype if out_dtype is not None else dtype, - ) - else: - out = paddle.empty( - shape=[ - B.shape[1] if transb else B.shape[0], - A.shape[0] if transa else A.shape[1], - ], - dtype=out_dtype if out_dtype is not None else dtype, - ) - - if gelu and not grad: - gelu_input = paddle.empty_like(out, dtype=dtype) - elif not gelu: - gelu_input = None - - if grad and use_bias: - grad_bias = paddle.empty(shape=[B.shape[1]], dtype=out.dtype) - else: - grad_bias = None - - bias = bias if use_bias else None - - assert ( - A.dtype == dtype and B.dtype == dtype - ), f"Expected dtype={dtype}, but found A.dtype={A.dtype} and B.dtype={B.dtype}" - input_dtype = TE_DType[dtype] - output_dtype = TE_DType[out.dtype] - if use_bias: - bias_dtype = TE_DType[grad_bias.dtype] if grad else TE_DType[bias.dtype] - else: - bias_dtype = output_dtype - - tex.te_gemm( - A, - None, - B, - None, - grad_bias if grad else bias, - out, - None, # out_scale - None, # out_amax - gelu_input, - workspace, - 0, # A_index - 0, # B_index - 0, # D_index - int(input_dtype), - int(input_dtype), - int(output_dtype), - int(bias_dtype), - transa, - transb, - grad, - workspace.shape[0], - accumulate, - False, # use_split_accumulator - 0, # math_sm_count - ) - - return out, grad_bias, gelu_input - - -def fp8_gemm( - A: paddle.Tensor, - A_scale_inv: paddle.Tensor, - A_fp8_tensor: Union[FP8FwdTensors, FP8BwdTensors], - A_dtype: tex.DType, - B: paddle.Tensor, - B_scale_inv: paddle.Tensor, - B_fp8_tensor: Union[FP8FwdTensors, FP8BwdTensors], - B_dtype: tex.DType, - out_dtype: paddle.dtype, - workspace: paddle.Tensor, - gelu: bool = False, - accumulate: bool = False, - out: Optional[paddle.Tensor] = None, - out_index=None, - fp8_meta_tensor: FP8TensorMeta = None, - bias: Optional[paddle.Tensor] = None, - use_bias: bool = False, - use_split_accumulator: bool = False, - D_dtype: Optional[tex.DType] = None, -) -> paddle.Tensor: - """TN layout GEMM with fp8 inputs.""" - - if D_dtype is not None and D_dtype in [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]: - assert fp8_meta_tensor is not None and out_index is not None - - if out is None: - if accumulate: - out = paddle.zeros( - shape=[ - B.shape[0], - A.shape[0], - ], - dtype=out_dtype, - ) - else: - out = paddle.empty( - shape=[ - B.shape[0], - A.shape[0], - ], - dtype=out_dtype, - ) - - # Use bfloat16 as default bias_dtype - bias_dtype = paddle.bfloat16 if bias is None else bias.dtype - if gelu: - gelu_input = paddle.empty_like(out, dtype=bias_dtype) - else: - gelu_input = None - bias_dtype = TE_DType[bias_dtype] - - out_dtype = TE_DType[out.dtype] if D_dtype is None else D_dtype - - tex.te_gemm( - A, - A_scale_inv, - B, - B_scale_inv, - bias if use_bias else None, - out, - None if out_index is None else fp8_meta_tensor.scale, - None if out_index is None else fp8_meta_tensor.amax_history, - gelu_input, # this is pre_gelu_out - workspace, - A_fp8_tensor.value, - B_fp8_tensor.value, - 0 if out_index is None else out_index, - int(A_dtype), - int(B_dtype), - int(out_dtype), - int(bias_dtype), - True, # transa - False, # transb - False, # grad - workspace.shape[0], - accumulate, - use_split_accumulator, - 0, # math_sm_count - ) - - return out, gelu_input - - -def cast_to_fp8( - inp: paddle.Tensor, - fp8_meta_tensor: FP8TensorMeta, - fp8_tensor: Union[FP8FwdTensors, FP8BwdTensors], - otype: tex.DType, - out: Optional[paddle.Tensor] = None, -) -> paddle.Tensor: - """Cast input to FP8""" - if out is None: - out = paddle.empty( - shape=inp.shape, - dtype=paddle.uint8, - ) - else: - assert out.shape == inp.shape, "Output shape does not match input shape." - assert out.dtype == paddle.uint8, "Output should be of uint8 dtype." - - tex.cast_to_fp8( - inp, - fp8_meta_tensor.scale, - out, - fp8_meta_tensor.amax_history, - fp8_meta_tensor.scale_inv, - fp8_tensor.value, - int(otype), - ) - return out - - -def cast_from_fp8( - inp: paddle.Tensor, - fp8_meta_tensor: FP8TensorMeta, - fp8_tensor: Union[FP8FwdTensors, FP8BwdTensors], - itype: tex.DType, - otype: tex.DType, -) -> paddle.Tensor: - """Cast input from FP8""" - return tex.cast_from_fp8( - inp, - fp8_meta_tensor.scale_inv, - fp8_tensor.value, - int(itype), - int(otype), - ) - - -def transpose( - inp: paddle.Tensor, - otype: tex.DType, -) -> paddle.Tensor: - """Transpose input""" - return tex.te_transpose( - inp, - int(otype), - ) - - -def cast_transpose( - inp: paddle.Tensor, - fp8_meta_tensor: FP8TensorMeta, - fp8_tensor: Union[FP8FwdTensors, FP8BwdTensors], - otype: tex.DType, - cast_out: Optional[paddle.Tensor] = None, - transpose_out: Optional[paddle.Tensor] = None, -) -> Union[Tuple[paddle.Tensor, paddle.Tensor], None]: - """Cast + Transpose with FP8 output""" - if cast_out is None: - cast_out = paddle.empty( - shape=inp.shape, - dtype=paddle.uint8, - ) - else: - assert cast_out.shape == inp.shape, "cast_out shape does not match input shape." - assert cast_out.dtype == paddle.uint8, "cast_out should be of uint8 dtype." - - if transpose_out is None: - transpose_out = paddle.empty( - shape=[inp.shape[1], inp.shape[0]], - dtype=paddle.uint8, - ) - else: - assert transpose_out.shape == [ - inp.shape[1], - inp.shape[0], - ], "Transposed output shape does not match input shape." - assert transpose_out.dtype == paddle.uint8, "Output should be of uint8 dtype." - - tex.te_cast_transpose( - inp, - fp8_meta_tensor.scale, - cast_out, - transpose_out, - fp8_meta_tensor.amax_history, - fp8_meta_tensor.scale_inv, - fp8_tensor.value, - int(otype), - ) - - return cast_out, transpose_out - - -def cast_transpose_bgrad( - inp: paddle.Tensor, - fp8_meta_tensor: FP8TensorMeta, - fp8_tensor: Union[FP8FwdTensors, FP8BwdTensors], - otype: tex.DType, -) -> Union[Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor], None]: - """Fused Cast + Transpose + Bias Grad""" - grad_bias, cast_out, transpose_out, _, _ = tex.te_cast_transpose_bgrad( - inp, - fp8_meta_tensor.scale, - fp8_meta_tensor.amax_history, - fp8_meta_tensor.scale_inv, - fp8_tensor.value, - int(otype), - ) - - return grad_bias, cast_out, transpose_out - - -def te_gelu( - inp: paddle.Tensor, - otype: tex.DType, -) -> paddle.Tensor: - """Non FP8 GELU""" - return tex.te_gelu( - inp, - int(otype), - ) - - -def gelu_fp8( - inp: paddle.Tensor, - fp8_meta_tensor: FP8TensorMeta, - fp8_tensor: Union[FP8FwdTensors, FP8BwdTensors], - otype: tex.DType, -) -> paddle.Tensor: - """GELU + FP8 cast""" - out, _, _ = tex.te_gelu_fp8( - inp, - fp8_meta_tensor.scale, - fp8_meta_tensor.amax_history, - fp8_meta_tensor.scale_inv, - fp8_tensor.value, - int(otype), - ) - - return out - - -def swiglu( - inp: paddle.Tensor, - otype: tex.DType, -) -> paddle.Tensor: - """Non FP8 SWIGLU""" - return tex.te_swiglu( - inp, - int(otype), - ) - - -def swiglu_pd( - inp: paddle.Tensor, -) -> paddle.Tensor: - """Native SWIGLU""" - gate_out, up_out = paddle.chunk(inp, chunks=2, axis=-1) - out = F.silu(gate_out) * up_out - return out - - -def swiglu_fp8( - inp: paddle.Tensor, - fp8_meta_tensor: FP8TensorMeta, - fp8_tensor: Union[FP8FwdTensors, FP8BwdTensors], - otype: tex.DType, -) -> paddle.Tensor: - """SWIGLU + FP8 cast""" - out, _, _ = tex.te_swiglu_fp8( - inp, - fp8_meta_tensor.scale, - fp8_meta_tensor.amax_history, - fp8_meta_tensor.scale_inv, - fp8_tensor.value, - int(otype), - ) - - return out - - -def dswiglu( - grad_output: paddle.Tensor, - swiglu_input: paddle.Tensor, - otype: tex.DType, -) -> paddle.Tensor: - """dSWIGLU""" - return tex.te_dswiglu( - grad_output, - swiglu_input, - int(otype), - ) - - -def dgelu_cast_transpose_bgrad_fp8( - grad_output: paddle.Tensor, - gelu_input: paddle.Tensor, - fp8_meta_tensor: FP8TensorMeta, - fp8_tensor: Union[FP8FwdTensors, FP8BwdTensors], - otype: tex.DType, -) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: - """ - Fused dgelu + cast / transpose / reduce the result of - the GELU backward along the first dimension - """ - cast_dgelu, transpose_dgelu, dbias, _, _ = tex.te_cast_transpose_bgrad_dgelu( - grad_output, - gelu_input, - fp8_meta_tensor.scale, - fp8_meta_tensor.amax_history, - fp8_meta_tensor.scale_inv, - fp8_tensor.value, - int(otype), - ) - - return cast_dgelu, transpose_dgelu, dbias - - -def layernorm_fwd_fp8( - inp: paddle.Tensor, - weight: paddle.Tensor, - bias: paddle.Tensor, - eps: float, - fp8_meta_tensor: FP8TensorMeta, - fp8_tensor: Union[FP8FwdTensors, FP8BwdTensors], - otype: tex.DType, - sm_margin: int = 0, - zero_centered_gamma: bool = False, -) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: - """LayerNorm with FP8 output""" - out, mu, rsigma, _, _ = tex.te_layernorm_fwd_fp8( - inp, - weight, - bias, - fp8_meta_tensor.scale, - fp8_meta_tensor.amax_history, - fp8_meta_tensor.scale_inv, - eps, - fp8_tensor.value, - int(otype), - sm_margin, - zero_centered_gamma, - ) - return out, mu, rsigma - - -def layernorm_fwd( - inp: paddle.Tensor, - weight: paddle.Tensor, - bias: paddle.Tensor, - eps: float, - otype: tex.DType, - sm_margin: int = 0, - zero_centered_gamma: bool = False, -) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: - """Non-FP8 LayerNorm forward""" - return tex.te_layernorm_fwd(inp, weight, bias, eps, int(otype), sm_margin, zero_centered_gamma) - - -def layernorm_bwd( - dz: paddle.Tensor, - x: paddle.Tensor, - mu: paddle.Tensor, - rsigma: paddle.Tensor, - gamma: paddle.Tensor, - sm_margin: int = 0, - zero_centered_gamma: bool = False, -) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: - """Non-FP8 LayerNorm backward""" - return tex.te_layernorm_bwd(dz, x, mu, rsigma, gamma, sm_margin, zero_centered_gamma) - - -def rmsnorm_fwd( - inp: paddle.Tensor, - weight: paddle.Tensor, - eps: float, - otype: tex.DType, - sm_margin: int = 0, - zero_centered_gamma: bool = False, -) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: - """Non-FP8 RMSNorm forward""" - return tex.te_rmsnorm_fwd(inp, weight, eps, int(otype), sm_margin, zero_centered_gamma) - - -def rmsnorm_fwd_fp8( - inp: paddle.Tensor, - weight: paddle.Tensor, - eps: float, - fp8_meta_tensor: FP8TensorMeta, - fp8_tensor: Union[FP8FwdTensors, FP8BwdTensors], - otype: tex.DType, - sm_margin: int = 0, - zero_centered_gamma: bool = False, -) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: - """RMSNorm with FP8 output""" - out, rsigma, _, _ = tex.te_rmsnorm_fwd_fp8( - inp, - weight, - fp8_meta_tensor.scale, - fp8_meta_tensor.amax_history, - fp8_meta_tensor.scale_inv, - eps, - fp8_tensor.value, - int(otype), - sm_margin, - zero_centered_gamma, - ) - return out, rsigma - - -def rmsnorm_bwd( - dz: paddle.Tensor, - x: paddle.Tensor, - rsigma: paddle.Tensor, - gamma: paddle.Tensor, - sm_margin: int = 0, - zero_centered_gamma: bool = False, -) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: - """Non-FP8 RMSNorm backward""" - return tex.te_rmsnorm_bwd(dz, x, rsigma, gamma, sm_margin, zero_centered_gamma) - - -def mask_to_cu_seqlens( - mask: paddle.Tensor, - need_kv: bool = False, -) -> paddle.Tensor: - """Convert mask to cu_seqlens""" - # mask shape: [b, 1, s_q, s_kv] - if get_global_fp8_state().is_cudagraph_enabled(): - raise RuntimeError("mask_to_cu_seqlens is not supported with cuda graphs.") - q_seqlen, kv_seqlen = mask.shape[2], mask.shape[3] - q_cu_seqlens = paddle.empty(shape=[mask.shape[0] + 1], dtype=paddle.int32) - q_cu_seqlens[0] = 0 - kv_cu_seqlens = None - if need_kv: - kv_cu_seqlens = paddle.empty(shape=[mask.shape[0] + 1], dtype=paddle.int32) - kv_cu_seqlens[0] = 0 - tex.mask_to_cu_seqlens(mask, q_cu_seqlens, kv_cu_seqlens, q_seqlen, kv_seqlen, need_kv) - return q_cu_seqlens, kv_cu_seqlens - - -def fused_attn_fwd_qkvpacked( - qkv: paddle.Tensor, - cu_seqlens: paddle.Tensor, - is_training: bool, - max_seqlen: int, - qkv_dtype: tex.DType, - fused_attention_backend: tex.NVTE_Fused_Attn_Backend, - Bias: paddle.Tensor = None, - attn_scale: float = None, - dropout: float = 0.0, - set_zero: bool = True, - qkv_layout: str = "bs3hd", - bias_type: str = "no_bias", - attn_mask_type: str = "padding", -) -> Tuple[paddle.Tensor, paddle.Tensor]: - """Fused Attention FWD for packed QKV input""" - - assert qkv_dtype in ( - tex.DType.kBFloat16, - tex.DType.kFloat16, - ), "Only support bf16/fp16 for fused attention." - - b = cu_seqlens.shape[0] - 1 - total_seqs = qkv.shape[0] * qkv.shape[1] - h = qkv.shape[3] - d = qkv.shape[4] - - if attn_scale is None: - attn_scale = 1.0 / math.sqrt(d) - - if bias_type != "no_bias": - assert Bias is not None, "bias tensor cannot be None when bias_type is not no_bias." - assert Bias.shape == [ - 1, - h, - max_seqlen, - max_seqlen, - ], "bias tensor must be in [1, h, max_seqlen, max_seqlen] shape." - assert Bias.dtype == qkv.dtype, "bias tensor must be in the same dtype as qkv." - - assert ( - fused_attention_backend != FusedAttnBackend["No_Backend"] - ), "Fused attention does not support this input combination." - - rng_elts_per_thread = None - # BF16/FP16 fused attention API from fmha_v1 apex - if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]: - rng_elts_per_thread = ( - max_seqlen * max_seqlen + BACKEND_F16m512_THREADS_PER_CTA - 1 - ) // BACKEND_F16m512_THREADS_PER_CTA - - # BF16/FP16 fused attention API from fmha_v2 - if fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]: - rng_elts_per_thread = BACKEND_F16arb_ELTS_PER_THREADS - - qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()]) - if qkv_format == "thd": - set_zero = True - if set_zero: - out = paddle.full(shape=[b, max_seqlen, h, d], fill_value=0, dtype=qkv.dtype) - else: - out = paddle.empty(shape=[b, max_seqlen, h, d], dtype=qkv.dtype) - - if is_training: - if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]: - softmax_aux = paddle.empty(shape=[b, h, max_seqlen, max_seqlen], dtype=qkv.dtype) - elif fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]: - softmax_aux = paddle.empty(shape=[b, h, max_seqlen, 1], dtype="float32") - else: - raise ValueError("Unsupported fused attention backend.") - else: - softmax_aux = None - - rng_state = paddle.empty( - shape=[ - 2, - ], - dtype=paddle.int64, - ) - - # execute kernel - tex.te_fused_attn_fwd_qkvpacked( - qkv, - cu_seqlens, - Bias, - out, - softmax_aux, - rng_state, - b, - h, - d, - total_seqs, - max_seqlen, - is_training, - attn_scale, - dropout, - qkv_layout, - bias_type, - attn_mask_type, - int(qkv_dtype), - rng_elts_per_thread, - ) - return out, softmax_aux, rng_state - - -def fused_attn_bwd_qkvpacked( - qkv: paddle.Tensor, - cu_seqlens: paddle.Tensor, - rng_state: paddle.Tensor, - o: paddle.Tensor, - d_o: paddle.Tensor, - softmax_aux: paddle.Tensor, - fused_attention_backend: tex.NVTE_Fused_Attn_Backend, - max_seqlen: int, - qkv_dtype: tex.DType, - attn_scale: float = None, - dropout: float = 0.0, - set_zero: bool = True, - qkv_layout: str = "bs3hd", - bias_type: str = "no_bias", - attn_mask_type: str = "padding", - deterministic: bool = False, -) -> Tuple[paddle.Tensor, paddle.Tensor]: - """Fused Attention BWD for packed QKV input""" - - assert qkv_dtype in ( - tex.DType.kBFloat16, - tex.DType.kFloat16, - ), "Only support bf16/fp16 for fused attention." - - b = cu_seqlens.shape[0] - 1 - total_seqs = qkv.shape[0] * qkv.shape[1] - h = qkv.shape[3] - d = qkv.shape[4] - - if attn_scale is None: - attn_scale = 1.0 / math.sqrt(d) - - assert ( - fused_attention_backend != FusedAttnBackend["No_Backend"] - ), "Fused attention does not support this input combination." - - qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()]) - if qkv_format == "thd": - set_zero = True - if set_zero: - dqkv = paddle.full(shape=qkv.shape, fill_value=0, dtype=qkv.dtype) - else: - dqkv = paddle.empty(shape=qkv.shape, dtype=qkv.dtype) - - if bias_type != "no_bias": - if qkv_format == "thd": - dbias = paddle.zero(shape=[1, h, max_seqlen, max_seqlen], dtype=qkv.dtype) - else: - dbias = paddle.empty(shape=[1, h, max_seqlen, max_seqlen], dtype=qkv.dtype) - else: - dbias = None - # execute kernel - dqkv, dbias = tex.te_fused_attn_bwd_qkvpacked( - qkv, - cu_seqlens, - o, - d_o, - softmax_aux, - dqkv, - dbias, - rng_state, - b, - h, - d, - total_seqs, - max_seqlen, - attn_scale, - dropout, - qkv_layout, - bias_type, - attn_mask_type, - int(qkv_dtype), - deterministic, - ) - - return dqkv, dbias - - -def fused_attn_fwd_kvpacked( - q: paddle.Tensor, - kv: paddle.Tensor, - cu_seqlens_q: paddle.Tensor, - cu_seqlens_kv: paddle.Tensor, - is_training: bool, - max_seqlen_q: int, - max_seqlen_kv: int, - qkv_dtype: tex.DType, - fused_attention_backend: tex.NVTE_Fused_Attn_Backend, - Bias: paddle.Tensor = None, - attn_scale: float = None, - dropout: float = 0.0, - set_zero: bool = True, - qkv_layout: str = "bshd_bs2hd", - bias_type: str = "no_bias", - attn_mask_type: str = "padding", -) -> Tuple[paddle.Tensor, paddle.Tensor]: - """Fused Attention FWD for packed KV input""" - - assert qkv_dtype in ( - tex.DType.kBFloat16, - tex.DType.kFloat16, - ), "Only support bf16/fp16 for fused attention." - assert ( - cu_seqlens_q.shape == cu_seqlens_kv.shape - ), "cu_seqlens_q and cu_seqlens_kv must have the same shape" - - b = cu_seqlens_q.shape[0] - 1 - total_seqs_q = q.shape[0] * q.shape[1] - total_seqs_kv = kv.shape[0] * kv.shape[1] - h = q.shape[2] - d = q.shape[3] - - if attn_scale is None: - attn_scale = 1.0 / math.sqrt(d) - - if bias_type != "no_bias": - assert Bias is not None, "bias tensor cannot be None when bias_type is not no_bias." - assert Bias.shape == [ - 1, - h, - max_seqlen_q, - max_seqlen_kv, - ], "bias tensor must be in [1, h, max_seqlen, max_seqlen] shape." - assert Bias.dtype == q.dtype, "bias tensor must be in the same dtype as q and kv." - - assert ( - fused_attention_backend != FusedAttnBackend["No_Backend"] - ), "Fused attention does not support this input combination." - - rng_elts_per_thread = None - # BF16/FP16 fused attention API from fmha_v1 apex - if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]: - rng_elts_per_thread = ( - max_seqlen_q * max_seqlen_kv + BACKEND_F16m512_THREADS_PER_CTA - 1 - ) // BACKEND_F16m512_THREADS_PER_CTA - - # BF16/FP16 fused attention API from fmha_v2 - if fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]: - rng_elts_per_thread = BACKEND_F16arb_ELTS_PER_THREADS - - qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()]) - if qkv_format == "thd": - set_zero = True - if set_zero: - out = paddle.full(shape=[b, max_seqlen_q, h, d], fill_value=0, dtype=q.dtype) - else: - out = paddle.empty(shape=[b, max_seqlen_q, h, d], dtype=q.dtype) - - if is_training: - if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]: - softmax_aux = paddle.empty(shape=[b, h, max_seqlen_q, max_seqlen_kv], dtype=q.dtype) - elif fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]: - softmax_aux = paddle.empty(shape=[b, h, max_seqlen_q, 1], dtype="float32") - else: - raise ValueError("Unsupported fused attention backend.") - else: - softmax_aux = None - - rng_state = paddle.empty( - shape=[ - 2, - ], - dtype=paddle.int64, - ) - - # execute kernel - tex.te_fused_attn_fwd_kvpacked( - q, - kv, - cu_seqlens_q, - cu_seqlens_kv, - Bias, - out, - softmax_aux, - rng_state, - b, - h, - d, - total_seqs_q, - total_seqs_kv, - max_seqlen_q, - max_seqlen_kv, - is_training, - attn_scale, - dropout, - qkv_layout, - bias_type, - attn_mask_type, - int(qkv_dtype), - rng_elts_per_thread, - ) - - return out, softmax_aux, rng_state - - -def fused_attn_bwd_kvpacked( - q: paddle.Tensor, - kv: paddle.Tensor, - cu_seqlens_q: paddle.Tensor, - cu_seqlens_kv: paddle.Tensor, - rng_state: paddle.Tensor, - o: paddle.Tensor, - d_o: paddle.Tensor, - softmax_aux: paddle.Tensor, - fused_attention_backend: tex.NVTE_Fused_Attn_Backend, - max_seqlen_q: int, - max_seqlen_kv: int, - qkv_dtype: tex.DType, - attn_scale: float = None, - dropout: float = 0.0, - set_zero: bool = True, - qkv_layout: str = "bshd_bs2hd", - bias_type: str = "no_bias", - attn_mask_type: str = "padding", - deterministic: bool = False, -) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: - """Fused Attention BWD for packed KV input""" - - assert qkv_dtype in ( - tex.DType.kBFloat16, - tex.DType.kFloat16, - ), "Only support bf16/fp16 for fused attention." - assert ( - cu_seqlens_q.shape == cu_seqlens_kv.shape - ), "cu_seqlens_q and cu_seqlens_kv must have the same shape" - - b = cu_seqlens_q.shape[0] - 1 - total_seqs_q = q.shape[0] * q.shape[1] - total_seqs_kv = kv.shape[0] * kv.shape[1] - h = q.shape[2] - d = q.shape[3] - - if attn_scale is None: - attn_scale = 1.0 / math.sqrt(d) - - assert ( - fused_attention_backend != FusedAttnBackend["No_Backend"] - ), "Fused attention does not support this input combination." - - qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()]) - if qkv_format == "thd": - set_zero = True - if set_zero: - dq = paddle.full(shape=q.shape, fill_value=0, dtype=q.dtype) - dkv = paddle.full(shape=kv.shape, fill_value=0, dtype=kv.dtype) - else: - dq = paddle.empty(shape=q.shape, dtype=q.dtype) - dkv = paddle.empty(shape=kv.shape, dtype=kv.dtype) - if bias_type != "no_bias": - if qkv_format == "thd": - dbias = paddle.zero(shape=[1, h, max_seqlen_q, max_seqlen_kv], dtype=q.dtype) - else: - dbias = paddle.empty(shape=[1, h, max_seqlen_q, max_seqlen_kv], dtype=q.dtype) - else: - dbias = None - # execute kernel - tex.te_fused_attn_bwd_kvpacked( - q, - kv, - cu_seqlens_q, - cu_seqlens_kv, - o, - d_o, - softmax_aux, - dq, - dkv, - dbias, - rng_state, - b, - h, - d, - total_seqs_q, - total_seqs_kv, - max_seqlen_q, - max_seqlen_kv, - attn_scale, - dropout, - qkv_layout, - bias_type, - attn_mask_type, - int(qkv_dtype), - deterministic, - ) - return dq, dkv, dbias - - -def fused_attn_fwd( - q: paddle.Tensor, - k: paddle.Tensor, - v: paddle.Tensor, - cu_seqlens_q: paddle.Tensor, - cu_seqlens_kv: paddle.Tensor, - is_training: bool, - max_seqlen_q: int, - max_seqlen_kv: int, - qkv_dtype: tex.DType, - fused_attention_backend: tex.NVTE_Fused_Attn_Backend, - Bias: paddle.Tensor = None, - attn_scale: float = None, - dropout: float = 0.0, - set_zero: bool = True, - qkv_layout: str = "bshd_bshd_bshd", - bias_type: str = "no_bias", - attn_mask_type: str = "padding", -) -> Tuple[paddle.Tensor, paddle.Tensor]: - """Fused Attention FWD for unpacked QKV input""" - - assert qkv_dtype in ( - tex.DType.kBFloat16, - tex.DType.kFloat16, - ), "Only support bf16/fp16 for fused attention." - assert ( - cu_seqlens_q.shape == cu_seqlens_kv.shape - ), "cu_seqlens_q and cu_seqlens_kv must have the same shape" - assert ( - qkv_layout == "bshd_bshd_bshd" - ), "Only support bshd_bshd_bshd layout for unpacked QKV input for now." - b = cu_seqlens_q.shape[0] - 1 - - h = q.shape[-2] - d = q.shape[-1] - - if attn_scale is None: - attn_scale = 1.0 / math.sqrt(d) - - if bias_type != "no_bias": - assert Bias is not None, "bias tensor cannot be None when bias_type is not no_bias." - assert Bias.shape == [ - 1, - h, - max_seqlen_q, - max_seqlen_kv, - ], "bias tensor must be in [1, h, max_seqlen_q, max_seqlen_kv] shape." - assert Bias.dtype == q.dtype, "bias tensor must be in the same dtype as qkv." - - assert ( - fused_attention_backend != FusedAttnBackend["No_Backend"] - ), "Fused attention does not support this input combination." - - rng_elts_per_thread = None - # BF16/FP16 fused attention API from fmha_v1 apex - if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]: - rng_elts_per_thread = ( - max_seqlen_q * max_seqlen_kv + BACKEND_F16m512_THREADS_PER_CTA - 1 - ) // BACKEND_F16m512_THREADS_PER_CTA - - # BF16/FP16 fused attention API from fmha_v2 - if fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]: - rng_elts_per_thread = BACKEND_F16arb_ELTS_PER_THREADS - - qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()]) - if qkv_format == "thd": - set_zero = True - if set_zero: - out = paddle.full(shape=[b, max_seqlen_q, h, d], fill_value=0, dtype=q.dtype) - else: - out = paddle.empty(shape=[b, max_seqlen_q, h, d], dtype=q.dtype) - - if is_training: - if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]: - softmax_aux = paddle.empty(shape=[b, h, max_seqlen_q, max_seqlen_kv], dtype=q.dtype) - elif fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]: - softmax_aux = paddle.empty(shape=[b, h, max_seqlen_q, 1], dtype="float32") - else: - raise ValueError("Unsupported fused attention backend.") - else: - softmax_aux = None - - rng_state = paddle.empty( - shape=[ - 2, - ], - dtype=paddle.int64, - ) - - # execute kernel - tex.te_fused_attn_fwd( - q, - k, - v, - cu_seqlens_q, - cu_seqlens_kv, - Bias, - out, - softmax_aux, - rng_state, - b, - h, - d, - max_seqlen_q, - max_seqlen_kv, - is_training, - attn_scale, - dropout, - qkv_layout, - bias_type, - attn_mask_type, - int(qkv_dtype), - rng_elts_per_thread, - ) - return out, softmax_aux, rng_state - - -def fused_attn_bwd( - q: paddle.Tensor, - k: paddle.Tensor, - v: paddle.Tensor, - cu_seqlens_q: paddle.Tensor, - cu_seqlens_kv: paddle.Tensor, - rng_state: paddle.Tensor, - o: paddle.Tensor, - d_o: paddle.Tensor, - softmax_aux: paddle.Tensor, - fused_attention_backend: tex.NVTE_Fused_Attn_Backend, - max_seqlen_q: int, - max_seqlen_kv: int, - qkv_dtype: tex.DType, - attn_scale: float = None, - dropout: float = 0.0, - set_zero: bool = True, - qkv_layout: str = "bshd_bshd_bshd", - bias_type: str = "no_bias", - attn_mask_type: str = "padding", - deterministic: bool = False, -) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: - """Fused Attention BWD for packed KV input""" - - assert qkv_dtype in ( - tex.DType.kBFloat16, - tex.DType.kFloat16, - ), "Only support bf16/fp16 for fused attention." - assert ( - cu_seqlens_q.shape == cu_seqlens_kv.shape - ), "cu_seqlens_q and cu_seqlens_kv must have the same shape" - assert ( - qkv_layout == "bshd_bshd_bshd" - ), "Only support bshd_bshd_bshd layout for unpacked QKV input for now." - - b = cu_seqlens_q.shape[0] - 1 - h = q.shape[-2] - d = q.shape[-1] - - if attn_scale is None: - attn_scale = 1.0 / math.sqrt(d) - - assert ( - fused_attention_backend != FusedAttnBackend["No_Backend"] - ), "Fused attention does not support this input combination." - - qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()]) - if qkv_format == "thd": - set_zero = True - if set_zero: - dq = paddle.full(shape=q.shape, fill_value=0, dtype=q.dtype) - dk = paddle.full(shape=k.shape, fill_value=0, dtype=k.dtype) - dv = paddle.full(shape=v.shape, fill_value=0, dtype=v.dtype) - else: - dq = paddle.empty(shape=q.shape, dtype=q.dtype) - dk = paddle.empty(shape=k.shape, dtype=k.dtype) - dv = paddle.empty(shape=v.shape, dtype=v.dtype) - if bias_type != "no_bias": - if qkv_format == "thd": - dbias = paddle.zero(shape=[1, h, max_seqlen_q, max_seqlen_kv], dtype=q.dtype) - else: - dbias = paddle.empty(shape=[1, h, max_seqlen_q, max_seqlen_kv], dtype=q.dtype) - else: - dbias = None - # execute kernel - tex.te_fused_attn_bwd( - q, - k, - v, - cu_seqlens_q, - cu_seqlens_kv, - o, - d_o, - softmax_aux, - dq, - dk, - dv, - dbias, - rng_state, - b, - h, - d, - max_seqlen_q, - max_seqlen_kv, - attn_scale, - dropout, - qkv_layout, - bias_type, - attn_mask_type, - int(qkv_dtype), - deterministic, - ) - return dq, dk, dv, dbias - - -def scaled_softmax_forward( - inp: paddle.Tensor, - scale_factor: float, -) -> paddle.Tensor: - """scaled softmax forward""" - return tex.te_scaled_softmax_forward(inp, scale_factor) - - -def scaled_softmax_backward( - out_grad: paddle.Tensor, - softmax_results: paddle.Tensor, - scale_factor: float, -) -> paddle.Tensor: - """scaled softmax backward""" - tex.te_scaled_softmax_backward(out_grad, softmax_results, scale_factor) - return out_grad - - -def scaled_masked_softmax_forward( - inp: paddle.Tensor, - mask: paddle.Tensor, - scale_factor: float, -) -> paddle.Tensor: - """scaled masked softmax forward""" - - return tex.te_scaled_masked_softmax_forward(inp, mask, scale_factor) - - -def scaled_masked_softmax_backward( - out_grad: paddle.Tensor, - softmax_results: paddle.Tensor, - scale_factor: float, -) -> paddle.Tensor: - """scaled masked softmax backward""" - tex.te_scaled_softmax_backward(out_grad, softmax_results, scale_factor) - return out_grad - - -def scaled_upper_triang_masked_softmax_forward( - inp: paddle.Tensor, - scale_factor: float, -) -> paddle.Tensor: - """scaled upper triang masked softmax forward""" - return tex.te_scaled_upper_triang_masked_softmax_forward(inp, scale_factor) - - -def scaled_upper_triang_masked_softmax_backward( - out_grad: paddle.Tensor, - softmax_results: paddle.Tensor, - scale_factor: float, -) -> paddle.Tensor: - """scaled upper triang masked softmax backward""" - tex.te_scaled_upper_triang_masked_softmax_backward(out_grad, softmax_results, scale_factor) - return out_grad diff --git a/transformer_engine/paddle/csrc/common.cpp b/transformer_engine/paddle/csrc/common.cpp deleted file mode 100644 index d65fbb2b50..0000000000 --- a/transformer_engine/paddle/csrc/common.cpp +++ /dev/null @@ -1,84 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -#include "common.h" - -namespace transformer_engine { -namespace paddle_ext { - -TensorWrapper MakeNvteTensor(const void *data_ptr, const std::vector &shape, - const DType type) { - return TensorWrapper(const_cast(data_ptr), shape, type); -} - -TensorWrapper MakeNvteTensor(void *data_ptr, const NVTEShape &shape, const DType type) { - return TensorWrapper(data_ptr, shape, type); -} - -TensorWrapper MakeNvteTensor(void *data_ptr, const std::vector &shape, const DType type, - void *amax_ptr, void *scale_ptr, void *scale_inv_ptr) { - return TensorWrapper(data_ptr, shape, type, reinterpret_cast(amax_ptr), - reinterpret_cast(scale_ptr), - reinterpret_cast(scale_inv_ptr)); -} - -TensorWrapper MakeNvteTensor(paddle::Tensor &tensor) { // NOLINT - return MakeNvteTensor(tensor.data(), GetShapeArray(tensor), Paddle2NvteDType(tensor.dtype())); -} - -TensorWrapper MakeNvteTensor(const paddle::Tensor &tensor) { - return MakeNvteTensor(const_cast(tensor.data()), GetShapeArray(tensor), - Paddle2NvteDType(tensor.dtype())); -} - -paddle::Tensor AllocateSpace(const NVTEShape &shape, const DType type, const paddle::Place &place, - bool init_to_zeros) { - auto size = shape.ndim; - if (size == 2 && init_to_zeros) { - return paddle::zeros({static_cast(shape.data[0]), static_cast(shape.data[1])}, - Nvte2PaddleDType(type), place); - } else if (size == 2) { - return paddle::empty({static_cast(shape.data[0]), static_cast(shape.data[1])}, - Nvte2PaddleDType(type), place); - } else if (size == 1 && init_to_zeros) { - return paddle::zeros({static_cast(shape.data[0])}, Nvte2PaddleDType(type), place); - } else if (size == 1) { - return paddle::empty({static_cast(shape.data[0])}, Nvte2PaddleDType(type), place); - } - NVTE_CHECK(false, "Should never reach here! func: AllocateSpace"); -} - -// MHA utils -// convert QKV layout to enum -NVTE_QKV_Layout get_nvte_qkv_layout(const std::string &qkv_layout) { - static const std::unordered_map layout_map = { - {"sb3hd", NVTE_QKV_Layout::NVTE_SB3HD}, - {"sbh3d", NVTE_QKV_Layout::NVTE_SBH3D}, - {"sbhd_sb2hd", NVTE_QKV_Layout::NVTE_SBHD_SB2HD}, - {"sbhd_sbh2d", NVTE_QKV_Layout::NVTE_SBHD_SBH2D}, - {"sbhd_sbhd_sbhd", NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD}, - {"bs3hd", NVTE_QKV_Layout::NVTE_BS3HD}, - {"bsh3d", NVTE_QKV_Layout::NVTE_BSH3D}, - {"bshd_bs2hd", NVTE_QKV_Layout::NVTE_BSHD_BS2HD}, - {"bshd_bsh2d", NVTE_QKV_Layout::NVTE_BSHD_BSH2D}, - {"bshd_bshd_bshd", NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD}, - {"t3hd", NVTE_QKV_Layout::NVTE_T3HD}, - {"th3d", NVTE_QKV_Layout::NVTE_TH3D}, - {"thd_t2hd", NVTE_QKV_Layout::NVTE_THD_T2HD}, - {"thd_th2d", NVTE_QKV_Layout::NVTE_THD_TH2D}, - {"thd_thd_thd", NVTE_QKV_Layout::NVTE_THD_THD_THD}, - }; - - auto it = layout_map.find(qkv_layout); - if (it != layout_map.end()) { - return it->second; - } else { - NVTE_ERROR("Invalid QKV layout string: " + qkv_layout); - } -} - -} // namespace paddle_ext -} // namespace transformer_engine diff --git a/transformer_engine/paddle/csrc/common.h b/transformer_engine/paddle/csrc/common.h deleted file mode 100644 index 83737c0d21..0000000000 --- a/transformer_engine/paddle/csrc/common.h +++ /dev/null @@ -1,185 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include - -#include "common/util/logging.h" -#include "paddle/extension.h" -#include "paddle/phi/backends/all_context.h" - -namespace transformer_engine { -namespace paddle_ext { -// Paddle Tensor Utils -template -inline const void *GetDataPtr(const paddle::Tensor &x, int64_t index) { - if (index < 0 || index >= x.numel()) { - NVTE_ERROR("Index out of bound"); - } - return reinterpret_cast(x.data() + static_cast(index)); -} - -template -inline void *GetDataPtr(paddle::Tensor &x, int64_t index) { // NOLINT - if (index < 0 || index >= x.numel()) { - NVTE_ERROR("Index out of bound"); - } - return reinterpret_cast(x.data() + static_cast(index)); -} - -template -inline const void *GetOptionalDataPtr(const paddle::optional &x, int64_t index) { - return x ? GetDataPtr(*x, index) : nullptr; -} - -template -inline void *GetOptionalDataPtr(paddle::optional &x, int64_t index) { // NOLINT - return x ? GetDataPtr(*x, index) : nullptr; -} - -inline const void *GetOptionalDataPtr(const paddle::optional &x) { - return x ? x->data() : nullptr; -} - -inline void *GetOptionalDataPtr(paddle::optional &x) { // NOLINT - return x ? x->data() : nullptr; -} - -inline std::vector GetShapeArray(const paddle::Tensor &x) { - std::vector shapes; - for (auto dim : x.shape()) { - shapes.push_back(static_cast(dim)); - } - return shapes; -} - -inline std::vector GetShapeArray(const paddle::optional &x) { - if (x) return GetShapeArray(x.get()); - return {0}; -} - -paddle::Tensor AllocateSpace(const NVTEShape &shape, const DType type, const paddle::Place &place, - bool init_to_zeros = 0); - -// DType Utils -inline paddle::DataType Nvte2PaddleDType(DType t) { - switch (t) { - case DType::kInt32: - case DType::kFloat32: - return paddle::DataType::FLOAT32; - case DType::kFloat16: - return paddle::DataType::FLOAT16; - case DType::kBFloat16: - return paddle::DataType::BFLOAT16; - case DType::kByte: - case DType::kFloat8E4M3: - case DType::kFloat8E5M2: - return paddle::DataType::UINT8; - default: - NVTE_ERROR("Invalid type"); - } -} - -inline DType Paddle2NvteDType(paddle::DataType t) { - switch (t) { - case paddle::DataType::FLOAT16: - return DType::kFloat16; - case paddle::DataType::FLOAT32: - return DType::kFloat32; - case paddle::DataType::BFLOAT16: - return DType::kBFloat16; - case paddle::DataType::BOOL: - return DType::kByte; - case paddle::DataType::UINT8: - return DType::kByte; - case paddle::DataType::INT32: - return DType::kInt32; - case paddle::DataType::INT64: - return DType::kInt64; - default: - NVTE_ERROR("Invalid type"); - } -} - -inline DType Int2NvteDType(int64_t dtype) { - if (dtype >= 0 && dtype < static_cast(DType::kNumTypes)) { - return static_cast(dtype); - } else { - NVTE_ERROR("Type not supported."); - } -} - -// get the fused attention backend -inline NVTE_Fused_Attn_Backend get_fused_attn_backend( - const transformer_engine::DType q_dtype, const transformer_engine::DType kv_dtype, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim) { - NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - static_cast(q_dtype), static_cast(kv_dtype), qkv_layout, bias_type, - attn_mask_type, p_dropout, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, - head_dim, head_dim, -1, -1); - return fused_attention_backend; -} - -// CUDA Utils -class cudaDevicePropertiesManager { - public: - static cudaDevicePropertiesManager &Instance() { - static thread_local cudaDevicePropertiesManager instance; - return instance; - } - - int GetMultiProcessorCount() { - if (!prop_queried_) { - int device_id; - NVTE_CHECK_CUDA(cudaGetDevice(&device_id)); - cudaGetDeviceProperties(&prop_, device_id); - prop_queried_ = true; - } - return prop_.multiProcessorCount; - } - - int GetMajor() { - if (!prop_queried_) { - int device_id; - NVTE_CHECK_CUDA(cudaGetDevice(&device_id)); - cudaGetDeviceProperties(&prop_, device_id); - prop_queried_ = true; - } - return prop_.major; - } - - private: - bool prop_queried_ = false; - cudaDeviceProp prop_; -}; - -// NVTE Tensor Utils -TensorWrapper MakeNvteTensor(const void *data_ptr, const std::vector &shape, - const DType type); -TensorWrapper MakeNvteTensor(void *data_ptr, const NVTEShape &shape, const DType type); -TensorWrapper MakeNvteTensor(void *data_ptr, const std::vector &shape, const DType type, - void *amax_ptr, void *scale_ptr, void *scale_inv_ptr); -TensorWrapper MakeNvteTensor(paddle::Tensor &tensor); // NOLINT -TensorWrapper MakeNvteTensor(const paddle::Tensor &tensor); - -NVTE_QKV_Layout get_nvte_qkv_layout(const std::string &qkv_layout); - -} // namespace paddle_ext -} // namespace transformer_engine diff --git a/transformer_engine/paddle/csrc/custom_ops.cu b/transformer_engine/paddle/csrc/custom_ops.cu deleted file mode 100644 index 460f4575e6..0000000000 --- a/transformer_engine/paddle/csrc/custom_ops.cu +++ /dev/null @@ -1,1776 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -#include -#include -#include - -#include "common.h" -#include "common/common.h" -#include "paddle/phi/backends/gpu/cuda/cuda_graph.h" - -namespace transformer_engine { -namespace paddle_ext { - -// convert bias type to enum -NVTE_Bias_Type get_nvte_bias_type(const std::string bias_type) { - if (bias_type == "no_bias") { - return NVTE_Bias_Type::NVTE_NO_BIAS; - } else if (bias_type == "pre_scale_bias") { - return NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS; - } else if (bias_type == "post_scale_bias") { - return NVTE_Bias_Type::NVTE_POST_SCALE_BIAS; - } else { - NVTE_ERROR("Invalid bias type. \n"); - } -} - -// convert attn mask type to enum -NVTE_Mask_Type get_nvte_mask_type(const std::string mask_type) { - if (mask_type == "padding") { - return NVTE_Mask_Type::NVTE_PADDING_MASK; - } else if (mask_type == "causal") { - return NVTE_Mask_Type::NVTE_CAUSAL_MASK; - } else if (mask_type == "no_mask") { - return NVTE_Mask_Type::NVTE_NO_MASK; - } else { - NVTE_ERROR("Invalid attention mask type. \n"); - } -} - -void cast_to_fp8(const paddle::Tensor &input, const paddle::Tensor &scale, - paddle::Tensor &output, // NOLINT - paddle::Tensor &amax, // NOLINT - paddle::Tensor &scale_inv, // NOLINT - int64_t index, int64_t otype) { - auto shape = GetShapeArray(input); - - auto input_cu = MakeNvteTensor(input); - auto output_cu = MakeNvteTensor( - output.data(), shape, Int2NvteDType(otype), GetDataPtr(amax, index), - const_cast(GetDataPtr(scale, index)), GetDataPtr(scale_inv, index)); - - nvte_fp8_quantize(input_cu.data(), output_cu.data(), input.stream()); -} - -std::vector cast_from_fp8(const paddle::Tensor &input, - const paddle::Tensor &scale_inv, int64_t index, - int64_t itype, int64_t otype) { - auto shape = GetShapeArray(input); - - auto output = paddle::empty_like(input, Nvte2PaddleDType(Int2NvteDType(otype))); - auto input_cu = - MakeNvteTensor(const_cast(input.data()), shape, Int2NvteDType(itype), nullptr, - nullptr, const_cast(GetDataPtr(scale_inv, index))); - auto output_cu = MakeNvteTensor(output); - - nvte_fp8_dequantize(input_cu.data(), output_cu.data(), input.stream()); - - return {output}; -} - -std::vector te_transpose(const paddle::Tensor &input, int64_t otype) { - auto shape = GetShapeArray(input); - NVTE_CHECK(shape.size() == 2, "Expect the input to have 2 dimensions."); - size_t M = shape[0]; - size_t N = shape[1]; - - auto output = paddle::empty({input.shape()[1], input.shape()[0]}, input.dtype(), input.place()); - - auto input_cu = MakeNvteTensor(const_cast(input.data()), {M, N}, Int2NvteDType(otype)); - auto output_cu = MakeNvteTensor(output.data(), {N, M}, Int2NvteDType(otype)); - - nvte_transpose(input_cu.data(), output_cu.data(), input.stream()); - - return {output}; -} - -void te_cast_transpose(const paddle::Tensor &input, const paddle::Tensor &scale, - paddle::Tensor &output_cast, // NOLINT - paddle::Tensor &output_transpose, // NOLINT - paddle::Tensor &amax, // NOLINT - paddle::Tensor &scale_inv, // NOLINT - int64_t index, int64_t otype) { - auto shape = GetShapeArray(input); - NVTE_CHECK(shape.size() == 2, "Expect the input to have 2 dimensions."); - - size_t M = shape[0]; - size_t N = shape[1]; - - auto input_cu = MakeNvteTensor(input); - void *amax_data = GetDataPtr(amax, index); - void *scale_data = const_cast(GetDataPtr(scale, index)); - void *scale_inv_data = GetDataPtr(scale_inv, index); - auto output_cast_cu = MakeNvteTensor(output_cast.data(), {M, N}, Int2NvteDType(otype), amax_data, - scale_data, scale_inv_data); - auto output_transpose_cu = MakeNvteTensor(output_transpose.data(), {N, M}, Int2NvteDType(otype), - amax_data, scale_data, scale_inv_data); - - nvte_cast_transpose(input_cu.data(), output_cast_cu.data(), output_transpose_cu.data(), - input.stream()); -} - -std::vector te_cast_transpose_bgrad(const paddle::Tensor &grad_output, - const paddle::Tensor &scale, - paddle::Tensor &amax, // NOLINT - paddle::Tensor &scale_inv, // NOLINT - int64_t index, int64_t otype) { - auto shape = GetShapeArray(grad_output); - NVTE_CHECK(shape.size() == 2, "Expect the input to have 2 dimensions."); - - size_t M = shape[0]; - size_t N = shape[1]; - - auto grad_bias = - paddle::empty({grad_output.shape()[1]}, grad_output.dtype(), grad_output.place()); - auto grad_output_cast = - paddle::empty_like(grad_output, Nvte2PaddleDType(Int2NvteDType(otype)), grad_output.place()); - auto grad_output_transpose = - paddle::empty({grad_output.shape()[1], grad_output.shape()[0]}, - Nvte2PaddleDType(Int2NvteDType(otype)), grad_output.place()); - - auto input_cu = MakeNvteTensor(grad_output); - void *amax_data = GetDataPtr(amax, index); - void *scale_data = const_cast(GetDataPtr(scale, index)); - void *scale_inv_data = GetDataPtr(scale_inv, index); - auto output_cast_cu = MakeNvteTensor(grad_output_cast.data(), {M, N}, Int2NvteDType(otype), - amax_data, scale_data, scale_inv_data); - auto output_transpose_cu = - MakeNvteTensor(grad_output_transpose.data(), {N, M}, Int2NvteDType(otype), amax_data, - scale_data, scale_inv_data); - auto dbias_cu = MakeNvteTensor(grad_bias); - transformer_engine::TensorWrapper workspace; - - nvte_cast_transpose_dbias(input_cu.data(), output_cast_cu.data(), output_transpose_cu.data(), - dbias_cu.data(), workspace.data(), grad_output.stream()); - - // Fill workspace - auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), grad_output.place()); - workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype()); - - nvte_cast_transpose_dbias(input_cu.data(), output_cast_cu.data(), output_transpose_cu.data(), - dbias_cu.data(), workspace.data(), grad_output.stream()); - - return {grad_bias, grad_output_cast, grad_output_transpose}; -} - -void te_gemm(const paddle::Tensor &A, const paddle::optional &A_scale_inverse, - const paddle::Tensor &B, const paddle::optional &B_scale_inverse, - const paddle::optional &bias, paddle::Tensor &D, // NOLINT - paddle::optional &D_scale, // NOLINT - paddle::optional &D_amax, // NOLINT - paddle::optional &pre_gelu_out, paddle::Tensor &workspace, // NOLINT - int64_t A_index, int64_t B_index, int64_t D_index, int64_t A_type, int64_t B_type, - int64_t D_type, int64_t bias_type, bool transa, bool transb, bool grad, - int64_t workspace_size, bool accumulate, bool use_split_accumulator, - int64_t math_sm_count) { - auto te_A = MakeNvteTensor( - const_cast(A.data()), GetShapeArray(A), Int2NvteDType(A_type), nullptr, nullptr, - const_cast(GetOptionalDataPtr(A_scale_inverse, A_index))); - auto te_B = MakeNvteTensor( - const_cast(B.data()), GetShapeArray(B), Int2NvteDType(B_type), nullptr, nullptr, - const_cast(GetOptionalDataPtr(B_scale_inverse, B_index))); - auto te_D = MakeNvteTensor(D.data(), GetShapeArray(D), Int2NvteDType(D_type), - GetOptionalDataPtr(D_amax, D_index), - GetOptionalDataPtr(D_scale, D_index), nullptr); - - auto te_bias = MakeNvteTensor(const_cast(GetOptionalDataPtr(bias)), GetShapeArray(bias), - Int2NvteDType(bias_type)); - - DType gelu_dtype = pre_gelu_out ? Paddle2NvteDType(pre_gelu_out->dtype()) : Int2NvteDType(D_type); - auto te_pre_gelu_out = - MakeNvteTensor(GetOptionalDataPtr(pre_gelu_out), GetShapeArray(pre_gelu_out), gelu_dtype); - auto te_workspace = - MakeNvteTensor(workspace.data(), {static_cast(workspace_size)}, DType::kByte); - - nvte_cublas_gemm(te_A.data(), te_B.data(), te_D.data(), te_bias.data(), te_pre_gelu_out.data(), - transa, transb, grad, te_workspace.data(), accumulate, use_split_accumulator, - math_sm_count, A.stream()); -} - -std::vector te_gelu_fp8(const paddle::Tensor &input, const paddle::Tensor &scale, - paddle::Tensor &amax, // NOLINT - paddle::Tensor &scale_inv, // NOLINT - int64_t index, int64_t otype) { - auto output = paddle::empty_like(input, Nvte2PaddleDType(DType::kByte), input.place()); - - auto input_cu = MakeNvteTensor(input); - auto output_cu = MakeNvteTensor( - output.data(), GetShapeArray(input), Int2NvteDType(otype), GetDataPtr(amax, index), - const_cast(GetDataPtr(scale, index)), GetDataPtr(scale_inv, index)); - - nvte_gelu(input_cu.data(), output_cu.data(), input.stream()); - - return {output}; -} - -std::vector te_gelu(const paddle::Tensor &input, int64_t otype) { - auto output = paddle::empty_like(input, Nvte2PaddleDType(Int2NvteDType(otype)), input.place()); - - auto input_cu = MakeNvteTensor(input); - auto output_cu = MakeNvteTensor(output.data(), GetShapeArray(input), Int2NvteDType(otype)); - - nvte_gelu(input_cu.data(), output_cu.data(), input.stream()); - - return {output}; -} - -std::vector te_swiglu(const paddle::Tensor &input, int64_t otype) { - auto shape = GetShapeArray(input); - NVTE_CHECK(shape.size() == 2, "Expect the input to have 2 dimensions."); - - size_t M = shape[0]; - size_t N = shape[1]; - - auto output = paddle::empty({input.shape()[0], input.shape()[1] / 2}, - Nvte2PaddleDType(Int2NvteDType(otype)), input.place()); - - auto input_cu = MakeNvteTensor(input); - auto output_cu = MakeNvteTensor(output.data(), GetShapeArray(output), Int2NvteDType(otype)); - - nvte_swiglu(input_cu.data(), output_cu.data(), input.stream()); - - return {output}; -} - -std::vector te_swiglu_fp8(const paddle::Tensor &input, const paddle::Tensor &scale, - paddle::Tensor &amax, // NOLINT - paddle::Tensor &scale_inv, // NOLINT - int64_t index, int64_t otype) { - auto shape = GetShapeArray(input); - NVTE_CHECK(shape.size() == 2, "Expect the input to have 2 dimensions."); - - size_t M = shape[0]; - size_t N = shape[1]; - - auto output = paddle::empty({input.shape()[0], input.shape()[1] / 2}, - Nvte2PaddleDType(Int2NvteDType(otype)), input.place()); - - auto input_cu = MakeNvteTensor(input); - auto output_cu = MakeNvteTensor( - output.data(), GetShapeArray(output), Int2NvteDType(otype), GetDataPtr(amax, index), - const_cast(GetDataPtr(scale, index)), GetDataPtr(scale_inv, index)); - - nvte_swiglu(input_cu.data(), output_cu.data(), input.stream()); - - return {output}; -} - -std::vector te_dswiglu(const paddle::Tensor &grad, const paddle::Tensor &input, - int64_t otype) { - auto shape = GetShapeArray(input); - NVTE_CHECK(shape.size() == 2, "Expect the input to have 2 dimensions."); - - size_t M = shape[0]; - size_t N = shape[1]; - - auto output = paddle::empty_like(input, Nvte2PaddleDType(Int2NvteDType(otype)), input.place()); - - auto input_cu = MakeNvteTensor(input.data(), {M, N}, Paddle2NvteDType(input.dtype())); - auto grad_cu = MakeNvteTensor(grad.data(), {M, N / 2}, Paddle2NvteDType(grad.dtype())); - auto output_cu = MakeNvteTensor(output.data(), {M, N}, Paddle2NvteDType(output.dtype())); - - nvte_dswiglu(grad_cu.data(), input_cu.data(), output_cu.data(), input.stream()); - - return {output}; -} - -std::vector te_cast_transpose_bgrad_dgelu(const paddle::Tensor &grad_output, - const paddle::Tensor &gelu_input, - const paddle::Tensor &scale, - paddle::Tensor &amax, // NOLINT - paddle::Tensor &scale_inv, // NOLINT - int64_t index, int64_t otype) { - auto shape = GetShapeArray(grad_output); - NVTE_CHECK(shape.size() == 2, "Expect the grad_output to have 2 dimensions."); - - size_t M = shape[0]; - size_t N = shape[1]; - - // DType grad_output_type = GetTransformerEngineDType(grad_output.scalar_type()); - auto grad_bias = - paddle::empty({grad_output.shape()[1]}, grad_output.dtype(), grad_output.place()); - - auto dgelu = paddle::empty_like(grad_output, Nvte2PaddleDType(DType::kByte), grad_output.place()); - - auto dgelu_transpose = paddle::empty({grad_output.shape()[1], grad_output.shape()[0]}, - Nvte2PaddleDType(DType::kByte), grad_output.place()); - - void *amax_data = GetDataPtr(amax, index); - void *scale_data = const_cast(GetDataPtr(scale, index)); - void *scale_inv_data = GetDataPtr(scale_inv, index); - - TensorWrapper workspace; - - auto gelu_input_cu = MakeNvteTensor(gelu_input); - auto input_cu = MakeNvteTensor(grad_output); - auto cast_output_cu = MakeNvteTensor(dgelu.data(), {M, N}, Int2NvteDType(otype), amax_data, - scale_data, scale_inv_data); - auto transposed_output_cu = MakeNvteTensor(dgelu_transpose.data(), {N, M}, Int2NvteDType(otype), - amax_data, scale_data, scale_inv_data); - auto dbias_cu = MakeNvteTensor(grad_bias); - - nvte_cast_transpose_dbias_dgelu(input_cu.data(), gelu_input_cu.data(), cast_output_cu.data(), - transposed_output_cu.data(), dbias_cu.data(), workspace.data(), - grad_output.stream()); - - // Fill workspace - auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), grad_output.place()); - workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype()); - - nvte_cast_transpose_dbias_dgelu(input_cu.data(), gelu_input_cu.data(), cast_output_cu.data(), - transposed_output_cu.data(), dbias_cu.data(), workspace.data(), - grad_output.stream()); - - return {dgelu, dgelu_transpose, grad_bias}; -} - -std::vector te_layernorm_fwd_fp8(const paddle::Tensor &input, - const paddle::Tensor &weight, - const paddle::Tensor &bias, - const paddle::Tensor &scale, - paddle::Tensor &amax, // NOLINT - paddle::Tensor &scale_inv, // NOLINT - float eps, int64_t index, int64_t otype, - int64_t sm_margin, bool zero_centered_gamma) { - auto shape = GetShapeArray(input); - NVTE_CHECK(shape.size() == 2, "Expect the grad_output to have 2 dimensions."); - - size_t N = shape[0]; - size_t H = shape[1]; - - auto ln_out = paddle::empty_like(input, Nvte2PaddleDType(Int2NvteDType(otype)), input.place()); - auto mu = paddle::empty({static_cast(N)}, paddle::DataType::FLOAT32, input.place()); - auto rsigma = paddle::empty({static_cast(N)}, paddle::DataType::FLOAT32, input.place()); - auto input_cu = MakeNvteTensor(input); - auto gamma_cu = MakeNvteTensor(weight); - auto beta_cu = MakeNvteTensor(bias); - auto z_cu = MakeNvteTensor( - ln_out.data(), {N, H}, Int2NvteDType(otype), GetDataPtr(amax, index), - const_cast(GetDataPtr(scale, index)), GetDataPtr(scale_inv, index)); - auto mu_cu = MakeNvteTensor(mu); - auto rsigma_cu = MakeNvteTensor(rsigma); - TensorWrapper workspace; - - auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount(); - - // This call populates workspace tensor with the required config - nvte_layernorm_fwd(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(), - mu_cu.data(), rsigma_cu.data(), workspace.data(), num_sm - sm_margin, - zero_centered_gamma, input.stream()); - - // Fill workspace - auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), input.place()); - workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype()); - - // Actual call to fwd kernel - nvte_layernorm_fwd(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(), - mu_cu.data(), rsigma_cu.data(), workspace.data(), num_sm - sm_margin, - zero_centered_gamma, input.stream()); - - return {ln_out, mu, rsigma}; -} - -std::vector te_layernorm_fwd(const paddle::Tensor &input, - const paddle::Tensor &weight, - const paddle::Tensor &bias, float eps, int64_t otype, - int64_t sm_margin, bool zero_centered_gamma) { - auto shape = GetShapeArray(input); - NVTE_CHECK(shape.size() == 2, "Expect the grad_output to have 2 dimensions."); - - size_t N = shape[0]; - size_t H = shape[1]; - - auto ln_out = paddle::empty_like(input, input.dtype(), input.place()); - auto mu = paddle::empty({static_cast(N)}, paddle::DataType::FLOAT32, input.place()); - auto rsigma = paddle::empty({static_cast(N)}, paddle::DataType::FLOAT32, input.place()); - auto input_cu = MakeNvteTensor(input); - auto gamma_cu = MakeNvteTensor(weight); - auto beta_cu = MakeNvteTensor(bias); - auto z_cu = MakeNvteTensor(ln_out.data(), {N, H}, Int2NvteDType(otype)); - auto mu_cu = MakeNvteTensor(mu); - auto rsigma_cu = MakeNvteTensor(rsigma); - TensorWrapper workspace; - - auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount(); - - // This call populates workspace tensor with the required config - nvte_layernorm_fwd(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(), - mu_cu.data(), rsigma_cu.data(), workspace.data(), num_sm - sm_margin, - zero_centered_gamma, input.stream()); - - // Fill workspace - auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), input.place()); - workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype()); - - // Actual call to fwd kernel - nvte_layernorm_fwd(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(), - mu_cu.data(), rsigma_cu.data(), workspace.data(), num_sm - sm_margin, - zero_centered_gamma, input.stream()); - - return {ln_out, mu, rsigma}; -} - -std::vector te_layernorm_bwd(const paddle::Tensor &dz, const paddle::Tensor &x, - const paddle::Tensor &mu, const paddle::Tensor &rsigma, - const paddle::Tensor &gamma, int64_t sm_margin, - bool zero_centered_gamma) { - auto dx = paddle::empty_like(x, x.dtype(), x.place()); - auto dgamma = paddle::empty_like(gamma, gamma.dtype(), gamma.place()); - auto dbeta = paddle::empty_like(gamma, gamma.dtype(), gamma.place()); - - TensorWrapper workspace; - - auto dz_cu = MakeNvteTensor(dz); - auto x_cu = MakeNvteTensor(x); - auto mu_cu = MakeNvteTensor(mu); - auto rsigma_cu = MakeNvteTensor(rsigma); - auto gamma_cu = MakeNvteTensor(gamma); - auto dx_cu = MakeNvteTensor(dx); - auto dgamma_cu = MakeNvteTensor(dgamma); - auto dbeta_cu = MakeNvteTensor(dbeta); - - auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount(); - - // This call populates tensors with the required config. - nvte_layernorm_bwd(dz_cu.data(), x_cu.data(), mu_cu.data(), rsigma_cu.data(), gamma_cu.data(), - dx_cu.data(), dgamma_cu.data(), dbeta_cu.data(), workspace.data(), - num_sm - sm_margin, zero_centered_gamma, dz.stream()); - - // Alloc space for Tensors. - auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), x.place()); - workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype()); - - // Actual call to bwd kernel. - nvte_layernorm_bwd(dz_cu.data(), x_cu.data(), mu_cu.data(), rsigma_cu.data(), gamma_cu.data(), - dx_cu.data(), dgamma_cu.data(), dbeta_cu.data(), workspace.data(), - num_sm - sm_margin, zero_centered_gamma, dz.stream()); - - return {dx, dgamma, dbeta}; -} - -std::vector te_rmsnorm_fwd(const paddle::Tensor &input, - const paddle::Tensor &weight, float eps, int64_t otype, - int64_t sm_margin, bool zero_centered_gamma) { - NVTE_CHECK(zero_centered_gamma == false, "zero_centered_gamma is not supported yet for RMSNorm."); - auto shape = GetShapeArray(input); - NVTE_CHECK(shape.size() == 2, "Expect the grad_output to have 2 dimensions."); - - size_t N = shape[0]; - size_t H = shape[1]; - - auto ln_out = paddle::empty_like(input, input.dtype(), input.place()); - auto rsigma = paddle::empty({static_cast(N)}, paddle::DataType::FLOAT32, input.place()); - auto input_cu = MakeNvteTensor(input); - auto gamma_cu = MakeNvteTensor(weight); - auto z_cu = MakeNvteTensor(ln_out.data(), {N, H}, Int2NvteDType(otype)); - auto rsigma_cu = MakeNvteTensor(rsigma); - TensorWrapper workspace; - - auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount(); - - // This call populates workspace tensor with the required config - nvte_rmsnorm_fwd(input_cu.data(), gamma_cu.data(), eps, z_cu.data(), rsigma_cu.data(), - workspace.data(), num_sm - sm_margin, zero_centered_gamma, input.stream()); - - // Fill workspace - auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), input.place()); - workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype()); - - // Actual call to fwd kernel - nvte_rmsnorm_fwd(input_cu.data(), gamma_cu.data(), eps, z_cu.data(), rsigma_cu.data(), - workspace.data(), num_sm - sm_margin, zero_centered_gamma, input.stream()); - - return {ln_out, rsigma}; -} - -std::vector te_rmsnorm_fwd_fp8(const paddle::Tensor &input, - const paddle::Tensor &weight, - const paddle::Tensor &scale, - paddle::Tensor &amax, // NOLINT - paddle::Tensor &scale_inv, // NOLINT - float eps, int64_t index, int64_t otype, - int64_t sm_margin, bool zero_centered_gamma) { - NVTE_CHECK(zero_centered_gamma == false, "zero_centered_gamma is not supported yet for RMSNorm."); - auto shape = GetShapeArray(input); - NVTE_CHECK(shape.size() == 2, "Expect the grad_output to have 2 dimensions."); - - size_t N = shape[0]; - size_t H = shape[1]; - - auto ln_out = paddle::empty_like(input, Nvte2PaddleDType(Int2NvteDType(otype)), input.place()); - auto rsigma = paddle::empty({static_cast(N)}, paddle::DataType::FLOAT32, input.place()); - auto input_cu = MakeNvteTensor(input); - auto gamma_cu = MakeNvteTensor(weight); - auto z_cu = MakeNvteTensor( - ln_out.data(), {N, H}, Int2NvteDType(otype), GetDataPtr(amax, index), - const_cast(GetDataPtr(scale, index)), GetDataPtr(scale_inv, index)); - auto rsigma_cu = MakeNvteTensor(rsigma); - TensorWrapper workspace; - - auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount(); - - // This call populates workspace tensor with the required config - nvte_rmsnorm_fwd(input_cu.data(), gamma_cu.data(), eps, z_cu.data(), rsigma_cu.data(), - workspace.data(), num_sm - sm_margin, zero_centered_gamma, input.stream()); - - // Fill workspace - auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), input.place()); - workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype()); - - // Actual call to fwd kernel - nvte_rmsnorm_fwd(input_cu.data(), gamma_cu.data(), eps, z_cu.data(), rsigma_cu.data(), - workspace.data(), num_sm - sm_margin, zero_centered_gamma, input.stream()); - - return {ln_out, rsigma}; -} - -std::vector te_rmsnorm_bwd(const paddle::Tensor &dz, const paddle::Tensor &x, - const paddle::Tensor &rsigma, - const paddle::Tensor &gamma, int64_t sm_margin, - bool zero_centered_gamma) { - NVTE_CHECK(zero_centered_gamma == false, "zero_centered_gamma is not supported yet for RMSNorm."); - auto dx = paddle::empty_like(x, x.dtype(), x.place()); - auto dgamma = paddle::empty_like(gamma, gamma.dtype(), gamma.place()); - - TensorWrapper workspace; - - auto dz_cu = MakeNvteTensor(dz); - auto x_cu = MakeNvteTensor(x); - auto rsigma_cu = MakeNvteTensor(rsigma); - auto gamma_cu = MakeNvteTensor(gamma); - auto dx_cu = MakeNvteTensor(dx); - auto dgamma_cu = MakeNvteTensor(dgamma); - - auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount(); - - // This call populates tensors with the required config. - nvte_rmsnorm_bwd(dz_cu.data(), x_cu.data(), rsigma_cu.data(), gamma_cu.data(), dx_cu.data(), - dgamma_cu.data(), workspace.data(), num_sm - sm_margin, zero_centered_gamma, - dz.stream()); - - // Alloc space for Tensors. - auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), x.place()); - workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype()); - - // Actual call to bwd kernel. - nvte_rmsnorm_bwd(dz_cu.data(), x_cu.data(), rsigma_cu.data(), gamma_cu.data(), dx_cu.data(), - dgamma_cu.data(), workspace.data(), num_sm - sm_margin, zero_centered_gamma, - dz.stream()); - - return {dx, dgamma}; -} - -__global__ void set_rng_state( - [[maybe_unused]] unsigned int - identifier, // This is used to relate kernel to cudaGraph nodes please refer to https://github.com/PaddlePaddle/Paddle/pull/60516 - std::pair seed_offset, int64_t *rng_state_ptr) { - rng_state_ptr[0] = static_cast(seed_offset.first); - rng_state_ptr[1] = static_cast(seed_offset.second); -} - -void UpdateRandomGenerator(phi::Place place, cudaStream_t stream, int rng_elts_per_thread, - paddle::Tensor &rng_state) { - // extract random number generator seed and offset - const phi::DeviceContext *dev_ctx = - paddle::experimental::DeviceContextPool::Instance().Get(place); - - phi::Generator *gen_cuda = dev_ctx->GetGenerator(); - auto seed_offset = gen_cuda->IncrementOffset(rng_elts_per_thread); - int64_t *rng_state_p = static_cast(rng_state.data()); -#if PADDLE_VERSION > 261 - auto state_index = gen_cuda->GetStateIndex(); - - auto parameterSetter = [gen_cuda, state_index, - rng_elts_per_thread](phi::backends::gpu::gpuKernelParams ¶ms) { - // ensure the generator use correct state index - gen_cuda->SetStateIndex(state_index); - auto seed_offset = gen_cuda->IncrementOffset(rng_elts_per_thread); - params.As>(1) = seed_offset; - }; - - phi::backends::gpu::CUDAGraphNodeLauncher::gpuKernelCallback_t cudaKernelCallback = - [=](unsigned int id) { - void *functionPtr = reinterpret_cast(&set_rng_state); - cudaFunction_t cudaFunc; - PADDLE_ENFORCE_GPU_SUCCESS(cudaGetFuncBySymbol(&cudaFunc, functionPtr)); - set_rng_state<<<1, 1, 0, stream>>>(id, seed_offset, rng_state_p); - return cudaFunc; - }; - phi::backends::gpu::CUDAGraphNodeLauncher::Instance().KernelNodeLaunch(parameterSetter, - cudaKernelCallback); -#else - set_rng_state<<<1, 1, 0, stream>>>(0, seed_offset, rng_state_p); -#endif -} - -void te_fused_attn_fwd_qkvpacked(const paddle::Tensor &QKV, const paddle::Tensor &cu_seqlens, - const paddle::optional &Bias, - paddle::Tensor &O, // NOLINT - paddle::optional &softmax_aux, // NOLINT - paddle::Tensor &rng_state, // NOLINT - int64_t b, int64_t h, int64_t d, int64_t total_seqs, - int64_t max_seqlen, bool is_training, float attn_scale, - float p_dropout, const std::string &qkv_layout, - const std::string &bias_type, const std::string &attn_mask_type, - const int64_t qkv_type, int64_t rng_elts_per_thread) { - if (is_training && !softmax_aux) { - NVTE_ERROR("softmax_aux must be provided when training. \n"); - } - - auto qkv_dtype = Int2NvteDType(qkv_type); - // construct NVTE tensors - TensorWrapper te_QKV, te_S, te_O, te_Bias, te_cu_seqlens; - if (qkv_dtype == DType::kBFloat16 || qkv_dtype == DType::kFloat16) { - // BF16 or FP16 - te_QKV = MakeNvteTensor(QKV); - te_S = MakeNvteTensor(nullptr, std::vector{0}, DType::kFloat32); - te_O = MakeNvteTensor(O); - } else { // TODO: support fp8 - NVTE_ERROR("Fused attention only supports BF16/FP16 data types. \n"); - } - if ((bias_type != "no_bias") && Bias) { - auto bias_shape = Bias->shape(); - std::vector shape{bias_shape.begin(), bias_shape.end()}; - te_Bias = MakeNvteTensor(GetOptionalDataPtr(Bias), shape, DType::kFloat32); - } - te_cu_seqlens = MakeNvteTensor(cu_seqlens.data(), {static_cast(b + 1)}, DType::kInt32); - - // convert strings to enums - NVTE_QKV_Layout qkv_layout_enum = get_nvte_qkv_layout(qkv_layout); - NVTE_Bias_Type bias_type_enum = get_nvte_bias_type(bias_type); - NVTE_Mask_Type attn_mask_type_enum = get_nvte_mask_type(attn_mask_type); - - UpdateRandomGenerator(QKV.place(), QKV.stream(), rng_elts_per_thread, rng_state); - auto te_rng_state = MakeNvteTensor(rng_state); - - // create auxiliary output tensors - NVTETensorPack nvte_aux_tensor_pack; - nvte_tensor_pack_create(&nvte_aux_tensor_pack); - - // create workspace - TensorWrapper workspace; - - auto dummy_seq_offsets = TensorWrapper(nullptr, {static_cast(b + 1)}, DType::kInt32); - // populate tensors with appropriate shapes and dtypes - nvte_fused_attn_fwd_qkvpacked(te_QKV.data(), te_Bias.data(), te_S.data(), te_O.data(), - &nvte_aux_tensor_pack, te_cu_seqlens.data(), - dummy_seq_offsets.data(), te_rng_state.data(), max_seqlen, - is_training, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, - attn_mask_type_enum, -1, -1, workspace.data(), QKV.stream()); - - // allocate memory for workspace and auxiliary output tensors - auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), QKV.place()); - workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype()); - - auto *output_s = reinterpret_cast(nvte_aux_tensor_pack.tensors[0]); - output_s->data.dptr = GetOptionalDataPtr(softmax_aux); - - // execute the kernel - nvte_fused_attn_fwd_qkvpacked(te_QKV.data(), te_Bias.data(), te_S.data(), te_O.data(), - &nvte_aux_tensor_pack, te_cu_seqlens.data(), - dummy_seq_offsets.data(), te_rng_state.data(), max_seqlen, - is_training, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, - attn_mask_type_enum, -1, -1, workspace.data(), QKV.stream()); - - // destroy tensor wrappers, but not allocated memory - nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); -} - -// fused attention BWD with packed QKV -void te_fused_attn_bwd_qkvpacked(const paddle::Tensor &QKV, const paddle::Tensor &cu_seqlens, - const paddle::Tensor &O, const paddle::Tensor &dO, - const paddle::Tensor &softmax_aux, - paddle::Tensor &dQKV, // NOLINT - paddle::optional &dBias, // NOLINT - paddle::Tensor &rng_state, // NOLINT - int64_t b, int64_t h, int64_t d, int64_t total_seqs, - int64_t max_seqlen, float attn_scale, float p_dropout, - const std::string &qkv_layout, const std::string &bias_type, - const std::string &attn_mask_type, int64_t qkv_type, - bool deterministic) { - TensorWrapper te_dBias; - if (bias_type != "no_bias" && dBias) { - auto bias_shape = dBias->shape(); - std::vector shape{bias_shape.begin(), bias_shape.end()}; - te_dBias = MakeNvteTensor(GetOptionalDataPtr(dBias), shape, DType::kFloat32); - } - - auto qkv_dtype = Int2NvteDType(qkv_type); - // construct NVTE tensors - TensorWrapper te_QKV, te_O, te_dO, te_S, te_dP, te_dQKV; - if (qkv_dtype == DType::kBFloat16 || qkv_dtype == DType::kFloat16) { - // BF16 or FP16 - te_QKV = MakeNvteTensor(QKV); - te_O = MakeNvteTensor(O); - te_dO = MakeNvteTensor(dO); - te_S = MakeNvteTensor(nullptr, std::vector(0), DType::kFloat32); - te_dP = MakeNvteTensor(nullptr, std::vector(0), DType::kFloat32); - te_dQKV = MakeNvteTensor(dQKV); - } else { - NVTE_ERROR("Fused attention only supports BF16/FP16 data types. \n"); - } - - // convert strings to enums - NVTE_QKV_Layout qkv_layout_enum = get_nvte_qkv_layout(qkv_layout); - NVTE_Bias_Type bias_type_enum = get_nvte_bias_type(bias_type); - NVTE_Mask_Type attn_mask_type_enum = get_nvte_mask_type(attn_mask_type); - - // convert auxiliary tensors from forward into NVTETensors - NVTETensorPack nvte_aux_tensor_pack; - nvte_tensor_pack_create(&nvte_aux_tensor_pack); - - nvte_aux_tensor_pack.size = 2; // 1. softmax_aux 2. rng_state - auto *output_s = reinterpret_cast(nvte_aux_tensor_pack.tensors[0]); - auto *fwd_rng_state = reinterpret_cast(nvte_aux_tensor_pack.tensors[1]); - output_s->data.shape = - std::vector({static_cast(b), static_cast(h), - static_cast(max_seqlen), static_cast(max_seqlen)}); - output_s->data.dptr = const_cast(softmax_aux.data()); - fwd_rng_state->data.shape = std::vector({2}); - fwd_rng_state->data.dptr = const_cast(rng_state.data()); - - // create cu_seqlens tensorwrappers - TensorWrapper te_cu_seqlens; - te_cu_seqlens = MakeNvteTensor(cu_seqlens.data(), {static_cast(b + 1)}, DType::kInt32); - - // create workspace - TensorWrapper workspace; - - auto dummy_seq_offsets = TensorWrapper(nullptr, {static_cast(b + 1)}, DType::kInt32); - // populate tensors with appropriate shapes and dtypes - nvte_fused_attn_bwd_qkvpacked( - te_QKV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), &nvte_aux_tensor_pack, - te_dQKV.data(), te_dBias.data(), te_cu_seqlens.data(), dummy_seq_offsets.data(), max_seqlen, - attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, attn_mask_type_enum, -1, -1, - deterministic, workspace.data(), QKV.stream()); - - // allocate memory for workspace - auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), QKV.place()); - workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype()); - - // execute kernel - nvte_fused_attn_bwd_qkvpacked( - te_QKV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), &nvte_aux_tensor_pack, - te_dQKV.data(), te_dBias.data(), te_cu_seqlens.data(), dummy_seq_offsets.data(), max_seqlen, - attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, attn_mask_type_enum, -1, -1, - deterministic, workspace.data(), QKV.stream()); - - // destroy tensor wrappers - nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); -} - -void te_fused_attn_fwd_kvpacked( - const paddle::Tensor &Q, const paddle::Tensor &KV, const paddle::Tensor &cu_seqlens_q, - const paddle::Tensor &cu_seqlens_kv, const paddle::optional &Bias, - paddle::Tensor &O, // NOLINT - paddle::optional &softmax_aux, // NOLINT - paddle::Tensor &rng_state, // NOLINT - int64_t b, int64_t h, int64_t d, int64_t total_seqs_q, int64_t total_seqs_kv, - int64_t max_seqlen_q, int64_t max_seqlen_kv, bool is_training, float attn_scale, - float p_dropout, const std::string &qkv_layout, const std::string &bias_type, - const std::string &attn_mask_type, const int64_t qkv_type, int64_t rng_elts_per_thread) { - if (is_training && !softmax_aux) { - NVTE_ERROR("softmax_aux must be provided when training. \n"); - } - - auto qkv_dtype = Int2NvteDType(qkv_type); - - // construct NVTE tensors - TensorWrapper te_Q, te_KV, te_S, te_O, te_Bias, te_cu_seqlens_q, te_cu_seqlens_kv; - if (qkv_dtype == DType::kBFloat16 || qkv_dtype == DType::kFloat16) { - // BF16 or FP16 - te_Q = MakeNvteTensor( - Q.data(), - {static_cast(total_seqs_q), static_cast(h), static_cast(d)}, - qkv_dtype); - te_KV = MakeNvteTensor( - KV.data(), - {static_cast(total_seqs_kv), 2, static_cast(h), static_cast(d)}, - qkv_dtype); - te_S = MakeNvteTensor(nullptr, std::vector{0}, DType::kFloat32); - te_O = MakeNvteTensor( - O.data(), - {static_cast(total_seqs_q), static_cast(h), static_cast(d)}, - qkv_dtype); - } else { - NVTE_ERROR("Fused attention only supports BF16/FP16 data types. \n"); - } - - if ((bias_type != "no_bias") && Bias) { - auto bias_shape = Bias->shape(); - std::vector shape{bias_shape.begin(), bias_shape.end()}; - te_Bias = MakeNvteTensor(GetOptionalDataPtr(Bias), shape, DType::kFloat32); - } - - te_cu_seqlens_q = - MakeNvteTensor(cu_seqlens_q.data(), {static_cast(b + 1)}, DType::kInt32); - te_cu_seqlens_kv = - MakeNvteTensor(cu_seqlens_kv.data(), {static_cast(b + 1)}, DType::kInt32); - - // convert strings to enums - NVTE_QKV_Layout qkv_layout_enum = get_nvte_qkv_layout(qkv_layout); - NVTE_Bias_Type bias_type_enum = get_nvte_bias_type(bias_type); - NVTE_Mask_Type attn_mask_type_enum = get_nvte_mask_type(attn_mask_type); - - UpdateRandomGenerator(Q.place(), Q.stream(), rng_elts_per_thread, rng_state); - auto te_rng_state = MakeNvteTensor(rng_state); - - // create auxiliary output tensors - NVTETensorPack nvte_aux_tensor_pack; - nvte_tensor_pack_create(&nvte_aux_tensor_pack); - - // create workspace - TensorWrapper workspace; - - auto dummy_seq_offsets = TensorWrapper(nullptr, {static_cast(b + 1)}, DType::kInt32); - // populate tensors with appropriate shapes and dtypes - nvte_fused_attn_fwd_kvpacked( - te_Q.data(), te_KV.data(), te_Bias.data(), te_S.data(), te_O.data(), &nvte_aux_tensor_pack, - te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), dummy_seq_offsets.data(), - dummy_seq_offsets.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, - attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, attn_mask_type_enum, -1, -1, - workspace.data(), Q.stream()); - - // allocate memory for workspace and auxiliary output tensors - auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), Q.place()); - workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype()); - - auto *output_s = reinterpret_cast(nvte_aux_tensor_pack.tensors[0]); - output_s->data.dptr = GetOptionalDataPtr(softmax_aux); - - // execute the kernel - nvte_fused_attn_fwd_kvpacked( - te_Q.data(), te_KV.data(), te_Bias.data(), te_S.data(), te_O.data(), &nvte_aux_tensor_pack, - te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), dummy_seq_offsets.data(), - dummy_seq_offsets.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, - attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, attn_mask_type_enum, -1, -1, - workspace.data(), Q.stream()); - - // destroy tensor wrappers, but not allocated memory - nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); -} - -// fused attention BWD with packed KV -void te_fused_attn_bwd_kvpacked(const paddle::Tensor &Q, const paddle::Tensor &KV, - const paddle::Tensor &cu_seqlens_q, - const paddle::Tensor &cu_seqlens_kv, const paddle::Tensor &O, - const paddle::Tensor &dO, const paddle::Tensor &softmax_aux, - paddle::Tensor &dQ, // NOLINT - paddle::Tensor &dKV, // NOLINT - paddle::optional &dBias, // NOLINT - paddle::Tensor &rng_state, // NOLINT - int64_t b, int64_t h, int64_t d, int64_t total_seqs_q, - int64_t total_seqs_kv, int64_t max_seqlen_q, int64_t max_seqlen_kv, - float attn_scale, float p_dropout, const std::string &qkv_layout, - const std::string &bias_type, const std::string &attn_mask_type, - int64_t qkv_type, bool deterministic) { - TensorWrapper te_dBias; - if (bias_type != "no_bias" && dBias) { - auto bias_shape = dBias->shape(); - std::vector shape{bias_shape.begin(), bias_shape.end()}; - te_dBias = MakeNvteTensor(GetOptionalDataPtr(dBias), shape, DType::kFloat32); - } - - auto qkv_dtype = Int2NvteDType(qkv_type); - // construct NVTE tensors - TensorWrapper te_Q, te_KV, te_O, te_dO, te_S, te_dP, te_dQ, te_dKV; - if (qkv_dtype == DType::kBFloat16 || qkv_dtype == DType::kFloat16) { - // BF16 or FP16 - te_Q = MakeNvteTensor(Q); - te_KV = MakeNvteTensor(KV); - te_O = MakeNvteTensor(O); - te_dO = MakeNvteTensor(dO); - te_S = MakeNvteTensor(nullptr, std::vector(0), DType::kFloat32); - te_dP = MakeNvteTensor(nullptr, std::vector(0), DType::kFloat32); - te_dQ = MakeNvteTensor(dQ); - te_dKV = MakeNvteTensor(dKV); - } else { - NVTE_ERROR("Fused attention only supports BF16/FP16 data types. \n"); - } - - // convert strings to enums - NVTE_QKV_Layout qkv_layout_enum = get_nvte_qkv_layout(qkv_layout); - NVTE_Bias_Type bias_type_enum = get_nvte_bias_type(bias_type); - NVTE_Mask_Type attn_mask_type_enum = get_nvte_mask_type(attn_mask_type); - - // convert auxiliary tensors from forward into NVTETensors - NVTETensorPack nvte_aux_tensor_pack; - nvte_tensor_pack_create(&nvte_aux_tensor_pack); - - nvte_aux_tensor_pack.size = 2; - auto *output_s = reinterpret_cast(nvte_aux_tensor_pack.tensors[0]); - auto *fwd_rng_state = reinterpret_cast(nvte_aux_tensor_pack.tensors[1]); - output_s->data.shape = - std::vector({static_cast(b), static_cast(h), - static_cast(max_seqlen_q), static_cast(max_seqlen_kv)}); - output_s->data.dptr = const_cast(softmax_aux.data()); - fwd_rng_state->data.shape = std::vector({2}); - fwd_rng_state->data.dptr = const_cast(rng_state.data()); - - // create cu_seqlens tensorwrappers - TensorWrapper te_cu_seqlens_q, te_cu_seqlens_kv; - te_cu_seqlens_q = - MakeNvteTensor(cu_seqlens_q.data(), {static_cast(b + 1)}, DType::kInt32); - te_cu_seqlens_kv = - MakeNvteTensor(cu_seqlens_kv.data(), {static_cast(b + 1)}, DType::kInt32); - - // create workspace - TensorWrapper workspace; - - auto dummy_seq_offsets = TensorWrapper(nullptr, {static_cast(b + 1)}, DType::kInt32); - // populate tensors with appropriate shapes and dtypes - nvte_fused_attn_bwd_kvpacked( - te_Q.data(), te_KV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), - &nvte_aux_tensor_pack, te_dQ.data(), te_dKV.data(), te_dBias.data(), te_cu_seqlens_q.data(), - te_cu_seqlens_kv.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(), max_seqlen_q, - max_seqlen_kv, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, attn_mask_type_enum, - -1, -1, deterministic, workspace.data(), Q.stream()); - - // allocate memory for workspace - auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), Q.place()); - workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype()); - - // execute kernel - nvte_fused_attn_bwd_kvpacked( - te_Q.data(), te_KV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), - &nvte_aux_tensor_pack, te_dQ.data(), te_dKV.data(), te_dBias.data(), te_cu_seqlens_q.data(), - te_cu_seqlens_kv.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(), max_seqlen_q, - max_seqlen_kv, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, attn_mask_type_enum, - -1, -1, deterministic, workspace.data(), Q.stream()); - - // destroy tensor wrappers - nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); -} - -void te_fused_attn_fwd(const paddle::Tensor &Q, const paddle::Tensor &K, const paddle::Tensor &V, - const paddle::Tensor &cu_seqlens_q, const paddle::Tensor &cu_seqlens_kv, - const paddle::optional &Bias, - paddle::Tensor &O, // NOLINT - paddle::optional &softmax_aux, // NOLINT - paddle::Tensor &rng_state, // NOLINT - int64_t b, int64_t h, int64_t d, int64_t max_seqlen_q, int64_t max_seqlen_kv, - bool is_training, float attn_scale, float p_dropout, - const std::string &qkv_layout, const std::string &bias_type, - const std::string &attn_mask_type, const int64_t qkv_type, - int64_t rng_elts_per_thread) { - if (is_training && !softmax_aux) { - NVTE_ERROR("softmax_aux must be provided when training. \n"); - } - - auto qkv_dtype = Int2NvteDType(qkv_type); - // construct NVTE tensors - TensorWrapper te_Q, te_K, te_V, te_S, te_O, te_Bias, te_cu_seqlens_q, te_cu_seqlens_kv; - if (qkv_dtype == DType::kBFloat16 || qkv_dtype == DType::kFloat16) { - // BF16 or FP16 - te_Q = MakeNvteTensor(Q); - te_K = MakeNvteTensor(K); - te_V = MakeNvteTensor(V); - te_S = MakeNvteTensor(nullptr, std::vector{0}, DType::kFloat32); - te_O = MakeNvteTensor(O); - } else { // TODO: support fp8 - NVTE_ERROR("Fused attention only supports BF16/FP16 data types. \n"); - } - if ((bias_type != "no_bias") && Bias) { - auto bias_shape = Bias->shape(); - std::vector shape{bias_shape.begin(), bias_shape.end()}; - te_Bias = MakeNvteTensor(GetOptionalDataPtr(Bias), shape, DType::kFloat32); - } - te_cu_seqlens_q = - MakeNvteTensor(cu_seqlens_q.data(), {static_cast(b + 1)}, DType::kInt32); - te_cu_seqlens_kv = - MakeNvteTensor(cu_seqlens_kv.data(), {static_cast(b + 1)}, DType::kInt32); - - // convert strings to enums - NVTE_QKV_Layout qkv_layout_enum = get_nvte_qkv_layout(qkv_layout); - NVTE_Bias_Type bias_type_enum = get_nvte_bias_type(bias_type); - NVTE_Mask_Type attn_mask_type_enum = get_nvte_mask_type(attn_mask_type); - - // extract random number generator seed and offset - auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(Q.place()); - auto gen_cuda = dev_ctx->GetGenerator(); - auto seed_offset = gen_cuda->IncrementOffset(rng_elts_per_thread); - auto stream = Q.stream(); - auto rng_state_p = static_cast(rng_state.data()); -#if PADDLE_VERSION > 261 - auto state_index = gen_cuda->GetStateIndex(); - auto parameterSetter = [gen_cuda, state_index, - rng_elts_per_thread](phi::backends::gpu::gpuKernelParams ¶ms) { - // ensure the generator use correct state index - gen_cuda->SetStateIndex(state_index); - auto seed_offset = gen_cuda->IncrementOffset(rng_elts_per_thread); - params.As>(1) = seed_offset; - }; - - phi::backends::gpu::CUDAGraphNodeLauncher::gpuKernelCallback_t cudaKernelCallback = - [=](unsigned int id) { - void *functionPtr = reinterpret_cast(&set_rng_state); - cudaFunction_t cudaFunc; - PADDLE_ENFORCE_GPU_SUCCESS(cudaGetFuncBySymbol(&cudaFunc, functionPtr)); - set_rng_state<<<1, 1, 0, stream>>>(id, seed_offset, rng_state_p); - return cudaFunc; - }; - phi::backends::gpu::CUDAGraphNodeLauncher::Instance().KernelNodeLaunch(parameterSetter, - cudaKernelCallback); -#else - set_rng_state<<<1, 1, 0, stream>>>(0, seed_offset, rng_state_p); -#endif - - auto te_rng_state = MakeNvteTensor(rng_state); - - // create auxiliary output tensors - NVTETensorPack nvte_aux_tensor_pack; - nvte_tensor_pack_create(&nvte_aux_tensor_pack); - - // create workspace - TensorWrapper workspace; - - auto dummy_seq_offsets = TensorWrapper(nullptr, {static_cast(b + 1)}, DType::kInt32); - // populate tensors with appropriate shapes and dtypes - nvte_fused_attn_fwd(te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_S.data(), - te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), - te_cu_seqlens_kv.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(), - te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, attn_scale, - p_dropout, qkv_layout_enum, bias_type_enum, attn_mask_type_enum, -1, -1, - workspace.data(), Q.stream()); - - // allocate memory for workspace and auxiliary output tensors - auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), Q.place()); - - workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype()); - - auto *output_s = reinterpret_cast(nvte_aux_tensor_pack.tensors[0]); - output_s->data.dptr = GetOptionalDataPtr(softmax_aux); - - // execute the kernel - nvte_fused_attn_fwd(te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_S.data(), - te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), - te_cu_seqlens_kv.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(), - te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, attn_scale, - p_dropout, qkv_layout_enum, bias_type_enum, attn_mask_type_enum, -1, -1, - workspace.data(), Q.stream()); - - // destroy tensor wrappers, but not allocated memory - nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); -} - -void te_fused_attn_bwd(const paddle::Tensor &Q, const paddle::Tensor &K, const paddle::Tensor &V, - const paddle::Tensor &cu_seqlens_q, const paddle::Tensor &cu_seqlens_kv, - const paddle::Tensor &O, const paddle::Tensor &dO, - const paddle::Tensor &softmax_aux, - paddle::Tensor &dQ, // NOLINT - paddle::Tensor &dK, // NOLINT - paddle::Tensor &dV, // NOLINT - paddle::optional &dBias, // NOLINT - paddle::Tensor &rng_state, // NOLINT - int64_t b, int64_t h, int64_t d, int64_t max_seqlen_q, int64_t max_seqlen_kv, - float attn_scale, float p_dropout, const std::string &qkv_layout, - const std::string &bias_type, const std::string &attn_mask_type, - int64_t qkv_type, bool deterministic) { - TensorWrapper te_dBias; - if (bias_type != "no_bias" && dBias) { - auto bias_shape = dBias->shape(); - std::vector shape{bias_shape.begin(), bias_shape.end()}; - te_dBias = MakeNvteTensor(GetOptionalDataPtr(dBias), shape, DType::kFloat32); - } - - auto qkv_dtype = Int2NvteDType(qkv_type); - // construct NVTE tensors - TensorWrapper te_Q, te_K, te_V, te_O, te_dO, te_S, te_dP, te_dQ, te_dK, te_dV; - if (qkv_dtype == DType::kBFloat16 || qkv_dtype == DType::kFloat16) { - // BF16 or FP16 - te_Q = MakeNvteTensor(Q); - te_K = MakeNvteTensor(K); - te_V = MakeNvteTensor(V); - te_O = MakeNvteTensor(O); - te_dO = MakeNvteTensor(dO); - te_S = MakeNvteTensor(nullptr, std::vector(0), DType::kFloat32); - te_dP = MakeNvteTensor(nullptr, std::vector(0), DType::kFloat32); - te_dQ = MakeNvteTensor(dQ); - te_dK = MakeNvteTensor(dK); - te_dV = MakeNvteTensor(dV); - } else { - NVTE_ERROR("Fused attention only supports BF16/FP16 data types. \n"); - } - - // convert strings to enums - NVTE_QKV_Layout qkv_layout_enum = get_nvte_qkv_layout(qkv_layout); - NVTE_Bias_Type bias_type_enum = get_nvte_bias_type(bias_type); - NVTE_Mask_Type attn_mask_type_enum = get_nvte_mask_type(attn_mask_type); - - // convert auxiliary tensors from forward into NVTETensors - NVTETensorPack nvte_aux_tensor_pack; - nvte_tensor_pack_create(&nvte_aux_tensor_pack); - - nvte_aux_tensor_pack.size = 2; - auto *output_s = reinterpret_cast(nvte_aux_tensor_pack.tensors[0]); - auto *fwd_rng_state = reinterpret_cast(nvte_aux_tensor_pack.tensors[1]); - output_s->data.shape = - std::vector({static_cast(b), static_cast(h), - static_cast(max_seqlen_q), static_cast(max_seqlen_kv)}); - output_s->data.dptr = const_cast(softmax_aux.data()); - fwd_rng_state->data.shape = std::vector({2}); - fwd_rng_state->data.dptr = const_cast(rng_state.data()); - - // create cu_seqlens tensorwrappers - TensorWrapper te_cu_seqlens_q, te_cu_seqlens_kv; - te_cu_seqlens_q = - MakeNvteTensor(cu_seqlens_q.data(), {static_cast(b + 1)}, DType::kInt32); - te_cu_seqlens_kv = - MakeNvteTensor(cu_seqlens_kv.data(), {static_cast(b + 1)}, DType::kInt32); - - // create workspace - TensorWrapper workspace; - - auto dummy_seq_offsets = TensorWrapper(nullptr, {static_cast(b + 1)}, DType::kInt32); - // populate tensors with appropriate shapes and dtypes - nvte_fused_attn_bwd(te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), te_S.data(), - te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), te_dK.data(), te_dV.data(), - te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), - dummy_seq_offsets.data(), dummy_seq_offsets.data(), max_seqlen_q, - max_seqlen_kv, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, - attn_mask_type_enum, -1, -1, deterministic, workspace.data(), Q.stream()); - - // allocate memory for workspace - auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), Q.place()); - workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype()); - - // execute kernel - nvte_fused_attn_bwd(te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), te_S.data(), - te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), te_dK.data(), te_dV.data(), - te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), - dummy_seq_offsets.data(), dummy_seq_offsets.data(), max_seqlen_q, - max_seqlen_kv, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, - attn_mask_type_enum, -1, -1, deterministic, workspace.data(), Q.stream()); - - // destroy tensor wrappers - nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); -} - -std::vector te_scaled_softmax_forward(const paddle::Tensor &input, - float scale_factor) { - NVTE_CHECK(input.shape().size() == 4, "expected 4D tensor"); - NVTE_CHECK( - (input.dtype() == paddle::DataType::FLOAT16) || (input.dtype() == paddle::DataType::BFLOAT16), - "Only fp16 and bf16 are supported"); - - const int batches = input.shape()[0]; - const int attn_heads = input.shape()[1]; - const int query_seq_len = input.shape()[2]; - const int key_seq_len = input.shape()[3]; - - NVTE_CHECK(key_seq_len <= 4096); - NVTE_CHECK(query_seq_len > 1); - - // Output - auto softmax_results = paddle::empty_like(input, input.dtype(), input.place()); - - auto input_cu = MakeNvteTensor(input); - auto softmax_results_cu = MakeNvteTensor(softmax_results); - - nvte_scaled_softmax_forward(input_cu.data(), softmax_results_cu.data(), scale_factor, - input.stream()); - - return {softmax_results}; -} - -void te_scaled_softmax_backward(paddle::Tensor &output_grads, // NOLINT - const paddle::Tensor &softmax_results, float scale_factor) { - NVTE_CHECK(output_grads.shape().size() == 4, "expected 4D tensor"); - NVTE_CHECK(softmax_results.shape().size() == 4, "expected 4D tensor"); - - NVTE_CHECK((output_grads.dtype() == paddle::DataType::FLOAT16) || - (output_grads.dtype() == paddle::DataType::BFLOAT16), - "Only fp16 and bf16 are supported"); - NVTE_CHECK((softmax_results.dtype() == paddle::DataType::FLOAT16) || - (softmax_results.dtype() == paddle::DataType::BFLOAT16), - "Only fp16 and bf16 are supported"); - - auto output_grads_cu = MakeNvteTensor(output_grads); - auto softmax_results_cu = MakeNvteTensor(softmax_results); - - // Produce gradients in place. - nvte_scaled_softmax_backward(output_grads_cu.data(), softmax_results_cu.data(), - output_grads_cu.data(), scale_factor, softmax_results.stream()); -} - -std::vector te_scaled_masked_softmax_forward(const paddle::Tensor &input, - const paddle::Tensor &mask, - float scale_factor) { - NVTE_CHECK(input.shape().size() == 4, "expected 4D tensor"); - NVTE_CHECK(mask.shape().size() == 4, "expected 4D tensor"); - NVTE_CHECK( - (input.dtype() == paddle::DataType::FLOAT16) || (input.dtype() == paddle::DataType::BFLOAT16), - "Only fp16 and bf16 are supported"); - - const int batches = input.shape()[0]; - const int pad_batches = mask.shape()[0]; - const int attn_heads = input.shape()[1]; - const int query_seq_len = input.shape()[2]; - const int key_seq_len = input.shape()[3]; - - NVTE_CHECK(key_seq_len <= 4096); - NVTE_CHECK(query_seq_len > 1); - NVTE_CHECK(pad_batches == 1 || pad_batches == batches); - NVTE_CHECK(mask.shape()[1] == 1); - NVTE_CHECK(mask.shape()[2] == query_seq_len); - NVTE_CHECK(mask.shape()[3] == key_seq_len); - - // Output - auto softmax_results = paddle::empty_like(input, input.dtype(), input.place()); - - auto input_cu = MakeNvteTensor(input); - auto mask_cu = MakeNvteTensor(mask); - auto softmax_results_cu = MakeNvteTensor(softmax_results); - - nvte_scaled_masked_softmax_forward(input_cu.data(), mask_cu.data(), softmax_results_cu.data(), - scale_factor, input.stream()); - - return {softmax_results}; -} - -void te_scaled_masked_softmax_backward(paddle::Tensor &output_grads, // NOLINT - const paddle::Tensor &softmax_results, float scale_factor) { - NVTE_CHECK(output_grads.shape().size() == 4, "expected 4D tensor"); - NVTE_CHECK(softmax_results.shape().size() == 4, "expected 4D tensor"); - - NVTE_CHECK((output_grads.dtype() == paddle::DataType::FLOAT16) || - (output_grads.dtype() == paddle::DataType::BFLOAT16), - "Only fp16 and bf16 are supported"); - NVTE_CHECK((softmax_results.dtype() == paddle::DataType::FLOAT16) || - (softmax_results.dtype() == paddle::DataType::BFLOAT16), - "Only fp16 and bf16 are supported"); - - auto output_grads_cu = MakeNvteTensor(output_grads); - auto softmax_results_cu = MakeNvteTensor(softmax_results); - - // Produce gradients in place. - nvte_scaled_softmax_backward(output_grads_cu.data(), softmax_results_cu.data(), - output_grads_cu.data(), scale_factor, softmax_results.stream()); -} - -std::vector te_scaled_upper_triang_masked_softmax_forward( - const paddle::Tensor &input, float scale_factor) { - NVTE_CHECK(input.shape().size() == 3, "expected 3D tensor"); - NVTE_CHECK( - (input.dtype() == paddle::DataType::FLOAT16) || (input.dtype() == paddle::DataType::BFLOAT16), - "Only fp16 and bf16 are supported"); - - const int attn_batches = input.shape()[0]; - const int seq_len = input.shape()[1]; - NVTE_CHECK(seq_len <= 2048); - - // Output - auto softmax_results = paddle::empty_like(input, input.dtype(), input.place()); - - auto input_cu = MakeNvteTensor(input); - auto softmax_results_cu = MakeNvteTensor(softmax_results); - - nvte_scaled_upper_triang_masked_softmax_forward(input_cu.data(), softmax_results_cu.data(), - scale_factor, input.stream()); - - return {softmax_results}; -} - -void te_scaled_upper_triang_masked_softmax_backward(paddle::Tensor &output_grads, // NOLINT - const paddle::Tensor &softmax_results, - float scale_factor) { - NVTE_CHECK(output_grads.shape().size() == 3, "expected 3D tensor"); - NVTE_CHECK(softmax_results.shape().size() == 3, "expected 3D tensor"); - - NVTE_CHECK((output_grads.dtype() == paddle::DataType::FLOAT16) || - (output_grads.dtype() == paddle::DataType::BFLOAT16), - "Only fp16 and bf16 are supported"); - NVTE_CHECK((softmax_results.dtype() == paddle::DataType::FLOAT16) || - (softmax_results.dtype() == paddle::DataType::BFLOAT16), - "Only fp16 and bf16 are supported"); - NVTE_CHECK(output_grads.shape()[1] == output_grads.shape()[2]); - - auto output_grads_cu = MakeNvteTensor(output_grads); - auto softmax_results_cu = MakeNvteTensor(softmax_results); - - // Produce gradients in place. - nvte_scaled_upper_triang_masked_softmax_backward( - output_grads_cu.data(), softmax_results_cu.data(), output_grads_cu.data(), scale_factor, - softmax_results.stream()); -} - -__global__ void UpdateFP8MetaKernel( - [[maybe_unused]] unsigned int - identifier, // This is used to relate kernel to cudaGraph nodes please refer to https://github.com/PaddlePaddle/Paddle/pull/60516 - const float *amax, const float *rolled_amax_history, const bool *non_weight_mask, - float *amax_history, float *scale, float *scale_inv, bool update_weight_scale_inv, float margin, - float fp8_max, size_t history_numel, size_t amax_numel) { - int idx = blockIdx.x * blockDim.x + threadIdx.x; - - if (idx >= history_numel) { - return; - } - - amax_history[idx] = rolled_amax_history[idx]; - - if (idx < amax_numel) { - float sf = (fp8_max / amax[idx]) / powf(2.0f, margin); - float scale_reg = ((amax[idx] > 0.0f) && isfinite(amax[idx])) ? sf : scale[idx]; - scale[idx] = scale_reg; - if (update_weight_scale_inv || non_weight_mask[idx]) scale_inv[idx] = 1.0f / scale_reg; - amax_history[idx] = 0.0f; - } -} - -constexpr int BLOCK_SIZE = 512; - -void amax_and_scale_update_inplace(paddle::Tensor &amax_history, // NOLINT - paddle::Tensor &scale, // NOLINT - paddle::Tensor &scale_inv, // NOLINT - const paddle::Tensor &non_weight_mask, int64_t fp8_dtype, - float margin, const std::string &amax_compute) { - auto amax_history_ = MakeNvteTensor(amax_history); - auto scale_ = MakeNvteTensor(scale); - auto scale_inv_ = MakeNvteTensor(scale_inv); - const auto non_weight_mask_ = MakeNvteTensor(non_weight_mask); - nvte_delayed_scaling_recipe_amax_and_scale_update( - amax_history_.data(), scale_.data(), scale_inv_.data(), non_weight_mask_.data(), - amax_history_.data(), scale_.data(), scale_inv_.data(), amax_compute.c_str(), - static_cast(fp8_dtype), margin, amax_history.stream()); -} - -void amax_and_scale_update_inplace_legacy( - paddle::Tensor &amax_history, // NOLINT - paddle::Tensor &scale, // NOLINT - paddle::Tensor &scale_inv, // NOLINT - const paddle::Tensor &non_weight_mask, - const paddle::optional ¤t_step_id_tensor, bool update_weight_scale_inv, - bool fwd_update, float fp8_max, float margin, const std::string &amax_compute) { -#if PADDLE_VERSION > 261 - NVTE_CHECK(amax_compute == "max" || amax_compute == "most_recent"); - - paddle::Tensor amax; - - if (amax_compute == "max") { - amax = amax_history.max({0}); - } else { - amax = amax_history.slice(0, 1); - } - - const auto rolled_amax_history = amax_history.roll({-1}, {0}); - - auto amax_history_numel = amax_history.numel(); - auto amax_numel = amax.numel(); - size_t num_blocks = (amax_history_numel + BLOCK_SIZE - 1) / BLOCK_SIZE; - - const int *current_step_id_ptr = - reinterpret_cast(GetOptionalDataPtr(current_step_id_tensor)); - auto parameterSetter = [current_step_id_ptr, - fwd_update](phi::backends::gpu::gpuKernelParams ¶ms) { - if (fwd_update) { - int current_step_id = *current_step_id_ptr; - params.As(7) = (current_step_id == 0); - } - }; - - const float *amax_ptr = amax.data(); - const float *rolled_amax_history_ptr = rolled_amax_history.data(); - const bool *non_weight_mask_ptr = non_weight_mask.data(); - float *amax_history_ptr = amax_history.data(); - float *scale_ptr = scale.data(); - float *scale_inv_ptr = scale_inv.data(); - - phi::backends::gpu::CUDAGraphNodeLauncher::gpuKernelCallback_t cudaKernelCallback = - [=](unsigned int id) { - void *functionPtr = reinterpret_cast(&UpdateFP8MetaKernel); - cudaFunction_t cudaFunc; - PADDLE_ENFORCE_GPU_SUCCESS(cudaGetFuncBySymbol(&cudaFunc, functionPtr)); - UpdateFP8MetaKernel<<>>( - id, amax_ptr, rolled_amax_history_ptr, non_weight_mask_ptr, amax_history_ptr, scale_ptr, - scale_inv_ptr, update_weight_scale_inv, margin, fp8_max, amax_history_numel, - amax_numel); - NVTE_CHECK_CUDA(cudaGetLastError()); - return cudaFunc; - }; - phi::backends::gpu::CUDAGraphNodeLauncher::Instance().KernelNodeLaunch(parameterSetter, - cudaKernelCallback); -#else - NVTE_ERROR( - "amax_and_scale_update_inplace_legacy is not supported in old version of PaddlePaddle\n"); -#endif -} - -void update_latest_amax_history_inplace(paddle::Tensor &history, // NOLINT - const paddle::Tensor &amax) { - // Copy amax to history[0] - NVTE_CHECK_CUDA(cudaMemcpyAsync(history.data(), amax.data(), amax.numel() * SizeOf(amax.dtype()), - cudaMemcpyDeviceToDevice, amax.stream())); -} - -__global__ __launch_bounds__(BLOCK_SIZE) void mask_to_actual_seqlens_kernel( - const bool *mask, int32_t *q_actual_seqlen, int32_t *kv_actual_seqlen, int q_seqlen, - int kv_seqlen, bool need_kv) { - typedef cub::BlockReduce BlockReduce; - __shared__ typename BlockReduce::TempStorage q_smem; - __shared__ typename BlockReduce::TempStorage kv_smem; - unsigned int tid = threadIdx.x; - unsigned int batch_offset = blockIdx.x * q_seqlen * kv_seqlen; - - // load mask, convert to 1/0, do accumulation - int q = 0, kv = 0; - for (unsigned int q_idx = tid * kv_seqlen; q_idx < q_seqlen * kv_seqlen; - q_idx += BLOCK_SIZE * kv_seqlen) { - q += (mask[q_idx + batch_offset] ? 0 : 1); - } - - if (need_kv) { - for (unsigned int kv_idx = tid; kv_idx < kv_seqlen; kv_idx += BLOCK_SIZE) { - kv += (mask[kv_idx + batch_offset] ? 0 : 1); - } - } - __syncthreads(); - - // compute cub::BlockReduce - int q_sum, kv_sum; - q_sum = BlockReduce(q_smem).Sum(q); - if (need_kv) kv_sum = BlockReduce(kv_smem).Sum(kv); - - // write result for this block to global mem - if (tid == 0) { - q_actual_seqlen[blockIdx.x + 1] = q_sum; - if (need_kv) { - kv_actual_seqlen[blockIdx.x + 1] = kv_sum; - } - } -} - -__global__ __launch_bounds__(BLOCK_SIZE) void block_prefix_sum_inplace(int32_t *x, int n) { - typedef cub::BlockScan BlockScan; - __shared__ typename BlockScan::TempStorage smem; - // +1 to ignore the first element - int i = blockIdx.x * blockDim.x + threadIdx.x + 1; - - // load data - int32_t thread_data[1]; - thread_data[0] = i < n ? x[i] : 0; - __syncthreads(); - - // CUB block prefix sum - BlockScan(smem).InclusiveSum(thread_data, thread_data); - __syncthreads(); - - // write result - if (i < n) { - x[i] = thread_data[0]; - } -} - -void mask_to_cu_seqlens(const paddle::Tensor &mask, - paddle::Tensor &q_cu_seqlen, // NOLINT - paddle::optional &kv_cu_seqlen, // NOLINT - int q_seqlen, int kv_seqlen, bool need_kv) { - if (need_kv) { - NVTE_CHECK(GetOptionalDataPtr(kv_cu_seqlen) != nullptr, - "kv_cu_seqlen must be provided when need_kv is true"); - } - mask_to_actual_seqlens_kernel<<>>( - mask.data(), q_cu_seqlen.data(), - reinterpret_cast(GetOptionalDataPtr(kv_cu_seqlen)), q_seqlen, kv_seqlen, need_kv); - // q_cu_seqlen shape: [bs+1], assume bs is not too large (<=512), so we can use a single block - // to do prefix sum - NVTE_CHECK(q_cu_seqlen.numel() - 1 <= BLOCK_SIZE, "batch size too large, kernel may fail"); - block_prefix_sum_inplace<<<1, BLOCK_SIZE, 0, mask.stream()>>>(q_cu_seqlen.data(), - q_cu_seqlen.numel()); - if (need_kv) { - block_prefix_sum_inplace<<<1, BLOCK_SIZE, 0, mask.stream()>>>( - reinterpret_cast(GetOptionalDataPtr(kv_cu_seqlen)), kv_cu_seqlen->numel()); - } -} - -} // namespace paddle_ext -} // namespace transformer_engine - -PD_BUILD_OP(te_gemm) - .Inputs({"A", paddle::Optional("A_scale_inverse"), "B", paddle::Optional("B_scale_inverse"), - paddle::Optional("bias"), "_D", paddle::Optional("_D_scale"), - paddle::Optional("_D_amax"), paddle::Optional("_pre_gelu_out"), "_workspace"}) - .Outputs({"D", paddle::Optional("D_scale"), paddle::Optional("D_amax"), - paddle::Optional("pre_gelu_out"), "workspace"}) - .Attrs({"A_index: int64_t", "B_index: int64_t", "D_index: int64_t", "A_type: int64_t", - "B_type: int64_t", "D_type: int64_t", "bias_type: int64_t", "transa: bool", - "transb: bool", "grad: bool", "workspace_size: int64_t", "accumulate: bool", - "use_split_accumulator: bool", "math_sm_count: int64_t"}) - .SetInplaceMap({{"_D", "D"}, - {paddle::Optional("_D_scale"), paddle::Optional("D_scale")}, - {paddle::Optional("_D_amax"), paddle::Optional("D_amax")}, - {paddle::Optional("_pre_gelu_out"), paddle::Optional("pre_gelu_out")}, - {"_workspace", "workspace"}}) - .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_gemm)); - -PD_BUILD_OP(cast_to_fp8) - .Inputs({"Input", "Scale", "_Output", "_Amax", "_ScaleInv"}) - .Outputs({"Output", "Amax", "ScaleInv"}) - .Attrs({"index: int64_t", "otype: int64_t"}) - .SetInplaceMap({{"_Output", "Output"}, {"_Amax", "Amax"}, {"_ScaleInv", "ScaleInv"}}) - .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::cast_to_fp8)); - -PD_BUILD_OP(cast_from_fp8) - .Inputs({"Input", "ScaleInv"}) - .Outputs({"Output"}) - .Attrs({"index: int64_t", "itype: int64_t", "otype: int64_t"}) - .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::cast_from_fp8)); - -PD_BUILD_OP(te_transpose) - .Inputs({"Input"}) - .Outputs({"Output"}) - .Attrs({"otype: int64_t"}) - .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_transpose)); - -PD_BUILD_OP(te_cast_transpose) - .Inputs({"Input", "Scale", "_CastedOutput", "_TransposedOutput", "_Amax", "_ScaleInv"}) - .Outputs({"CastedOutput", "TransposedOutput", "Amax", "ScaleInv"}) - .SetInplaceMap({{"_CastedOutput", "CastedOutput"}, - {"_TransposedOutput", "TransposedOutput"}, - {"_Amax", "Amax"}, - {"_ScaleInv", "ScaleInv"}}) - .Attrs({"index: int64_t", "otype: int64_t"}) - .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_cast_transpose)); - -PD_BUILD_OP(te_cast_transpose_bgrad) - .Inputs({"GradOutput", "Scale", "_Amax", "_ScaleInv"}) - .Outputs({"dBias", "CastedOutput", "TransposedOutput", "Amax", "ScaleInv"}) - .SetInplaceMap({{"_Amax", "Amax"}, {"_ScaleInv", "ScaleInv"}}) - .Attrs({"index: int64_t", "otype: int64_t"}) - .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_cast_transpose_bgrad)); - -PD_BUILD_OP(te_gelu_fp8) - .Inputs({"Input", "Scale", "_Amax", "_ScaleInv"}) - .Outputs({"Output", "Amax", "ScaleInv"}) - .SetInplaceMap({{"_Amax", "Amax"}, {"_ScaleInv", "ScaleInv"}}) - .Attrs({"index: int64_t", "otype: int64_t"}) - .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_gelu_fp8)); - -PD_BUILD_OP(te_gelu) - .Inputs({"Input"}) - .Outputs({"Output"}) - .Attrs({"otype: int64_t"}) - .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_gelu)); - -PD_BUILD_OP(te_swiglu) - .Inputs({"Input"}) - .Outputs({"Output"}) - .Attrs({"otype: int64_t"}) - .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_swiglu)); - -PD_BUILD_OP(te_swiglu_fp8) - .Inputs({"Input", "Scale", "_Amax", "_ScaleInv"}) - .Outputs({"Output", "Amax", "ScaleInv"}) - .SetInplaceMap({{"_Amax", "Amax"}, {"_ScaleInv", "ScaleInv"}}) - .Attrs({"index: int64_t", "otype: int64_t"}) - .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_swiglu_fp8)); - -PD_BUILD_OP(te_dswiglu) - .Inputs({"Grad", "Input"}) - .Outputs({"Output"}) - .Attrs({"otype: int64_t"}) - .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_dswiglu)); - -PD_BUILD_OP(te_cast_transpose_bgrad_dgelu) - .Inputs({"GradOutput", "GeluInput", "Scale", "_Amax", "_ScaleInv"}) - .Outputs({"CastedDgelu", "TransposedDgelu", "Dbias", "Amax", "ScaleInv"}) - .SetInplaceMap({{"_Amax", "Amax"}, {"_ScaleInv", "ScaleInv"}}) - .Attrs({"index: int64_t", "otype: int64_t"}) - .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_cast_transpose_bgrad_dgelu)); - -PD_BUILD_OP(te_layernorm_fwd_fp8) - .Inputs({"Input", "Weight", "Bias", "Scale", "_Amax", "_ScaleInv"}) - .Outputs({"Output", "Mu", "Rsigma", "Amax", "ScaleInv"}) - .SetInplaceMap({{"_Amax", "Amax"}, {"_ScaleInv", "ScaleInv"}}) - .Attrs({"eps: float", "index: int64_t", "otype: int64_t", "sm_margin: int64_t", - "zero_centered_gamma: bool"}) - .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_layernorm_fwd_fp8)); - -PD_BUILD_OP(te_layernorm_fwd) - .Inputs({"Input", "Weight", "Bias"}) - .Outputs({"Output", "Mu", "Rsigma"}) - .Attrs({"eps: float", "otype: int64_t", "sm_margin: int64_t", "zero_centered_gamma: bool"}) - .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_layernorm_fwd)); - -PD_BUILD_OP(te_layernorm_bwd) - .Inputs({"Dz", "X", "Mu", "Rsigma", "Gamma"}) - .Outputs({"Dx", "Dgamma", "Dbeta"}) - .Attrs({"sm_margin: int64_t", "zero_centered_gamma: bool"}) - .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_layernorm_bwd)); - -PD_BUILD_OP(te_rmsnorm_fwd) - .Inputs({"Input", "Weight"}) - .Outputs({"Output", "InvVariance"}) - .Attrs({"eps: float", "otype: int64_t", "sm_margin: int64_t", "zero_centered_gamma: bool"}) - .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_rmsnorm_fwd)); - -PD_BUILD_OP(te_rmsnorm_fwd_fp8) - .Inputs({"Input", "Weight", "Scale", "_Amax", "_ScaleInv"}) - .Outputs({"Output", "InvVariance", "Amax", "ScaleInv"}) - .SetInplaceMap({{"_Amax", "Amax"}, {"_ScaleInv", "ScaleInv"}}) - .Attrs({"eps: float", "index: int64_t", "otype: int64_t", "sm_margin: int64_t", - "zero_centered_gamma: bool"}) - .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_rmsnorm_fwd_fp8)); - -PD_BUILD_OP(te_rmsnorm_bwd) - .Inputs({"Dz", "X", "Rsigma", "Gamma"}) - .Outputs({"Dx", "Dgamma"}) - .Attrs({"sm_margin: int64_t", "zero_centered_gamma: bool"}) - .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_rmsnorm_bwd)); - -PD_BUILD_OP(te_fused_attn_fwd_qkvpacked) - .Inputs({"QKV", "cu_seqlens", paddle::Optional("Bias"), "_O", paddle::Optional("_softmax_aux"), - "_rng_state"}) - .Outputs({"O", paddle::Optional("softmax_aux"), "rng_state"}) - .Attrs({"b: int64_t", "h: int64_t", "d: int64_t", "total_seqs: int64_t", "max_seqlen: int64_t", - "is_training: bool", "attn_scale: float", "p_dropout: float", "qkv_layout: std::string", - "bias_type: std::string", "attn_mask_type: std::string", "qkv_type: int64_t", - "rng_elts_per_thread: int64_t"}) - .SetInplaceMap({{"_O", "O"}, - {paddle::Optional("_softmax_aux"), paddle::Optional("softmax_aux")}, - {"_rng_state", "rng_state"}}) - .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_fused_attn_fwd_qkvpacked)); - -PD_BUILD_OP(te_fused_attn_bwd_qkvpacked) - .Inputs({"QKV", "cu_seqlens", "O", "dO", "softmax_aux", "_dQKV", paddle::Optional("_dBias"), - "rng_state"}) - .Outputs({"dQKV", paddle::Optional("dBias")}) - .Attrs({"b: int64_t", "h: int64_t", "d: int64_t", "total_seqs: int64_t", "max_seqlen: int64_t", - "attn_scale: float", "p_dropout: float", "qkv_layout: std::string", - "bias_type: std::string", "attn_mask_type: std::string", "qkv_type: int64_t", - "deterministic: bool"}) - .SetInplaceMap({{"_dQKV", "dQKV"}, {paddle::Optional("_dBias"), paddle::Optional("dBias")}}) - .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_fused_attn_bwd_qkvpacked)); - -PD_BUILD_OP(te_fused_attn_fwd_kvpacked) - .Inputs({"Q", "KV", "cu_seqlens_q", "cu_seqlens_kv", paddle::Optional("Bias"), "_O", - paddle::Optional("_softmax_aux"), "_rng_state"}) - .Outputs({"O", paddle::Optional("softmax_aux"), "rng_state"}) - .Attrs({"b: int64_t", "h: int64_t", "d: int64_t", "total_seqs_q: int64_t", - "total_seqs_kv: int64_t", "max_seqlen_q: int64_t", "max_seqlen_kv: int64_t", - "is_training: bool", "attn_scale: float", "p_dropout: float", "qkv_layout: std::string", - "bias_type: std::string", "attn_mask_type: std::string", "qkv_type: int64_t", - "rng_elts_per_thread: int64_t"}) - .SetInplaceMap({{"_O", "O"}, - {paddle::Optional("_softmax_aux"), paddle::Optional("softmax_aux")}, - {"_rng_state", "rng_state"}}) - .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_fused_attn_fwd_kvpacked)); - -PD_BUILD_OP(te_fused_attn_bwd_kvpacked) - .Inputs({"Q", "KV", "cu_seqlens_q", "cu_seqlens_kv", "O", "dO", "softmax_aux", "_dQ", "_dKV", - paddle::Optional("_dBias"), "rng_state"}) - .Outputs({"dQ", "dKV", paddle::Optional("dBias")}) - .Attrs({"b: int64_t", "h: int64_t", "d: int64_t", "total_seqs_q: int64_t", - "total_seqs_kv: int64_t", "max_seqlen_q: int64_t", "max_seqlen_kv: int64_t", - "attn_scale: float", "p_dropout: float", "qkv_layout: std::string", - "bias_type: std::string", "attn_mask_type: std::string", "qkv_type: int64_t", - "deterministic: bool"}) - .SetInplaceMap({{"_dQ", "dQ"}, - {"_dKV", "dKV"}, - {paddle::Optional("_dBias"), paddle::Optional("dBias")}}) - .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_fused_attn_bwd_kvpacked)); - -PD_BUILD_OP(te_fused_attn_fwd) - .Inputs({"Q", "K", "V", "cu_seqlens_q", "cu_seqlens_kv", paddle::Optional("Bias"), "_O", - paddle::Optional("_softmax_aux"), "_rng_state"}) - .Outputs({"O", paddle::Optional("softmax_aux"), "rng_state"}) - .Attrs({"b: int64_t", "h: int64_t", "d: int64_t", "max_seqlen_q: int64_t", - "max_seqlen_kv: int64_t", "is_training: bool", "attn_scale: float", "p_dropout: float", - "qkv_layout: std::string", "bias_type: std::string", "attn_mask_type: std::string", - "qkv_type: int64_t", "rng_elts_per_thread: int64_t"}) - .SetInplaceMap({{"_O", "O"}, - {paddle::Optional("_softmax_aux"), paddle::Optional("softmax_aux")}, - {"_rng_state", "rng_state"}}) - .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_fused_attn_fwd)); - -PD_BUILD_OP(te_fused_attn_bwd) - .Inputs({"Q", "K", "V", "cu_seqlens_q", "cu_seqlens_kv", "O", "dO", "softmax_aux", "_dQ", "_dK", - "_dV", paddle::Optional("_dBias"), "rng_state"}) - .Outputs({"dQ", "dK", "dV", paddle::Optional("dBias")}) - .Attrs({"b: int64_t", "h: int64_t", "d: int64_t", "max_seqlen_q: int64_t", - "max_seqlen_kv: int64_t", "attn_scale: float", "p_dropout: float", - "qkv_layout: std::string", "bias_type: std::string", "attn_mask_type: std::string", - "qkv_type: int64_t", "deterministic: bool"}) - .SetInplaceMap({{"_dQ", "dQ"}, - {"_dK", "dK"}, - {"_dV", "dV"}, - {paddle::Optional("_dBias"), paddle::Optional("dBias")}}) - .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_fused_attn_bwd)); - -PD_BUILD_OP(te_scaled_softmax_forward) - .Inputs({"input"}) - .Outputs({"softmax_results"}) - .Attrs({"scale_factor: float"}) - .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_scaled_softmax_forward)); - -PD_BUILD_OP(te_scaled_softmax_backward) - .Inputs({"out_grad_", "softmax_results"}) - .Outputs({"out_grad"}) - .Attrs({"scale_factor: float"}) - .SetInplaceMap({{"out_grad_", "out_grad"}}) - .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_scaled_softmax_backward)); - -PD_BUILD_OP(te_scaled_masked_softmax_forward) - .Inputs({"input", "mask"}) - .Outputs({"softmax_results"}) - .Attrs({"scale_factor: float"}) - .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_scaled_masked_softmax_forward)); - -PD_BUILD_OP(te_scaled_masked_softmax_backward) - .Inputs({"out_grad_", "softmax_results"}) - .Outputs({"out_grad"}) - .Attrs({"scale_factor: float"}) - .SetInplaceMap({{"out_grad_", "out_grad"}}) - .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_scaled_masked_softmax_backward)); - -PD_BUILD_OP(te_scaled_upper_triang_masked_softmax_forward) - .Inputs({"input"}) - .Outputs({"softmax_results"}) - .Attrs({"scale_factor: float"}) - .SetKernelFn( - PD_KERNEL(transformer_engine::paddle_ext::te_scaled_upper_triang_masked_softmax_forward)); - -PD_BUILD_OP(te_scaled_upper_triang_masked_softmax_backward) - .Inputs({"out_grad_", "softmax_results"}) - .Outputs({"out_grad"}) - .Attrs({"scale_factor: float"}) - .SetInplaceMap({{"out_grad_", "out_grad"}}) - .SetKernelFn( - PD_KERNEL(transformer_engine::paddle_ext::te_scaled_upper_triang_masked_softmax_backward)); - -PD_BUILD_OP(amax_and_scale_update_inplace_legacy) - .Inputs({"_amax_history", "_scale", "_scale_inv", "non_weight_mask", - paddle::Optional("current_step_id_tensor")}) - .Outputs({"amax_history", "scale", "scale_inv"}) - .SetInplaceMap({{"_amax_history", "amax_history"}, - {"_scale", "scale"}, - {"_scale_inv", "scale_inv"}}) - .Attrs({"update_weight_scale_inv: bool", "fwd_update: bool", "fp8_max: float", "margin: float", - "amax_compute: std::string"}) - .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::amax_and_scale_update_inplace_legacy)); - -PD_BUILD_OP(amax_and_scale_update_inplace) - .Inputs({"_amax_history", "_scale", "_scale_inv", "non_weight_mask"}) - .Outputs({"amax_history", "scale", "scale_inv"}) - .SetInplaceMap({{"_amax_history", "amax_history"}, - {"_scale", "scale"}, - {"_scale_inv", "scale_inv"}}) - .Attrs({"fp8_dtype: int64_t", "margin: float", "amax_compute: std::string"}) - .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::amax_and_scale_update_inplace)); - -PD_BUILD_OP(update_latest_amax_history_inplace) - .Inputs({"_history", "amax"}) - .Outputs({"history"}) - .SetInplaceMap({{"_history", "history"}}) - .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::update_latest_amax_history_inplace)); - -PD_BUILD_OP(mask_to_cu_seqlens) - .Inputs({"mask", "_q_cu_seqlen", paddle::Optional("_kv_cu_seqlen")}) - .Outputs({"q_cu_seqlen", paddle::Optional("kv_cu_seqlen")}) - .Attrs({"q_seqlen: int", "kv_seqlen: int", "need_kv: bool"}) - .SetInplaceMap({{"_q_cu_seqlen", "q_cu_seqlen"}, - {paddle::Optional("_kv_cu_seqlen"), paddle::Optional("kv_cu_seqlen")}}) - .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::mask_to_cu_seqlens)); diff --git a/transformer_engine/paddle/csrc/extensions.cpp b/transformer_engine/paddle/csrc/extensions.cpp deleted file mode 100644 index 44ad2e7511..0000000000 --- a/transformer_engine/paddle/csrc/extensions.cpp +++ /dev/null @@ -1,63 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -#include "common.h" - -namespace transformer_engine { -namespace paddle_ext { - -size_t get_cublasLt_version() { return cublasLtGetVersion(); } - -PYBIND11_MODULE(transformer_engine_paddle, m) { - // Misc - m.def("get_cublasLt_version", &get_cublasLt_version, "Get cublasLt version"); - m.def("get_fused_attn_backend", &get_fused_attn_backend, "Get Fused Attention backend"); - m.def("get_nvte_qkv_layout", &get_nvte_qkv_layout, "Get qkv layout enum by the string"); - // Data structures - py::enum_(m, "DType", py::module_local()) - .value("kByte", DType::kByte) - .value("kInt32", DType::kInt32) - .value("kFloat32", DType::kFloat32) - .value("kFloat16", DType::kFloat16) - .value("kBFloat16", DType::kBFloat16) - .value("kFloat8E4M3", DType::kFloat8E4M3) - .value("kFloat8E5M2", DType::kFloat8E5M2); - - py::enum_(m, "NVTE_Bias_Type") - .value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS) - .value("NVTE_PRE_SCALE_BIAS", NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS) - .value("NVTE_POST_SCALE_BIAS", NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); - - py::enum_(m, "NVTE_Mask_Type") - .value("NVTE_NO_MASK", NVTE_Mask_Type::NVTE_NO_MASK) - .value("NVTE_PADDING_MASK", NVTE_Mask_Type::NVTE_PADDING_MASK) - .value("NVTE_CAUSAL_MASK", NVTE_Mask_Type::NVTE_CAUSAL_MASK); - - py::enum_(m, "NVTE_QKV_Layout") - .value("NVTE_SB3HD", NVTE_QKV_Layout::NVTE_SB3HD) - .value("NVTE_SBH3D", NVTE_QKV_Layout::NVTE_SBH3D) - .value("NVTE_SBHD_SB2HD", NVTE_QKV_Layout::NVTE_SBHD_SB2HD) - .value("NVTE_SBHD_SBH2D", NVTE_QKV_Layout::NVTE_SBHD_SBH2D) - .value("NVTE_SBHD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD) - .value("NVTE_BS3HD", NVTE_QKV_Layout::NVTE_BS3HD) - .value("NVTE_BSH3D", NVTE_QKV_Layout::NVTE_BSH3D) - .value("NVTE_BSHD_BS2HD", NVTE_QKV_Layout::NVTE_BSHD_BS2HD) - .value("NVTE_BSHD_BSH2D", NVTE_QKV_Layout::NVTE_BSHD_BSH2D) - .value("NVTE_BSHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) - .value("NVTE_T3HD", NVTE_QKV_Layout::NVTE_T3HD) - .value("NVTE_TH3D", NVTE_QKV_Layout::NVTE_TH3D) - .value("NVTE_THD_T2HD", NVTE_QKV_Layout::NVTE_THD_T2HD) - .value("NVTE_THD_TH2D", NVTE_QKV_Layout::NVTE_THD_TH2D) - .value("NVTE_THD_THD_THD", NVTE_QKV_Layout::NVTE_THD_THD_THD); - - py::enum_(m, "NVTE_Fused_Attn_Backend", py::module_local()) - .value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) - .value("NVTE_F16_arbitrary_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) - .value("NVTE_FP8", NVTE_Fused_Attn_Backend::NVTE_FP8) - .value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend); -} -} // namespace paddle_ext -} // namespace transformer_engine diff --git a/transformer_engine/paddle/distributed.py b/transformer_engine/paddle/distributed.py deleted file mode 100644 index 0e91341b80..0000000000 --- a/transformer_engine/paddle/distributed.py +++ /dev/null @@ -1,213 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""Methods needed for distributed training.""" - -import os -import warnings -from contextlib import contextmanager -from typing import Any, Optional, Union, Tuple - -import paddle - -import paddle.distributed.fleet.base.topology as tp -from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker -from paddle.distributed.fleet.layers.mpu import mp_ops - -try: - # This feature is not supported as of Paddle 2.6. - from paddle.distributed.fleet.meta_parallel import ( - PipelineParallelMicroStepLocations, - register_global_pipeline_parallel_hook, - ) -except ImportError: - print("Cannot find register_global_pipeline_parallel_hook !") - register_global_pipeline_parallel_hook = None - -from .constants import dist_group_type - -_weight_split_axis = { - "transformer_engine": {"row": 1, "column": 0}, - "paddle": {"row": 0, "column": 1}, -} - - -def get_tp_group_and_world_size( - tp_group: Union[dist_group_type, None], enable_tp: bool = True -) -> Tuple[Union[dist_group_type, None], int]: - """Get TP group and world size using Fleet API""" - if not (paddle.distributed.is_initialized() and enable_tp): - return None, 1 - model_parallel_group = ( - tp._HYBRID_PARALLEL_GROUP.get_model_parallel_group() if tp_group is None else tp_group - ) - world_size = ( - tp._HYBRID_PARALLEL_GROUP.get_model_parallel_world_size() - if tp_group is None - else tp_group.nranks - ) - """ - When using TP, the NCCL communication needs to be scheduled - before the GEMM for a guaranteed overlap. From the host side - in TE, the comm calls are always launched first, but to ensure - that the GEMM isn't scheduled first, the environment variable - `CUDA_DEVICE_MAX_CONNECTIONS` needs to be set to 1 to force a - single channel. - """ - num_cuda_work_queues = int(os.getenv("CUDA_DEVICE_MAX_CONNECTIONS", "0")) - if num_cuda_work_queues != 1: - warnings.warn( - "To guarantee overlapping TP and SP collectives with the backward" - "GEMMs, set environment variable CUDA_DEVICE_MAX_CONNECTIONS = 1" - ) - - return model_parallel_group, world_size - - -def is_pp_enabled() -> bool: - """Check if pipeline parallel is enabled""" - if not paddle.distributed.is_initialized(): - return False - - return tp._HYBRID_PARALLEL_GROUP.get_pipe_parallel_world_size() > 1 - - -def register_pp_fwd_begin_hook(forward_begin_hook): - """Register the pp hook if register_global_pipeline_parallel_hook exist""" - if register_global_pipeline_parallel_hook is not None: - register_global_pipeline_parallel_hook( - PipelineParallelMicroStepLocations.FORWARD_BEGIN, forward_begin_hook - ) - - -@contextmanager -def track_rng_state(enable: bool, **kwargs) -> None: - """ - Applies get_rng_state_tracker().rng_state() to the context. - If not enabled, it does nothing. - """ - if enable: - with get_rng_state_tracker().rng_state(**kwargs): - yield - else: - yield - - -def set_tensor_dist_attr(tensor: paddle.Tensor, is_parallel: bool, axis: int) -> None: - """Set distributed attributes for the input tensor""" - tensor.is_distributed = is_parallel - if is_parallel: - tensor.split_axis = axis - - -def set_weight_tensor_dist_attr( - tensor: paddle.Tensor, is_parallel: bool, parallel_mode: Optional[str], backend: str -) -> None: - """Set distributed attributes for the weight tensor""" - if not is_parallel or parallel_mode is None: - return - set_tensor_dist_attr(tensor, is_parallel, axis=_weight_split_axis[backend][parallel_mode]) - - -def allreduce( - input_: paddle.Tensor, - tp_group: Optional[dist_group_type] = None, - sync_op: bool = True, -) -> Tuple[paddle.Tensor, Any]: - """All-reduce the input tensor across model parallel group.""" - - # Bypass the function if we are using only 1 GPU. - if tp_group is None or tp_group.nranks == 1: - return input_ - - # All-reduce. - if sync_op: - output = mp_ops._mp_allreduce( - input_, - group=tp_group, - use_calc_stream=True, - use_model_parallel=True, - ) - return output, None - - wait_handle = paddle.distributed.all_reduce( - input_, - op=paddle.distributed.ReduceOp.SUM, - group=tp_group, - sync_op=False, - ) - - output = input_ - - return output, wait_handle - - -def allgather( - input_: paddle.Tensor, - tp_group: Optional[dist_group_type] = None, - sync_op: bool = True, - axis: int = 0, -) -> Tuple[paddle.Tensor, Any]: - """All-gather the input tensor across model parallel group.""" - - # Bypass the function if we are using only 1 GPU. - if tp_group is None or tp_group.nranks == 1: - return input_, None - - parallelism = tp_group.nranks - output_shape = input_.shape - output_shape[axis] = output_shape[axis] * parallelism - output = paddle.empty(shape=output_shape, dtype=input_.dtype) - wait_handle = tp_group.process_group.all_gather_into_tensor(output, input_, sync_op) - if sync_op: - wait_handle.wait() - return output, None - return output, wait_handle - - -def reduce_scatter( - input_: paddle.Tensor, - tp_group: Optional[dist_group_type] = None, - sync_op: bool = True, -) -> [paddle.Tensor, Any]: - """Reduce-scatter the input tensor across model parallel group.""" - - # Bypass the function if we are using only 1 GPU. - if tp_group is None or tp_group.nranks == 1: - return input_, None - - parallelism = tp_group.nranks - output_shape = input_.shape - assert input_.shape[0] % parallelism == 0, ( - f"Input sequence length {input_.shape[0]} can't be divided " - f"exactly by sequence parallelism {parallelism}" - ) - output_shape[0] = output_shape[0] // parallelism - output = paddle.empty(shape=output_shape, dtype=input_.dtype) - wait_handle = paddle.distributed.stream.reduce_scatter( - output, input_, op=paddle.distributed.ReduceOp.SUM, group=tp_group, sync_op=sync_op - ) - if sync_op: - return output, None - return output, wait_handle - - -def identity( - input_: paddle.Tensor, - tp_group: Optional[dist_group_type] = None, -) -> paddle.Tensor: - """ - Identity when forward. - Allreduce across model parallel group when backward. - """ - output = mp_ops._c_identity(input_, group=tp_group) - - return output - - -def mark_as_sequence_parallel_parameter(parameter: paddle.Tensor): - """ - Set sequence_parallel attribute to input tensor. It is used for registering allreduce - hooks in PaddleNLP sequence parallel training. - """ - setattr(parameter, "sequence_parallel", True) diff --git a/transformer_engine/paddle/fp8.py b/transformer_engine/paddle/fp8.py deleted file mode 100644 index 7313a81975..0000000000 --- a/transformer_engine/paddle/fp8.py +++ /dev/null @@ -1,370 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""FP8 utilities for TransformerEngine""" - -from contextlib import contextmanager -from typing import Tuple, Optional, Dict, Any, Union - -import numpy as np - -import paddle -from transformer_engine import transformer_engine_paddle as tex -from transformer_engine.common.recipe import DelayedScaling, Format - -from .constants import dist_group_type -from .fp8_buffer import FP8MetaFwdBuffer, FP8MetaBwdBuffer, FP8RecomputeBuffer - -__all__ = ["fp8_autocast"] - -# FP8 support -_is_fp8_available = None -_reason_for_no_fp8 = "" - - -def _check_fp8_support() -> Tuple[bool, str]: - """Return if fp8 support is available""" - - # Check GPU arch - arch = paddle.device.cuda.get_device_capability() - if arch >= (9, 0): # hopper and above - return True, "" - if arch < (8, 9): # pre-ada - return False, "Device compute capability 8.9 or higher required for FP8 execution." - - # Special handling for Ada - if tex.get_cublasLt_version() < 120103: - return False, "CublasLt version 12.1.3.x or higher required for FP8 execution on Ada." - if not paddle.version.cuda(): - return False, "Cuda version 12.1 or higher required for FP8 execution on Ada." - if tuple(int(v) for v in paddle.version.cuda().split(".")) < (12, 1): - return False, "Cuda version 12.1 or higher required for FP8 execution on Ada." - return True, "" - - -def is_fp8_available() -> Tuple[bool, str]: - """Return if fp8 support is available""" - global _is_fp8_available, _reason_for_no_fp8 - if _is_fp8_available is None: - _is_fp8_available, _reason_for_no_fp8 = _check_fp8_support() - return _is_fp8_available, _reason_for_no_fp8 - - -class FP8State: - """Stores FP8 state""" - - def __init__(self): - self._fp8_enabled = False - self._fp8_calibration = False - self._fp8_recipe = None - self._fp8_distributed_group = None - self._is_first_fp8_module = False - self._fp8_autocast_counter = 0 - self._fp8_autocast_depth = 0 - self._fp8_recompute_enabled = False - self._use_cudagraph = False - self._fp8_fwd_buffer = FP8MetaFwdBuffer() - self._fp8_bwd_buffer = FP8MetaBwdBuffer() - self._fp8_recompute_buffer = FP8RecomputeBuffer() - - def is_fp8_enabled(self) -> bool: - """Is FP8 enabled""" - return self._fp8_enabled - - def is_fp8_calibration(self) -> bool: - """Is FP8 calibration""" - return self._fp8_calibration - - def get_fp8_recipe(self) -> DelayedScaling: - """Return the fp8 recipe""" - return self._fp8_recipe - - @staticmethod - def get_default_fp8_recipe() -> DelayedScaling: - """FP8 recipe with default args.""" - return DelayedScaling() - - def get_autocast_id(self) -> int: - """Returns the number of times of entering the `fp8_autocast` context. - as a unique ID for different training steps.""" - return self._fp8_autocast_counter - - def is_first_fp8_module(self): - """Returns `True` only the first time when called multiple - times from within the same `fp8_autocast` context. - """ - tmp = self._is_first_fp8_module - self._is_first_fp8_module = False - return tmp - - def get_fp8_group(self) -> Union[dist_group_type, None]: - """Return the fp8 group for scale/amax comm""" - return self._fp8_distributed_group - - def get_fp8_fwd_buffer(self) -> FP8MetaFwdBuffer: - """Returns global fp8 forward buffer.""" - return self._fp8_fwd_buffer - - def get_fp8_bwd_buffer(self) -> FP8MetaBwdBuffer: - """Returns global fp8 backward buffer.""" - return self._fp8_bwd_buffer - - def is_fp8_recompute_enabled(self) -> bool: - """Is FP8 recompute enabled""" - return self._fp8_recompute_enabled - - def get_fp8_recompute_buffer(self) -> FP8RecomputeBuffer: - """Returns global fp8 recompute buffer.""" - return self._fp8_recompute_buffer - - def is_cudagraph_enabled(self) -> bool: - """Is CUDAGraph enabled""" - return self._use_cudagraph - - def enable_cudagraph(self): - """Enable CUDA Graphs. Once CUDA Graphs are enabled, they cannot be disabled within the same execution context at current implementation.""" - self._use_cudagraph = True - self._fp8_fwd_buffer.enable_cudagraph() - self._fp8_bwd_buffer.enable_cudagraph() - if self._fp8_recompute_enabled: - raise RuntimeError("Currently, We do not allow recompute with cudagraph") - - def enter( - self, - enabled: bool, - calibrating: bool, - fp8_recipe: Optional[DelayedScaling], - fp8_group: Optional[dist_group_type], - ) -> None: - """Called when entering 'fp8_autocast'""" - self.saved_states = ( - self._fp8_enabled, - self._fp8_calibration, - self._fp8_recipe, - self._fp8_distributed_group, - self._is_first_fp8_module, - ) - - self._fp8_enabled = enabled - self._fp8_calibration = calibrating - self._fp8_recipe = self.get_default_fp8_recipe() if fp8_recipe is None else fp8_recipe - self._fp8_distributed_group = fp8_group - - if self._fp8_autocast_depth == 0: - self._is_first_fp8_module = True - self._fp8_autocast_counter += 1 - self._fp8_autocast_depth += 1 - - def exit(self): - """Called when exiting 'fp8_autocast'""" - # Restore saved states - ( - self._fp8_enabled, - self._fp8_calibration, - self._fp8_recipe, - self._fp8_distributed_group, - self._is_first_fp8_module, - ) = self.saved_states - - self._fp8_autocast_depth -= 1 - - if self._fp8_autocast_depth == 0: - self._fp8_fwd_buffer.finalize() - - -_global_fp8_state = FP8State() - - -def get_global_fp8_state() -> FP8State: - """Get global fp8 state""" - return _global_fp8_state - - -@contextmanager -def fp8_autocast( - enabled: bool = False, - calibrating: bool = False, - fp8_recipe: Optional[DelayedScaling] = None, - fp8_group: Optional[dist_group_type] = None, -) -> None: - """ - Context manager for FP8 usage. - - .. code-block:: python - - with fp8_autocast(enabled=True): - out = model(inp) - - .. note:: - - Support for FP8 in the Linear layer of Transformer Engine is currently limited to tensors - with shapes where both dimensions are divisible by 16. In terms of the input to the full - Transformer network, this typically requires padding sequence length to be multiple of 16. - - .. note:: - - When :attr:`fp8_recipe.reduce_amax==True`, any module must not be invoked more than once - inside a single `fp8_autocast` region. This is unsupported behavior because the amax - reduction is handled during the exit of the `fp8_autocast` context. Calling the same - module more than once inside an `fp8_autocast` region overrides the amax tensors - before reduction can occur. - - Parameters - ---------- - enabled: bool, default = `False` - whether or not to enable fp8 - calibrating: bool, default = `False` - calibration mode allows collecting statistics such as amax and scale - data of fp8 tensors even when executing without fp8 enabled. This is - useful for saving an inference ready fp8 checkpoint while training - using a higher precision. - fp8_recipe: recipe.DelayedScaling, default = `None` - recipe used for FP8 training. - fp8_group: paddle.distributed.collective.Group, default = `None` - distributed group over which amaxes for the fp8 tensors - are reduced at the end of each training step. - """ - try: - _global_fp8_state.enter(enabled, calibrating, fp8_recipe, fp8_group) - - if enabled: - fp8_available, reason_for_no_fp8 = is_fp8_available() - assert fp8_available, reason_for_no_fp8 - yield - finally: - _global_fp8_state.exit() - - -def get_fp8_te_dtype(fp8_recipe: DelayedScaling, fprop_tensor: bool = True) -> tex.DType: - """Get fp8 data type according to recipe and tensor""" - if fp8_recipe.fp8_format == Format.E4M3 or ( - fp8_recipe.fp8_format == Format.HYBRID and fprop_tensor - ): - return tex.DType.kFloat8E4M3 - return tex.DType.kFloat8E5M2 - - -def amax_and_scale_update( - fp8_meta: Dict[str, Any], - fwd_update: bool, - update_weight_scale_inv: bool = True, - current_step_id_tensor: Optional[paddle.Tensor] = None, - use_cudagraph: bool = False, -) -> None: - """Updates fp8 amaxes/scales for fwd | bwd.""" - amax_compute = fp8_meta["recipe"].amax_compute_algo - sf_compute = fp8_meta["recipe"].scaling_factor_compute_algo - fp8_meta_tensor_key = "scaling_fwd" if fwd_update else "scaling_bwd" - fp8_max_key = "fp8_max_fwd" if fwd_update else "fp8_max_bwd" - - if not callable(amax_compute) and sf_compute is None: - non_weight_mask = fp8_meta[fp8_meta_tensor_key].non_weight_mask - - if use_cudagraph: - tex.amax_and_scale_update_inplace_legacy( - _amax_history=fp8_meta[fp8_meta_tensor_key].amax_history, - _scale=fp8_meta[fp8_meta_tensor_key].scale, - _scale_inv=fp8_meta[fp8_meta_tensor_key].scale_inv, - non_weight_mask=non_weight_mask, - current_step_id_tensor=current_step_id_tensor, - update_weight_scale_inv=update_weight_scale_inv, - fwd_update=fwd_update, - fp8_max=fp8_meta[fp8_max_key], - margin=float(fp8_meta["recipe"].margin), - amax_compute=amax_compute, - ) - else: - if update_weight_scale_inv: - # we pass nullptr into kernel when we need to update_weight_scale_inv - non_weight_mask = paddle.empty([0]) - tex.amax_and_scale_update_inplace( - _amax_history=fp8_meta[fp8_meta_tensor_key].amax_history, - _scale=fp8_meta[fp8_meta_tensor_key].scale, - _scale_inv=fp8_meta[fp8_meta_tensor_key].scale_inv, - non_weight_mask=non_weight_mask, - fp8_dtype=int(get_fp8_te_dtype(fp8_meta["recipe"], fwd_update)), - margin=float(fp8_meta["recipe"].margin), - amax_compute=amax_compute, - ) - - else: - raise ValueError( - "We only support the fp8 recipe with 'max' or 'most_recent' " - "amax_compute_algo and default scaling_factor_compute_algo at this " - "moment." - ) - - -class FP8TensorMeta: - """Holds FP8 scaling and amax history for FP8 layers""" - - def __init__(self, is_forward: bool): - self.scale = paddle.Tensor() - self.scale_inv = paddle.Tensor() - self.amax_history = paddle.Tensor() - self.non_weight_mask = paddle.Tensor() - self.is_initialized = False - self.is_forward = is_forward - - def get_non_weight_mask(self, num_gemms: int): - """Needed for calculation of scale inverses to - preserve scale_inv when caching FP8 weights""" - if self.is_forward: - # [True, False, True]: -> [input, weight, output] - return paddle.to_tensor([True, False, True] * num_gemms) - # [True, True]: -> [grad_output, grad_input] - return paddle.to_tensor([True, True] * num_gemms) - - def prepare(self, num_gemms: int, amax_history_len: int) -> None: - """Prepare scales and amax tensors. It is called during fprop in each iteration. - If the meta tensors are not initialized yet, initialization is performed. If already - initialized, resize the meta tensors if amax_history_len has changed.""" - - if self.is_initialized: - # Handle changed amax history size. - curr_len = self.amax_history.shape[0] - num_fp8_tensors = self.amax_history.shape[1] - if amax_history_len < curr_len: - self.amax_history = self.amax_history[:amax_history_len] - elif amax_history_len > curr_len: - extra_rows = amax_history_len - curr_len - self.amax_history = paddle.concat( - [ - self.amax_history, - paddle.zeros((extra_rows, num_fp8_tensors), dtype="float32"), - ], - axis=0, - ) - return - - # Max. number of fp8 tensors per GEMM = 3 (input, weight, output) for fwd and - # 2 (grad_output and grad_input) for bwd - num_fp8_tensors = num_gemms * 3 if self.is_forward else num_gemms * 2 - - self.scale = paddle.ones(num_fp8_tensors, dtype="float32") - self.scale_inv = paddle.ones(num_fp8_tensors, dtype="float32") - self.amax_history = paddle.zeros([amax_history_len, num_fp8_tensors], dtype="float32") - self.non_weight_mask = self.get_non_weight_mask(num_gemms=num_gemms) - - self.is_initialized = True - - def to_numpy(self): - """Convert FP8 meta tensors to numpy.""" - assert self.is_initialized, "FP8TensorMeta is not initialized yet." - return { - "scale": self.scale.numpy(), - "scale_inv": self.scale_inv.numpy(), - "amax_history": self.amax_history.numpy(), - } - - def from_numpy(self, data: Dict[str, np.array]): - """Set FP8 meta tensors from numpy""" - self.scale = paddle.to_tensor(data["scale"]) - self.scale_inv = paddle.to_tensor(data["scale_inv"]) - self.amax_history = paddle.to_tensor(data["amax_history"]) - - num_fp8_tensors = self.scale.shape[0] - num_gemms = num_fp8_tensors // 3 if self.is_forward else num_fp8_tensors // 2 - self.non_weight_mask = self.get_non_weight_mask(num_gemms=num_gemms) - - self.is_initialized = True diff --git a/transformer_engine/paddle/fp8_buffer.py b/transformer_engine/paddle/fp8_buffer.py deleted file mode 100644 index 06a9355e72..0000000000 --- a/transformer_engine/paddle/fp8_buffer.py +++ /dev/null @@ -1,350 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""FP8 meta buffer for FP8 amax reduction""" - -from abc import ABC, abstractmethod -from collections import deque -from functools import partial -import os -from typing import Dict, Any, List, Union - -import numpy as np -import paddle -from transformer_engine import transformer_engine_paddle as tex - -from .constants import dist_group_type, RecomputeFunctionNames - - -class FP8MetaBufferBase(ABC): - """ - A global buffer that holds FP8 meta for reduction across trainers. - """ - - def __init__(self): - self._global_amax = {} - self._buffer_delete_key = None - self._amax_reduce_wait_func = None - self._dp_amax_reduce_interval = None - self._contiguous_amax = None - self._use_cudagraph = False - self._dp_amax_reduce_idx = 0 - - @staticmethod - @abstractmethod - def _get_meta_tensor_key(): - """Returns scaling key in `fp8_meta`.""" - - @staticmethod - @abstractmethod - def _get_buffer_position_key(): - """Returns module position key in `fp8_meta`.""" - - @staticmethod - @abstractmethod - def _get_autocast_key(): - """Returns autocast id key in `fp8_meta`.""" - - def _get_amax_buffer_key(self, fp8_meta: Dict[str, Any]) -> str: - """Return a key in `_global_amax` for the AMAX storage.""" - return f"AMAX_{fp8_meta[self._get_autocast_key()]}" - - def _execute_deletion(self) -> None: - """Delete the key from global amax buffer.""" - if self._buffer_delete_key is not None and self._buffer_delete_key in self._global_amax: - del self._global_amax[self._buffer_delete_key] - - def _wait_handle_and_split( - self, - contiguous_amax: paddle.Tensor, - chunk_sizes: List[int], - amax_buffer_key: str, - wait_handle: Union[bool, None], - ) -> None: - """Wait for amax reduction to finish and then copy reduced amax to buffer""" - if wait_handle is not None: - wait_handle.wait() - if self._use_cudagraph: - splited_list = list(contiguous_amax.split(chunk_sizes)) - for amax, split in zip(self._global_amax[amax_buffer_key], splited_list): - amax.copy_(split, False) - else: - self._global_amax[amax_buffer_key] = list(contiguous_amax.split(chunk_sizes)) - - def _global_amax_reduction( - self, - fp8_meta: Dict[str, Any], - tp_group: dist_group_type, - tp_size: int, - ) -> None: - """Concatenate, reduce, and split amaxes in the global buffer.""" - - def _reduce_tensor_across_group_op_max(tensor, group, sync_op): - if paddle.distributed.is_initialized(): - wait_handle = paddle.distributed.all_reduce( - tensor, - op=paddle.distributed.ReduceOp.MAX, - group=group, - sync_op=sync_op, - ) - return wait_handle - return None - - amax_buffer_key = self._get_amax_buffer_key(fp8_meta) - # Key already deleted. - if amax_buffer_key not in self._global_amax: - return None - - # Reduce AMAX in DP-domain at an interval. - if self._dp_amax_reduce_interval is None: - self._dp_amax_reduce_interval = int(os.getenv("NVTE_DP_AMAX_REDUCE_INTERVAL", "1")) - - tp_amax_reduce = False - reduce_group = -1 # Set value that will raise error if not set. `None` is a valid group. - if self._dp_amax_reduce_idx == 0: - reduce_group = fp8_meta["fp8_group"] - else: - tp_amax_reduce = True - self._dp_amax_reduce_idx = (self._dp_amax_reduce_idx + 1) % self._dp_amax_reduce_interval - - if tp_amax_reduce: - if tp_size > 1: - reduce_group = tp_group - else: - return None - - chunk_sizes = [x.shape[0] for x in self._global_amax[amax_buffer_key]] - if self._use_cudagraph: - # we need to ensure the _contiguous_amax is address-stable under cudagraph - if self._contiguous_amax is None: - self._contiguous_amax = paddle.concat(self._global_amax[amax_buffer_key]) - else: - self._contiguous_amax.copy_( - paddle.concat(self._global_amax[amax_buffer_key]), False - ) - else: - self._contiguous_amax = paddle.concat(self._global_amax[amax_buffer_key]) - - wait_handle = _reduce_tensor_across_group_op_max( - self._contiguous_amax, - reduce_group, - not fp8_meta["async_amax_reduction"], - ) - - if wait_handle is not None and self._use_cudagraph: - # we need to ensure record/wait does not cross the boundary of the graph - wait_handle.wait() - wait_handle = None - - return partial( - self._wait_handle_and_split, - self._contiguous_amax, - chunk_sizes, - amax_buffer_key, - wait_handle, - ) - - def add_amax(self, fp8_meta: Dict[str, Any]) -> None: - """Append `amax_history` to global buffer.""" - buffer_key = self._get_amax_buffer_key(fp8_meta) - fp8_meta_tensor_key = self._get_meta_tensor_key() - buffer_position_key = self._get_buffer_position_key() - - if buffer_key not in self._global_amax: - self._global_amax[buffer_key] = [fp8_meta[fp8_meta_tensor_key].amax_history[0]] - else: - self._global_amax[buffer_key].append(fp8_meta[fp8_meta_tensor_key].amax_history[0]) - - if buffer_position_key not in fp8_meta: - fp8_meta[buffer_position_key] = len(self._global_amax[buffer_key]) - 1 - - # Catch incorrect fp8_autocast usage. - assert fp8_meta[buffer_position_key] == len(self._global_amax[buffer_key]) - 1, ( - "Same module is being invoked more than once inside an `fp8_autocast` " - "region when using FP8 with amax reduction. This behavior is currently " - "unsupported. For more details and correct usage, please see " - "https://github.com/NVIDIA/TransformerEngine/pull/93." - ) - - def copy_amax_from_buffer(self, fp8_meta: Dict[str, Any]) -> None: - """Populate current amax with the correct location from buffer.""" - fp8_meta_tensor_key = self._get_meta_tensor_key() - buffer_position_key = self._get_buffer_position_key() - if buffer_position_key not in fp8_meta: - return - - amax_buffer_key = self._get_amax_buffer_key(fp8_meta) - assert amax_buffer_key in self._global_amax, "TE internal error." - - # Copy amax to amax_history[0] - tex.update_latest_amax_history_inplace( - _history=fp8_meta[fp8_meta_tensor_key].amax_history, - amax=self._global_amax[amax_buffer_key][fp8_meta[buffer_position_key]], - ) - - def set_for_deletion(self, fp8_meta: Dict[str, Any]) -> None: - """Delete this amax key from global buffer during autocast end.""" - if self._get_autocast_key() not in fp8_meta: - return - self._buffer_delete_key = self._get_amax_buffer_key(fp8_meta) - - def get_amax_reduce_handle(self) -> Union[bool, None]: - """Return AMAX reduction wait handle.""" - return self._amax_reduce_handle - - def wait(self) -> None: - """Wait for reduced amax to be available in buffer.""" - if self._amax_reduce_wait_func is not None: - self._amax_reduce_wait_func() # pylint: disable=not-callable - self._amax_reduce_wait_func = None - - def to_numpy(self) -> Dict[str, List[np.array]]: - """Convert to numpy arrays""" - out = {} - for k, v in self._global_amax.items(): - out[k] = [tensor.numpy() for tensor in v] - return out - - def from_numpy(self, buffer: Dict[str, np.array]) -> None: - """Set buffer values from numpy arrays""" - for k, v in buffer.items(): - self._global_amax[k] = [paddle.to_tensor(arr) for arr in v] - - def enable_cudagraph(self): - """Enable CUDA Graphs.""" - self._use_cudagraph = True - - -class FP8MetaFwdBuffer(FP8MetaBufferBase): - """FP8Meta Buffer for forward""" - - @staticmethod - def _get_meta_tensor_key() -> str: - """Returns scaling key in `fp8_meta`.""" - return "scaling_fwd" - - @staticmethod - def _get_buffer_position_key() -> str: - """Returns module position key in `fp8_meta`.""" - return "global_fp8_buffer_pos_fwd" - - @staticmethod - def _get_autocast_key() -> str: - """Returns module position key in `fp8_meta`.""" - return "autocast_id_fwd" - - def set_for_amax_reduction( - self, - fp8_meta: Dict[str, Any], - tp_group: dist_group_type, - tp_size: int, - ) -> None: - """Sets up the function to call during autocast exit.""" - self._amax_global_reduce_func = partial( - self._global_amax_reduction, - fp8_meta, - tp_group, - tp_size, - ) - - def finalize(self) -> None: - """ - Called at FP8 autocast end. - Performs AMAX reduction and delete unused buffer entries. - """ - if hasattr(self, "_amax_global_reduce_func") and callable(self._amax_global_reduce_func): - self._amax_reduce_wait_func = self._amax_global_reduce_func() - self._execute_deletion() - - -class FP8MetaBwdBuffer(FP8MetaBufferBase): - """FP8Meta Buffer for backward""" - - @staticmethod - def _get_meta_tensor_key() -> str: - """Returns scaling key in `fp8_meta`.""" - return "scaling_bwd" - - @staticmethod - def _get_buffer_position_key() -> str: - """Returns module position key in `fp8_meta`.""" - return "global_fp8_buffer_pos_bwd" - - @staticmethod - def _get_autocast_key() -> str: - """Returns module position key in `fp8_meta`.""" - return "autocast_id_bwd" - - def finalize( - self, - fp8_meta: Dict[str, Any], - tp_group: dist_group_type, - tp_size: int, - ) -> None: - """ - Called at FP8 autocast end in backward. - Performs AMAX reduction and delete unused buffer entries. - """ - self._amax_reduce_wait_func = self._global_amax_reduction( - fp8_meta, tp_group, tp_size - ) # _wait_handle_and_split - self._execute_deletion() - - -class FP8RecomputeBuffer: - """Buffer used to hold FP8 meta tensors for recompute""" - - def __init__(self): - self._global_amax = [] - - @staticmethod - def get_buffer_position_key(): - """Returns the key (in fp8_meta) for recompute buffer position""" - return "recompute_buffer_pos" - - def stash_fp8_meta_tensors(self, fp8_meta: Dict[str, Any]) -> None: - """Stash the scaling factors and amaxes for recompute""" - buffer_position_key = self.get_buffer_position_key() - - to_copy = [ - fp8_meta["scaling_fwd"].amax_history.clone(), - fp8_meta["scaling_fwd"].scale.clone(), - fp8_meta["scaling_fwd"].scale_inv.clone(), - ] - - if buffer_position_key in fp8_meta: - self._global_amax[fp8_meta[buffer_position_key]].append(to_copy) - else: - self._global_amax.append(deque()) - self._global_amax[-1].append(to_copy) - fp8_meta[buffer_position_key] = len(self._global_amax) - 1 - - def retrieve_fp8_meta_tensors(self, fp8_meta: Dict[str, Any]) -> None: - """Switch to the previously saved scaling factors and amaxes""" - # Store updated amaxes and scales from phase 1 post forward. - fp8_meta["updated_amax_history_fwd"] = fp8_meta["scaling_fwd"].amax_history - fp8_meta["updated_scale_fwd"] = fp8_meta["scaling_fwd"].scale - fp8_meta["updated_scale_inv_fwd"] = fp8_meta["scaling_fwd"].scale_inv - - # Retrieve stashed amaxes and scales from phase 1 pre forward. - buffer_position_key = self.get_buffer_position_key() - stashed_fp8_meta = self._global_amax[fp8_meta[buffer_position_key]].popleft() - - # Replace amaxes and scales with stashed values for phase 2 forward - fp8_meta["scaling_fwd"].amax_history = stashed_fp8_meta[0] - fp8_meta["scaling_fwd"].scale = stashed_fp8_meta[1] - fp8_meta["scaling_fwd"].scale_inv = stashed_fp8_meta[2] - - @staticmethod - def restore_fp8_meta_tensors(fp8_meta: Dict[str, Any]) -> None: - """Restore latest scaling factors and amaxes after recompute forward run.""" - assert "updated_amax_history_fwd" in fp8_meta, ( - "Recompute internal error." - " If you are not using recompute, please check if" - " the forward function is called from one of these functions: " - f"{RecomputeFunctionNames}. If so, consider change the function name " - "or set NVTE_DISABLE_RECOMPUTE=1." - ) - fp8_meta["scaling_fwd"].amax_history = fp8_meta["updated_amax_history_fwd"] - fp8_meta["scaling_fwd"].scale = fp8_meta["updated_scale_fwd"] - fp8_meta["scaling_fwd"].scale_inv = fp8_meta["updated_scale_inv_fwd"] diff --git a/transformer_engine/paddle/layer/__init__.py b/transformer_engine/paddle/layer/__init__.py deleted file mode 100644 index 4d81ca231a..0000000000 --- a/transformer_engine/paddle/layer/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""Layer level Paddle APIs""" - -from .attention import DotProductAttention, MultiHeadAttention, RotaryPositionEmbedding -from .layernorm import LayerNorm -from .layernorm_linear import LayerNormLinear -from .layernorm_mlp import LayerNormMLP -from .linear import Linear -from .softmax import FusedScaleMaskSoftmax -from .transformer import TransformerLayer diff --git a/transformer_engine/paddle/layer/attention.py b/transformer_engine/paddle/layer/attention.py deleted file mode 100644 index d3b0950dee..0000000000 --- a/transformer_engine/paddle/layer/attention.py +++ /dev/null @@ -1,1161 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""Attntion API""" - -import math -import os -import warnings -from typing import Optional, Tuple, Union - -import paddle -import paddle.nn.functional as F - -try: - from paddle.incubate.nn.functional import fused_rotary_position_embedding -except ImportError: - fused_rotary_position_embedding = None -from transformer_engine import transformer_engine_paddle as tex - -from .layernorm_linear import LayerNormLinear -from .linear import Linear -from .softmax import FusedScaleMaskSoftmax -from ..constants import ( - AttnTypes, - TE_DType, - AttnBiasType, - AttnMaskType, - FusedAttnBackend, - dist_group_type, -) -from ..cpp_extensions import ( - fused_attn_fwd_qkvpacked, - fused_attn_bwd_qkvpacked, - fused_attn_fwd_kvpacked, - fused_attn_bwd_kvpacked, - fused_attn_fwd, - fused_attn_bwd, - mask_to_cu_seqlens, -) -from ..distributed import get_tp_group_and_world_size, track_rng_state -from ..utils import attention_mask_func, divide -from ..recompute import recompute - -__all__ = ["DotProductAttention", "MultiHeadAttention", "RotaryPositionEmbedding"] - - -def repeat_kv(hidden_states: paddle.Tensor, n_rep: int) -> paddle.Tensor: - """ - Used to repeat the key and value states for GQA. - The hidden states go from (batch, seqlen, num_gqa_groups, head_size) - to (batch, seqlen, num_heads, head_size) - """ - batch, seqlen, num_gqa_groups, head_size = hidden_states.shape - if n_rep == 1: - return hidden_states - - hidden_states = hidden_states.unsqueeze(-2).tile([1, 1, 1, n_rep, 1]) - return hidden_states.reshape([batch, seqlen, num_gqa_groups * n_rep, head_size]) - - -class RotaryPositionEmbedding(paddle.nn.Layer): - """ - Implements Rotary Position Embedding from https://arxiv.org/abs/2104.09864. - """ - - def __init__( - self, - dim: int, - max_position_embeddings: int, - ): - """ - Parameters - ---------- - dim: int - rotary embedding dimension - max_position_embeddings: int - max_position_embeddings before position interpolation - """ - super().__init__() - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.inv_freq = 1.0 / ( - 10000 ** (paddle.cast(paddle.arange(0, dim, 2), dtype="float32") / self.dim) - ) - self._set_cos_sin_cache(seq_len=max_position_embeddings) - - def _set_cos_sin_cache(self, seq_len): - self.max_seq_len_cached = seq_len - # [seq_len] - t = paddle.arange(seq_len, dtype="float32") - # [seq_len, dim/2] - freqs = paddle.einsum("i,j->ij", t, self.inv_freq) - # [seq_len, dim] - emb = paddle.concat([freqs, freqs], axis=-1) - # [1, seqlen, 1, dim] - self.cos_cached = emb.cos()[None, :, None, :] - self.sin_cached = emb.sin()[None, :, None, :] - - def forward(self, max_seq_len: int): - """ - Create rotary position embedding frequencies - - Parameters - ---------- - max_seq_len: int - sequence length of a sample - """ - cos = self.cos_cached[:, :, :max_seq_len, ...] - sin = self.sin_cached[:, :, :max_seq_len, ...] - return (cos, sin) - - -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return paddle.concat([-x2, x1], axis=-1) # shape is the same as x - - -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None): - """Applies rotary positional embedding to the input.""" - - if position_ids is None: - # Note: Only for LlamaForCausalLMPipe model pretraining - cos = cos[:, : q.shape[1], :, :] # [bs, seq_len, 1, dim] - sin = sin[:, : q.shape[1], :, :] # [bs, seq_len, 1, dim] - else: - cos = cos.squeeze(axis=[0, 2]) # [seq_len, dim] - sin = sin.squeeze(axis=[0, 2]) # [seq_len, dim] - cos = cos[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim] - sin = sin[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim] - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - -class FusedAttnFuncPackedQKV(paddle.autograd.PyLayer): - """Function for FusedAttention with packed QKV input""" - - @staticmethod - def forward( - ctx, - qkv, - cu_seqlens, - attn_bias, - max_seqlen, - attn_scale, - qkv_dtype, - dropout_p, - set_zero, - qkv_layout, - attn_bias_type, - attn_mask_type, - is_training, - deterministic, - fused_attention_backend, - ): - """Forward function for FusedAttention with packed QKV input""" - out, softmax_aux, rng_state = fused_attn_fwd_qkvpacked( - qkv, - cu_seqlens, - is_training, - max_seqlen, - qkv_dtype, - fused_attention_backend, - attn_bias, - attn_scale, - dropout_p, - set_zero, - qkv_layout, - attn_bias_type, - attn_mask_type, - ) - - ctx.save_for_backward(qkv, out, cu_seqlens, rng_state, softmax_aux) - ctx.max_seqlen = max_seqlen - ctx.qkv_dtype = qkv_dtype - ctx.attn_scale = attn_scale - ctx.dropout_p = dropout_p - ctx.set_zero = set_zero - ctx.qkv_layout = qkv_layout - ctx.attn_bias_type = attn_bias_type - ctx.attn_mask_type = attn_mask_type - ctx.deterministic = deterministic - ctx.fused_attention_backend = fused_attention_backend - - return out - - @staticmethod - def backward(ctx, d_out): - """Backward function for FusedAttention with packed QKV input""" - qkv, out, cu_seqlens, rng_state, softmax_aux = ctx.saved_tensor() - dqkv, *rest = fused_attn_bwd_qkvpacked( - qkv, - cu_seqlens, - rng_state, - out, - d_out, - softmax_aux, - ctx.fused_attention_backend, - ctx.max_seqlen, - ctx.qkv_dtype, - ctx.attn_scale, - ctx.dropout_p, - ctx.set_zero, - ctx.qkv_layout, - ctx.attn_bias_type, - ctx.attn_mask_type, - ctx.deterministic, - ) - - # if no_bias, return dqkv - if ctx.attn_bias_type == "no_bias": - return (dqkv, None) - # else, return (dqkv, dbias) - return (dqkv, None, rest[0]) - - -class FusedAttnFuncPackedKV(paddle.autograd.PyLayer): - """Function for FusedAttention with packed KV input""" - - @staticmethod - def forward( - ctx, - q, - kv, - cu_seqlens_q, - cu_seqlens_kv, - attn_bias, - max_seqlen_q, - max_seqlen_kv, - attn_scale, - qkv_dtype, - dropout_p, - set_zero, - qkv_layout, - attn_bias_type, - attn_mask_type, - is_training, - deterministic, - fused_attention_backend, - ): - """Forward function for FusedAttention with packed KV input""" - out, softmax_aux, rng_state = fused_attn_fwd_kvpacked( - q, - kv, - cu_seqlens_q, - cu_seqlens_kv, - is_training, - max_seqlen_q, - max_seqlen_kv, - qkv_dtype, - fused_attention_backend, - attn_bias, - attn_scale, - dropout_p, - set_zero, - qkv_layout, - attn_bias_type, - attn_mask_type, - ) - - ctx.save_for_backward(q, kv, out, cu_seqlens_q, cu_seqlens_kv, rng_state, softmax_aux) - ctx.max_seqlen_q = max_seqlen_q - ctx.max_seqlen_kv = max_seqlen_kv - ctx.qkv_dtype = qkv_dtype - ctx.attn_scale = attn_scale - ctx.dropout_p = dropout_p - ctx.set_zero = set_zero - ctx.qkv_layout = qkv_layout - ctx.attn_bias_type = attn_bias_type - ctx.attn_mask_type = attn_mask_type - ctx.deterministic = deterministic - ctx.fused_attention_backend = fused_attention_backend - - return out - - @staticmethod - def backward(ctx, d_out): - """Backward function for FusedAttention with packed KV input""" - q, kv, out, cu_seqlens_q, cu_seqlens_kv, rng_state, softmax_aux = ctx.saved_tensor() - dq, dkv, *rest = fused_attn_bwd_kvpacked( - q, - kv, - cu_seqlens_q, - cu_seqlens_kv, - rng_state, - out, - d_out, - softmax_aux, - ctx.fused_attention_backend, - ctx.max_seqlen_q, - ctx.max_seqlen_kv, - ctx.qkv_dtype, - ctx.attn_scale, - ctx.dropout_p, - ctx.set_zero, - ctx.qkv_layout, - ctx.attn_bias_type, - ctx.attn_mask_type, - ctx.deterministic, - ) - - # if no_bias, return dq, dkv - if ctx.attn_bias_type == "no_bias": - return (dq, dkv, None, None) - # else, return (dq, dkv, dbias) - return (dq, dkv, None, None, rest[0]) - - -class FusedAttnFunc(paddle.autograd.PyLayer): - """Function for FusedAttention with separate Q, K, V tensors""" - - @staticmethod - def forward( - ctx, - q, - k, - v, - cu_seqlens_q, - cu_seqlens_kv, - attn_bias, - max_seqlen_q, - max_seqlen_kv, - attn_scale, - qkv_dtype, - dropout_p, - set_zero, - qkv_layout, - attn_bias_type, - attn_mask_type, - is_training, - deterministic, - fused_attention_backend, - ): - """Forward function for FusedAttention with separate Q, K, V tensors""" - out, softmax_aux, rng_state = fused_attn_fwd( - q, - k, - v, - cu_seqlens_q, - cu_seqlens_kv, - is_training, - max_seqlen_q, - max_seqlen_kv, - qkv_dtype, - fused_attention_backend, - attn_bias, - attn_scale, - dropout_p, - set_zero, - qkv_layout, - attn_bias_type, - attn_mask_type, - ) - - ctx.save_for_backward(q, k, v, out, cu_seqlens_q, cu_seqlens_kv, rng_state, softmax_aux) - ctx.max_seqlen_q = max_seqlen_q - ctx.max_seqlen_kv = max_seqlen_kv - ctx.qkv_dtype = qkv_dtype - ctx.attn_scale = attn_scale - ctx.dropout_p = dropout_p - ctx.set_zero = set_zero - ctx.qkv_layout = qkv_layout - ctx.attn_bias_type = attn_bias_type - ctx.attn_mask_type = attn_mask_type - ctx.deterministic = deterministic - ctx.fused_attention_backend = fused_attention_backend - - return out - - @staticmethod - def backward(ctx, d_out): - """Backward function for FusedAttention with separate Q, K, V tensors""" - q, k, v, out, cu_seqlens_q, cu_seqlens_kv, rng_state, softmax_aux = ctx.saved_tensor() - dq, dk, dv, *rest = fused_attn_bwd( - q, - k, - v, - cu_seqlens_q, - cu_seqlens_kv, - rng_state, - out, - d_out, - softmax_aux, - ctx.fused_attention_backend, - ctx.max_seqlen_q, - ctx.max_seqlen_kv, - ctx.qkv_dtype, - ctx.attn_scale, - ctx.dropout_p, - ctx.set_zero, - ctx.qkv_layout, - ctx.attn_bias_type, - ctx.attn_mask_type, - ctx.deterministic, - ) - # if no_bias, return dq, dk, dv - if ctx.attn_bias_type == "no_bias": - return (dq, dk, dv, None, None) - # else, return (dq, dk, dv, dbias) - return (dq, dk, dv, None, None, rest[0]) - - -class DotProductAttention(paddle.nn.Layer): - """ - Allows the model to jointly attend to information from different - representation subspaces as described in the paper: - `Attention Is All You Need `_. - - .. note:: - - Argument :attr:`attention_mask` will be ignored in the `forward` call when - :attr:`attn_mask_type` is set to `"causal"`. - - .. warning:: - - Fused attention backward uses a non-deterministic algorithm when workspace - optimization is not enabled. To use a deterministic algorithm, set the - environment variable :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO=0` - - Parameters - ---------- - num_attention_heads: int - number of attention heads in the transformer layer. - kv_channels: int - number of channels in the key and value tensors. - num_gqa_groups : Optional[int] = None - number of GQA groups in the transformer layer. - Grouped Query Attention is described in - `this paper `_. - This only affects the keys and values, not the queries. - GQA-1 is equivalent to Multi-Query Attention - (`MQA `_), while GQA-H - is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`. - attention_dropout: float, default = 0.1 - dropout probability for the dropout op during multi-head attention. - attn_mask_type: {'causal', 'padding', 'no_mask'}, default = `causal` - type of attention mask passed into softmax operation. - attention_type: {'self', 'cross'}, default = `self` - type of attention operation. - tp_group : ProcessGroup, default = `None` - tensor parallel process group. - backend: {'transformer_engine', 'paddle'}, default = `transformer_engine` - backend to use for attention operation. - """ - - def __init__( - self, - num_attention_heads: int, - kv_channels: int, - num_gqa_groups: Optional[int] = None, - attention_dropout: float = 0.1, - attn_mask_type: str = "causal", - attention_type: str = "self", - tp_size: int = 1, - backend: str = "transformer_engine", - ) -> None: - super().__init__() - - self.attn_mask_type = attn_mask_type - self.attention_dropout = attention_dropout - self.attention_type = attention_type - self.qkv_layout = "bshd_bshd_bshd" - self.hidden_size_per_attention_head = kv_channels - self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) - self.tp_size = tp_size - self.num_gqa_groups = num_attention_heads if num_gqa_groups is None else num_gqa_groups - self.num_gqa_groups_per_partition = int(self.num_gqa_groups // tp_size) - self.num_queries_per_key_value = num_attention_heads // self.num_gqa_groups - - self.backend = backend - - self.use_fused_attention = bool(int(os.getenv("NVTE_FUSED_ATTN", "1"))) - - self.deterministic = not bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1"))) - - # To use the workspace optimization path for determinism, please - # set NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT=1 for cuDNN >=8.9.5 and <9.0.0, - # and set NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 for cuDNN >=9.0.0. - cudnn_version = paddle.get_cudnn_version() - if 8905 <= cudnn_version < 9000: - if self.deterministic: - # workspace optimization path is deterministic - os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] = "1" - - # CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT - # - unset: enables workspace optimization when required workspace is <= 256MB - # or when bias gradient needs to be computed - # - n: enables workspace optimization when required workspace is <= n bytes - # - -1: enables workspace optimization always - # - 0: disables workspace optimization always - if "NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT" in os.environ: - if os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] == "0": - os.environ["CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT"] = "0" - if os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] == "1": - os.environ["CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT"] = "-1" - - if not self.use_fused_attention and backend == "transformer_engine": - warnings.warn("Fused attention is not enabled, falling back to Paddle backend") - self.backend = "paddle" - - if self.backend != "transformer_engine": - self.scale_mask_softmax = FusedScaleMaskSoftmax( - attn_mask_type, attention_mask_func, backend=self.backend - ) - - def forward( - self, - query_layer: paddle.Tensor, - key_layer: paddle.Tensor, - value_layer: paddle.Tensor, - attention_mask: Optional[paddle.Tensor] = None, - core_attention_bias_type: str = "no_bias", - core_attention_bias: Optional[paddle.Tensor] = None, - set_zero: bool = True, - ) -> paddle.Tensor: - """ - Dot Product Attention Layer. - - .. note:: - - Argument :attr:`attention_mask` will be ignored when :attr:`attn_mask_type` - is set to `"causal"`. - - - Parameters - ---------- - query_layer : paddle.Tensor - Query tensor. - key_layer : paddle.Tensor - Key tensor. - value_layer : paddle.Tensor - Value tensor. - attention_mask : Optional[paddle.Tensor], default = `None` - Boolean tensor used to mask out softmax input when not using attention. - core_attention_bias_type: str, default = `no_bias` - only support no_bias type currently, {`no_bias`} - core_attention_bias: Optional[paddle.Tensor], default = `None` - Bias tensor for Q * K.T - set_zero: bool, default = `True` - Whether to use the fast path to set output tensors to 0 or not. - """ - - backend = self.backend - - assert key_layer.shape == value_layer.shape, "Keys and values must have the same shape!" - assert ( - key_layer.shape[-2] == self.num_gqa_groups_per_partition - ), f"Keys and values must have num_gqa_group = {self.num_gqa_groups} heads!" - - if backend == "transformer_engine": - max_s_q = query_layer.shape[1] - max_s_kv = max_s_q if self.attention_type == "self" else key_layer.shape[1] - self.fused_attention_backend = tex.get_fused_attn_backend( - TE_DType[query_layer.dtype], - TE_DType[query_layer.dtype], - tex.get_nvte_qkv_layout(self.qkv_layout), - AttnBiasType[core_attention_bias_type], - AttnMaskType[self.attn_mask_type], - self.attention_dropout, - query_layer.shape[-2], - key_layer.shape[-2] if key_layer is not None else query_layer.shape[-2], - max_s_q, - max_s_kv, - query_layer.shape[-1], - ) - - is_backend_avail = self.fused_attention_backend in [ - FusedAttnBackend["F16_max512_seqlen"], - FusedAttnBackend["F16_arbitrary_seqlen"], - ] - if is_backend_avail and self.use_fused_attention: - return self._te_forward( - query_layer, - key_layer, - value_layer, - attention_mask, - core_attention_bias_type, - core_attention_bias, - set_zero, - ) - warnings.warn("Fused attention is not enabled, falling back to Paddle backend") - backend = "paddle" - self.scale_mask_softmax = FusedScaleMaskSoftmax( - self.attn_mask_type, attention_mask_func, backend=backend - ) - if backend == "paddle": - if core_attention_bias_type != "no_bias": - warnings.warn( - "Paddle backend dot product attention does not support bias yet. " - "Bias will be ignored." - ) - return self._pd_forward(query_layer, key_layer, value_layer, attention_mask) - raise AttributeError(f"Backend {backend} is not supported.") - - def _te_forward( - self, - query_layer: paddle.Tensor, - key_layer: paddle.Tensor, - value_layer: paddle.Tensor, - attention_mask: Optional[paddle.Tensor] = None, - core_attention_bias_type: str = "no_bias", - core_attention_bias: Optional[paddle.Tensor] = None, - set_zero: bool = True, - ) -> paddle.Tensor: - - if self.attention_type == "self": - # self attention - q: [b, s, h, d] kv: None - assert ( - len(query_layer.shape) == 4 - and len(key_layer.shape) == 4 - and len(value_layer.shape) == 4 - ), "q,k,v shape must be [b, s, h, d] for dot product self attention" - max_seqlen = query_layer.shape[1] - if self.attn_mask_type == "causal" or attention_mask is None: - cu_seqlens = paddle.arange( - 0, - (query_layer.shape[0] + 1) * query_layer.shape[1], - step=query_layer.shape[1], - dtype="int32", - ) - else: - cu_seqlens, _ = mask_to_cu_seqlens(attention_mask, need_kv=False) - qkv_dtype = TE_DType[query_layer.dtype] - - output = FusedAttnFunc.apply( - query_layer, - key_layer, - value_layer, - cu_seqlens, - cu_seqlens, - core_attention_bias, - max_seqlen, - max_seqlen, - 1.0 / self.norm_factor, - qkv_dtype, - self.attention_dropout if self.training else 0.0, - set_zero, - self.qkv_layout, - core_attention_bias_type, - self.attn_mask_type, - self.training, - self.deterministic, - self.fused_attention_backend, - ) - elif self.attention_type == "cross": - # cross attention - q: [b, s_q, h, d] k,v: [b, s_kv, h, d] - assert ( - len(query_layer.shape) == 4 - and len(key_layer.shape) == 4 - and len(value_layer.shape) == 4 - ), ( - "query shape must be [b, s_q, h, d] and key shape must be [b, s_kv, h, d]" - "for dot product cross attention" - ) - assert attention_mask is not None, "attention_mask must be provided for cross attention" - max_seqlen_q = query_layer.shape[1] - max_seqlen_kv = key_layer.shape[1] - cu_seqlens_q, cu_seqlens_kv = mask_to_cu_seqlens(attention_mask, need_kv=True) - qkv_dtype = TE_DType[query_layer.dtype] - output = FusedAttnFunc.apply( - query_layer, - key_layer, - value_layer, - cu_seqlens_q, - cu_seqlens_kv, - core_attention_bias, - max_seqlen_q, - max_seqlen_kv, - 1.0 / self.norm_factor, - qkv_dtype, - self.attention_dropout if self.training else 0.0, - set_zero, - self.qkv_layout, - core_attention_bias_type, - self.attn_mask_type, - self.training, - self.deterministic, - self.fused_attention_backend, - ) - else: - raise ValueError("attention_type must be one of ['self', 'cross']") - return output - - def _pd_forward( - self, - query_layer: paddle.Tensor, - key_layer: paddle.Tensor, - value_layer: paddle.Tensor, - attention_mask: Optional[paddle.Tensor] = None, - ) -> paddle.Tensor: - - q = query_layer - k = repeat_kv(key_layer, self.num_queries_per_key_value) - v = repeat_kv(value_layer, self.num_queries_per_key_value) - - q = paddle.transpose(x=q, perm=[0, 2, 1, 3]) - k = paddle.transpose(x=k, perm=[0, 2, 1, 3]) - v = paddle.transpose(x=v, perm=[0, 2, 1, 3]) - - product = paddle.matmul(x=q * (1.0 / self.norm_factor), y=k, transpose_y=True) - attention_probs = self.scale_mask_softmax(product, attention_mask, scale=None) - - if self.attention_dropout > 0: - attention_probs = F.dropout( - attention_probs, - self.attention_dropout, - training=self.training, - ) - - out = paddle.matmul(attention_probs, v) - out = paddle.transpose(out, perm=[0, 2, 1, 3]) # [b, s, h, d] - # out = paddle.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]]) - return out - - -class MultiHeadAttention(paddle.nn.Layer): - """ - Multi-head Attention (MHA), including Query, - Key, Value and Output projection. - - Parameters - ---------- - hidden_size: int - hidden size of the model. - num_attention_heads: int - number of attention heads. - attention_dropout: float, default = 0.1 - dropout probability for the dropout op during multi-head attention. - layernorm_epsilon: float, default = 1e-5 - epsilon to use in the layer norm operations. - weight_attr: Union[paddle.ParamAttr, None], default = `None` - paddle.ParamAttr object for the weight parameter. - bias_attr: Union[paddle.ParamAttr, None, bool], default = `None` - paddle.ParamAttr object for the bias parameter. - attn_mask_type: {'causal', 'padding', 'no_mask'}, default = `causal` - type of attention mask passed into softmax operation. - params_dtype: Optional[paddle.dtype], default = `None` - data type for the weights and biases. - return_layernorm_output: bool, default = `False` - whether to return the output of the layernorm operation. - input_layernorm: bool, default = `False` - whether to apply layernorm to the input. - attention_type: {'self', 'cross'}, default = `self` - type of attention operation. - normalization : { 'LayerNorm', 'RMSNorm' }, default = 'LayerNorm' - type of normalization applied. - zero_centered_gamma: bool, default = `False` - whether to zero initialize the gamma of the layernorm operation. - backend: {'transformer_engine', 'paddle'}, default = `transformer_engine` - backend to use for attention operation. If set to 'paddle', a framework - only no-FP8 path is executed with limited optimization. - - Parallelism parameters - ---------------------- - set_parallel_mode : bool, default = `False` - if set to `True`, QKV and FC1 layers are used as Column Parallel - whereas PROJ and FC2 is used as Row Parallel as described - `here `_. - sequence_parallel : bool, default = `False` - if set to `True`, uses sequence parallelism. - tp_group : ProcessGroup, default = `None` - tensor parallel process group. - num_gqa_groups : int, default = `None` - number of GQA groups in the transformer layer. - Grouped Query Attention is described in - `this paper `_. - This only affects the keys and values, not the querys. - GQA-1 is equivalent to Multi-Query Attention - (`MQA `_), while GQA-H - is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`. - rng_state_name : str, default = `local_seed` - Controls the rng state used for dropout on attention probs. The - specified rng should be set different seeds for different TP ranks. - It will be ignored if `set_parallel_mode` is False. The specified - name should be registered through - `paddle.distributed.fleet.meta_parallel.get_rng_state_tracker() - .add(rng_state_name, seed)`. - - Optimization parameters - ----------------------- - fuse_wgrad_accumulation : bool, default = 'False' - if set to `True`, enables fusing of creation and accumulation of - the weight gradient. When enabled, it is assumed that the weights - have an additional `main_grad` attribute (used instead of the - regular `grad`) which is a pre-allocated buffer of the correct - size to accumulate gradients in. - - """ - - def __init__( - self, - hidden_size: int, - num_attention_heads: int, - attention_dropout: float = 0.1, - layernorm_epsilon: float = 1e-5, - weight_attr: Union[paddle.ParamAttr, None] = None, - bias_attr: Union[paddle.ParamAttr, None, bool] = None, - max_sequence_length: Optional[int] = None, - attn_mask_type: str = "causal", - params_dtype: Optional[paddle.dtype] = None, - return_layernorm_output: bool = False, - input_layernorm: bool = False, - attention_type: str = "self", - normalization: str = "LayerNorm", - zero_centered_gamma: bool = False, - set_parallel_mode: bool = False, - sequence_parallel: bool = False, - tp_group: Optional[dist_group_type] = None, - num_gqa_groups: Optional[int] = None, - fuse_wgrad_accumulation: bool = False, - rng_state_name: str = "local_seed", - backend: str = "transformer_engine", - ) -> None: - super().__init__() - self.input_layernorm = input_layernorm - self.attention_type = attention_type - self.return_layernorm_output = return_layernorm_output - self.params_dtype = paddle.get_default_dtype() if params_dtype is None else params_dtype - self.max_sequence_length = max_sequence_length - self.weight_attr = weight_attr - self.bias_attr = bias_attr - self.attn_mask_type = attn_mask_type - - assert attention_type in AttnTypes, f"attention_type {attention_type} not supported" - - self.tp_group, self.tp_size = get_tp_group_and_world_size( - tp_group, enable_tp=set_parallel_mode - ) - self.tensor_parallel = self.tp_size > 1 - self.sequence_parallel = self.tensor_parallel and sequence_parallel - self.hidden_size_per_attention_head = hidden_size // num_attention_heads - self.num_attention_heads = num_attention_heads - self.set_parallel_mode = set_parallel_mode - self.rng_state_name = rng_state_name - self.backend = backend - - self.num_attention_heads_per_partition = divide(self.num_attention_heads, self.tp_size) - self.num_gqa_groups = num_attention_heads if num_gqa_groups is None else num_gqa_groups - assert ( - self.num_attention_heads % self.num_gqa_groups == 0 - ), "The number of attention heads must be divisible by the number of GQA groups!" - assert ( - self.num_gqa_groups % self.tp_size == 0 - ), "The number of GQA groups must be divisible by tensor parallel size!" - self.num_gqa_groups_per_partition = int(self.num_gqa_groups // self.tp_size) - self.hidden_size_kv = int(hidden_size * self.num_gqa_groups // self.num_attention_heads) - qkv_parallel_mode = "column" if set_parallel_mode else None - - if self.attention_type == "self": - if self.input_layernorm: - self.layernorm_qkv = LayerNormLinear( - hidden_size, - hidden_size + 2 * self.hidden_size_kv, - eps=layernorm_epsilon, - weight_attr=self.weight_attr, - bias_attr=self.bias_attr, - return_layernorm_output=return_layernorm_output, - normalization=normalization, - zero_centered_gamma=zero_centered_gamma, - parallel_mode=qkv_parallel_mode, - sequence_parallel=self.sequence_parallel, - tp_group=self.tp_group, - fuse_wgrad_accumulation=fuse_wgrad_accumulation, - backend=self.backend, - ) - else: - self.qkv = Linear( - hidden_size, - hidden_size + 2 * self.hidden_size_kv, - self.weight_attr, - self.bias_attr, - parallel_mode=qkv_parallel_mode, - sequence_parallel=self.sequence_parallel, - tp_group=self.tp_group, - fuse_wgrad_accumulation=fuse_wgrad_accumulation, - backend=self.backend, - ) - - else: # cross attention - if self.input_layernorm: - self.layernorm_query = LayerNormLinear( - hidden_size, - hidden_size, - eps=layernorm_epsilon, - weight_attr=self.weight_attr, - bias_attr=self.bias_attr, - return_layernorm_output=return_layernorm_output, - normalization=normalization, - zero_centered_gamma=zero_centered_gamma, - parallel_mode=qkv_parallel_mode, - sequence_parallel=self.sequence_parallel, - tp_group=self.tp_group, - fuse_wgrad_accumulation=fuse_wgrad_accumulation, - backend=self.backend, - ) - else: - self.query_layer = Linear( - hidden_size, - hidden_size, - self.weight_attr, - self.bias_attr, - parallel_mode=qkv_parallel_mode, - sequence_parallel=self.sequence_parallel, - tp_group=self.tp_group, - fuse_wgrad_accumulation=fuse_wgrad_accumulation, - backend=self.backend, - ) - self.key_value = Linear( - hidden_size, - 2 * self.hidden_size_kv, - self.weight_attr, - self.bias_attr, - parallel_mode=qkv_parallel_mode, - sequence_parallel=self.sequence_parallel, - tp_group=self.tp_group, - fuse_wgrad_accumulation=fuse_wgrad_accumulation, - backend=self.backend, - ) - - # Attention. - self.core_attention = DotProductAttention( - self.num_attention_heads, - self.hidden_size_per_attention_head, - self.num_gqa_groups, - attention_dropout, - attn_mask_type=attn_mask_type, - attention_type=self.attention_type, - tp_size=self.tp_size, - backend=self.backend, - ) - - # Linear - self.proj = Linear( - hidden_size, - hidden_size, - self.weight_attr, - self.bias_attr, - parallel_mode="row" if set_parallel_mode else None, - sequence_parallel=self.sequence_parallel, - tp_group=self.tp_group, - fuse_wgrad_accumulation=fuse_wgrad_accumulation, - backend=self.backend, - ) - - def forward( - self, - hidden_states: paddle.Tensor, - attention_mask: Optional[paddle.Tensor] = None, - encoder_output: Optional[paddle.Tensor] = None, - rotary_pos_emb: Optional[Tuple[paddle.Tensor, paddle.Tensor]] = None, - core_attention_bias_type: str = "no_bias", - core_attention_bias: Optional[paddle.Tensor] = None, - set_zero: bool = True, - recompute_core_attention: bool = False, - is_first_microbatch: Optional[bool] = None, - ) -> Tuple[Union[paddle.Tensor, None], ...]: - """ - MultiHeadAttention Layer. - - Parameters - ---------- - hidden_states : paddle.Tensor - Input tensor. - attention_mask : Optional[paddle.Tensor], default = `None` - Boolean tensor used to mask out softmax input when not using attention. - encoder_output : Optional[paddle.Tensor], default = `None` - Output of the encoder layer. - rotary_pos_emb: Tuple[paddle.Tensor, paddle.Tensor], default = `None` - Embeddings for query and key tensors for applying rotary position - embedding. By default no input embedding is applied. - core_attention_bias_type: str, default = `no_bias` - only support no_bias type currently, {`no_bias`} - core_attention_bias: Optional[paddle.Tensor], default = `None` - Bias tensor for Q * K.T - set_zero: bool, default = `True` - Whether to use the fast path to set output tensors to 0 or not. - recompute_core_attention: bool, default = `False` - If true, forward activations for core attention are recomputed - during the backward pass in order to save memory that would - otherwise be occupied to store the forward activations until - backprop. - is_first_microbatch : {True, False, None}, default = None - During training using either gradient accumulation or - pipeline parallelism a minibatch of data is further split - into microbatches. Between the microbatches of the same minibatch - the model weights are not updated. Setting this parameter indicates - whether the current microbatch is the first in a minibatch or not. - When set, this parameter enables additional optimizations: - - * during FP8 training, it allows caching of the FP8 versions of - the weights - """ - - if self.attn_mask_type != "causal" and attention_mask is not None: - assert attention_mask.dtype == paddle.bool, "Attention mask must be a boolean tensor" - - input_dim = len(hidden_states.shape) - if input_dim == 2: - # hidden_states: [b * s_q, hidden_size] - # need to get max_seq_len from attention_mask - assert self.max_sequence_length is not None, "max_sequence_length must be provided" - max_seq_len = self.max_sequence_length - elif input_dim == 3: - # hidden_states: [b, s_q, hidden_size] - max_seq_len = hidden_states.shape[1] - else: - raise ValueError(f"hidden_states should have 2 or 3 dimensions, got {input_dim}.") - - layernorm_output = None - if self.attention_type == "self": - if self.input_layernorm: - layernorm_qkv_outputs = self.layernorm_qkv( - hidden_states, - is_first_microbatch=is_first_microbatch, - ) - if self.return_layernorm_output: - mixed_qkv_layer, layernorm_output = layernorm_qkv_outputs - else: - mixed_qkv_layer = layernorm_qkv_outputs - else: - mixed_qkv_layer = self.qkv( - hidden_states, - is_first_microbatch=is_first_microbatch, - ) - - num_queries_per_key_value = ( - self.num_attention_heads_per_partition // self.num_gqa_groups_per_partition - ) - - # [b, s_q, hidden_size+2*hidden_size_kv] --> [b, s_q, (h/ng+2), ng, d] - mixed_qkv_layer = mixed_qkv_layer.reshape( - shape=[ - -1, - max_seq_len, - (num_queries_per_key_value + 2), - self.num_gqa_groups_per_partition, - self.hidden_size_per_attention_head, - ] - ) - - # [b, s_q, (h/ng+2), ng, d] - # --> [b, s_q, (h/ng), ng, d] [b, s_q, 1, ng, d] [b, s_q, 1, ng, d] - query_layer, key_layer, value_layer = paddle.split( - mixed_qkv_layer, - num_or_sections=(num_queries_per_key_value, 1, 1), - axis=2, - ) - - # query: -> [b, s, h, d] - # key, value: -> [b, s, ng, d] - query_layer, key_layer, value_layer = ( - x.reshape(shape=[x.shape[0], x.shape[1], -1, self.hidden_size_per_attention_head]) - for x in (query_layer, key_layer, value_layer) - ) - - else: # cross attention - mixed_kv_layer = self.key_value( - encoder_output, - is_first_microbatch=is_first_microbatch, - ) - # [b, s_kv, 2 * hidden_size] --> [b, s_kv, 2, num_heads, head_size] - mixed_kv_layer = mixed_kv_layer.reshape( - shape=[ - 0, - 0, - 2 * self.num_gqa_groups_per_partition, - self.hidden_size_per_attention_head, - ] - ) - - # [b, s_kv, 2 * ng, head_size] - # --> 2 [b, s_kv, ng, head_size] - key_layer, value_layer = paddle.split( - mixed_kv_layer, - num_or_sections=2, - axis=2, - ) - - if self.input_layernorm: - layernorm_query_outputs = self.layernorm_query( - hidden_states, - is_first_microbatch=is_first_microbatch, - ) - if self.return_layernorm_output: - query_layer, layernorm_output = layernorm_query_outputs - else: - query_layer = layernorm_query_outputs - else: - query_layer = self.query_layer( - hidden_states, - is_first_microbatch=is_first_microbatch, - ) - - # [b, s, hidden_size] --> [b, s, h, d] - query_layer = query_layer.reshape( - shape=[ - -1, - max_seq_len, - self.num_attention_heads_per_partition, - self.hidden_size_per_attention_head, - ] - ) - - if rotary_pos_emb is not None: - q_pos_emb, k_pos_emb = rotary_pos_emb - if fused_rotary_position_embedding is None: - query_layer, key_layer = apply_rotary_pos_emb( - query_layer, key_layer, q_pos_emb, k_pos_emb - ) - else: - query_layer, key_layer, _ = fused_rotary_position_embedding( - query_layer, - key_layer, - v=None, - sin=k_pos_emb, - cos=q_pos_emb, - position_ids=None, - use_neox_rotary_style=False, - ) - - with track_rng_state(enable=self.tensor_parallel, name=self.rng_state_name): - if recompute_core_attention: - context_layer = recompute( - self.core_attention, - query_layer, - key_layer, - value_layer, - attention_mask, - core_attention_bias_type, - core_attention_bias, - set_zero, - use_reentrant=False, - ) - else: - context_layer = self.core_attention( - query_layer=query_layer, - key_layer=key_layer, - value_layer=value_layer, - attention_mask=attention_mask, - core_attention_bias_type=core_attention_bias_type, - core_attention_bias=core_attention_bias, - set_zero=set_zero, - ) - - if input_dim == 3: - context_layer = paddle.reshape( - context_layer, [-1, max_seq_len, context_layer.shape[2] * context_layer.shape[3]] - ) - else: # input_dim == 2 - context_layer = paddle.reshape( - context_layer, [-1, context_layer.shape[2] * context_layer.shape[3]] - ) - - # Output. [b, s, hidden] - attention_output = self.proj(context_layer, is_first_microbatch=is_first_microbatch) - - if self.input_layernorm and self.return_layernorm_output: - return attention_output, layernorm_output - return attention_output diff --git a/transformer_engine/paddle/layer/base.py b/transformer_engine/paddle/layer/base.py deleted file mode 100644 index a854bb70db..0000000000 --- a/transformer_engine/paddle/layer/base.py +++ /dev/null @@ -1,571 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""Base modules and utilities for TransformerEngine Paddle API""" - -from abc import ABC, abstractmethod -from contextlib import contextmanager -import os -import pickle -from typing import Generator, Dict, Tuple, Union, Any, List, Optional - -import numpy as np - -import paddle - -try: - from paddle.base import core - from paddle.base.framework import _dygraph_tracer -except ImportError: - from paddle.fluid import core - from paddle.fluid.framework import _dygraph_tracer - -from ..constants import FP8FwdTensors, FP8BwdTensors, dist_group_type -from ..cpp_extensions import cast_transpose, cast_transpose_bgrad, cast_to_fp8, transpose -from ..fp8 import ( - FP8State, - FP8TensorMeta, - amax_and_scale_update, - get_global_fp8_state, - get_fp8_te_dtype, -) -from ..distributed import allgather, register_pp_fwd_begin_hook, is_pp_enabled -from ..profile import nvtx_range -from ..recompute import is_in_recompute_phase -from ..fp8_buffer import FP8RecomputeBuffer - -_2X_ACC_FPROP = False -_2X_ACC_DGRAD = True -_2X_ACC_WGRAD = True -_cublas_workspace = None - - -def get_cublas_workspace_size_bytes() -> None: - """Return 32 MiB if using hopper, 4 MiB for all other architectures.""" - if paddle.device.cuda.get_device_capability()[0] >= 9: - return 33_554_432 - return 4_194_304 - - -def get_workspace() -> paddle.Tensor: - """Returns workspace for cublas.""" - global _cublas_workspace - if _cublas_workspace is None: - _cublas_workspace = paddle.empty( - [get_cublas_workspace_size_bytes()], - dtype="uint8", - ) - return _cublas_workspace - - -class TransformerEngineBaseLayer(paddle.nn.Layer, ABC): - """Base TE Layer.""" - - def __init__(self) -> None: - super().__init__() - assert "gpu" in paddle.device.get_device(), "TransformerEngine needs CUDA." - self.fp8_initialized = False - self.fp8_enabled = False - self.fp8_calibration = False - self.fp8_meta = {} - self.fp8_meta["fp8_checkpoint"] = False - self.fp8_meta["fp8_group"] = None - self.fp8_meta["recipe"] = FP8State.get_default_fp8_recipe() - self.fp8_meta["scaling_fwd"] = FP8TensorMeta(is_forward=True) - self.fp8_meta["scaling_bwd"] = FP8TensorMeta(is_forward=False) - self.tp_group = None - self.tp_size = 1 - self.sequence_parallel = False - self.fp8_meta["autocast_id_fwd_stack"] = [] - self.fp8_meta["async_amax_reduction"] = bool( - int(os.getenv("NVTE_ASYNC_AMAX_REDUCTION", "0")) - ) - # weights that stored in fp16 would be cast into fp8 every first microstep - self.fp8_weights = [] - self.fp8_weight_cache = {} - self.registered_pp_start_callback = False - self.current_step_id = None - - def set_activation_dtype(self, inp: paddle.Tensor) -> None: - """Get activation data type for AMP.""" - tracer = _dygraph_tracer() - if tracer and tracer._amp_level != core.AmpLevel.O0: - # Set activation_dtype to the Paddle AMP dtype if under 'paddle.amp.auto_cast' context - if tracer._amp_dtype == "float32": - self.activation_dtype = paddle.float32 - elif tracer._amp_dtype == "bfloat16": - self.activation_dtype = paddle.bfloat16 - elif tracer._amp_dtype == "float16": - self.activation_dtype = paddle.float16 - else: - raise RuntimeError(f"AMP format {tracer._amp_dtype} is not supported.") - else: - # If not under paddle.amp.auto_cast, set activation_dtype to the input dtype. - # Also, make sure the parameters match the input dtype. - - # Skip the check if activation_dtype is already set and if activation_dtype - # matches input dtype. If they do not match, e.g, when user switch from AMP - # training to normal training, activation_dtype will still be updated. - if hasattr(self, "activation_dtype") and self.activation_dtype == inp.dtype: - return - - dtype = inp.dtype - - for name, param in self.named_parameters(): - if param is not None: - assert dtype == param.dtype, ( - "Data types for parameters must match when outside of autocasted region. " - f" Found input dtype: {dtype} and {name!r} dtype: {param.dtype}" - ) - - self.activation_dtype = dtype - - # This routine is shared across FP8 and FP8_calibration paths so should not actually - # assume FP8 execution. - def fp8_init(self, num_gemms: int = 1) -> None: - """Initialize fp8 related metadata and tensors during fprop.""" - global_fp8_state = get_global_fp8_state() - self.fp8_enabled = global_fp8_state.is_fp8_enabled() - self.fp8_calibration = global_fp8_state.is_fp8_calibration() - self.fp8_meta["fp8_checkpoint"] = self.fp8_enabled or self.fp8_calibration - - if self.fp8_enabled or self.fp8_calibration: - # FP8 init has already been run and recipe is the same, don't do anything. - if ( - self.fp8_initialized - and global_fp8_state.get_fp8_recipe() == self.fp8_meta["recipe"] - ): - return - - # Set FP8, recipe, and other FP8 metadata - self.fp8_meta["recipe"] = global_fp8_state.get_fp8_recipe() - self.fp8_meta["fp8_group"] = global_fp8_state.get_fp8_group() - - # Set FP8_MAX per tensor according to recipe - self.fp8_meta["fp8_max_fwd"] = self.fp8_meta["recipe"].fp8_format.value.max_fwd - self.fp8_meta["fp8_max_bwd"] = self.fp8_meta["recipe"].fp8_format.value.max_bwd - - # Allocate scales and amaxes - amax_history_len = self.fp8_meta["recipe"].amax_history_len - self.fp8_meta["scaling_fwd"].prepare(num_gemms, amax_history_len) - self.fp8_meta["scaling_bwd"].prepare(num_gemms, amax_history_len) - self.fp8_initialized = True - else: - # If fp8 isn't enabled, turn off and return. - self.fp8_initialized = False - return - - def set_fp8_weights(self) -> None: - """Initializes FP8 weights for the module""" - if not self.fp8_enabled: - return - - for i, weight in enumerate(self.fp8_weights, start=1): - weight_cast_key = f"weight{i}_fp8" - weight_transpose_key = f"weight{i}_t_fp8" - - if ( - weight_cast_key in self.fp8_weight_cache - and self.fp8_weight_cache[weight_cast_key].shape == weight.shape - ): - return - - self.fp8_weight_cache[weight_cast_key] = paddle.empty( - shape=weight.shape, - dtype=paddle.uint8, - ) - - self.fp8_weight_cache[weight_transpose_key] = paddle.empty( - shape=[weight.shape[1], weight.shape[0]], - dtype=paddle.uint8, - ) - - def _get_fp8_state(self) -> paddle.Tensor: - """Dump FP8 state to paddle.Tensor.""" - state = None - if self.fp8_meta["fp8_checkpoint"]: - state = {} - state["scaling_fwd"] = self.fp8_meta["scaling_fwd"].to_numpy() - state["scaling_bwd"] = self.fp8_meta["scaling_bwd"].to_numpy() - state["global_fp8_fwd_buffer"] = get_global_fp8_state().get_fp8_fwd_buffer().to_numpy() - state["global_fp8_bwd_buffer"] = get_global_fp8_state().get_fp8_bwd_buffer().to_numpy() - # Store other pickelable values. - extra = {} - for k, v in self.fp8_meta.items(): - if isinstance(v, (bool, int, float, str)): - extra[k] = v - state["extra_fp8_variables"] = extra - - state_serialized = pickle.dumps(state) - state_tensor = paddle.to_tensor(np.frombuffer(state_serialized, dtype=np.uint8)) - - return state_tensor - - @paddle.no_grad() - def state_dict( - self, - destination=None, - include_sublayers=True, - structured_name_prefix="", - use_hook=True, - ): - """Save FP8 State when checkpointing.""" - st = super().state_dict( - destination=destination, - include_sublayers=include_sublayers, - structured_name_prefix=structured_name_prefix, - use_hook=use_hook, - ) - st["fp8_state"] = self._get_fp8_state() - return st - - def _set_fp8_state(self, state: paddle.Tensor) -> None: - """Load previous state.""" - if state is None: - return - - state = pickle.loads(state.numpy().tobytes()) - if state is None: - return - - # Load fp8 meta tensors. - self.fp8_meta["scaling_fwd"].from_numpy(state["scaling_fwd"]) - self.fp8_meta["scaling_bwd"].from_numpy(state["scaling_bwd"]) - - # Restore global FP8 buffer states. - global_fp8_fwd_buffer = get_global_fp8_state().get_fp8_fwd_buffer() - global_fp8_bwd_buffer = get_global_fp8_state().get_fp8_bwd_buffer() - global_fp8_fwd_buffer.from_numpy(state["global_fp8_fwd_buffer"]) - global_fp8_bwd_buffer.from_numpy(state["global_fp8_bwd_buffer"]) - - # Load extra items. - self.fp8_meta.update(state["extra_fp8_variables"]) - self.fp8_meta["recipe"].amax_history_len = self.fp8_meta["scaling_fwd"].amax_history.shape[ - 0 - ] - recompute_buffer_pos_key = FP8RecomputeBuffer.get_buffer_position_key() - if recompute_buffer_pos_key in self.fp8_meta: - del self.fp8_meta[recompute_buffer_pos_key] - - @paddle.no_grad() - def set_state_dict(self, state_dict, use_structured_name=True): - """Restore FP8 State from checkpoint.""" - fp8_state_tensor = state_dict.pop("fp8_state") - self._set_fp8_state(fp8_state_tensor) - - return super().set_state_dict(state_dict) - - @contextmanager - def prepare_forward( - self, - inp: paddle.Tensor, - is_first_microbatch: Union[bool, None], - num_gemms: int = 1, - ) -> Generator[paddle.Tensor, None, None]: - """Checks and prep for FWD. - The context manager is needed because there isn't a way for a module to know - if it's the last FP8 module in the forward autocast. It is useful - to setup the forward aggregated amax reduction for every module - just in case. The autocast exit will pick up the most recent one. - """ - - if self.fp8_enabled and is_in_recompute_phase(): - global_recompute_buffer = get_global_fp8_state().get_fp8_recompute_buffer() - global_recompute_buffer.retrieve_fp8_meta_tensors(self.fp8_meta) - else: - self.set_activation_dtype(inp) - self.fp8_init(num_gemms=num_gemms) - - # Create persistent tensors for fp8 weights and their transposes - # only when fp8 weight caching is used. - if is_first_microbatch is not None: - self.set_fp8_weights() - - if self.fp8_enabled and self.sequence_parallel: - assert self.fp8_meta["recipe"].reduce_amax, ( - "Amax reduction across tensor parallel group is " - "necessary when using sequence parallelism with FP8." - ) - - update_weight_scale_inv = is_first_microbatch is None or is_first_microbatch - - # Previous iteration was grad_enabled - if self.fp8_meta.get("update_amax_and_scale_fwd", False): - global_fp8_fwd_buffer = get_global_fp8_state().get_fp8_fwd_buffer() - global_fp8_fwd_buffer.wait() - # Register PP forward begin hook when CUDAGraph is enabled. - # NOTE(tizheng): register_pp_fwd_begin_hook prevents layer parameters from being freed - # when the layer object is deleted. Need to find a better way. - if get_global_fp8_state().is_cudagraph_enabled() and self.current_step_id is None: - self.current_step_id = paddle.to_tensor( - [1], dtype=paddle.int32, place=paddle.CPUPlace() - ) - - def current_step_id_callback( - step_id=None, **kwargs - ): # pylint: disable=unused-argument - self.current_step_id.copy_( - paddle.to_tensor( - [step_id], dtype=paddle.int32, place=paddle.CPUPlace() - ), - True, - ) - - if is_pp_enabled(): - register_pp_fwd_begin_hook(current_step_id_callback) - - if self.fp8_meta["recipe"].reduce_amax: - global_fp8_fwd_buffer.copy_amax_from_buffer(self.fp8_meta) - amax_and_scale_update( - self.fp8_meta, - fwd_update=True, - update_weight_scale_inv=update_weight_scale_inv, - current_step_id_tensor=self.current_step_id, - use_cudagraph=get_global_fp8_state().is_cudagraph_enabled(), - ) - global_fp8_fwd_buffer.set_for_deletion(self.fp8_meta) - else: - amax_and_scale_update( - self.fp8_meta, - fwd_update=True, - update_weight_scale_inv=update_weight_scale_inv, - current_step_id_tensor=self.current_step_id, - use_cudagraph=get_global_fp8_state().is_cudagraph_enabled(), - ) - - if self.fp8_enabled and self.training: - # Setup for amax reduction - if self.fp8_meta["recipe"].reduce_amax: - global_fp8_state = get_global_fp8_state() - self.fp8_meta["first_module"] = global_fp8_state.is_first_fp8_module() - self.fp8_meta["autocast_id_fwd"] = global_fp8_state.get_autocast_id() - self.fp8_meta["autocast_id_fwd_stack"].append(self.fp8_meta["autocast_id_fwd"]) - self.fp8_meta["update_amax_and_scale_fwd"] = True - else: - self.fp8_meta["update_amax_and_scale_fwd"] = False - - # Activation recomputation is used and this is the first forward phase. - if ( - self.fp8_enabled - and self.training - and get_global_fp8_state().is_fp8_recompute_enabled() - ): - global_recompute_buffer = get_global_fp8_state().get_fp8_recompute_buffer() - global_recompute_buffer.stash_fp8_meta_tensors(self.fp8_meta) - - with nvtx_range(self.__class__.__name__ + " forward"): - yield inp - - if self.fp8_enabled and is_in_recompute_phase(): - FP8RecomputeBuffer.restore_fp8_meta_tensors(self.fp8_meta) - return - - if self.fp8_enabled and self.training and self.fp8_meta["recipe"].reduce_amax: - global_fp8_state = get_global_fp8_state() - global_fp8_fwd_buffer = global_fp8_state.get_fp8_fwd_buffer() - global_fp8_fwd_buffer.add_amax(self.fp8_meta) - global_fp8_fwd_buffer.set_for_amax_reduction( - self.fp8_meta, - self.tp_group, - self.tp_size, - ) - - @staticmethod - @contextmanager - def prepare_backward( - fp8_enabled: bool, - fp8_meta: Dict[str, Any], - tp_group: dist_group_type, - tp_size: int, - name: str = "", - ) -> Generator[None, None, None]: - """Checks and prep for BWD.""" - if fp8_enabled: - global_fp8_state = get_global_fp8_state() - global_fp8_bwd_buffer = global_fp8_state.get_fp8_bwd_buffer() - global_fp8_bwd_buffer.wait() - - if fp8_meta["recipe"].reduce_amax: - global_fp8_bwd_buffer.copy_amax_from_buffer(fp8_meta) - amax_and_scale_update( - fp8_meta, - fwd_update=False, - use_cudagraph=get_global_fp8_state().is_cudagraph_enabled(), - ) - global_fp8_bwd_buffer.set_for_deletion(fp8_meta) - - # Get new backward key. - fp8_meta["autocast_id_bwd"] = fp8_meta["autocast_id_fwd_stack"].pop(0) - else: - amax_and_scale_update( - fp8_meta, - fwd_update=False, - use_cudagraph=get_global_fp8_state().is_cudagraph_enabled(), - ) - - with nvtx_range(name + " backward"): - yield - - if fp8_enabled and fp8_meta["recipe"].reduce_amax: - global_fp8_bwd_buffer.add_amax(fp8_meta) - if fp8_meta["first_module"]: - global_fp8_bwd_buffer.finalize(fp8_meta, tp_group, tp_size) - - @staticmethod - def grad_output_preprocess( - ctx, grad_output: paddle.Tensor, row_parallel_mode: bool - ) -> Tuple[Union[paddle.Tensor, None], ...]: - """Utility function for backward. - Returns tuple in order (all optional/None based on training precion/recipe): - R1: gathered `grad_output` in higher precision. - R2: gathered `grad_output` in FP8. - R3: R2 transposed. - R4: bias gradient on R1. - """ - grad_output_mat = grad_output.reshape((-1, grad_output.shape[-1])) - gather_grad_output = row_parallel_mode and ctx.sequence_parallel - - # No-FP8 case: bgrad is fused with wgrad for this case. - if not ctx.fp8_enabled: - if gather_grad_output: - grad_output_mat, _ = allgather(grad_output_mat, ctx.tp_group) - return grad_output_mat, None, None, None - - fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False) - - if gather_grad_output: - if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad: - # FP8 case with gather: unfused bgrad, cast, transpose for efficient gather - if ctx.use_bias: - bgrad = grad_output_mat.sum(axis=0) - else: - bgrad = None - grad_output_c = cast_to_fp8( - grad_output_mat, - ctx.fp8_meta["scaling_bwd"], - FP8BwdTensors.GRAD_OUTPUT1, - fp8_dtype_backward, - ) - grad_output_c, _ = allgather(grad_output_c, ctx.tp_group) - grad_output_t = transpose(grad_output_c, fp8_dtype_backward) - - return grad_output_mat, grad_output_c, grad_output_t, bgrad - - # FP8 case with gather and non-FP8 wgrad - grad_output_mat, _ = allgather(grad_output_mat, ctx.tp_group) - - # FP8 case without gather: cast, transpose, bgrad fused - if ctx.use_bias: - bgrad, grad_output_c, grad_output_t = cast_transpose_bgrad( - grad_output_mat, - ctx.fp8_meta["scaling_bwd"], - FP8BwdTensors.GRAD_OUTPUT1, - fp8_dtype_backward, - ) - else: - if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad: - grad_output_c, grad_output_t = cast_transpose( - grad_output_mat, - ctx.fp8_meta["scaling_bwd"], - FP8BwdTensors.GRAD_OUTPUT1, - fp8_dtype_backward, - ) - else: - grad_output_t = None - grad_output_c = cast_to_fp8( - grad_output_mat, - ctx.fp8_meta["scaling_bwd"], - FP8BwdTensors.GRAD_OUTPUT1, - fp8_dtype_backward, - ) - bgrad = None - return grad_output_mat, grad_output_c, grad_output_t, bgrad - - @abstractmethod - def forward(self): - """Needs override.""" - - def get_fp8_weights_scratchpad_and_cast( - self, - is_first_microbatch: Union[bool, None], - ) -> List[Optional[paddle.Tensor]]: - """ - Fetch the fp8 weight tensor placeholders if they exist (when - `is_first_microbatch` is not `None`) - """ - if not self.fp8_enabled or is_first_microbatch is None: - return [None, None] * len(self.fp8_weights) - - out_list = [] - for i, _ in enumerate(self.fp8_weights, start=1): - weight_cast_key = f"weight{i}_fp8" - weight_transpose_key = f"weight{i}_t_fp8" - - assert ( - weight_cast_key in self.fp8_weight_cache - ), "TE internal error: fp8 weight buffer is not found" - - weight_fp8 = self.fp8_weight_cache[weight_cast_key] - weight_t_fp8 = self.fp8_weight_cache[weight_transpose_key] - - # Disable fp8 weight cache - # is_first_microbatch is None -> we cast the weights into fp8 every micro step - # Enalbe fp8 weight cache - # is_first_microbatch == true -> we cast the weights into fp8 every micro step - - out_list.extend([weight_fp8, weight_t_fp8]) - - # is cudagraph is enabled we cast the weight before the pp pipe - # we only register the callback once - if get_global_fp8_state().is_cudagraph_enabled() and ( - not self.registered_pp_start_callback and is_pp_enabled() - ): - - fp8_dtype_forward = get_fp8_te_dtype(self.fp8_meta["recipe"], fprop_tensor=True) - - def cast_callback(step_id=None, **kwargs): # pylint: disable=unused-argument - update_fp8_weights = step_id == 0 - - for i, weight in enumerate(self.fp8_weights, start=1): - weight_cast_key = f"weight{i}_fp8" - weight_transpose_key = f"weight{i}_t_fp8" - - assert ( - weight_cast_key in self.fp8_weight_cache - ), "TE internal error: fp8 weight buffer is not found" - - weight_fp8 = self.fp8_weight_cache[weight_cast_key] - weight_t_fp8 = self.fp8_weight_cache[weight_transpose_key] - - if paddle.is_grad_enabled(): - if update_fp8_weights: - cast_transpose( - weight, - self.fp8_meta["scaling_fwd"], - ( - FP8FwdTensors.GEMM1_WEIGHT - if i == 1 - else FP8FwdTensors.GEMM2_WEIGHT - ), - fp8_dtype_forward, - cast_out=weight_fp8, - transpose_out=weight_t_fp8, - ) - else: - if update_fp8_weights: - cast_to_fp8( - weight, - self.fp8_meta["scaling_fwd"], - ( - FP8FwdTensors.GEMM1_WEIGHT - if i == 1 - else FP8FwdTensors.GEMM2_WEIGHT - ), - fp8_dtype_forward, - out=weight_fp8, - ) - - cast_callback(0 if is_first_microbatch else 1) - register_pp_fwd_begin_hook(cast_callback) - self.registered_pp_start_callback = True - return out_list diff --git a/transformer_engine/paddle/layer/layernorm.py b/transformer_engine/paddle/layer/layernorm.py deleted file mode 100644 index be12b6534f..0000000000 --- a/transformer_engine/paddle/layer/layernorm.py +++ /dev/null @@ -1,197 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""Linear API""" - -import os -from typing import Union, Tuple - -import paddle -import paddle.nn.functional as F -from paddle.nn.initializer import Constant - -from ..constants import TE_DType -from ..cpp_extensions import layernorm_fwd, layernorm_bwd -from ..distributed import mark_as_sequence_parallel_parameter - -__all__ = ["LayerNorm"] - - -class _LayerNorm(paddle.autograd.PyLayer): - """TE Non-FP8 LayerNorm""" - - @staticmethod - def forward( - ctx, - inp: paddle.Tensor, - ln_weight: paddle.Tensor, - ln_bias: paddle.Tensor, - eps: float, - fwd_ln_sm_margin: int, - bwd_ln_sm_margin: int, - zero_centered_gamma: bool, - ) -> paddle.Tensor: - # Make sure input dimensions are compatible - in_features = ln_weight.shape[0] - assert inp.shape[-1] == in_features, "LayerNorm not possible" - inputmat = inp.reshape((-1, in_features)) - - ln_out, mu, rsigma = layernorm_fwd( - inputmat, - ln_weight, - ln_bias, - eps, - TE_DType[inp.dtype], - fwd_ln_sm_margin, - zero_centered_gamma, - ) - - ctx.save_for_backward(inputmat, ln_weight, mu, rsigma) - ctx.inp_shape = inp.shape - ctx.bwd_ln_sm_margin = bwd_ln_sm_margin - ctx.zero_centered_gamma = zero_centered_gamma - ctx.requires_dx = not inp.stop_gradient - ctx.requires_dw = not ln_weight.stop_gradient - ctx.requires_dbias = not ln_bias.stop_gradient - return ln_out.reshape(inp.shape) - - @staticmethod - def backward(ctx, grad_output: paddle.Tensor) -> Tuple[Union[paddle.Tensor, None], ...]: - inputmat, ln_weight, mu, rsigma = ctx.saved_tensor() - d_ln_out = grad_output.reshape(inputmat.shape) - dxmat, dgamma, dbeta = layernorm_bwd( - d_ln_out, inputmat, mu, rsigma, ln_weight, ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma - ) - return ( - dxmat.reshape(ctx.inp_shape) if ctx.requires_dx else None, - dgamma if ctx.requires_dw else None, - dbeta if ctx.requires_dbias else None, - ) - - -class LayerNorm(paddle.nn.Layer): - r""" - Applies Layer Normalization over a mini-batch of inputs as described in - the paper `Layer Normalization `__ - - .. math:: - y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} * \gamma + \beta - - :math:`\gamma` and :math:`\beta` are learnable affine transform parameters of - size :attr:`hidden_size` - - Parameters - ---------- - hidden_size : int - size of each input sample. - eps : float, default = 1e-5 - a value added to the denominator of layer normalization for numerical stability. - weight_attr: Union[paddle.ParamAttr, None], default = None - optional `paddle.ParamAttr` for weight. - bias_attr: Union[paddle.ParamAttr, None, bool], default = None - optional `paddle.ParamAttr` for bias. - zero_centered_gamma : bool, default = 'False' - if set to 'True', gamma parameter in LayerNorm is initialized to 0 and - the LayerNorm formula changes to - - .. math:: - y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} * - (1 + \gamma) + \beta - backend: {'transformer_engine', 'paddle'}, default = `transformer_engine` - backend to use for softmax operation. - - Parallelism parameters - ---------------------- - sequence_parallel : bool, default = `False` - if set to `True`, uses sequence parallelism. - """ - - def __init__( - self, - hidden_size: int, - eps: float = 1e-5, - weight_attr: Union[paddle.ParamAttr, None] = None, - bias_attr: Union[paddle.ParamAttr, None, bool] = None, - zero_centered_gamma: bool = False, - sequence_parallel: bool = False, - backend: str = "transformer_engine", - ) -> None: - super().__init__() - self.eps = eps - self.zero_centered_gamma = zero_centered_gamma - self.sequence_parallel = sequence_parallel - self.backend = backend - self._dtype = self._helper.get_default_dtype() - - self._weight_attr = weight_attr - if not self._weight_attr: - self._weight_attr = paddle.ParamAttr( - initializer=Constant(value=0.0 if self.zero_centered_gamma else 1.0) - ) - - self._bias_attr = bias_attr - if self._bias_attr is False: - self._bias_attr = paddle.ParamAttr(initializer=Constant(value=0.0), trainable=False) - - self.weight = self.create_parameter( - shape=[hidden_size], - attr=self._weight_attr, - dtype=self._dtype, - is_bias=False, - ) - - self.bias = self.create_parameter( - shape=[hidden_size], - attr=self._bias_attr, - dtype=self._dtype, - is_bias=True, - ) - - if self.sequence_parallel: - mark_as_sequence_parallel_parameter(self.weight) - mark_as_sequence_parallel_parameter(self.bias) - - # These many SMs are subtracted from the total SM count when calling forward - # and backward LayerNorm C APIs. These envvars can be used to prevent the LN - # kernels from using all SMs in the device. This is useful for cases such as - # communication overlap with LN. - self.fwd_ln_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0")) - self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0")) - - def _te_forward(self, inp: paddle.Tensor) -> paddle.Tensor: - """LayerNorm FWD""" - return _LayerNorm.apply( - inp, - self.weight, - self.bias, - self.eps, - self.fwd_ln_sm_margin, - self.bwd_ln_sm_margin, - self.zero_centered_gamma, - ) - - def _pd_forward( - self, - inp: paddle.Tensor, - ) -> paddle.Tensor: - """Calls Paddle OP""" - if self.zero_centered_gamma: - raise NotImplementedError( - "Paddle backend does not support LayerNorm with zero-centered scale." - ) - - return F.layer_norm( - x=inp, - normalized_shape=inp.shape[-1], - weight=self.weight, - bias=self.bias, - epsilon=self.eps, - ) - - def forward(self, *args, **kwargs): - """forward""" - if self.backend == "transformer_engine": - return self._te_forward(*args, **kwargs) - if self.backend == "paddle": - return self._pd_forward(*args, **kwargs) - raise AttributeError(f"Backend {self.backend} is not supported.") diff --git a/transformer_engine/paddle/layer/layernorm_linear.py b/transformer_engine/paddle/layer/layernorm_linear.py deleted file mode 100644 index 57c91238e6..0000000000 --- a/transformer_engine/paddle/layer/layernorm_linear.py +++ /dev/null @@ -1,721 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""LayerNormLinear API""" - -import warnings -import os -from typing import Union, Tuple, Dict, Any, Optional - -import paddle -import paddle.nn.functional as F -from paddle.nn.initializer import Constant - -from ..cpp_extensions import ( - cast_to_fp8, - cast_from_fp8, - layernorm_fwd, - layernorm_fwd_fp8, - layernorm_bwd, - rmsnorm_fwd_fp8, - rmsnorm_fwd, - rmsnorm_bwd, -) - -from .base import TransformerEngineBaseLayer -from .linear import _linear_fwd, _linear_bwd -from ..constants import TE_DType, FP8FwdTensors, FP8BwdTensors, GemmParallelModes, dist_group_type -from ..distributed import ( - allreduce, - get_tp_group_and_world_size, - identity, - track_rng_state, - set_tensor_dist_attr, - set_weight_tensor_dist_attr, - mark_as_sequence_parallel_parameter, -) -from ..fp8 import get_fp8_te_dtype -from ..utils import ( - assert_dim_for_fp8_forward_exec, - cast_if_needed, - cast_if_needed_inplace, - divide, - save_for_backward_allow_none, - saved_tensor_allow_none, -) - -__all__ = ["LayerNormLinear"] - - -def _apply_normalization_fwd( - normalization: str, - inputmat: paddle.Tensor, - norm_weight: paddle.Tensor, - norm_bias: Union[paddle.Tensor, None], - out_fp8_index: FP8FwdTensors, - eps: float, - fp8_enabled: bool, - fp8_meta: Dict[str, Any], - activation_dtype: paddle.dtype, - return_norm_output: bool, - fwd_norm_sm_margin: int, - zero_centered_gamma: bool, -): - """Performs LayerNorm + FP8_Cast for FP8 path. LayerNorm only for BF16 path""" - assert normalization in ["LayerNorm", "RMSNorm"], "Unsupported normalization type!" - if normalization == "RMSNorm": - assert norm_bias is None, "RMSNorm does not support bias!" - norm_weight = cast_if_needed_inplace(norm_weight, activation_dtype) - if norm_bias is not None: - norm_bias = cast_if_needed_inplace(norm_bias, activation_dtype) - - norm_kwargs = { - "inp": inputmat, - "weight": norm_weight, - "eps": eps, - "otype": TE_DType[activation_dtype], - "sm_margin": fwd_norm_sm_margin, - "zero_centered_gamma": zero_centered_gamma, - } - - fwd_normalization_funcs = { - ("LayerNorm", True, True): layernorm_fwd, - ("LayerNorm", True, False): layernorm_fwd_fp8, - ("LayerNorm", False, True): layernorm_fwd, - ("LayerNorm", False, False): layernorm_fwd, - ("RMSNorm", True, True): rmsnorm_fwd, - ("RMSNorm", True, False): rmsnorm_fwd_fp8, - ("RMSNorm", False, True): rmsnorm_fwd, - ("RMSNorm", False, False): rmsnorm_fwd, - } - - if normalization == "LayerNorm": - norm_kwargs["bias"] = norm_bias - norm_fwd_func = fwd_normalization_funcs[(normalization, fp8_enabled, return_norm_output)] - - if fp8_enabled: - fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) - if not return_norm_output: - fp8_kwargs = { - "fp8_meta_tensor": fp8_meta["scaling_fwd"], - "fp8_tensor": out_fp8_index, - "otype": fp8_dtype_forward, - } - norm_kwargs.update(fp8_kwargs) - - out_tuple = norm_fwd_func(**norm_kwargs) - - if normalization == "LayerNorm": - norm_out_return, mu, rsigma = out_tuple - else: # RMSNorm - norm_out_return, rsigma = out_tuple - mu = None - - if fp8_enabled and return_norm_output: - norm_out = cast_to_fp8( - norm_out_return, - fp8_meta["scaling_fwd"], - out_fp8_index, - fp8_dtype_forward, - ) - else: - norm_out = norm_out_return - - return ( - norm_out_return, - norm_out, - mu, - rsigma, - ) - - -def _apply_normalization_bwd( - normalization: str, - inputmat: paddle.Tensor, - dgrad: paddle.Tensor, - norm_weight: paddle.Tensor, - mu: Union[paddle.Tensor, None], - rsigma: paddle.Tensor, - grad_norm_out_return: paddle.Tensor, - return_norm_output: bool, - bwd_norm_sm_margin: int, - zero_centered_gamma: bool, -): - assert normalization in ["LayerNorm", "RMSNorm"], "Unsupported normalization type!" - if normalization == "RMSNorm": - assert mu is None, "RMSNorm does not support bias!" - # LayerNorm gradient - d_norm_out = dgrad.reshape(inputmat.shape) - # Residual gradient - if return_norm_output: - d_norm_out = d_norm_out + grad_norm_out_return.reshape(d_norm_out.shape) - - norm_bwd_func = layernorm_bwd if normalization == "LayerNorm" else rmsnorm_bwd - norm_bwd_kwargs = { - "dz": d_norm_out, - "x": inputmat, - "rsigma": rsigma, - "gamma": norm_weight, - "sm_margin": bwd_norm_sm_margin, - "zero_centered_gamma": zero_centered_gamma, - } - if normalization == "LayerNorm": - norm_bwd_kwargs["mu"] = mu - - out_tuple = norm_bwd_func(**norm_bwd_kwargs) - if normalization == "LayerNorm": - dxmat, dgamma, dbeta = out_tuple - else: # RMSNorm - dxmat, dgamma = out_tuple - dbeta = None - - return dxmat, dgamma, dbeta - - -class _LayerNormLinear(paddle.autograd.PyLayer): - """TE implementation of LayerNormLinear""" - - @staticmethod - def forward( - ctx, - inp: paddle.Tensor, - ln_weight: paddle.Tensor, - ln_bias: Union[paddle.Tensor, None], - weight: paddle.Tensor, - weight_fp8: Optional[paddle.Tensor], - weight_t_fp8: Optional[paddle.Tensor], - bias: Union[paddle.Tensor, None], - use_bias: bool, - eps: float, - fp8_enabled: bool, - fp8_calibration: bool, - fp8_meta: Dict[str, Any], - activation_dtype: paddle.dtype, - return_layernorm_output: bool, - is_grad_enabled: bool, - fwd_ln_sm_margin: int, - bwd_ln_sm_margin: int, - zero_centered_gamma: bool, - normalization: str, - parallel_mode: Union[str, None], - tensor_parallel: bool, - sequence_parallel: bool, - tp_group: Union[dist_group_type, None], - tp_size: int, - fuse_wgrad_accumulation: bool, - is_first_microbatch: bool, - ) -> Union[Tuple[paddle.Tensor, ...], paddle.Tensor]: - if normalization == "RMSNorm": - assert ln_bias is None, "RMSNorm does not support bias!" - else: # LayerNorm - assert ln_bias is not None, "LayerNorm requires bias!" - # Make sure input dimensions are compatible - in_features = ln_weight.shape[0] - assert inp.shape[-1] == in_features, "GEMM not possible" - inputmat = inp.reshape((-1, in_features)) - if fp8_enabled: - assert_dim_for_fp8_forward_exec(inputmat) - assert_dim_for_fp8_forward_exec(weight) - - # LayerNorm Fwd + FP8 Cast - ( - ln_out_return, - ln_out, - mu, - rsigma, - ) = _apply_normalization_fwd( - normalization, - inputmat, - ln_weight, - ln_bias, - FP8FwdTensors.GEMM1_INPUT, - eps, - fp8_enabled, - fp8_meta, - activation_dtype, - return_layernorm_output, - fwd_ln_sm_margin, - zero_centered_gamma, - ) - - # Linear Fwd - out, weight_t_fp8 = _linear_fwd( - ln_out, - FP8FwdTensors.GEMM1_INPUT, - weight, - weight_fp8, - weight_t_fp8, - FP8FwdTensors.GEMM1_WEIGHT, - bias, - use_bias, - fp8_enabled, - fp8_calibration, - fp8_meta, - activation_dtype, - parallel_mode, - tensor_parallel, - sequence_parallel, - tp_group, - is_grad_enabled, - is_first_microbatch, - ) - - if is_grad_enabled: - save_for_backward_allow_none( - ctx, - inputmat, - ln_weight, - mu, - rsigma, - weight, - weight_t_fp8 if fp8_enabled else None, - ln_out, - fp8_meta["scaling_fwd"].scale_inv.clone() if fp8_enabled else None, - ) - - ctx.activation_dtype = activation_dtype - ctx.fp8_enabled = fp8_enabled - ctx.fp8_meta = fp8_meta - ctx.use_bias = use_bias - ctx.inp_shape = inp.shape - ctx.return_layernorm_output = return_layernorm_output - ctx.bwd_ln_sm_margin = bwd_ln_sm_margin - ctx.zero_centered_gamma = zero_centered_gamma - ctx.parallel_mode = parallel_mode - ctx.tensor_parallel = tensor_parallel - ctx.sequence_parallel = sequence_parallel - ctx.tp_group = tp_group - ctx.tp_size = tp_size - ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation - ctx.requires_dgrad = not inp.stop_gradient - ctx.requires_wgrad = not weight.stop_gradient - ctx.requires_bgrad = use_bias and not bias.stop_gradient - ctx.requires_ln_bgrad = ln_bias is not None and not ln_bias.stop_gradient - ctx.requires_ln_wgrad = not ln_weight.stop_gradient - ctx.is_first_microbatch = is_first_microbatch - ctx.has_ln_bias = ln_bias is not None - ctx.normalization = normalization - - # [*, in_features] -> [*, out_features] except first dimension changes for SP - out = out.reshape((-1, *inp.shape[1:-1], out.shape[-1])) - - if return_layernorm_output: - return out, ln_out_return.reshape(inp.shape) - return out - - @staticmethod - def backward( - ctx, *grad_outputs: Tuple[paddle.Tensor, ...] - ) -> Tuple[Union[paddle.Tensor, None], ...]: - with TransformerEngineBaseLayer.prepare_backward( - ctx.fp8_enabled, ctx.fp8_meta, ctx.tp_group, ctx.tp_size, name="_LayerNormLinear" - ): - ( # pylint: disable=unbalanced-tuple-unpacking - inputmat, - ln_weight, - mu, - rsigma, - weight, - weight_t_fp8, - ln_out, - fwd_scale_inverses, - ) = saved_tensor_allow_none(ctx) - - ( - grad_output, - grad_output_c, - grad_output_t, - bgrad, - ) = TransformerEngineBaseLayer.grad_output_preprocess( - ctx, grad_outputs[0], ctx.parallel_mode == "row" - ) - - if ctx.is_first_microbatch is not None: - accumulate_wgrad_into_param_main_grad = ( - ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch - ) - else: - accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation - - # Prepare ln_out for Linear bwd - linear_inputmat = ln_out - if ctx.fp8_enabled: - fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True) - if ctx.requires_wgrad and ctx.fp8_meta["recipe"].override_linear_precision.wgrad: - linear_inputmat = cast_from_fp8( - ln_out, - ctx.fp8_meta["scaling_fwd"], - FP8FwdTensors.GEMM1_INPUT, - fp8_dtype_forward, - TE_DType[ctx.activation_dtype], - ) - - # Linear Bwd - dgrad, wgrad, bgrad_ = _linear_bwd( - linear_inputmat, - None, # inputmat_t will be automatically computed if not provided - FP8FwdTensors.GEMM1_INPUT, - weight, - weight_t_fp8, - FP8FwdTensors.GEMM1_WEIGHT, - grad_output, - grad_output_c, - grad_output_t, - FP8BwdTensors.GRAD_OUTPUT1, - fwd_scale_inverses, - ctx.requires_bgrad, - ctx.fp8_enabled, - ctx.fp8_meta, - True, # Always compute dgrad to feed into LayerNorm bwd - ctx.requires_wgrad, - ctx.activation_dtype, - ctx.parallel_mode, - ctx.tensor_parallel, - ctx.sequence_parallel, - ctx.tp_group, - ctx.fuse_wgrad_accumulation, - accumulate_wgrad_into_param_main_grad, - ) - - if not ctx.fp8_enabled: - # bgrad is fused with gemm for non-FP8 path - bgrad = bgrad_ - - # LayerNorm Bwd - dxmat, dgamma, dbeta = _apply_normalization_bwd( - ctx.normalization, - inputmat, - dgrad, - ln_weight, - mu, - rsigma, - grad_outputs[1] if ctx.return_layernorm_output else None, - ctx.return_layernorm_output, - ctx.bwd_ln_sm_margin, - ctx.zero_centered_gamma, - ) - - bgrad = bgrad if ctx.requires_bgrad else None - bgrad_out = (bgrad,) if ctx.use_bias else () - dbeta = dbeta if ctx.requires_ln_bgrad else None - dbeta_out = (dbeta,) if ctx.has_ln_bias else () - - if not ctx.fp8_enabled or ctx.is_first_microbatch is None: - weight_cache_grad = () - else: - # weight_fp8 and weight_t_fp8 are stop_gradient tensors - weight_cache_grad = (None, None) - - if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation: - wgrad = None - return ( - dxmat.reshape(ctx.inp_shape) if ctx.requires_dgrad else None, - dgamma if ctx.requires_ln_wgrad else None, - *dbeta_out, - wgrad if ctx.requires_wgrad else None, - *weight_cache_grad, - *bgrad_out, - ) - - -class LayerNormLinear(TransformerEngineBaseLayer): - r""" - Applies layer normalization followed by linear transformation to the incoming data. - - Parameters - ---------- - in_features : int - size of each input sample. - out_features : int - size of each output sample. - eps : float, default = 1e-5 - a value added to the denominator of layer normalization for numerical stability. - weight_attr: Union[paddle.ParamAttr, None], default = None - optional `paddle.ParamAttr` for weight. - bias_attr: Union[paddle.ParamAttr, None, bool], default = None - optional `paddle.ParamAttr` for bias. - normalization : { 'LayerNorm', 'RMSNorm' }, default = 'LayerNorm' - type of normalization applied. - return_layernorm_output : bool, default = `False` - if set to `True`, output of layernorm is returned from the forward - together with the output of the linear transformation. - Example use case: residual connection for transformer module is - taken post layernorm. - zero_centered_gamma : bool, default = 'False' - if set to 'True', gamma parameter in LayerNorm is initialized to 0 and - the LayerNorm formula changes to - - .. math:: - y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} * - (1 + \gamma) + \beta - backend: {'transformer_engine', 'paddle'}, default = 'transformer_engine' - if set to 'paddle', a framework only no-FP8 path is executed with limited optimization. - - Parallelism parameters - ---------------------- - tp_group : ProcessGroup, default = `None` - tensor parallel process group. - parallel_mode : {None, 'Column', 'Row'}, default = `None` - used to decide whether this Linear layer is Column Parallel Linear or Row - Parallel Linear as described `here `_. - When set to `None`, no communication is performed. - sequence_parallel : bool, default = `False` - if set to `True`, uses sequence parallelism. - - Optimization parameters - ----------------------- - fuse_wgrad_accumulation : bool, default = 'False' - if set to `True`, enables fusing of creation and accumulation of - the weight gradient. When enabled, it is assumed that the weights - have an additional `main_grad` attribute (used instead of the - regular `grad`) which is a pre-allocated buffer of the correct - size to accumulate gradients in. - """ - - def __init__( - self, - in_features: int, - out_features: int, - eps: float = 1e-5, - weight_attr: Union[paddle.ParamAttr, None] = None, - bias_attr: Union[paddle.ParamAttr, None, bool] = None, - normalization: str = "LayerNorm", - return_layernorm_output: bool = False, - zero_centered_gamma: bool = False, - parallel_mode: Optional[str] = None, - sequence_parallel: bool = False, - tp_group: Union[dist_group_type, None] = None, - fuse_wgrad_accumulation: bool = False, - backend: str = "transformer_engine", - ) -> None: - super().__init__() - - self.in_features = in_features - self.out_features = out_features - self.eps = eps - self.normalization = normalization - assert normalization in ["LayerNorm", "RMSNorm"], "Unsupported normalization type!" - self.return_layernorm_output = return_layernorm_output - self.zero_centered_gamma = zero_centered_gamma - self.backend = backend - - self._weight_attr = weight_attr - self._bias_attr = bias_attr - self._dtype = self._helper.get_default_dtype() - - # Set parallel configs - self.tp_group, self.tp_size = get_tp_group_and_world_size( - tp_group, enable_tp=parallel_mode is not None - ) - self.tensor_parallel = self.tp_size > 1 - self.parallel_mode = parallel_mode - assert ( - self.parallel_mode in GemmParallelModes - ), f"parallel_mode {parallel_mode} not supported" - - if self.parallel_mode == "column": - self.out_features = divide(self.out_features, self.tp_size) - elif self.parallel_mode == "row": - self.in_features = divide(self.in_features, self.tp_size) - - self.sequence_parallel = self.tensor_parallel and sequence_parallel - - self.fuse_wgrad_accumulation = fuse_wgrad_accumulation - - # LayerNorm weights - self.ln_weight = self.create_parameter( - shape=[self.in_features], - attr=paddle.ParamAttr( - initializer=Constant(value=0.0 if self.zero_centered_gamma else 1.0) - ), - dtype=self._dtype, - is_bias=False, - ) - if self.normalization != "RMSNorm": - self.ln_bias = self.create_parameter( - shape=[self.in_features], - attr=paddle.ParamAttr(initializer=Constant(value=0.0)), - dtype=self._dtype, - is_bias=True, - ) - else: - self.ln_bias = None - - if self.sequence_parallel: - mark_as_sequence_parallel_parameter(self.ln_weight) - if self.ln_bias is not None: - mark_as_sequence_parallel_parameter(self.ln_bias) - - # Initialize Linear weight parameter - with track_rng_state(enable=self.tensor_parallel): - # TE linear weight is in column major - self.weight = self.create_parameter( - shape=( - [self.out_features, self.in_features] - if self.backend == "transformer_engine" - else [self.in_features, self.out_features] - ), - attr=self._weight_attr, - dtype=self._dtype, - is_bias=False, - ) - set_weight_tensor_dist_attr( - self.weight, self.tensor_parallel, self.parallel_mode, self.backend - ) - self.fp8_weights.append(self.weight) - - # Initialize Linear bias parameter - self.has_bias = self._bias_attr is not False - use_default_bias = self._bias_attr is None or self._bias_attr is True - if self.has_bias: - self.bias = self.create_parameter( - shape=[self.out_features], - attr=( - self._bias_attr - if not use_default_bias - else paddle.ParamAttr(initializer=Constant(value=0.0)) - ), - dtype=self._dtype, - is_bias=True, - ) - if parallel_mode == "column": - set_tensor_dist_attr(self.bias, self.tensor_parallel, axis=0) - if parallel_mode == "row" and self.sequence_parallel: - mark_as_sequence_parallel_parameter(self.bias) - else: - self.bias = None - - # For RPL, bias has to be added after TP collectives - # So it cannot be fused with the GEMM - if self.parallel_mode == "row" and self.tensor_parallel and self.has_bias: - self.gemm_bias_fused_add = False - else: - self.gemm_bias_fused_add = True - - # These many SMs are subtracted from the total SM count when calling forward - # and backward LayerNorm C APIs. These envvars can be used to prevent the LN - # kernels from using all SMs in the device. This is useful for cases such as - # communication overlap with LN. - self.fwd_ln_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0")) - self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0")) - self.inf_ln_sm_margin = int(os.getenv("NVTE_INF_LAYERNORM_SM_MARGIN", "0")) - - def _te_forward( - self, - inp: paddle.Tensor, - is_first_microbatch: Optional[bool] = None, - ) -> Union[paddle.Tensor, Tuple[paddle.Tensor, ...]]: - """ - Apply layer normalization to the input followed by a linear transformation. - """ - - with self.prepare_forward(inp, is_first_microbatch=is_first_microbatch) as inp: - # Layer input should be casted outside PyLayer, as performing - # inplace cast to input tensors may cause problems when used - # together with Paddle native layers. - inp = cast_if_needed(inp, self.activation_dtype) - - # Get persistent fp8 weight buffer. None if buffer does not exist. - weight_fp8, weight_t_fp8 = self.get_fp8_weights_scratchpad_and_cast(is_first_microbatch) - - out = _LayerNormLinear.apply( - inp, - self.ln_weight, - self.ln_bias, - self.weight, - weight_fp8, - weight_t_fp8, - self.bias if self.gemm_bias_fused_add else None, - self.has_bias and self.gemm_bias_fused_add, - self.eps, - self.fp8_enabled, - self.fp8_calibration, - self.fp8_meta, - self.activation_dtype, - self.return_layernorm_output, - paddle.is_grad_enabled(), - self.fwd_ln_sm_margin if paddle.is_grad_enabled() else self.inf_ln_sm_margin, - self.bwd_ln_sm_margin, - self.zero_centered_gamma, - self.normalization, - self.parallel_mode, - self.tensor_parallel, - self.sequence_parallel, - self.tp_group, - self.tp_size, - self.fuse_wgrad_accumulation, - is_first_microbatch, - ) - - if self.return_layernorm_output: - out, ln_out = out - - if not self.gemm_bias_fused_add: - out = out + cast_if_needed_inplace(self.bias, self.activation_dtype) - - if self.return_layernorm_output: - return out, ln_out - return out - - def _pd_forward( - self, - inp: paddle.Tensor, - is_first_microbatch: Optional[bool] = None, - ) -> paddle.Tensor: - """Calls Paddle OP""" - if self.zero_centered_gamma: - raise NotImplementedError( - "Paddle backend does not support LayerNorm with zero-centered scale." - ) - - if is_first_microbatch is not None: - warnings.warn( - "`is_first_microbatch` is not supported for paddle backend and is ignored." - ) - - if self.normalization == "RMSNorm": - norm = paddle.rsqrt(paddle.mean(inp**2, axis=-1, keepdim=True) + self.eps) - norm_out = inp * norm * self.ln_weight - else: # LayerNorm - norm_out = F.layer_norm( - x=inp, - normalized_shape=inp.shape[-1], - weight=self.ln_weight, - bias=self.ln_bias, - epsilon=self.eps, - ) - - if self.parallel_mode == "column" and self.tensor_parallel: - norm_out = identity(norm_out, self.tp_group) - out = F.linear(norm_out, self.weight, self.bias if self.gemm_bias_fused_add else None) - if self.parallel_mode == "row" and self.tensor_parallel: - out, _ = allreduce(out, self.tp_group) - out = out + self.bias if self.bias is not None else out - if self.return_layernorm_output: - return out, norm_out - return out - - def forward(self, *args, **kwargs): - """ - Apply layer normalization to the input followed by a linear transformation. - - Parameters - ---------- - inp : paddle.Tensor - Input tensor. - is_first_microbatch : {True, False, None}, default = None - During training using either gradient accumulation or - pipeline parallelism a minibatch of data is further split - into microbatches. Between the microbatches of the same minibatch - the model weights are not updated. Setting this parameter indicates - whether the current microbatch is the first in a minibatch or not. - When set, this parameter enables additional optimizations: - - * during FP8 training, it allows caching of the FP8 versions of - the weights - """ - if self.backend == "transformer_engine": - return self._te_forward(*args, **kwargs) - if self.backend == "paddle": - return self._pd_forward(*args, **kwargs) - raise AttributeError(f"Backend {self.backend} is not supported.") diff --git a/transformer_engine/paddle/layer/layernorm_mlp.py b/transformer_engine/paddle/layer/layernorm_mlp.py deleted file mode 100644 index 069fb82c69..0000000000 --- a/transformer_engine/paddle/layer/layernorm_mlp.py +++ /dev/null @@ -1,1010 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""LayerNormMLP API""" - -import os -import warnings -from typing import Union, Tuple, Dict, Any, Optional - -import paddle -import paddle.nn.functional as F -from paddle.nn.initializer import Constant - -from .base import TransformerEngineBaseLayer -from .layernorm_linear import _apply_normalization_fwd, _apply_normalization_bwd -from .linear import _linear_fwd_fp8, _linear_fwd_non_fp8, _linear_bwd_fp8, _linear_bwd_non_fp8 -from ..constants import TE_DType, FP8FwdTensors, FP8BwdTensors, dist_group_type -from ..cpp_extensions import ( - cast_from_fp8, - gelu_fp8, - swiglu_fp8, - swiglu, - dswiglu, - cast_transpose_bgrad, - dgelu_cast_transpose_bgrad_fp8, -) -from ..distributed import ( - allreduce, - get_tp_group_and_world_size, - identity, - track_rng_state, - set_tensor_dist_attr, - set_weight_tensor_dist_attr, - mark_as_sequence_parallel_parameter, -) -from ..fp8 import get_fp8_te_dtype -from ..utils import ( - assert_dim_for_fp8_forward_exec, - cast_if_needed, - cast_if_needed_inplace, - divide, - get_paddle_act_func, - save_for_backward_allow_none, - saved_tensor_allow_none, -) - -__all__ = ["LayerNormMLP"] - - -def _mlp_forward( - inputmat: paddle.Tensor, - inputmat_fp8_index: FP8FwdTensors, - fc1_weight: paddle.Tensor, - fc1_weight_fp8: Optional[paddle.Tensor], - fc1_weight_t_fp8: Optional[paddle.Tensor], - fc1_weight_fp8_index: FP8FwdTensors, - fc1_bias: Union[paddle.Tensor, None], - use_fc1_bias: bool, - fc2_input_fp8_index: FP8FwdTensors, # FP8FwdTensors.GEMM2_INPUT - fc2_weight: paddle.Tensor, - fc2_weight_fp8: Optional[paddle.Tensor], - fc2_weight_t_fp8: Optional[paddle.Tensor], - fc2_weight_fp8_index: FP8FwdTensors, - fc2_bias: Union[paddle.Tensor, None], - use_fc2_bias: bool, - fp8_enabled: bool, - fp8_calibration: bool, - fp8_meta: Dict[str, Any], - activation_dtype: paddle.dtype, - activation: str, - is_grad_enabled: bool, - set_parallel_mode: bool, - tensor_parallel: bool, - sequence_parallel: bool, - tp_group: Union[dist_group_type, None], - is_first_microbatch: bool, -): - if fp8_enabled: - fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) - fc1_out, fc1_weight_t_fp8 = _linear_fwd_fp8( - inputmat, - inputmat_fp8_index, - fc1_weight, - fc1_weight_fp8, - fc1_weight_t_fp8, - fc1_weight_fp8_index, - fc1_bias, - use_fc1_bias, - fp8_meta, - activation_dtype, - "column" if set_parallel_mode else None, - tensor_parallel, - sequence_parallel, - tp_group, - is_grad_enabled, - is_first_microbatch, - ) - if activation == "gelu": - gelu_out = gelu_fp8( - fc1_out, - fp8_meta["scaling_fwd"], - fc2_input_fp8_index, - fp8_dtype_forward, - ) - elif activation == "swiglu": - gelu_out = swiglu_fp8( - fc1_out, - fp8_meta["scaling_fwd"], - fc2_input_fp8_index, - fp8_dtype_forward, - ) - else: - raise NotImplementedError("Activation type " + activation + " is not supported!") - - fc2_out, fc2_weight_t_fp8 = _linear_fwd_fp8( - gelu_out, - fc2_input_fp8_index, - fc2_weight, - fc2_weight_fp8, - fc2_weight_t_fp8, - fc2_weight_fp8_index, - fc2_bias, - use_fc2_bias, - fp8_meta, - activation_dtype, - "row" if set_parallel_mode else None, - tensor_parallel, - sequence_parallel, - tp_group, - is_grad_enabled, - is_first_microbatch, - ) - else: - fc1_outputs = _linear_fwd_non_fp8( - inputmat, - inputmat_fp8_index, - fc1_weight, - fc1_weight_fp8_index, - fc1_bias, - use_fc1_bias, - fp8_calibration, - fp8_meta, - activation_dtype, - "column" if set_parallel_mode else None, - tensor_parallel, - sequence_parallel, - tp_group, - activation=activation, - ) - - if activation == "gelu": - fc1_out, gelu_out = fc1_outputs - elif activation == "swiglu": - fc1_out = fc1_outputs - gelu_out = swiglu(fc1_out, TE_DType[activation_dtype]) - else: - raise NotImplementedError("Activation type " + activation + " is not supported!") - - fc2_out = _linear_fwd_non_fp8( - gelu_out, - fc2_input_fp8_index, - fc2_weight, - fc2_weight_fp8_index, - fc2_bias, - use_fc2_bias, - fp8_calibration, - fp8_meta, - activation_dtype, - "row" if set_parallel_mode else None, - tensor_parallel, - sequence_parallel, - tp_group, - ) - return ( - fc1_out, - gelu_out, - fc2_out, - fc1_weight_t_fp8 if fp8_enabled else None, - fc2_weight_t_fp8 if fp8_enabled else None, - ) - - -def _mlp_backward( - fc1_input: paddle.Tensor, # ln_out, BF16 / FP8 - fc1_input_fp8_index: FP8FwdTensors, - fc1_weight: paddle.Tensor, - fc1_weight_t_fp8: paddle.Tensor, - fc1_weight_fp8_index: FP8FwdTensors, - fc1_grad_output_fp8_index: FP8BwdTensors, # FP8BwdTensors.GRAD_OUTPUT2 - requires_fc1_wgrad: bool, - requires_fc1_bgrad: bool, - fc1_out: paddle.Tensor, - fc2_input: paddle.Tensor, # gelu_out - fc2_input_fp8_index: FP8FwdTensors, # FP8FwdTensors.GEMM2_INPUT - fc2_weight: paddle.Tensor, - fc2_weight_t_fp8: paddle.Tensor, - fc2_weight_fp8_index: FP8FwdTensors, - requires_fc2_wgrad: bool, - requires_fc2_bgrad: bool, - grad_output: paddle.Tensor, - grad_output_c: paddle.Tensor, - grad_output_t: paddle.Tensor, - grad_output_fp8_index: FP8BwdTensors, # FP8BwdTensors.GRAD_OUTPUT1 - fwd_scale_inverses: paddle.Tensor, - fp8_enabled: bool, - fp8_meta: Dict[str, Any], - requires_dgrad: bool, - activation_dtype: paddle.dtype, - activation: str, - set_parallel_mode: bool, - tensor_parallel: bool, - sequence_parallel: bool, - tp_group: Union[dist_group_type, None], - fuse_wgrad_accumulation: bool, - accumulate_wgrad_into_param_main_grad: bool, -): - ( - fc1_dgrad, - fc1_wgrad, - fc1_bgrad, - fc2_wgrad, - fc2_bgrad, - ) = ( - None, - None, - None, - None, - None, - ) - - if fp8_enabled: - fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) - fp8_dtype_backward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=False) - # FC2 Bwd - fp8_wgrad = not fp8_meta["recipe"].override_linear_precision.wgrad - if requires_fc2_wgrad and not fp8_wgrad: - fc2_input = cast_from_fp8( - fc2_input, - fp8_meta["scaling_fwd"], - fc2_input_fp8_index, - fp8_dtype_forward, - TE_DType[activation_dtype], - ) - - fc2_dgrad, fc2_wgrad = _linear_bwd_fp8( - fc2_input, - None, - fc2_input_fp8_index, - fc2_weight, - fc2_weight_t_fp8, - fc2_weight_fp8_index, - grad_output, - grad_output_c, - grad_output_t, - grad_output_fp8_index, - fwd_scale_inverses, - fp8_meta, - True, - requires_fc2_wgrad, - activation_dtype, - "row" if set_parallel_mode else None, - tensor_parallel, - sequence_parallel, - tp_group, - fuse_wgrad_accumulation, - accumulate_wgrad_into_param_main_grad, - ) - - dgelu_t = None - fc1_bgrad_ = None - if activation == "gelu": - # GELU Bwd - dgelu, dgelu_t, fc1_bgrad_ = dgelu_cast_transpose_bgrad_fp8( - fc2_dgrad, - fc1_out, - fp8_meta["scaling_bwd"], - fc1_grad_output_fp8_index, - fp8_dtype_backward, - ) - elif activation == "swiglu": - dgelu = dswiglu(fc2_dgrad, fc1_out, TE_DType[fc2_dgrad.dtype]) - fc1_bgrad_, dgelu, dgelu_t = cast_transpose_bgrad( - dgelu, - fp8_meta["scaling_bwd"], - fc1_grad_output_fp8_index, - fp8_dtype_backward, - ) - - if requires_fc1_bgrad: - fc1_bgrad = fc1_bgrad_ - - # FC1 Bwd - dgelu_no_fp8 = None - if requires_fc1_wgrad and not fp8_wgrad: - # TODO(tizheng) Paddle lacks fused dgelu_bgrad OP. Cast from dgrad(fp8) instead. - dgelu_no_fp8 = cast_from_fp8( - dgelu, - fp8_meta["scaling_bwd"], - fc1_grad_output_fp8_index, - fp8_dtype_backward, - TE_DType[activation_dtype], - ) - fc1_input = cast_from_fp8( - fc1_input, - fp8_meta["scaling_fwd"], - fc1_input_fp8_index, - fp8_dtype_forward, - TE_DType[activation_dtype], - ) - - fc1_dgrad, fc1_wgrad = _linear_bwd_fp8( - fc1_input, - None, - fc1_input_fp8_index, - fc1_weight, - fc1_weight_t_fp8, - fc1_weight_fp8_index, - dgelu_no_fp8, - dgelu, - dgelu_t, - fc1_grad_output_fp8_index, - fwd_scale_inverses, - fp8_meta, - requires_dgrad, - requires_fc1_wgrad, - activation_dtype, - "column" if set_parallel_mode else None, - tensor_parallel, - sequence_parallel, - tp_group, - fuse_wgrad_accumulation, - accumulate_wgrad_into_param_main_grad, - ) - else: - dgelu, fc2_wgrad, fc2_bgrad = _linear_bwd_non_fp8( - fc2_input, - fc2_weight, - grad_output, - requires_fc2_bgrad, - True, - requires_fc2_wgrad, - activation_dtype, - "row" if set_parallel_mode else None, - tensor_parallel, - sequence_parallel, - tp_group, - fuse_wgrad_accumulation=fuse_wgrad_accumulation, - accumulate_wgrad_into_param_main_grad=accumulate_wgrad_into_param_main_grad, - gelu_input=fc1_out, - activation=activation, - ) - - if activation == "swiglu": - dgelu = dswiglu(dgelu, fc1_out, TE_DType[dgelu.dtype]) - - fc1_dgrad, fc1_wgrad, fc1_bgrad = _linear_bwd_non_fp8( - fc1_input, - fc1_weight, - dgelu, - requires_fc1_bgrad, - requires_dgrad, - requires_fc1_wgrad, - activation_dtype, - "column" if set_parallel_mode else None, - tensor_parallel, - sequence_parallel, - tp_group, - fuse_wgrad_accumulation=fuse_wgrad_accumulation, - accumulate_wgrad_into_param_main_grad=accumulate_wgrad_into_param_main_grad, - ) - return ( - fc1_dgrad, - fc1_wgrad, - fc1_bgrad, - fc2_wgrad, - fc2_bgrad, - ) - - -class _LayerNormMLP(paddle.autograd.PyLayer): - """TE implementation of LayerNormMLP""" - - @staticmethod - def forward( - ctx, - inp: paddle.Tensor, - ln_weight: paddle.Tensor, - ln_bias: Union[paddle.Tensor, None], - fc1_weight: paddle.Tensor, - fc1_weight_fp8: Optional[paddle.Tensor], - fc1_weight_t_fp8: Optional[paddle.Tensor], - fc1_bias: Union[paddle.Tensor, None], - use_fc1_bias: bool, - fc2_weight: paddle.Tensor, - fc2_weight_fp8: Optional[paddle.Tensor], - fc2_weight_t_fp8: Optional[paddle.Tensor], - fc2_bias: Union[paddle.Tensor, None], - use_fc2_bias: bool, - eps: float, - fp8_enabled: bool, - fp8_calibration: bool, - fp8_meta: Dict[str, Any], - activation_dtype: paddle.dtype, - return_layernorm_output: bool, - is_grad_enabled: bool, - fwd_ln_sm_margin: int, - bwd_ln_sm_margin: int, - zero_centered_gamma: bool, - normalization: str, - activation: str, - set_parallel_mode: bool, - tensor_parallel: bool, - sequence_parallel: bool, - tp_group: Union[dist_group_type, None], - tp_size: int, - fuse_wgrad_accumulation: bool, - is_first_microbatch: bool, - ) -> Union[Tuple[paddle.Tensor, ...], paddle.Tensor]: - if normalization == "RMSNorm": - assert ln_bias is None, "RMSNorm does not support bias!" - else: # LayerNorm - assert ln_bias is not None, "LayerNorm requires bias!" - # Make sure input dimensions are compatible - in_features = ln_weight.shape[0] - assert inp.shape[-1] == in_features, "GEMM not possible" - inputmat = inp.reshape((-1, in_features)) - if fp8_enabled: - assert_dim_for_fp8_forward_exec(inputmat) - assert_dim_for_fp8_forward_exec(fc1_weight) - assert_dim_for_fp8_forward_exec(fc2_weight) - - # only support gelu for now - assert activation in ["gelu", "swiglu"], "Only gelu and swiglu are supported for now" - - # LayerNorm Fwd + FP8 Cast - ( - ln_out_return, - ln_out, - mu, - rsigma, - ) = _apply_normalization_fwd( - normalization, - inputmat, - ln_weight, - ln_bias, - FP8FwdTensors.GEMM1_INPUT, - eps, - fp8_enabled, - fp8_meta, - activation_dtype, - return_layernorm_output, - fwd_ln_sm_margin, - zero_centered_gamma, - ) - - ( - fc1_out, - gelu_out, - fc2_out, - fc1_weight_t_fp8, - fc2_weight_t_fp8, - ) = _mlp_forward( - ln_out, - FP8FwdTensors.GEMM1_INPUT, - fc1_weight, - fc1_weight_fp8, - fc1_weight_t_fp8, - FP8FwdTensors.GEMM1_WEIGHT, - fc1_bias, - use_fc1_bias, - FP8FwdTensors.GEMM2_INPUT, - fc2_weight, - fc2_weight_fp8, - fc2_weight_t_fp8, - FP8FwdTensors.GEMM2_WEIGHT, - fc2_bias, - use_fc2_bias, - fp8_enabled, - fp8_calibration, - fp8_meta, - activation_dtype, - activation, - is_grad_enabled, - set_parallel_mode, - tensor_parallel, - sequence_parallel, - tp_group, - is_first_microbatch, - ) - - if is_grad_enabled: - save_for_backward_allow_none( - ctx, - inputmat, - ln_weight, - mu, - rsigma, - ln_out, - fc1_out, - gelu_out, - fc1_weight, - fc1_weight_t_fp8, - fc2_weight, - fc2_weight_t_fp8, - fp8_meta["scaling_fwd"].scale_inv.clone() if fp8_enabled else None, - ) - ctx.activation_dtype = activation_dtype - ctx.activation = activation - ctx.fp8_enabled = fp8_enabled - ctx.fp8_meta = fp8_meta - ctx.use_fc1_bias = use_fc1_bias - ctx.use_fc2_bias = use_fc2_bias - ctx.inp_shape = inp.shape - ctx.return_layernorm_output = return_layernorm_output - ctx.bwd_ln_sm_margin = bwd_ln_sm_margin - ctx.zero_centered_gamma = zero_centered_gamma - ctx.set_parallel_mode = set_parallel_mode - ctx.tensor_parallel = tensor_parallel - ctx.sequence_parallel = sequence_parallel - ctx.tp_group = tp_group - ctx.tp_size = tp_size - ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation - ctx.requires_dgrad = not inp.stop_gradient - ctx.requires_fc1_wgrad = not fc1_weight.stop_gradient - ctx.requires_fc2_wgrad = not fc2_weight.stop_gradient - ctx.requires_fc1_bgrad = use_fc1_bias and not fc1_bias.stop_gradient - ctx.requires_fc2_bgrad = use_fc2_bias and not fc2_bias.stop_gradient - ctx.requires_ln_bgrad = ln_bias is not None and not ln_bias.stop_gradient - ctx.requires_ln_wgrad = not ln_weight.stop_gradient - ctx.is_first_microbatch = is_first_microbatch - ctx.has_ln_bias = ln_bias is not None - ctx.normalization = normalization - - # [*, in_features] -> [*, out_features] except first dimension changes for SP - fc2_out = fc2_out.reshape((-1, *inp.shape[1:-1], fc2_out.shape[-1])) - - if return_layernorm_output: - return fc2_out, ln_out_return.reshape(inp.shape) - return fc2_out - - @staticmethod - def backward( - ctx, *grad_outputs: Tuple[paddle.Tensor, ...] - ) -> Tuple[Union[paddle.Tensor, None], ...]: - with TransformerEngineBaseLayer.prepare_backward( - ctx.fp8_enabled, ctx.fp8_meta, ctx.tp_group, ctx.tp_size, name="_LayerNormMLP" - ): - ( # pylint: disable=unbalanced-tuple-unpacking - inputmat, - ln_weight, - mu, - rsigma, - ln_out, - fc1_out, - gelu_out, - fc1_weight, - fc1_weight_t_fp8, - fc2_weight, - fc2_weight_t_fp8, - fwd_scale_inverses, - ) = saved_tensor_allow_none(ctx) - - ctx.use_bias = ctx.use_fc2_bias # For grad_output_preprocess - ( - grad_output, - grad_output_c, - grad_output_t, - fc2_bgrad, - ) = TransformerEngineBaseLayer.grad_output_preprocess(ctx, grad_outputs[0], True) - - if ctx.is_first_microbatch is not None: - accumulate_wgrad_into_param_main_grad = ( - ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch - ) - else: - accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation - - ( - fc1_dgrad, - fc1_wgrad, - fc1_bgrad, - fc2_wgrad, - fc2_bgrad_, - ) = _mlp_backward( - ln_out, - FP8FwdTensors.GEMM1_INPUT, - fc1_weight, - fc1_weight_t_fp8, - FP8FwdTensors.GEMM1_WEIGHT, - FP8BwdTensors.GRAD_OUTPUT2, - ctx.requires_fc1_wgrad, - ctx.requires_fc1_bgrad, - fc1_out, - gelu_out, - FP8FwdTensors.GEMM2_INPUT, - fc2_weight, - fc2_weight_t_fp8, - FP8FwdTensors.GEMM2_WEIGHT, - ctx.requires_fc2_wgrad, - ctx.requires_fc2_bgrad, - grad_output, - grad_output_c, - grad_output_t, - FP8BwdTensors.GRAD_OUTPUT1, - fwd_scale_inverses, - ctx.fp8_enabled, - ctx.fp8_meta, - True, - ctx.activation_dtype, - ctx.activation, - ctx.set_parallel_mode, - ctx.tensor_parallel, - ctx.sequence_parallel, - ctx.tp_group, - ctx.fuse_wgrad_accumulation, - accumulate_wgrad_into_param_main_grad, - ) - if not ctx.fp8_enabled: - # fc2_bias is fused with gemm for non-FP8 path - fc2_bgrad = fc2_bgrad_ - - # LayerNorm Bwd - dxmat, dgamma, dbeta = _apply_normalization_bwd( - ctx.normalization, - inputmat, - fc1_dgrad, - ln_weight, - mu, - rsigma, - grad_outputs[1] if ctx.return_layernorm_output else None, - ctx.return_layernorm_output, - ctx.bwd_ln_sm_margin, - ctx.zero_centered_gamma, - ) - - fc1_bgrad = fc1_bgrad if ctx.requires_fc1_bgrad else None - fc2_bgrad = fc2_bgrad if ctx.requires_fc2_bgrad else None - fc1_bgrad_out = (fc1_bgrad,) if ctx.use_fc1_bias else () - fc2_bgrad_out = (fc2_bgrad,) if ctx.use_fc2_bias else () - dbeta = dbeta if ctx.requires_ln_bgrad else None - dbeta_out = (dbeta,) if ctx.has_ln_bias else () - - if not ctx.fp8_enabled or ctx.is_first_microbatch is None: - fc1_weight_cache_grad = () - fc2_weight_cache_grad = () - else: - # weight_fp8 and weight_t_fp8 are stop_gradient tensors - fc1_weight_cache_grad = (None, None) - fc2_weight_cache_grad = (None, None) - - if ctx.requires_fc1_wgrad and ctx.fuse_wgrad_accumulation: - fc1_wgrad = None - if ctx.requires_fc2_wgrad and ctx.fuse_wgrad_accumulation: - fc2_wgrad = None - - return ( - dxmat.reshape(ctx.inp_shape) if ctx.requires_dgrad else None, - dgamma if ctx.requires_ln_wgrad else None, - *dbeta_out, - fc1_wgrad if ctx.requires_fc1_wgrad else None, - *fc1_weight_cache_grad, - *fc1_bgrad_out, - fc2_wgrad if ctx.requires_fc2_wgrad else None, - *fc2_weight_cache_grad, - *fc2_bgrad_out, - ) - - -class LayerNormMLP(TransformerEngineBaseLayer): - r""" - Applies layer normalization on the input followed by the MLP module, consisting of - 2 successive linear transformations, separated by the GeLU activation. - - Parameters - ---------- - hidden_size : int - size of each input sample. - ffn_hidden_size : int - intermediate size to which input samples are projected. - eps : float, default = 1e-5 - a value added to the denominator of layer normalization for numerical stability. - weight_attr: Union[paddle.ParamAttr, None], default = None - optional `paddle.ParamAttr` for weight. - bias_attr: Union[paddle.ParamAttr, None, bool], default = None - optional `paddle.ParamAttr` for bias. - normalization : { 'LayerNorm', 'RMSNorm' }, default = 'LayerNorm' - type of normalization applied. - activation : str, default = 'gelu' - activation function used. - Options: 'gelu', 'geglu', 'relu', 'reglu', 'squared_relu', 'swiglu'. - return_layernorm_output : bool, default = `False` - if set to `True`, output of layernorm is returned from the forward - together with the output of the linear transformation. - Example use case: residual connection for transformer module - is taken post layernorm. - zero_centered_gamma : bool, default = 'False' - if set to 'True', gamma parameter in LayerNorm is initialized to 0 and - the LayerNorm formula changes to - - .. math:: - y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} * - (1 + \gamma) + \beta - backend: {'transformer_engine', 'paddle'}, default = 'transformer_engine' - if set to 'paddle', a framework only no-FP8 path is executed with limited optimization. - - Parallelism parameters - ---------------------- - set_parallel_mode : bool, default = `False` - if set to `True`, FC1 is used as Column Parallel and FC2 is used as Row - Parallel as described `here `_. - sequence_parallel : bool, default = `False` - if set to `True`, uses sequence parallelism. - tp_group : paddle.distributed.collective.Group, default = `None` - tensor parallel process group. - - Optimization parameters - ----------------------- - fuse_wgrad_accumulation : bool, default = 'False' - if set to `True`, enables fusing of creation and accumulation of - the weight gradient. When enabled, it is assumed that the weights - have an additional `main_grad` attribute (used instead of the - regular `grad`) which is a pre-allocated buffer of the correct - size to accumulate gradients in. - """ - - def __init__( - self, - hidden_size: int, - ffn_hidden_size: int, - eps: float = 1e-5, - weight_attr: Union[paddle.ParamAttr, None] = None, - bias_attr: Union[paddle.ParamAttr, None, bool] = None, - normalization: str = "LayerNorm", - activation: str = "gelu", - return_layernorm_output: bool = False, - zero_centered_gamma: bool = False, - set_parallel_mode: bool = False, - sequence_parallel: bool = False, - tp_group: Optional[dist_group_type] = None, - fuse_wgrad_accumulation: bool = False, - backend: str = "transformer_engine", - ) -> None: - super().__init__() - - self.hidden_size = hidden_size - self.ffn_hidden_size = ffn_hidden_size - self.eps = eps - self.normalization = normalization - assert normalization in ["LayerNorm", "RMSNorm"], "Normalization type not supported" - self.activation = activation - self.return_layernorm_output = return_layernorm_output - self.zero_centered_gamma = zero_centered_gamma - self.backend = backend - - self._weight_attr = weight_attr - self._bias_attr = bias_attr - self._dtype = self._helper.get_default_dtype() - - # Set parallel configs - self.tp_group, self.tp_size = get_tp_group_and_world_size( - tp_group, enable_tp=set_parallel_mode - ) - self.tensor_parallel = self.tp_size > 1 - self.set_parallel_mode = set_parallel_mode - self.sequence_parallel = self.tensor_parallel and sequence_parallel - - self.fuse_wgrad_accumulation = fuse_wgrad_accumulation - - if self.set_parallel_mode: - self.size_per_partition = divide(self.ffn_hidden_size, self.tp_size) - else: - self.size_per_partition = self.ffn_hidden_size - - # LayerNorm weights - self.ln_weight = self.create_parameter( - shape=[self.hidden_size], - attr=paddle.ParamAttr( - initializer=Constant(value=0.0 if self.zero_centered_gamma else 1.0) - ), - dtype=self._dtype, - is_bias=False, - ) - - if self.normalization != "RMSNorm": - self.ln_bias = self.create_parameter( - shape=[self.hidden_size], - attr=paddle.ParamAttr(initializer=Constant(value=0.0)), - dtype=self._dtype, - is_bias=True, - ) - else: - self.ln_bias = None - - if self.sequence_parallel: - mark_as_sequence_parallel_parameter(self.ln_weight) - if self.ln_bias is not None: - mark_as_sequence_parallel_parameter(self.ln_bias) - - # FC1 weights - if self.activation in ["swiglu"]: - fc1_output_features = self.size_per_partition * 2 - else: - fc1_output_features = self.size_per_partition - - with track_rng_state(enable=self.tensor_parallel): - self.fc1_weight = self.create_parameter( - shape=( - [fc1_output_features, self.hidden_size] - if self.backend == "transformer_engine" - else [self.hidden_size, fc1_output_features] - ), - attr=self._weight_attr, - dtype=self._dtype, - is_bias=False, - ) - set_weight_tensor_dist_attr( - self.fc1_weight, self.tensor_parallel, parallel_mode="column", backend=self.backend - ) - self.fp8_weights.append(self.fc1_weight) - - self.has_bias = self._bias_attr is not False - use_default_bias = self._bias_attr is None or self._bias_attr is True - if use_default_bias: - self._bias_attr = paddle.ParamAttr(initializer=Constant(value=0.0)) - - if self.has_bias: - self.fc1_bias = self.create_parameter( - shape=[fc1_output_features], - attr=self._bias_attr, - dtype=self._dtype, - is_bias=True, - ) - set_tensor_dist_attr(self.fc1_bias, self.tensor_parallel, axis=0) - else: - self.fc1_bias = None - - # FC2 weights - self.fc2_weight = self.create_parameter( - shape=( - [self.hidden_size, self.size_per_partition] - if self.backend == "transformer_engine" - else [self.size_per_partition, self.hidden_size] - ), - attr=self._weight_attr, - dtype=self._dtype, - is_bias=False, - ) - set_weight_tensor_dist_attr( - self.fc2_weight, self.tensor_parallel, parallel_mode="row", backend=self.backend - ) - self.fp8_weights.append(self.fc2_weight) - - if self.has_bias: - self.fc2_bias = self.create_parameter( - shape=[self.hidden_size], - attr=self._bias_attr, - dtype=self._dtype, - is_bias=True, - ) - if self.set_parallel_mode and self.sequence_parallel: - mark_as_sequence_parallel_parameter(self.fc2_bias) - else: - self.fc2_bias = None - - # For RPL, bias has to be added after TP collectives - # So it cannot be fused with the GEMM - if self.set_parallel_mode and self.tensor_parallel and self.has_bias: - self.gemm_bias_fused_add = False - else: - self.gemm_bias_fused_add = True - - # These many SMs are subtracted from the total SM count when calling forward - # and backward LayerNorm C APIs. These envvars can be used to prevent the LN - # kernels from using all SMs in the device. This is useful for cases such as - # communication overlap with LN. - self.fwd_ln_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0")) - self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0")) - self.inf_ln_sm_margin = int(os.getenv("NVTE_INF_LAYERNORM_SM_MARGIN", "0")) - - def _te_forward( - self, - inp: paddle.Tensor, - is_first_microbatch: Optional[bool] = None, - ) -> Union[paddle.Tensor, Tuple[paddle.Tensor, ...]]: - """ - Apply layer normalization to the input followed by a linear transformation. - """ - - with self.prepare_forward(inp, num_gemms=2, is_first_microbatch=is_first_microbatch) as inp: - # Layer input should be casted outside PyLayer, as performing - # inplace cast to input tensors may cause problems when used - # together with Paddle native layers. - inp = cast_if_needed(inp, self.activation_dtype) - - # Get persistent fp8 weight buffer. None if buffer does not exist. - fc1_weight_fp8, fc1_weight_t_fp8, fc2_weight_fp8, fc2_weight_t_fp8 = ( - self.get_fp8_weights_scratchpad_and_cast(is_first_microbatch) - ) - - out = _LayerNormMLP.apply( - inp, - self.ln_weight, - self.ln_bias, - self.fc1_weight, - fc1_weight_fp8, - fc1_weight_t_fp8, - self.fc1_bias, - self.has_bias, - self.fc2_weight, - fc2_weight_fp8, - fc2_weight_t_fp8, - self.fc2_bias, - self.has_bias, - self.eps, - self.fp8_enabled, - self.fp8_calibration, - self.fp8_meta, - self.activation_dtype, - self.return_layernorm_output, - paddle.is_grad_enabled(), - self.fwd_ln_sm_margin if paddle.is_grad_enabled() else self.inf_ln_sm_margin, - self.bwd_ln_sm_margin, - self.zero_centered_gamma, - self.normalization, - self.activation, - self.set_parallel_mode, - self.tensor_parallel, - self.sequence_parallel, - self.tp_group, - self.tp_size, - self.fuse_wgrad_accumulation, - is_first_microbatch, - ) - - if self.return_layernorm_output: - out, ln_out = out - - if not self.gemm_bias_fused_add: - out = out + cast_if_needed_inplace(self.fc2_bias, self.activation_dtype) - - if self.return_layernorm_output: - return out, ln_out - return out - - def _pd_forward( - self, - inp: paddle.Tensor, - is_first_microbatch: Optional[bool] = None, - ) -> paddle.Tensor: - """Calls Paddle OP""" - if self.zero_centered_gamma: - raise NotImplementedError( - "Paddle backend does not support LayerNorm with zero-centered scale." - ) - - if is_first_microbatch is not None: - warnings.warn( - "`is_first_microbatch` is not supported for paddle backend and is ignored." - ) - - if self.normalization == "RMSNorm": - norm = paddle.rsqrt(paddle.mean(inp**2, axis=-1, keepdim=True) + self.eps) - norm_out = inp * norm * self.ln_weight - else: # LayerNorm - norm_out = F.layer_norm( - x=inp, - normalized_shape=inp.shape[-1], - weight=self.ln_weight, - bias=self.ln_bias, - epsilon=self.eps, - ) - if self.set_parallel_mode and self.tensor_parallel: - norm_out = identity(norm_out, self.tp_group) - fc1_out = F.linear(norm_out, self.fc1_weight, self.fc1_bias) - act_func = get_paddle_act_func(self.activation) - act_out = act_func(fc1_out) - out = F.linear( - act_out, self.fc2_weight, self.fc2_bias if self.gemm_bias_fused_add else None - ) - if self.set_parallel_mode and self.tensor_parallel: - out, _ = allreduce(out, self.tp_group) - out = out + self.fc2_bias if self.fc2_bias is not None else out - if self.return_layernorm_output: - return out, norm_out - return out - - def forward(self, *args, **kwargs): - """ - Apply layer normalization to the input followed by a feedforward network (MLP Block). - - Parameters - ---------- - inp : paddle.Tensor - Input tensor. - is_first_microbatch : {True, False, None}, default = None - During training using either gradient accumulation or - pipeline parallelism a minibatch of data is further split - into microbatches. Between the microbatches of the same minibatch - the model weights are not updated. Setting this parameter indicates - whether the current microbatch is the first in a minibatch or not. - When set, this parameter enables additional optimizations: - - * during FP8 training, it allows caching of the FP8 versions of - the weights - """ - if self.backend == "transformer_engine": - return self._te_forward(*args, **kwargs) - if self.backend == "paddle": - return self._pd_forward(*args, **kwargs) - raise AttributeError(f"Backend {self.backend} is not supported.") diff --git a/transformer_engine/paddle/layer/linear.py b/transformer_engine/paddle/layer/linear.py deleted file mode 100644 index 78b22ac7e4..0000000000 --- a/transformer_engine/paddle/layer/linear.py +++ /dev/null @@ -1,919 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""Linear API""" - -import warnings -from typing import Union, Tuple, Dict, Any, Optional - -import paddle -import paddle.nn.functional as F -from paddle.nn.initializer import Constant - -from .base import ( - TransformerEngineBaseLayer, - get_workspace, - _2X_ACC_FPROP, - _2X_ACC_DGRAD, - _2X_ACC_WGRAD, -) - -from ..constants import FP8FwdTensors, FP8BwdTensors, GemmParallelModes, dist_group_type -from ..cpp_extensions import gemm, fp8_gemm, cast_to_fp8, cast_transpose, transpose -from ..distributed import ( - allgather, - allreduce, - get_tp_group_and_world_size, - identity, - reduce_scatter, - track_rng_state, - set_tensor_dist_attr, - set_weight_tensor_dist_attr, - mark_as_sequence_parallel_parameter, -) -from ..fp8 import get_fp8_te_dtype, get_global_fp8_state -from ..utils import ( - assert_dim_for_fp8_forward_exec, - cast_if_needed, - cast_if_needed_inplace, - divide, - get_bias_dtype, - save_for_backward_allow_none, - saved_tensor_allow_none, - clear_tensor_data, -) - -__all__ = ["Linear"] - - -def _linear_fwd_fp8( - inputmat: paddle.Tensor, - inputmat_fp8_index: FP8FwdTensors, - weight: paddle.Tensor, - weight_fp8: Optional[paddle.Tensor], - weight_t_fp8: Optional[paddle.Tensor], - weight_fp8_index: FP8FwdTensors, - bias: paddle.Tensor, - use_bias: bool, - fp8_meta: Dict[str, Any], - activation_dtype: paddle.dtype, - parallel_mode: Union[str, None], - tensor_parallel: bool, - sequence_parallel: bool, - tp_group: Union[dist_group_type, None], - is_grad_enabled: bool, - is_first_microbatch: bool = None, -): - """FP8 path of Linear Fwd""" - fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) - bias_dtype = get_bias_dtype(activation_dtype) - bias = cast_if_needed(bias, bias_dtype) - - if parallel_mode == "column" and sequence_parallel: - inputmat_total, _ = allgather(inputmat, tp_group) - else: - inputmat_total = inputmat - - if not get_global_fp8_state().is_cudagraph_enabled(): - # if cuda graph is not enabled, we cast the weight here - update_fp8_weights = is_first_microbatch is None or is_first_microbatch - if is_grad_enabled: - if update_fp8_weights: - weight_fp8, weight_t_fp8 = cast_transpose( - weight, - fp8_meta["scaling_fwd"], - weight_fp8_index, - fp8_dtype_forward, - cast_out=weight_fp8, - transpose_out=weight_t_fp8, - ) - else: - weight_t_fp8 = None - if update_fp8_weights: - weight_fp8 = cast_to_fp8( - weight, - fp8_meta["scaling_fwd"], - weight_fp8_index, - fp8_dtype_forward, - out=weight_fp8, - ) - - out, _ = fp8_gemm( - weight_fp8, - fp8_meta["scaling_fwd"].scale_inv, - weight_fp8_index, - fp8_dtype_forward, - inputmat_total, - fp8_meta["scaling_fwd"].scale_inv, - inputmat_fp8_index, - fp8_dtype_forward, - activation_dtype, - get_workspace(), - bias=bias, - use_bias=use_bias, - use_split_accumulator=_2X_ACC_FPROP, - ) - - if parallel_mode == "row" and sequence_parallel: - out, _ = reduce_scatter(out, tp_group) - elif parallel_mode == "row" and tensor_parallel: - out, _ = allreduce(out, tp_group) - - return out, weight_t_fp8 - - -def _linear_fwd_non_fp8( - inputmat: paddle.Tensor, - inputmat_fp8_index: FP8FwdTensors, - weight: paddle.Tensor, - weight_fp8_index: FP8FwdTensors, - bias: paddle.Tensor, - use_bias: bool, - fp8_calibration: bool, - fp8_meta: Dict[str, Any], - activation_dtype: paddle.dtype, - parallel_mode: Union[str, None], - tensor_parallel: bool, - sequence_parallel: bool, - tp_group: Union[dist_group_type, None], - activation: str = "", -): - """Non-FP8 path of Linear Fwd""" - - if parallel_mode == "column" and sequence_parallel: - inputmat_total, _ = allgather(inputmat, tp_group) - else: - inputmat_total = inputmat - - # Layer parameters are initialized as float32 dtype by default. - # Cast the parameters to activation_dtype if the current dtype - # does not match activation_dtype. The casting is inplace, so it - # only needs to performed once throughout the traing process. - weight = cast_if_needed_inplace(weight, activation_dtype) - bias = cast_if_needed_inplace(bias, activation_dtype) - - if fp8_calibration: - # amax of input - fp8_meta["scaling_fwd"].amax_history[0, inputmat_fp8_index.value] = paddle.max( - paddle.abs(inputmat_total) - ).item() - # amax of weight - fp8_meta["scaling_fwd"].amax_history[0, weight_fp8_index.value] = paddle.max( - paddle.abs(weight) - ).item() - fp8_meta["update_amax_and_scale_fwd"] = True - - outputs = gemm( - weight, - inputmat_total, - activation_dtype, - get_workspace(), - bias=bias, - use_bias=use_bias, - gelu=(activation == "gelu"), - ) - - if activation == "gelu": - gelu_out, _, out = outputs - return out, gelu_out - - out, _, _ = outputs - - if parallel_mode == "row" and sequence_parallel: - out, _ = reduce_scatter(out, tp_group) - elif parallel_mode == "row" and tensor_parallel: - out, _ = allreduce(out, tp_group) - return out - - -def _linear_fwd( - inputmat: paddle.Tensor, - inputmat_fp8_index: FP8FwdTensors, - weight: paddle.Tensor, - weight_fp8: Optional[paddle.Tensor], - weight_t_fp8: Optional[paddle.Tensor], - weight_fp8_index: FP8FwdTensors, - bias: paddle.Tensor, - use_bias: bool, - fp8_enabled: bool, - fp8_calibration: bool, - fp8_meta: Dict[str, Any], - activation_dtype: paddle.dtype, - parallel_mode: Union[str, None], - tensor_parallel: bool, - sequence_parallel: bool, - tp_group: Union[dist_group_type, None], - is_grad_enabled: bool, - is_first_microbatch: bool = None, - gather_output: bool = False, -): - if fp8_enabled: - out, weight_t_fp8 = _linear_fwd_fp8( - inputmat, - inputmat_fp8_index, - weight, - weight_fp8, - weight_t_fp8, - weight_fp8_index, - bias, - use_bias, - fp8_meta, - activation_dtype, - parallel_mode, - tensor_parallel, - sequence_parallel, - tp_group, - is_grad_enabled, - is_first_microbatch, - ) - else: - out = _linear_fwd_non_fp8( - inputmat, - inputmat_fp8_index, - weight, - weight_fp8_index, - bias, - use_bias, - fp8_calibration, - fp8_meta, - activation_dtype, - parallel_mode, - tensor_parallel, - sequence_parallel, - tp_group, - ) - if gather_output and tensor_parallel and parallel_mode == "column": - out, _ = allgather(out, tp_group, axis=-1) - - return ( - out, - weight_t_fp8 if fp8_enabled else None, - ) - - -def _linear_bwd_fp8( - inputmat: paddle.Tensor, - inputmat_t: paddle.Tensor, - inputmat_fp8_index: FP8FwdTensors, - weight: paddle.Tensor, - weight_t_fp8: paddle.Tensor, - weight_fp8_index: FP8FwdTensors, - grad_output: paddle.Tensor, - grad_output_c: paddle.Tensor, - grad_output_t: paddle.Tensor, - grad_output_fp8_index: FP8BwdTensors, - fwd_scale_inverses: paddle.Tensor, - fp8_meta: Dict[str, Any], - requires_dgrad: bool, - requires_wgrad: bool, - activation_dtype: paddle.dtype, - parallel_mode: Union[str, None], - tensor_parallel: bool, - sequence_parallel: bool, - tp_group: Union[dist_group_type, None], - fuse_wgrad_accumulation: bool, - accumulate_wgrad_into_param_main_grad: bool, -): - dgrad, wgrad, handle = None, None, None - - # Overlap input AG with dgrad - inputmat_total = None - inputmat_t_total = None - if requires_wgrad and parallel_mode == "column" and sequence_parallel: - inputmat_total, handle = allgather(inputmat, tp_group, sync_op=not requires_dgrad) - else: - inputmat_total = inputmat - inputmat_t_total = inputmat_t - - fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) - fp8_dtype_backward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=False) - if requires_dgrad: - dgrad, _ = fp8_gemm( - weight_t_fp8, - fwd_scale_inverses, - weight_fp8_index, - fp8_dtype_forward, - grad_output_c, - fp8_meta["scaling_bwd"].scale_inv, - grad_output_fp8_index, - fp8_dtype_backward, - activation_dtype, - get_workspace(), - use_split_accumulator=_2X_ACC_DGRAD, - ) - clear_tensor_data(grad_output_c) - - # Overlap dgrad-RS/AR with wgrad - if parallel_mode == "column" and sequence_parallel: - if handle is not None: - handle.wait() - dgrad, handle = reduce_scatter(dgrad, tp_group, sync_op=False) - elif parallel_mode == "column" and tensor_parallel: - dgrad, handle = allreduce(dgrad, tp_group, sync_op=False) - - if requires_wgrad: - if not fp8_meta["recipe"].override_linear_precision.wgrad: - if inputmat_t_total is None: - inputmat_t_total = transpose(inputmat_total, fp8_dtype_backward) - clear_tensor_data(inputmat_total) - - wgrad, _ = fp8_gemm( - inputmat_t_total, - fwd_scale_inverses, - inputmat_fp8_index, - fp8_dtype_forward, - grad_output_t, - fp8_meta["scaling_bwd"].scale_inv, - grad_output_fp8_index, - fp8_dtype_backward, - "float32" if fuse_wgrad_accumulation else activation_dtype, - get_workspace(), - accumulate=accumulate_wgrad_into_param_main_grad, - out=weight.main_grad if fuse_wgrad_accumulation else None, - use_split_accumulator=_2X_ACC_WGRAD, - ) - clear_tensor_data(inputmat_t_total, grad_output_t) - else: - wgrad, _, _ = gemm( - inputmat_total, - grad_output, - activation_dtype, - get_workspace(), - grad=True, - accumulate=accumulate_wgrad_into_param_main_grad, - layout="NT", - out=weight.main_grad if fuse_wgrad_accumulation else None, - out_dtype="float32" if fuse_wgrad_accumulation else None, - ) - clear_tensor_data(inputmat_total) - - if fuse_wgrad_accumulation: - weight.main_grad = wgrad - - if parallel_mode == "column" and tensor_parallel and handle is not None: - handle.wait() - if parallel_mode == "column" and sequence_parallel: - handle.wait() - - return dgrad, wgrad - - -def _linear_bwd_non_fp8( - inputmat: paddle.Tensor, - weight: paddle.Tensor, - grad_output: paddle.Tensor, - requires_bgrad: bool, - requires_dgrad: bool, - requires_wgrad: bool, - activation_dtype: paddle.dtype, - parallel_mode: Union[str, None], - tensor_parallel: bool, - sequence_parallel: bool, - tp_group: Union[dist_group_type, None], - fuse_wgrad_accumulation: bool, - accumulate_wgrad_into_param_main_grad: bool, - gelu_input: Union[paddle.Tensor, None] = None, - activation: str = "", -): - """ - Performs Linear Backward. Optionally, fuses GELU backward and dbias. - """ - dgrad, wgrad, bgrad, handle = None, None, None, None - - # Overlap input AG with dgrad - inputmat_total = None - if requires_wgrad and parallel_mode == "column" and sequence_parallel: - inputmat_total, handle = allgather(inputmat, tp_group, sync_op=not requires_dgrad) - else: - inputmat_total = inputmat - - if requires_dgrad: - dgrad, _, _ = gemm( - weight, - grad_output, - activation_dtype, - get_workspace(), - layout="NN", - gelu=(activation == "gelu"), - gelu_input=gelu_input, - grad=True, - ) - # Overlap dgrad-RS/AR with wgrad - if parallel_mode == "column" and sequence_parallel: - if handle is not None: - handle.wait() - dgrad, handle = reduce_scatter(dgrad, tp_group, sync_op=False) - elif parallel_mode == "column" and tensor_parallel: - dgrad, handle = allreduce(dgrad, tp_group, sync_op=False) - - if requires_wgrad: - wgrad, bgrad, _ = gemm( - inputmat_total, - grad_output, - activation_dtype, - get_workspace(), - grad=True, - accumulate=accumulate_wgrad_into_param_main_grad, - layout="NT", - out=weight.main_grad if fuse_wgrad_accumulation else None, - out_dtype="float32" if fuse_wgrad_accumulation else None, - use_bias=requires_bgrad, - ) - if fuse_wgrad_accumulation: - weight.main_grad = wgrad - - elif requires_bgrad: - bgrad = grad_output.sum(axis=0) - if parallel_mode == "column" and tensor_parallel and handle is not None: - handle.wait() - if parallel_mode == "column" and sequence_parallel and handle is not None: - handle.wait() - - return dgrad, wgrad, bgrad - - -def _linear_bwd( - inputmat: paddle.Tensor, - inputmat_t: paddle.Tensor, - inputmat_fp8_index: FP8FwdTensors, - weight: paddle.Tensor, - weight_t_fp8: paddle.Tensor, - weight_fp8_index: FP8FwdTensors, - grad_output: paddle.Tensor, - grad_output_c: paddle.Tensor, - grad_output_t: paddle.Tensor, - grad_output_fp8_index: FP8BwdTensors, - fwd_scale_inverses: paddle.Tensor, - requires_bgrad: bool, - fp8_enabled: bool, - fp8_meta: Dict[str, Any], - requires_dgrad: bool, - requires_wgrad: bool, - activation_dtype: paddle.dtype, - parallel_mode: Union[str, None], - tensor_parallel: bool, - sequence_parallel: bool, - tp_group: Union[dist_group_type, None], - fuse_wgrad_accumulation: bool, - accumulate_wgrad_into_param_main_grad: bool, -): - dgrad, wgrad, bgrad = None, None, None - if fp8_enabled: - dgrad, wgrad = _linear_bwd_fp8( - inputmat, - inputmat_t, - inputmat_fp8_index, - weight, - weight_t_fp8, - weight_fp8_index, - grad_output, - grad_output_c, - grad_output_t, - grad_output_fp8_index, - fwd_scale_inverses, - fp8_meta, - requires_dgrad, - requires_wgrad, - activation_dtype, - parallel_mode, - tensor_parallel, - sequence_parallel, - tp_group, - fuse_wgrad_accumulation=fuse_wgrad_accumulation, - accumulate_wgrad_into_param_main_grad=accumulate_wgrad_into_param_main_grad, - ) - else: - dgrad, wgrad, bgrad = _linear_bwd_non_fp8( - inputmat, - weight, - grad_output, - requires_bgrad, - requires_dgrad, - requires_wgrad, - activation_dtype, - parallel_mode, - tensor_parallel, - sequence_parallel, - tp_group, - fuse_wgrad_accumulation=fuse_wgrad_accumulation, - accumulate_wgrad_into_param_main_grad=accumulate_wgrad_into_param_main_grad, - ) - return dgrad, wgrad, bgrad - - -class _Linear(paddle.autograd.PyLayer): - """TE implementation of Linear""" - - @staticmethod - def forward( - ctx, - weight: paddle.Tensor, - weight_fp8: Optional[paddle.Tensor], - weight_t_fp8: Optional[paddle.Tensor], - inp: paddle.Tensor, - bias: paddle.Tensor, - use_bias: bool, - fp8_enabled: bool, - fp8_calibration: bool, - fp8_meta: Dict[str, Any], - activation_dtype: paddle.dtype, - is_grad_enabled: bool, - parallel_mode: Union[str, None], - tensor_parallel: bool, - sequence_parallel: bool, - tp_group: Union[dist_group_type, None], - tp_size: int, - fuse_wgrad_accumulation: bool, - is_first_microbatch: bool, - gather_output: bool, - ) -> paddle.Tensor: - # Make sure input dimensions are compatible - in_features = weight.shape[-1] - assert inp.shape[-1] == in_features, "GEMM not possible" - inputmat = inp.reshape((-1, in_features)) - if fp8_enabled: - assert_dim_for_fp8_forward_exec(inputmat) - assert_dim_for_fp8_forward_exec(weight) - - inputmat_no_fp8 = inputmat - - # FP8 casting - inputmat_t = None - if fp8_enabled: - fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) - if ( - not fp8_meta["recipe"].override_linear_precision.wgrad - and is_grad_enabled - and not sequence_parallel - ): - inputmat, inputmat_t = cast_transpose( - inputmat, - fp8_meta["scaling_fwd"], - FP8FwdTensors.GEMM1_INPUT, - fp8_dtype_forward, - ) - else: - inputmat = cast_to_fp8( - inputmat, - fp8_meta["scaling_fwd"], - FP8FwdTensors.GEMM1_INPUT, - fp8_dtype_forward, - ) - - # GEMM Fwd - out, weight_t_fp8 = _linear_fwd( - inputmat, - FP8FwdTensors.GEMM1_INPUT, - weight, - weight_fp8, - weight_t_fp8, - FP8FwdTensors.GEMM1_WEIGHT, - bias, - use_bias, - fp8_enabled, - fp8_calibration, - fp8_meta, - activation_dtype, - parallel_mode, - tensor_parallel, - sequence_parallel, - tp_group, - is_grad_enabled, - is_first_microbatch, - gather_output, - ) - - if is_grad_enabled: - saved_inputmat = None - if fp8_enabled and sequence_parallel: - saved_inputmat = inputmat - else: - saved_inputmat = inputmat_no_fp8 - save_for_backward_allow_none( - ctx, - saved_inputmat, - inputmat_t, - weight, - weight_t_fp8 if fp8_enabled else None, - fp8_meta["scaling_fwd"].scale_inv.clone() if fp8_enabled else None, - ) - ctx.activation_dtype = activation_dtype - ctx.fp8_enabled = fp8_enabled - ctx.fp8_meta = fp8_meta - ctx.use_bias = use_bias - ctx.inp_shape = inp.shape - ctx.parallel_mode = parallel_mode - ctx.tensor_parallel = tensor_parallel - ctx.sequence_parallel = sequence_parallel - ctx.tp_group = tp_group - ctx.tp_size = tp_size - ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation - ctx.requires_dgrad = not inp.stop_gradient - ctx.requires_wgrad = not weight.stop_gradient - ctx.requires_bgrad = use_bias and not bias.stop_gradient - ctx.is_first_microbatch = is_first_microbatch - ctx.reduce_scatter_output = gather_output - - return out.reshape((-1, *inp.shape[1:-1], out.shape[-1])) - - @staticmethod - def backward(ctx, grad_output: paddle.Tensor) -> Tuple[Union[paddle.Tensor, None], ...]: - with TransformerEngineBaseLayer.prepare_backward( - ctx.fp8_enabled, ctx.fp8_meta, ctx.tp_group, ctx.tp_size, name="_Linear" - ): - - ( # pylint: disable=unbalanced-tuple-unpacking - inputmat, - inputmat_t, - weight, - weight_t_fp8, - fwd_scale_inverses, - ) = saved_tensor_allow_none(ctx) - - ( - grad_output, - grad_output_c, - grad_output_t, - bgrad, - ) = TransformerEngineBaseLayer.grad_output_preprocess( - ctx, grad_output, ctx.parallel_mode == "row" - ) - if ctx.is_first_microbatch is not None: - accumulate_wgrad_into_param_main_grad = ( - ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch - ) - else: - accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation - - dgrad, wgrad, bgrad_ = _linear_bwd( - inputmat, - inputmat_t, - FP8FwdTensors.GEMM1_INPUT, - weight, - weight_t_fp8, - FP8FwdTensors.GEMM1_WEIGHT, - grad_output, - grad_output_c, - grad_output_t, - FP8BwdTensors.GRAD_OUTPUT1, - fwd_scale_inverses, - ctx.requires_bgrad, - ctx.fp8_enabled, - ctx.fp8_meta, - ctx.requires_dgrad, - ctx.requires_wgrad, - ctx.activation_dtype, - ctx.parallel_mode, - ctx.tensor_parallel, - ctx.sequence_parallel, - ctx.tp_group, - ctx.fuse_wgrad_accumulation, - accumulate_wgrad_into_param_main_grad, - ) - - if not ctx.fp8_enabled: - # bgrad is fused with gemm for non-FP8 path - bgrad = bgrad_ - - if ctx.reduce_scatter_output: - wgrad, _ = reduce_scatter(wgrad, ctx.tp_group) - bgrad, _ = reduce_scatter(bgrad, ctx.tp_group) - - if not ctx.fp8_enabled or ctx.is_first_microbatch is None: - weight_cache_grad = () - else: - # weight_fp8 and weight_t_fp8 are stop_gradient tensors - weight_cache_grad = (None, None) - - dgrad_return = dgrad.reshape(ctx.inp_shape) if ctx.requires_dgrad else None - if not ctx.use_bias: - bgrad_return = () - elif ctx.requires_bgrad: - bgrad_return = (bgrad,) - else: - bgrad_return = (None,) - - if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation: - wgrad = None - - return ( - wgrad if ctx.requires_wgrad else None, - *weight_cache_grad, - dgrad_return, - *bgrad_return, - ) - - -class Linear(TransformerEngineBaseLayer): - """ - Applies a linear transformation to the incoming data :math:`y = xA^T + b` - - Parameters - ---------- - in_features : int - size of each input sample. - out_features : int - size of each output sample. - weight_attr: Union[paddle.ParamAttr, None], default = None - optional `paddle.ParamAttr` for weight. - bias_attr: Union[paddle.ParamAttr, None, bool], default = None - optional `paddle.ParamAttr` for bias. - backend: {'transformer_engine', 'paddle'}, default = 'transformer_engine' - if set to 'paddle', a framework only no-FP8 path is executed with limited optimization. - - Parallelism parameters - ---------------------- - tp_group : ProcessGroup, default = `None` - tensor parallel process group. - parallel_mode : {None, 'Column', 'Row'}, default = `None` - used to decide whether this Linear layer is Column Parallel Linear or Row - Parallel Linear as described `here `_. - When set to `None`, no communication is performed. - sequence_parallel : bool, default = `False` - if set to `True`, uses sequence parallelism. - - Optimization parameters - ----------------------- - fuse_wgrad_accumulation : bool, default = 'False' - if set to `True`, enables fusing of creation and accumulation of - the weight gradient. When enabled, it is assumed that the weights - have an additional `main_grad` attribute (used instead of the - regular `grad`) which is a pre-allocated buffer of the correct - size to accumulate gradients in. - - """ - - def __init__( - self, - in_features: int, - out_features: int, - weight_attr: Union[paddle.ParamAttr, None] = None, - bias_attr: Union[paddle.ParamAttr, None, bool] = None, - parallel_mode: Optional[str] = None, - sequence_parallel: bool = False, - tp_group: Union[dist_group_type, None] = None, - fuse_wgrad_accumulation: bool = False, - gather_output: bool = False, - backend: str = "transformer_engine", - ) -> None: - super().__init__() - self.in_features = in_features - self.out_features = out_features - self.backend = backend - self._weight_attr = weight_attr - self._bias_attr = bias_attr - self._dtype = self._helper.get_default_dtype() - self.gather_output = gather_output - - # Set parallel configs - self.tp_group, self.tp_size = get_tp_group_and_world_size( - tp_group, enable_tp=parallel_mode is not None - ) - self.tensor_parallel = self.tp_size > 1 - self.parallel_mode = parallel_mode - assert ( - self.parallel_mode in GemmParallelModes - ), f"parallel_mode {parallel_mode} not supported" - - if self.parallel_mode == "column": - self.out_features = divide(self.out_features, self.tp_size) - elif self.parallel_mode == "row": - self.in_features = divide(self.in_features, self.tp_size) - - self.sequence_parallel = self.tensor_parallel and sequence_parallel - - self.fuse_wgrad_accumulation = fuse_wgrad_accumulation - - # Initialize weight parameter - with track_rng_state(enable=self.tensor_parallel): - # TE linear weight is in column major - self.weight = self.create_parameter( - shape=( - [self.out_features, self.in_features] - if self.backend == "transformer_engine" - else [self.in_features, self.out_features] - ), - attr=self._weight_attr, - dtype=self._dtype, - is_bias=False, - ) - set_weight_tensor_dist_attr( - self.weight, self.tensor_parallel, self.parallel_mode, self.backend - ) - - # Initialize bias parameter - self.has_bias = self._bias_attr is not False - use_default_bias = self._bias_attr is None or self._bias_attr is True - if self.has_bias: - self.bias = self.create_parameter( - shape=[self.out_features], - attr=( - self._bias_attr - if not use_default_bias - else paddle.ParamAttr(initializer=Constant(value=0.0)) - ), - dtype=self._dtype, - is_bias=True, - ) - if parallel_mode == "column": - set_tensor_dist_attr(self.bias, self.tensor_parallel, axis=0) - if parallel_mode == "row" and self.sequence_parallel: - mark_as_sequence_parallel_parameter(self.bias) - else: - self.bias = None - - self.fp8_weights.append(self.weight) - - # For RPL, bias has to be added after TP collectives - # So it cannot be fused with the GEMM - if self.parallel_mode == "row" and self.tensor_parallel and self.has_bias: - self.gemm_bias_fused_add = False - else: - self.gemm_bias_fused_add = True - - def _te_forward( - self, - inp: paddle.Tensor, - is_first_microbatch: Optional[bool] = None, - ) -> paddle.Tensor: - """ - Apply the linear transformation to the input. - """ - with self.prepare_forward(inp, is_first_microbatch=is_first_microbatch) as inp: - # Layer input should be casted outside PyLayer, as performing - # inplace cast to input tensors may cause problems when used - # together with Paddle native layers. - inp = cast_if_needed(inp, self.activation_dtype) - - # Get persistent fp8 weight buffer. None if buffer does not exist. - weight_fp8, weight_t_fp8 = self.get_fp8_weights_scratchpad_and_cast(is_first_microbatch) - - out = _Linear.apply( - self.weight, - weight_fp8, - weight_t_fp8, - inp, - self.bias if self.gemm_bias_fused_add else None, - self.has_bias and self.gemm_bias_fused_add, - self.fp8_enabled, - self.fp8_calibration, - self.fp8_meta, - self.activation_dtype, - paddle.is_grad_enabled(), - self.parallel_mode, - self.tensor_parallel, - self.sequence_parallel, - self.tp_group, - self.tp_size, - self.fuse_wgrad_accumulation, - is_first_microbatch, - self.gather_output, - ) - - if not self.gemm_bias_fused_add: - out = out + cast_if_needed_inplace(self.bias, self.activation_dtype) - - return out - - def _pd_forward( - self, - inp: paddle.Tensor, - is_first_microbatch: Optional[bool] = None, - ) -> paddle.Tensor: - """Calls Paddle OP""" - if is_first_microbatch is not None: - warnings.warn( - "`is_first_microbatch` is not supported for paddle backend and is ignored." - ) - if self.parallel_mode == "column" and self.tensor_parallel: - inp = identity(inp, self.tp_group) - out = F.linear(inp, self.weight, self.bias if self.gemm_bias_fused_add else None) - if self.parallel_mode == "row" and self.tensor_parallel: - out, _ = allreduce(out, self.tp_group) - out = out + self.bias if self.bias is not None else out - return out - - def forward(self, *args, **kwargs): - """ - Apply the linear transformation to the input. - - Parameters - ---------- - inp : paddle.Tensor - Input tensor. - is_first_microbatch : {True, False, None}, default = None - During training using either gradient accumulation or - pipeline parallelism a minibatch of data is further split - into microbatches. Between the microbatches of the same minibatch - the model weights are not updated. Setting this parameter indicates - whether the current microbatch is the first in a minibatch or not. - When set, this parameter enables additional optimizations: - - * during FP8 training, it allows caching of the FP8 versions of - the weights - """ - if self.backend == "transformer_engine": - return self._te_forward(*args, **kwargs) - if self.backend == "paddle": - return self._pd_forward(*args, **kwargs) - raise AttributeError(f"Backend {self.backend} is not supported.") diff --git a/transformer_engine/paddle/layer/rmsnorm.py b/transformer_engine/paddle/layer/rmsnorm.py deleted file mode 100644 index 23e406e3fb..0000000000 --- a/transformer_engine/paddle/layer/rmsnorm.py +++ /dev/null @@ -1,175 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""RMSNorm API""" -import os -from typing import Union, Tuple - -import paddle -from paddle.nn.initializer import Constant - -from ..constants import TE_DType -from ..cpp_extensions import rmsnorm_fwd, rmsnorm_bwd -from ..distributed import mark_as_sequence_parallel_parameter - -__all__ = ["RMSNorm"] - - -class _RMSNorm(paddle.autograd.PyLayer): - """functional RMSNorm""" - - @staticmethod - def forward( - ctx, - inp: paddle.Tensor, - rmsnorm_weight: paddle.Tensor, - eps: float, - fwd_rmsnorm_sm_margin: int, - bwd_rmsnorm_sm_margin: int, - zero_centered_gamma: bool, - ) -> paddle.Tensor: - # Make sure input dimensions are compatible - in_features = rmsnorm_weight.shape[0] - assert inp.shape[-1] == in_features, "RMSNorm not possible" - inputmat = inp.reshape((-1, in_features)) - - rmsnorm_out, rsigma = rmsnorm_fwd( - inputmat, - rmsnorm_weight, - eps, - TE_DType[inp.dtype], - fwd_rmsnorm_sm_margin, - zero_centered_gamma, - ) - - ctx.save_for_backward(inputmat, rmsnorm_weight, rsigma) - ctx.inp_shape = inp.shape - ctx.bwd_rmsnorm_sm_margin = bwd_rmsnorm_sm_margin - ctx.zero_centered_gamma = zero_centered_gamma - ctx.requires_dx = not inp.stop_gradient - ctx.requires_dw = not rmsnorm_weight.stop_gradient - - return rmsnorm_out.reshape(inp.shape) - - @staticmethod - def backward(ctx, grad_output: paddle.Tensor) -> Tuple[Union[paddle.Tensor, None], ...]: - inputmat, rmsnorm_weight, rsigma = ctx.saved_tensor() - d_rmsnorm_out = grad_output.reshape(inputmat.shape) - dxmat, dgamma = rmsnorm_bwd( - d_rmsnorm_out, - inputmat, - rsigma, - rmsnorm_weight, - ctx.bwd_rmsnorm_sm_margin, - ctx.zero_centered_gamma, - ) - return ( - dxmat.reshape(ctx.inp_shape) if ctx.requires_dx else None, - dgamma if ctx.requires_dw else None, - ) - - -class RMSNorm(paddle.nn.Layer): - r""" - Applies Root Mean Square Layer Normalization over a mini-batch of inputs as described in - the paper `Root Mean Square Layer Normalization `__ - - .. math:: - y = \frac{x}{RMS_\varepsilon(x)} * \gamma - - where - - .. math:: - RMS_\varepsilon(x) = \sqrt{\frac{1}{n}\sum_{i=0}^nx_i^2 + \varepsilon} - - :math:`\gamma` is a learnable affine transform parameter of size :attr:`hidden_size` - - Parameters - ---------- - hidden_size : int - size of each input sample. - eps : float, default = 1e-5 - a value added to the denominator of layer normalization for numerical stability. - weight_attr: Union[paddle.ParamAttr, None], default = None - optional `paddle.ParamAttr` for weight. - zero_centered_gamma : bool, default = 'False' - if set to 'True', gamma parameter in RMSNorm is initialized to 0 and - the RMSNorm formula changes to - - .. math:: - y = \frac{x}{RMS(x) + \varepsilon} * (1 + \gamma) - backend: {'transformer_engine', 'paddle'}, default = 'transformer_engine' - backend to use for rmsnorm operation. - - Parallelism parameters - ---------------------- - sequence_parallel : bool, default = `False` - if set to `True`, uses sequence parallelism. - """ - - def __init__( - self, - hidden_size: int, - eps: float = 1e-5, - weight_attr: Union[paddle.ParamAttr, None] = None, - zero_centered_gamma: bool = False, - sequence_parallel: bool = False, - backend: str = "transformer_engine", - ) -> None: - super().__init__() - - self.eps = eps - self.zero_centered_gamma = zero_centered_gamma - self.sequence_parallel = sequence_parallel - self.backend = backend - self._dtype = self._helper.get_default_dtype() - - self._weight_attr = weight_attr - if not self._weight_attr: - self._weight_attr = paddle.ParamAttr(initializer=Constant(1.0)) - - self.weight = self.create_parameter( - shape=[hidden_size], - attr=self._weight_attr, - dtype=self._dtype, - is_bias=False, - ) - - if self.sequence_parallel: - mark_as_sequence_parallel_parameter(self.weight) - - # These many SMs are subtracted from the total SM count when calling forward - # and backward RMSNorm C APIs. These envvars can be used to prevent the LN - # kernels from using all SMs in the device. This is useful for cases such as - # communication overlap with RMSNorm. - self.fwd_rmsnorm_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0")) - self.bwd_rmsnorm_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0")) - - def _te_forward(self, inp: paddle.Tensor) -> paddle.Tensor: - return _RMSNorm.apply( - inp, - self.weight, - self.eps, - self.fwd_rmsnorm_sm_margin, - self.bwd_rmsnorm_sm_margin, - self.zero_centered_gamma, - ) - - def _pd_forward( - self, - inp: paddle.Tensor, - ) -> paddle.Tensor: - if self.zero_centered_gamma: - raise NotImplementedError( - "Paddle backend does not support RMSNorm with zero_centered_gamma." - ) - norm = paddle.rsqrt(paddle.mean(inp**2, axis=-1, keepdim=True) + self.eps) - y = inp * norm * self.weight - return y - - def forward(self, *args, **kwargs): - if self.backend == "transformer_engine": - return self._te_forward(*args, **kwargs) - if self.backend == "paddle": - return self._pd_forward(*args, **kwargs) - raise AttributeError(f"Backend {self.backend} not supported.") diff --git a/transformer_engine/paddle/layer/softmax.py b/transformer_engine/paddle/layer/softmax.py deleted file mode 100644 index 971be68167..0000000000 --- a/transformer_engine/paddle/layer/softmax.py +++ /dev/null @@ -1,254 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""Fused scaled masked softmax functions""" - -import os -import warnings -from typing import Callable, Tuple, Union, Optional - -import paddle - -from transformer_engine.paddle.cpp_extensions import ( - scaled_upper_triang_masked_softmax_forward, - scaled_upper_triang_masked_softmax_backward, - scaled_masked_softmax_forward, - scaled_masked_softmax_backward, - scaled_softmax_forward, - scaled_softmax_backward, -) - - -__all__ = ["FusedScaleMaskSoftmax"] - - -THREADS_PER_WARP = 32 -THREADS_PER_BLOCK = 128 - - -_default_causal_mask = {} - - -def _get_default_causal_mask(seqlen: int) -> paddle.Tensor: - """Return the causal upper triangular mask for softmax input""" - if seqlen not in _default_causal_mask: - _default_causal_mask[seqlen] = paddle.triu(paddle.ones((seqlen, seqlen)), diagonal=1).cast( - "bool" - ) - return _default_causal_mask[seqlen] - - -class ScaledUpperTriangMaskedSoftmax(paddle.autograd.PyLayer): - """ - Fused operation which performs following three operations in sequence - 1. Scale the tensor. - 2. Apply upper triangular mask (typically used in gpt models). - 3. Perform softmax. - """ - - @staticmethod - def forward(ctx, inputs: paddle.Tensor, scale: float) -> paddle.Tensor: - """ScaledUpperTriangMaskedSoftmax fwd""" - scale_t = paddle.Tensor([scale]) - softmax_results = scaled_upper_triang_masked_softmax_forward(inputs, scale_t[0]) - - ctx.save_for_backward(softmax_results, scale_t) - return softmax_results - - @staticmethod - def backward(ctx, output_grads: paddle.Tensor) -> Tuple[Union[paddle.Tensor, None], ...]: - """ScaledUpperTriangMaskedSoftmax bwd""" - softmax_results, scale_t = ctx.saved_tensor() - input_grads = scaled_upper_triang_masked_softmax_backward( - output_grads, softmax_results, scale_t[0] - ) - - return input_grads, None - - -class ScaledMaskedSoftmax(paddle.autograd.PyLayer): - """ - Fused operation which performs following three operations in sequence - 1. Scale the tensor. - 2. Apply the mask. - 3. Perform softmax. - """ - - @staticmethod - def forward(ctx, inputs: paddle.Tensor, mask: paddle.Tensor, scale: float) -> paddle.Tensor: - """ScaledMaskedSoftmax fwd""" - scale_t = paddle.Tensor([scale]) - - softmax_results = scaled_masked_softmax_forward(inputs, mask, scale_t[0]) - ctx.save_for_backward(softmax_results, scale_t) - return softmax_results - - @staticmethod - def backward(ctx, output_grads: paddle.Tensor) -> Tuple[Union[paddle.Tensor, None], ...]: - """ScaledMaskedSoftmax bwd""" - softmax_results, scale_t = ctx.saved_tensor() - - input_grads = scaled_masked_softmax_backward(output_grads, softmax_results, scale_t[0]) - return input_grads, None, None - - -class ScaledSoftmax(paddle.autograd.PyLayer): - """ - Fused operation which performs following two operations in sequence - 1. Scale the tensor. - 2. Perform softmax. - """ - - @staticmethod - def forward(ctx, inputs: paddle.Tensor, scale: float) -> paddle.Tensor: - """ScaledSoftmax fwd""" - scale_t = paddle.Tensor([scale]) - - softmax_results = scaled_softmax_forward(inputs, scale_t[0]) - ctx.save_for_backward(softmax_results, scale_t) - return softmax_results - - @staticmethod - def backward(ctx, output_grads: paddle.Tensor) -> Tuple[Union[paddle.Tensor, None], ...]: - """ScaledSoftmax bwd""" - softmax_results, scale_t = ctx.saved_tensor() - - input_grads = scaled_softmax_backward(output_grads, softmax_results, scale_t[0]) - return input_grads, None, None - - -class FusedScaleMaskSoftmax(paddle.nn.Layer): - """ - Scaled and masked softmax module for paddle with fused optimizations. - - Parameters - ---------- - attn_mask_type : str, default = `causal` - type of attention mask, can be 'causal', 'padding', or 'no_mask'. - mask_func : callable - custom callable for applying the mask to the softmax input. - `masked_input=mask_func(inp, mask)`. - softmax_in_fp32 : bool, default = True - perform softmax computation in fp32. - layernorm_epsilon : float, default = 1e-5 - a value added to the denominator of layer normalization - for numerical stability. - backend: {'transformer_engine', 'paddle'}, default = `transformer_engine` - backend to use for operation. - """ - - def __init__( - self, - attn_mask_type: str, - mask_func: Callable, - softmax_in_fp32: bool = True, - backend: str = "transformer_engine", - ) -> None: - super().__init__() - self.attn_mask_type = attn_mask_type - self.scaled_masked_softmax_fusion = bool(int(os.getenv("NVTE_MASKED_SOFTMAX_FUSION", "1"))) - self.mask_func = mask_func - self.softmax_in_fp32 = softmax_in_fp32 - self.backend = backend - - def forward( - self, - inp: paddle.Tensor, - mask: paddle.Tensor, - scale: Optional[float] = None, - ) -> paddle.Tensor: - """FusedScaleMaskSoftmax fprop""" - # [batch_size, num_heads, s_q, s_kv] - assert inp.dim() == 4 - self.input_is_fp16 = inp.dtype == paddle.float16 - self.input_is_bf16 = inp.dtype == paddle.bfloat16 - self.input_in_16bit_float = self.input_is_fp16 or self.input_is_bf16 - - assert scale is None or self.softmax_in_fp32, "softmax should be in fp32 when scaled" - - if self.backend == "transformer_engine" and not self.is_kernel_available(*inp.shape): - warnings.warn( - "fused kernel is not available for this input shape, fall back to paddle backend" - ) - self.backend = "paddle" - - if self.backend == "transformer_engine": - return self._te_forward(inp, mask, scale) - if self.backend == "paddle": - return self._pd_forward(inp, mask, scale) - raise AttributeError(f"Backend {self.backend} is not supported.") - - def is_kernel_available(self, b: int, h: int, s_q: int, s_kv: int) -> bool: - """Check FusedScaleMaskSoftmax kernel availability based on size""" - attn_batches = b * h - - if ( - self.scaled_masked_softmax_fusion # user want to fuse - and self.input_in_16bit_float # input must be fp16 - and 16 < s_kv <= 4096 # s_kv must be 16 ~ 2048 - and s_q % 4 == 0 # s_q must be a multiple of 4 - and attn_batches % 4 == 0 # b * h must be a multiple of 4 - ): - if 0 <= s_kv <= 4096: - batch_per_block = self.get_batch_per_block(int(s_kv)) - - if self.attn_mask_type == "causal": - if attn_batches % batch_per_block == 0: - return True - else: - if s_q % batch_per_block == 0: - return True - return False - - def _te_forward( - self, inp: paddle.Tensor, mask: paddle.Tensor, scale: Optional[float] = None - ) -> paddle.Tensor: - """Fused masked softmax kernel""" - b, h, s_q, s_kv = inp.size() - scale = 1.0 if scale is None else scale - - if self.attn_mask_type == "causal": - assert s_q == s_kv, "causal mask is only for self attention" - - # input is 3D tensor (attn_batches, s_q, s_kv) - inp = inp.reshape((-1, s_q, s_kv)) - probs = ScaledUpperTriangMaskedSoftmax.apply(inp, scale) - return probs.reshape((b, h, s_q, s_kv)) - # input is 4D tensor (b, h, s_q, s_kv) - if mask is not None: - return ScaledMaskedSoftmax.apply(inp, mask, scale) - return ScaledSoftmax.apply(inp, scale) - - def _pd_forward( - self, inp: paddle.Tensor, mask: paddle.Tensor, scale: Optional[float] = None - ) -> paddle.Tensor: - """Call Paddle OP""" - if self.input_in_16bit_float and self.softmax_in_fp32: - inp = paddle.cast(inp, "float32") - - if scale is not None: - inp = inp * scale - - if self.attn_mask_type == "causal": - mask = _get_default_causal_mask(inp.shape[2]) - - mask_output = self.mask_func(inp, mask) if mask is not None else inp - probs = paddle.nn.functional.softmax(mask_output, axis=-1) - - if self.input_in_16bit_float and self.softmax_in_fp32: - if self.input_is_fp16: - probs = paddle.cast(probs, "float16") - else: - probs = paddle.cast(probs, "bfloat16") - - return probs - - @staticmethod - def get_batch_per_block(key_seq_len: int) -> int: - """Softmax utility""" - pow2 = 1 << (key_seq_len - 1).bit_length() - warp_size = pow2 if pow2 < THREADS_PER_WARP else THREADS_PER_WARP - batches_per_warp = 2 if pow2 <= 128 else 1 - warps_per_block = THREADS_PER_BLOCK // warp_size - batches_per_block = warps_per_block * batches_per_warp - return batches_per_block diff --git a/transformer_engine/paddle/layer/transformer.py b/transformer_engine/paddle/layer/transformer.py deleted file mode 100644 index feb79c0caa..0000000000 --- a/transformer_engine/paddle/layer/transformer.py +++ /dev/null @@ -1,375 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""Transformer""" - -from typing import Optional, Tuple, Union -import warnings - -import paddle -from paddle.incubate.nn.layer.fused_dropout_add import FusedDropoutAdd - -from .layernorm_mlp import LayerNormMLP -from .layernorm import LayerNorm -from .attention import MultiHeadAttention -from ..constants import AttnMaskTypes, LayerTypes, dist_group_type -from ..distributed import get_tp_group_and_world_size, track_rng_state - - -class TransformerLayer(paddle.nn.Layer): - r""" - TransformerLayer is made up of an attention block and a feedforward network (MLP). - This standard layer is based on the paper "Attention Is All You Need". - - Parameters - ---------- - hidden_size : int - size of each input sample. - ffn_hidden_size : int - intermediate size to which input samples are projected. - num_attention_heads : int - number of attention heads in the transformer layer. - num_gqa_groups : Optional[int], default = `None` - number of GQA groups in the transformer layer. - Grouped Query Attention is described in - `this paper `_. - This only affects the keys and values, not the queries. - GQA-1 is equivalent to Multi-Query Attention - (`MQA `_), while GQA-H - is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`. - layernorm_epsilon : float, default = 1e-5 - a value added to the denominator of layer normalization - for numerical stability. - hidden_dropout: float, default = 0.1 - dropout probability for the dropout op after FC2 layer. - attention_dropout: float, default = 0.1 - dropout probability for the dropout op during multi-head attention. - weight_attr: Union[paddle.ParamAttr, None], default = None - optional `paddle.ParamAttr` for weight. - bias_attr: Union[paddle.ParamAttr, None, bool], default = None - optional `paddle.ParamAttr` for bias. - self_attn_mask_type: {'causal', 'padding'}, default = `causal` - type of attention mask passed into softmax operation. - apply_residual_connection_post_layernorm : bool, default = `False` - if set to `True`, residual connections are taken - from the output of layer norm (default is taken - from input of layer norm) - output_layernorm: bool, default = `False` - if set to `True`, layer normalization is applied on the output side, - after the final dropout-add. default behavior is to apply layer - normalization on the input side, before the QKV transformation. - layer_type: {'encoder', 'decoder'}, default = `encoder` - if set to `decoder`, an additional cross-attn block is added after self-attn. - This can be used for structures like `T5` Transformer in conjunction with the - `encoder` option. - normalization: {'LayerNorm', 'RMSNorm'}, default = `LayerNorm` - zero_centered_gamma : bool, default = 'False' - if set to 'True', gamma parameter in LayerNorm is initialized to 0 and - the LayerNorm formula changes to - - .. math:: - y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} * - (1 + \gamma) + \beta - activation : str, default = 'gelu' - Type of activation used in MLP block. - Options are: 'gelu', 'relu', 'reglu', 'geglu' and 'swiglu'. - - params_dtype : paddle.dtype, default = `paddle.get_default_dtype()` - it controls the type used to allocate the initial parameters. Useful when - the model is trained with lower precision and the original FP32 parameters - would not fit in GPU memory. - backend: {'transformer_engine', 'paddle'}, default = 'transformer_engine' - if set to 'paddle', a framework only no-FP8 path is executed with limited optimization. - - Parallelism parameters - ---------------------- - set_parallel_mode : bool, default = `False` - if set to `True`, QKV and FC1 layers are used as Column Parallel - whereas PROJ and FC2 is used as Row Parallel as described - `here `_. - sequence_parallel : bool, default = `False` - if set to `True`, uses sequence parallelism. - tp_group : ProcessGroup, default = `None` - tensor parallel process group. - attention_dropout_rng_state_name : str, default = `local_seed` - Controls the rng state used for dropout on attention probs. The - specified rng should be set different seeds for different TP ranks. - It will be ignored if `set_parallel_mode` is False. - hidden_dropout_rng_state_name : str, default = `global_seed` - Controls the rng state used for dropout on hidden states. The - specified rng should be given the same seeds for different TP - ranks. It will be ignored if `set_parallel_mode` is False. The - specified name should be registered through - `paddle.distributed.fleet.meta_parallel.get_rng_state_tracker() - .add(rng_state_name, seed)`. - - Optimization parameters - ----------------------- - fuse_wgrad_accumulation : bool, default = 'False' - if set to `True`, enables fusing of creation and accumulation of - the weight gradient. When enabled, it is assumed that the weights - have an additional `main_grad` attribute (used instead of the - regular `grad`) which is a pre-allocated buffer of the correct - size to accumulate gradients in. - - """ - - def __init__( - self, - hidden_size: int, - ffn_hidden_size: int, - num_attention_heads: int, - num_gqa_groups: Optional[int] = None, - layernorm_epsilon: float = 1e-5, - hidden_dropout: float = 0.1, - attention_dropout: float = 0.1, - weight_attr: Union[paddle.ParamAttr, None] = None, - bias_attr: Union[paddle.ParamAttr, None, bool] = None, - max_sequence_length: Optional[int] = None, - self_attn_mask_type: str = "causal", - params_dtype: Optional[paddle.dtype] = None, - apply_residual_connection_post_layernorm: bool = False, - output_layernorm: bool = False, - layer_type: str = "encoder", - normalization: str = "LayerNorm", - zero_centered_gamma: bool = False, - activation: str = "gelu", - set_parallel_mode: bool = False, - sequence_parallel: bool = False, - tp_group: Optional[dist_group_type] = None, - fuse_wgrad_accumulation: bool = False, - attention_dropout_rng_state_name: str = "local_seed", - hidden_dropout_rng_state_name: str = "global_seed", - backend: str = "transformer_engine", - ) -> None: - super().__init__() - - params_dtype = paddle.get_default_dtype() if params_dtype is None else params_dtype - self.output_layernorm = output_layernorm - self.layer_type = layer_type - self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm - self.self_attn_mask_type = self_attn_mask_type - self.set_parallel_mode = set_parallel_mode - self.tp_group, self.tp_size = get_tp_group_and_world_size( - tp_group, enable_tp=set_parallel_mode - ) - self.tensor_parallel = self.tp_size > 1 - self.sequence_parallel = self.tensor_parallel and sequence_parallel - self.hidden_dropout_rng_state_name = hidden_dropout_rng_state_name - # SP needs local seed for hidden dropout - if self.sequence_parallel and self.hidden_dropout_rng_state_name == "global_seed": - warnings.warn( - "RNG state for hidden dropout needs to be different across TP ranks. " - "Forcing hidden_dropout_rng_state_name to 'local_seed'" - ) - self.hidden_dropout_rng_state_name = "local_seed" - - assert ( - self_attn_mask_type in AttnMaskTypes - ), f"self_attn_mask_type {self_attn_mask_type} not supported" - assert layer_type in LayerTypes, f"layer_type {layer_type} not supported" - - attention_args = ( - hidden_size, - num_attention_heads, - attention_dropout, - layernorm_epsilon, - weight_attr, - bias_attr, - ) - common_attention_kwargs = { - "params_dtype": params_dtype, - "return_layernorm_output": apply_residual_connection_post_layernorm, - "normalization": normalization, - "zero_centered_gamma": zero_centered_gamma, - "set_parallel_mode": set_parallel_mode, - "sequence_parallel": self.sequence_parallel, - "max_sequence_length": max_sequence_length, - "tp_group": tp_group, - "num_gqa_groups": num_gqa_groups, - "fuse_wgrad_accumulation": fuse_wgrad_accumulation, - "rng_state_name": attention_dropout_rng_state_name, - "backend": backend, - } - - self.self_attention = MultiHeadAttention( - *attention_args, - **common_attention_kwargs, - attn_mask_type=self_attn_mask_type, - input_layernorm=not output_layernorm, - attention_type="self", - ) - - if layer_type == "decoder": - self.inter_attention = MultiHeadAttention( - *attention_args, - **common_attention_kwargs, - attn_mask_type="padding", - input_layernorm=True, - attention_type="cross", - ) - - self.layernorm_mlp = LayerNormMLP( - hidden_size, - ffn_hidden_size, - eps=layernorm_epsilon, - weight_attr=weight_attr, - bias_attr=bias_attr, - normalization=normalization, - activation=activation, - return_layernorm_output=apply_residual_connection_post_layernorm, - zero_centered_gamma=zero_centered_gamma, - set_parallel_mode=set_parallel_mode, - sequence_parallel=self.sequence_parallel, - tp_group=tp_group, - fuse_wgrad_accumulation=fuse_wgrad_accumulation, - backend=backend, - ) - - self.hidden_dropout = hidden_dropout - - if self.output_layernorm: - self.layernorm = LayerNorm( - hidden_size, - layernorm_epsilon, - weight_attr, - bias_attr, - zero_centered_gamma=zero_centered_gamma, - sequence_parallel=self.sequence_parallel, - backend=backend, - ) - - self.fused_dropout_add1 = FusedDropoutAdd(self.hidden_dropout, mode="upscale_in_train") - if self.layer_type == "decoder": - self.fused_dropout_add2 = FusedDropoutAdd(self.hidden_dropout, mode="upscale_in_train") - self.fused_dropout_add3 = FusedDropoutAdd(self.hidden_dropout, mode="upscale_in_train") - - def forward( - self, - hidden_states: paddle.Tensor, - attention_mask: Optional[paddle.Tensor] = None, - encoder_output: Optional[paddle.Tensor] = None, - enc_dec_attn_mask: Optional[paddle.Tensor] = None, - rotary_pos_emb: Optional[Tuple[paddle.Tensor, paddle.Tensor]] = None, - core_attention_bias_type: str = "no_bias", - core_attention_bias: Optional[paddle.Tensor] = None, - set_zero: bool = True, - recompute_core_attention: bool = False, - is_first_microbatch: Optional[bool] = None, - ) -> paddle.Tensor: - """ - Transformer Layer: attention block and a feedforward network (MLP) - - .. note:: - - Argument :attr:`attention_mask` will be ignored when :attr:`self_attn_mask_type` - is set to `"causal"`. - - Parameters - ---------- - hidden_states : paddle.Tensor - Input tensor. - attention_mask : Optional[paddle.Tensor], default = `None` - Boolean tensor used to mask out self-attention softmax input. - encoder_output : Optional[paddle.Tensor], default = `None` - Output of the encoder block to be fed into the decoder block if using - `layer_type="decoder"`. - enc_dec_attn_mask : Optional[paddle.Tensor], default = `None` - Boolean tensor used to mask out inter-attention softmax input if using - `layer_type="decoder"`. - rotary_pos_emb : Optional[Tuple[paddle.Tensor, paddle.Tensor]], default = `None` - Embeddings for query and key tensors for applying rotary position - embedding. By default no input embedding is applied - core_attention_bias_type: str, default = `no_bias` - core_attention_bias: Optional[paddle.Tensor], default = `None` - Bias tensor for Q * K.T - set_zero: bool, default = `True` - Whether to set output tensors to 0 or not before use. - recompute_core_attention: bool, default = `False` - If true, forward activations for core attention are recomputed - during the backward pass in order to save memory that would - otherwise be occupied to store the forward activations until - backprop. - is_first_microbatch : {True, False, None}, default = None - During training using either gradient accumulation or - pipeline parallelism a minibatch of data is further split - into microbatches. Between the microbatches of the same minibatch - the model weights are not updated. Setting this parameter indicates - whether the current microbatch is the first in a minibatch or not. - When set, this parameter enables additional optimizations: - - * during FP8 training, it allows caching of the FP8 versions of - the weights - """ - - if self.self_attn_mask_type != "causal" and attention_mask is not None: - assert attention_mask.dtype == paddle.bool, "Attention mask must be a boolean tensor" - - assert core_attention_bias_type in ["no_bias"], ( - "Only no_bias is supported currently, " - f"but receive core_attention_bias_type = {core_attention_bias_type}" - ) - - # Self attention. - self_attention_outputs = self.self_attention( - hidden_states, - attention_mask, - core_attention_bias_type=core_attention_bias_type, - core_attention_bias=core_attention_bias, - set_zero=set_zero, - rotary_pos_emb=rotary_pos_emb, - recompute_core_attention=recompute_core_attention, - is_first_microbatch=is_first_microbatch, - ) - - if self.apply_residual_connection_post_layernorm and not self.output_layernorm: - attention_output, residual = self_attention_outputs - else: - attention_output = self_attention_outputs - residual = hidden_states - - # dropoout add. - with track_rng_state(enable=self.tensor_parallel, name=self.hidden_dropout_rng_state_name): - bda_output = self.fused_dropout_add1(attention_output, residual) - - # Cross attention. - if self.layer_type == "decoder": - inter_attention_outputs = self.inter_attention( - bda_output, - enc_dec_attn_mask, - encoder_output=encoder_output, - core_attention_bias_type=core_attention_bias_type, - core_attention_bias=core_attention_bias, - set_zero=set_zero, - recompute_core_attention=recompute_core_attention, - is_first_microbatch=is_first_microbatch, - ) - if self.apply_residual_connection_post_layernorm: - attention_output, residual = inter_attention_outputs - else: - attention_output = inter_attention_outputs - residual = bda_output - - with track_rng_state( - enable=self.tensor_parallel, name=self.hidden_dropout_rng_state_name - ): - bda_output = self.fused_dropout_add2(attention_output, residual) - - # MLP. - mlp_outputs = self.layernorm_mlp(bda_output, is_first_microbatch=is_first_microbatch) - if self.apply_residual_connection_post_layernorm: - mlp_output, residual = mlp_outputs - else: - mlp_output = mlp_outputs - residual = bda_output - - # dropoout add. - with track_rng_state(enable=self.tensor_parallel, name=self.hidden_dropout_rng_state_name): - output = self.fused_dropout_add3(mlp_output, residual) - - # For BERT like architectures. - if self.output_layernorm: - output = self.layernorm(output) - - # output: [b, s, hidden] - return output diff --git a/transformer_engine/paddle/profile.py b/transformer_engine/paddle/profile.py deleted file mode 100644 index d58679aea1..0000000000 --- a/transformer_engine/paddle/profile.py +++ /dev/null @@ -1,19 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""Utils for profiling""" - -from contextlib import contextmanager - -try: - from paddle.base import core -except ImportError: - from paddle.fluid import core - - -@contextmanager -def nvtx_range(msg): - """Context to insert NVTX""" - core.nvprof_nvtx_push(msg) - yield - core.nvprof_nvtx_pop() diff --git a/transformer_engine/paddle/recompute.py b/transformer_engine/paddle/recompute.py deleted file mode 100644 index 5551583736..0000000000 --- a/transformer_engine/paddle/recompute.py +++ /dev/null @@ -1,63 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""Methods needed for recompute.""" - -import os -import inspect - -from paddle.distributed import fleet - -from .constants import RecomputeFunctionNames -from .fp8 import get_global_fp8_state - - -__all__ = ["recompute"] - - -_DISABLE_RECOMPUTE = int(os.getenv("NVTE_DISABLE_RECOMPUTE", "0")) - - -def is_in_recompute_phase(): - """Inspect call stack to determine if this is called from - backward phase. Paddle has two recompute methods: - (1) Use RecomputeFunction. The recomputed function is called from `RecomputeFunction.backward`; - (2) Use paddle.autograd.saved_tensors_hooks. The recompute function is called from `unpack`.""" - if _DISABLE_RECOMPUTE: - return False - frame = inspect.currentframe().f_back - while frame: - if frame.f_code.co_name in RecomputeFunctionNames: - return True - frame = frame.f_back - return False - - -def recompute(function, *args, **kwargs): - """ - This is a wrapper of paddle.distributed.fleet.utils.recompute. It provides necessary - state information for fp8 layers. - - Parameters - ---------- - function: Callable - paddle module used to run the forward and backward passes using - the specified :attr:`args` and :attr:`kwargs`. - args : tuple - tuple of torch tensors for inputs to :attr:`function`. - kwargs : dict - dictionary of string keys for keyword arguments to :attr:`function`. - """ - assert ( - not _DISABLE_RECOMPUTE - ), f"Recompute is disabled. Got NVTE_DISABLE_RECOMPUTE={_DISABLE_RECOMPUTE}." - - global_fp8_state = get_global_fp8_state() - - try: - global_fp8_state._fp8_recompute_enabled = True - outputs = fleet.utils.recompute(function, *args, **kwargs) - finally: - global_fp8_state._fp8_recompute_enabled = False - - return outputs diff --git a/transformer_engine/paddle/setup.py b/transformer_engine/paddle/setup.py deleted file mode 100644 index c80f21a01d..0000000000 --- a/transformer_engine/paddle/setup.py +++ /dev/null @@ -1,64 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -"""Installation script for TE paddle-paddle extensions.""" - -# pylint: disable=wrong-import-position,wrong-import-order - -import sys -import os -import shutil -from pathlib import Path - -import setuptools -from paddle.utils.cpp_extension import BuildExtension - -try: - import paddle # pylint: disable=unused-import -except ImportError as e: - raise RuntimeError("This package needs Paddle Paddle to build.") from e - - -current_file_path = Path(__file__).parent.resolve() -build_tools_dir = current_file_path.parent.parent / "build_tools" -if bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))) or os.path.isdir(build_tools_dir): - build_tools_copy = current_file_path / "build_tools" - if build_tools_copy.exists(): - shutil.rmtree(build_tools_copy) - shutil.copytree(build_tools_dir, build_tools_copy) - - -from build_tools.build_ext import get_build_ext -from build_tools.utils import copy_common_headers -from build_tools.te_version import te_version -from build_tools.paddle import setup_paddle_extension - - -os.environ["NVTE_PROJECT_BUILDING"] = "1" -CMakeBuildExtension = get_build_ext(BuildExtension) - - -if __name__ == "__main__": - # Extensions - common_headers_dir = "common_headers" - copy_common_headers(current_file_path.parent, str(current_file_path / common_headers_dir)) - ext_modules = [ - setup_paddle_extension( - "csrc", current_file_path / "csrc", current_file_path / common_headers_dir - ) - ] - - # Configure package - setuptools.setup( - name="transformer_engine_paddle", - version=te_version(), - description="Transformer acceleration library - Paddle Paddle Lib", - ext_modules=ext_modules, - cmdclass={"build_ext": CMakeBuildExtension}, - install_requires=["paddlepaddle-gpu>=2.6.1"], - tests_require=["numpy"], - ) - if any(x in sys.argv for x in (".", "sdist", "bdist_wheel")): - shutil.rmtree(common_headers_dir) - shutil.rmtree("build_tools") diff --git a/transformer_engine/paddle/utils.py b/transformer_engine/paddle/utils.py deleted file mode 100644 index 4a801495ab..0000000000 --- a/transformer_engine/paddle/utils.py +++ /dev/null @@ -1,149 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""Utility functions for Transformer Engine modules""" - -from typing import Optional, Tuple, Union - -import paddle -import paddle.nn.functional as F -from .cpp_extensions import swiglu_pd - - -def cast_if_needed( - tensor: Union[paddle.Tensor, None], dtype: paddle.dtype -) -> Union[paddle.Tensor, None]: - """Cast tensor to dtype""" - return tensor if tensor is None or tensor.dtype == dtype else paddle.cast(tensor, dtype) - - -def cast_if_needed_inplace( - tensor: Union[paddle.Tensor, None], dtype: paddle.dtype -) -> Union[paddle.Tensor, None]: - """Cast tensor to dtype (inplace), not to be used on layer inputs""" - return tensor if tensor is None or tensor.dtype == dtype else tensor._to(dtype=dtype) - - -def check_dim_for_fp8_forward_exec(tensor: paddle.Tensor) -> bool: - """For fp8 fprop (TN layout), inputs and weights must be such - that dim0 is divisible by 8 and dim1 is divisible by 16. - """ - return not tensor.shape[0] % 8 and not tensor.shape[1] % 16 - - -def assert_dim_for_fp8_forward_exec(tensor: paddle.Tensor) -> None: - """For fp8 fprop (TN layout), inputs and weights must be such - that dim0 is divisible by 8 and dim1 is divisible by 16. - """ - # single tensor check so it's clear which tensor is triggering the assertion - assert check_dim_for_fp8_forward_exec(tensor), ( - "Tensor dimensions are not compatible for FP8 execution: " - f"({tensor.shape[0]} % 8 != 0, {tensor.shape[1]} % 16 != 0)" - ) - - -def get_bias_dtype(activation_dtype: paddle.dtype): - """Get bias dtype given activation_dtype""" - return paddle.bfloat16 if activation_dtype == paddle.float32 else activation_dtype - - -def get_paddle_act_func(activation): - """Get paddle activation function""" - funcs = { - "gelu": F.gelu, - "relu": F.relu, - "silu": F.silu, - "swiglu": swiglu_pd, - } - if activation not in funcs: - raise "Activation type " + activation + " is not supported." - return funcs[activation] - - -def attention_mask_func( - attention_scores: paddle.Tensor, attention_mask: paddle.Tensor -) -> paddle.Tensor: - """Get attention mask""" - - def _masked_fill(x, mask, value): - y = paddle.full(x.shape, value, x.dtype) - return paddle.where(mask, y, x) - - attention_scores = _masked_fill(attention_scores, attention_mask, -10000.0) - return attention_scores - - -def mask_to_cu_seqlens(mask: paddle.Tensor, need_kv: bool = False) -> paddle.Tensor: - """Convert mask to cu_seqlens""" - assert "bool" in str(mask.dtype), "mask must be bool dtype" - assert len(mask.shape) == 4 and mask.shape[1] == 1, "mask must be [b, 1, s_q, s_kv]" - q_actual_seqlens = paddle.sum(mask[:, :, :, 0].logical_not(), axis=(-1, -2), dtype="int32") - q_cu_seqlens = paddle.cumsum(q_actual_seqlens) - q_cu_seqlens = paddle.concat([paddle.zeros([1], dtype=paddle.int32), q_cu_seqlens], axis=0) - if not need_kv: - return q_cu_seqlens, None - kv_actual_seqlens = paddle.sum(mask[:, :, 0, :].logical_not(), axis=(-1, -2), dtype="int32") - kv_cu_seqlens = paddle.cumsum(kv_actual_seqlens) - kv_cu_seqlens = paddle.concat([paddle.zeros([1], dtype=paddle.int32), kv_cu_seqlens], axis=0) - return q_cu_seqlens, kv_cu_seqlens - - -def divide(numerator: int, denominator: int) -> int: - """Ensure that numerator is divisible by the denominator and return - the division value.""" - assert numerator % denominator == 0, f"{numerator} is not divisible by {denominator}" - return numerator // denominator - - -def save_for_backward_allow_none(ctx, *args) -> None: - """Save tensors for backward. Args could be None""" - indices_mapping = [] - tensors_to_save = [] - for x in args: - if isinstance(x, paddle.Tensor): - indices_mapping.append(len(tensors_to_save)) - tensors_to_save.append(x) - elif x is None: - indices_mapping.append(-1) - else: - raise ValueError(f"Type {type(x)} is not allowed.") - - ctx._indices_mapping = indices_mapping - ctx.save_for_backward(*tensors_to_save) - - -def saved_tensor_allow_none(ctx) -> Tuple[Optional[paddle.Tensor]]: - """Used with `save_for_backward_allow_none` in pair. Get saved tensors from ctx.""" - assert hasattr( - ctx, "_indices_mapping" - ), "`saved_tensor_allow_none` must be used with `save_for_backward_allow_none` in pair." - - indices_mapping = ctx._indices_mapping - outputs = [] - saved_tensors = ctx.saved_tensor() - - for index in indices_mapping: - if index < 0: - outputs.append(None) - else: - outputs.append(saved_tensors[index]) - - return tuple(outputs) - - -def clear_tensor_data(*tensors: Tuple[Optional[paddle.Tensor], ...]) -> None: - """ - Free tensor buffer - """ - - def can_free(t): - return ( - t is not None - and isinstance(t, paddle.Tensor) - and t._is_initialized() - and t.inplace_version == 0 - ) - - for t in tensors: - if can_free(t): - t._clear_dataptr() diff --git a/transformer_engine/pytorch/__init__.py b/transformer_engine/pytorch/__init__.py index 91d3772fd7..57addca3b9 100644 --- a/transformer_engine/pytorch/__init__.py +++ b/transformer_engine/pytorch/__init__.py @@ -82,27 +82,12 @@ def _load_library(): from transformer_engine.pytorch.fp8 import fp8_autocast from transformer_engine.pytorch.fp8 import fp8_model_init from transformer_engine.pytorch.graph import make_graphed_callables -from transformer_engine.pytorch.export import onnx_export from transformer_engine.pytorch.distributed import checkpoint from transformer_engine.pytorch.distributed import CudaRNGStatesTracker from transformer_engine.pytorch.cpu_offload import get_cpu_offload_context from transformer_engine.pytorch import ops from transformer_engine.pytorch import optimizers -# Register custom op symbolic ONNX functions -from transformer_engine.pytorch.te_onnx_extensions import ( - onnx_cast_to_fp8, - onnx_cast_to_fp8_noalloc, - onnx_cast_from_fp8, - onnx_fp8_gelu, - onnx_fp8_relu, - onnx_te_gemm, - onnx_layernorm_fwd_fp8, - onnx_layernorm_fwd, - onnx_rmsnorm_fwd, - onnx_rmsnorm_fwd_fp8, -) - try: torch._dynamo.config.error_on_nested_jit_trace = False except AttributeError: diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index ccceacff85..bf6adc309c 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -24,15 +24,7 @@ import transformer_engine_torch as tex import transformer_engine as te from transformer_engine.pytorch.utils import get_cudnn_version -from transformer_engine.pytorch.cpp_extensions import ( - cast_to_fp8, - cast_from_fp8, -) from transformer_engine.pytorch.cpp_extensions.fused_attn import ( - fused_attn_fwd_qkvpacked, - fused_attn_bwd_qkvpacked, - fused_attn_fwd_kvpacked, - fused_attn_bwd_kvpacked, fused_attn_fwd, fused_attn_bwd, QKVLayout, @@ -54,6 +46,7 @@ get_fp8_torch_dtype, ) from transformer_engine.pytorch.float8_tensor import Float8Tensor +from transformer_engine.pytorch.tensor._internal.float8_tensor_base import Float8TensorBase from transformer_engine.pytorch.module import LayerNormLinear, Linear from transformer_engine.pytorch.module.base import TransformerEngineBaseModule from transformer_engine.pytorch.utils import ( @@ -82,9 +75,13 @@ gather_along_first_dim, reduce_scatter_along_first_dim, ) -from transformer_engine.pytorch.export import is_in_onnx_export_mode from transformer_engine.pytorch.jit import jit_fuser, no_torch_dynamo from transformer_engine.pytorch.graph import is_graph_capturing +from transformer_engine.pytorch.tensor.quantized_tensor import ( + QuantizedTensor, + prepare_for_saving, + restore_from_saved, +) # NVTE_DEBUG = 0/1 # disables/enables debug mode, default = 0 @@ -116,7 +113,8 @@ def _get_supported_versions(version_min, version_max): _flash_attn_is_installed = False _flash_attn_version = PkgVersion("0") _flash_attn_version_required = PkgVersion("2.1.1") -_flash_attn_max_version = PkgVersion("2.6.3") +_flash_attn_version_required_blackwell = PkgVersion("2.7.3") +_flash_attn_max_version = PkgVersion("2.7.3") _flash_attn_2_plus = False _flash_attn_2_1_plus = False _flash_attn_2_3_plus = False @@ -124,6 +122,7 @@ def _get_supported_versions(version_min, version_max): _flash_attn_2_4_1_plus = False _flash_attn_2_5_7_plus = False _flash_attn_2_6_0_plus = False +_flash_attn_2_7_0_plus = False flash_attn_cuda_bwd = None flash_attn_func = None @@ -142,7 +141,13 @@ def _get_supported_versions(version_min, version_max): """ "pip install flash-attn".""", ) else: - if _flash_attn_version_required <= _flash_attn_version <= _flash_attn_max_version: + if torch.cuda.is_available() and get_device_compute_capability() >= (10, 0): + if _flash_attn_version_required_blackwell <= _flash_attn_version <= _flash_attn_max_version: + _flash_attn_is_installed = True + elif _flash_attn_version_required <= _flash_attn_version <= _flash_attn_max_version: + _flash_attn_is_installed = True + + if _flash_attn_is_installed: from flash_attn_2_cuda import varlen_bwd as flash_attn_cuda_bwd from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func from flash_attn.flash_attn_interface import _flash_attn_forward as _flash_attn_fwd @@ -154,7 +159,6 @@ def _get_supported_versions(version_min, version_max): _flash_attn_varlen_backward as _flash_attn_varlen_bwd, ) - _flash_attn_is_installed = True _flash_attn_2_plus = _flash_attn_version >= PkgVersion("2") _flash_attn_2_1_plus = _flash_attn_version >= PkgVersion("2.1") _flash_attn_2_3_plus = _flash_attn_version >= PkgVersion("2.3") @@ -162,13 +166,18 @@ def _get_supported_versions(version_min, version_max): _flash_attn_2_4_1_plus = _flash_attn_version >= PkgVersion("2.4.1") _flash_attn_2_5_7_plus = _flash_attn_version >= PkgVersion("2.5.7") _flash_attn_2_6_0_plus = _flash_attn_version >= PkgVersion("2.6.0") + _flash_attn_2_7_0_plus = _flash_attn_version >= PkgVersion("2.7.0") elif ( torch.cuda.is_available() and get_device_compute_capability() >= (8, 0) and _NVTE_FLASH_ATTN ): fa_logger.warning( "Supported flash-attn versions are %s. Found flash-attn %s.", _get_supported_versions( - _flash_attn_version_required, + ( + _flash_attn_version_required + if get_device_compute_capability() < (10, 0) + else _flash_attn_version_required_blackwell + ), _flash_attn_max_version, ), _flash_attn_version, @@ -181,11 +190,13 @@ def _get_supported_versions(version_min, version_max): _flash_attn_3_version = PkgVersion("0") _flash_attn_3_0_0_beta = False _use_flash_attn_3 = False +# TODO(cyang): update FA to 2.7.3 when its FA3 compilation issue is resolved +# https://github.com/Dao-AILab/flash-attention/issues/1452 _flash_attn_3_installation_steps = """\ -(1) pip install "git+https://github.com/Dao-AILab/flash-attention.git#egg=flashattn-hopper&subdirectory=hopper" +(1) pip install "git+https://github.com/Dao-AILab/flash-attention.git@v2.7.2#egg=flashattn-hopper&subdirectory=hopper" (2) python_path=`python -c "import site; print(site.getsitepackages()[0])"` (3) mkdir -p $python_path/flashattn_hopper -(4) wget -P $python_path/flashattn_hopper https://raw.githubusercontent.com/Dao-AILab/flash-attention/main/hopper/flash_attn_interface.py""" +(4) wget -P $python_path/flashattn_hopper https://raw.githubusercontent.com/Dao-AILab/flash-attention/v2.7.2/hopper/flash_attn_interface.py""" try: _flash_attn_3_version = PkgVersion(get_pkg_version("flashattn-hopper")) except PackageNotFoundError: @@ -317,7 +328,7 @@ def __eq__(self, other): if fname != "fp8_meta": if sf != of: return False - elif sf["recipe"] != of["recipe"]: + elif sf.get("recipe", None) != of.get("recipe", None): return False return True @@ -434,15 +445,6 @@ def get_attention_backend( if not use_unfused_attention: logger.debug("Disabling UnfusedDotProductAttention due to NVTE_UNFUSED_ATTN=0") - # Filter: ONNX mode - if is_in_onnx_export_mode(): - if use_flash_attention and _flash_attn_is_installed: - logger.debug("Disabling FlashAttention due to ONNX mode") - use_flash_attention = False - if use_fused_attention: - logger.debug("Disabling FusedAttention due to ONNX mode") - use_fused_attention = False - # Filter: Compute capability if device_compute_capability < (8, 0): if use_flash_attention and _flash_attn_is_installed: @@ -937,7 +939,7 @@ def get_attention_backend( and use_fused_attention and fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"] ): - if device_compute_capability == (9, 0): + if device_compute_capability >= (9, 0): logger.debug( "Disabling FlashAttention to give FusedAttention preference on Hopper+ " "for performance reasons" @@ -1390,8 +1392,9 @@ def pack_tensor( indices = indices.repeat(1, tensor.shape[1], tensor.shape[2]) if isinstance(tensor, Float8Tensor): tensor_data = torch.cat((tensor._data, padding_indice), dim=0) + gathered_data = torch.gather(tensor_data, 0, indices) - packed = Float8Tensor.make_like(tensor, data=torch.gather(tensor_data, 0, indices)) + packed = Float8Tensor.make_like(tensor, data=gathered_data, shape=gathered_data.shape) else: tensor = torch.cat((tensor, padding_indice), dim=0) @@ -1444,7 +1447,8 @@ def unpack_tensor( ) if isinstance(tensor, Float8Tensor): unpacked.scatter_(0, indices, tensor._data) - unpacked = Float8Tensor.make_like(tensor, data=unpacked[0:-1, :, :]) + unpacked_data = unpacked[0:-1, :, :] + unpacked = Float8Tensor.make_like(tensor, data=unpacked_data, shape=unpacked_data.shape) else: unpacked.scatter_(0, indices, tensor) unpacked = unpacked[0:-1, :, :] @@ -1746,6 +1750,49 @@ def flash_attn_a2a_communicate( return a2a_outputs[0] if len(a2a_inputs) == 1 else a2a_outputs +def get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=False): + """Get the list of quantizers used in attention from the quantizers list.""" + if not fp8: + num_of_nones = 8 if cp_specific_quantizers else 6 + return [None] * num_of_nones + QKV_quantizer = quantizers["scaling_fwd"][META_QKV] + QKV_quantizer.internal = True + QKV_quantizer.set_usage(rowwise=True, columnwise=False) + O_quantizer = quantizers["scaling_fwd"][META_O] + O_quantizer.set_usage(rowwise=True, columnwise=False) + S_quantizer = quantizers["scaling_fwd"][META_S] + S_quantizer.internal = True + S_quantizer.set_usage(rowwise=True, columnwise=False) + dQKV_quantizer = quantizers["scaling_bwd"][META_DQKV] + dQKV_quantizer.interal = True + dQKV_quantizer.set_usage(rowwise=True, columnwise=False) + dO_quantizer = quantizers["scaling_bwd"][META_DO] + dO_quantizer.set_usage(rowwise=True, columnwise=False) + dO_quantizer.internal = True + dP_quantizer = quantizers["scaling_bwd"][META_DP] + dP_quantizer.set_usage(rowwise=True, columnwise=False) + dP_quantizer.interal = True + dQKV_CP_quantizer = quantizers["scaling_bwd"][META_DQKV_CP] + dQKV_CP_quantizer.set_usage(rowwise=True, columnwise=False) + dQKV_CP_quantizer.internal = True + O_CP_quantizer = quantizers["scaling_fwd"][META_O_CP] + O_CP_quantizer.set_usage(rowwise=True, columnwise=False) + + if cp_specific_quantizers: + return ( + QKV_quantizer, + O_quantizer, + O_CP_quantizer, + S_quantizer, + dQKV_quantizer, + dQKV_CP_quantizer, + dO_quantizer, + dP_quantizer, + ) + + return QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer + + class AttnFuncWithCPAndKVP2P(torch.autograd.Function): """ Attention implementation with context parallelism. Exchange KV between CP ranks @@ -1784,6 +1831,7 @@ def forward( cp_group, cp_global_ranks, cp_stream, + quantizers, ): # pylint: disable=missing-function-docstring if softmax_scale is None: @@ -1839,56 +1887,58 @@ def forward( cu_seqlens_q_per_step = [None for _ in range(cp_size)] cu_seqlens_kv_per_step = [None for _ in range(cp_size)] - fused_attn_qkv_dtype = None fused_attn_backend = None - amax_per_step = None qkv_dtype = q.dtype # "fp8_mha" decides outputs in fp8, while inputs are inferred from the real dtype is_input_fp8 = False - is_output_fp8 = fp8_meta is not None and fp8_meta["recipe"].fp8_mha + is_output_fp8 = False + if fp8: + is_output_fp8 = fp8_meta["recipe"].fp8_mha + + ( + QKV_quantizer, + O_quantizer, + O_CP_quantizer, + S_quantizer, + dQKV_quantizer, + dQKV_CP_quantizer, + dO_quantizer, + dP_quantizer, + ) = get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=True) + if fp8: if use_fused_attention: - fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) - fused_attn_qkv_dtype = fp8_dtype_forward fused_attn_backend = FusedAttnBackend["FP8"] + assert isinstance(k, q.__class__) and isinstance( v, q.__class__ ), "q, k, and v must have the same type." is_input_fp8 = isinstance(q, Float8Tensor) - if is_input_fp8: - fp8_meta["scaling_fwd"].scale_inv[META_QKV] = q._scale_inv - q_fp8, k_fp8, v_fp8 = q, k, v - q, k, v = q_fp8._data, k_fp8._data, v_fp8._data - else: + if not is_input_fp8: q_f16, k_f16, v_f16 = q, k, v if cp_size_a2a == 1 or int(os.getenv("NVTE_FP8_DPA_BWD", "1")): - q = cast_to_fp8(q_f16, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward) + q = QKV_quantizer(q_f16) if int(os.getenv("NVTE_FP8_DPA_BWD", "1")): - k, v = [ - cast_to_fp8(x, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward) - for x in [k_f16, v_f16] - ] + k, v = [QKV_quantizer(x) for x in [k_f16, v_f16]] fp8_meta_kwargs = {} - fp8_meta_kwargs["d_scale_qkv"] = fp8_meta["scaling_fwd"].scale_inv - fp8_meta_kwargs["d_scale_qkv_offset"] = META_QKV - fp8_meta_kwargs["d_scale_s"] = fp8_meta["scaling_fwd"].scale_inv - fp8_meta_kwargs["d_scale_s_offset"] = META_S - fp8_meta_kwargs["q_scale_s"] = fp8_meta["scaling_fwd"].scale - fp8_meta_kwargs["q_scale_s_offset"] = META_S - fp8_meta_kwargs["q_scale_o"] = fp8_meta["scaling_fwd"].scale - fp8_meta_kwargs["q_scale_o_offset"] = META_O_CP - amax_per_step = torch.zeros((2, cp_size), dtype=torch.float32, device=q.device) + fp8_meta_kwargs["s_quantizer"] = S_quantizer + fp8_meta_kwargs["o_quantizer"] = O_CP_quantizer # partial result quantizer else: assert False, "FP8 is only supported with Fused Attention!" else: q_f16 = q if use_fused_attention: fp8_meta_kwargs = {} - fused_attn_qkv_dtype = TE_DType[q.dtype] fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"] + if fp8: + q = q._data + k = k._data + v = v._data + if cp_size_a2a > 1: chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size_a2a, q.device, True) + q, k, v = flash_attn_a2a_communicate( [q, k, v], chunk_ids_for_a2a, seq_dim, cp_size_a2a, cp_group_a2a, cp_stream, True ) @@ -1896,7 +1946,7 @@ def forward( q_f16 = q elif not is_input_fp8 and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): q_f16 = q - q = cast_to_fp8(q_f16, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward) + q = QKV_quantizer(q_f16)._data assert qkv_format == "thd" or ( q.shape[seq_dim] % 2 == 0 and k.shape[seq_dim] % 2 == 0 @@ -1953,12 +2003,17 @@ def forward( flash_attn_fwd = _flash_attn_fwd fa_forward_kwargs["dropout_p"] = dropout_p fa_forward_kwargs["return_softmax"] = False - if _flash_attn_2_3_plus: + if (_flash_attn_2_3_plus and not _flash_attn_2_7_0_plus) or _use_flash_attn_3: fa_forward_kwargs["window_size"] = (-1, 0) if causal else (-1, -1) + elif _flash_attn_2_7_0_plus: + fa_forward_kwargs["window_size_left"] = -1 + fa_forward_kwargs["window_size_right"] = 0 if causal else -1 if _flash_attn_2_4_plus: fa_forward_kwargs["alibi_slopes"] = None - if _flash_attn_2_5_7_plus: + if _flash_attn_2_5_7_plus and qkv_format == "thd": fa_forward_kwargs["block_table"] = None + if _flash_attn_2_6_0_plus: + fa_forward_kwargs["softcap"] = 0.0 # Flash Attn inputs q_inputs = [None, None] @@ -2007,17 +2062,7 @@ def forward( kv_inputs[i % 2] = p2p_comm_buffers[i] else: # KV exchange is in BF16/FP16, cast received KV in each step - kv_inputs[i % 2] = cast_to_fp8( - p2p_comm_buffers[i], - fp8_meta["scaling_fwd"], - META_QKV, - fp8_dtype_forward, - ) - if fp8 and use_fused_attention: - fp8_meta_kwargs["amax_s"] = amax_per_step - fp8_meta_kwargs["amax_s_offset"] = i - fp8_meta_kwargs["amax_o"] = amax_per_step - fp8_meta_kwargs["amax_o_offset"] = cp_size + i + kv_inputs[i % 2] = QKV_quantizer(p2p_comm_buffers[i]) if causal: if i == 0: if pad_between_seqs_q: @@ -2058,25 +2103,40 @@ def forward( ), dim=-1, ).contiguous() + + q_part = q_inputs[i % 2] + k_part = ( + kv_inputs[i % 2][..., 0, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][0] + ) + v_part = ( + kv_inputs[i % 2][..., 1, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][1] + ) + if fp8: + q_part = QKV_quantizer.create_tensor_from_data( + q_part, fake_dtype=qkv_dtype, internal=True + ) + k_part = QKV_quantizer.create_tensor_from_data( + k_part, fake_dtype=qkv_dtype, internal=True + ) + v_part = QKV_quantizer.create_tensor_from_data( + v_part, fake_dtype=qkv_dtype, internal=True + ) + out_per_step[i], aux_ctx_tensors = fused_attn_fwd( is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q_per_step[i], cu_seqlens_kv_per_step[i], - q_inputs[i % 2], - ( - kv_inputs[i % 2][..., 0, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][0] - ), - ( - kv_inputs[i % 2][..., 1, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][1] - ), - fused_attn_qkv_dtype, - fused_attn_backend, + q_part, + k_part, + v_part, + fake_dtype=qkv_dtype, + fused_attention_backend=fused_attn_backend, attn_scale=softmax_scale, dropout=dropout_p, qkv_layout=qkv_layout, @@ -2117,10 +2177,16 @@ def forward( causal=True, **fa_forward_kwargs, ) - out_per_step[i] = fa_outputs[4] - softmax_lse_per_step[i] = fa_outputs[5] - if not _use_flash_attn_3: - rng_states[i] = fa_outputs[7] + if not _flash_attn_2_7_0_plus: + out_per_step[i] = fa_outputs[4] + softmax_lse_per_step[i] = fa_outputs[5] + if not _use_flash_attn_3: + rng_states[i] = fa_outputs[7] + else: + out_per_step[i] = fa_outputs[0] + softmax_lse_per_step[i] = fa_outputs[1] + if not _use_flash_attn_3: + rng_states[i] = fa_outputs[3] elif i <= rank: if pad_between_seqs_q: cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank( @@ -2160,24 +2226,38 @@ def forward( if attn_bias is not None: idx = (rank - i) % cp_size attn_bias_inputs[i % 2] = attn_bias[..., idx, :].contiguous() + + q_part = q_inputs[i % 2] + k_part = ( + kv_inputs[i % 2][..., 0, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][0] + ) + v_part = ( + kv_inputs[i % 2][..., 1, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][1] + ) + if fp8: + q_part = QKV_quantizer.create_tensor_from_data( + q_part, fake_dtype=qkv_dtype, internal=True + ) + k_part = QKV_quantizer.create_tensor_from_data( + k_part, fake_dtype=qkv_dtype, internal=True + ) + v_part = QKV_quantizer.create_tensor_from_data( + v_part, fake_dtype=qkv_dtype, internal=True + ) out_per_step[i], aux_ctx_tensors = fused_attn_fwd( is_training, max_seqlen_q, max_seqlen_kv // 2, cu_seqlens_q_per_step[i], cu_seqlens_kv_per_step[i], - q_inputs[i % 2], - ( - kv_inputs[i % 2][..., 0, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][0] - ), - ( - kv_inputs[i % 2][..., 1, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][1] - ), - fused_attn_qkv_dtype, + q_part, + k_part, + v_part, + qkv_dtype, fused_attn_backend, attn_scale=softmax_scale, dropout=dropout_p, @@ -2207,8 +2287,13 @@ def forward( max_seqlen_q, max_seqlen_kv // 2, ] - if _use_flash_attn_3 or _flash_attn_2_3_plus: + if _use_flash_attn_3 or ( + _flash_attn_2_3_plus and not _flash_attn_2_7_0_plus + ): fa_forward_kwargs["window_size"] = (-1, -1) + elif _flash_attn_2_7_0_plus: + fa_forward_kwargs["window_size_left"] = -1 + fa_forward_kwargs["window_size_right"] = -1 fa_outputs = flash_attn_fwd( q_inputs[i % 2], ( @@ -2225,10 +2310,16 @@ def forward( causal=False, **fa_forward_kwargs, ) - out_per_step[i] = fa_outputs[4] - softmax_lse_per_step[i] = fa_outputs[5] - if not _use_flash_attn_3: - rng_states[i] = fa_outputs[7] + if not _flash_attn_2_7_0_plus: + out_per_step[i] = fa_outputs[4] + softmax_lse_per_step[i] = fa_outputs[5] + if not _use_flash_attn_3: + rng_states[i] = fa_outputs[7] + else: + out_per_step[i] = fa_outputs[0] + softmax_lse_per_step[i] = fa_outputs[1] + if not _use_flash_attn_3: + rng_states[i] = fa_outputs[3] else: if pad_between_seqs_q: cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank( @@ -2277,24 +2368,38 @@ def forward( ), dim=-1, ).contiguous() + + q_part = q_inputs[i % 2] + k_part = ( + kv_inputs[i % 2][..., 0, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][0] + ) + v_part = ( + kv_inputs[i % 2][..., 1, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][1] + ) + if fp8: + q_part = QKV_quantizer.create_tensor_from_data( + q_part, fake_dtype=qkv_dtype, internal=True + ) + k_part = QKV_quantizer.create_tensor_from_data( + k_part, fake_dtype=qkv_dtype, internal=True + ) + v_part = QKV_quantizer.create_tensor_from_data( + v_part, fake_dtype=qkv_dtype, internal=True + ) out_per_step[i], aux_ctx_tensors = fused_attn_fwd( is_training, max_seqlen_q // 2, max_seqlen_kv, cu_seqlens_q_per_step[i], cu_seqlens_kv_per_step[i], - q_inputs[i % 2], - ( - kv_inputs[i % 2][..., 0, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][0] - ), - ( - kv_inputs[i % 2][..., 1, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][1] - ), - fused_attn_qkv_dtype, + q_part, + k_part, + v_part, + qkv_dtype, fused_attn_backend, attn_scale=softmax_scale, dropout=dropout_p, @@ -2324,8 +2429,13 @@ def forward( max_seqlen_q // 2, max_seqlen_kv, ] - if _use_flash_attn_3 or _flash_attn_2_3_plus: + if _use_flash_attn_3 or ( + _flash_attn_2_3_plus and not _flash_attn_2_7_0_plus + ): fa_forward_kwargs["window_size"] = (-1, -1) + elif _flash_attn_2_7_0_plus: + fa_forward_kwargs["window_size_left"] = -1 + fa_forward_kwargs["window_size_right"] = -1 fa_outputs = flash_attn_fwd( q_inputs[i % 2], ( @@ -2342,10 +2452,16 @@ def forward( causal=False, **fa_forward_kwargs, ) - out_per_step[i] = fa_outputs[4] - softmax_lse_per_step[i] = fa_outputs[5] - if not _use_flash_attn_3: - rng_states[i] = fa_outputs[7] + if not _flash_attn_2_7_0_plus: + out_per_step[i] = fa_outputs[4] + softmax_lse_per_step[i] = fa_outputs[5] + if not _use_flash_attn_3: + rng_states[i] = fa_outputs[7] + else: + out_per_step[i] = fa_outputs[0] + softmax_lse_per_step[i] = fa_outputs[1] + if not _use_flash_attn_3: + rng_states[i] = fa_outputs[3] else: if pad_between_seqs_q: cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank( @@ -2374,24 +2490,38 @@ def forward( ), dim=-1, ).contiguous() + + q_part = q + k_part = ( + kv_inputs[i % 2][..., 0, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][0] + ) + v_part = ( + kv_inputs[i % 2][..., 1, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][1] + ) + if fp8: + q_part = QKV_quantizer.create_tensor_from_data( + q_part, fake_dtype=qkv_dtype, internal=True + ) + k_part = QKV_quantizer.create_tensor_from_data( + k_part, fake_dtype=qkv_dtype, internal=True + ) + v_part = QKV_quantizer.create_tensor_from_data( + v_part, fake_dtype=qkv_dtype, internal=True + ) out_per_step[i], aux_ctx_tensors = fused_attn_fwd( is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q_per_step[i], cu_seqlens_kv_per_step[i], - q, - ( - kv_inputs[i % 2][..., 0, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][0] - ), - ( - kv_inputs[i % 2][..., 1, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][1] - ), - fused_attn_qkv_dtype, + q_part, + k_part, + v_part, + qkv_dtype, fused_attn_backend, attn_scale=softmax_scale, dropout=dropout_p, @@ -2433,10 +2563,16 @@ def forward( causal=False, **fa_forward_kwargs, ) - out_per_step[i] = fa_outputs[4] - softmax_lse_per_step[i] = fa_outputs[5] - if not _use_flash_attn_3: - rng_states[i] = fa_outputs[7] + if not _flash_attn_2_7_0_plus: + out_per_step[i] = fa_outputs[4] + softmax_lse_per_step[i] = fa_outputs[5] + if not _use_flash_attn_3: + rng_states[i] = fa_outputs[7] + else: + out_per_step[i] = fa_outputs[0] + softmax_lse_per_step[i] = fa_outputs[1] + if not _use_flash_attn_3: + rng_states[i] = fa_outputs[3] if i > 0: # wait until fwd restuls correction of last step is done @@ -2454,13 +2590,7 @@ def forward( with torch.cuda.stream(flash_attn_streams[(i - 1) % 2]): if fp8: - out_per_step[i - 1] = cast_from_fp8( - out_per_step[i - 1], - fp8_meta["scaling_fwd"], - META_O_CP, - fp8_dtype_forward, - TE_DType[torch.float32], - ) + out_per_step[i - 1] = out_per_step[i - 1].dequantize() if i == 1: out = torch.zeros_like(q if not fp8 else out_per_step[0]).view(q.shape) softmax_lse = torch.clone(softmax_lse_per_step[0]).to(torch.double) @@ -2562,70 +2692,48 @@ def forward( elif not use_fused_attention: out = out.view(-1, *out.shape[-2:]) - if fp8 and use_fused_attention: - amax_cp_fwd = amax_per_step.amax(dim=1) - fp8_meta["scaling_fwd"].amax_history[0][META_S] = amax_cp_fwd[0] - fp8_meta["scaling_fwd"].amax_history[0][META_O_CP] = amax_cp_fwd[1] - out_fp8 = None out_f16 = out.to(qkv_dtype) + if fp8 and (is_output_fp8 or int(os.getenv("NVTE_FP8_DPA_BWD", "1"))): - out_fp8 = cast_to_fp8(out_f16, fp8_meta["scaling_fwd"], META_O, fp8_dtype_forward) - - if fp8 and is_output_fp8: - out_ret = Float8Tensor( - data=out_fp8, - fp8_meta=fp8_meta, - fp8_meta_forward=True, - fp8_meta_index=META_O, - fp8_dtype=fp8_dtype_forward, - dtype=qkv_dtype, - ) - else: - out_ret = out_f16 + out_fp8 = O_quantizer(out_f16) # final result + + out_ret = out_fp8 if (fp8 and is_output_fp8) else out_f16 if fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")): - q_save, kv_save, out_save = q, kv, out_fp8 - fp8_fwd_scales = fp8_meta["scaling_fwd"].scale.clone() - fp8_fwd_scale_invs = fp8_meta["scaling_fwd"].scale_inv.clone() + q_save, kv_save, out_save = q, kv, out_fp8._data elif fp8 and is_input_fp8: - q_fp8 = Float8Tensor( - data=q, - fp8_meta=fp8_meta, - fp8_meta_forward=True, - fp8_meta_index=META_QKV, - fp8_dtype=fp8_dtype_forward, - dtype=q_fp8.dtype, - ) - kv_fp8 = Float8Tensor( - data=kv, - fp8_meta=fp8_meta, - fp8_meta_forward=True, - fp8_meta_index=META_QKV, - fp8_dtype=fp8_dtype_forward, - dtype=k_fp8.dtype, - ) - q_save, kv_save, out_save = q_fp8, kv_fp8, out_f16 - fp8_fwd_scales, fp8_fwd_scale_invs = None, None + q_save, kv_save, out_save = q, k, out_f16 else: q_f16 = q_f16.view(q.shape) q_save, kv_save, out_save = q_f16, kv, out_f16 - fp8_fwd_scales, fp8_fwd_scale_invs = None, None - ctx.save_for_backward( + tensors_to_save, tensor_objects = prepare_for_saving( q_save, kv_save, out_save, softmax_lse, cu_seqlens_q_padded, cu_seqlens_kv_padded, - fp8_fwd_scales, - fp8_fwd_scale_invs, *cu_seqlens_q_per_step, *cu_seqlens_kv_per_step, *rng_states, *attn_biases, ) + ctx.save_for_backward(*tensors_to_save) + ctx.tensor_objects = tensor_objects + + ctx.qkv_dtype = qkv_dtype + ctx.QKV_quantizer = QKV_quantizer + ctx.O_quantizer = O_quantizer + ctx.O_CP_quantizer = O_CP_quantizer + ctx.S_quantizer = S_quantizer + ctx.dQKV_quantizer = dQKV_quantizer + ctx.dQKV_CP_quantizer = dQKV_CP_quantizer + ctx.dO_quantizer = dO_quantizer + ctx.dP_quantizer = dP_quantizer + ctx.qkv_dtype = qkv_dtype + ctx.cp_group_a2a = cp_group_a2a ctx.cp_size_a2a = cp_size_a2a ctx.rank_a2a = rank_a2a @@ -2648,6 +2756,7 @@ def forward( ctx.fp8_meta = fp8_meta ctx.is_input_fp8 = is_input_fp8 ctx.is_output_fp8 = is_output_fp8 + return out_ret @staticmethod @@ -2662,13 +2771,15 @@ def backward(ctx, dout): recv_src = ctx.cp_global_ranks[(rank + 1) % cp_size * cp_size_a2a + rank_a2a] batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2) - (*saved_tensors,) = ctx.saved_tensors - (q, kv, out, softmax_lse, cu_seqlens_q_padded, cu_seqlens_kv_padded) = saved_tensors[:6] - (fp8_fwd_scales, fp8_fwd_scale_invs) = saved_tensors[6:8] - cu_seqlens_q_per_step = saved_tensors[8 : 8 + cp_size] - cu_seqlens_kv_per_step = saved_tensors[8 + cp_size : 8 + cp_size * 2] - rng_states = saved_tensors[8 + cp_size * 2 : 8 + cp_size * 3] - attn_biases = saved_tensors[8 + cp_size * 3 : 8 + cp_size * 4] + saved_tensors = ctx.saved_tensors + + q, kv, out, softmax_lse, cu_seqlens_q_padded, cu_seqlens_kv_padded, *other_tensors = ( + restore_from_saved(ctx.tensor_objects, saved_tensors) + ) + cu_seqlens_q_per_step = other_tensors[:cp_size] + cu_seqlens_kv_per_step = other_tensors[cp_size : cp_size * 2] + rng_states = other_tensors[cp_size * 2 : cp_size * 3] + attn_biases = other_tensors[cp_size * 3 : cp_size * 4] causal = "causal" in ctx.attn_mask_type padding = "padding" in ctx.attn_mask_type @@ -2724,50 +2835,40 @@ def backward(ctx, dout): dq = None dout_dtype = dout.dtype fused_attn_backend = None - fused_attn_qkv_dtype = None fused_attn_dqkv_dtype = None - amax_per_step = None - dout_fp8_dtype = None if ctx.fp8: if ctx.use_fused_attention: - fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True) - fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False) - fused_attn_qkv_dtype = fp8_dtype_forward - fused_attn_dqkv_dtype = fp8_dtype_backward fused_attn_backend = FusedAttnBackend["FP8"] + dq_fp8 = torch.empty((cp_size, *q.shape), dtype=q.dtype, device=q.device) dkv_fp8 = torch.empty((cp_size, *kv.shape), dtype=kv.dtype, device=kv.device) dkv_fp8_ = torch.empty_like(dkv_fp8) if ctx.is_output_fp8: assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!" - ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO] = dout._scale_inv + fused_attn_dqkv_dtype = dout._fp8_dtype dout = dout._data else: - dout = cast_to_fp8( - dout, ctx.fp8_meta["scaling_bwd"], META_DO, fp8_dtype_backward - ) + dout = ctx.dO_quantizer(dout) + fused_attn_dqkv_dtype = dout._fp8_dtype + dout = dout._data p2p_comm_buffers = [[kv, dkv_fp8], [torch.empty_like(kv), dkv_fp8_]] fp8_meta_kwargs = {} - fp8_meta_kwargs["d_scale_qkv"] = fp8_fwd_scale_invs[META_QKV] - fp8_meta_kwargs["d_scale_s"] = fp8_fwd_scale_invs[META_S] - fp8_meta_kwargs["d_scale_o"] = fp8_fwd_scale_invs[META_O] - fp8_meta_kwargs["d_scale_do"] = ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO] - fp8_meta_kwargs["d_scale_dp"] = ctx.fp8_meta["scaling_bwd"].scale_inv[META_DP] - fp8_meta_kwargs["q_scale_s"] = fp8_fwd_scales[META_S] - fp8_meta_kwargs["q_scale_dp"] = ctx.fp8_meta["scaling_bwd"].scale[META_DP] - fp8_meta_kwargs["q_scale_dqkv"] = ctx.fp8_meta["scaling_bwd"].scale[META_DQKV_CP] - amax_per_step = torch.zeros((2, cp_size), dtype=torch.float32, device=q.device) + fp8_meta_kwargs["s_quantizer"] = ctx.S_quantizer + fp8_meta_kwargs["dp_quantizer"] = ctx.dP_quantizer + fp8_meta_kwargs["dqkv_quantizer"] = ctx.dQKV_CP_quantizer else: assert False, "FP8 is only supported with Fused Attention!" else: if ctx.fp8_meta is not None and ctx.is_input_fp8: - q, kv = [x.from_float8(x.dtype) for x in [q, kv]] + q = ctx.QKV_quantizer.create_tensor_from_data( + q, fake_dtype=ctx.qkv_dtype, internal=True + ) + kv = ctx.QKV_quantizer.create_tensor_from_data( + kv, fake_dtype=ctx.qkv_dtype, internal=True + ) + q, kv = q.dequantize(), kv.dequantize() if cp_size_a2a == 1: - dout = dout.from_float8(dout_dtype) - else: - dout_fp8_dtype = dout._fp8_dtype - dout_scale_inv = dout._scale_inv - dout = dout._data + dout = dout.dequantize() dq = torch.empty_like(q) p2p_comm_buffers = [ torch.empty((2, *kv.shape), dtype=kv.dtype, device=kv.device), @@ -2776,7 +2877,6 @@ def backward(ctx, dout): p2p_comm_buffers[0][0].copy_(kv) if ctx.use_fused_attention: fp8_meta_kwargs = {} - fused_attn_qkv_dtype = TE_DType[q.dtype] fused_attn_dqkv_dtype = TE_DType[dout_dtype] fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"] @@ -2795,14 +2895,9 @@ def backward(ctx, dout): True, ) if not ctx.fp8 and ctx.fp8_meta is not None and ctx.is_output_fp8: - dout = cast_from_fp8( - dout, - None, - None, - dout_fp8_dtype, - TE_DType[dout_dtype], - scale_inv=dout_scale_inv, # pylint: disable=used-before-assignment - ) + dout = ctx.dO_quantizer.create_tensor_from_data(data=dout, internal=True) + dout = dout.dequantize() + dout = dout._data out = out.view(*q.shape) dout = dout.view(*q.shape) @@ -2827,6 +2922,8 @@ def backward(ctx, dout): fa_backward_kwargs["alibi_slopes"] = None if _flash_attn_2_4_1_plus: fa_backward_kwargs["deterministic"] = ctx.deterministic + if _flash_attn_2_6_0_plus: + fa_backward_kwargs["softcap"] = 0.0 for i in range(cp_size): # wait until KV is received @@ -2868,9 +2965,6 @@ def backward(ctx, dout): kv = p2p_comm_buffers[i % 2][0] q_, kv_, out_, dout_ = None, None, None, None dq_, dk_, dv_ = None, None, None - if ctx.fp8 and ctx.use_fused_attention: - fp8_meta_kwargs["amax_dp"] = amax_per_step[0][i] - fp8_meta_kwargs["amax_dqkv"] = amax_per_step[0][i] # In reversed order of fwd if causal: if i == (cp_size - 1): @@ -2899,17 +2993,39 @@ def backward(ctx, dout): aux_ctx_tensors = [softmax_lse, rng_states[cp_size - i - 1]] if attn_dbias is not None: aux_ctx_tensors += [attn_biases[cp_size - i - 1]] + q_part = q_ + k_part = kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0] + v_part = kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1] + out_part = out_ + dout_part = dout_ + + if ctx.fp8: + q_part = ctx.QKV_quantizer.create_tensor_from_data( + q_part, fake_dtype=ctx.qkv_dtype, internal=True + ) + k_part = ctx.QKV_quantizer.create_tensor_from_data( + k_part, fake_dtype=ctx.qkv_dtype, internal=True + ) + v_part = ctx.QKV_quantizer.create_tensor_from_data( + v_part, fake_dtype=ctx.qkv_dtype, internal=True + ) + out_part = ctx.O_quantizer.create_tensor_from_data( + out_part, fake_dtype=ctx.qkv_dtype, internal=True + ) + dout_part = ctx.dO_quantizer.create_tensor_from_data( + dout_part, fake_dtype=ctx.qkv_dtype, internal=True + ) dq_, dk_, dv_, dbias_ = fused_attn_bwd( ctx.max_seqlen_q, ctx.max_seqlen_kv, cu_seqlens_q_per_step[cp_size - i - 1], cu_seqlens_kv_per_step[cp_size - i - 1], - q_, - kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0], - kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1], - out_, - dout_, - fused_attn_qkv_dtype, + q_part, + k_part, + v_part, + out_part, + dout_part, + ctx.qkv_dtype, fused_attn_dqkv_dtype, aux_ctx_tensors, fused_attn_backend, @@ -2923,6 +3039,10 @@ def backward(ctx, dout): deterministic=ctx.deterministic, **fp8_meta_kwargs, ) + if ctx.fp8: + dq_ = dq_._data + dk_ = dk_._data + dv_ = dv_._data else: dq_ = torch.empty_like(q_) dkv_ = torch.empty_like(kv_) @@ -2934,8 +3054,13 @@ def backward(ctx, dout): ctx.max_seqlen_q, ctx.max_seqlen_kv, ] - if _use_flash_attn_3 or _flash_attn_2_3_plus: + if _use_flash_attn_3 or ( + _flash_attn_2_3_plus and not _flash_attn_2_7_0_plus + ): fa_backward_kwargs["window_size"] = (-1, 0) + elif _flash_attn_2_7_0_plus: + fa_backward_kwargs["window_size_left"] = -1 + fa_backward_kwargs["window_size_right"] = 0 if not _use_flash_attn_3: fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1] flash_attn_bwd( @@ -2981,17 +3106,39 @@ def backward(ctx, dout): aux_ctx_tensors = [softmax_lse, rng_states[cp_size - i - 1]] if attn_dbias is not None: aux_ctx_tensors += [attn_biases[cp_size - i - 1]] + q_part = q_ + k_part = kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0] + v_part = kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1] + out_part = out_ + dout_part = dout_ + + if ctx.fp8: + q_part = ctx.QKV_quantizer.create_tensor_from_data( + q_part, fake_dtype=ctx.qkv_dtype, internal=True + ) + k_part = ctx.QKV_quantizer.create_tensor_from_data( + k_part, fake_dtype=ctx.qkv_dtype, internal=True + ) + v_part = ctx.QKV_quantizer.create_tensor_from_data( + v_part, fake_dtype=ctx.qkv_dtype, internal=True + ) + out_part = ctx.O_quantizer.create_tensor_from_data( + out_part, fake_dtype=ctx.qkv_dtype, internal=True + ) + dout_part = ctx.dO_quantizer.create_tensor_from_data( + dout_part, fake_dtype=ctx.qkv_dtype, internal=True + ) dq_, dk_, dv_, dbias_ = fused_attn_bwd( ctx.max_seqlen_q, ctx.max_seqlen_kv // 2, cu_seqlens_q_per_step[cp_size - i - 1], cu_seqlens_kv_per_step[cp_size - i - 1], - q_, - kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0], - kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1], - out_, - dout_, - fused_attn_qkv_dtype, + q_part, + k_part, + v_part, + out_part, + dout_part, + ctx.qkv_dtype, fused_attn_dqkv_dtype, aux_ctx_tensors, fused_attn_backend, @@ -3007,6 +3154,10 @@ def backward(ctx, dout): deterministic=ctx.deterministic, **fp8_meta_kwargs, ) + if ctx.fp8: + dq_ = dq_._data + dk_ = dk_._data + dv_ = dv_._data else: dq_ = torch.empty_like(q_) dkv_ = torch.empty_like(kv_) @@ -3018,8 +3169,13 @@ def backward(ctx, dout): ctx.max_seqlen_q, ctx.max_seqlen_kv // 2, ] - if _use_flash_attn_3 or _flash_attn_2_3_plus: + if _use_flash_attn_3 or ( + _flash_attn_2_3_plus and not _flash_attn_2_7_0_plus + ): fa_backward_kwargs["window_size"] = (-1, -1) + if _flash_attn_2_7_0_plus: + fa_backward_kwargs["window_size_left"] = -1 + fa_backward_kwargs["window_size_right"] = -1 if not _use_flash_attn_3: fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1] flash_attn_bwd( @@ -3066,17 +3222,40 @@ def backward(ctx, dout): aux_ctx_tensors = [softmax_lse_, rng_states[cp_size - i - 1]] if attn_dbias is not None: aux_ctx_tensors += [attn_biases[cp_size - i - 1]] + + q_part = q_ + k_part = kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0] + v_part = kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1] + out_part = out_ + dout_part = dout_ + + if ctx.fp8: + q_part = ctx.QKV_quantizer.create_tensor_from_data( + q_part, fake_dtype=ctx.qkv_dtype, internal=True + ) + k_part = ctx.QKV_quantizer.create_tensor_from_data( + k_part, fake_dtype=ctx.qkv_dtype, internal=True + ) + v_part = ctx.QKV_quantizer.create_tensor_from_data( + v_part, fake_dtype=ctx.qkv_dtype, internal=True + ) + out_part = ctx.O_quantizer.create_tensor_from_data( + out_part, fake_dtype=ctx.qkv_dtype, internal=True + ) + dout_part = ctx.dO_quantizer.create_tensor_from_data( + dout_part, fake_dtype=ctx.qkv_dtype, internal=True + ) dq_, dk_, dv_, dbias_ = fused_attn_bwd( ctx.max_seqlen_q // 2, ctx.max_seqlen_kv, cu_seqlens_q_per_step[cp_size - i - 1], cu_seqlens_kv_per_step[cp_size - i - 1], - q_, - kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0], - kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1], - out_, - dout_, - fused_attn_qkv_dtype, + q_part, + k_part, + v_part, + out_part, + dout_part, + ctx.qkv_dtype, fused_attn_dqkv_dtype, aux_ctx_tensors, fused_attn_backend, @@ -3092,6 +3271,11 @@ def backward(ctx, dout): deterministic=ctx.deterministic, **fp8_meta_kwargs, ) + if ctx.fp8: + dq_ = dq_._data + dk_ = dk_._data + dv_ = dv_._data + else: dq_ = torch.empty_like(q_) dkv_ = torch.empty_like(kv_) @@ -3103,8 +3287,13 @@ def backward(ctx, dout): ctx.max_seqlen_q // 2, ctx.max_seqlen_kv, ] - if _use_flash_attn_3 or _flash_attn_2_3_plus: + if _use_flash_attn_3 or ( + _flash_attn_2_3_plus and not _flash_attn_2_7_0_plus + ): fa_backward_kwargs["window_size"] = (-1, -1) + elif _flash_attn_2_7_0_plus: + fa_backward_kwargs["window_size_left"] = -1 + fa_backward_kwargs["window_size_right"] = -1 if not _use_flash_attn_3: fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1] flash_attn_bwd( @@ -3129,17 +3318,39 @@ def backward(ctx, dout): aux_ctx_tensors = [softmax_lse, rng_states[cp_size - i - 1]] if attn_dbias is not None: aux_ctx_tensors += [attn_biases[cp_size - i - 1]] + q_part = q + k_part = kv[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[0] + v_part = kv[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[1] + out_part = out + dout_part = dout + + if ctx.fp8: + q_part = ctx.QKV_quantizer.create_tensor_from_data( + q_part, fake_dtype=ctx.qkv_dtype + ) + k_part = ctx.QKV_quantizer.create_tensor_from_data( + k_part, fake_dtype=ctx.qkv_dtype + ) + v_part = ctx.QKV_quantizer.create_tensor_from_data( + v_part, fake_dtype=ctx.qkv_dtype + ) + out_part = ctx.O_quantizer.create_tensor_from_data( + out_part, fake_dtype=ctx.qkv_dtype + ) + dout_part = ctx.dO_quantizer.create_tensor_from_data( + dout_part, fake_dtype=ctx.qkv_dtype + ) dq_, dk_, dv_, dbias_ = fused_attn_bwd( ctx.max_seqlen_q, ctx.max_seqlen_kv, cu_seqlens_q_per_step[cp_size - i - 1], cu_seqlens_kv_per_step[cp_size - i - 1], - q, - kv[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[0], - kv[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[1], - out, - dout, - fused_attn_qkv_dtype, + q_part, + k_part, + v_part, + out_part, + dout_part, + ctx.qkv_dtype, fused_attn_dqkv_dtype, aux_ctx_tensors, fused_attn_backend, @@ -3153,6 +3364,12 @@ def backward(ctx, dout): deterministic=ctx.deterministic, **fp8_meta_kwargs, ) + + if ctx.fp8: + dq_ = dq_._data + dk_ = dk_._data + dv_ = dv_._data + else: dq_ = torch.empty_like(q) dkv_ = torch.empty_like(kv) @@ -3164,8 +3381,11 @@ def backward(ctx, dout): ctx.max_seqlen_q, ctx.max_seqlen_kv, ] - if _use_flash_attn_3 or _flash_attn_2_3_plus: + if _use_flash_attn_3 or (_flash_attn_2_3_plus and not _flash_attn_2_7_0_plus): fa_backward_kwargs["window_size"] = (-1, -1) + elif _flash_attn_2_7_0_plus: + fa_backward_kwargs["window_size_left"] = -1 + fa_backward_kwargs["window_size_right"] = -1 if not _use_flash_attn_3: fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1] flash_attn_bwd( @@ -3328,23 +3548,13 @@ def backward(ctx, dout): dkv.add_(dkv_) if ctx.fp8 and ctx.use_fused_attention: - amax_cp_bwd = amax_per_step.amax(dim=1) - ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DP] = amax_cp_bwd[0] - ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DQKV_CP] = amax_cp_bwd[1] if ctx.qkv_format in ["bshd", "sbhd"]: # [cp, b, 2, sk//2, 2, np, hn] -> [cp, 2, b, 2, sk//2, np, hn] or # [cp, 2, sk//2, b, 2, np, hn] -> [cp, 2, 2, sk//2, b, np, hn] dkv_fp8 = dkv_fp8.view(cp_size, 2, *dkv_fp8.shape[1:-3], *dkv_fp8.shape[-2:]) - dq, dkv = [ - cast_from_fp8( - x, - ctx.fp8_meta["scaling_bwd"], - META_DQKV_CP, - fp8_dtype_backward, - TE_DType[torch.float32], - ) - for x in [dq_fp8, dkv_fp8] - ] + dq = ctx.dQKV_quantizer.create_tensor_from_data(dq_fp8) + dkv = ctx.dQKV_quantizer.create_tensor_from_data(dkv_fp8) + dq, dkv = [x.dequantize() for x in [dq, dkv]] dq, dkv = [x.sum(dim=0).to(dout_dtype) for x in [dq, dkv]] if causal: @@ -3364,10 +3574,8 @@ def backward(ctx, dout): dkv[:, cu_seqlens_kv_padded[-1] :].fill_(0) if ctx.fp8 and ctx.is_input_fp8: - dq, dkv = [ - cast_to_fp8(x, ctx.fp8_meta["scaling_bwd"], META_DQKV, fp8_dtype_backward) - for x in [dq, dkv] - ] + assert torch.uint8 not in [dq.dtype, dkv.dtype] + dq, dkv = [ctx.dQKV_quantizer(x)._data for x in [dq, dkv]] dk, dv = dkv[0], dkv[1] if cp_size_a2a > 1: @@ -3386,22 +3594,14 @@ def backward(ctx, dout): elif ctx.qkv_format == "sbhd": dq, dk, dv = [x.view(-1, ctx.batch_size, *x.shape[-2:]) for x in [dq, dk, dv]] - if ctx.fp8 and ctx.is_input_fp8: - dq, dk, dv = [ - Float8Tensor( - data=x, - fp8_meta=ctx.fp8_meta, - fp8_meta_forward=False, - fp8_meta_index=META_DQKV, - fp8_dtype=fp8_dtype_backward, - dtype=dout_dtype, - ) - for x in [dq, dk, dv] - ] - if attn_dbias is not None: # [b, np, sq, 2*cp, sk//(2*cp)] -> [b, np, sq, sk] attn_dbias = attn_dbias.view(*attn_dbias.shape[:-2], -1) + # converting torch.uint8 to float8tensor + if ctx.fp8 and ctx.is_input_fp8: + dq = ctx.dQKV_quantizer.create_tensor_from_data(dq, ctx.qkv_dtype) + dk = ctx.dQKV_quantizer.create_tensor_from_data(dk, ctx.qkv_dtype) + dv = ctx.dQKV_quantizer.create_tensor_from_data(dv, ctx.qkv_dtype) return ( None, @@ -3427,6 +3627,7 @@ def backward(ctx, dout): None, None, None, + None, ) @@ -3493,6 +3694,8 @@ def forward( cp_size = get_distributed_world_size(cp_group) rank = get_distributed_rank(cp_group) + qkv_dtype = q.dtype + causal = "causal" in attn_mask_type padding = "padding" in attn_mask_type assert not padding, f"{attn_mask_type} mask type is not supported!" @@ -3521,8 +3724,10 @@ def forward( fa_forward_kwargs["return_softmax"] = False if _flash_attn_2_4_plus: fa_forward_kwargs["alibi_slopes"] = None - if _flash_attn_2_5_7_plus: + if _flash_attn_2_5_7_plus and qkv_format == "thd": fa_forward_kwargs["block_table"] = None + if _flash_attn_2_6_0_plus: + fa_forward_kwargs["softcap"] = 0.0 assert qkv_format != "thd", f"{qkv_format} format is not supported!" qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format @@ -3610,7 +3815,7 @@ def forward( q_, k_, v_, - TE_DType[q.dtype], + qkv_dtype, tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, attn_scale=softmax_scale, dropout=dropout_p, @@ -3631,19 +3836,31 @@ def forward( max_seqlen_q, max_seqlen_kv_, ] + if _use_flash_attn_3 or ( + _flash_attn_2_3_plus and not _flash_attn_2_7_0_plus + ): + fa_forward_kwargs["window_size"] = window_size_per_step[i] + elif _flash_attn_2_7_0_plus: + fa_forward_kwargs["window_size_left"] = window_size_per_step[i][0] + fa_forward_kwargs["window_size_right"] = window_size_per_step[i][1] fa_outputs = flash_attn_fwd( q_, k_, v_, *fa_forward_args_thd, causal=causal, - window_size=window_size_per_step[i], **fa_forward_kwargs, ) - out_per_step[i] = fa_outputs[4] - softmax_lse_per_step[i] = fa_outputs[5] - if not _use_flash_attn_3: - rng_states[i] = fa_outputs[7] + if not _flash_attn_2_7_0_plus: + out_per_step[i] = fa_outputs[4] + softmax_lse_per_step[i] = fa_outputs[5] + if not _use_flash_attn_3: + rng_states[i] = fa_outputs[7] + else: + out_per_step[i] = fa_outputs[0] + softmax_lse_per_step[i] = fa_outputs[1] + if not _use_flash_attn_3: + rng_states[i] = fa_outputs[3] if i > 0: with torch.cuda.stream(flash_attn_streams[i - 1]): @@ -3673,6 +3890,8 @@ def forward( *softmax_lse_per_step, *rng_states, ) + + ctx.qkv_dtype = qkv_dtype ctx.kv_seq_range_per_step = kv_seq_range_per_step ctx.window_size_per_step = window_size_per_step ctx.cp_group = cp_group @@ -3754,6 +3973,8 @@ def backward(ctx, dout): fa_backward_kwargs["alibi_slopes"] = None if _flash_attn_2_4_1_plus: fa_backward_kwargs["deterministic"] = ctx.deterministic + if _flash_attn_2_6_0_plus: + fa_backward_kwargs["softcap"] = 0.0 for i in range(len(local_seq_chunk_ids) + 1): if i < len(local_seq_chunk_ids): @@ -3783,7 +4004,7 @@ def backward(ctx, dout): v_, out_, dout_, - TE_DType[q.dtype], + ctx.qkv_dtype, TE_DType[dout.dtype], aux_ctx_tensors, tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, @@ -3811,6 +4032,11 @@ def backward(ctx, dout): ] if not _use_flash_attn_3: fa_backward_kwargs["rng_state"] = rng_states[i] + if _flash_attn_2_3_plus and not _flash_attn_2_7_0_plus: + fa_backward_kwargs["window_size"] = window_size_per_step[i] + if _flash_attn_2_7_0_plus: + fa_backward_kwargs["window_size_left"] = window_size_per_step[i][0] + fa_backward_kwargs["window_size_right"] = window_size_per_step[i][1] flash_attn_bwd( dout_, q_, @@ -3823,7 +4049,6 @@ def backward(ctx, dout): dv_per_step[i], *fa_backward_args_thd, causal="causal" in ctx.attn_mask_type, - window_size=window_size_per_step[i], **fa_backward_kwargs, ) @@ -3923,12 +4148,14 @@ def forward( fp8_meta, cp_group, cp_stream, + quantizers, ): # pylint: disable=missing-function-docstring if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) cp_size = get_distributed_world_size(cp_group) + qkv_dtype = q.dtype causal = "causal" in attn_mask_type padding = "padding" in attn_mask_type @@ -3958,12 +4185,17 @@ def forward( flash_attn_fwd = _flash_attn_fwd fa_forward_kwargs["dropout_p"] = dropout_p fa_forward_kwargs["return_softmax"] = False - if _flash_attn_2_3_plus: + if _use_flash_attn_3 or (_flash_attn_2_3_plus and not _flash_attn_2_7_0_plus): fa_forward_kwargs["window_size"] = window_size + elif _flash_attn_2_7_0_plus: + fa_forward_kwargs["window_size_left"] = window_size[0] + fa_forward_kwargs["window_size_right"] = window_size[1] if _flash_attn_2_4_plus: fa_forward_kwargs["alibi_slopes"] = None - if _flash_attn_2_5_7_plus: + if _flash_attn_2_5_7_plus and qkv_format == "thd": fa_forward_kwargs["block_table"] = None + if _flash_attn_2_6_0_plus: + fa_forward_kwargs["softcap"] = 0.0 assert ( q.shape[-2] % cp_size == 0 and k.shape[-2] % cp_size == 0 @@ -3978,50 +4210,38 @@ def forward( q.shape[seq_dim] % 2 == 0 and k.shape[seq_dim] % 2 == 0 ), "Sequence length per GPU needs to be divisible by 2!" - qkv_dtype = q.dtype fused_attn_backend = None - fused_attn_qkv_dtype = None # "fp8_mha" decides outputs in fp8, while inputs are inferred from the real dtype is_input_fp8 = False - is_output_fp8 = fp8_meta is not None and fp8_meta["recipe"].fp8_mha + is_output_fp8 = False if fp8: - if use_fused_attention: - fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) - fused_attn_qkv_dtype = fp8_dtype_forward + is_output_fp8 = fp8_meta["recipe"].fp8_mha + + QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer = ( + get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=False) + ) + if fp8: + if use_fused_attention: + fused_attn_backend = FusedAttnBackend["FP8"] assert isinstance(k, q.__class__) and isinstance( v, q.__class__ ), "q, k, and v must have the same type." is_input_fp8 = isinstance(q, Float8Tensor) if is_input_fp8: - fp8_meta["scaling_fwd"].scale_inv[META_QKV] = q._scale_inv q_fp8, k_fp8, v_fp8 = q, k, v q, k, v = q_fp8._data, k_fp8._data, v_fp8._data elif int(os.getenv("NVTE_FP8_DPA_BWD", "1")): q_f16, k_f16, v_f16 = q, k, v - q, k, v = [ - cast_to_fp8(x, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward) - for x in [q_f16, k_f16, v_f16] - ] + q, k, v = [QKV_quantizer(x)._data for x in [q_f16, k_f16, v_f16]] fp8_meta_kwargs = {} - fp8_meta_kwargs["d_scale_qkv"] = fp8_meta["scaling_fwd"].scale_inv - fp8_meta_kwargs["d_scale_qkv_offset"] = META_QKV - fp8_meta_kwargs["d_scale_s"] = fp8_meta["scaling_fwd"].scale_inv - fp8_meta_kwargs["d_scale_s_offset"] = META_S - fp8_meta_kwargs["q_scale_s"] = fp8_meta["scaling_fwd"].scale - fp8_meta_kwargs["q_scale_s_offset"] = META_S - fp8_meta_kwargs["q_scale_o"] = fp8_meta["scaling_fwd"].scale - fp8_meta_kwargs["q_scale_o_offset"] = META_O - fp8_meta_kwargs["amax_s"] = fp8_meta["scaling_fwd"].amax_history - fp8_meta_kwargs["amax_s_offset"] = META_S - fp8_meta_kwargs["amax_o"] = fp8_meta["scaling_fwd"].amax_history - fp8_meta_kwargs["amax_o_offset"] = META_O + fp8_meta_kwargs["s_quantizer"] = S_quantizer + fp8_meta_kwargs["o_quantizer"] = O_quantizer # partial result quantizer else: assert False, "FP8 is only supported with Fused Attention!" else: if use_fused_attention: fp8_meta_kwargs = {} - fused_attn_qkv_dtype = TE_DType[q.dtype] fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"] chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size, q.device, True) @@ -4031,23 +4251,31 @@ def forward( if fp8 and not is_input_fp8 and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): q_f16, k_f16, v_f16 = q, k, v - q, k, v = [ - cast_to_fp8(x, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward) - for x in [q_f16, k_f16, v_f16] - ] + q, k, v = [QKV_quantizer(x)._data for x in [q_f16, k_f16, v_f16]] batch_size = q.shape[batch_dim] if use_fused_attention: + q_part, k_part, v_part = q, k, v + if fp8: + q_part = QKV_quantizer.create_tensor_from_data( + q, fake_dtype=qkv_dtype, internal=True + ) + k_part = QKV_quantizer.create_tensor_from_data( + k, fake_dtype=qkv_dtype, internal=True + ) + v_part = QKV_quantizer.create_tensor_from_data( + v, fake_dtype=qkv_dtype, internal=True + ) out, aux_ctx_tensors = fused_attn_fwd( is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, - q, - k, - v, - fused_attn_qkv_dtype, + q_part, + k_part, + v_part, + qkv_dtype, fused_attn_backend, attn_scale=softmax_scale, dropout=dropout_p, @@ -4060,6 +4288,8 @@ def forward( window_size=window_size, **fp8_meta_kwargs, ) + if fp8: + out = out._data else: fa_forward_args_thd = [] if qkv_format == "thd": @@ -4077,8 +4307,12 @@ def forward( causal=causal, **fa_forward_kwargs, ) - out, softmax_lse = fa_outputs[4], fa_outputs[5] - rng_state = fa_outputs[7] if not _use_flash_attn_3 else None + if not _flash_attn_2_7_0_plus: + out, softmax_lse = fa_outputs[4], fa_outputs[5] + rng_state = fa_outputs[7] if not _use_flash_attn_3 else None + else: + out, softmax_lse = fa_outputs[0], fa_outputs[1] + rng_state = fa_outputs[3] if not _use_flash_attn_3 else None aux_ctx_tensors = [softmax_lse, rng_state] chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size, out.device, False) @@ -4096,24 +4330,16 @@ def forward( if fp8: if is_output_fp8: - out_fp8 = Float8Tensor( - data=out, - fp8_meta=fp8_meta, - fp8_meta_forward=True, - fp8_meta_index=META_O, - fp8_dtype=fp8_dtype_forward, - dtype=qkv_dtype, + out_fp8 = O_quantizer.create_tensor_from_data( + out, fake_dtype=qkv_dtype, internal=False ) - out = out_fp8._data out_ret = out_fp8 + out = out_fp8._data else: - out_f16 = cast_from_fp8( - out, - fp8_meta["scaling_fwd"], - META_O, - fp8_dtype_forward, - TE_DType[q_f16.dtype], + out_fp8 = O_quantizer.create_tensor_from_data( + out, fake_dtype=qkv_dtype, internal=False ) + out_f16 = out_fp8.dequantize() out_ret = out_f16 else: out_ret = out @@ -4122,30 +4348,22 @@ def forward( if int(os.getenv("NVTE_FP8_DPA_BWD", "1")): q_save, k_save, v_save, out_save = q, k, v, out elif is_input_fp8: - q_fp8, k_fp8, v_fp8 = [ - Float8Tensor( - data=x, - fp8_meta=fp8_meta, - fp8_meta_forward=True, - fp8_meta_index=META_QKV, - fp8_dtype=fp8_dtype_forward, - dtype=out_fp8.dtype, - ) - for x in [q, k, v] - ] - q_save, k_save, v_save, out_save = q_fp8, k_fp8, v_fp8, out_fp8 + q_fp8 = QKV_quantizer.create_tensor_from_data( + q, fake_dtype=qkv_dtype, internal=False + ) + k_fp8 = QKV_quantizer.create_tensor_from_data( + k, fake_dtype=qkv_dtype, internal=False + ) + v_fp8 = QKV_quantizer.create_tensor_from_data( + v, fake_dtype=qkv_dtype, internal=False + ) + q_save, k_save, v_save, out_save = q_fp8, k_fp8, v_fp8, out else: q_save, k_save, v_save, out_save = q_f16, k_f16, v_f16, out_f16 else: q_save, k_save, v_save, out_save = q, k, v, out - if fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")): - fp8_fwd_scales = fp8_meta["scaling_fwd"].scale.clone() - fp8_fwd_scale_invs = fp8_meta["scaling_fwd"].scale_inv.clone() - else: - fp8_fwd_scales, fp8_fwd_scale_invs = None, None - - ctx.save_for_backward( + tensors_to_save, tensor_objects = prepare_for_saving( q_save, k_save, v_save, @@ -4154,10 +4372,20 @@ def forward( cu_seqlens_kv, cu_seqlens_q_padded, cu_seqlens_kv_padded, - fp8_fwd_scales, - fp8_fwd_scale_invs, *aux_ctx_tensors, ) + ctx.save_for_backward(*tensors_to_save) + ctx.tensor_objects = tensor_objects + + ctx.qkv_dtype = qkv_dtype + ctx.QKV_quantizer = QKV_quantizer + ctx.O_quantizer = O_quantizer + ctx.S_quantizer = S_quantizer + ctx.dQKV_quantizer = dQKV_quantizer + ctx.dO_quantizer = dO_quantizer + ctx.dP_quantizer = dP_quantizer + ctx.qkv_dtype = qkv_dtype + ctx.batch_size = batch_size ctx.cp_group = cp_group ctx.cp_stream = cp_stream @@ -4182,11 +4410,18 @@ def backward(ctx, dout): # pylint: disable=missing-function-docstring cp_size = get_distributed_world_size(ctx.cp_group) - (*saved_tensors,) = ctx.saved_tensors - q, k, v, out = saved_tensors[:4] - cu_seqlens_q, cu_seqlens_kv, cu_seqlens_q_padded, cu_seqlens_kv_padded = saved_tensors[4:8] - fp8_fwd_scales, fp8_fwd_scale_invs = saved_tensors[8:10] - aux_ctx_tensors = saved_tensors[10:] + ( + q, + k, + v, + out, + cu_seqlens_q, + cu_seqlens_kv, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + *aux_ctx_tensors, + ) = restore_from_saved(ctx.tensor_objects, ctx.saved_tensors) + dout_dtype = dout.dtype qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format causal = "causal" in ctx.attn_mask_type @@ -4194,47 +4429,32 @@ def backward(ctx, dout): fused_attn_backend = None fused_attn_dqkv_dtype = None - fused_attn_qkv_dtype = None - dout_dtype = dout.dtype if ctx.fp8: + fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False) + fused_attn_dqkv_dtype = fp8_dtype_backward + if ctx.use_fused_attention: - fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True) - fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False) - fused_attn_qkv_dtype = fp8_dtype_forward - fused_attn_dqkv_dtype = fp8_dtype_backward fused_attn_backend = FusedAttnBackend["FP8"] if ctx.is_output_fp8: assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!" - ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO] = dout._scale_inv dout_fp8 = dout dout = dout_fp8._data else: dout_f16 = dout - dout = cast_to_fp8( - dout_f16, ctx.fp8_meta["scaling_bwd"], META_DO, fp8_dtype_backward - ) + dout = ctx.dO_quantizer(dout_f16)._data fp8_meta_kwargs = {} - fp8_meta_kwargs["d_scale_qkv"] = fp8_fwd_scale_invs[META_QKV] - fp8_meta_kwargs["d_scale_s"] = fp8_fwd_scale_invs[META_S] - fp8_meta_kwargs["d_scale_o"] = fp8_fwd_scale_invs[META_O] - fp8_meta_kwargs["d_scale_do"] = ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO] - fp8_meta_kwargs["d_scale_dp"] = ctx.fp8_meta["scaling_bwd"].scale_inv[META_DP] - fp8_meta_kwargs["q_scale_s"] = fp8_fwd_scales[META_S] - fp8_meta_kwargs["q_scale_dp"] = ctx.fp8_meta["scaling_bwd"].scale[META_DP] - fp8_meta_kwargs["q_scale_dqkv"] = ctx.fp8_meta["scaling_bwd"].scale[META_DQKV] - fp8_meta_kwargs["amax_dp"] = ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DP] - fp8_meta_kwargs["amax_dqkv"] = ctx.fp8_meta["scaling_bwd"].amax_history[0][ - META_DQKV - ] + fp8_meta_kwargs["s_quantizer"] = ctx.S_quantizer + fp8_meta_kwargs["dp_quantizer"] = ctx.dP_quantizer + fp8_meta_kwargs["dqkv_quantizer"] = ctx.dQKV_quantizer + else: assert False, "FP8 is only supported with Fused Attention!" else: if ctx.fp8_meta is not None and ctx.is_output_fp8: assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!" - q, k, v, out, dout = [x.from_float8(x.dtype) for x in [q, k, v, out, dout]] + q, k, v, out, dout = [x.dequantize() for x in [q, k, v, out, dout]] if ctx.use_fused_attention: fp8_meta_kwargs = {} - fused_attn_qkv_dtype = TE_DType[q.dtype] fused_attn_dqkv_dtype = TE_DType[dout.dtype] fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"] @@ -4263,25 +4483,53 @@ def backward(ctx, dout): else: flash_attn_bwd = _flash_attn_bwd fa_backward_kwargs["dropout_p"] = ctx.dropout_p - if _flash_attn_2_3_plus: + if _use_flash_attn_3 or (_flash_attn_2_3_plus and not _flash_attn_2_7_0_plus): fa_backward_kwargs["window_size"] = ctx.window_size + elif _flash_attn_2_7_0_plus: + fa_backward_kwargs["window_size_left"] = ctx.window_size[0] + fa_backward_kwargs["window_size_right"] = ctx.window_size[1] if _flash_attn_2_4_plus: fa_backward_kwargs["alibi_slopes"] = None if _flash_attn_2_4_1_plus: fa_backward_kwargs["deterministic"] = ctx.deterministic + if _flash_attn_2_6_0_plus: + fa_backward_kwargs["softcap"] = 0.0 if ctx.use_fused_attention: + q_part = q + k_part = k + v_part = v + out_part = out + dout_part = dout + + if ctx.fp8: + q_part = ctx.QKV_quantizer.create_tensor_from_data( + q_part, fake_dtype=ctx.qkv_dtype, internal=True + ) + k_part = ctx.QKV_quantizer.create_tensor_from_data( + k_part, fake_dtype=ctx.qkv_dtype, internal=True + ) + v_part = ctx.QKV_quantizer.create_tensor_from_data( + v_part, fake_dtype=ctx.qkv_dtype, internal=True + ) + out_part = ctx.O_quantizer.create_tensor_from_data( + out_part, fake_dtype=ctx.qkv_dtype, internal=True + ) + dout_part = ctx.dO_quantizer.create_tensor_from_data( + dout_part, fake_dtype=ctx.qkv_dtype, internal=True + ) + dq, dk, dv, _ = fused_attn_bwd( ctx.max_seqlen_q, ctx.max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, - q, - k, - v, - out, - dout, - fused_attn_qkv_dtype, + q_part, + k_part, + v_part, + out_part, + dout_part, + ctx.qkv_dtype, fused_attn_dqkv_dtype, aux_ctx_tensors, fused_attn_backend, @@ -4296,6 +4544,10 @@ def backward(ctx, dout): deterministic=ctx.deterministic, **fp8_meta_kwargs, ) + if ctx.fp8: + dq = dq._data + dk = dk._data + dv = dv._data else: softmax_lse, rng_state = aux_ctx_tensors dq, dk, dv = [torch.empty_like(x) for x in [q, k, v]] @@ -4335,29 +4587,11 @@ def backward(ctx, dout): dq, dk, dv = [x.view(-1, ctx.batch_size, *x.shape[-2:]) for x in [dq, dk, dv]] if ctx.fp8: - if ctx.is_input_fp8: - dq, dk, dv = [ - Float8Tensor( - data=x, - fp8_meta=ctx.fp8_meta, - fp8_meta_forward=False, - fp8_meta_index=META_DQKV, - fp8_dtype=fp8_dtype_backward, - dtype=dout_dtype, - ) - for x in [dq, dk, dv] - ] - else: - dq, dk, dv = [ - cast_from_fp8( - x, - ctx.fp8_meta["scaling_bwd"], - META_DQKV, - fp8_dtype_backward, - TE_DType[dout_dtype], - ) - for x in [dq, dk, dv] - ] + dq = ctx.dQKV_quantizer.create_tensor_from_data(dq, fake_dtype=dout_dtype) + dk = ctx.dQKV_quantizer.create_tensor_from_data(dk, fake_dtype=dout_dtype) + dv = ctx.dQKV_quantizer.create_tensor_from_data(dv, fake_dtype=dout_dtype) + if not ctx.is_input_fp8: + dq, dk, dv = [x.dequantize() for x in [dq, dk, dv]] return ( None, @@ -4384,6 +4618,7 @@ def backward(ctx, dout): None, None, None, + None, ) @@ -4413,6 +4648,7 @@ def attn_forward_func_with_cp( window_size=None, fp8=False, fp8_meta=None, + quantizers=None, ) -> torch.Tensor: """ Attention implementation with context parallelism. @@ -4480,7 +4716,7 @@ def attn_forward_func_with_cp( ] if cp_comm_type in ["p2p", "a2a+p2p"]: - args += [fp8, fp8_meta, cp_group, cp_global_ranks, cp_stream] + args += [fp8, fp8_meta, cp_group, cp_global_ranks, cp_stream, quantizers] out = AttnFuncWithCPAndKVP2P.apply(*args) elif cp_comm_type == "all_gather": args.pop(5) @@ -4488,7 +4724,7 @@ def attn_forward_func_with_cp( args += [window_size, cp_group, cp_stream] out = AttnFuncWithCPAndKVAllGather.apply(*args) elif cp_comm_type == "a2a": - args += [window_size, fp8, fp8_meta, cp_group, cp_stream] + args += [window_size, fp8, fp8_meta, cp_group, cp_stream, quantizers] out = AttnFuncWithCPAndQKVOA2A.apply(*args) else: raise ValueError(f"Unsupported communication type: {cp_comm_type}!") @@ -4720,15 +4956,34 @@ def forward( mixed_x_layer: torch.Tensor, split_dim: int, split_size_or_sections: Union[int, List[int], Tuple[int]], + squeeze=False, ) -> Tuple[torch.Tensor, ...]: # pylint: disable=missing-function-docstring ctx.split_dim = split_dim ctx.split_size_or_sections = split_size_or_sections + if isinstance(mixed_x_layer, Float8TensorBase) and not isinstance( + mixed_x_layer, Float8Tensor + ): + return tuple( + Float8TensorBase( + fp8_scale_inv=mixed_x_layer._scale_inv, + fp8_dtype=mixed_x_layer._fp8_dtype, + data=x.squeeze(split_dim) if squeeze else x, + shape=x.squeeze(split_dim).shape if squeeze else x.shape, + quantizer=mixed_x_layer._quantizer, + ) + for x in torch.split( + mixed_x_layer._data, + split_size_or_sections=split_size_or_sections, + dim=split_dim, + ) + ) if isinstance(mixed_x_layer, Float8Tensor): return tuple( Float8Tensor.make_like( mixed_x_layer, - data=x, + data=x.squeeze(split_dim) if squeeze else x, + shape=x.squeeze(split_dim).shape if squeeze else x.shape, ) for x in torch.split( mixed_x_layer._data, @@ -4736,7 +4991,10 @@ def forward( dim=split_dim, ) ) - return torch.split(mixed_x_layer, split_size_or_sections, dim=split_dim) + out_list = torch.split(mixed_x_layer, split_size_or_sections, dim=split_dim) + if squeeze: + out_list = [x.squeeze(split_dim) for x in out_list] + return out_list @staticmethod def backward(ctx, *grad_outputs): @@ -4782,13 +5040,17 @@ def backward(ctx, *grad_outputs): new_shape, strides, ) - return Float8Tensor.make_like(grad_outputs[0], data=ret), None, None + return ( + Float8Tensor.make_like(grad_outputs[0], data=ret, shape=ret.shape), + None, + None, + ) grad_outputs_data = [x._data for x in grad_outputs] + data = torch.cat(grad_outputs_data, dim=split_dim) return ( - Float8Tensor.make_like( - grad_outputs[0], data=torch.cat(grad_outputs_data, dim=split_dim) - ), + Float8Tensor.make_like(grad_outputs[0], data=data, shape=data.shape), + None, None, None, ) @@ -4923,19 +5185,14 @@ def forward( key_layer = key_layer.reshape(output_size[3], output_size[0] * output_size[1], -1) # preallocting result tensor: [b * np, sq, sk] - # WAR to set dtype to FP32 as ONNX lacks BF16 support for ConstantOfShape operator - is_bf16 = query_layer.dtype == torch.bfloat16 matmul_result = torch.empty( output_size[0] * output_size[1], output_size[2], output_size[3], - dtype=torch.float32 if is_in_onnx_export_mode() and is_bf16 else query_layer.dtype, + dtype=query_layer.dtype, device=torch.cuda.current_device(), ) - if is_in_onnx_export_mode() and is_bf16: - matmul_result = matmul_result.bfloat16() - scale = self.softmax_scale if apply_qk_layer_scaling: scale /= self.layer_number @@ -5323,6 +5580,7 @@ def forward( cp_comm_type: str = "p2p", fp8: bool = False, fp8_meta: Optional[Dict[str, Any]] = None, + quantizers=None, ) -> torch.Tensor: """flash-attn fprop""" @@ -5373,7 +5631,7 @@ def forward( for x in (query_layer._data, key_layer._data, value_layer._data) ] query_layer, key_layer, value_layer = [ - Float8Tensor.make_like(x, data=x._data) + Float8Tensor.make_like(x, data=x._data, shape=x._data.shape) for x in (query_layer, key_layer, value_layer) ] if context_parallel: @@ -5476,6 +5734,7 @@ def forward( attn_mask_type=attn_mask_type, deterministic=self.deterministic, window_size=window_size, + quantizers=quantizers, ) else: @@ -5514,10 +5773,10 @@ def forward( fa_3_optional_forward_kwargs = {} fa_3_optional_forward_kwargs["window_size"] = window_size fa_3_optional_forward_kwargs["deterministic"] = self.deterministic - activation_dtype = query_layer.dtype if fp8: - fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) + QKV_quantizer = quantizers["scaling_fwd"][META_QKV] torch_dtype = get_fp8_torch_dtype(fp8_meta["recipe"], fprop_tensor=True) + torch_orig_dtype = query_layer.dtype def convert_to_torch_float8(tensor, dtype): out = torch.Tensor().to(device=tensor.device, dtype=dtype) @@ -5534,960 +5793,118 @@ def convert_to_torch_float8(tensor, dtype): assert isinstance(key_layer, query_layer.__class__) and isinstance( value_layer, query_layer.__class__ ), "q, k, and v must have the same type." - if isinstance(query_layer, Float8Tensor): - fp8_meta["scaling_fwd"].scale_inv[META_QKV] = query_layer._scale_inv - else: + if not isinstance(query_layer, Float8Tensor): query_layer, key_layer, value_layer = ( - Float8Tensor.to_float8(x, fp8_dtype=fp8_dtype_forward) - for x in [query_layer, key_layer, value_layer] + QKV_quantizer(x) for x in [query_layer, key_layer, value_layer] ) - fa_3_optional_forward_kwargs["descale_q"] = query_layer._scale_inv - fa_3_optional_forward_kwargs["descale_k"] = key_layer._scale_inv - fa_3_optional_forward_kwargs["descale_v"] = value_layer._scale_inv - query_layer, key_layer, value_layer = ( - convert_to_torch_float8(x, torch_dtype) - for x in [query_layer, key_layer, value_layer] - ) - try: - output, _ = func( - query_layer, - key_layer, - value_layer, - *fa_optional_forward_args_thd, - softmax_scale=self.softmax_scale, - causal="causal" in attn_mask_type, - **fa_3_optional_forward_kwargs, + fa_3_optional_forward_kwargs["descale_q"] = ( + query_layer._scale_inv.unsqueeze(0) ) - except TypeError as e: - if _flash_attn_3_0_0_beta: - e.args = ( - e.args[0] - + ". Please update your flash-attn v3 (beta) installation as it " - + "may have added more supported arguments to its API. \n" - + _flash_attn_3_installation_steps, - ) + e.args[1:] - raise - - if fp8 and fp8_meta["recipe"].fp8_mha: - output = cast_to_fp8( - output, - fp8_meta["scaling_fwd"], - META_O, - fp8_dtype_forward, + fa_3_optional_forward_kwargs["descale_k"] = key_layer._scale_inv.unsqueeze( + 0 ) - output = Float8Tensor( - data=output, - fp8_meta=fp8_meta, - fp8_meta_forward=True, - fp8_meta_index=META_O, - fp8_dtype=fp8_dtype_forward, - dtype=activation_dtype, + fa_3_optional_forward_kwargs["descale_v"] = ( + value_layer._scale_inv.unsqueeze(0) ) - else: - output = func( - query_layer, - key_layer, - value_layer, - *fa_optional_forward_args_thd, - self.attention_dropout if self.training else 0.0, - softmax_scale=self.softmax_scale, - causal="causal" in attn_mask_type, - **fa_optional_forward_kwargs, - ) - - if qkv_format in ["sbhd", "bshd"] and "padding" in attn_mask_type: - output = UnpackTensor.apply(indices_q, batch_size * max_seqlen_q, output) - - if qkv_format == "sbhd": - # (bs)hd -> bs(hd) -> sb(hd) - if fp8 and fp8_meta["recipe"].fp8_mha: - output = Float8Tensor.make_like( - output, - data=output._data.reshape(batch_size, max_seqlen_q // cp_size, -1) - .transpose(0, 1) - .contiguous(), - ) - else: - output = output.view(batch_size, max_seqlen_q // cp_size, -1).transpose(0, 1) - elif qkv_format == "bshd": - # (bs)hd -> bs(hd) - output = output.reshape(batch_size, max_seqlen_q // cp_size, -1) - elif qkv_format == "thd": - # thd -> t(hd) - output = output.reshape(output.shape[0], -1) - - return output.contiguous() - - -def _combine_tensors( - tensors: List[torch.Tensor], - dim: int, -) -> torch.Tensor: - """Combine tensors along a particular dimension""" - - num_tensors = len(tensors) - new_shape = list(tensors[0].shape) - new_shape.insert(dim, num_tensors) - new_stride = list(tensors[0].stride()) - new_stride.insert(dim, int(new_stride[dim - 1] / num_tensors)) - if isinstance(tensors[0], Float8Tensor): - combined_tensor = torch.Tensor().to(device=tensors[0].device, dtype=tensors[0]._data.dtype) - combined_tensor.set_( - tensors[0]._data.untyped_storage(), - tensors[0]._data.storage_offset(), - new_shape, - new_stride, - ) - combined_tensor = Float8Tensor.make_like(tensors[0], data=combined_tensor) - else: - combined_tensor = torch.Tensor().to(device=tensors[0].device, dtype=tensors[0].dtype) - combined_tensor.set_( - tensors[0].untyped_storage(), tensors[0].storage_offset(), new_shape, new_stride - ) - - return combined_tensor - - -class FusedAttnFunc_qkvpacked(torch.autograd.Function): - """Function for FusedAttention with packed QKV input""" - - @staticmethod - def forward( - ctx, - is_training, - max_seqlen, - cu_seqlens, - cu_seqlens_padded, - qkv, - qkv_dtype, - attn_bias, - attn_scale, - dropout_p, - fast_zero_fill, - qkv_layout, - attn_bias_type, - attn_mask_type, - window_size, - rng_gen, - fused_attention_backend, - use_FAv2_bwd, - fp8, - fp8_meta, - deterministic, - ): - # pylint: disable=missing-function-docstring - # "fp8_mha" decides outputs in fp8, while inputs are inferred from the real dtype - is_input_fp8 = False - is_output_fp8 = fp8_meta["recipe"].fp8_mha - if fp8: - is_input_fp8 = isinstance(qkv, Float8Tensor) - if is_input_fp8: - fp8_meta["scaling_fwd"].scale_inv[META_QKV] = qkv._scale_inv - fused_attention_backend = FusedAttnBackend["FP8"] - fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) - # 1: qkv packed, 2: kv packed, 3: qkv separate - qkv_group = len(qkv_layout.split("_")) - assert ( - qkv_group == 1 - ), f"qkv layout should conform to 3hd or h3d, e.g. sb3hd, but found {qkv_layout}." - if is_input_fp8: - qkv_fp8 = qkv._data - else: - qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1]) - qkv_fp8 = cast_to_fp8( - qkv_c, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward - ).view(qkv.shape) - out_fp8, aux_ctx_tensors = fused_attn_fwd_qkvpacked( - is_training, - max_seqlen, - cu_seqlens, - qkv_fp8, - fp8_dtype_forward, - fused_attention_backend, - attn_bias, - cu_seqlens_padded, - fp8_meta["scaling_fwd"].scale_inv, # d_scale_qkv - META_QKV, # d_scale_qkv_offset - fp8_meta["scaling_fwd"].scale_inv, # d_scale_s - META_S, # d_scale_s_offset - fp8_meta["scaling_fwd"].scale, # q_scale_s - META_S, # q_scale_s_offset - fp8_meta["scaling_fwd"].scale, # q_scale_o - META_O, # q_scale_o_offset - fp8_meta["scaling_fwd"].amax_history, # amax_s - META_S, # amax_s_offset - fp8_meta["scaling_fwd"].amax_history, # amax_o - META_O, # amax_o_offset - attn_scale, - dropout_p, - fast_zero_fill, - qkv_layout, - attn_bias_type, - attn_mask_type, - window_size, - rng_gen, - ) - if is_output_fp8: - out_ret = Float8Tensor( - data=out_fp8, - fp8_meta=fp8_meta, - fp8_meta_forward=True, - fp8_meta_index=META_O, - fp8_dtype=fp8_dtype_forward, - dtype=qkv.dtype, - ) - else: - out_ret = cast_from_fp8( - out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]), - fp8_meta["scaling_fwd"], - META_O, - fp8_dtype_forward, - qkv_dtype, - ).view(out_fp8.shape) - out_save = out_ret - if not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): - if is_input_fp8: - qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1]) - qkv = cast_from_fp8( - qkv_c._data, - fp8_meta["scaling_fwd"], - META_QKV, - fp8_dtype_forward, - TE_DType[qkv.dtype], - ).view(qkv.shape) - if is_output_fp8: - out_save = cast_from_fp8( - out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]), - fp8_meta["scaling_fwd"], - META_O, - fp8_dtype_forward, - qkv_dtype, - ).view(out_fp8.shape) - fp8_tensors = ( - qkv_fp8, - out_fp8, - fp8_meta["scaling_fwd"].scale.clone(), - fp8_meta["scaling_fwd"].scale_inv.clone(), - ) - else: - out_ret, aux_ctx_tensors = fused_attn_fwd_qkvpacked( - is_training, - max_seqlen, - cu_seqlens, - qkv, - qkv_dtype, - fused_attention_backend, - attn_bias, - cu_seqlens_padded, - None, # d_scale_qkv - 0, # d_scale_qkv_offset - None, # d_scale_s - 0, # d_scale_s_offset - None, # q_scale_s - 0, # q_scale_s_offset - None, # q_scale_o - 0, # q_scale_o_offset - None, # amax_s - 0, # amax_s_offset - None, # amax_o - 0, # amax_o_offset - attn_scale, - dropout_p, - fast_zero_fill, - qkv_layout, - attn_bias_type, - attn_mask_type, - window_size, - rng_gen, - ) - fp8_tensors = (None, None, None, None) - out_save = out_ret - - ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")) - ctx.is_input_fp8 = is_input_fp8 - ctx.is_output_fp8 = is_output_fp8 - qkvo_tensors = (qkv, out_save) if not ctx.fp8 else (None, None) - ctx.save_for_backward( - *qkvo_tensors, cu_seqlens, cu_seqlens_padded, *fp8_tensors, *aux_ctx_tensors - ) - ctx.fp8_meta = fp8_meta - ctx.max_seqlen = max_seqlen - ctx.qkv_dtype = qkv_dtype - ctx.attn_scale = attn_scale - ctx.dropout_p = dropout_p - ctx.fast_zero_fill = fast_zero_fill - ctx.qkv_layout = qkv_layout - ctx.attn_bias_type = attn_bias_type - ctx.attn_mask_type = attn_mask_type - ctx.window_size = window_size - ctx.fused_attention_backend = ( - fused_attention_backend if ctx.fp8 else FusedAttnBackend["F16_arbitrary_seqlen"] - ) - ctx.use_FAv2_bwd = use_FAv2_bwd - ctx.deterministic = deterministic - - return out_ret - - @staticmethod - def backward(ctx, d_out): - # pylint: disable=missing-function-docstring - if ctx.is_output_fp8: - assert isinstance( - d_out, Float8Tensor - ), "Gradient of the DPA output must be in Float8Tensor type for FP8 MHA." - d_out_f8tensor = d_out - d_out = d_out._data - - d_out = d_out.contiguous() - ( - qkv, - out, - cu_seqlens, - cu_seqlens_padded, - qkv_fp8, - out_fp8, - fwd_scales, - fwd_scale_invs, - *aux_ctx_tensors, - ) = ctx.saved_tensors - rest = [None] - if not aux_ctx_tensors[0].is_contiguous(): - aux_ctx_tensors[0] = aux_ctx_tensors[0].contiguous() - if ctx.use_FAv2_bwd: - softmax_lse, rng_state = aux_ctx_tensors - dqkv = torch.empty_like(qkv) - d_out, q, k, v, out = [ - maybe_contiguous(x) for x in (d_out, qkv[:, 0], qkv[:, 1], qkv[:, 2], out) - ] - flash_attn_cuda_bwd( - d_out, - q, - k, - v, - out, - softmax_lse, - dqkv[:, 0], - dqkv[:, 1], - dqkv[:, 2], - cu_seqlens, - cu_seqlens, - ctx.max_seqlen, - ctx.max_seqlen, - ctx.dropout_p, - ctx.attn_scale, - False, - "causal" in ctx.attn_mask_type, - None, - rng_state, - ) - dqkv = dqkv[..., : d_out.shape[-1]] - else: - with torch.cuda.nvtx.range("_FusedAttn_qkvpacked"): - if ctx.fp8: - fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True) - fp8_dtype_backward = get_fp8_te_dtype( - ctx.fp8_meta["recipe"], fprop_tensor=False - ) - if ctx.is_output_fp8: - d_out_fp8 = d_out - ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO] = d_out_f8tensor._scale_inv - else: - d_out_fp8 = cast_to_fp8( - d_out.view(-1, d_out.shape[-2] * d_out.shape[-1]), - ctx.fp8_meta["scaling_bwd"], - META_DO, - fp8_dtype_backward, - ).view(d_out.shape) - dqkv_fp8, *rest = fused_attn_bwd_qkvpacked( - ctx.max_seqlen, - cu_seqlens, - qkv_fp8, - out_fp8, - d_out_fp8, - fp8_dtype_forward, - fp8_dtype_backward, - aux_ctx_tensors, - ctx.fused_attention_backend, - cu_seqlens_padded, - fwd_scale_invs[META_QKV], # d_scale_qkv, - fwd_scale_invs[META_S], # d_scale_s, - fwd_scale_invs[META_O], # d_scale_o, - ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO], # d_scale_do - ctx.fp8_meta["scaling_bwd"].scale_inv[META_DP], # d_scale_dp - fwd_scales[META_S], # q_scale_s - ctx.fp8_meta["scaling_bwd"].scale[META_DP], # q_scale_dp - ctx.fp8_meta["scaling_bwd"].scale[META_DQKV], # q_scale_dqkv - ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DP], # amax_dp - ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DQKV], # amax_dqkv - ctx.attn_scale, - ctx.dropout_p, - ctx.fast_zero_fill, - ctx.qkv_layout, - ctx.attn_bias_type, - ctx.attn_mask_type, - ctx.window_size, - ctx.deterministic, - ) - if ctx.is_input_fp8: - dqkv = Float8Tensor( - data=dqkv_fp8, - fp8_meta=ctx.fp8_meta, - fp8_meta_forward=False, - fp8_meta_index=META_DQKV, - fp8_dtype=fp8_dtype_backward, - dtype=d_out_f8tensor.dtype, - ) - else: - dqkv_c_fp8 = dqkv_fp8.view( - -1, dqkv_fp8.shape[-3] * dqkv_fp8.shape[-2] * dqkv_fp8.shape[-1] - ) - dqkv = cast_from_fp8( - dqkv_c_fp8, - ctx.fp8_meta["scaling_bwd"], - META_DQKV, - fp8_dtype_backward, - ctx.qkv_dtype, - ).view(dqkv_fp8.shape) - else: - if d_out.dtype == torch.uint8: - d_out = d_out_f8tensor.from_float8(qkv.dtype) - dqkv, *rest = fused_attn_bwd_qkvpacked( - ctx.max_seqlen, - cu_seqlens, - qkv, - out, - d_out, - ctx.qkv_dtype, - ctx.qkv_dtype, - aux_ctx_tensors, - ctx.fused_attention_backend, - cu_seqlens_padded, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - ctx.attn_scale, - ctx.dropout_p, - ctx.fast_zero_fill, - ctx.qkv_layout, - ctx.attn_bias_type, - ctx.attn_mask_type, - ctx.window_size, - ctx.deterministic, - ) - - # if no_bias or alibi, return dqkv - if ctx.attn_bias_type in ["no_bias", "alibi"]: - return ( - None, - None, - None, - None, - dqkv, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - ) - # else, return (dqkv, dbias) - return ( - None, - None, - None, - None, - dqkv, - None, - rest[0], - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - ) - - -class FusedAttnFunc_kvpacked(torch.autograd.Function): - """Function for FusedAttention with packed KV input""" - - @staticmethod - def forward( - ctx, - is_training, - max_seqlen_q, - max_seqlen_kv, - cu_seqlens_q, - cu_seqlens_kv, - cu_seqlens_q_padded, - cu_seqlens_kv_padded, - q, - kv, - qkv_dtype, - attn_bias, - attn_scale, - dropout_p, - fast_zero_fill, - qkv_layout, - attn_bias_type, - attn_mask_type, - window_size, - rng_gen, - fused_attention_backend, - use_FAv2_bwd, - fp8, - fp8_meta, - deterministic, - ): - # pylint: disable=missing-function-docstring - # "fp8_mha" decides outputs in fp8, while inputs are inferred from the real dtype - is_input_fp8 = False - is_output_fp8 = fp8_meta["recipe"].fp8_mha - if fp8: - assert isinstance(kv, q.__class__), "q and kv must have the same type." - is_input_fp8 = isinstance(q, Float8Tensor) - if is_input_fp8: - fp8_meta["scaling_fwd"].scale_inv[META_QKV] = q._scale_inv - fused_attention_backend = FusedAttnBackend["FP8"] - fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) - if is_input_fp8: - q_fp8, kv_fp8 = q._data, kv._data - else: - # 1: qkv packed, 2: kv packed, 3: qkv separate - qkv_group = len(qkv_layout.split("_")) - assert qkv_group == 2, ( - "qkv layout should conform to hd_2hd or hd_h2d, e.g. sbhd_sb2hd, " - f"but found {qkv_layout}." - ) - q_fp8 = cast_to_fp8(q, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward).view( - q.shape - ) - kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1]) - kv_fp8 = cast_to_fp8( - kv_c, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward - ).view(kv.shape) - out_fp8, aux_ctx_tensors = fused_attn_fwd_kvpacked( - is_training, - max_seqlen_q, - max_seqlen_kv, - cu_seqlens_q, - cu_seqlens_kv, - q_fp8, - kv_fp8, - fp8_dtype_forward, - fused_attention_backend, - attn_bias, - cu_seqlens_q_padded, - cu_seqlens_kv_padded, - fp8_meta["scaling_fwd"].scale_inv, # d_scale_qkv - META_QKV, # d_scale_qkv_offset - fp8_meta["scaling_fwd"].scale_inv, # d_scale_s - META_S, # d_scale_s_offset - fp8_meta["scaling_fwd"].scale, # q_scale_s - META_S, # q_scale_s_offset - fp8_meta["scaling_fwd"].scale, # q_scale_o - META_O, # q_scale_o_offset - fp8_meta["scaling_fwd"].amax_history, # amax_s - META_S, # amax_s_offset - fp8_meta["scaling_fwd"].amax_history, # amax_o - META_O, # amax_o_offset - attn_scale, - dropout_p, - fast_zero_fill, - qkv_layout, - attn_bias_type, - attn_mask_type, - window_size, - rng_gen, - ) - if is_output_fp8: - out_ret = Float8Tensor( - data=out_fp8, - fp8_meta=fp8_meta, - fp8_meta_forward=True, - fp8_meta_index=META_O, - fp8_dtype=fp8_dtype_forward, - dtype=q.dtype, - ) - else: - out_ret = cast_from_fp8( - out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]), - fp8_meta["scaling_fwd"], - META_O, - fp8_dtype_forward, - qkv_dtype, - ).view(out_fp8.shape) - out_save = out_ret - if not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): - if is_input_fp8: - q = cast_from_fp8( - q._data, - fp8_meta["scaling_fwd"], - META_QKV, - fp8_dtype_forward, - TE_DType[q.dtype], - ).view(q.shape) - kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1]) - kv = cast_from_fp8( - kv_c._data, - fp8_meta["scaling_fwd"], - META_QKV, - fp8_dtype_forward, - TE_DType[kv.dtype], - ).view(kv.shape) - if is_output_fp8: - out_save = cast_from_fp8( - out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]), - fp8_meta["scaling_fwd"], - META_O, - fp8_dtype_forward, - qkv_dtype, - ).view(out_fp8.shape) - fp8_tensors = ( - q_fp8, - kv_fp8, - out_fp8, - fp8_meta["scaling_fwd"].scale.clone(), - fp8_meta["scaling_fwd"].scale_inv.clone(), - ) - else: - out_ret, aux_ctx_tensors = fused_attn_fwd_kvpacked( - is_training, - max_seqlen_q, - max_seqlen_kv, - cu_seqlens_q, - cu_seqlens_kv, - q, - kv, - qkv_dtype, - fused_attention_backend, - attn_bias, - cu_seqlens_q_padded, - cu_seqlens_kv_padded, - None, # d_scale_qkv - 0, # d_scale_qkv_offset - None, # d_scale_s - 0, # d_scale_s_offset - None, # q_scale_s - 0, # q_scale_s_offset - None, # q_scale_o - 0, # q_scale_o_offset - None, # amax_s - 0, # amax_s_offset - None, # amax_o - 0, # amax_o_offset - attn_scale, - dropout_p, - fast_zero_fill, - qkv_layout, - attn_bias_type, - attn_mask_type, - window_size, - rng_gen, - ) - out_save = out_ret - fp8_tensors = (None, None, None, None, None) - - ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")) - ctx.is_input_fp8 = is_input_fp8 - ctx.is_output_fp8 = is_output_fp8 - qkvo_tensors = (q, kv, out_save) if not ctx.fp8 else (None, None, None) - ctx.save_for_backward( - *qkvo_tensors, - cu_seqlens_q, - cu_seqlens_kv, - cu_seqlens_q_padded, - cu_seqlens_kv_padded, - *fp8_tensors, - *aux_ctx_tensors, - ) - ctx.fp8_meta = fp8_meta - ctx.max_seqlen_q = max_seqlen_q - ctx.max_seqlen_kv = max_seqlen_kv - ctx.qkv_dtype = qkv_dtype - ctx.attn_scale = attn_scale - ctx.dropout_p = dropout_p - ctx.fast_zero_fill = fast_zero_fill - ctx.qkv_layout = qkv_layout - ctx.attn_bias_type = attn_bias_type - ctx.attn_mask_type = attn_mask_type - ctx.window_size = window_size - ctx.fused_attention_backend = ( - fused_attention_backend if ctx.fp8 else FusedAttnBackend["F16_arbitrary_seqlen"] - ) - ctx.use_FAv2_bwd = use_FAv2_bwd - ctx.deterministic = deterministic - - return out_ret - - @staticmethod - def backward(ctx, d_out): - # pylint: disable=missing-function-docstring - if ctx.is_output_fp8: - assert isinstance( - d_out, Float8Tensor - ), "Gradient of the DPA output must be in Float8Tensor type for FP8 MHA." - d_out_f8tensor = d_out - d_out = d_out._data - - d_out = d_out.contiguous() - ( - q, - kv, - out, - cu_seqlens_q, - cu_seqlens_kv, - cu_seqlens_q_padded, - cu_seqlens_kv_padded, - q_fp8, - kv_fp8, - out_fp8, - fwd_scales, - fwd_scale_invs, - *aux_ctx_tensors, - ) = ctx.saved_tensors - rest = [None] - if not aux_ctx_tensors[0].is_contiguous(): - aux_ctx_tensors[0] = aux_ctx_tensors[0].contiguous() - if ctx.use_FAv2_bwd: - softmax_lse, rng_state = aux_ctx_tensors - dq = torch.empty_like(q) - dkv = torch.empty_like(kv) - d_out, q, k, v, out = [maybe_contiguous(x) for x in (d_out, q, kv[:, 0], kv[:, 1], out)] - flash_attn_cuda_bwd( - d_out, - q, - k, - v, - out, - softmax_lse, - dq, - dkv[:, 0], - dkv[:, 1], - cu_seqlens_q, - cu_seqlens_kv, - ctx.max_seqlen_q, - ctx.max_seqlen_kv, - ctx.dropout_p, - ctx.attn_scale, - False, - "causal" in ctx.attn_mask_type, - None, - rng_state, - ) - dq = dq[..., : d_out.shape[-1]] - dkv = dkv[..., : d_out.shape[-1]] - else: - with torch.cuda.nvtx.range("_FusedAttn_kvpacked"): - if ctx.fp8: - fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True) - fp8_dtype_backward = get_fp8_te_dtype( - ctx.fp8_meta["recipe"], fprop_tensor=False - ) - if ctx.is_output_fp8: - d_out_fp8 = d_out - ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO] = d_out_f8tensor._scale_inv - else: - d_out_fp8 = cast_to_fp8( - d_out.view(-1, d_out.shape[-2] * d_out.shape[-1]), - ctx.fp8_meta["scaling_bwd"], - META_DO, - fp8_dtype_backward, - ).view(d_out.shape) - dq_fp8, dkv_fp8, *rest = fused_attn_bwd_kvpacked( - ctx.max_seqlen_q, - ctx.max_seqlen_kv, - cu_seqlens_q, - cu_seqlens_kv, - q_fp8, - kv_fp8, - out_fp8, - d_out_fp8, - fp8_dtype_forward, - fp8_dtype_backward, - aux_ctx_tensors, - ctx.fused_attention_backend, - cu_seqlens_q_padded, - cu_seqlens_kv_padded, - fwd_scale_invs[META_QKV], # d_scale_qkv, - fwd_scale_invs[META_S], # d_scale_s, - fwd_scale_invs[META_O], # d_scale_o, - ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO], # d_scale_do - ctx.fp8_meta["scaling_bwd"].scale_inv[META_DP], # d_scale_dp - fwd_scales[META_S], # q_scale_s - ctx.fp8_meta["scaling_bwd"].scale[META_DP], # q_scale_dp - ctx.fp8_meta["scaling_bwd"].scale[META_DQKV], # q_scale_dqkv - ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DP], # amax_dp - ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DQKV], # amax_dqkv - ctx.attn_scale, - ctx.dropout_p, - ctx.fast_zero_fill, - ctx.qkv_layout, - ctx.attn_bias_type, - ctx.attn_mask_type, - ctx.window_size, - ctx.deterministic, - ) - if ctx.is_input_fp8: - dq = Float8Tensor( - data=dq_fp8, - fp8_meta=ctx.fp8_meta, - fp8_meta_forward=False, - fp8_meta_index=META_DQKV, - fp8_dtype=fp8_dtype_backward, - dtype=d_out_f8tensor.dtype, - ) - dkv = Float8Tensor( - data=dkv_fp8, - fp8_meta=ctx.fp8_meta, - fp8_meta_forward=False, - fp8_meta_index=META_DQKV, - fp8_dtype=fp8_dtype_backward, - dtype=d_out_f8tensor.dtype, - ) - else: - dq = cast_from_fp8( - dq_fp8.view(-1, dq_fp8.shape[-2] * dq_fp8.shape[-1]), - ctx.fp8_meta["scaling_bwd"], - META_DQKV, - fp8_dtype_backward, - ctx.qkv_dtype, - ).view(dq_fp8.shape) - dkv_c_fp8 = dkv_fp8.view( - -1, dkv_fp8.shape[-3] * dkv_fp8.shape[-2] * dkv_fp8.shape[-1] + query_layer, key_layer, value_layer = ( + convert_to_torch_float8(x, torch_dtype) + for x in [query_layer, key_layer, value_layer] ) - dkv = cast_from_fp8( - dkv_c_fp8, - ctx.fp8_meta["scaling_bwd"], - META_DQKV, - fp8_dtype_backward, - ctx.qkv_dtype, - ).view(dkv_fp8.shape) + try: + output, _ = func( + query_layer, + key_layer, + value_layer, + *fa_optional_forward_args_thd, + softmax_scale=self.softmax_scale, + causal="causal" in attn_mask_type, + **fa_3_optional_forward_kwargs, + ) + except TypeError as e: + if _flash_attn_3_0_0_beta: + e.args = ( + e.args[0] + + ". Please update your flash-attn v3 (beta) installation as it " + + "may have added more supported arguments to its API. \n" + + _flash_attn_3_installation_steps, + ) + e.args[1:] + raise + + if fp8: + output = output.to(dtype=torch_orig_dtype) + if fp8 and fp8_meta["recipe"].fp8_mha: + O_quantizer = quantizers["scaling_fwd"][META_O] + output = O_quantizer(output) else: - if d_out.dtype == torch.uint8: - d_out = d_out_f8tensor.from_float8(q.dtype) - dq, dkv, *rest = fused_attn_bwd_kvpacked( - ctx.max_seqlen_q, - ctx.max_seqlen_kv, - cu_seqlens_q, - cu_seqlens_kv, - q, - kv, - out, - d_out, - ctx.qkv_dtype, - ctx.qkv_dtype, - aux_ctx_tensors, - ctx.fused_attention_backend, - cu_seqlens_q_padded, - cu_seqlens_kv_padded, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - ctx.attn_scale, - ctx.dropout_p, - ctx.fast_zero_fill, - ctx.qkv_layout, - ctx.attn_bias_type, - ctx.attn_mask_type, - ctx.window_size, - ctx.deterministic, + output = func( + query_layer, + key_layer, + value_layer, + *fa_optional_forward_args_thd, + self.attention_dropout if self.training else 0.0, + softmax_scale=self.softmax_scale, + causal="causal" in attn_mask_type, + **fa_optional_forward_kwargs, ) - # if no_bias or alibi, return dqkv - if ctx.attn_bias_type in ["no_bias", "alibi"]: - return ( - None, - None, - None, - None, - None, - None, - None, - dq, - dkv, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - ) - # else, return (dqkv, dbias) - return ( - None, - None, - None, - None, - None, - None, - None, - dq, - dkv, - None, - rest[0], - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, + if qkv_format in ["sbhd", "bshd"] and "padding" in attn_mask_type: + output = UnpackTensor.apply(indices_q, batch_size * max_seqlen_q, output) + + if qkv_format == "sbhd": + # (bs)hd -> bs(hd) -> sb(hd) + if fp8 and fp8_meta["recipe"].fp8_mha: + output_data = ( + output._data.reshape(batch_size, max_seqlen_q // cp_size, -1) + .transpose(0, 1) + .contiguous() + ) + output = Float8Tensor.make_like( + output, + data=output_data, + shape=output_data.shape, + ) + else: + output = output.view(batch_size, max_seqlen_q // cp_size, -1).transpose(0, 1) + elif qkv_format == "bshd": + # (bs)hd -> bs(hd) + output = output.reshape(batch_size, max_seqlen_q // cp_size, -1) + elif qkv_format == "thd": + # thd -> t(hd) + output = output.reshape(output.shape[0], -1) + + return output.contiguous() + + +def _combine_tensors( + tensors: List[torch.Tensor], + dim: int, +) -> torch.Tensor: + """Combine tensors along a particular dimension""" + + num_tensors = len(tensors) + new_shape = list(tensors[0].shape) + new_shape.insert(dim, num_tensors) + if isinstance(tensors[0], Float8Tensor): + new_stride = list(tensors[0]._data.stride()) + new_stride.insert(dim, int(new_stride[dim - 1] / num_tensors)) + combined_tensor = torch.Tensor().to(device=tensors[0].device, dtype=tensors[0]._data.dtype) + combined_tensor.set_( + tensors[0]._data.untyped_storage(), + tensors[0]._data.storage_offset(), + new_shape, + new_stride, + ) + combined_tensor = Float8Tensor.make_like(tensors[0], data=combined_tensor, shape=new_shape) + else: + new_stride = list(tensors[0].stride()) + new_stride.insert(dim, int(new_stride[dim - 1] / num_tensors)) + combined_tensor = torch.Tensor().to(device=tensors[0].device, dtype=tensors[0].dtype) + combined_tensor.set_( + tensors[0].untyped_storage(), tensors[0].storage_offset(), new_shape, new_stride ) + return combined_tensor + class FusedAttnFunc(torch.autograd.Function): """Function for FusedAttention with separate Q, K, V tensors""" @@ -6519,56 +5936,51 @@ def forward( use_FAv2_bwd, fp8, fp8_meta, + quantizers, deterministic, ): # pylint: disable=missing-function-docstring # "fp8_mha" decides outputs in fp8, while inputs are inferred from the real dtype is_input_fp8 = False - is_output_fp8 = fp8_meta["recipe"].fp8_mha + is_output_fp8 = fp8_meta["recipe"].fp8_mha if "recipe" in fp8_meta else False + fake_dtype = q.dtype + + QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer = ( + get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=False) + ) if fp8: fused_attention_backend = FusedAttnBackend["FP8"] - fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) assert isinstance(k, q.__class__) and isinstance( v, q.__class__ ), "q, k, and v must have the same type." + is_input_fp8 = isinstance(q, Float8Tensor) + q_fp8, k_fp8, v_fp8 = None, None, None if is_input_fp8: - fp8_meta["scaling_fwd"].scale_inv[META_QKV] = q._scale_inv - q_fp8, k_fp8, v_fp8 = q._data, k._data, v._data + q_fp8, k_fp8, v_fp8 = q, k, v else: # 1: qkv packed, 2: kv packed, 3: qkv separate qkv_group = len(qkv_layout.split("_")) - if qkv_group == 1: - dim = qkv_layout.find("3") - qkv = _combine_tensors([q, k, v], dim) - qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1]) - qkv_fp8 = cast_to_fp8( - qkv_c, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward - ).view(qkv.shape) - q_fp8, k_fp8, v_fp8 = _SplitAlongDim.apply(qkv_fp8, dim, [1, 1, 1]) - q_fp8, k_fp8, v_fp8 = [x.squeeze(dim) for x in [q_fp8, k_fp8, v_fp8]] - if qkv_group == 2: - q_fp8 = cast_to_fp8( - q, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward - ).view(q.shape) - dim = qkv_layout.split("_")[1].find("2") - kv = _combine_tensors([k, v], dim) - kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1]) - kv_fp8 = cast_to_fp8( - kv_c, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward - ).view(kv.shape) - k_fp8, v_fp8 = _SplitAlongDim.apply(kv_fp8, dim, [1, 1]) - k_fp8, v_fp8 = [x.squeeze(dim) for x in [k_fp8, v_fp8]] - if qkv_group == 3: - q_fp8 = cast_to_fp8( - q, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward - ).view(q.shape) - k_fp8 = cast_to_fp8( - k, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward - ).view(k.shape) - v_fp8 = cast_to_fp8( - v, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward - ).view(v.shape) + match qkv_group: + case 1: + dim = qkv_layout.find("3") + qkv = _combine_tensors([q, k, v], dim) + qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1]) + qkv_fp8 = QKV_quantizer(qkv) + q_fp8, k_fp8, v_fp8 = _SplitAlongDim.apply(qkv_fp8, dim, [1, 1, 1], True) + case 2: + q_fp8 = QKV_quantizer(q) + dim = qkv_layout.split("_")[1].find("2") + kv = _combine_tensors([k, v], dim) + kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1]) + kv_fp8 = QKV_quantizer(kv_c) + k_fp8, v_fp8 = _SplitAlongDim.apply(kv_fp8, dim, [1, 1], True) + case 3: + q_fp8 = QKV_quantizer(q) + k_fp8 = QKV_quantizer(k) + v_fp8 = QKV_quantizer(v) + case _: + raise "Invalid qkv_layout " + qkv_layout out_fp8, aux_ctx_tensors = fused_attn_fwd( is_training, max_seqlen_q, @@ -6578,23 +5990,13 @@ def forward( q_fp8, k_fp8, v_fp8, - fp8_dtype_forward, + fake_dtype, fused_attention_backend, attn_bias, cu_seqlens_q_padded, cu_seqlens_kv_padded, - fp8_meta["scaling_fwd"].scale_inv, # d_scale_qkv - META_QKV, # d_scale_qkv_offset - fp8_meta["scaling_fwd"].scale_inv, # d_scale_s - META_S, # d_scale_s_offset - fp8_meta["scaling_fwd"].scale, # q_scale_s - META_S, # q_scale_s_offset - fp8_meta["scaling_fwd"].scale, # q_scale_o - META_O, # q_scale_o_offset - fp8_meta["scaling_fwd"].amax_history, # amax_s - META_S, # amax_s_offset - fp8_meta["scaling_fwd"].amax_history, # amax_o - META_O, # amax_o_offset + S_quantizer, + O_quantizer, attn_scale, dropout_p, fast_zero_fill, @@ -6605,22 +6007,9 @@ def forward( rng_gen, ) if is_output_fp8: - out_ret = Float8Tensor( - data=out_fp8, - fp8_meta=fp8_meta, - fp8_meta_forward=True, - fp8_meta_index=META_O, - fp8_dtype=fp8_dtype_forward, - dtype=q.dtype, - ) + out_ret = out_fp8 else: - out_ret = cast_from_fp8( - out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]), - fp8_meta["scaling_fwd"], - META_O, - fp8_dtype_forward, - qkv_dtype, - ).view(out_fp8.shape) + out_ret = out_fp8.dequantize().view(out_fp8.shape) out_save = out_ret if not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): @@ -6631,75 +6020,25 @@ def forward( dim = qkv_layout.find("3") qkv = _combine_tensors([q, k, v], dim) qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1]) - qkv_no_fp8 = cast_from_fp8( - qkv_c._data, - fp8_meta["scaling_fwd"], - META_QKV, - fp8_dtype_forward, - TE_DType[qkv.dtype], - ).view(qkv.shape) - q, k, v = _SplitAlongDim.apply(qkv_no_fp8, dim, [1, 1, 1]) - q, k, v = [x.squeeze(dim) for x in [q, k, v]] + qkv_no_fp8 = qkv_c.dequantize().view(qkv.shape) + q, k, v = _SplitAlongDim.apply(qkv_no_fp8, dim, [1, 1, 1], True) if qkv_group == 2: - q = cast_from_fp8( - q._data, - fp8_meta["scaling_fwd"], - META_QKV, - fp8_dtype_forward, - TE_DType[q.dtype], - ).view(q.shape) + q = q.dequantize() dim = qkv_layout.split("_")[1].find("2") kv = _combine_tensors([k, v], dim) kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1]) - kv_no_fp8 = cast_from_fp8( - kv_c._data, - fp8_meta["scaling_fwd"], - META_QKV, - fp8_dtype_forward, - TE_DType[kv.dtype], - ).view(kv.shape) - k, v = _SplitAlongDim.apply(kv_no_fp8, dim, [1, 1]) - k, v = [x.squeeze(dim) for x in [k, v]] + kv_no_fp8 = kv.dequantize() + k, v = _SplitAlongDim.apply(kv_no_fp8, dim, [1, 1], True) if qkv_group == 3: - q = cast_from_fp8( - q._data, - fp8_meta["scaling_fwd"], - META_QKV, - fp8_dtype_forward, - TE_DType[q.dtype], - ).view(q.shape) - k = cast_from_fp8( - k._data, - fp8_meta["scaling_fwd"], - META_QKV, - fp8_dtype_forward, - TE_DType[k.dtype], - ).view(k.shape) - v = cast_from_fp8( - v._data, - fp8_meta["scaling_fwd"], - META_QKV, - fp8_dtype_forward, - TE_DType[v.dtype], - ).view(v.shape) + q = q.dequantize() + k = k.dequantize() + v = v.dequantize() if is_output_fp8: - out_save = cast_from_fp8( - out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]), - fp8_meta["scaling_fwd"], - META_O, - fp8_dtype_forward, - qkv_dtype, - ).view(out_fp8.shape) - - fp8_tensors = ( - q_fp8, - k_fp8, - v_fp8, - out_fp8, - fp8_meta["scaling_fwd"].scale.clone(), - fp8_meta["scaling_fwd"].scale_inv.clone(), - ) + out_save = out_fp8.dequantize() + + fp8_tensors = (q_fp8, k_fp8, v_fp8, out_fp8) else: + out_ret, aux_ctx_tensors = fused_attn_fwd( is_training, max_seqlen_q, @@ -6709,23 +6048,13 @@ def forward( q, k, v, - qkv_dtype, + fake_dtype, fused_attention_backend, attn_bias, cu_seqlens_q_padded, cu_seqlens_kv_padded, - None, # d_scale_qkv - 0, # d_scale_qkv_offset - None, # d_scale_s - 0, # d_scale_s_offset - None, # q_scale_s - 0, # q_scale_s_offset - None, # q_scale_o - 0, # q_scale_o_offset - None, # amax_s - 0, # amax_s_offset - None, # amax_o - 0, # amax_o_offset + None, # s_quantizer + None, # o_quantizer attn_scale, dropout_p, fast_zero_fill, @@ -6736,7 +6065,7 @@ def forward( rng_gen, ) out_save = out_ret - fp8_tensors = (None, None, None, None, None, None) + fp8_tensors = (None, None, None, None) ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")) @@ -6758,18 +6087,27 @@ def forward( ctx.is_input_fp8 = is_input_fp8 ctx.is_output_fp8 = is_output_fp8 qkvo_tensors = (q, k, v, out_save) if not ctx.fp8 else (None, None, None, None) - ctx.save_for_backward( + tensors_to_save, tensor_objects = prepare_for_saving( + *fp8_tensors, *qkvo_tensors, cu_seqlens_q, cu_seqlens_kv, cu_seqlens_q_padded, cu_seqlens_kv_padded, - *fp8_tensors, *aux_ctx_tensors, ) + ctx.save_for_backward(*tensors_to_save) + ctx.tensor_objects = tensor_objects ctx.fp8_meta = fp8_meta + + ctx.dQKV_quantizer = dQKV_quantizer + ctx.dO_quantizer = dO_quantizer + ctx.dP_quantizer = dP_quantizer + ctx.S_quantizer = S_quantizer + ctx.max_seqlen_q = max_seqlen_q ctx.max_seqlen_kv = max_seqlen_kv + ctx.fake_dtype = fake_dtype ctx.qkv_dtype = qkv_dtype ctx.attn_scale = attn_scale ctx.dropout_p = dropout_p @@ -6793,11 +6131,13 @@ def backward(ctx, d_out): assert isinstance( d_out, Float8Tensor ), "Gradient of the DPA output must be in Float8Tensor type for FP8 MHA." - d_out_f8tensor = d_out - d_out = d_out._data d_out = d_out.contiguous() ( + q_fp8, + k_fp8, + v_fp8, + out_fp8, q, k, v, @@ -6806,14 +6146,11 @@ def backward(ctx, d_out): cu_seqlens_kv, cu_seqlens_q_padded, cu_seqlens_kv_padded, - q_fp8, - k_fp8, - v_fp8, - out_fp8, - fwd_scales, - fwd_scale_invs, - *aux_ctx_tensors, - ) = ctx.saved_tensors + *other_tensors, + ) = restore_from_saved(ctx.tensor_objects, ctx.saved_tensors) + + aux_ctx_tensors = other_tensors + if not aux_ctx_tensors[0].is_contiguous(): aux_ctx_tensors[0] = aux_ctx_tensors[0].contiguous() rest = [None] @@ -6850,20 +6187,10 @@ def backward(ctx, d_out): else: with torch.cuda.nvtx.range("_FusedAttn"): if ctx.fp8: - fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True) - fp8_dtype_backward = get_fp8_te_dtype( - ctx.fp8_meta["recipe"], fprop_tensor=False - ) if ctx.is_output_fp8: d_out_fp8 = d_out - ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO] = d_out_f8tensor._scale_inv else: - d_out_fp8 = cast_to_fp8( - d_out.view(-1, d_out.shape[-2] * d_out.shape[-1]), - ctx.fp8_meta["scaling_bwd"], - META_DO, - fp8_dtype_backward, - ).view(d_out.shape) + d_out_fp8 = ctx.dO_quantizer(d_out) dq_fp8, dk_fp8, dv_fp8, *rest = fused_attn_bwd( ctx.max_seqlen_q, ctx.max_seqlen_kv, @@ -6874,22 +6201,15 @@ def backward(ctx, d_out): v_fp8, out_fp8, d_out_fp8, - fp8_dtype_forward, - fp8_dtype_backward, + ctx.fake_dtype, + ctx.qkv_dtype, aux_ctx_tensors, ctx.fused_attention_backend, cu_seqlens_q_padded, cu_seqlens_kv_padded, - fwd_scale_invs[META_QKV], # d_scale_qkv, - fwd_scale_invs[META_S], # d_scale_s, - fwd_scale_invs[META_O], # d_scale_o, - ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO], # d_scale_do - ctx.fp8_meta["scaling_bwd"].scale_inv[META_DP], # d_scale_dp - fwd_scales[META_S], # q_scale_s - ctx.fp8_meta["scaling_bwd"].scale[META_DP], # q_scale_dp - ctx.fp8_meta["scaling_bwd"].scale[META_DQKV], # q_scale_dqkv - ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DP], # amax_dp - ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DQKV], # amax_dqkv + ctx.S_quantizer, + ctx.dP_quantizer, + ctx.dQKV_quantizer, ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill, @@ -6900,95 +6220,36 @@ def backward(ctx, d_out): ctx.deterministic, ) - if ctx.is_input_fp8: - dq = Float8Tensor( - data=dq_fp8, - fp8_meta=ctx.fp8_meta, - fp8_meta_forward=False, - fp8_meta_index=META_DQKV, - fp8_dtype=fp8_dtype_backward, - dtype=d_out_f8tensor.dtype, - ) - dk = Float8Tensor( - data=dk_fp8, - fp8_meta=ctx.fp8_meta, - fp8_meta_forward=False, - fp8_meta_index=META_DQKV, - fp8_dtype=fp8_dtype_backward, - dtype=d_out_f8tensor.dtype, - ) - dv = Float8Tensor( - data=dv_fp8, - fp8_meta=ctx.fp8_meta, - fp8_meta_forward=False, - fp8_meta_index=META_DQKV, - fp8_dtype=fp8_dtype_backward, - dtype=d_out_f8tensor.dtype, - ) - else: + if not ctx.is_input_fp8: qkv_group = len(ctx.qkv_layout.split("_")) if qkv_group == 1: dim = ctx.qkv_layout.find("3") - dqkv_fp8 = _combine_tensors([dq_fp8, dk_fp8, dv_fp8], dim) - dqkv_c_fp8 = dqkv_fp8.view( - -1, dqkv_fp8.shape[-3] * dqkv_fp8.shape[-2] * dqkv_fp8.shape[-1] + dqkv_fp8_data = _combine_tensors( + [dq_fp8._data, dk_fp8._data, dv_fp8._data], dim ) - dqkv = cast_from_fp8( - dqkv_c_fp8, - ctx.fp8_meta["scaling_bwd"], - META_DQKV, - fp8_dtype_backward, - ctx.qkv_dtype, - ).view(dqkv_fp8.shape) - dq, dk, dv = _SplitAlongDim.apply(dqkv, dim, [1, 1, 1]) - dq, dk, dv = [x.squeeze(dim) for x in [dq, dk, dv]] + dqkv_fp8 = dq_fp8.make_like( + tensor=dq_fp8, data=dqkv_fp8_data, shape=dqkv_fp8_data.shape + ) + dqkv = dqkv_fp8.dequantize() + dq, dk, dv = _SplitAlongDim.apply(dqkv, dim, [1, 1, 1], True) if qkv_group == 2: - dq = cast_from_fp8( - dq_fp8.view(-1, dq_fp8.shape[-2] * dq_fp8.shape[-1]), - ctx.fp8_meta["scaling_bwd"], - META_DQKV, - fp8_dtype_backward, - ctx.qkv_dtype, - ).view(dq_fp8.shape) + dq = dq_fp8.dequantize() dim = ctx.qkv_layout.split("_")[1].find("2") dkv_fp8 = _combine_tensors([dk_fp8, dv_fp8], dim) dkv_c_fp8 = dkv_fp8.view( -1, dkv_fp8.shape[-3] * dkv_fp8.shape[-2] * dkv_fp8.shape[-1] ) - dkv = cast_from_fp8( - dkv_c_fp8, - ctx.fp8_meta["scaling_bwd"], - META_DQKV, - fp8_dtype_backward, - ctx.qkv_dtype, - ).view(dkv_fp8.shape) - dk, dv = _SplitAlongDim.apply(dkv, dim, [1, 1]) - dk, dv = [x.squeeze(dim) for x in [dk, dv]] + dkv = dkv_c_fp8.dequantize() + dk, dv = _SplitAlongDim.apply(dkv, dim, [1, 1], True) if qkv_group == 3: - dq = cast_from_fp8( - dq_fp8.view(-1, dq_fp8.shape[-2] * dq_fp8.shape[-1]), - ctx.fp8_meta["scaling_bwd"], - META_DQKV, - fp8_dtype_backward, - ctx.qkv_dtype, - ).view(dq_fp8.shape) - dk = cast_from_fp8( - dk_fp8.view(-1, dk_fp8.shape[-2] * dk_fp8.shape[-1]), - ctx.fp8_meta["scaling_bwd"], - META_DQKV, - fp8_dtype_backward, - ctx.qkv_dtype, - ).view(dk_fp8.shape) - dv = cast_from_fp8( - dv_fp8.view(-1, dv_fp8.shape[-2] * dv_fp8.shape[-1]), - ctx.fp8_meta["scaling_bwd"], - META_DQKV, - fp8_dtype_backward, - ctx.qkv_dtype, - ).view(dv_fp8.shape) + dq = dq_fp8.dequantize() + dk = dk_fp8.dequantize() + dv = dv_fp8.dequantize() + else: + dq, dk, dv = dq_fp8, dk_fp8, dv_fp8 else: - if d_out.dtype == torch.uint8: - d_out = d_out_f8tensor.from_float8(q.dtype) + if isinstance(d_out, QuantizedTensor): + d_out = d_out.dequantize() dq, dk, dv, *rest = fused_attn_bwd( ctx.max_seqlen_q, ctx.max_seqlen_kv, @@ -6999,7 +6260,7 @@ def backward(ctx, d_out): v, out, d_out, - ctx.qkv_dtype, + ctx.fake_dtype, ctx.qkv_dtype, aux_ctx_tensors, ctx.fused_attention_backend, @@ -7008,13 +6269,6 @@ def backward(ctx, d_out): None, None, None, - None, - None, - None, - None, - None, - None, - None, ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill, @@ -7055,6 +6309,7 @@ def backward(ctx, d_out): None, None, None, + None, ) # else, return (dqkv, dbias) return ( @@ -7085,6 +6340,7 @@ def backward(ctx, d_out): None, None, None, + None, ) @@ -7184,6 +6440,7 @@ def forward( cp_comm_type: str = "p2p", fp8: bool = False, fp8_meta: Optional[Dict[str, Any]] = None, + quantizers=None, ) -> torch.Tensor: """fused attention fprop""" assert ( @@ -7321,6 +6578,7 @@ def forward( window_size=window_size, fp8=fp8, fp8_meta=fp8_meta, + quantizers=quantizers, ) else: with self.attention_dropout_ctx(): @@ -7349,6 +6607,7 @@ def forward( use_FAv2_bwd, fp8, fp8_meta, + quantizers, self.deterministic, ) @@ -7736,7 +6995,6 @@ def forward( alibi_slopes: Optional[torch.Tensor] = None, fast_zero_fill: bool = True, inference_params: Optional[InferenceParams] = None, - is_first_microbatch: Optional[bool] = None, ) -> torch.Tensor: """ Dot Product Attention Layer. @@ -7906,27 +7164,13 @@ def forward( Adjustments of the sequence_len_offset should be done after a complete forward pass. If rotary positional embeddings (RoPE) are utilized, they must be prepared beforehand. Supports "sbhd" and "bshd" layouts, with the "sbhd" layout being more efficient. - is_first_microbatch : {True, False, None}, default = None - During training using either gradient accumulation or - pipeline parallelism a minibatch of data is further split - into microbatches. Between the microbatches of the same minibatch - the model weights are not updated. Setting this parameter indicates - whether the current microbatch is the first in a minibatch or not. - When set, this parameter enables additional optimizations: - - * during FP8 training, it allows caching of the FP8 versions of - the weights - * it also allows skipping gradient accumulation during the - first microbatch (since it is the first gradient being - produced) """ + with self.prepare_forward( query_layer, - is_first_microbatch, num_gemms=3, allow_non_contiguous=True, ) as query_layer: - if self.fp8: if self.fp8_meta["recipe"].fp8_mha: if not self.fp8_meta["recipe"].fp8_dpa: @@ -8290,6 +7534,7 @@ def forward( max_seqlen_kv=max_seqlen_kv, fp8=self.fp8 and self.fp8_meta["recipe"].fp8_dpa, fp8_meta=self.fp8_meta, + quantizers=self.quantizers, ) if use_fused_attention: @@ -8358,6 +7603,7 @@ def forward( cp_comm_type=self.cp_comm_type, fp8=self.fp8 and self.fp8_meta["recipe"].fp8_dpa, fp8_meta=self.fp8_meta, + quantizers=self.quantizers, ) from .cpu_offload import CPUOffloadEnabled @@ -8569,11 +7815,11 @@ def __init__( fuse_qkv_params: bool = False, zero_centered_gamma: bool = False, qkv_weight_interleaved: bool = True, - ub_bulk_wgrad: bool = False, - ub_bulk_dgrad: bool = False, - ub_overlap_rs_dgrad: bool = False, - ub_overlap_rs: bool = False, ub_overlap_ag: bool = False, + ub_overlap_rs: bool = False, + ub_overlap_rs_dgrad: bool = False, + ub_bulk_dgrad: bool = False, + ub_bulk_wgrad: bool = False, bias: bool = True, normalization: str = "LayerNorm", device: Union[torch.device, str] = "cuda", @@ -9035,16 +8281,9 @@ def forward( # not qkv_weight_interleaved: # [sq, b, (np/ng + 2), ng, hn] # --> [sq, b, np/ng, np, hn], [sq, b, 1, ng, hn], [sq, b, 1, ng, hn] - if not is_in_onnx_export_mode(): - query_layer, key_layer, value_layer = _SplitAlongDim.apply( - mixed_x_layer, split_dim, (num_queries_per_key_value, 1, 1) - ) - else: - query_layer, key_layer, value_layer = torch.split( - mixed_x_layer, - (num_queries_per_key_value, 1, 1), - dim=split_dim, - ) + query_layer, key_layer, value_layer = _SplitAlongDim.apply( + mixed_x_layer, split_dim, (num_queries_per_key_value, 1, 1) + ) if self.qkv_format == "thd": query_layer, key_layer, value_layer = ( @@ -9086,18 +8325,11 @@ def forward( mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape) # mixed_kv_layer --> 2 [sk, b, ng, hn] - if not is_in_onnx_export_mode(): - key_layer, value_layer = _SplitAlongDim.apply( - mixed_kv_layer, - split_dim, - mixed_kv_layer.shape[split_dim] // 2, - ) - else: - key_layer, value_layer = torch.split( - mixed_kv_layer, - mixed_kv_layer.shape[split_dim] // 2, - dim=split_dim, - ) + key_layer, value_layer = _SplitAlongDim.apply( + mixed_kv_layer, + split_dim, + mixed_kv_layer.shape[split_dim] // 2, + ) key_layer, value_layer = ( x.reshape( x.size(0), @@ -9208,10 +8440,10 @@ def forward( # =================== # Output. [sq, b, h] # =================== - projection_output = self.proj( context_layer, is_first_microbatch=is_first_microbatch, + fp8_grad=isinstance(context_layer, QuantizedTensor), ) if self.return_bias: diff --git a/transformer_engine/pytorch/constants.py b/transformer_engine/pytorch/constants.py index c1790313ac..ff475caf21 100644 --- a/transformer_engine/pytorch/constants.py +++ b/transformer_engine/pytorch/constants.py @@ -16,6 +16,8 @@ """ TE_DType = { torch.uint8: tex.DType.kByte, + torch.float8_e4m3fn: tex.DType.kFloat8E4M3, + torch.float8_e5m2: tex.DType.kFloat8E5M2, torch.int32: tex.DType.kInt32, torch.float32: tex.DType.kFloat32, torch.half: tex.DType.kFloat16, @@ -59,3 +61,5 @@ GemmParallelModes = ("row", "column", None) dist_group_type = torch.distributed.ProcessGroup + +MXFP8_BLOCK_SCALING_SIZE = 32 diff --git a/transformer_engine/pytorch/cpp_extensions/__init__.py b/transformer_engine/pytorch/cpp_extensions/__init__.py index be911fcd95..944d1849bf 100644 --- a/transformer_engine/pytorch/cpp_extensions/__init__.py +++ b/transformer_engine/pytorch/cpp_extensions/__init__.py @@ -7,8 +7,3 @@ from .fused_attn import * from .gemm import * -from .transpose import * -from .activation import * -from .normalization import * -from .cast import * -from .padding import * diff --git a/transformer_engine/pytorch/cpp_extensions/_common.py b/transformer_engine/pytorch/cpp_extensions/_common.py deleted file mode 100644 index aec972994a..0000000000 --- a/transformer_engine/pytorch/cpp_extensions/_common.py +++ /dev/null @@ -1,87 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -"""Helper functions for C++ extensions""" -import functools -from typing import Dict, Optional, Tuple, Union - -import torch - -import transformer_engine_torch as tex - - -@functools.lru_cache(maxsize=None) -def empty_tensor() -> torch.Tensor: - """Get tensor with no entries and no data""" - return torch.Tensor() - - -def canonicalize_fp8_scales( - *, - scale: Optional[torch.Tensor] = None, - amax: Optional[torch.Tensor] = None, - scale_inv: Optional[torch.Tensor] = None, - fp8_meta: Optional[tex.FP8TensorMeta] = None, - fp8_meta_index: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None] = None, - allow_multiple_offsets: bool = True, -) -> Tuple[Dict[str, torch.Tensor], Dict[str, int]]: - """Canonicalize FP8 scaling factors (scale, amax, scale-inverse) - - If a scaling factor is not provided, try to access it within the - FP8 meta tensors. Returns dict with tensors and dict with tensor - offsets. - - """ - - # Default: use provided scales with no offsets - scale_offset = 0 - amax_offset = 0 - scale_inv_offset = 0 - - # Get scales from FP8 meta tensors if needed - if (fp8_meta is not None) and any(arg is None for arg in (scale, amax, scale_inv)): - if fp8_meta_index is None: - raise ValueError("Provided `fp8_meta` without corresponding `fp8_meta_index`") - fp8_meta_index = int(fp8_meta_index) - if scale is None: - scale = fp8_meta.scale - scale_offset = fp8_meta_index - if amax is None: - amax = fp8_meta.amax_history - amax_offset = fp8_meta_index - if scale_inv is None: - scale_inv = fp8_meta.scale_inv - scale_inv_offset = fp8_meta_index - - # Construct empty tensors if needed - if scale is None: - scale = empty_tensor() - scale_offset = 0 - if amax is None: - amax = empty_tensor() - amax_offset = 0 - if scale_inv is None: - scale_inv = empty_tensor() - scale_inv_offset = 0 - - # Force offsets to be the same if needed - if not allow_multiple_offsets and not scale_offset == amax_offset == scale_inv_offset: - if scale_offset != 0: - scale = scale[scale_offset:] - scale_offset = 0 - if amax_offset != 0: - amax = amax[:, amax_offset:] - amax_offset = 0 - if scale_inv_offset != 0: - scale_inv = scale_inv[scale_inv_offset:] - scale_inv_offset = 0 - - # Pack tensors and offsets into dicts - tensors = {"scale": scale, "amax": amax, "scale_inv": scale_inv} - offsets = { - "scale_offset": scale_offset, - "amax_offset": amax_offset, - "scale_inv_offset": scale_inv_offset, - } - return tensors, offsets diff --git a/transformer_engine/pytorch/cpp_extensions/activation.py b/transformer_engine/pytorch/cpp_extensions/activation.py deleted file mode 100644 index 534e71d134..0000000000 --- a/transformer_engine/pytorch/cpp_extensions/activation.py +++ /dev/null @@ -1,237 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -"""Python interface for activation extensions""" -from typing import Optional, Union - -import torch - -import transformer_engine_torch as tex -from ._common import canonicalize_fp8_scales - -__all__ = ["gelu", "relu", "reglu", "geglu", "swiglu", "qgelu", "srelu"] - - -def gelu( - inp: torch.Tensor, - fp8_meta_tensor: Optional[tex.FP8TensorMeta], - fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None], - otype: tex.DType, - scale: Optional[torch.Tensor] = None, - amax: Optional[torch.Tensor] = None, - scale_inv: Optional[torch.Tensor] = None, -) -> torch.Tensor: - """GeLU with FP8 output""" - - # Get FP8 scaling factors - fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales( - scale=scale, - amax=amax, - scale_inv=scale_inv, - fp8_meta=fp8_meta_tensor, - fp8_meta_index=fp8_tensor, - allow_multiple_offsets=False, - ) - - # Launch kernel - return torch.ops.tex_ts.gelu_ts( - inp, - fp8_scales["scale"], - fp8_scales["amax"], - fp8_scales["scale_inv"], - fp8_scales_offsets["scale_offset"], - otype, - ) - - -def relu( - inp: torch.Tensor, - fp8_meta_tensor: Optional[tex.FP8TensorMeta], - fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None], - otype: tex.DType, - scale: Optional[torch.Tensor] = None, - amax: Optional[torch.Tensor] = None, - scale_inv: Optional[torch.Tensor] = None, -) -> torch.Tensor: - """ReLU with FP8 output""" - - # Get FP8 scaling factors - fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales( - scale=scale, - amax=amax, - scale_inv=scale_inv, - fp8_meta=fp8_meta_tensor, - fp8_meta_index=fp8_tensor, - allow_multiple_offsets=False, - ) - - # Launch kernel - return torch.ops.tex_ts.relu_ts( - inp, - fp8_scales["scale"], - fp8_scales["amax"], - fp8_scales["scale_inv"], - fp8_scales_offsets["scale_offset"], - otype, - ) - - -def geglu( - inp: torch.Tensor, - fp8_meta_tensor: Optional[tex.FP8TensorMeta], - fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None], - otype: tex.DType, - scale: Optional[torch.Tensor] = None, - amax: Optional[torch.Tensor] = None, - scale_inv: Optional[torch.Tensor] = None, -) -> torch.Tensor: - """GeGLU with FP8 output""" - - # Get FP8 scaling factors - fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales( - scale=scale, - amax=amax, - scale_inv=scale_inv, - fp8_meta=fp8_meta_tensor, - fp8_meta_index=fp8_tensor, - allow_multiple_offsets=False, - ) - - # Launch kernel - return torch.ops.tex_ts.geglu_ts( - inp, - fp8_scales["scale"], - fp8_scales["amax"], - fp8_scales["scale_inv"], - fp8_scales_offsets["scale_offset"], - otype, - ) - - -def reglu( - inp: torch.Tensor, - fp8_meta_tensor: Optional[tex.FP8TensorMeta], - fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None], - otype: tex.DType, - scale: Optional[torch.Tensor] = None, - amax: Optional[torch.Tensor] = None, - scale_inv: Optional[torch.Tensor] = None, -) -> torch.Tensor: - """ReGLU with FP8 output""" - - # Get FP8 scaling factors - fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales( - scale=scale, - amax=amax, - scale_inv=scale_inv, - fp8_meta=fp8_meta_tensor, - fp8_meta_index=fp8_tensor, - allow_multiple_offsets=False, - ) - - # Launch kernel - return torch.ops.tex_ts.reglu_ts( - inp, - fp8_scales["scale"], - fp8_scales["amax"], - fp8_scales["scale_inv"], - fp8_scales_offsets["scale_offset"], - otype, - ) - - -def swiglu( - inp: torch.Tensor, - fp8_meta_tensor: Optional[tex.FP8TensorMeta], - fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None], - otype: tex.DType, - scale: Optional[torch.Tensor] = None, - amax: Optional[torch.Tensor] = None, - scale_inv: Optional[torch.Tensor] = None, -) -> torch.Tensor: - """SwiGLU with FP8 output""" - - # Get FP8 scaling factors - fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales( - scale=scale, - amax=amax, - scale_inv=scale_inv, - fp8_meta=fp8_meta_tensor, - fp8_meta_index=fp8_tensor, - allow_multiple_offsets=False, - ) - - # Launch kernel - return torch.ops.tex_ts.swiglu_ts( - inp, - fp8_scales["scale"], - fp8_scales["amax"], - fp8_scales["scale_inv"], - fp8_scales_offsets["scale_offset"], - otype, - ) - - -def qgelu( - inp: torch.Tensor, - fp8_meta_tensor: Optional[tex.FP8TensorMeta], - fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None], - otype: tex.DType, - scale: Optional[torch.Tensor] = None, - amax: Optional[torch.Tensor] = None, - scale_inv: Optional[torch.Tensor] = None, -) -> torch.Tensor: - """QuickGELU with FP8 output""" - - # Get FP8 scaling factors - fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales( - scale=scale, - amax=amax, - scale_inv=scale_inv, - fp8_meta=fp8_meta_tensor, - fp8_meta_index=fp8_tensor, - allow_multiple_offsets=False, - ) - - # Launch kernel - return torch.ops.tex_ts.qgelu_ts( - inp, - fp8_scales["scale"], - fp8_scales["amax"], - fp8_scales["scale_inv"], - fp8_scales_offsets["scale_offset"], - otype, - ) - - -def srelu( - inp: torch.Tensor, - fp8_meta_tensor: Optional[tex.FP8TensorMeta], - fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None], - otype: tex.DType, - scale: Optional[torch.Tensor] = None, - amax: Optional[torch.Tensor] = None, - scale_inv: Optional[torch.Tensor] = None, -) -> torch.Tensor: - """ReLU with FP8 output""" - - # Get FP8 scaling factors - fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales( - scale=scale, - amax=amax, - scale_inv=scale_inv, - fp8_meta=fp8_meta_tensor, - fp8_meta_index=fp8_tensor, - allow_multiple_offsets=False, - ) - - # Launch kernel - return torch.ops.tex_ts.srelu_ts( - inp, - fp8_scales["scale"], - fp8_scales["amax"], - fp8_scales["scale_inv"], - fp8_scales_offsets["scale_offset"], - otype, - ) diff --git a/transformer_engine/pytorch/cpp_extensions/cast.py b/transformer_engine/pytorch/cpp_extensions/cast.py deleted file mode 100644 index 9c21edccec..0000000000 --- a/transformer_engine/pytorch/cpp_extensions/cast.py +++ /dev/null @@ -1,93 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -"""Python interface for cast extensions""" -from typing import Optional, Union - -import torch - -import transformer_engine_torch as tex -from ._common import canonicalize_fp8_scales - -__all__ = ["cast_to_fp8", "cast_from_fp8"] - - -def cast_to_fp8( - inp: torch.Tensor, - fp8_meta_tensor: Optional[tex.FP8TensorMeta], - fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None], - otype: tex.DType, - out: Optional[torch.Tensor] = None, - scale: Optional[torch.Tensor] = None, - amax: Optional[torch.Tensor] = None, - scale_inv: Optional[torch.Tensor] = None, -) -> torch.Tensor: - """Cast input to FP8""" - - # Get FP8 scaling factors - fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales( - scale=scale, - amax=amax, - scale_inv=scale_inv, - fp8_meta=fp8_meta_tensor, - fp8_meta_index=fp8_tensor, - allow_multiple_offsets=False, - ) - - # Launch FP8 cast kernel - if inp.nelement() == 0: - if out is None: - out = torch.empty_like(inp, dtype=torch.uint8) - elif out is None: - out = torch.ops.tex_ts.cast_to_fp8_ts( - inp, - fp8_scales["scale"], - fp8_scales["amax"], - fp8_scales["scale_inv"], - fp8_scales_offsets["scale_offset"], - otype, - ) - else: - torch.ops.tex_ts.cast_to_fp8_noalloc_ts( - inp, - fp8_scales["scale"], - out, - fp8_scales["amax"], - fp8_scales["scale_inv"], - fp8_scales_offsets["scale_offset"], - otype, - ) - return out - - -def cast_from_fp8( - inp: torch.Tensor, - fp8_meta_tensor: Optional[tex.FP8TensorMeta], - fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None], - itype: tex.DType, - otype: tex.DType, - scale_inv: Optional[torch.Tensor] = None, -) -> torch.Tensor: - """Cast input from FP8""" - - # Get scaling factors from FP8 meta tensors if needed - scale_inv_offset = 0 - if (fp8_meta_tensor is not None) and (scale_inv is None): - if fp8_tensor is None: - raise ValueError("Provided `fp8_meta_tensor` without corresponding `fp8_tensor`") - scale_inv = fp8_meta_tensor.scale_inv - scale_inv_offset = int(fp8_tensor) - - # Construct empty tensors if needed - if scale_inv is None: - raise ValueError("Did not provide either `scale_inv` or `fp8_meta_tensor`") - - # Launch FP8 cast kernel - return torch.ops.tex_ts.cast_from_fp8_ts( - inp, - scale_inv, - scale_inv_offset, - itype, - otype, - ) diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index 332b4e52ee..b91a6c1751 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -4,7 +4,7 @@ """Python interface for fused attention extensions""" import math -from typing import Tuple, List, Union +from typing import Tuple, List, Union, Optional import torch import transformer_engine_torch as tex from transformer_engine_torch import ( @@ -13,13 +13,10 @@ NVTE_Mask_Type, NVTE_Fused_Attn_Backend, ) +from ..tensor.quantized_tensor import Quantizer __all__ = [ - "fused_attn_fwd_qkvpacked", - "fused_attn_bwd_qkvpacked", - "fused_attn_fwd_kvpacked", - "fused_attn_bwd_kvpacked", "fused_attn_fwd", "fused_attn_bwd", ] @@ -89,803 +86,6 @@ META_DQKV_CP = tex.FP8BwdTensors.GRAD_INPUT1 -def fused_attn_fwd_qkvpacked( - is_training: bool, - max_seqlen: int, - cu_seqlens: torch.Tensor, - qkv: torch.Tensor, - qkv_dtype: tex.DType, - fused_attention_backend: tex.NVTE_Fused_Attn_Backend, - attn_bias: torch.Tensor = None, - cu_seqlens_padded: torch.Tensor = None, - d_scale_qkv: torch.Tensor = None, - d_scale_qkv_offset: int = META_QKV, - d_scale_s: torch.Tensor = None, - d_scale_s_offset: int = META_S, - q_scale_s: torch.Tensor = None, - q_scale_s_offset: int = META_S, - q_scale_o: torch.Tensor = None, - q_scale_o_offset: int = META_O, - amax_s: torch.Tensor = None, - amax_s_offset: int = META_S, - amax_o: torch.Tensor = None, - amax_o_offset: int = META_O, - attn_scale: float = None, - dropout: float = 0.0, - fast_zero_fill: bool = True, - qkv_layout: str = "sbh3d", - attn_bias_type: str = "no_bias", - attn_mask_type: str = "padding", - window_size: Tuple[int, int] = (-1, -1), - rng_gen: torch.Generator = None, -) -> Tuple[Union[torch.Tensor, None], ...]: - """Fused Attention FWD for packed QKV input. - - Parameters - ---------- - is_training: bool - if True, runs training and produces auxiliary tensors aux_ctx_tensors - for the backward; if False, runs inference and doesn't produce aux_ctx_tensors - max_seqlen: int - max sequence length for QKV, used for padding; may be larger than max(seqlens), - seqlens = cu_seqlens[1:] - cu_seqlens[:-1] - cu_seqlens: torch.Tensor - cumulative sequence lengths for QKV; shape [batch_size + 1] - qkv: torch.Tensor - input tensor QKV; shape 3hd or h3d (see `qkv_layout` for details) - qkv_dtype: tex.DType - data type of QKV; in tex.DType, not torch.dtype - fused_attention_backend: tex.NVTE_Fused_Attn_Backend - please see FusedAttention module for details on supported backends. - attn_bias: torch.Tensor, default = None - input tensor Bias when attn_bias_type is "pre_scale_bias" or "post_scale_bias"; - shape [1, num_heads, max_seqlen, max_seqlen], same data type as qkv - cu_seqlens_padded: torch.Tensor, default = None - cumulative sequence offsets for QKV; shape [batch_size + 1] - d_scale_qkv: torch.Tensor, default = None - input tensor for the dequantization of QKV in FP8 computations - d_scale_qkv_offset: int, default = META_QKV - offset in d_scale_qkv for QKV - d_scale_s: torch.Tensor, default = None - input tensor for the dequantization of S in FP8 computations, S = Softmax(Q * K.T) - d_scale_s_offset: int, default = META_S - offset in d_scale_s for S - q_scale_s: torch.Tensor, default = None - input tensor for the quantization of S in FP8 computations, S = Softmax(Q * K.T) - q_scale_s_offset: int, default = META_S - offset in q_scale_s for S - q_scale_o: torch.Tensor, default = None - input tensor for the quantization of O in FP8 computations - q_scale_o_offset: int, default = META_O - offset in q_scale_o for O - amax_s: torch.Tensor, default = None - output tensor, amax of S, used by the next iteration in FP8 computations - amax_s_offset: int, default = META_S - offset in amax_s for S - amax_o: torch.Tensor, default = None - output tensor, amax of O, used by the next iteration in FP8 computations - amax_o_offset: int, default = META_O - offset in amax_o for O - attn_scale: float, default = None - if not None, use attn_scale as the attention scale for Q*K.T BMM; - if None, use 1.0/sqrt(head_dim_qk) as the default - dropout: float, default = 0.0 - dropout probability, 0.0 means no dropout, 1.0 means no output; - dropout must be 0.0 if is_training is False - fast_zero_fill: bool, default = True - if True, initializes the output tensor O to zero using the fast filling method; - if False, uses PyTorch's .fill_() method - qkv_layout: str, default = "sbh3d" - layout of QKV; {"sbh3d", "sb3hd", "bsh3d", "bs3hd", "th3d", "t3hd"} - attn_bias_type: str, default = "no_bias" - type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"} - attn_mask_type: str, default = "padding" - type of the attention mask; {"padding", "causal", "padding_causal", "no_mask"} - window_size: Tuple[int, int], default = (-1, -1) - sliding window size for local attention, where query at position i attends to keys - in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q - + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding - window and causal mask specifically. - rng_gen: torch.Generator, default = None - random number generator; - if None, uses the default CUDA generator from PyTorch; otherwise, uses rng_gen - - Returns - ---------- - o: torch.Tensor - output tensor O, of the attention calculation; same data type as QKV; - same shape as Q, i.e. thd, sbhd or bshd (see `qkv_layout` for details) - aux_ctx_tensors: List[torch.Tensor] - auxiliary output tensors used for the backward; - if is_training is True, aux_ctx_tensors = [softmax-related tensors, rng_state] - if is_training is False, aux_ctx_tensors = None - - softmax-related tensors: - 1. if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"] - softmax: torch.Tensor - Softmax(Q*K.T) - shape [batch_size, num_heads, max_seqlen, max_seqlen], dtype float32 - 2. if fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"] - softmaxStats: torch.Tensor - log(sum(e^(x - max(x)))), where x=Q*K.T - shape [batch_size, num_heads, max_seqlen, 1], dtype float32 - 3. if fused_attention_backend == FusedAttnBackend["FP8"] - M: torch.Tensor - max(Q*K.T) - shape [batch_size, num_heads, max_seqlen, 1], dtype float32 - ZInv: torch.Tensor - 1/sum(e^(x - max(x))), where x=Q*K.T - shape [batch_size, num_heads, max_seqlen, 1], dtype float32 - rng_state: torch.Tensor, optional, if backend is not F16_max512_seqlen - state of the random number generator; - [seed, offset], dtype uint64 - """ - - if attn_scale is None: - d = qkv.size(-1) - attn_scale = 1.0 / math.sqrt(d) - - if attn_bias_type not in ["no_bias", "alibi"]: - assert ( - attn_bias is not None - ), "attn_bias tensor cannot be None when attn_bias_type is not no_bias or alibi." - assert attn_bias.dtype == qkv.dtype, "attn_bias tensor must be in the same dtype as qkv." - - assert ( - fused_attention_backend != FusedAttnBackend["No_Backend"] - ), "Fused attention does not support this input combination." - - # BF16/FP16 fused attention API from fmha_v1 apex - if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]: - rng_elts_per_thread = ( - max_seqlen * max_seqlen + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1 - ) // BACKEND_F16m512_FP8_THREADS_PER_CTA - # BF16/FP16 fused attention API from fmha_v2 - elif fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]: - rng_elts_per_thread = BACKEND_F16arb_ELTS_PER_THREADS - # FP8 fused attention API from fmha_v2 - elif fused_attention_backend == FusedAttnBackend["FP8"]: - rng_elts_per_thread = ( - max_seqlen * max_seqlen + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1 - ) // BACKEND_F16m512_FP8_THREADS_PER_CTA - - assert ( - d_scale_qkv is not None - ), "d_scale_qkv is required as an input for FP8 fused attention." - assert d_scale_s is not None, "q_scale_s is required as an input for FP8 fused attention." - assert q_scale_s is not None, "q_scale_s is required as an input for FP8 fused attention." - assert q_scale_o is not None, "q_scale_o is required as an input for FP8 fused attention." - assert amax_s is not None, "amax_s is required as an input for FP8 fused attention." - assert amax_o is not None, "amax_o is required as an input for FP8 fused attention." - else: - raise ValueError(f"Unsupported backend {fused_attention_backend}") - - # execute kernel - output_tensors = tex.fused_attn_fwd_qkvpacked( - max_seqlen, - is_training, - attn_scale, - dropout, - fast_zero_fill, - QKVLayout[qkv_layout], - AttnBiasType[attn_bias_type], - AttnMaskType[attn_mask_type], - window_size, - cu_seqlens, - qkv, - qkv_dtype, - cu_seqlens_padded, - d_scale_qkv, - d_scale_qkv_offset, - d_scale_s, - d_scale_s_offset, - q_scale_s, - q_scale_s_offset, - q_scale_o, - q_scale_o_offset, - amax_s, - amax_s_offset, - amax_o, - amax_o_offset, - attn_bias, - rng_gen, - rng_elts_per_thread, - ) - - # out, aux_ctx_tensors - return output_tensors[0], output_tensors[1:] - - -def fused_attn_bwd_qkvpacked( - max_seqlen: int, - cu_seqlens: torch.Tensor, - qkv: torch.Tensor, - o: torch.Tensor, - d_o: torch.Tensor, - qkv_dtype: tex.DType, - dqkv_dtype: tex.DType, - aux_ctx_tensors: List[torch.Tensor], - fused_attention_backend: tex.NVTE_Fused_Attn_Backend, - cu_seqlens_padded: torch.Tensor = None, - d_scale_qkv: torch.Tensor = None, - d_scale_s: torch.Tensor = None, - d_scale_o: torch.Tensor = None, - d_scale_do: torch.Tensor = None, - d_scale_dp: torch.Tensor = None, - q_scale_s: torch.Tensor = None, - q_scale_dp: torch.Tensor = None, - q_scale_dqkv: torch.Tensor = None, - amax_dp: torch.Tensor = None, - amax_dqkv: torch.Tensor = None, - attn_scale: float = None, - dropout: float = 0.0, - fast_zero_fill: bool = True, - qkv_layout: str = "sbh3d", - attn_bias_type: str = "no_bias", - attn_mask_type: str = "padding", - window_size: Tuple[int, int] = (-1, -1), - deterministic: bool = False, -) -> Tuple[Union[torch.Tensor, None], ...]: - """Fused Attention BWD for packed QKV input. - - Parameters - ---------- - max_seqlen: int - max sequence length for QKV, used for padding; may be larger than max(seqlens) - seqlens = cu_seqlens[1:] - cu_seqlens[:-1] - cu_seqlens: torch.Tensor - cumulative sequence lengths for QKV; shape [batch_size + 1] - qkv: torch.Tensor - input tensor QKV; shape 3hd or h3d (see `qkv_layout` for details) - o: torch.Tensor - input tensor O (output of forward); - same shape as Q, i.e. thd, sbhd or bshd (see `qkv_layout` for details) - d_o: torch.Tensor - input tensor dO (gradient of O); - same shape as Q, i.e. thd, sbhd or bshd (see `qkv_layout` for details) - qkv_dtype: tex.DType - data type of QKV; in tex.DType, not torch.dtype - dqkv_dtype: tex.DType - data type of dQKV; in tex.DType, not torch.dtype - aux_ctx_tensors: List[torch.Tensor] - auxiliary output tensors of the forward pass when its is_training is True, - e.g. aux_ctx_tensors = [M, ZInv, rng_state] - fused_attention_backend: tex.NVTE_Fused_Attn_Backend - please see FusedAttention module for details on supported backends. - cu_seqlens_padded: torch.Tensor, default = None - cumulative sequence offsets for QKV; shape [batch_size + 1] - d_scale_qkv: torch.Tensor, default = None - input tensor for the dequantization of QKV in FP8 computations - d_scale_s: torch.Tensor, default = None - input tensor for the dequantization of S in FP8 computations, S = Softmax(Q * K.T) - d_scale_o: torch.Tensor, default = None - input tensor for the dequantization of O in FP8 computations - d_scale_do: torch.Tensor, default = None - input tensor for the dequantization of dO in FP8 computations - d_scale_dp: torch.Tensor, default = None - input tensor for the dequantization of dP in FP8 computations - q_scale_s: torch.Tensor, default = None - input tensor for the quantization of S in FP8 computations - q_scale_dp: torch.Tensor, default = None - input tensor for the quantization of dP in FP8 computations, P = Q * K.T - q_scale_dqkv: torch.Tensor, default = None - input tensor for the quantization of dQKV in FP8 computations - amax_dp: torch.Tensor, default = None - output tensor, amax of dP, used by the next iteration in FP8 computations - amax_dqkv: torch.Tensor, default = None - output tensor, amax of dQKV, used by the next iteration in FP8 computations - attn_scale: float, default = None - if not None, use attn_scale as the attention scale for Q*K.T BMM; - if None, use 1.0/sqrt(head_dim_qk) as the default - dropout: float, default = 0.0 - dropout probability, 0.0 means no dropout, 1.0 means no output; - dropout must be 0.0 if is_training is False - fast_zero_fill: bool, default = True - if True, initializes the output tensor O to zero using the fast filling method; - if False, uses PyTorch's .fill_() method - qkv_layout: str, default = "sbh3d" - layout of QKV; {"sbh3d", "sb3hd", "bsh3d", "bs3hd", "th3d", "t3hd"} - attn_bias_type: str, default = "no_bias" - type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"} - attn_mask_type: str, default = "padding" - type of the attention mask; {"padding", "causal", "padding_causal", "no_mask"} - window_size: Tuple[int, int], default = (-1, -1) - sliding window size for local attention, where query at position i attends to keys - in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q - + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding - window and causal mask specifically. - deterministic: bool, default = False - whether to execute the backward pass with deterministic behaviours. - - Returns - ---------- - d_qkv: torch.Tensor - gradient tensor of QKV; same data type and shape as QKV - d_bias: torch.Tensor, optional - gradient tensor of Bias when attn_bias_type is "pre_scale_bias" - or "post_scale_bias"; same data type and shape as Bias - """ - - if attn_scale is None: - d = qkv.size(-1) - attn_scale = 1.0 / math.sqrt(d) - - assert ( - fused_attention_backend != FusedAttnBackend["No_Backend"] - ), "Fused attention does not support this input combination." - - if fused_attention_backend != FusedAttnBackend["F16_max512_seqlen"]: - assert ( - len(aux_ctx_tensors) >= 1 - ), "aux_ctx_tensors must contain rng_state as its last element." - - if fused_attention_backend == FusedAttnBackend["FP8"]: - assert d_scale_qkv is not None, "d_scale_qkv is required for FP8 fused attention." - assert d_scale_s is not None, "d_scale_s is required for FP8 fused attention." - assert d_scale_o is not None, "d_scale_o is required for FP8 fused attention." - assert d_scale_do is not None, "d_scale_do is required for FP8 fused attention." - assert d_scale_dp is not None, "d_scale_dp is required for FP8 fused attention." - assert q_scale_s is not None, "q_scale_s is required for FP8 fused attention." - assert q_scale_dp is not None, "q_scale_dp is required for FP8 fused attention." - assert q_scale_dqkv is not None, "q_scale_dqkv is required for FP8 fused attention." - assert amax_dp is not None, "amax_dp is required for FP8 fused attention." - assert amax_dqkv is not None, "amax_dqkv is required for FP8 fused attention." - assert ( - len(aux_ctx_tensors) == 3 - ), "aux_ctx_tensors is required to be [M, ZInv, rng_state] for FP8 fused attention." - - # execute kernel - output_tensors = tex.fused_attn_bwd_qkvpacked( - max_seqlen, - attn_scale, - dropout, - fast_zero_fill, - QKVLayout[qkv_layout], - AttnBiasType[attn_bias_type], - AttnMaskType[attn_mask_type], - window_size, - deterministic, - cu_seqlens, - qkv, - o, - d_o, - qkv_dtype, - dqkv_dtype, - aux_ctx_tensors, - cu_seqlens_padded, - d_scale_qkv, - d_scale_s, - d_scale_o, - d_scale_do, - d_scale_dp, - q_scale_s, - q_scale_dp, - q_scale_dqkv, - amax_dp, - amax_dqkv, - ) - - return output_tensors - - -def fused_attn_fwd_kvpacked( - is_training: bool, - max_seqlen_q: int, - max_seqlen_kv: int, - cu_seqlens_q: torch.Tensor, - cu_seqlens_kv: torch.Tensor, - q: torch.Tensor, - kv: torch.Tensor, - qkv_dtype: tex.DType, - fused_attention_backend: tex.NVTE_Fused_Attn_Backend, - attn_bias: torch.Tensor = None, - cu_seqlens_q_padded: torch.Tensor = None, - cu_seqlens_kv_padded: torch.Tensor = None, - d_scale_qkv: torch.Tensor = None, - d_scale_qkv_offset: int = META_QKV, - d_scale_s: torch.Tensor = None, - d_scale_s_offset: int = META_S, - q_scale_s: torch.Tensor = None, - q_scale_s_offset: int = META_S, - q_scale_o: torch.Tensor = None, - q_scale_o_offset: int = META_O, - amax_s: torch.Tensor = None, - amax_s_offset: int = META_S, - amax_o: torch.Tensor = None, - amax_o_offset: int = META_O, - attn_scale: float = None, - dropout: float = 0.0, - fast_zero_fill: bool = True, - qkv_layout: str = "sbhd_sbh2d", - attn_bias_type: str = "no_bias", - attn_mask_type: str = "padding", - window_size: Tuple[int, int] = (-1, -1), - rng_gen: torch.Generator = None, -) -> Tuple[Union[torch.Tensor, None], ...]: - """Fused Attention FWD for packed KV input. - - Parameters - ---------- - is_training: bool - if True, runs training and produces auxiliary tensors aux_ctx_tensors - for the backward; if False, runs inference and doesn't produce aux_ctx_tensors - max_seqlen_q: int - max sequence length for Q, used for padding; may be larger than max(seqlens_q), - seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1] - max_seqlen_kv: int - max sequence length for KV, used for padding; may be larger than max(seqlens_kv), - seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] - cu_seqlens_q: torch.Tensor - cumulative sequence lengths for Q; shape [batch_size + 1] - cu_seqlens_kv: torch.Tensor - cumulative sequence lengths for KV; shape [batch_size + 1] - q: torch.Tensor - input tensor Q; shape thd, sbhd or bshd (see `qkv_layout` for details) - kv: torch.Tensor - packed input tensor KV; shape 2hd or h2d (see `qkv_layout` for details) - qkv_dtype: tex.DType - data type of Q and KV; in tex.DType, not torch.dtype - fused_attention_backend: tex.NVTE_Fused_Attn_Backend - please see FusedAttention module for details on supported backends. - attn_bias: torch.Tensor, default = None - input tensor Bias when attn_bias_type is "pre_scale_bias" or "post_scale_bias"; - shape [1, num_heads, max_seqlen_q, max_seqlen_kv], same data type as q and kv - cu_seqlens_q_padded: torch.Tensor, default = None - cumulative sequence offsets for Q; shape [batch_size + 1] - cu_seqlens_kv_padded: torch.Tensor, default = None - cumulative sequence offsets for KV; shape [batch_size + 1] - d_scale_qkv: torch.Tensor, default = None - input tensor for the dequantization of QKV in FP8 computations - d_scale_qkv_offset: int, default = META_QKV - offset in d_scale_qkv for QKV - d_scale_s: torch.Tensor, default = None - input tensor for the dequantization of S in FP8 computations, S = Softmax(Q * K.T) - d_scale_s_offset: int, default = META_S - offset in d_scale_s for S - q_scale_s: torch.Tensor, default = None - input tensor for the quantization of S in FP8 computations, S = Softmax(Q * K.T) - q_scale_s_offset: int, default = META_S - offset in q_scale_s for S - q_scale_o: torch.Tensor, default = None - input tensor for the quantization of O in FP8 computations - q_scale_o_offset: int, default = META_O - offset in q_scale_o for O - amax_s: torch.Tensor, default = None - output tensor, amax of S, used by the next iteration in FP8 computations - amax_s_offset: int, default = META_S - offset in amax_s for S - amax_o: torch.Tensor, default = None - output tensor, amax of O, used by the next iteration in FP8 computations - amax_o_offset: int, default = META_O - offset in amax_o for O - attn_scale: float, default = None - if not None, use attn_scale as the attention scale for Q*K.T BMM; - if None, use 1.0/sqrt(head_dim_qk) as the default - dropout: float, default = 0.0 - dropout probability, 0.0 means no dropout, 1.0 means no output; - dropout must be 0.0 if is_training is False - fast_zero_fill: bool, default = True - if True, initializes the output tensor O to zero using the fast filling method; - if False, uses PyTorch's .fill_() method - qkv_layout: str, default = "sbhd_sbh2d" - layout of QKV; - {"sbhd_sbh2d", "sbhd_sb2hd", "bshd_bsh2d", "bshd_bs2hd", "thd_th2d", "thd_t2hd"} - attn_bias_type: str, default = "no_bias" - type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"} - attn_mask_type: str, default = "padding" - type of the attention mask; {"padding", "causal", "padding_causal", "no_mask"} - window_size: Tuple[int, int], default = (-1, -1) - sliding window size for local attention, where query at position i attends to keys - in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q - + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding - window and causal mask specifically. - rng_gen: torch.Generator, default = None - random number generator; - if None, uses the default CUDA generator from PyTorch; otherwise, uses rng_gen - - Returns - ---------- - o: torch.Tensor - output tensor O, of the attention calculation; same data type as QKV; - same shape as Q, i.e. thd, sbhd or bshd (see `qkv_layout` for details) - aux_ctx_tensors: List[torch.Tensor] - auxiliary output tensors used for the backward; - if is_training is True, aux_ctx_tensors = [softmax-related tensors, rng_state] - if is_training is False, aux_ctx_tensors = None - - softmax-related tensors: - 1. if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"] - softmax: torch.Tensor - Softmax(Q*K.T) - shape [batch_size, num_heads, max_seqlen_q, max_seqlen_kv], dtype float32 - 2. if fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"] - softmaxStats: torch.Tensor - log(sum(e^(x - max(x)))), where x=Q*K.T - shape [batch_size, num_heads, max_seqlen_q, 1], dtype float32 - 3. if fused_attention_backend == FusedAttnBackend["FP8"] - M: torch.Tensor - max(Q*K.T) - shape [batch_size, num_heads, max_seqlen_q, 1], dtype float32 - ZInv: torch.Tensor - 1/sum(e^(x - max(x))), where x=Q*K.T - shape [batch_size, num_heads, max_seqlen_q, 1], dtype float32 - rng_state: torch.Tensor, optional, if backend is not F16_max512_seqlen - state of the random number generator; - [seed, offset], dtype uint64 - """ - - if attn_scale is None: - d = q.size(-1) - attn_scale = 1.0 / math.sqrt(d) - - if attn_bias_type not in ["no_bias", "alibi"]: - assert ( - attn_bias is not None - ), "attn_bias tensor cannot be None when attn_bias_type is not no_bias or alibi." - assert attn_bias.dtype == q.dtype, "attn_bias tensor must be in the same dtype as q and kv." - - assert ( - fused_attention_backend != FusedAttnBackend["No_Backend"] - ), "Fused attention does not support this input combination." - - # BF16/FP16 fused attention API from fmha_v1 apex - if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]: - rng_elts_per_thread = ( - max_seqlen_q * max_seqlen_kv + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1 - ) // BACKEND_F16m512_FP8_THREADS_PER_CTA - # BF16/FP16 fused attention API from fmha_v2 - elif fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]: - rng_elts_per_thread = BACKEND_F16arb_ELTS_PER_THREADS - # FP8 fused attention API from fmha_v2 - elif fused_attention_backend == FusedAttnBackend["FP8"]: - rng_elts_per_thread = ( - max_seqlen_q * max_seqlen_q + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1 - ) // BACKEND_F16m512_FP8_THREADS_PER_CTA - - assert ( - d_scale_qkv is not None - ), "d_scale_qkv is required as an input for FP8 fused attention." - assert d_scale_s is not None, "q_scale_s is required as an input for FP8 fused attention." - assert q_scale_s is not None, "q_scale_s is required as an input for FP8 fused attention." - assert q_scale_o is not None, "q_scale_o is required as an input for FP8 fused attention." - assert amax_s is not None, "amax_s is required as an input for FP8 fused attention." - assert amax_o is not None, "amax_o is required as an input for FP8 fused attention." - else: - raise ValueError(f"Unsupported backend {fused_attention_backend}") - - # execute kernel - output_tensors = tex.fused_attn_fwd_kvpacked( - max_seqlen_q, - max_seqlen_kv, - is_training, - attn_scale, - dropout, - fast_zero_fill, - QKVLayout[qkv_layout], - AttnBiasType[attn_bias_type], - AttnMaskType[attn_mask_type], - window_size, - cu_seqlens_q, - cu_seqlens_kv, - q, - kv, - qkv_dtype, - cu_seqlens_q_padded, - cu_seqlens_kv_padded, - d_scale_qkv, - d_scale_qkv_offset, - d_scale_s, - d_scale_s_offset, - q_scale_s, - q_scale_s_offset, - q_scale_o, - q_scale_o_offset, - amax_s, - amax_s_offset, - amax_o, - amax_o_offset, - attn_bias, - rng_gen, - rng_elts_per_thread, - ) - - # out, aux_ctx_tensors - return output_tensors[0], output_tensors[1:] - - -def fused_attn_bwd_kvpacked( - max_seqlen_q: int, - max_seqlen_kv: int, - cu_seqlens_q: torch.Tensor, - cu_seqlens_kv: torch.Tensor, - q: torch.Tensor, - kv: torch.Tensor, - o: torch.Tensor, - d_o: torch.Tensor, - qkv_dtype: tex.DType, - dqkv_dtype: tex.DType, - aux_ctx_tensors: List[torch.Tensor], - fused_attention_backend: tex.NVTE_Fused_Attn_Backend, - cu_seqlens_q_padded: torch.Tensor = None, - cu_seqlens_kv_padded: torch.Tensor = None, - d_scale_qkv: torch.Tensor = None, - d_scale_s: torch.Tensor = None, - d_scale_o: torch.Tensor = None, - d_scale_do: torch.Tensor = None, - d_scale_dp: torch.Tensor = None, - q_scale_s: torch.Tensor = None, - q_scale_dp: torch.Tensor = None, - q_scale_dqkv: torch.Tensor = None, - amax_dp: torch.Tensor = None, - amax_dqkv: torch.Tensor = None, - attn_scale: float = None, - dropout: float = 0.0, - fast_zero_fill: bool = True, - qkv_layout: str = "sbhd_sbh2d", - attn_bias_type: str = "no_bias", - attn_mask_type: str = "padding", - window_size: Tuple[int, int] = (-1, -1), - deterministic: bool = False, -) -> Tuple[Union[torch.Tensor, None], ...]: - """Fused Attention BWD for packed KV input. - - Parameters - ---------- - max_seqlen_q: int - max sequence length for Q, used for padding; may be larger than max(seqlens_q), - seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1] - max_seqlen_kv: int - max sequence length for KV, used for padding; may be larger than max(seqlens_kv), - seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] - cu_seqlens_q: torch.Tensor - cumulative sequence lengths for Q; shape [batch_size + 1] - cu_seqlens_kv: torch.Tensor - cumulative sequence lengths for KV; shape [batch_size + 1] - q: torch.Tensor - input tensor Q; shape thd, sbhd or bshd (see `qkv_layout` for details) - kv: torch.Tensor - packed input tensor KV; shape h2d or 2hd (see `qkv_layout` for details) - o: torch.Tensor - input tensor O (output of forward); - same shape as Q, i.e. thd, sbhd or bshd (see `qkv_layout` for details) - d_o: torch.Tensor - input tensor dO (gradient of O); - same shape as Q, i.e. thd, sbhd or bshd (see `qkv_layout` for details) - qkv_dtype: tex.DType - data type of Q and KV; in tex.DType, not torch.dtype - dqkv_dtype: tex.DType - data type of dQ and dKV; in tex.DType, not torch.dtype - aux_ctx_tensors: List[torch.Tensor] - auxiliary output tensors of the forward pass when its is_training is True, - e.g. aux_ctx_tensors = [M, ZInv, rng_state] - fused_attention_backend: tex.NVTE_Fused_Attn_Backend - please see FusedAttention module for details on supported backends. - cu_seqlens_q_padded: torch.Tensor, default = None - cumulative sequence offsets for Q; shape [batch_size + 1] - cu_seqlens_kv_padded: torch.Tensor, default = None - cumulative sequence offsets for KV; shape [batch_size + 1] - d_scale_qkv: torch.Tensor, default = None - input tensor for the dequantization of QKV in FP8 computations - d_scale_s: torch.Tensor, default = None - input tensor for the dequantization of S in FP8 computations, S = Softmax(Q * K.T) - d_scale_o: torch.Tensor, default = None - input tensor for the dequantization of O in FP8 computations - d_scale_do: torch.Tensor, default = None - input tensor for the dequantization of dO in FP8 computations - d_scale_dp: torch.Tensor, default = None - input tensor for the dequantization of dP in FP8 computations - q_scale_s: torch.Tensor, default = None - input tensor for the quantization of S in FP8 computations - q_scale_dp: torch.Tensor, default = None - input tensor for the quantization of dP in FP8 computations, P = Q * K.T - q_scale_dqkv: torch.Tensor, default = None - input tensor for the quantization of dQKV in FP8 computations - amax_dp: torch.Tensor, default = None - output tensor, amax of dP, used by the next iteration in FP8 computations, - P = Q * K.T - amax_dqkv: torch.Tensor, default = None - output tensor, amax of dQKV, used by the next iteration in FP8 computations - attn_scale: float, default = None - if not None, use attn_scale as the attention scale for Q*K.T BMM; - if None, use 1.0/sqrt(head_dim_qk) as the default - dropout: float, default = 0.0 - dropout probability, 0.0 means no dropout, 1.0 means no output; - dropout must be 0.0 if is_training is False - fast_zero_fill: bool, default = True - if True, initializes the output tensor O to zero using the fast filling method; - if False, uses PyTorch's .fill_() method - qkv_layout: str, default = "sbhd_sbh2d" - layout of QKV; - {"sbhd_sbh2d", "sbhd_sb2hd", "bshd_bsh2d", "bshd_bs2hd", "thd_th2d", "thd_t2hd"} - attn_bias_type: str, default = "no_bias" - type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"} - attn_mask_type: str, default = "padding" - type of the attention mask; {"padding", "causal", "padding_causal", "no_mask"} - window_size: Tuple[int, int], default = (-1, -1) - sliding window size for local attention, where query at position i attends to keys - in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q - + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding - window and causal mask specifically. - deterministic: bool, default = False - whether to execute the backward pass with deterministic behaviours. - - Returns - ---------- - d_q: torch.Tensor - gradient tensor of Q; same data type and shape as Q - d_kv: torch.Tensor - gradient tensor of KV; same data type and shape as KV - d_bias: torch.Tensor, optional - gradient tensor of Bias when attn_bias_type is "pre_scale_bias" - or "post_scale_bias"; same data type and shape as Bias - """ - - if attn_scale is None: - d = q.size(-1) - attn_scale = 1.0 / math.sqrt(d) - - assert ( - fused_attention_backend != FusedAttnBackend["No_Backend"] - ), "Fused attention does not support this input combination." - - if fused_attention_backend != FusedAttnBackend["F16_max512_seqlen"]: - assert ( - len(aux_ctx_tensors) >= 1 - ), "aux_ctx_tensors must contain rng_state as its last element." - - if fused_attention_backend == FusedAttnBackend["FP8"]: - assert d_scale_qkv is not None, "d_scale_qkv is required for FP8 fused attention." - assert d_scale_s is not None, "d_scale_s is required for FP8 fused attention." - assert d_scale_o is not None, "d_scale_o is required for FP8 fused attention." - assert d_scale_do is not None, "d_scale_do is required for FP8 fused attention." - assert d_scale_dp is not None, "d_scale_dp is required for FP8 fused attention." - assert q_scale_s is not None, "q_scale_s is required for FP8 fused attention." - assert q_scale_dp is not None, "q_scale_dp is required for FP8 fused attention." - assert q_scale_dqkv is not None, "q_scale_dqkv is required for FP8 fused attention." - assert amax_dp is not None, "amax_dp is required for FP8 fused attention." - assert amax_dqkv is not None, "amax_dqkv is required for FP8 fused attention." - assert ( - len(aux_ctx_tensors) == 3 - ), "aux_ctx_tensors is required to be [M, ZInv, rng_state] for FP8 fused attention." - - # execute kernel - output_tensors = tex.fused_attn_bwd_kvpacked( - max_seqlen_q, - max_seqlen_kv, - attn_scale, - dropout, - fast_zero_fill, - QKVLayout[qkv_layout], - AttnBiasType[attn_bias_type], - AttnMaskType[attn_mask_type], - window_size, - deterministic, - cu_seqlens_q, - cu_seqlens_kv, - q, - kv, - o, - d_o, - qkv_dtype, - dqkv_dtype, - aux_ctx_tensors, - cu_seqlens_q_padded, - cu_seqlens_kv_padded, - d_scale_qkv, - d_scale_s, - d_scale_o, - d_scale_do, - d_scale_dp, - q_scale_s, - q_scale_dp, - q_scale_dqkv, - amax_dp, - amax_dqkv, - ) - - return output_tensors - - def fused_attn_fwd( is_training: bool, max_seqlen_q: int, @@ -895,23 +95,13 @@ def fused_attn_fwd( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - qkv_dtype: tex.DType, + fake_dtype: torch.dtype, fused_attention_backend: tex.NVTE_Fused_Attn_Backend, attn_bias: torch.Tensor = None, cu_seqlens_q_padded: torch.Tensor = None, cu_seqlens_kv_padded: torch.Tensor = None, - d_scale_qkv: torch.Tensor = None, - d_scale_qkv_offset: int = META_QKV, - d_scale_s: torch.Tensor = None, - d_scale_s_offset: int = META_S, - q_scale_s: torch.Tensor = None, - q_scale_s_offset: int = META_S, - q_scale_o: torch.Tensor = None, - q_scale_o_offset: int = META_O, - amax_s: torch.Tensor = None, - amax_s_offset: int = META_S, - amax_o: torch.Tensor = None, - amax_o_offset: int = META_O, + s_quantizer: Quantizer = None, + o_quantizer: Quantizer = None, attn_scale: float = None, dropout: float = 0.0, fast_zero_fill: bool = True, @@ -946,8 +136,9 @@ def fused_attn_fwd( input tensor K; shape sbhd, bshd or thd (see `qkv_layout` for details) v: torch.Tensor input tensor V; shape sbhd, bshd or thd (see `qkv_layout` for details) - qkv_dtype: tex.DType - data type of Q, K and V; in tex.DType, not torch.dtype + fake_dtype: tex.DType + data type of Q, K and V - in case of high precision, fake dtype in case of FP8; + in torch.dtype fused_attention_backend: tex.NVTE_Fused_Attn_Backend please see FusedAttention module for details on supported backends. attn_bias: torch.Tensor, default = None @@ -957,30 +148,10 @@ def fused_attn_fwd( cumulative sequence offsets for Q; shape [batch_size + 1] cu_seqlens_kv_padded: torch.Tensor, default = None cumulative sequence offsets for KV; shape [batch_size + 1] - d_scale_qkv: torch.Tensor, default = None - input tensor for the dequantization of QKV in FP8 computations - d_scale_qkv_offset: int, default = META_QKV - offset in d_scale_qkv for QKV - d_scale_s: torch.Tensor, default = None - input tensor for the dequantization of S in FP8 computations, S = Softmax(Q * K.T) - d_scale_s_offset: int, default = META_S - offset in d_scale_s for S - q_scale_s: torch.Tensor, default = None - input tensor for the quantization of S in FP8 computations, S = Softmax(Q * K.T) - q_scale_s_offset: int, default = META_S - offset in q_scale_s for S - q_scale_o: torch.Tensor, default = None - input tensor for the quantization of O in FP8 computations - q_scale_o_offset: int, default = META_O - offset in q_scale_o for O - amax_s: torch.Tensor, default = None - output tensor, amax of S, used by the next iteration in FP8 computations - amax_s_offset: int, default = META_S - offset in amax_s for S - amax_o: torch.Tensor, default = None - output tensor, amax of O, used by the next iteration in FP8 computations - amax_o_offset: int, default = META_O - offset in amax_o for O + s_quantizer: Quantizer, default = None + Quantizer object for the intermediate value S. + o_quantizer: Quantizer, default = None + Quantizer object for the output of the attention. attn_scale: float, default = None if not None, use attn_scale as the attention scale for Q*K.T BMM; if None, use 1.0/sqrt(head_dim_qk) as the default @@ -1068,17 +239,16 @@ def fused_attn_fwd( ) // BACKEND_F16m512_FP8_THREADS_PER_CTA assert ( - d_scale_qkv is not None - ), "d_scale_qkv is required as an input for FP8 fused attention." - assert d_scale_s is not None, "q_scale_s is required as an input for FP8 fused attention." - assert q_scale_s is not None, "q_scale_s is required as an input for FP8 fused attention." - assert q_scale_o is not None, "q_scale_o is required as an input for FP8 fused attention." - assert amax_s is not None, "amax_s is required as an input for FP8 fused attention." - assert amax_o is not None, "amax_o is required as an input for FP8 fused attention." + s_quantizer is not None + ), "s_quantizer is required as an input for FP8 fused attention." + assert ( + o_quantizer is not None + ), "o_quantizer is required as an input for FP8 fused attention." else: raise ValueError(f"Unsupported backend {fused_attention_backend}") # execute kernel + output_tensors = tex.fused_attn_fwd( max_seqlen_q, max_seqlen_kv, @@ -1095,21 +265,11 @@ def fused_attn_fwd( q, k, v, - qkv_dtype, + fake_dtype, cu_seqlens_q_padded, cu_seqlens_kv_padded, - d_scale_qkv, - d_scale_qkv_offset, - d_scale_s, - d_scale_s_offset, - q_scale_s, - q_scale_s_offset, - q_scale_o, - q_scale_o_offset, - amax_s, - amax_s_offset, - amax_o, - amax_o_offset, + s_quantizer, + o_quantizer, attn_bias, rng_gen, rng_elts_per_thread, @@ -1129,23 +289,16 @@ def fused_attn_bwd( v: torch.Tensor, o: torch.Tensor, d_o: torch.Tensor, - qkv_dtype: tex.DType, + fake_dtype: torch.dtype, dqkv_dtype: tex.DType, aux_ctx_tensors: List[torch.Tensor], fused_attention_backend: tex.NVTE_Fused_Attn_Backend, cu_seqlens_q_padded: torch.Tensor = None, cu_seqlens_kv_padded: torch.Tensor = None, - d_scale_qkv: torch.Tensor = None, - d_scale_s: torch.Tensor = None, - d_scale_o: torch.Tensor = None, - d_scale_do: torch.Tensor = None, - d_scale_dp: torch.Tensor = None, - q_scale_s: torch.Tensor = None, - q_scale_dp: torch.Tensor = None, - q_scale_dqkv: torch.Tensor = None, - amax_dp: torch.Tensor = None, - amax_dqkv: torch.Tensor = None, - attn_scale: float = None, + s_quantizer: Quantizer = None, + dp_quantizer: Quantizer = None, + dqkv_quantizer: Quantizer = None, + attn_scale: Optional[float] = None, dropout: float = 0.0, fast_zero_fill: bool = True, qkv_layout: str = "sbh3d", @@ -1181,8 +334,9 @@ def fused_attn_bwd( d_o: torch.Tensor input tensor dO (gradient of O); same data type as Q, K and V; same shape as Q - qkv_dtype: tex.DType - data type of Q, K and V; in tex.DType, not torch.dtype + fake_dtype: tex.DType + data type of Q, K and V - in case of high precision, fake dtype in case of FP8; + in torch.dtype dqkv_dtype: tex.DType data type of dQ, dK and dV; in tex.DType, not torch.dtype aux_ctx_tensors: List[torch.Tensor] @@ -1194,30 +348,12 @@ def fused_attn_bwd( cumulative sequence offsets for Q; shape [batch_size + 1] cu_seqlens_kv_padded: torch.Tensor, default = None cumulative sequence offsets for KV; shape [batch_size + 1] - d_scale_qkv: torch.Tensor, default = None - input tensor for the dequantization of Q, K and V in FP8 computations - d_scale_s: torch.Tensor, default = None - input tensor for the dequantization of S in FP8 computations, S = Softmax(Q * K.T) - d_scale_o: torch.Tensor, default = None - input tensor for the dequantization of O in FP8 computations - d_scale_do: torch.Tensor, default = None - input tensor for the dequantization of dO in FP8 computations - d_scale_dp: torch.Tensor, default = None - input tensor for the dequantization of dP in FP8 computations - q_scale_s: torch.Tensor, default = None - input tensor for the quantization of S in FP8 computations - q_scale_dp: torch.Tensor, default = None - input tensor for the quantization of dP in FP8 computations, P = Q * K.T - q_scale_dqkv: torch.Tensor, default = None - input tensor for the quantization of dQ, dK and dV in FP8 computations - amax_dp: torch.Tensor, default = None - output tensor, amax of dP, used by the next iteration in FP8 computations, - P = Q * K.T - amax_dqkv: torch.Tensor, default = None - output tensor, amax of dQ, dK and dV, used by the next iteration in FP8 computations - attn_scale: float, default = None - if not None, use attn_scale as the attention scale for Q*K.T BMM; - if None, use 1.0/sqrt(head_dim_qk) as the default + s_quantizer: Quantizer, default = None + Quantizer object for the intermediate value S. + dp_quantizer: Quantizer, default = None + Quantizer object for the intermediate value dP. + dqkv_quantizer: Quantizer, default = None + Quantizer object for the output values of the fused_attn_bwd. dropout: float, default = 0.0 dropout probability, 0.0 means no dropout, 1.0 means no output; dropout must be 0.0 if is_training is False @@ -1253,7 +389,6 @@ def fused_attn_bwd( gradient tensor of Bias when attn_bias_type is "pre_scale_bias" or "post_scale_bias"; same data type and shape as Bias """ - if attn_scale is None: d = q.size(-1) attn_scale = 1.0 / math.sqrt(d) @@ -1268,21 +403,19 @@ def fused_attn_bwd( ), "aux_ctx_tensors must contain rng_state as its last element." if fused_attention_backend == FusedAttnBackend["FP8"]: - assert d_scale_qkv is not None, "d_scale_qkv is required for FP8 fused attention." - assert d_scale_s is not None, "d_scale_s is required for FP8 fused attention." - assert d_scale_o is not None, "d_scale_o is required for FP8 fused attention." - assert d_scale_do is not None, "d_scale_do is required for FP8 fused attention." - assert d_scale_dp is not None, "d_scale_dp is required for FP8 fused attention." - assert q_scale_s is not None, "q_scale_s is required for FP8 fused attention." - assert q_scale_dp is not None, "q_scale_dp is required for FP8 fused attention." - assert q_scale_dqkv is not None, "q_scale_dqkv is required for FP8 fused attention." - assert amax_dp is not None, "amax_dp is required for FP8 fused attention." - assert amax_dqkv is not None, "amax_dqkv is required for FP8 fused attention." + assert ( + s_quantizer is not None + ), "s_quantizer is required as an input for FP8 fused attention backward." + assert ( + dp_quantizer is not None + ), "dp_quantizer is required as an input for FP8 fused attention backward." + assert ( + dqkv_dtype is not None + ), "dqkv_dtype is required as an input for FP8 fused attention backward." assert ( len(aux_ctx_tensors) == 3 ), "aux_ctx_tensors is required to be [M, ZInv, rng_state] for FP8 fused attention." - # execute kernel output_tensors = tex.fused_attn_bwd( max_seqlen_q, max_seqlen_kv, @@ -1301,21 +434,14 @@ def fused_attn_bwd( v, o, d_o, - qkv_dtype, + fake_dtype, dqkv_dtype, aux_ctx_tensors, cu_seqlens_q_padded, cu_seqlens_kv_padded, - d_scale_qkv, - d_scale_s, - d_scale_o, - d_scale_do, - d_scale_dp, - q_scale_s, - q_scale_dp, - q_scale_dqkv, - amax_dp, - amax_dqkv, + s_quantizer, + dp_quantizer, + dqkv_quantizer, ) return output_tensors diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index c55f5a9fd4..948a13a03e 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -4,499 +4,223 @@ """Python interface for GEMM extensions""" import functools -from typing import Optional, Tuple, Union, List +from typing import Iterable, Optional, Tuple, Union, List +import os import torch import transformer_engine_torch as tex from ..constants import TE_DType -from ..utils import assert_dim_for_fp8_exec +from ..utils import assert_dim_for_fp8_exec, get_sm_count +from ..tensor.quantized_tensor import Quantizer +from ..tensor._internal.float8_tensor_base import Float8TensorBase +from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase __all__ = [ - "gemm", - "fp8_gemm", - "grouped_gemm", - "fp8_grouped_gemm", + "general_gemm", + "general_grouped_gemm", ] @functools.lru_cache(maxsize=None) def _empty_tensor() -> torch.Tensor: """Get tensor with no entries and no data""" - return torch.Tensor() + return torch.Tensor().cuda() -def fp8_gemm( - A: torch.Tensor, - A_scale_inv: torch.Tensor, - A_fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], - A_dtype: tex.DType, - B: torch.Tensor, - B_scale_inv: torch.Tensor, - B_fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], - B_dtype: tex.DType, - out_dtype: torch.dtype, - workspace: torch.Tensor, - gelu: bool = False, - accumulate: bool = False, - out: Optional[torch.Tensor] = None, - out_index=None, - fp8_meta_tensor: tex.FP8TensorMeta = None, - bias: Optional[torch.Tensor] = None, - use_bias: bool = False, - use_split_accumulator: bool = False, - D_dtype: Optional[tex.DType] = None, - ub_algo: tex.CommOverlapAlgo = None, - ub: Union[tex.CommOverlap, tex.CommOverlapP2P] = None, - extra_output_tensor: torch.Tensor = None, -) -> torch.Tensor: - """TN layout GEMM with fp8 inputs.""" +def swizzle_inputs(A: torch.Tensor, B: torch.Tensor, layout: str): + """Swizzle gemm inputs and return original scaling factor inverses.""" + if not isinstance(A, MXFP8TensorBase) or not isinstance(B, MXFP8TensorBase): + return None - empty_tensor = _empty_tensor() - if D_dtype is not None and D_dtype in [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]: - assert fp8_meta_tensor is not None and out_index is not None - assert_dim_for_fp8_exec(A) - assert_dim_for_fp8_exec(B) - assert A.dtype == torch.uint8 - assert B.dtype == torch.uint8 - - if out is None: - out = torch.empty( - B.shape[0], - A.shape[0], - dtype=out_dtype, - device="cuda", - ) + original_scale_inverses = ( + A._rowwise_scale_inv, + A._columnwise_scale_inv, + B._rowwise_scale_inv, + B._columnwise_scale_inv, + ) + + if layout[0] == "T": + A._rowwise_scale_inv = tex.rowwise_swizzle(A._rowwise_data, A._rowwise_scale_inv) else: - if not out.is_contiguous(): - raise ValueError("Output tensor is not contiguous.") + A._columnwise_scale_inv = tex.columnwise_swizzle( + A._columnwise_data, A._columnwise_scale_inv + ) - # Use bfloat16 as default bias_dtype - bias_dtype = torch.bfloat16 if bias is None else bias.dtype - if gelu: - gelu_input = torch.empty_like(out, dtype=bias_dtype) + if layout[1] == "N": + B._rowwise_scale_inv = tex.rowwise_swizzle(B._rowwise_data, B._rowwise_scale_inv) else: - gelu_input = empty_tensor - bias_dtype = TE_DType[bias_dtype] + B._columnwise_scale_inv = tex.columnwise_swizzle( + B._columnwise_data, B._columnwise_scale_inv + ) - out_dtype = TE_DType[out.dtype] if D_dtype is None else D_dtype + return original_scale_inverses - args = ( - A, - A_scale_inv, - A_fp8_tensor, - A_dtype, - True, # transa - B, - B_scale_inv, - B_fp8_tensor, - B_dtype, - False, # transb - out, - empty_tensor if out_index is None else fp8_meta_tensor.scale[out_index], - out_dtype, - empty_tensor if out_index is None else fp8_meta_tensor.amax_history[0][out_index], - bias if use_bias else empty_tensor, - bias_dtype, - gelu_input, # this is pre_gelu_out - False, # grad - workspace, - workspace.shape[0], - accumulate, - use_split_accumulator, - ) - fn = torch.ops.tex_ts.te_gemm_ts - if ub_algo is not None: - assert ub is not None, "ub object is None!" - if ub_algo == tex.CommOverlapAlgo.BULK_OVERLAP_AG: - fn = ub.bulk_overlap - extra_output_tensor = ( - empty_tensor if extra_output_tensor is None else extra_output_tensor - ) - args = tuple( - args - + ( - tex.CommOverlapType.AG, - extra_output_tensor, - ) - ) - elif ub_algo == tex.CommOverlapAlgo.BULK_OVERLAP_RS: - fn = ub.bulk_overlap - extra_output_tensor = ( - empty_tensor if extra_output_tensor is None else extra_output_tensor - ) - args = tuple( - args - + ( - tex.CommOverlapType.RS, - extra_output_tensor, - ) - ) - elif ub_algo == tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P: - fn = ub.split_overlap_ag_p2p - extra_output_tensor = ( - empty_tensor if extra_output_tensor is None else extra_output_tensor - ) - args = tuple(args + (extra_output_tensor,)) - elif ub_algo == tex.CommOverlapAlgo.ATOMIC_GEMM_AG_P2P: - fn = ub.atomic_gemm_overlap_ag_p2p - extra_output_tensor = ( - empty_tensor if extra_output_tensor is None else extra_output_tensor - ) - args = tuple(args + (extra_output_tensor,)) - elif ub_algo == tex.CommOverlapAlgo.SPLIT_PIPELINED_RS: - fn = ub.split_overlap_rs - assert ( - extra_output_tensor is not None - ), "SPLIT_PIPELINED_RS requires extra output tensor" - args = tuple( - args - + ( - True, - extra_output_tensor, - ) - ) - elif ub_algo == tex.CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P: - fn = ub.split_overlap_rs_p2p - assert ( - extra_output_tensor is not None - ), "SPLIT_PIPELINED_RS_P2P requires extra output tensor" - args = tuple(args + (extra_output_tensor,)) - elif ub_algo == tex.CommOverlapAlgo.ATOMIC_GEMM_RS: - fn = ub.atomic_gemm_overlap_rs - assert extra_output_tensor is not None, "ATOMIC_GEMM_RS requires extra output tensor" - args = tuple( - args - + ( - True, - extra_output_tensor, - ) - ) - elif ub_algo == tex.CommOverlapAlgo.ATOMIC_GEMM_RS_P2P: - fn = ub.atomic_gemm_overlap_rs_p2p - assert ( - extra_output_tensor is not None - ), "ATOMIC_GEMM_RS_P2P requires extra output tensor" - args = tuple(args + (extra_output_tensor,)) - _ = fn(*args) - - return out, gelu_input - - -def gemm( + +def reset_swizzled_inputs(A, B, scale_inverses): + """Reset the swizzled scale inverses after GEMM.""" + if scale_inverses is not None: + ( + A._rowwise_scale_inv, + A._columnwise_scale_inv, + B._rowwise_scale_inv, + B._columnwise_scale_inv, + ) = scale_inverses + + +def general_gemm( A: torch.Tensor, B: torch.Tensor, - dtype: torch.dtype, workspace: torch.Tensor, + out_dtype: Optional[torch.dtype] = None, + quantization_params: Optional[Quantizer] = None, gelu: bool = False, - gelu_input: Optional[torch.Tensor] = None, - grad: bool = False, + gelu_in: torch.Tensor = None, accumulate: bool = False, layout: str = "TN", out: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None, - use_bias: bool = False, - ub_algo: tex.CommOverlapAlgo = None, + use_split_accumulator: bool = False, + grad: bool = False, ub: Union[tex.CommOverlap, tex.CommOverlapP2P] = None, - extra_output_tensor: torch.Tensor = None, -) -> Tuple[Union[torch.Tensor, None], ...]: - """Non FP8 GEMM.""" + ub_type: tex.CommOverlapType = None, + extra_output: Optional[torch.Tensor] = None, + bulk_overlap: bool = False, +) -> Iterable[Optional[torch.Tensor]]: + """GEMM supporting fp8 inputs.""" assert layout in ("TN", "NN", "NT"), f"GEMM layout {layout} not supported." transa = layout[0] == "T" transb = layout[1] == "T" - empty_tensor = _empty_tensor() - fp8_index = -1 # dummy index - - if out is None: - out = torch.empty( - B.shape[1] if transb else B.shape[0], - A.shape[0] if transa else A.shape[1], - dtype=dtype, - device="cuda", + # assert quantization_params is None, "FP8 output not supported yet" + + if ub_type is not None: + assert ub is not None, ( + f"{'AG+GEMM' if ub_type == tex.CommOverlapType.AG else 'GEMM+RS'} overlap requires" + + "a valid `ub` communicator object." ) - else: - if not out.is_contiguous(): - raise ValueError("Output tensor is not contiguous.") - if gelu and not grad: - gelu_input = torch.empty_like(out, dtype=dtype) - elif not gelu: - gelu_input = empty_tensor + if ub is not None: + assert ub_type is not None, "Comm+GEMM overlap requires a valid `comm_type` argument." + if ub_type == tex.CommOverlapType.RS: + if not (bulk_overlap and not ub.is_fp8_ubuf()): + assert extra_output is not None, "GEMM+RS overlap requires extra output tensor." - if grad and use_bias: - grad_bias = torch.empty(B.shape[1], dtype=out.dtype, device="cuda") - else: - grad_bias = empty_tensor - - bias = bias if use_bias else empty_tensor + if out is not None: + if not out.is_contiguous(): + raise ValueError("Output tensor is not contiguous.") - assert ( - A.dtype == dtype and B.dtype == dtype - ), f"Expected dtype={dtype}, but found A.dtype={A.dtype} and B.dtype={B.dtype}" - input_dtype = TE_DType[dtype] - output_dtype = TE_DType[out.dtype] - if use_bias: - bias_dtype = TE_DType[grad_bias.dtype] if grad else TE_DType[bias.dtype] - else: - bias_dtype = output_dtype + # Use bfloat16 as default bias_dtype + bias_dtype = TE_DType[torch.bfloat16 if bias is None else bias.dtype] args = ( A, - empty_tensor, - fp8_index, - input_dtype, - transa, + transa, # transa B, - empty_tensor, - fp8_index, - input_dtype, - transb, + transb, # transb out, - empty_tensor, # out_scale - output_dtype, - empty_tensor, # out_amax - grad_bias if grad else bias, + quantization_params, + TE_DType[out_dtype] if out_dtype is not None else None, + bias, bias_dtype, - gelu_input, - grad, + gelu, + gelu_in, + grad, # grad workspace, workspace.shape[0], accumulate, - False, # use_split_accumulator + use_split_accumulator, ) - fn = torch.ops.tex_ts.te_gemm_ts - if ub_algo is not None: - assert ub is not None, "ub object is None!" - if ub_algo == tex.CommOverlapAlgo.BULK_OVERLAP_AG: - fn = ub.bulk_overlap - args = tuple(args + (tex.CommOverlapType.AG, empty_tensor)) - elif ub_algo == tex.CommOverlapAlgo.BULK_OVERLAP_RS: - fn = ub.bulk_overlap - args = tuple(args + (tex.CommOverlapType.RS, empty_tensor)) - elif ub_algo == tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P: - fn = ub.split_overlap_ag_p2p - extra_output_tensor = ( - empty_tensor if extra_output_tensor is None else extra_output_tensor - ) - args = tuple(args + (extra_output_tensor,)) - elif ub_algo == tex.CommOverlapAlgo.SPLIT_PIPELINED_RS: - fn = ub.split_overlap_rs - assert ( - extra_output_tensor is not None - ), "SPLIT_PIPELINED_RS requires extra output tensor" - args = tuple( - args - + ( - False, - extra_output_tensor, - ) - ) - elif ub_algo == tex.CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P: - fn = ub.split_overlap_rs_p2p - assert ( - extra_output_tensor is not None - ), "SPLIT_PIPELINED_RS_P2P requires extra output tensor" - args = tuple(args + (extra_output_tensor,)) - _ = fn(*args) - - return out, grad_bias, gelu_input - - -def grouped_gemm( + kwargs = { + "comm_overlap": ub, + "comm_type": ub_type, + "extra_output": extra_output, + "bulk_overlap": bulk_overlap, + } + + original_scale_inverses = swizzle_inputs(A, B, layout) + out, bias_grad, gelu_input, extra_output = tex.generic_gemm(*args, **kwargs) + reset_swizzled_inputs(A, B, original_scale_inverses) + + return out, bias_grad, gelu_input, extra_output + + +def general_grouped_gemm( A: List[torch.Tensor], B: List[torch.Tensor], out: List[torch.Tensor], - dtype: torch.dtype, + out_dtype: torch.dtype, workspaces: List[torch.Tensor], + layout: str = "TN", + m_splits: Optional[List[int]] = None, gelu: bool = False, - gelu_input: Optional[List[torch.Tensor]] = None, - grad: bool = False, + grad=False, accumulate: bool = False, - layout: str = "TN", bias: Optional[List[torch.Tensor]] = None, use_bias: bool = False, + use_split_accumulator: bool = False, + D_dtype: Optional[tex.DType] = None, + single_output=False, ) -> Tuple[List[torch.Tensor], ...]: - """Non FP8 Grouped GEMM.""" + """ + TN layout Grouped GEMM with fp8 inputs. + """ + num_gemms = len(A) - assert layout in ("TN", "NN", "NT"), f"GEMM layout {layout} not supported." transa = layout[0] == "T" transb = layout[1] == "T" - num_gemms = len(A) + + # assert [a.is_contiguous() for a in A] + # assert [b.is_contiguous() for b in B] + + if isinstance(A[0], Float8TensorBase): + for a, b in zip(A, B): + assert_dim_for_fp8_exec(a._data) + assert_dim_for_fp8_exec(b._data) + empty_tensor = _empty_tensor() empty_tensors = [empty_tensor] * num_gemms - if gelu and not grad: - gelu_input = [ - torch.empty_like(o, dtype=dtype, memory_format=torch.contiguous_format) for o in out - ] - elif not gelu: - gelu_input = empty_tensors + # Use bfloat16 as default bias_dtype + gelu_input = empty_tensors + out_dtype = TE_DType[out[0].dtype] if D_dtype is None else D_dtype + sm_count = get_sm_count() if grad and use_bias: grad_bias = [ torch.empty(B[i].shape[1], dtype=out[0].dtype, device="cuda") for i in range(num_gemms) ] else: grad_bias = empty_tensors - bias = bias if use_bias else empty_tensors - - assert ( - A[0].dtype == dtype and B[0].dtype == dtype - ), f"Expected dtype={dtype}, but found A.dtype={A[0].dtype} and B.dtype={B[0].dtype}" - input_dtype = TE_DType[dtype] - output_dtype = TE_DType[out[0].dtype] if use_bias: bias_dtype = TE_DType[grad_bias[0].dtype] if grad else TE_DType[bias[0].dtype] else: - bias_dtype = output_dtype + bias_dtype = TE_DType[torch.bfloat16] - torch.ops.tex_ts.te_grouped_gemm_ts( + if gelu: + gelu_input = [ + torch.empty_like(o, dtype=bias_dtype, memory_format=torch.contiguous_format) + for o in out + ] # this should differ with respect to single output + + bias = tex.te_general_grouped_gemm( A, - empty_tensor, - 0, # A_offset - input_dtype, transa, B, - empty_tensor, - 0, # B_offset - input_dtype, transb, out, - 0, # out_offset - empty_tensor, # out_scale - output_dtype, - empty_tensor, # out_amax + out_dtype, + m_splits, grad_bias if grad else bias, bias_dtype, - gelu_input, # gelu_input - grad, + single_output, + gelu_input, # this is pre_gelu_out + grad, # grad workspaces, workspaces[0].shape[0], accumulate, - False, # use_split_accumulator + use_split_accumulator, + sm_count - int(os.getenv("NVTE_EXT_MARGIN_SM", str(sm_count))), ) - return out, grad_bias, gelu_input - - -def fp8_grouped_gemm( - A: List[torch.Tensor], - A_scale_inv: List[torch.Tensor], - A_fp8_tensor_offset: int, - A_dtype: tex.DType, - B: List[torch.Tensor], - B_scale_inv: torch.Tensor, - B_fp8_tensor_offset: int, - B_dtype: tex.DType, - out: List[torch.Tensor], - out_dtype: torch.dtype, - workspaces: List[torch.Tensor], - m_splits: Optional[List[int]] = None, - out_offset: Optional[int] = None, - fp8_meta_tensor: tex.FP8TensorMeta = None, - gelu: bool = False, - accumulate: bool = False, - bias: Optional[List[torch.Tensor]] = None, - use_bias: bool = False, - use_split_accumulator: bool = False, - D_dtype: Optional[tex.DType] = None, -) -> Tuple[List[torch.Tensor], ...]: - """ - TN layout Grouped GEMM with fp8 inputs. - Input requirements: - 1. If len(A_scale_inv) == num_gemms, len(out) must be 1, and m_splits is not None. - This is used for the calculation of output (fwd) and dgrad (bwd). - 2. if len(A_scale_inv) == 1, len(out) must be num_gemms. This is used for the - calculation of wgrad. - """ - num_gemms = len(A) - if num_gemms > 1 and len(A_scale_inv) == num_gemms: - assert len(out) == 1 and m_splits is not None - elif num_gemms > 1 and len(A_scale_inv) == 1: - assert len(out) == num_gemms - elif num_gemms == 1: - assert len(A_scale_inv) == 1 and len(out) == 1 - else: - raise ValueError("Invalid input combinations of A_scale_inv and out.") - - empty_tensor = _empty_tensor() - empty_tensors = [empty_tensor] * num_gemms - if D_dtype is not None and D_dtype in [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]: - assert fp8_meta_tensor is not None and out_offset is not None - for a, b in zip(A, B): - assert_dim_for_fp8_exec(a) - assert_dim_for_fp8_exec(b) - assert A[0].dtype == torch.uint8 - assert B[0].dtype == torch.uint8 - - # Use bfloat16 as default bias_dtype - bias_dtype = torch.bfloat16 if bias is None else bias[0].dtype - bias_dtype = TE_DType[bias_dtype] - gelu_input = empty_tensors - out_dtype = TE_DType[out[0].dtype] if D_dtype is None else D_dtype - - if len(A_scale_inv) == 1: - if gelu: - gelu_input = [ - torch.empty_like(o, dtype=bias_dtype, memory_format=torch.contiguous_format) - for o in out - ] - - torch.ops.tex_ts.te_grouped_gemm_ts( - A, - A_scale_inv[0], - A_fp8_tensor_offset, - A_dtype, - True, # transa - B, - B_scale_inv, - B_fp8_tensor_offset, - B_dtype, - False, # transb - out, - 0 if out_offset is None else out_offset, - empty_tensor if out_offset is None else fp8_meta_tensor.scale, - out_dtype, - empty_tensor if out_offset is None else fp8_meta_tensor.amax_history, - bias if use_bias else empty_tensors, - bias_dtype, - gelu_input, # this is pre_gelu_out - False, # grad - workspaces, - workspaces[0].shape[0], - accumulate, - use_split_accumulator, - ) - else: - if gelu: - gelu_input = [torch.empty((m, A[0].size(0)), dtype=bias_dtype) for m in m_splits] - - torch.ops.tex_ts.te_grouped_gemm_single_output_ts( - A, - A_scale_inv, - A_fp8_tensor_offset, - A_dtype, - True, # transa - B, - B_scale_inv, - B_fp8_tensor_offset, - B_dtype, - False, # transb - m_splits, - out[0], - 0 if out_offset is None else out_offset, - empty_tensor if out_offset is None else fp8_meta_tensor.scale, - out_dtype, - empty_tensor if out_offset is None else fp8_meta_tensor.amax_history, - bias if use_bias else empty_tensors, - bias_dtype, - gelu_input, # this is pre_gelu_out - False, # grad - workspaces, - workspaces[0].shape[0], - accumulate, - use_split_accumulator, - ) - - return out, gelu_input + return out, bias, gelu_input diff --git a/transformer_engine/pytorch/cpp_extensions/normalization.py b/transformer_engine/pytorch/cpp_extensions/normalization.py deleted file mode 100644 index f997a8a536..0000000000 --- a/transformer_engine/pytorch/cpp_extensions/normalization.py +++ /dev/null @@ -1,260 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -"""Python interface for normalization extensions""" -from typing import Optional, Tuple, Union - -import torch - -import transformer_engine_torch as tex -from ._common import canonicalize_fp8_scales - - -__all__ = [ - "layernorm_fwd_fp8", - "layernorm_fwd_fp8_inf", - "layernorm_fwd_inf", - "rmsnorm_fwd_fp8", - "rmsnorm_fwd_fp8_inf", - "rmsnorm_fwd_inf", -] - - -def layernorm_fwd_fp8( - inp: torch.Tensor, - weight: torch.Tensor, - bias: torch.Tensor, - eps: float, - fp8_meta_tensor: Optional[tex.FP8TensorMeta], - fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None], - otype: tex.DType, - sm_margin: int, - zero_centered_gamma: bool, - ln_out: Optional[torch.Tensor] = None, - scale: Optional[torch.Tensor] = None, - amax: Optional[torch.Tensor] = None, - scale_inv: Optional[torch.Tensor] = None, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """LayerNorm with FP8 output""" - - # Get FP8 scaling factors - fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales( - scale=scale, - amax=amax, - scale_inv=scale_inv, - fp8_meta=fp8_meta_tensor, - fp8_meta_index=fp8_tensor, - ) - - # Launch kernel - if ln_out is not None: - return tex.layernorm_fwd_fp8_noalloc( - inp, - weight, - bias, - eps, - fp8_scales["scale"], - ln_out, - fp8_scales["amax"], - fp8_scales["scale_inv"], - otype, - sm_margin, - zero_centered_gamma, - **fp8_scales_offsets, - ) - return tex.layernorm_fwd_fp8( - inp, - weight, - bias, - eps, - fp8_scales["scale"], - fp8_scales["amax"], - fp8_scales["scale_inv"], - otype, - sm_margin, - zero_centered_gamma, - **fp8_scales_offsets, - ) - - -def layernorm_fwd_fp8_inf( - inp: torch.Tensor, - weight: torch.Tensor, - bias: torch.Tensor, - eps: float, - fp8_meta_tensor: Optional[tex.FP8TensorMeta], - fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None], - otype: tex.DType, - sm_margin: int, - zero_centered_gamma, - scale: Optional[torch.Tensor] = None, - amax: Optional[torch.Tensor] = None, - scale_inv: Optional[torch.Tensor] = None, -) -> torch.Tensor: - """LayerNorm with FP8 output. - - This version of layernorm_fwd_fp8 is specialized for inference, and returns - only the normalized output. - """ - - # Get FP8 scaling factors - fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales( - scale=scale, - amax=amax, - scale_inv=scale_inv, - fp8_meta=fp8_meta_tensor, - fp8_meta_index=fp8_tensor, - allow_multiple_offsets=False, - ) - - # Launch kernel - ret = torch.ops.tex_ts.layernorm_fwd_fp8_inf_ts( - inp, - weight, - bias, - eps, - fp8_scales["scale"], - fp8_scales["amax"], - fp8_scales["scale_inv"], - fp8_scales_offsets["scale_offset"], - otype, - sm_margin, - zero_centered_gamma, - ) - return ret - - -def layernorm_fwd_inf( - inp: torch.Tensor, - weight: torch.Tensor, - bias: torch.Tensor, - eps: float, - sm_margin: int, - zero_centered_gamma: bool, -) -> torch.Tensor: - """LayerNorm with FP8 output""" - return torch.ops.tex_ts.layernorm_fwd_inf_ts( - inp, - weight, - bias, - eps, - sm_margin, - zero_centered_gamma, - ) - - -def rmsnorm_fwd_fp8( - inp: torch.Tensor, - weight: torch.Tensor, - eps: float, - fp8_meta_tensor: Optional[tex.FP8TensorMeta], - fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None], - otype: tex.DType, - sm_margin: int, - zero_centered_gamma: bool, - rmsnorm_out: Optional[torch.Tensor] = None, - scale: Optional[torch.Tensor] = None, - amax: Optional[torch.Tensor] = None, - scale_inv: Optional[torch.Tensor] = None, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """RMSNorm with FP8 output""" - - # Get FP8 scaling factors - fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales( - scale=scale, - amax=amax, - scale_inv=scale_inv, - fp8_meta=fp8_meta_tensor, - fp8_meta_index=fp8_tensor, - ) - - # Launch kernel - if rmsnorm_out is not None: - return tex.rmsnorm_fwd_fp8_noalloc( - inp, - weight, - eps, - fp8_scales["scale"], - rmsnorm_out, - fp8_scales["amax"], - fp8_scales["scale_inv"], - otype, - sm_margin, - zero_centered_gamma, - **fp8_scales_offsets, - ) - return tex.rmsnorm_fwd_fp8( - inp, - weight, - eps, - fp8_scales["scale"], - fp8_scales["amax"], - fp8_scales["scale_inv"], - otype, - sm_margin, - zero_centered_gamma, - **fp8_scales_offsets, - ) - - -def rmsnorm_fwd_fp8_inf( - inp: torch.Tensor, - weight: torch.Tensor, - eps: float, - fp8_meta_tensor: Optional[tex.FP8TensorMeta], - fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None], - otype: tex.DType, - sm_margin: int, - zero_centered_gamma, - scale: Optional[torch.Tensor] = None, - amax: Optional[torch.Tensor] = None, - scale_inv: Optional[torch.Tensor] = None, -) -> torch.Tensor: - """RMSNorm with FP8 output. - - This version of rmsnorm_fwd_fp8 is specialized for inference, and returns - only the normalized output. - """ - - # Get FP8 scaling factors - fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales( - scale=scale, - amax=amax, - scale_inv=scale_inv, - fp8_meta=fp8_meta_tensor, - fp8_meta_index=fp8_tensor, - allow_multiple_offsets=False, - ) - - # Launch kernel - ret = torch.ops.tex_ts.rmsnorm_fwd_fp8_inf_ts( - inp, - weight, - eps, - fp8_scales["scale"], - fp8_scales["amax"], - fp8_scales["scale_inv"], - fp8_scales_offsets["scale_offset"], - otype, - sm_margin, - zero_centered_gamma, - ) - return ret - - -def rmsnorm_fwd_inf( - inp: torch.Tensor, - weight: torch.Tensor, - eps: float, - sm_margin: int, - zero_centered_gamma: bool, -) -> torch.Tensor: - """RMSNorm with FP8 output""" - return torch.ops.tex_ts.rmsnorm_fwd_inf_ts( - inp, - weight, - eps, - sm_margin, - zero_centered_gamma, - ) diff --git a/transformer_engine/pytorch/cpp_extensions/padding.py b/transformer_engine/pytorch/cpp_extensions/padding.py deleted file mode 100644 index cf704d06ee..0000000000 --- a/transformer_engine/pytorch/cpp_extensions/padding.py +++ /dev/null @@ -1,29 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -"""Python interface for transpose extensions""" -from typing import List, Tuple, Union -import torch -import transformer_engine_torch as tex - - -__all__ = [ - "multi_padding_fused", -] - - -def multi_padding_fused( - inp: torch.Tensor, - row_list: List[int], - padded_row_list: List[int], - out: torch.Tensor, -) -> Union[Tuple[List[torch.Tensor], List[torch.Tensor]], None]: - """Padding""" - - tex.fused_multi_row_padding( - inp, - out, - row_list, - padded_row_list, - ) diff --git a/transformer_engine/pytorch/cpp_extensions/transpose.py b/transformer_engine/pytorch/cpp_extensions/transpose.py deleted file mode 100644 index 77bf0019af..0000000000 --- a/transformer_engine/pytorch/cpp_extensions/transpose.py +++ /dev/null @@ -1,230 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -"""Python interface for transpose extensions""" -from typing import List, Optional, Tuple, Union - -import torch - -import transformer_engine_torch as tex -from ..constants import TE_DType -from ._common import canonicalize_fp8_scales, empty_tensor - - -__all__ = [ - "fp8_cast_transpose_fused", - "fp8_cast_transpose_bgrad_fused", - "fp8_cast_transpose_bgrad_dgelu_fused", - "fp8_dswiglu_cast_transpose_fused", - "fp8_multi_cast_transpose_fused", - "fp8_transpose_bgrad_fused", -] - - -def fp8_cast_transpose_fused( - inp: torch.Tensor, - fp8_meta_tensor: Optional[tex.FP8TensorMeta], - fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None], - otype: tex.DType, - cast_out: Optional[torch.Tensor] = None, - transpose_out: Optional[torch.Tensor] = None, - scale: Optional[torch.Tensor] = None, - amax: Optional[torch.Tensor] = None, - scale_inv: Optional[torch.Tensor] = None, - noop_flag: Optional[torch.Tensor] = None, -) -> Tuple[torch.Tensor, torch.Tensor]: - """Cast + Transpose with FP8 output""" - - # Allocate outputs if needed - if transpose_out is None: - transpose_out = torch.empty(inp.shape[1], inp.shape[0], device="cuda", dtype=torch.uint8) - if cast_out is None: - cast_out = torch.empty_like(inp, dtype=torch.uint8) - - # Get FP8 scaling factors - fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales( - scale=scale, - amax=amax, - scale_inv=scale_inv, - fp8_meta=fp8_meta_tensor, - fp8_meta_index=fp8_tensor, - ) - - # Construct no-op flag if needed - if noop_flag is None: - noop_flag = empty_tensor() - - # Launch kernel if needed - if inp.nelement() > 0: - tex.fused_cast_transpose_noop( - inp, - noop_flag, - fp8_scales["scale"], - fp8_scales["amax"], - fp8_scales["scale_inv"], - cast_out, - transpose_out, - otype, - **fp8_scales_offsets, - ) - - return cast_out, transpose_out - - -def fp8_cast_transpose_bgrad_fused( - inp: torch.Tensor, - fp8_meta_tensor: Optional[tex.FP8TensorMeta], - fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None], - otype: tex.DType, - scale: Optional[torch.Tensor] = None, - amax: Optional[torch.Tensor] = None, - scale_inv: Optional[torch.Tensor] = None, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Cast + Transpose + BGRAD with FP8 output""" - - # Get FP8 scaling factors - fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales( - scale=scale, - amax=amax, - scale_inv=scale_inv, - fp8_meta=fp8_meta_tensor, - fp8_meta_index=fp8_tensor, - ) - - # Launch kernel - return tex.fused_cast_transpose_bgrad( - inp, - fp8_scales["scale"], - fp8_scales["amax"], - fp8_scales["scale_inv"], - otype, - **fp8_scales_offsets, - ) - - -def fp8_transpose_bgrad_fused( - inp: torch.Tensor, - fp8_meta_tensor: Optional[tex.FP8TensorMeta], - fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None], - otype: tex.DType, - grad_bias_type: torch.dtype, - scale: Optional[torch.Tensor] = None, - amax: Optional[torch.Tensor] = None, - scale_inv: Optional[torch.Tensor] = None, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Transpose + BGRAD with FP8 output""" - - # Get FP8 scaling factors - fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales( - scale=scale, - amax=amax, - scale_inv=scale_inv, - fp8_meta=fp8_meta_tensor, - fp8_meta_index=fp8_tensor, - ) - - # Launch kernel - return tex.fused_fp8_transpose_bgrad( - inp, - fp8_scales["scale"], - fp8_scales["amax"], - fp8_scales["scale_inv"], - otype, - TE_DType[grad_bias_type], - **fp8_scales_offsets, - ) - - -def fp8_cast_transpose_bgrad_dgelu_fused( - grad_output: torch.Tensor, - gelu_input: torch.Tensor, - fp8_meta_tensor: tex.FP8TensorMeta, - fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], - otype: tex.DType, - scale: Optional[torch.Tensor] = None, - amax: Optional[torch.Tensor] = None, - scale_inv: Optional[torch.Tensor] = None, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Cast + Transpose + BGRAD + DGELU with FP8 output""" - - # Get FP8 scaling factors - fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales( - scale=scale, - amax=amax, - scale_inv=scale_inv, - fp8_meta=fp8_meta_tensor, - fp8_meta_index=fp8_tensor, - ) - - # Launch kernel - return tex.fused_cast_transpose_bgrad_dgelu( - grad_output, - gelu_input, - fp8_scales["scale"], - fp8_scales["amax"], - fp8_scales["scale_inv"], - otype, - **fp8_scales_offsets, - ) - - -def fp8_dswiglu_cast_transpose_fused( - grad_output: torch.Tensor, - inp: torch.Tensor, - *, - grad_input: torch.Tensor, - grad_input_transpose: torch.Tensor, - otype: tex.DType, - fp8_meta: Optional[tex.FP8TensorMeta] = None, - fp8_meta_index: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None] = None, - scale: Optional[torch.Tensor] = None, - amax: Optional[torch.Tensor] = None, - scale_inv: Optional[torch.Tensor] = None, -) -> None: - """Fused SwiGLU backward + FP8 cast + FP8 transpose""" - - # Get FP8 scaling factors - fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales( - scale=scale, - amax=amax, - scale_inv=scale_inv, - fp8_meta=fp8_meta, - fp8_meta_index=fp8_meta_index, - ) - - # Launch kernel - return tex.fused_dswiglu_cast_transpose( - grad_output, - inp, - grad_input, - grad_input_transpose, - fp8_scales["scale"], - fp8_scales["amax"], - fp8_scales["scale_inv"], - otype, - **fp8_scales_offsets, - ) - - -def fp8_multi_cast_transpose_fused( - input_list: List[torch.Tensor], - fp8_meta_tensor: tex.FP8TensorMeta, - scale_indices: List[int], - amax_indices: List[int], - scale_inv_indices: List[int], - otype: tex.DType, - scale_inv: Optional[torch.Tensor] = None, -) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: - """Cast + Transpose with FP8 output""" - - return tex.fused_multi_cast_transpose_alloc( - input_list, - fp8_meta_tensor.scale, - fp8_meta_tensor.amax_history, - scale_inv if scale_inv is not None else fp8_meta_tensor.scale_inv, - scale_indices, - amax_indices, - scale_inv_indices, - otype, - ) diff --git a/transformer_engine/pytorch/cpu_offload.py b/transformer_engine/pytorch/cpu_offload.py index 2c8736ee09..33de562a89 100644 --- a/transformer_engine/pytorch/cpu_offload.py +++ b/transformer_engine/pytorch/cpu_offload.py @@ -9,13 +9,27 @@ import torch -from .float8_tensor import Float8Tensor +from .tensor.float8_tensor import Float8Tensor __all__ = ["get_cpu_offload_context"] CPUOffloadEnabled = False +def set_offloading_param(tensor, param_name, value): + """Set the type of the offloading needed for a tensor.""" + assert param_name in ["weight_offloading", "activation_offloading"] + if tensor is None: + return + if type(tensor) in [torch.Tensor, torch.nn.Parameter]: + setattr(tensor, param_name, value) + else: + data_tensors = tensor.get_data_tensors() + for tensor in data_tensors: + if tensor is not None: + setattr(tensor, param_name, value) + + def is_cpu_offload_enabled() -> bool: """Check if CPU offloading is currently enabled.""" return CPUOffloadEnabled @@ -258,6 +272,7 @@ def tensor_push(self, tensor: torch.Tensor, **kwargs): else: # will be offloaded together after group commit self.tensor_tag_to_state[tensor_tag] = tensor + return tensor_tag def tensor_pop(self, tensor_tag, **kwargs): @@ -366,6 +381,7 @@ def bulk_offload_group(self, group_to_offload): if self.tensor_need_offloading_checker(tensor_on_device): state = SynchronizedGroupOffloadHandler.offload(tensor_on_device) self.tensor_tag_to_state[tensor_tag] = state + tensor_on_device.data = torch.Tensor() # Force to release memory def synchronize_on_group_commit_forward(self, current_group): """Synchronize on group commit forward.""" diff --git a/transformer_engine/pytorch/csrc/common.cpp b/transformer_engine/pytorch/csrc/common.cpp index eb97dc36eb..5775fe381d 100644 --- a/transformer_engine/pytorch/csrc/common.cpp +++ b/transformer_engine/pytorch/csrc/common.cpp @@ -6,7 +6,33 @@ #include "common.h" +#include "c10/util/ArrayRef.h" +#include "pybind.h" #include "transformer_engine/transformer_engine.h" +namespace transformer_engine::pytorch { + +std::vector getTensorShape(at::Tensor t) { + std::vector shape; + for (auto s : t.sizes()) { + shape.push_back(s); + } + return shape; +} + +std::unique_ptr convert_quantizer(py::handle quantizer) { + init_extension(); + if (quantizer.is_none()) { + return std::make_unique(quantizer); + } + for (auto [_check_type, check_quantizer_type, _create_tensor, create_quantizer] : + detail::custom_types_converters) { + if (check_quantizer_type(quantizer.ptr())) { + return create_quantizer(quantizer); + } + } + + NVTE_ERROR("Unexpected type for quantizer"); +} transformer_engine::DType getTransformerEngineFP8Type(bool e4m3_if_hybrid, const std::string& fp8_recipe) { @@ -17,6 +43,34 @@ transformer_engine::DType getTransformerEngineFP8Type(bool e4m3_if_hybrid, return transformer_engine::DType::kFloat8E5M2; } +TensorWrapper makeTransformerEngineTensor(py::handle tensor, py::handle quantizer) { + NVTE_CHECK(!tensor.is_none(), "Tensor is not allocated!"); + std::unique_ptr my_quantizer = convert_quantizer(quantizer); + for (auto [check_type, check_quantizer_type, create_tensor, _] : + detail::custom_types_converters) { + if (check_type(tensor.ptr())) { + NVTE_CHECK(quantizer.is_none() || check_quantizer_type(quantizer.ptr()), + "Unexpected quantization params type."); + auto x = create_tensor(tensor, my_quantizer.get()); + return x; + } + } + + // Regular pyTorch tensor + at::Tensor torch_tensor = tensor.cast(); + + // #TODO (pgadzinski) - needed in attention for non-contiguous tensors. + //if (!torch_tensor.is_contiguous()) { + // torch_tensor = torch_tensor.contiguous(); + //} + auto ret = TensorWrapper(my_quantizer->get_scaling_mode()); + ret.set_rowwise_data(torch_tensor.data_ptr(), + GetTransformerEngineDType(torch_tensor.scalar_type()), + getTensorShape(torch_tensor)); + my_quantizer->set_quantization_params(&ret); + return ret; +} + transformer_engine::TensorWrapper makeTransformerEngineTensor( void* data_ptr, const NVTEShape& shape, const transformer_engine::DType type) { return transformer_engine::TensorWrapper(data_ptr, shape, type); @@ -30,48 +84,95 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor( transformer_engine::TensorWrapper makeTransformerEngineTensor(at::Tensor tensor) { transformer_engine::DType dtype = GetTransformerEngineDType(tensor.scalar_type()); std::vector shape; - for (auto s : tensor.sizes()) { shape.push_back(s); } return makeTransformerEngineTensor(tensor.data_ptr(), shape, dtype); } -transformer_engine::TensorWrapper makeTransformerEngineTensor(void* data_ptr, - const std::vector& shape, - const transformer_engine::DType type, - void* amax_ptr, void* scale_ptr, - void* scale_inv_ptr) { - return transformer_engine::TensorWrapper( - data_ptr, shape, type, reinterpret_cast(amax_ptr), - reinterpret_cast(scale_ptr), reinterpret_cast(scale_inv_ptr)); +transformer_engine::TensorWrapper makeTransformerEngineTensor( + void* data_ptr, const std::vector& shape, const transformer_engine::DType type, + void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, std::vector scale_inv_shape, + NVTEScalingMode scaling_mode) { + TensorWrapper ret(scaling_mode); + ret.set_rowwise_data(data_ptr, type, shape); + const std::vector meta_shape{1}; + ret.set_amax(amax_ptr, DType::kFloat32, meta_shape); + ret.set_scale(scale_ptr, DType::kFloat32, meta_shape); + auto scale_inv_dtype = + (scaling_mode == NVTE_MXFP8_1D_SCALING) ? DType::kFloat8E8M0 : DType::kFloat32; + ret.set_rowwise_scale_inv(scale_inv_ptr, scale_inv_dtype, scale_inv_shape); + return ret; +} + +transformer_engine::TensorWrapper makeTransformerEngineTensor( + void* data_ptr, void* columnwise_data_ptr, const std::vector& shape, + const std::vector& columnwise_shape, const transformer_engine::DType type, + void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, void* columnwise_scale_inv_ptr, + const std::vector& scale_inv_shape, + const std::vector& columnwise_scale_inv_shape, NVTEScalingMode scaling_mode) { + TensorWrapper ret(scaling_mode); + ret.set_rowwise_data(data_ptr, type, shape); + ret.set_columnwise_data(columnwise_data_ptr, type, columnwise_shape); + const std::vector meta_shape{1}; + ret.set_amax(amax_ptr, DType::kFloat32, meta_shape); + ret.set_scale(scale_ptr, DType::kFloat32, meta_shape); + auto scale_inv_dtype = + (scaling_mode == NVTE_MXFP8_1D_SCALING) ? DType::kFloat8E8M0 : DType::kFloat32; + ret.set_rowwise_scale_inv(scale_inv_ptr, scale_inv_dtype, scale_inv_shape); + ret.set_columnwise_scale_inv(columnwise_scale_inv_ptr, scale_inv_dtype, + columnwise_scale_inv_shape); + return ret; } transformer_engine::TensorWrapper makeTransformerEngineTensor(at::Tensor tensor, at::Tensor amax, const at::Tensor scale, - at::Tensor scale_inv) { + at::Tensor scale_inv, + NVTEScalingMode scaling_mode) { transformer_engine::DType dtype = GetTransformerEngineDType(tensor.scalar_type()); - std::vector shape; - for (auto s : tensor.sizes()) { - shape.push_back(s); - } + auto tensor_shape = getTensorShape(tensor); + auto scale_inv_shape = getTensorShape(scale_inv); + NVTE_CHECK(amax.scalar_type() == at::kFloat); NVTE_CHECK(scale.scalar_type() == at::kFloat); NVTE_CHECK(scale_inv.scalar_type() == at::kFloat); - return makeTransformerEngineTensor(tensor.data_ptr(), shape, dtype, amax.data_ptr(), - scale.data_ptr(), scale_inv.data_ptr()); + return makeTransformerEngineTensor(tensor.data_ptr(), tensor_shape, dtype, amax.data_ptr(), + scale.data_ptr(), scale_inv.data_ptr(), scale_inv_shape, + scaling_mode); } -size_t product(const std::vector& shape) { - size_t ret = 1; +template +T product(const std::vector& shape) { + T ret = 1; for (auto s : shape) { ret *= s; } return ret; } +template size_t product(const std::vector& shape); +template int64_t product(const std::vector& shape); + +size_t product(const NVTEShape& shape, size_t begin, size_t end) { + NVTE_CHECK(begin <= end && end <= shape.ndim, "Attempted to access entries ", begin, " to ", end, + " in a shape with ", shape.ndim, " entries"); + size_t ret = 1; + for (size_t i = begin; i < end; ++i) { + ret *= shape.data[i]; + } + return ret; +} + +std::vector nvte_shape_to_vector(const NVTEShape& nvte_shape) { + std::vector shape; + for (size_t i = 0; i < nvte_shape.ndim; i++) { + shape.push_back(nvte_shape.data[i]); + } + return shape; +} + at::Tensor allocateSpace(const std::vector& shape, const transformer_engine::DType type, bool init_to_zeros) { std::vector shape_int64(shape.begin(), shape.end()); @@ -121,3 +222,14 @@ void* getDataPtr(at::Tensor tensor, int offset) { } return dptr; } + +std::vector convertShape(const NVTEShape& shape) { + return std::vector(shape.data, shape.data + shape.ndim); +} + +int roundup(const int value, const int multiple) { + assert(multiple > 0); + return ((value + multiple - 1) / multiple) * multiple; +} + +} // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 94e1f7569a..40245cf2d9 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -33,23 +33,22 @@ #include #include #include +#include #include #include #include #include #include -#include #include #include -#include -#include #include #include +#include "c10/util/ArrayRef.h" #include "common/util/logging.h" -namespace transformer_engine { +namespace transformer_engine::pytorch { // Each tensor here is shape (N, ) holding all scaling // data for a single FP8 block, e.g. LayerNormLinear @@ -85,7 +84,76 @@ enum FP8BwdTensors { GRAD_INPUT3 = 5 }; -} // namespace transformer_engine +class Quantizer { + public: + virtual NVTEScalingMode get_scaling_mode() const = 0; + + virtual void set_quantization_params(TensorWrapper* tensor) const = 0; + + virtual std::pair create_tensor( + const std::vector& shape, DType dtype, + std::optional rowwise_data = std::nullopt) const = 0; + + virtual ~Quantizer() = default; + + bool rowwise_usage = true; + bool columnwise_usage = true; + bool internal = false; + py::handle quantizer; + + protected: + explicit Quantizer(const py::handle& quantizer); +}; + +class NoneQuantizer : public Quantizer { + public: + explicit NoneQuantizer(const py::handle& quantizer) : Quantizer(quantizer) {} + + NVTEScalingMode get_scaling_mode() const override { return NVTE_DELAYED_TENSOR_SCALING; } + + void set_quantization_params(TensorWrapper* tensor) const override {} + + std::pair create_tensor( + const std::vector& shape, DType dtype, + std::optional rowwise_data = std::nullopt) const override; +}; + +class Float8Quantizer : public Quantizer { + public: + at::Tensor scale; + at::Tensor scale_inv; + at::Tensor amax; + DType dtype; + + explicit Float8Quantizer(const py::handle& quantizer); + + NVTEScalingMode get_scaling_mode() const override { return NVTE_DELAYED_TENSOR_SCALING; } + + void set_quantization_params(TensorWrapper* tensor) const override; + + std::pair create_tensor( + const std::vector& shape, DType dtype, + std::optional rowwise_data = std::nullopt) const override; +}; + +class MXFP8Quantizer : public Quantizer { + public: + DType dtype; + + explicit MXFP8Quantizer(const py::handle& quantizer); + + NVTEScalingMode get_scaling_mode() const override { return NVTE_MXFP8_1D_SCALING; } + + void set_quantization_params(TensorWrapper* tensor) const override; + + std::pair create_tensor( + const std::vector& shape, DType dtype, + std::optional rowwise_data = std::nullopt) const override; +}; + +std::unique_ptr convert_quantizer(py::handle quantizer); + +std::vector getTensorShape(at::Tensor t); transformer_engine::DType getTransformerEngineFP8Type(bool e4m3_if_hybrid, const std::string& fp8_recipe); @@ -103,9 +171,11 @@ inline at::ScalarType GetATenDType(transformer_engine::DType t) { case transformer_engine::DType::kBFloat16: return at::kBFloat16; case transformer_engine::DType::kByte: + return at::kByte; case transformer_engine::DType::kFloat8E4M3: + return at::kFloat8_e4m3fn; case transformer_engine::DType::kFloat8E5M2: - return at::kByte; + return at::kFloat8_e5m2; default: NVTE_ERROR("Invalid type"); } @@ -113,6 +183,10 @@ inline at::ScalarType GetATenDType(transformer_engine::DType t) { inline transformer_engine::DType GetTransformerEngineDType(at::ScalarType t) { switch (t) { + case at::kFloat8_e4m3fn: + return transformer_engine::DType::kFloat8E4M3; + case at::kFloat8_e5m2: + return transformer_engine::DType::kFloat8E5M2; case at::kHalf: return transformer_engine::DType::kFloat16; case at::kFloat: @@ -128,6 +202,7 @@ inline transformer_engine::DType GetTransformerEngineDType(at::ScalarType t) { case torch::kInt64: return transformer_engine::DType::kInt64; default: + std::cout << "Type: " << static_cast(t) << std::endl; NVTE_ERROR("Invalid type"); } } @@ -140,11 +215,18 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor(void* data_ptr, const std::vector& shape, const transformer_engine::DType type); -transformer_engine::TensorWrapper makeTransformerEngineTensor(void* data_ptr, - const std::vector& shape, - const transformer_engine::DType type, - void* amax_ptr, void* scale_ptr, - void* scale_inv_ptr); +transformer_engine::TensorWrapper makeTransformerEngineTensor( + void* data_ptr, const std::vector& shape, const transformer_engine::DType type, + void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, std::vector scale_inv_shape = {1}, + NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING); + +transformer_engine::TensorWrapper makeTransformerEngineTensor( + void* data_ptr, void* columnwise_data_ptr, const std::vector& shape, + const std::vector& columnwise_shape, const transformer_engine::DType type, + void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, void* columnwise_scale_inv_ptr, + const std::vector& scale_inv_shape = {1}, + const std::vector& columnwise_scale_inv_shape = {1}, + NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING); transformer_engine::TensorWrapper makeTransformerEngineTensor(void* data_ptr, const NVTEShape& shape, @@ -152,11 +234,18 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor(void* data_ptr, transformer_engine::TensorWrapper makeTransformerEngineTensor(at::Tensor tensor); -transformer_engine::TensorWrapper makeTransformerEngineTensor(at::Tensor tensor, at::Tensor amax, - const at::Tensor scale, - at::Tensor scale_inv); +TensorWrapper makeTransformerEngineTensor(py::handle tensor, py::handle quantizer); + +transformer_engine::TensorWrapper makeTransformerEngineTensor( + at::Tensor tensor, at::Tensor amax, const at::Tensor scale, at::Tensor scale_inv, + NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING); + +template +T product(const std::vector& shape); -size_t product(const std::vector& shape); +size_t product(const NVTEShape& shape, size_t begin, size_t end); + +std::vector nvte_shape_to_vector(const NVTEShape& nvte_shape); at::Tensor allocateSpace(const std::vector& shape, const transformer_engine::DType type, bool init_to_zeros); @@ -170,4 +259,54 @@ at::Tensor allocateTorchTensor(int M, transformer_engine::DType dtype); void* getDataPtr(at::Tensor tensor, int offset = 0); +std::vector convertShape(const NVTEShape& shape); + +int roundup(const int value, const int multiple); + +} // namespace transformer_engine::pytorch + +namespace std { +template +string to_string(const vector& vec) { + string ret = "["; + for (const auto& val : vec) { + ret += to_string(val) + ","; + } + if (ret.size() > 1) { + ret[ret.size() - 1] = ']'; + } else { + ret += "]"; + } + return ret; +} + +// Torch shape -> string +template +string to_string(const c10::ArrayRef& vec) { + string ret = "["; + for (const auto& val : vec) { + ret += to_string(val) + ","; + } + if (ret.size() > 1) { + ret[ret.size() - 1] = ']'; + } else { + ret += "]"; + } + return ret; +} + +inline string to_string(const NVTEShape& s) { + string ret = "["; + for (size_t i = 0; i < s.ndim; ++i) { + ret += to_string(s.data[i]) + ","; + } + if (ret.size() > 1) { + ret[ret.size() - 1] = ']'; + } else { + ret += "]"; + } + return ret; +} +} // namespace std + #endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_COMMON_H_ diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 58527ef6d5..e871228b80 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -10,7 +10,6 @@ #include #include "common.h" -#include "common/common.h" /*************************************************************************************************** * Permutation @@ -45,93 +44,27 @@ NVTE_Fused_Attn_Backend get_fused_attn_backend(const transformer_engine::DType q size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, int64_t window_size_right); -std::vector fused_attn_fwd_qkvpacked( - size_t max_seqlen, bool is_training, float attn_scale, float p_dropout, bool set_zero, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - const std::vector window_size, const at::Tensor cu_seqlens, const at::Tensor QKV, - const transformer_engine::DType qkv_type, const c10::optional cu_seqlens_padded, - const c10::optional descale_QKV, const int descale_QKV_offset, - const c10::optional descale_S, const int descale_S_offset, - const c10::optional scale_S, const int scale_S_offset, - const c10::optional scale_O, const int scale_O_offset, - c10::optional amax_S, const int amax_S_offset, c10::optional amax_O, - const int amax_O_offset, const c10::optional Bias, - const c10::optional rng_gen, size_t rng_elts_per_thread); - -std::vector fused_attn_bwd_qkvpacked( - size_t max_seqlen, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const std::vector window_size, - bool deterministic, const at::Tensor cu_seqlens, const at::Tensor QKV, const at::Tensor O, - const at::Tensor dO, const transformer_engine::DType qkv_type, - const transformer_engine::DType dqkv_type, const std::vector Aux_CTX_Tensors, - const c10::optional cu_seqlens_padded, const c10::optional descale_QKV, - const c10::optional descale_S, const c10::optional descale_O, - const c10::optional descale_dO, const c10::optional descale_dP, - const c10::optional scale_S, const c10::optional scale_dP, - const c10::optional scale_dQKV, c10::optional amax_dP, - c10::optional amax_dQKV); - -std::vector fused_attn_fwd_kvpacked( +std::vector fused_attn_fwd( size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const std::vector window_size, - const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const at::Tensor Q, - const at::Tensor KV, const transformer_engine::DType qkv_type, + const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q, + const py::handle K, const py::handle V, const at::ScalarType fake_dtype, const c10::optional cu_seqlens_q_padded, - const c10::optional cu_seqlens_kv_padded, - const c10::optional descale_QKV, const int descale_QKV_offset, - const c10::optional descale_S, const int descale_S_offset, - const c10::optional scale_S, const int scale_S_offset, - const c10::optional scale_O, const int scale_O_offset, - c10::optional amax_S, const int amax_S_offset, c10::optional amax_O, - const int amax_O_offset, const c10::optional Bias, + const c10::optional cu_seqlens_kv_padded, py::handle s_quantizer, + py::handle o_quantizer, const c10::optional Bias, const c10::optional rng_gen, size_t rng_elts_per_thread); -std::vector fused_attn_bwd_kvpacked( +std::vector fused_attn_bwd( size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const std::vector window_size, bool deterministic, const at::Tensor cu_seqlens_q, - const at::Tensor cu_seqlens_kv, const at::Tensor Q, const at::Tensor KV, const at::Tensor O, - const at::Tensor dO, const transformer_engine::DType qkv_type, + const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V, + const py::handle O, const py::handle dO, const at::ScalarType fake_dtype, const transformer_engine::DType dqkv_type, const std::vector Aux_CTX_Tensors, const c10::optional cu_seqlens_q_padded, - const c10::optional cu_seqlens_kv_padded, - const c10::optional descale_QKV, const c10::optional descale_S, - const c10::optional descale_O, const c10::optional descale_dO, - const c10::optional descale_dP, const c10::optional scale_S, - const c10::optional scale_dP, const c10::optional scale_dQKV, - c10::optional amax_dP, c10::optional amax_dQKV); - -std::vector fused_attn_fwd( - size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout, - bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, const std::vector window_size, - const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const at::Tensor Q, - const at::Tensor K, const at::Tensor V, const transformer_engine::DType qkv_type, - const c10::optional cu_seqlens_q_padded, - const c10::optional cu_seqlens_kv_padded, - const c10::optional descale_QKV, const int descale_QKV_offset, - const c10::optional descale_S, const int descale_S_offset, - const c10::optional scale_S, const int scale_S_offset, - const c10::optional scale_O, const int scale_O_offset, - c10::optional amax_S, const int amax_S_offset, c10::optional amax_O, - const int amax_O_offset, const c10::optional Bias, - const c10::optional rng_gen, size_t rng_elts_per_thread); - -std::vector fused_attn_bwd( - size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - const std::vector window_size, bool deterministic, const at::Tensor cu_seqlens_q, - const at::Tensor cu_seqlens_kv, const at::Tensor Q, const at::Tensor K, const at::Tensor V, - const at::Tensor O, const at::Tensor dO, const transformer_engine::DType qkv_type, - const transformer_engine::DType dqkv_type, const std::vector Aux_CTX_Tensors, - const c10::optional cu_seqlens_q_padded, - const c10::optional cu_seqlens_kv_padded, - const c10::optional descale_QKV, const c10::optional descale_S, - const c10::optional descale_O, const c10::optional descale_dO, - const c10::optional descale_dP, const c10::optional scale_S, - const c10::optional scale_dP, const c10::optional scale_dQKV, - c10::optional amax_dP, c10::optional amax_dQKV); + const c10::optional cu_seqlens_kv_padded, py::handle s_quantizer, + py::handle dp_quantizer, py::handle dqkv_quantizer); at::Tensor fa_prepare_fwd(at::Tensor qkvi); at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v); @@ -140,237 +73,146 @@ at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v); * GEMM **************************************************************************************************/ -void te_gemm(at::Tensor A, at::Tensor A_scale_inverse, transformer_engine::DType A_type, - bool transa, at::Tensor B, at::Tensor B_scale_inverse, - transformer_engine::DType B_type, bool transb, at::Tensor D, at::Tensor D_scale, - transformer_engine::DType D_type, at::Tensor D_amax, at::Tensor bias, - transformer_engine::DType bias_type, at::Tensor pre_gelu_out, bool grad, - at::Tensor workspace, size_t workspaceSize, bool accumulate, - bool use_split_accumulator, int math_sm_count); +using MaybeTensor = std::optional; void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, transformer_engine::DType A_type, - bool transa, at::Tensor B, at::Tensor B_scale_inverse, - transformer_engine::DType B_type, bool transb, at::Tensor D, at::Tensor D_scale, - transformer_engine::DType D_type, at::Tensor D_amax, at::Tensor bias, - transformer_engine::DType bias_type, at::Tensor pre_gelu_out, bool grad, - at::Tensor workspace, size_t workspaceSize, bool accumulate, + std::vector A_scaling_mode, bool transa, at::Tensor B, + at::Tensor B_scale_inverse, transformer_engine::DType B_type, + std::vector B_scaling_mode, bool transb, at::Tensor D, + at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, + at::Tensor bias, transformer_engine::DType bias_type, at::Tensor pre_gelu_out, + bool grad, at::Tensor workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator, int math_sm_count, int m_split, int n_split, bool gemm_producer, at::Tensor counter); -void te_grouped_gemm(std::vector A, at::Tensor A_scale_inverse, int A_offset, - transformer_engine::DType A_type, bool transa, std::vector B, - at::Tensor B_scale_inverse, int B_offset, transformer_engine::DType B_type, - bool transb, std::vector D, int D_offset, at::Tensor D_scale, - transformer_engine::DType D_type, at::Tensor D_amax, - std::vector bias, transformer_engine::DType bias_type, - std::vector pre_gelu_out, bool grad, - std::vector workspace, size_t workspaceSize, bool accumulate, - bool use_split_accumulator, int math_sm_count); - -void te_grouped_gemm_single_output( - std::vector A, std::vector A_scale_inverse, int A_offset, - transformer_engine::DType A_type, bool transa, std::vector B, - at::Tensor B_scale_inverse, int B_offset, transformer_engine::DType B_type, bool transb, - std::vector m_splits, at::Tensor D, int D_offset, at::Tensor D_scale, - transformer_engine::DType D_type, at::Tensor D_amax, std::vector bias, - transformer_engine::DType bias_type, std::vector pre_gelu_out, bool grad, - std::vector workspace, size_t workspaceSize, bool accumulate, +std::optional> te_general_grouped_gemm( + std::vector A, bool transa, std::vector B, bool transb, + std::optional> D, transformer_engine::DType D_type, + std::vector m_splits, std::vector bias, + transformer_engine::DType bias_type, bool single_output, std::vector pre_gelu_out, + bool grad, std::vector workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator, int math_sm_count); /*************************************************************************************************** * Transpose **************************************************************************************************/ -void fused_cast_transpose(at::Tensor input, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, - at::Tensor input_cast, at::Tensor input_transpose, - transformer_engine::DType otype); - -void fused_cast_transpose_noop(at::Tensor input, at::Tensor noop, at::Tensor scale, at::Tensor amax, - at::Tensor scale_inv, at::Tensor input_cast, - at::Tensor input_transpose, transformer_engine::DType otype, - int scale_offset = 0, int amax_offset = 0, int scale_inv_offset = 0); - -std::vector fused_cast_transpose_bgrad(at::Tensor grad_output, at::Tensor scale, - at::Tensor amax, at::Tensor scale_inv, - transformer_engine::DType otype, - int scale_offset = 0, int amax_offset = 0, - int scale_inv_offset = 0); - -std::vector fused_fp8_transpose_bgrad(at::Tensor grad_output, at::Tensor scale, - at::Tensor amax, at::Tensor scale_inv, - transformer_engine::DType otype, - transformer_engine::DType grad_bias_type, - int scale_offset = 0, int amax_offset = 0, - int scale_inv_offset = 0); - -std::vector fused_cast_transpose_bgrad_dgelu(at::Tensor grad_output, - at::Tensor gelu_input, at::Tensor scale, - at::Tensor amax, at::Tensor scale_inv, - transformer_engine::DType otype, - int scale_offset = 0, int amax_offset = 0, - int scale_inv_offset = 0); - -void fused_dswiglu_cast_transpose(at::Tensor grad_output, at::Tensor input, at::Tensor grad_input, - at::Tensor grad_input_transpose, at::Tensor scale, - at::Tensor amax, at::Tensor scale_inv, - transformer_engine::DType otype, int scale_offset = 0, - int amax_offset = 0, int scale_inv_offset = 0); - -void fused_multi_cast_transpose(std::vector input_list, - std::vector scale_list, - std::vector cast_output_list, - std::vector transposed_output_list, - std::vector amax_output_list, - std::vector scale_inv_output_list, - transformer_engine::DType otype); - -std::tuple, std::vector> fused_multi_cast_transpose_alloc( - std::vector input_list, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, - std::vector scale_indices, std::vector amax_indices, - std::vector scale_inv_indices, transformer_engine::DType otype); - -at::Tensor fp8_transpose(at::Tensor input, transformer_engine::DType otype); - -void fp8_transpose_noalloc(at::Tensor input, at::Tensor output, transformer_engine::DType otype); - -void fp8_transpose_noalloc_noop(at::Tensor input, at::Tensor output, at::Tensor noop, - transformer_engine::DType otype); +std::vector fused_multi_quantize(std::vector input_list, + std::optional> output_list, + std::vector quantizer_list, + transformer_engine::DType otype); + +at::Tensor fp8_transpose(at::Tensor input, transformer_engine::DType otype, + std::optional output = std::nullopt); + +namespace transformer_engine::pytorch { /*************************************************************************************************** * Activations **************************************************************************************************/ -at::Tensor gelu(at::Tensor input, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, - transformer_engine::DType otype); +py::object gelu(const at::Tensor &input, py::handle quantizer); + +py::object relu(const at::Tensor &input, py::handle quantizer); -at::Tensor relu(at::Tensor input, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, - transformer_engine::DType otype); +py::object geglu(const at::Tensor &input, py::handle quantizer); -at::Tensor geglu(at::Tensor input, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, - transformer_engine::DType otype); +py::object qgeglu(const at::Tensor &input, py::handle quantizer); -at::Tensor reglu(at::Tensor input, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, - transformer_engine::DType otype); +py::object reglu(const at::Tensor &input, py::handle quantizer); -at::Tensor swiglu(at::Tensor input, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, - transformer_engine::DType otype); +py::object swiglu(const at::Tensor &input, py::handle quantizer); -at::Tensor qgelu(at::Tensor input, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, - transformer_engine::DType otype); +py::object qgelu(const at::Tensor &input, py::handle quantizer); -at::Tensor srelu(at::Tensor input, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, - transformer_engine::DType otype); +py::object srelu(const at::Tensor &input, py::handle quantizer); -at::Tensor dgelu(at::Tensor grad, at::Tensor input, transformer_engine::DType otype); +py::object dgelu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); -at::Tensor drelu(at::Tensor grad, at::Tensor input, transformer_engine::DType otype); +py::object drelu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); -at::Tensor dgeglu(at::Tensor grad, at::Tensor input, transformer_engine::DType otype); +py::object dgeglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); -at::Tensor dreglu(at::Tensor grad, at::Tensor input, transformer_engine::DType otype); +py::object dqgeglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); -at::Tensor dswiglu(at::Tensor grad, at::Tensor input, transformer_engine::DType otype); +py::object dreglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); -at::Tensor dqgelu(at::Tensor grad, at::Tensor input, transformer_engine::DType otype); +py::object dswiglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); -at::Tensor dsrelu(at::Tensor grad, at::Tensor input, transformer_engine::DType otype); +py::object dqgelu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); + +py::object dsrelu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); + +} // namespace transformer_engine::pytorch /*************************************************************************************************** * LayerNorm **************************************************************************************************/ -std::vector layernorm_bwd(const at::Tensor &dz, const at::Tensor &x, +std::vector layernorm_bwd(const at::Tensor &dz, const at::Tensor &x, const at::Tensor &mu, const at::Tensor &rsigma, const at::Tensor &gamma, const int sm_margin, const bool zero_centered_gamma); -std::vector layernorm_fwd_fp8(const at::Tensor &input, const at::Tensor &weight, - const at::Tensor &bias, float eps, at::Tensor scale, - at::Tensor amax, at::Tensor scale_inv, - transformer_engine::DType otype, const int sm_margin, - const bool zero_centered_gamma, - const int scale_offset = 0, const int amax_offset = 0, - const int scale_inv_offset = 0); - -std::vector layernorm_fwd_fp8_noalloc( - const at::Tensor &input, const at::Tensor &weight, const at::Tensor &bias, float eps, - at::Tensor scale, at::Tensor ln_out, at::Tensor amax, at::Tensor scale_inv, - transformer_engine::DType otype, const int sm_margin, const bool zero_centered_gamma, - const int scale_offset = 0, const int amax_offset = 0, const int scale_inv_offset = 0); - -at::Tensor layernorm_fwd_fp8_inf(const at::Tensor &input, const at::Tensor &weight, - const at::Tensor &bias, float eps, at::Tensor scale, - at::Tensor amax, at::Tensor scale_inv, - transformer_engine::DType otype, const int sm_margin, - const bool zero_centered_gamma, const int scale_offset = 0, - const int amax_offset = 0, const int scale_inv_offset = 0); - -std::vector layernorm_fwd(const at::Tensor &input, const at::Tensor &weight, - const at::Tensor &bias, float eps, const int sm_margin, +std::vector layernorm_fwd(py::handle input, py::handle weight, MaybeTensor bias, + float eps, py::object ln_out, py::handle quantizer, + transformer_engine::DType out_dtype, const int sm_margin, const bool zero_centered_gamma); -std::vector layernorm_fwd_noalloc(const at::Tensor &input, const at::Tensor &weight, - const at::Tensor &bias, at::Tensor ln_out, float eps, - const int sm_margin, const bool zero_centered_gamma); - -at::Tensor layernorm_fwd_inf(const at::Tensor &input, const at::Tensor &weight, - const at::Tensor &bias, float eps, const int sm_margin, - const bool zero_centered_gamma); - /*************************************************************************************************** * RMSNorm **************************************************************************************************/ -std::vector rmsnorm_bwd(const at::Tensor &dz, const at::Tensor &x, +std::vector rmsnorm_bwd(const at::Tensor &dz, const at::Tensor &x, const at::Tensor &rsigma, const at::Tensor &gamma, const int sm_margin, const bool zero_centered_gamma); -std::vector rmsnorm_fwd_fp8(const at::Tensor &input, const at::Tensor &weight, - float eps, at::Tensor scale, at::Tensor amax, - at::Tensor scale_inv, transformer_engine::DType otype, - const int sm_margin, const bool zero_centered_gamma, - const int scale_offset = 0, const int amax_offset = 0, - const int scale_inv_offset = 0); - -std::vector rmsnorm_fwd_fp8_noalloc( - const at::Tensor &input, const at::Tensor &weight, float eps, at::Tensor scale, - at::Tensor ln_out, at::Tensor amax, at::Tensor scale_inv, transformer_engine::DType otype, - const int sm_margin, const bool zero_centered_gamma, const int scale_offset = 0, - const int amax_offset = 0, const int scale_inv_offset = 0); - -at::Tensor rmsnorm_fwd_fp8_inf(const at::Tensor &input, const at::Tensor &weight, float eps, - at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, - transformer_engine::DType otype, const int sm_margin, - const bool zero_centered_gamma, const int scale_offset = 0, - const int amax_offset = 0, const int scale_inv_offset = 0); - -std::vector rmsnorm_fwd(const at::Tensor &input, const at::Tensor &weight, float eps, - const int sm_margin, const bool zero_centered_gamma); +std::vector rmsnorm_fwd(const py::handle &input, const py::handle &weight, float eps, + py::object ln_out, py::handle quantizer, + transformer_engine::DType otype, const int sm_margin, + const bool zero_centered_gamma); + +/*************************************************************************************************** + * Cast + **************************************************************************************************/ -std::vector rmsnorm_fwd_noalloc(const at::Tensor &input, const at::Tensor &weight, - at::Tensor ln_out, float eps, const int sm_margin, - const bool zero_centered_gamma); +namespace transformer_engine::pytorch { -at::Tensor rmsnorm_fwd_inf(const at::Tensor &input, const at::Tensor &weight, float eps, - const int sm_margin, const bool zero_centered_gamma); +py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::object &output, + std::optional noop); + +py::object dequantize(const py::handle &input, transformer_engine::DType otype); + +std::vector bgrad_quantize(const at::Tensor &input, py::handle py_quantizer); + +std::vector gemm(py::handle A, bool transa, py::handle B, bool transb, py::object D, + py::handle quantizer, std::optional out_dtype, MaybeTensor bias, + DType bias_type, bool gelu, MaybeTensor gelu_in, bool grad, + at::Tensor workspace, size_t workspaceSize, bool accumulate, + bool use_split_accumulator, CommOverlapCore *comm_overlap = nullptr, + std::optional comm_type = std::nullopt, + MaybeTensor extra_output = std::nullopt, bool bulk_overlap = false); /*************************************************************************************************** - * Cast + * Cast fusions **************************************************************************************************/ -at::Tensor cast_to_fp8(const at::Tensor &input, const at::Tensor &scale, at::Tensor amax, - at::Tensor scale_inv, transformer_engine::DType otype, - const int scale_offset = 0, const int amax_offset = 0, - const int scale_inv_offset = 0); +std::vector dbias_dgelu(const at::Tensor &grad_output, const at::Tensor &act_input, + py::handle quantizer); + +std::vector dbias_dsilu(const at::Tensor &grad_output, const at::Tensor &act_input, + py::handle quantizer); + +std::vector dbias_drelu(const at::Tensor &grad_output, const at::Tensor &act_input, + py::handle quantizer); + +std::vector dbias_dqgelu(const at::Tensor &grad_output, const at::Tensor &act_input, + py::handle quantizer); -void cast_to_fp8_noalloc(const at::Tensor &input, const at::Tensor &scale, at::Tensor output, - at::Tensor amax, at::Tensor scale_inv, transformer_engine::DType otype, - const int scale_offset = 0, const int amax_offset = 0, - const int scale_inv_offset = 0); +std::vector dbias_dsrelu(const at::Tensor &grad_output, const at::Tensor &act_input, + py::handle quantizer); -at::Tensor cast_from_fp8(const at::Tensor &input, const at::Tensor &scale_inv, - transformer_engine::DType itype, transformer_engine::DType otype, - const int scale_inv_offset = 0); +} // namespace transformer_engine::pytorch /*************************************************************************************************** * Softmax @@ -405,7 +247,6 @@ at::Tensor scaled_aligned_causal_masked_softmax_backward(at::Tensor output_grads void fused_amax_and_scale_update_after_reduction(const at::Tensor &amax_reduction_buffer, std::vector amax_histories, std::vector scales, - std::vector scale_invs, const std::string &amax_compute_algo, transformer_engine::DType fp8_dtype, float margin); @@ -518,6 +359,16 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output, std::vector input_row_list, std::vector padded_input_row_list); +/*************************************************************************************************** + * swizzle + **************************************************************************************************/ + +void swizzle_scaling_factors(transformer_engine::TensorWrapper &input, bool trans); + +at::Tensor rowwise_swizzle(at::Tensor input, at::Tensor scale_inv); + +at::Tensor columnwise_swizzle(at::Tensor input, at::Tensor scale_inv); + /*************************************************************************************************** * Comm+GEMM Overlap Wrappers **************************************************************************************************/ @@ -551,151 +402,44 @@ class CommOverlapHelper : torch::CustomClassHolder { }; class CommOverlap : torch::CustomClassHolder, public transformer_engine::CommOverlapBase { - private: - torch::Tensor _ubuf_torch; - torch::Tensor _ubuf_counter; - public: CommOverlap(const std::vector &buffer_shape, at::ScalarType buffer_dtype, CommOverlapHelper *helper, int tp_size, int num_splits = 3, int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, int comm_cga_size = 2, - int num_comm_sm = 16, bool set_sm_margin = true, bool atomic_gemm = false); - - void set_ubuf_scale_inv(torch::Tensor scale_inv) { - assert(scale_inv.numel()); - assert(scale_inv.scalar_type() == torch::kFloat32); - transformer_engine::CommOverlapBase::set_ubuf_scale_inv( - reinterpret_cast(scale_inv.data_ptr())); - } - - void copy_input_to_ubuf(torch::Tensor input, int comm_type); - - torch::Tensor get_ubuf_output(int comm_type); - - /* - ** Bulk GEMM + COMM - ** This function assumes the communication input is pre-copied to _ubuf - */ - std::vector bulk_overlap( - at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, - transformer_engine::DType A_type, bool transa, at::Tensor B, at::Tensor B_scale_inverse, - int64_t B_fp8_tensor, transformer_engine::DType B_type, bool transb, at::Tensor D, - at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, at::Tensor bias, - transformer_engine::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, - size_t workspaceSize, bool accumulate, bool use_split_accumulator, - transformer_engine::CommOverlapType comm_type, at::Tensor rs_output); - - /* - ** Split FPROP GEMM + ReduceScatter - */ - void atomic_gemm_overlap_rs(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, - transformer_engine::DType A_type, bool transa, at::Tensor B, - at::Tensor B_scale_inverse, int64_t B_fp8_tensor, - transformer_engine::DType B_type, bool transb, at::Tensor D, - at::Tensor D_scale, transformer_engine::DType D_type, - at::Tensor D_amax, at::Tensor bias, - transformer_engine::DType bias_type, at::Tensor pre_gelu_out, - bool grad, at::Tensor workspace, size_t workspaceSize, - bool accumulate, bool use_split_accumulator, bool gemm_overlap, - at::Tensor rs_output); - - /* - ** Split FPROP GEMM + ReduceScatter - */ - void split_overlap_rs(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, - transformer_engine::DType A_type, bool transa, at::Tensor B, - at::Tensor B_scale_inverse, int64_t B_fp8_tensor, - transformer_engine::DType B_type, bool transb, at::Tensor D, - at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, - at::Tensor bias, transformer_engine::DType bias_type, - at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, - size_t workspaceSize, bool accumulate, bool use_split_accumulator, - bool gemm_overlap, at::Tensor rs_output); + int gemm_priority = 0, int comm_priority = 0, int num_comm_sm = 16, + bool set_sm_margin = true, bool atomic_gemm = false, + bool rs_overlap_first_gemm = false); + + ~CommOverlap() {} + + void set_buffer_params(py::handle quantizer); + + void copy_into_buffer(py::handle input, py::handle quantizer, bool local_chunk = false); + + py::object get_buffer(py::handle quantizer, bool local_chunk = false, + std::optional> shape = std::nullopt); + }; // CommOverlap class CommOverlapP2P : torch::CustomClassHolder, public transformer_engine::CommOverlapP2PBase { - private: - torch::Tensor _ubuf_torch; - torch::Tensor _ubuf_counter; - public: CommOverlapP2P(const std::vector &buffer_shape, at::ScalarType buffer_dtype, CommOverlapHelper *helper, int tp_size, transformer_engine::CommOverlapType comm_type, int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, int comm_cga_size = 2, - int num_comm_sm = 3, bool set_sm_margin = true, bool atomic_gemm = false, - bool use_ce = true, bool aggregate = false); - - void set_ubuf_scale_inv(torch::Tensor scale_inv) { - assert(scale_inv.numel()); - assert(scale_inv.scalar_type() == torch::kFloat32); - transformer_engine::CommOverlapP2PBase::set_ubuf_scale_inv( - reinterpret_cast(scale_inv.data_ptr())); - } - - void copy_input_to_ubuf(torch::Tensor input, bool chunk); - - torch::Tensor get_ubuf_output(int comm_type); - - /* - ** Split AllGather + AtomicGEMM using P2P communication - ** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is - *needed to have AG outputs - ** in each rank to be in the contiguous memory space after all ring exchange - *phases. - */ - void atomic_gemm_overlap_ag(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, - transformer_engine::DType A_type, bool transa, at::Tensor B, - at::Tensor B_scale_inverse, int64_t B_fp8_tensor, - transformer_engine::DType B_type, bool transb, at::Tensor D, - at::Tensor D_scale, transformer_engine::DType D_type, - at::Tensor D_amax, at::Tensor bias, - transformer_engine::DType bias_type, at::Tensor pre_gelu_out, - bool grad, at::Tensor workspace, size_t workspaceSize, - bool accumulate, bool use_split_accumulator, at::Tensor B_copy); - - /* - ** Split AllGather + GEMM using P2P communication - ** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is - *needed to have AG outputs - ** in each rank to be in the contiguous memory space after all ring exchange - *phases. - */ - void split_overlap_ag(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, - transformer_engine::DType A_type, bool transa, at::Tensor B, - at::Tensor B_scale_inverse, int64_t B_fp8_tensor, - transformer_engine::DType B_type, bool transb, at::Tensor D, - at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, - at::Tensor bias, transformer_engine::DType bias_type, - at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, - size_t workspaceSize, bool accumulate, bool use_split_accumulator, - at::Tensor B_copy); - - /* - ** Split ReduceScatter + GEMM using P2P communication - */ - void atomic_gemm_overlap_rs(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, - transformer_engine::DType A_type, bool transa, at::Tensor B, - at::Tensor B_scale_inverse, int64_t B_fp8_tensor, - transformer_engine::DType B_type, bool transb, at::Tensor D, - at::Tensor D_scale, transformer_engine::DType D_type, - at::Tensor D_amax, at::Tensor bias, - transformer_engine::DType bias_type, at::Tensor pre_gelu_out, - bool grad, at::Tensor workspace, size_t workspaceSize, - bool accumulate, bool use_split_accumulator, at::Tensor rs_output); - - /* - ** Split ReduceScatter + GEMM using P2P communication - */ - void split_overlap_rs(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, - transformer_engine::DType A_type, bool transa, at::Tensor B, - at::Tensor B_scale_inverse, int64_t B_fp8_tensor, - transformer_engine::DType B_type, bool transb, at::Tensor D, - at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, - at::Tensor bias, transformer_engine::DType bias_type, - at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, - size_t workspaceSize, bool accumulate, bool use_split_accumulator, - at::Tensor rs_output); + int gemm_priority = 0, int comm_priority = 0, int num_comm_sm = 3, + bool set_sm_margin = true, bool atomic_gemm = false, bool use_ce = true, + bool aggregate = false); + + ~CommOverlapP2P() {} + + void set_buffer_params(py::handle quantizer); + + void copy_into_buffer(py::handle input, py::handle quantizer, bool local_chunk = false); + + py::object get_buffer(py::handle quantizer, bool local_chunk = false, + std::optional> shape = std::nullopt); + }; // CommOverlapP2P #endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_H_ diff --git a/transformer_engine/pytorch/csrc/extensions/activation.cpp b/transformer_engine/pytorch/csrc/extensions/activation.cpp index 48832e6994..7ce33ee77b 100644 --- a/transformer_engine/pytorch/csrc/extensions/activation.cpp +++ b/transformer_engine/pytorch/csrc/extensions/activation.cpp @@ -5,272 +5,114 @@ ************************************************************************/ #include "extensions.h" +#include "pybind.h" -at::Tensor gelu(at::Tensor input, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, - transformer_engine::DType otype) { - using namespace transformer_engine; +namespace transformer_engine::pytorch { - size_t N = static_cast(input.size(-1)); - size_t M = input.numel() / N; +template +py::object activation_helper(const at::Tensor& input, py::handle quantizer, int shape_divisor = 1) { + init_extension(); + auto my_quantizer = convert_quantizer(quantizer); + auto input_tensor = input.contiguous(); - auto output = allocateTorchTensor(M, N, otype); + const TensorWrapper& te_input = makeTransformerEngineTensor(input_tensor); + const auto& te_input_shape = te_input.shape(); + std::vector input_shape(te_input_shape.data, te_input_shape.data + te_input_shape.ndim); + input_shape[input_shape.size() - 1] /= shape_divisor; + auto fake_tensor_type = input.scalar_type(); - auto itype = GetTransformerEngineDType(input.scalar_type()); - auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype); - auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N}, otype, amax.data_ptr(), - scale.data_ptr(), scale_inv.data_ptr()); + auto [te_output, out] = + my_quantizer->create_tensor(input_shape, GetTransformerEngineDType(fake_tensor_type)); - nvte_gelu(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); + act_func(te_input.data(), te_output.data(), at::cuda::getCurrentCUDAStream()); - return output; + return out; } -at::Tensor dgelu(at::Tensor grad, at::Tensor input, transformer_engine::DType otype) { - using namespace transformer_engine; +template +py::object dactivation_helper(const at::Tensor& grad, const at::Tensor& input, + py::handle quantizer) { + init_extension(); + auto my_quantizer = convert_quantizer(quantizer); + auto input_tensor = input.contiguous(); + auto grad_tensor = grad.contiguous(); - size_t N = static_cast(input.size(-1)); - size_t M = input.numel() / N; + const TensorWrapper& te_input = makeTransformerEngineTensor(input_tensor); + const TensorWrapper& te_grad = makeTransformerEngineTensor(grad_tensor); + const auto& te_input_shape = te_input.shape(); + std::vector input_shape(te_input_shape.data, te_input_shape.data + te_input_shape.ndim); + auto fake_tensor_type = input.scalar_type(); - auto output = allocateTorchTensor(M, N, otype); + auto [te_output, out] = + my_quantizer->create_tensor(input_shape, GetTransformerEngineDType(fake_tensor_type)); - auto itype = GetTransformerEngineDType(input.scalar_type()); - auto gtype = GetTransformerEngineDType(grad.scalar_type()); - auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype); - auto grad_cu = makeTransformerEngineTensor(grad.data_ptr(), {M, N}, gtype); - auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N}, otype); + act_func(te_grad.data(), te_input.data(), te_output.data(), at::cuda::getCurrentCUDAStream()); - nvte_dgelu(grad_cu.data(), input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); - - return output; + return out; } -at::Tensor relu(at::Tensor input, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, - transformer_engine::DType otype) { - using namespace transformer_engine; - - size_t N = static_cast(input.size(-1)); - size_t M = static_cast(input.numel()) / N; - - auto output = allocateTorchTensor(M, N, otype); - - auto itype = GetTransformerEngineDType(input.scalar_type()); - auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype); - auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N}, otype, amax.data_ptr(), - scale.data_ptr(), scale_inv.data_ptr()); - - nvte_relu(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); - - return output; +py::object gelu(const at::Tensor& input, py::handle quantizer) { + return activation_helper(input, quantizer); } -at::Tensor drelu(at::Tensor grad, at::Tensor input, transformer_engine::DType otype) { - using namespace transformer_engine; - - size_t N = static_cast(input.size(-1)); - size_t M = input.numel() / N; - - auto output = allocateTorchTensor(M, N, otype); - - auto itype = GetTransformerEngineDType(input.scalar_type()); - auto gtype = GetTransformerEngineDType(grad.scalar_type()); - auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype); - auto grad_cu = makeTransformerEngineTensor(grad.data_ptr(), {M, N}, gtype); - auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N}, otype); - - nvte_drelu(grad_cu.data(), input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); - - return output; +py::object dgelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { + return dactivation_helper(grad, input, quantizer); } -at::Tensor geglu(at::Tensor input, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, - transformer_engine::DType otype) { - using namespace transformer_engine; - - size_t N = static_cast(input.size(-1)); - size_t M = input.numel() / N; - - auto output = allocateTorchTensor(M, N / 2, otype); - - auto itype = GetTransformerEngineDType(input.scalar_type()); - auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype); - auto output_cu = - makeTransformerEngineTensor(output.data_ptr(), {M, N / 2}, otype, amax.data_ptr(), - scale.data_ptr(), scale_inv.data_ptr()); - - nvte_geglu(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); - - return output; +py::object relu(const at::Tensor& input, py::handle quantizer) { + return activation_helper(input, quantizer); } -at::Tensor dgeglu(at::Tensor grad, at::Tensor input, transformer_engine::DType otype) { - using namespace transformer_engine; - - size_t N = static_cast(input.size(-1)); - size_t M = input.numel() / N; - - auto output = allocateTorchTensor(M, N, otype); - - auto itype = GetTransformerEngineDType(input.scalar_type()); - auto gtype = GetTransformerEngineDType(grad.scalar_type()); - auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype); - auto grad_cu = makeTransformerEngineTensor(grad.data_ptr(), {M, N / 2}, gtype); - auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N}, otype); - - nvte_dgeglu(grad_cu.data(), input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); - - return output; +py::object drelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { + return dactivation_helper(grad, input, quantizer); } -at::Tensor reglu(at::Tensor input, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, - transformer_engine::DType otype) { - using namespace transformer_engine; - - size_t N = static_cast(input.size(-1)); - size_t M = input.numel() / N; - - auto output = allocateTorchTensor(M, N / 2, otype); - - auto itype = GetTransformerEngineDType(input.scalar_type()); - auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype); - auto output_cu = - makeTransformerEngineTensor(output.data_ptr(), {M, N / 2}, otype, amax.data_ptr(), - scale.data_ptr(), scale_inv.data_ptr()); - - nvte_reglu(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); - - return output; +py::object geglu(const at::Tensor& input, py::handle quantizer) { + return activation_helper(input, quantizer, 2); } -at::Tensor dreglu(at::Tensor grad, at::Tensor input, transformer_engine::DType otype) { - using namespace transformer_engine; - - size_t N = static_cast(input.size(-1)); - size_t M = input.numel() / N; - - auto output = allocateTorchTensor(M, N, otype); - - auto itype = GetTransformerEngineDType(input.scalar_type()); - auto gtype = GetTransformerEngineDType(grad.scalar_type()); - auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype); - auto grad_cu = makeTransformerEngineTensor(grad.data_ptr(), {M, N / 2}, gtype); - auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N}, otype); - - nvte_dreglu(grad_cu.data(), input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); - - return output; +py::object qgeglu(const at::Tensor& input, py::handle quantizer) { + return activation_helper(input, quantizer, 2); } -at::Tensor swiglu(at::Tensor input, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, - transformer_engine::DType otype) { - using namespace transformer_engine; - - size_t N = static_cast(input.size(-1)); - size_t M = input.numel() / N; - - auto output = allocateTorchTensor(M, N / 2, otype); - - auto itype = GetTransformerEngineDType(input.scalar_type()); - auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype); - auto output_cu = - makeTransformerEngineTensor(output.data_ptr(), {M, N / 2}, otype, amax.data_ptr(), - scale.data_ptr(), scale_inv.data_ptr()); - - nvte_swiglu(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); - - return output; +py::object dgeglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { + return dactivation_helper(grad, input, quantizer); } -at::Tensor dswiglu(at::Tensor grad, at::Tensor input, transformer_engine::DType otype) { - using namespace transformer_engine; - - size_t N = static_cast(input.size(-1)); - size_t M = input.numel() / N; - - auto output = allocateTorchTensor(M, N, otype); - - auto itype = GetTransformerEngineDType(input.scalar_type()); - auto gtype = GetTransformerEngineDType(grad.scalar_type()); - auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype); - auto grad_cu = makeTransformerEngineTensor(grad.data_ptr(), {M, N / 2}, gtype); - auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N}, otype); - - nvte_dswiglu(grad_cu.data(), input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); - - return output; +py::object dqgeglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { + return dactivation_helper(grad, input, quantizer); } -at::Tensor qgelu(at::Tensor input, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, - transformer_engine::DType otype) { - using namespace transformer_engine; - - size_t N = static_cast(input.size(-1)); - size_t M = input.numel() / N; - - auto output = allocateTorchTensor(M, N, otype); - - auto itype = GetTransformerEngineDType(input.scalar_type()); - auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype); - auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N}, otype, amax.data_ptr(), - scale.data_ptr(), scale_inv.data_ptr()); - - nvte_qgelu(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); - - return output; +py::object reglu(const at::Tensor& input, py::handle quantizer) { + return activation_helper(input, quantizer, 2); } -at::Tensor dqgelu(at::Tensor grad, at::Tensor input, transformer_engine::DType otype) { - using namespace transformer_engine; - - size_t N = static_cast(input.size(-1)); - size_t M = input.numel() / N; - - auto output = allocateTorchTensor(M, N, otype); - - auto itype = GetTransformerEngineDType(input.scalar_type()); - auto gtype = GetTransformerEngineDType(grad.scalar_type()); - auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype); - auto grad_cu = makeTransformerEngineTensor(grad.data_ptr(), {M, N}, gtype); - auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N}, otype); - - nvte_dqgelu(grad_cu.data(), input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); - - return output; +py::object dreglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { + return dactivation_helper(grad, input, quantizer); } -at::Tensor srelu(at::Tensor input, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, - transformer_engine::DType otype) { - using namespace transformer_engine; - - size_t N = static_cast(input.size(-1)); - size_t M = static_cast(input.numel()) / N; - - auto output = allocateTorchTensor(M, N, otype); - - auto itype = GetTransformerEngineDType(input.scalar_type()); - auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype); - auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N}, otype, amax.data_ptr(), - scale.data_ptr(), scale_inv.data_ptr()); - - nvte_srelu(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); - - return output; +py::object swiglu(const at::Tensor& input, py::handle quantizer) { + return activation_helper(input, quantizer, 2); } -at::Tensor dsrelu(at::Tensor grad, at::Tensor input, transformer_engine::DType otype) { - using namespace transformer_engine; - - size_t N = static_cast(input.size(-1)); - size_t M = input.numel() / N; +py::object dswiglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { + return dactivation_helper(grad, input, quantizer); +} - auto output = allocateTorchTensor(M, N, otype); +py::object qgelu(const at::Tensor& input, py::handle quantizer) { + return activation_helper(input, quantizer); +} - auto itype = GetTransformerEngineDType(input.scalar_type()); - auto gtype = GetTransformerEngineDType(grad.scalar_type()); - auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype); - auto grad_cu = makeTransformerEngineTensor(grad.data_ptr(), {M, N}, gtype); - auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N}, otype); +py::object dqgelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { + return dactivation_helper(grad, input, quantizer); +} - nvte_dsrelu(grad_cu.data(), input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); +py::object srelu(const at::Tensor& input, py::handle quantizer) { + return activation_helper(input, quantizer); +} - return output; +py::object dsrelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { + return dactivation_helper(grad, input, quantizer); } + +} // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp b/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp index d9977f01b9..c323e7b6c1 100644 --- a/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp +++ b/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp @@ -8,7 +8,7 @@ at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs, const bool transpose_output_memory) { - using namespace transformer_engine; + using namespace transformer_engine::pytorch; TORCH_CHECK(input.dim() == 4, "expected 4D tensor"); TORCH_CHECK(freqs.dim() == 4, "expected 4D tensor"); TORCH_CHECK(input.size(0) <= freqs.size(0), @@ -66,7 +66,7 @@ at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs, at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor &freqs, const bool transpose_output_memory) { - using namespace transformer_engine; + using namespace transformer_engine::pytorch; TORCH_CHECK(output_grads.dim() == 4, "expected 4D tensor"); TORCH_CHECK(freqs.dim() == 4, "expected 4D tensor"); TORCH_CHECK(output_grads.size(0) <= freqs.size(0), @@ -122,7 +122,7 @@ at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor at::Tensor fused_rope_thd_forward(const at::Tensor &input, const at::Tensor &cu_seqlens, const at::Tensor &freqs, const int cp_size, const int cp_rank) { - using namespace transformer_engine; + using namespace transformer_engine::pytorch; TORCH_CHECK(input.dim() == 3, "expected 3D tensor"); TORCH_CHECK(cu_seqlens.dim() == 1, "expected 1D tensor"); TORCH_CHECK(freqs.dim() == 4, "expected 4D tensor"); @@ -174,7 +174,7 @@ at::Tensor fused_rope_thd_forward(const at::Tensor &input, const at::Tensor &cu_ at::Tensor fused_rope_thd_backward(const at::Tensor &output_grads, const at::Tensor &cu_seqlens, const at::Tensor &freqs, const int cp_size, const int cp_rank) { - using namespace transformer_engine; + using namespace transformer_engine::pytorch; TORCH_CHECK(output_grads.dim() == 3, "expected 3D tensor"); TORCH_CHECK(cu_seqlens.dim() == 1, "expected 1D tensor"); TORCH_CHECK(freqs.dim() == 4, "expected 4D tensor"); diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cu b/transformer_engine/pytorch/csrc/extensions/attention.cu index 9c9ffdb1a7..f2d1ecf3b9 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cu +++ b/transformer_engine/pytorch/csrc/extensions/attention.cu @@ -4,6 +4,7 @@ * See LICENSE for license information. ************************************************************************/ +#include "common/common.h" #include "common/fused_attn/thd_utils.h" #include "extensions.h" @@ -40,22 +41,27 @@ __global__ void __launch_bounds__(block_size) } // fast zero-fills of tensors -void mha_fill(const at::Tensor &self, const at::Tensor &start_index) { - auto max_tokens = self.size(0); - auto self_2d = self.view({max_tokens, -1}); - auto fcd_size = self_2d.size(1); - TORCH_CHECK(self.is_contiguous(), "input not contiguous"); +void mha_fill(const transformer_engine::TensorWrapper &self, const at::Tensor &start_index) { + std::vector shape = transformer_engine::pytorch::convertShape(self.shape()); + + auto max_tokens = shape[0]; + auto fcd_size = 1; + for (int i = 1; i <= shape.size(); i++) { + fcd_size *= shape[i]; + } TORCH_CHECK(fcd_size % block_size == 0, "input size not aligned to block size"); const int num_mp = at::cuda::getCurrentDeviceProperties()->multiProcessorCount; uint64_t num_blk_y = (uint64_t)(fcd_size / block_size); uint64_t num_blk_x = (uint64_t)((num_mp * ctas_per_sm + num_blk_y - 1) / num_blk_y); dim3 dim_grid(num_blk_x, num_blk_y); dim3 dim_block(block_size); + // trzeba jakos przekonwertowac DType na scalar_type + at::ScalarType scalar_type = transformer_engine::pytorch::GetATenDType(self.dtype()); AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2( - at::ScalarType::Half, at::ScalarType::BFloat16, self_2d.scalar_type(), "mha_fill", [&]() { + at::ScalarType::Half, at::ScalarType::BFloat16, scalar_type, "mha_fill", [&]() { mha_fill_kernel<<>>( - self_2d.data_ptr(), static_cast(start_index.data_ptr()), - max_tokens); + static_cast(self.get_rowwise_data().data_ptr), + static_cast(start_index.data_ptr()), max_tokens); C10_CUDA_KERNEL_LAUNCH_CHECK(); }); } @@ -80,735 +86,48 @@ at::PhiloxCudaState init_philox_state(at::CUDAGeneratorImpl *gen, size_t elts_pe return philox_args; } -// fused attention FWD with packed QKV -std::vector fused_attn_fwd_qkvpacked( - size_t max_seqlen, bool is_training, float attn_scale, float p_dropout, bool set_zero, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - const std::vector window_size, const at::Tensor cu_seqlens, const at::Tensor QKV, - const transformer_engine::DType qkv_type, const c10::optional cu_seqlens_padded, - const c10::optional descale_QKV, const int descale_QKV_offset, - const c10::optional descale_S, const int descale_S_offset, - const c10::optional scale_S, const int scale_S_offset, - const c10::optional scale_O, const int scale_O_offset, - c10::optional amax_S, const int amax_S_offset, c10::optional amax_O, - const int amax_O_offset, const c10::optional Bias, - const c10::optional rng_gen, size_t rng_elts_per_thread) { - using namespace transformer_engine; - - auto qkv_sizes = QKV.sizes().vec(); - std::vector qkv_shape{qkv_sizes.begin(), qkv_sizes.end()}; - std::vector q_shape; - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); - int loc_3 = 0; - switch (layout_group) { - case NVTE_3HD: - loc_3 = qkv_sizes.size() - 3; - break; - case NVTE_H3D: - loc_3 = qkv_sizes.size() - 2; - break; - default: - NVTE_ERROR("Invalid QKV layout group."); - } - for (auto it = qkv_shape.begin(); it != qkv_shape.end(); ++it) { - if (it - qkv_shape.begin() != loc_3) { - q_shape.push_back(*it); - } - } - std::vector o_shape{q_shape.begin(), q_shape.end()}; - - // create output tensor O - auto options = torch::TensorOptions().dtype(GetATenDType(qkv_type)).device(torch::kCUDA); - auto O = torch::empty(o_shape, options); - - // construct NVTE tensors - TensorWrapper te_QKV, te_S, te_O, te_Bias, te_cu_seqlens, te_cu_seqlens_padded; - if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { - // FP8 - auto h = q_shape[q_shape.size() - 2]; - auto d = q_shape[q_shape.size() - 1]; - if (set_zero && ((h * d) % block_size == 0) && - (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD)) { - mha_fill(O, cu_seqlens.index({torch::indexing::Slice(-1, torch::indexing::None)})); - } else { - O.fill_(0); - } - if ((!descale_QKV.has_value()) || (!descale_S.has_value()) || (!scale_S.has_value()) || - (!scale_O.has_value()) || (!amax_S.has_value()) || (!amax_O.has_value())) { - std::string err_tensors = "descale_QKV, descale_S, scale_S, scale_O, amax_S and amax_O "; - NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n")); - } - te_QKV = makeTransformerEngineTensor(QKV.data_ptr(), qkv_shape, qkv_type, nullptr, nullptr, - getDataPtr(descale_QKV.value(), descale_QKV_offset)); - te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, - getDataPtr(amax_S.value(), amax_S_offset), - getDataPtr(scale_S.value(), scale_S_offset), - getDataPtr(descale_S.value(), descale_S_offset)); - te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, - getDataPtr(amax_O.value(), amax_O_offset), - getDataPtr(scale_O.value(), scale_O_offset), nullptr); - } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { - if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) { - O.fill_(0); - } - // BF16 or FP16 - te_QKV = - makeTransformerEngineTensor(QKV.data_ptr(), qkv_shape, qkv_type, nullptr, nullptr, nullptr); - te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, nullptr, nullptr, nullptr); - te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, nullptr, nullptr, nullptr); - } else { - NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n"); - } - if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI) && (Bias.has_value())) { - auto bias_sizes = Bias.value().sizes().vec(); - std::vector bias_shape{bias_sizes.begin(), bias_sizes.end()}; - te_Bias = makeTransformerEngineTensor(Bias.value().data_ptr(), bias_shape, DType::kFloat32, - nullptr, nullptr, nullptr); - } - auto cu_seqlens_sizes = cu_seqlens.sizes().vec(); - std::vector cu_seqlens_shape{cu_seqlens_sizes.begin(), cu_seqlens_sizes.end()}; - te_cu_seqlens = makeTransformerEngineTensor(cu_seqlens.data_ptr(), cu_seqlens_shape, - DType::kInt32, nullptr, nullptr, nullptr); - - if (cu_seqlens_padded.has_value()) { - auto cu_seqlens_padded_sizes = cu_seqlens_padded.value().sizes().vec(); - std::vector cu_seqlens_padded_shape{cu_seqlens_padded_sizes.begin(), - cu_seqlens_padded_sizes.end()}; - te_cu_seqlens_padded = - makeTransformerEngineTensor(cu_seqlens_padded.value().data_ptr(), cu_seqlens_padded_shape, - DType::kInt32, nullptr, nullptr, nullptr); - } - - // extract random number generator seed and offset - auto gen = at::get_generator_or_default( - rng_gen, at::cuda::detail::getDefaultCUDAGenerator()); - at::PhiloxCudaState philox_args = init_philox_state(gen, rng_elts_per_thread); - auto rng_state = torch::empty({2}, options.dtype(torch::kInt64)); - unpack<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>( - philox_args, static_cast(rng_state.data_ptr())); - auto te_rng_state = makeTransformerEngineTensor(rng_state); - - // create auxiliary output tensors - NVTETensorPack nvte_aux_tensor_pack; - nvte_tensor_pack_create(&nvte_aux_tensor_pack); - - // create workspace - TensorWrapper workspace; - - // populate tensors with appropriate shapes and dtypes - nvte_fused_attn_fwd_qkvpacked( - te_QKV.data(), te_Bias.data(), te_S.data(), te_O.data(), &nvte_aux_tensor_pack, - te_cu_seqlens.data(), te_cu_seqlens_padded.data(), te_rng_state.data(), max_seqlen, - is_training, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, window_size[0], - window_size[1], workspace.data(), at::cuda::getCurrentCUDAStream()); - - // allocate memory for workspace and auxiliary output tensors - auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); - workspace = - makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype()); - - // output_tensors = [O, nvte_aux_tensor_pack.tensors] - std::vector output_tensors; - output_tensors.push_back(O); - for (size_t i = 0; i < nvte_aux_tensor_pack.size; ++i) { - auto tensor = reinterpret_cast(nvte_aux_tensor_pack.tensors[i]); - // allocate memory for nvte_aux_tensor_pack.tensors - at::Tensor output_tensor; - if (nvte_aux_tensor_pack.size >= 2) { - if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI) && (Bias.has_value())) { - if (i < nvte_aux_tensor_pack.size - 2) { - output_tensor = allocateSpace(tensor->data.shape, tensor->data.dtype, false); - } else if (i == nvte_aux_tensor_pack.size - 2) { - output_tensor = rng_state; - } else if (i == nvte_aux_tensor_pack.size - 1) { - output_tensor = Bias.value(); - } - } else { - output_tensor = (i < nvte_aux_tensor_pack.size - 1) - ? allocateSpace(tensor->data.shape, tensor->data.dtype, false) - : rng_state; - } - } else { - output_tensor = allocateSpace(tensor->data.shape, tensor->data.dtype, false); - } - output_tensors.push_back(output_tensor); - tensor->data.dptr = output_tensor.data_ptr(); - } - - // execute the kernel - nvte_fused_attn_fwd_qkvpacked( - te_QKV.data(), te_Bias.data(), te_S.data(), te_O.data(), &nvte_aux_tensor_pack, - te_cu_seqlens.data(), te_cu_seqlens_padded.data(), te_rng_state.data(), max_seqlen, - is_training, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, window_size[0], - window_size[1], workspace.data(), at::cuda::getCurrentCUDAStream()); - - // destroy tensor wrappers, but not allocated memory - nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); - - // if training, [O, softmax-related tensors, rng_state]; if inference, [O] - return output_tensors; -} - -// fused attention BWD with packed QKV -std::vector fused_attn_bwd_qkvpacked( - size_t max_seqlen, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const std::vector window_size, - bool deterministic, const at::Tensor cu_seqlens, const at::Tensor QKV, const at::Tensor O, - const at::Tensor dO, const transformer_engine::DType qkv_type, - const transformer_engine::DType dqkv_type, const std::vector Aux_CTX_Tensors, - const c10::optional cu_seqlens_padded, const c10::optional descale_QKV, - const c10::optional descale_S, const c10::optional descale_O, - const c10::optional descale_dO, const c10::optional descale_dP, - const c10::optional scale_S, const c10::optional scale_dP, - const c10::optional scale_dQKV, c10::optional amax_dP, - c10::optional amax_dQKV) { - using namespace transformer_engine; - - auto qkv_sizes = QKV.sizes().vec(); - std::vector qkv_shape{qkv_sizes.begin(), qkv_sizes.end()}; - std::vector q_shape; - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); - int loc_3 = 0; - switch (layout_group) { - case NVTE_3HD: - loc_3 = qkv_sizes.size() - 3; - break; - case NVTE_H3D: - loc_3 = qkv_sizes.size() - 2; - break; - default: - NVTE_ERROR("Invalid QKV layout group."); - } - for (auto it = qkv_shape.begin(); it != qkv_shape.end(); ++it) { - if (it - qkv_shape.begin() != loc_3) { - q_shape.push_back(*it); - } - } - auto h = q_shape[q_shape.size() - 2]; - - // create output tensor dQKV - auto options = torch::TensorOptions().dtype(GetATenDType(dqkv_type)).device(torch::kCUDA); - at::Tensor dQKV = torch::empty_like(QKV, options); - - // construct NVTE tensors - TensorWrapper te_QKV, te_O, te_dO, te_S, te_dP, te_dQKV; - if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { - // FP8 - auto d = q_shape[q_shape.size() - 1]; - if (set_zero && ((h * d) % block_size == 0) && - (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD)) { - mha_fill(dQKV, cu_seqlens.index({torch::indexing::Slice(-1, torch::indexing::None)})); - } else { - dQKV.fill_(0); - } - if ((!descale_QKV.has_value()) || (!descale_S.has_value()) || (!descale_O.has_value()) || - (!descale_dO.has_value()) || (!descale_dP.has_value()) || (!scale_S.has_value()) || - (!scale_dP.has_value()) || (!scale_dQKV.has_value()) || (!amax_dP.has_value()) || - (!amax_dQKV.has_value())) { - std::string err_tensors = "descale_QKV, descale_S, descale_O, descale_dO, descale_dP, "; - err_tensors = err_tensors + std::string("scale_S, scale_dP, scale_dQKV, "); - err_tensors = err_tensors + std::string("amax_dP and amax_dQKV "); - NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n")); - } - te_QKV = makeTransformerEngineTensor(QKV.data_ptr(), qkv_shape, qkv_type, nullptr, nullptr, - descale_QKV.value().data_ptr()); - te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, nullptr, nullptr, - descale_O.value().data_ptr()); - te_dO = makeTransformerEngineTensor(dO.data_ptr(), q_shape, dqkv_type, nullptr, nullptr, - descale_dO.value().data_ptr()); - te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, nullptr, - scale_S.value().data_ptr(), descale_S.value().data_ptr()); - te_dP = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, amax_dP.value().data_ptr(), - scale_dP.value().data_ptr(), descale_dP.value().data_ptr()); - te_dQKV = makeTransformerEngineTensor(dQKV.data_ptr(), qkv_shape, dqkv_type, - amax_dQKV.value().data_ptr(), - scale_dQKV.value().data_ptr(), nullptr); - } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { - if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) { - dQKV.fill_(0); - } - // BF16 or FP16 - te_QKV = - makeTransformerEngineTensor(QKV.data_ptr(), qkv_shape, qkv_type, nullptr, nullptr, nullptr); - te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, nullptr, nullptr, nullptr); - te_dO = - makeTransformerEngineTensor(dO.data_ptr(), q_shape, dqkv_type, nullptr, nullptr, nullptr); - te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, nullptr, nullptr, nullptr); - te_dP = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, nullptr, nullptr, nullptr); - te_dQKV = makeTransformerEngineTensor(dQKV.data_ptr(), qkv_shape, dqkv_type, nullptr, nullptr, - nullptr); - } else { - NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n"); - } - - // convert auxiliary tensors from forward into NVTETensors - NVTETensorPack nvte_aux_tensor_pack; - nvte_tensor_pack_create(&nvte_aux_tensor_pack); - nvte_aux_tensor_pack.size = Aux_CTX_Tensors.size(); - for (size_t i = 0; i < nvte_aux_tensor_pack.size; ++i) { - auto tensor = reinterpret_cast(nvte_aux_tensor_pack.tensors[i]); - tensor->data.dptr = Aux_CTX_Tensors[i].data_ptr(); - std::vector tmp(Aux_CTX_Tensors[i].sizes().vec()); - tensor->data.shape = std::vector(tmp.begin(), tmp.end()); - tensor->data.dtype = GetTransformerEngineDType(Aux_CTX_Tensors[i].scalar_type()); - } - - // create dBias the same shape as Bias - at::Tensor dBias; - TensorWrapper te_dBias; - if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { - if (nvte_aux_tensor_pack.size >= 2) { - std::vector bias_shape(Aux_CTX_Tensors[nvte_aux_tensor_pack.size - 1].sizes().vec()); - dBias = torch::empty(bias_shape, options); - te_dBias = makeTransformerEngineTensor(dBias); - } else { - dBias = torch::empty({1, static_cast(h), static_cast(max_seqlen), - static_cast(max_seqlen)}, - options); - te_dBias = makeTransformerEngineTensor(dBias); - } - if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) { - dBias.fill_(0); - } - } - - // create cu_seqlens tensorwrappers - auto cu_seqlens_sizes = cu_seqlens.sizes().vec(); - std::vector cu_seqlens_shape{cu_seqlens_sizes.begin(), cu_seqlens_sizes.end()}; - TensorWrapper te_cu_seqlens = makeTransformerEngineTensor( - cu_seqlens.data_ptr(), cu_seqlens_shape, DType::kInt32, nullptr, nullptr, nullptr); - - TensorWrapper te_cu_seqlens_padded; - if (cu_seqlens_padded.has_value()) { - auto cu_seqlens_padded_sizes = cu_seqlens_padded.value().sizes().vec(); - std::vector cu_seqlens_padded_shape{cu_seqlens_padded_sizes.begin(), - cu_seqlens_padded_sizes.end()}; - te_cu_seqlens_padded = - makeTransformerEngineTensor(cu_seqlens_padded.value().data_ptr(), cu_seqlens_padded_shape, - DType::kInt32, nullptr, nullptr, nullptr); - } - - // create workspace - TensorWrapper workspace; - - // populate tensors with appropriate shapes and dtypes - nvte_fused_attn_bwd_qkvpacked( - te_QKV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), &nvte_aux_tensor_pack, - te_dQKV.data(), te_dBias.data(), te_cu_seqlens.data(), te_cu_seqlens_padded.data(), - max_seqlen, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, window_size[0], - window_size[1], deterministic, workspace.data(), at::cuda::getCurrentCUDAStream()); - - // allocate memory for workspace - auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); - workspace = - makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype()); - - // execute kernel - nvte_fused_attn_bwd_qkvpacked( - te_QKV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), &nvte_aux_tensor_pack, - te_dQKV.data(), te_dBias.data(), te_cu_seqlens.data(), te_cu_seqlens_padded.data(), - max_seqlen, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, window_size[0], - window_size[1], deterministic, workspace.data(), at::cuda::getCurrentCUDAStream()); - - // destroy tensor wrappers - nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); - - return {dQKV, dBias}; -} - -// fused attention FWD with packed KV -std::vector fused_attn_fwd_kvpacked( +// fused attention FWD with separate Q, K and V tensors +std::vector fused_attn_fwd( size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const std::vector window_size, - const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const at::Tensor Q, - const at::Tensor KV, const transformer_engine::DType qkv_type, + const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q, + const py::handle K, const py::handle V, const at::ScalarType fake_dtype, const c10::optional cu_seqlens_q_padded, - const c10::optional cu_seqlens_kv_padded, - const c10::optional descale_QKV, const int descale_QKV_offset, - const c10::optional descale_S, const int descale_S_offset, - const c10::optional scale_S, const int scale_S_offset, - const c10::optional scale_O, const int scale_O_offset, - c10::optional amax_S, const int amax_S_offset, c10::optional amax_O, - const int amax_O_offset, const c10::optional Bias, + const c10::optional cu_seqlens_kv_padded, py::handle s_quantizer, + py::handle o_quantizer, const c10::optional Bias, const c10::optional rng_gen, size_t rng_elts_per_thread) { using namespace transformer_engine; + using namespace transformer_engine::pytorch; + TensorWrapper te_Q, te_K, te_V, te_O, te_S; - auto q_sizes = Q.sizes().vec(); - std::vector q_shape{q_sizes.begin(), q_sizes.end()}; - auto kv_sizes = KV.sizes().vec(); - std::vector kv_shape{kv_sizes.begin(), kv_sizes.end()}; - std::vector o_shape{q_shape.begin(), q_shape.end()}; - - // create output tensor O - auto options = torch::TensorOptions().dtype(GetATenDType(qkv_type)).device(torch::kCUDA); - auto O = torch::empty(o_shape, options); - - // construct NVTE tensors - TensorWrapper te_Q, te_KV, te_S, te_O, te_Bias, te_cu_seqlens_q, te_cu_seqlens_kv; - TensorWrapper te_cu_seqlens_q_padded, te_cu_seqlens_kv_padded; - if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { - // FP8 - auto h = q_shape[q_shape.size() - 2]; - auto d = q_shape[q_shape.size() - 1]; - if (set_zero && ((h * d) % block_size == 0) && - (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD)) { - mha_fill(O, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)})); - } else { - O.fill_(0); - } - if ((!descale_QKV.has_value()) || (!descale_S.has_value()) || (!scale_S.has_value()) || - (!scale_O.has_value()) || (!amax_S.has_value()) || (!amax_O.has_value())) { - std::string err_tensors = "descale_QKV, descale_S, scale_S, scale_O, amax_S and amax_O "; - NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n")); - } - te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape, qkv_type, nullptr, nullptr, - getDataPtr(descale_QKV.value(), descale_QKV_offset)); - te_KV = makeTransformerEngineTensor(KV.data_ptr(), kv_shape, qkv_type, nullptr, nullptr, - getDataPtr(descale_QKV.value(), descale_QKV_offset)); - te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, - getDataPtr(amax_S.value(), amax_S_offset), - getDataPtr(scale_S.value(), scale_S_offset), - getDataPtr(descale_S.value(), descale_S_offset)); - te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, - getDataPtr(amax_O.value(), amax_O_offset), - getDataPtr(scale_O.value(), scale_O_offset), nullptr); - } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { - if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) { - O.fill_(0); - } - // BF16 or FP16 - te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape, qkv_type, nullptr, nullptr, nullptr); - te_KV = - makeTransformerEngineTensor(KV.data_ptr(), kv_shape, qkv_type, nullptr, nullptr, nullptr); - te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, nullptr, nullptr, nullptr); - te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, nullptr, nullptr, nullptr); - } else { - NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n"); - } - if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI) && (Bias.has_value())) { - auto bias_sizes = Bias.value().sizes().vec(); - std::vector bias_shape{bias_sizes.begin(), bias_sizes.end()}; - te_Bias = makeTransformerEngineTensor(Bias.value().data_ptr(), bias_shape, DType::kFloat32, - nullptr, nullptr, nullptr); - } - auto cu_seqlens_q_sizes = cu_seqlens_q.sizes().vec(); - std::vector cu_seqlens_q_shape{cu_seqlens_q_sizes.begin(), cu_seqlens_q_sizes.end()}; - auto cu_seqlens_kv_sizes = cu_seqlens_kv.sizes().vec(); - std::vector cu_seqlens_kv_shape{cu_seqlens_kv_sizes.begin(), cu_seqlens_kv_sizes.end()}; - te_cu_seqlens_q = makeTransformerEngineTensor(cu_seqlens_q.data_ptr(), cu_seqlens_q_shape, - DType::kInt32, nullptr, nullptr, nullptr); - te_cu_seqlens_kv = makeTransformerEngineTensor(cu_seqlens_kv.data_ptr(), cu_seqlens_kv_shape, - DType::kInt32, nullptr, nullptr, nullptr); - - if ((cu_seqlens_q_padded.has_value()) && (cu_seqlens_kv_padded.has_value())) { - auto cu_seqlens_q_padded_sizes = cu_seqlens_q_padded.value().sizes().vec(); - std::vector cu_seqlens_q_padded_shape{cu_seqlens_q_padded_sizes.begin(), - cu_seqlens_q_padded_sizes.end()}; - auto cu_seqlens_kv_padded_sizes = cu_seqlens_kv_padded.value().sizes().vec(); - std::vector cu_seqlens_kv_padded_shape{cu_seqlens_kv_padded_sizes.begin(), - cu_seqlens_kv_padded_sizes.end()}; - te_cu_seqlens_q_padded = makeTransformerEngineTensor(cu_seqlens_q_padded.value().data_ptr(), - cu_seqlens_q_padded_shape, DType::kInt32, - nullptr, nullptr, nullptr); - te_cu_seqlens_kv_padded = makeTransformerEngineTensor(cu_seqlens_kv_padded.value().data_ptr(), - cu_seqlens_kv_padded_shape, DType::kInt32, - nullptr, nullptr, nullptr); - } - - // extract rng seed and offset - auto gen = at::get_generator_or_default( - rng_gen, at::cuda::detail::getDefaultCUDAGenerator()); - at::PhiloxCudaState philox_args = init_philox_state(gen, rng_elts_per_thread); - auto rng_state = torch::empty({2}, options.dtype(torch::kInt64)); - unpack<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>( - philox_args, static_cast(rng_state.data_ptr())); - auto te_rng_state = makeTransformerEngineTensor(rng_state); - - // create auxiliary output tensors - NVTETensorPack nvte_aux_tensor_pack; - nvte_tensor_pack_create(&nvte_aux_tensor_pack); - - // create workspace - TensorWrapper workspace; - - // populate tensors with appropriate shapes and dtypes - nvte_fused_attn_fwd_kvpacked( - te_Q.data(), te_KV.data(), te_Bias.data(), te_S.data(), te_O.data(), &nvte_aux_tensor_pack, - te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), - te_cu_seqlens_kv_padded.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, - attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, window_size[0], window_size[1], - workspace.data(), at::cuda::getCurrentCUDAStream()); - - // allocate memory for workspace and auxiliary output tensors - auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); - workspace = - makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype()); - - // output_tensors = [O, nvte_aux_tensor_pack.tensors] - std::vector output_tensors; - output_tensors.push_back(O); - for (size_t i = 0; i < nvte_aux_tensor_pack.size; ++i) { - auto tensor = reinterpret_cast(nvte_aux_tensor_pack.tensors[i]); - // allocate memory for nvte_aux_tensor_pack.tensors - at::Tensor output_tensor; - if (nvte_aux_tensor_pack.size >= 2) { - if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI) && (Bias.has_value())) { - if (i < nvte_aux_tensor_pack.size - 2) { - output_tensor = allocateSpace(tensor->data.shape, tensor->data.dtype, false); - } else if (i == nvte_aux_tensor_pack.size - 2) { - output_tensor = rng_state; - } else if (i == nvte_aux_tensor_pack.size - 1) { - output_tensor = Bias.value(); - } - } else { - output_tensor = (i < nvte_aux_tensor_pack.size - 1) - ? allocateSpace(tensor->data.shape, tensor->data.dtype, false) - : rng_state; - } - } else { - output_tensor = allocateSpace(tensor->data.shape, tensor->data.dtype, false); - } - output_tensors.push_back(output_tensor); - tensor->data.dptr = output_tensor.data_ptr(); - } - - // execute the kernel - nvte_fused_attn_fwd_kvpacked( - te_Q.data(), te_KV.data(), te_Bias.data(), te_S.data(), te_O.data(), &nvte_aux_tensor_pack, - te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), - te_cu_seqlens_kv_padded.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, - attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, window_size[0], window_size[1], - workspace.data(), at::cuda::getCurrentCUDAStream()); - - // destroy tensor wrappers, but not allocated memory - nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); - - // if training, [O, softmax-related tensors, rng_state]; if inference, [O] - return output_tensors; -} - -// fused attention BWD with packed KV -std::vector fused_attn_bwd_kvpacked( - size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - const std::vector window_size, bool deterministic, const at::Tensor cu_seqlens_q, - const at::Tensor cu_seqlens_kv, const at::Tensor Q, const at::Tensor KV, const at::Tensor O, - const at::Tensor dO, const transformer_engine::DType qkv_type, - const transformer_engine::DType dqkv_type, const std::vector Aux_CTX_Tensors, - const c10::optional cu_seqlens_q_padded, - const c10::optional cu_seqlens_kv_padded, - const c10::optional descale_QKV, const c10::optional descale_S, - const c10::optional descale_O, const c10::optional descale_dO, - const c10::optional descale_dP, const c10::optional scale_S, - const c10::optional scale_dP, const c10::optional scale_dQKV, - c10::optional amax_dP, c10::optional amax_dQKV) { - using namespace transformer_engine; - - auto q_sizes = Q.sizes().vec(); - std::vector q_shape{q_sizes.begin(), q_sizes.end()}; - auto kv_sizes = KV.sizes().vec(); - std::vector kv_shape{kv_sizes.begin(), kv_sizes.end()}; - std::vector k_shape; - for (auto i : kv_shape) { - if (i != 2) { - k_shape.push_back(i); - } - } - auto h_q = q_shape[q_shape.size() - 2]; - auto h_kv = k_shape[k_shape.size() - 2]; - auto d = q_shape[q_shape.size() - 1]; - - // create output tensors dQ and dKV - auto options = torch::TensorOptions().dtype(GetATenDType(dqkv_type)).device(torch::kCUDA); - at::Tensor dQ = torch::empty_like(Q, options); - at::Tensor dKV = torch::empty_like(KV, options); - - // construct NVTE tensors - TensorWrapper te_Q, te_KV, te_O, te_dO, te_S, te_dP, te_dQ, te_dKV; - if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { - // FP8 - if (set_zero && ((h_q * d) % block_size == 0) && ((h_kv * d) % block_size == 0) && - (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD)) { - mha_fill(dQ, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)})); - mha_fill(dKV, cu_seqlens_kv.index({torch::indexing::Slice(-1, torch::indexing::None)})); - } else { - dQ.fill_(0); - dKV.fill_(0); - } - if ((!descale_QKV.has_value()) || (!descale_S.has_value()) || (!descale_O.has_value()) || - (!descale_dO.has_value()) || (!descale_dP.has_value()) || (!scale_S.has_value()) || - (!scale_dP.has_value()) || (!scale_dQKV.has_value()) || (!amax_dP.has_value()) || - (!amax_dQKV.has_value())) { - std::string err_tensors = "descale_QKV, descale_S, descale_O, descale_dO, descale_dP, "; - err_tensors = err_tensors + std::string("scale_S, scale_dP, scale_dQKV, "); - err_tensors = err_tensors + std::string("amax_dP and amax_dQKV "); - NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n")); - } - te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape, qkv_type, nullptr, nullptr, - descale_QKV.value().data_ptr()); - te_KV = makeTransformerEngineTensor(KV.data_ptr(), kv_shape, qkv_type, nullptr, nullptr, - descale_QKV.value().data_ptr()); - te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, nullptr, nullptr, - descale_O.value().data_ptr()); - te_dO = makeTransformerEngineTensor(dO.data_ptr(), q_shape, dqkv_type, nullptr, nullptr, - descale_dO.value().data_ptr()); - te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, nullptr, - scale_S.value().data_ptr(), descale_S.value().data_ptr()); - te_dP = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, amax_dP.value().data_ptr(), - scale_dP.value().data_ptr(), descale_dP.value().data_ptr()); - te_dQ = - makeTransformerEngineTensor(dQ.data_ptr(), q_shape, dqkv_type, amax_dQKV.value().data_ptr(), - scale_dQKV.value().data_ptr(), nullptr); - te_dKV = makeTransformerEngineTensor(dKV.data_ptr(), kv_shape, dqkv_type, - amax_dQKV.value().data_ptr(), - scale_dQKV.value().data_ptr(), nullptr); - } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { - if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) { - dQ.fill_(0); - dKV.fill_(0); - } - // BF16 or FP16 - te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape, qkv_type, nullptr, nullptr, nullptr); - te_KV = - makeTransformerEngineTensor(KV.data_ptr(), kv_shape, qkv_type, nullptr, nullptr, nullptr); - te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, nullptr, nullptr, nullptr); - te_dO = - makeTransformerEngineTensor(dO.data_ptr(), q_shape, dqkv_type, nullptr, nullptr, nullptr); - te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, nullptr, nullptr, nullptr); - te_dP = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, nullptr, nullptr, nullptr); - te_dQ = - makeTransformerEngineTensor(dQ.data_ptr(), q_shape, dqkv_type, nullptr, nullptr, nullptr); - te_dKV = - makeTransformerEngineTensor(dKV.data_ptr(), kv_shape, dqkv_type, nullptr, nullptr, nullptr); - } else { - NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n"); - } - - // create cu_seqlens tensorwrappers - auto cu_seqlens_q_sizes = cu_seqlens_q.sizes().vec(); - std::vector cu_seqlens_q_shape{cu_seqlens_q_sizes.begin(), cu_seqlens_q_sizes.end()}; - auto cu_seqlens_kv_sizes = cu_seqlens_kv.sizes().vec(); - std::vector cu_seqlens_kv_shape{cu_seqlens_kv_sizes.begin(), cu_seqlens_kv_sizes.end()}; - TensorWrapper te_cu_seqlens_q, te_cu_seqlens_kv; - te_cu_seqlens_q = makeTransformerEngineTensor(cu_seqlens_q.data_ptr(), cu_seqlens_q_shape, - DType::kInt32, nullptr, nullptr, nullptr); - te_cu_seqlens_kv = makeTransformerEngineTensor(cu_seqlens_kv.data_ptr(), cu_seqlens_kv_shape, - DType::kInt32, nullptr, nullptr, nullptr); + auto none = py::none(); + std::unique_ptr S_quantizer = convert_quantizer(s_quantizer); + std::unique_ptr O_quantizer = convert_quantizer(o_quantizer); - TensorWrapper te_cu_seqlens_q_padded, te_cu_seqlens_kv_padded; - if ((cu_seqlens_q_padded.has_value()) && (cu_seqlens_kv_padded.has_value())) { - auto cu_seqlens_q_padded_sizes = cu_seqlens_q_padded.value().sizes().vec(); - std::vector cu_seqlens_q_padded_shape{cu_seqlens_q_padded_sizes.begin(), - cu_seqlens_q_padded_sizes.end()}; - auto cu_seqlens_kv_padded_sizes = cu_seqlens_kv_padded.value().sizes().vec(); - std::vector cu_seqlens_kv_padded_shape{cu_seqlens_kv_padded_sizes.begin(), - cu_seqlens_kv_padded_sizes.end()}; - te_cu_seqlens_q_padded = makeTransformerEngineTensor(cu_seqlens_q_padded.value().data_ptr(), - cu_seqlens_q_padded_shape, DType::kInt32, - nullptr, nullptr, nullptr); - te_cu_seqlens_kv_padded = makeTransformerEngineTensor(cu_seqlens_kv_padded.value().data_ptr(), - cu_seqlens_kv_padded_shape, DType::kInt32, - nullptr, nullptr, nullptr); - } + te_Q = makeTransformerEngineTensor(Q, none); + te_K = makeTransformerEngineTensor(K, none); + te_V = makeTransformerEngineTensor(V, none); - // convert auxiliary tensors from forward to NVTETensors - NVTETensorPack nvte_aux_tensor_pack; - nvte_tensor_pack_create(&nvte_aux_tensor_pack); - nvte_aux_tensor_pack.size = Aux_CTX_Tensors.size(); - for (size_t i = 0; i < nvte_aux_tensor_pack.size; ++i) { - auto tensor = reinterpret_cast(nvte_aux_tensor_pack.tensors[i]); - tensor->data.dptr = Aux_CTX_Tensors[i].data_ptr(); - std::vector tmp(Aux_CTX_Tensors[i].sizes().vec()); - tensor->data.shape = std::vector(tmp.begin(), tmp.end()); - tensor->data.dtype = GetTransformerEngineDType(Aux_CTX_Tensors[i].scalar_type()); - } + // If qkv has FP8 dtype, fake_dtype_te is equal to the fake dtype of q, k, v - needed since torch do not have fp8 types. + const transformer_engine::DType qkv_type = te_Q.dtype(); + const transformer_engine::DType fake_dtype_te = GetTransformerEngineDType(fake_dtype); - // create dBias the same shape as Bias - at::Tensor dBias; - TensorWrapper te_dBias; - if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { - if (nvte_aux_tensor_pack.size >= 2) { - std::vector bias_shape(Aux_CTX_Tensors[nvte_aux_tensor_pack.size - 1].sizes().vec()); - dBias = torch::empty(bias_shape, options); - te_dBias = makeTransformerEngineTensor(dBias); - } else { - dBias = torch::empty({1, static_cast(h_q), static_cast(max_seqlen_q), - static_cast(max_seqlen_kv)}, - options); - te_dBias = makeTransformerEngineTensor(dBias); - } - if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) { - dBias.fill_(0); - } - } - - // create workspace - TensorWrapper workspace; - - // populate tensors with appropriate shapes and dtypes - nvte_fused_attn_bwd_kvpacked(te_Q.data(), te_KV.data(), te_O.data(), te_dO.data(), te_S.data(), - te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), te_dKV.data(), - te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), - te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), - max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout, qkv_layout, - bias_type, attn_mask_type, window_size[0], window_size[1], - deterministic, workspace.data(), at::cuda::getCurrentCUDAStream()); - - // allocate memory for workspace - auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); - workspace = - makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype()); - - // execute kernel - nvte_fused_attn_bwd_kvpacked(te_Q.data(), te_KV.data(), te_O.data(), te_dO.data(), te_S.data(), - te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), te_dKV.data(), - te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), - te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), - max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout, qkv_layout, - bias_type, attn_mask_type, window_size[0], window_size[1], - deterministic, workspace.data(), at::cuda::getCurrentCUDAStream()); - - // destroy tensor wrappers - nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); - - return {dQ, dKV, dBias}; -} - -// fused attention FWD with separate Q, K and V tensors -std::vector fused_attn_fwd( - size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout, - bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, const std::vector window_size, - const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const at::Tensor Q, - const at::Tensor K, const at::Tensor V, const transformer_engine::DType qkv_type, - const c10::optional cu_seqlens_q_padded, - const c10::optional cu_seqlens_kv_padded, - const c10::optional descale_QKV, const int descale_QKV_offset, - const c10::optional descale_S, const int descale_S_offset, - const c10::optional scale_S, const int scale_S_offset, - const c10::optional scale_O, const int scale_O_offset, - c10::optional amax_S, const int amax_S_offset, c10::optional amax_O, - const int amax_O_offset, const c10::optional Bias, - const c10::optional rng_gen, size_t rng_elts_per_thread) { - using namespace transformer_engine; - - auto q_sizes = Q.sizes().vec(); - std::vector q_shape{q_sizes.begin(), q_sizes.end()}; - auto k_sizes = K.sizes().vec(); - std::vector k_shape{k_sizes.begin(), k_sizes.end()}; - auto v_sizes = V.sizes().vec(); - std::vector v_shape{v_sizes.begin(), v_sizes.end()}; - - // create output tensor O + std::vector q_shape = convertShape(te_Q.shape()); + std::vector k_shape = convertShape(te_K.shape()); + std::vector v_shape = convertShape(te_V.shape()); auto options = torch::TensorOptions().dtype(GetATenDType(qkv_type)).device(torch::kCUDA); - auto o_shape = std::vector{q_sizes.begin(), q_sizes.end()}; - o_shape[o_shape.size() - 1] = v_sizes[v_sizes.size() - 1]; - std::vector o_shape_tmp{o_shape.begin(), o_shape.end()}; - auto O = torch::empty(c10::IntArrayRef(o_shape_tmp), options); + // create output tensor O + + auto o_shape = std::vector{q_shape.begin(), q_shape.end()}; + o_shape[o_shape.size() - 1] = v_shape[v_shape.size() - 1]; + py::object o_python, s_python; + std::tie(te_O, o_python) = O_quantizer->create_tensor(o_shape, fake_dtype_te); + std::tie(te_S, s_python) = S_quantizer->create_tensor({0}, DType::kFloat32); + auto o_shape_int64 = std::vector{o_shape.begin(), o_shape.end()}; // construct NVTE tensors - TensorWrapper te_Q, te_K, te_V, te_S, te_O, te_Bias; + TensorWrapper te_Bias; TensorWrapper te_cu_seqlens_q, te_cu_seqlens_kv; TensorWrapper te_cu_seqlens_q_padded, te_cu_seqlens_kv_padded; if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { @@ -817,55 +136,30 @@ std::vector fused_attn_fwd( auto d = q_shape[q_shape.size() - 1]; if (set_zero && ((h * d) % block_size == 0) && (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD)) { - mha_fill(O, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)})); + mha_fill(te_O, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)})); } else { - O.fill_(0); - } - if ((!descale_QKV.has_value()) || (!descale_S.has_value()) || (!scale_S.has_value()) || - (!scale_O.has_value()) || (!amax_S.has_value()) || (!amax_O.has_value())) { - std::string err_tensors = "descale_QKV, descale_S, scale_S, scale_O, amax_S and amax_O "; - NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n")); + te_O.zero_(at::cuda::getCurrentCUDAStream()); } - te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape, qkv_type, nullptr, nullptr, - getDataPtr(descale_QKV.value(), descale_QKV_offset)); - te_K = makeTransformerEngineTensor(K.data_ptr(), k_shape, qkv_type, nullptr, nullptr, - getDataPtr(descale_QKV.value(), descale_QKV_offset)); - te_V = makeTransformerEngineTensor(V.data_ptr(), v_shape, qkv_type, nullptr, nullptr, - getDataPtr(descale_QKV.value(), descale_QKV_offset)); - te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, - getDataPtr(amax_S.value(), amax_S_offset), - getDataPtr(scale_S.value(), scale_S_offset), - getDataPtr(descale_S.value(), descale_S_offset)); - te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, - getDataPtr(amax_O.value(), amax_O_offset), - getDataPtr(scale_O.value(), scale_O_offset), nullptr); } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) { - O.fill_(0); + te_O.zero_(at::cuda::getCurrentCUDAStream()); } - // BF16 or FP16 - te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape, qkv_type, nullptr, nullptr, nullptr); - te_K = makeTransformerEngineTensor(K.data_ptr(), k_shape, qkv_type, nullptr, nullptr, nullptr); - te_V = makeTransformerEngineTensor(V.data_ptr(), v_shape, qkv_type, nullptr, nullptr, nullptr); - te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, nullptr, nullptr, nullptr); - te_O = makeTransformerEngineTensor(O.data_ptr(), o_shape, qkv_type, nullptr, nullptr, nullptr); } else { NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n"); } if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI) && (Bias.has_value())) { auto bias_sizes = Bias.value().sizes().vec(); std::vector bias_shape{bias_sizes.begin(), bias_sizes.end()}; - te_Bias = makeTransformerEngineTensor(Bias.value().data_ptr(), bias_shape, DType::kFloat32, - nullptr, nullptr, nullptr); + te_Bias = makeTransformerEngineTensor(Bias.value().data_ptr(), bias_shape, DType::kFloat32); } auto cu_seqlens_q_sizes = cu_seqlens_q.sizes().vec(); std::vector cu_seqlens_q_shape{cu_seqlens_q_sizes.begin(), cu_seqlens_q_sizes.end()}; auto cu_seqlens_kv_sizes = cu_seqlens_kv.sizes().vec(); std::vector cu_seqlens_kv_shape{cu_seqlens_kv_sizes.begin(), cu_seqlens_kv_sizes.end()}; - te_cu_seqlens_q = makeTransformerEngineTensor(cu_seqlens_q.data_ptr(), cu_seqlens_q_shape, - DType::kInt32, nullptr, nullptr, nullptr); - te_cu_seqlens_kv = makeTransformerEngineTensor(cu_seqlens_kv.data_ptr(), cu_seqlens_kv_shape, - DType::kInt32, nullptr, nullptr, nullptr); + te_cu_seqlens_q = + makeTransformerEngineTensor(cu_seqlens_q.data_ptr(), cu_seqlens_q_shape, DType::kInt32); + te_cu_seqlens_kv = + makeTransformerEngineTensor(cu_seqlens_kv.data_ptr(), cu_seqlens_kv_shape, DType::kInt32); if ((cu_seqlens_q_padded.has_value()) && (cu_seqlens_kv_padded.has_value())) { auto cu_seqlens_q_padded_sizes = cu_seqlens_q_padded.value().sizes().vec(); @@ -875,11 +169,9 @@ std::vector fused_attn_fwd( std::vector cu_seqlens_kv_padded_shape{cu_seqlens_kv_padded_sizes.begin(), cu_seqlens_kv_padded_sizes.end()}; te_cu_seqlens_q_padded = makeTransformerEngineTensor(cu_seqlens_q_padded.value().data_ptr(), - cu_seqlens_q_padded_shape, DType::kInt32, - nullptr, nullptr, nullptr); - te_cu_seqlens_kv_padded = makeTransformerEngineTensor(cu_seqlens_kv_padded.value().data_ptr(), - cu_seqlens_kv_padded_shape, DType::kInt32, - nullptr, nullptr, nullptr); + cu_seqlens_q_padded_shape, DType::kInt32); + te_cu_seqlens_kv_padded = makeTransformerEngineTensor( + cu_seqlens_kv_padded.value().data_ptr(), cu_seqlens_kv_padded_shape, DType::kInt32); } // extract rng seed and offset @@ -913,8 +205,8 @@ std::vector fused_attn_fwd( makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype()); // output_tensors = [O, nvte_aux_tensor_pack.tensors] - std::vector output_tensors; - output_tensors.push_back(O); + std::vector output_tensors; + output_tensors.push_back(o_python); for (size_t i = 0; i < nvte_aux_tensor_pack.size; ++i) { auto tensor = reinterpret_cast(nvte_aux_tensor_pack.tensors[i]); // allocate memory for nvte_aux_tensor_pack.tensors @@ -936,7 +228,7 @@ std::vector fused_attn_fwd( } else { output_tensor = allocateSpace(tensor->data.shape, tensor->data.dtype, false); } - output_tensors.push_back(output_tensor); + output_tensors.push_back(py::cast(output_tensor)); tensor->data.dptr = output_tensor.data_ptr(); } @@ -957,45 +249,55 @@ std::vector fused_attn_fwd( } // fused attention BWD with separate Q, K and V -std::vector fused_attn_bwd( +std::vector fused_attn_bwd( size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const std::vector window_size, bool deterministic, const at::Tensor cu_seqlens_q, - const at::Tensor cu_seqlens_kv, const at::Tensor Q, const at::Tensor K, const at::Tensor V, - const at::Tensor O, const at::Tensor dO, const transformer_engine::DType qkv_type, + const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V, + const py::handle O, const py::handle dO, const at::ScalarType fake_dtype, const transformer_engine::DType dqkv_type, const std::vector Aux_CTX_Tensors, const c10::optional cu_seqlens_q_padded, - const c10::optional cu_seqlens_kv_padded, - const c10::optional descale_QKV, const c10::optional descale_S, - const c10::optional descale_O, const c10::optional descale_dO, - const c10::optional descale_dP, const c10::optional scale_S, - const c10::optional scale_dP, const c10::optional scale_dQKV, - c10::optional amax_dP, c10::optional amax_dQKV) { + const c10::optional cu_seqlens_kv_padded, py::handle s_quantizer, + py::handle dp_quantizer, py::handle dqkv_quantizer) { using namespace transformer_engine; - - auto q_sizes = Q.sizes().vec(); - std::vector q_shape{q_sizes.begin(), q_sizes.end()}; - auto k_sizes = K.sizes().vec(); - std::vector k_shape{k_sizes.begin(), k_sizes.end()}; - auto v_sizes = V.sizes().vec(); - std::vector v_shape{v_sizes.begin(), v_sizes.end()}; + using namespace transformer_engine::pytorch; + auto none = py::none(); + TensorWrapper te_Q, te_K, te_V, te_O, te_dO, te_S, te_dP, te_dQ, te_dK, te_dV; + te_Q = makeTransformerEngineTensor(Q, none); + te_K = makeTransformerEngineTensor(K, none); + te_V = makeTransformerEngineTensor(V, none); + te_O = makeTransformerEngineTensor(O, none); + te_dO = makeTransformerEngineTensor(dO, none); + // qkv type from the te_Q + std::unique_ptr dQKV_quantizer = convert_quantizer(dqkv_quantizer); + const transformer_engine::DType qkv_type = te_Q.dtype(); + const transformer_engine::DType fake_dtype_te = GetTransformerEngineDType(fake_dtype); + + py::object s_python, dp_python; + std::unique_ptr S_quantizer = convert_quantizer(s_quantizer); + std::unique_ptr dP_quantizer = convert_quantizer(dp_quantizer); + std::tie(te_S, s_python) = S_quantizer->create_tensor({0}, DType::kFloat32); + std::tie(te_dP, dp_python) = dP_quantizer->create_tensor({0}, DType::kFloat32); + + std::vector q_shape = convertShape(te_Q.shape()); + std::vector k_shape = convertShape(te_K.shape()); + std::vector v_shape = convertShape(te_V.shape()); auto h_q = q_shape[q_shape.size() - 2]; auto h_kv = k_shape[k_shape.size() - 2]; auto d_qk = q_shape[q_shape.size() - 1]; auto d_v = v_shape[v_shape.size() - 1]; auto options = torch::TensorOptions().dtype(GetATenDType(dqkv_type)).device(torch::kCUDA); - std::vector o_shape{q_sizes.begin(), q_sizes.end()}; + std::vector o_shape{q_shape.begin(), q_shape.end()}; o_shape[o_shape.size() - 1] = d_v; - at::Tensor dQ; - at::Tensor dK; - at::Tensor dV; - at::Tensor dQKV, dKV; + at::Tensor dQ, dK, dV, dQKV, dKV; + py::object py_dQ, py_dK, py_dV; NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); std::vector tmp_shape; + switch (layout_group) { case NVTE_QKV_Layout_Group::NVTE_3HD: - tmp_shape = std::vector{q_sizes.begin(), q_sizes.end()}; + tmp_shape = std::vector{q_shape.begin(), q_shape.end()}; tmp_shape.insert(tmp_shape.begin() + tmp_shape.size() - 2, int64_t(3)); dQKV = torch::empty(c10::IntArrayRef(tmp_shape), options); dQ = dQKV.index({"...", torch::indexing::Slice(0, 1, 1), @@ -1012,7 +314,7 @@ std::vector fused_attn_bwd( .squeeze(tmp_shape.size() - 3); break; case NVTE_QKV_Layout_Group::NVTE_H3D: - tmp_shape = std::vector{q_sizes.begin(), q_sizes.end()}; + tmp_shape = std::vector{q_shape.begin(), q_shape.end()}; tmp_shape.insert(tmp_shape.begin() + tmp_shape.size() - 1, int64_t(3)); dQKV = torch::empty(c10::IntArrayRef(tmp_shape), options); dQ = dQKV.index({"...", torch::indexing::Slice(0, 1, 1), @@ -1026,8 +328,9 @@ std::vector fused_attn_bwd( .squeeze(tmp_shape.size() - 2); break; case NVTE_QKV_Layout_Group::NVTE_HD_2HD: - dQ = torch::empty_like(Q, options); - tmp_shape = std::vector{k_sizes.begin(), k_sizes.end()}; + tmp_shape = std::vector(q_shape.begin(), q_shape.end()); + dQ = torch::empty(tmp_shape, options); + tmp_shape = std::vector{k_shape.begin(), k_shape.end()}; tmp_shape.insert(tmp_shape.begin() + tmp_shape.size() - 2, int64_t(2)); dKV = torch::empty(c10::IntArrayRef(tmp_shape), options); dK = dKV.index({"...", torch::indexing::Slice(0, 1, 1), @@ -1040,8 +343,9 @@ std::vector fused_attn_bwd( .squeeze(tmp_shape.size() - 3); break; case NVTE_QKV_Layout_Group::NVTE_HD_H2D: - dQ = torch::empty_like(Q, options); - tmp_shape = std::vector{k_sizes.begin(), k_sizes.end()}; + tmp_shape = std::vector(q_shape.begin(), q_shape.end()); + dQ = torch::empty(tmp_shape, options); + tmp_shape = std::vector{k_shape.begin(), k_shape.end()}; tmp_shape.insert(tmp_shape.begin() + tmp_shape.size() - 1, int64_t(2)); dKV = torch::empty(c10::IntArrayRef(tmp_shape), options); dK = dKV.index({"...", torch::indexing::Slice(0, 1, 1), @@ -1052,82 +356,41 @@ std::vector fused_attn_bwd( .squeeze(tmp_shape.size() - 2); break; case NVTE_QKV_Layout_Group::NVTE_HD_HD_HD: - dQ = torch::empty_like(Q, options); - dK = torch::empty_like(K, options); - dV = torch::empty_like(V, options); + tmp_shape = std::vector(q_shape.begin(), q_shape.end()); + dQ = torch::empty(tmp_shape, options); + tmp_shape = std::vector(k_shape.begin(), k_shape.end()); + dK = torch::empty(tmp_shape, options); + tmp_shape = std::vector(v_shape.begin(), v_shape.end()); + dV = torch::empty(tmp_shape, options); break; default: NVTE_ERROR("QKV layout not supported!"); } + std::tie(te_dQ, py_dQ) = dQKV_quantizer->create_tensor(q_shape, fake_dtype_te, dQ); + std::tie(te_dK, py_dK) = dQKV_quantizer->create_tensor(k_shape, fake_dtype_te, dK); + std::tie(te_dV, py_dV) = dQKV_quantizer->create_tensor(v_shape, fake_dtype_te, dV); // construct NVTE tensors - TensorWrapper te_Q, te_K, te_V, te_O, te_dO, te_S, te_dP, te_dQ, te_dK, te_dV; if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { // FP8 if (set_zero && ((h_q * d_qk) % block_size == 0) && ((h_kv * d_qk) % block_size == 0) && dQ.is_contiguous() && dK.is_contiguous() && dV.is_contiguous() && (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD)) { - mha_fill(dQ, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)})); - mha_fill(dK, cu_seqlens_kv.index({torch::indexing::Slice(-1, torch::indexing::None)})); - mha_fill(dV, cu_seqlens_kv.index({torch::indexing::Slice(-1, torch::indexing::None)})); + mha_fill(te_dQ, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)})); + mha_fill(te_dK, cu_seqlens_kv.index({torch::indexing::Slice(-1, torch::indexing::None)})); + mha_fill(te_dV, cu_seqlens_kv.index({torch::indexing::Slice(-1, torch::indexing::None)})); } else { dQ.fill_(0); dK.fill_(0); dV.fill_(0); } - if ((!descale_QKV.has_value()) || (!descale_S.has_value()) || (!descale_O.has_value()) || - (!descale_dO.has_value()) || (!descale_dP.has_value()) || (!scale_S.has_value()) || - (!scale_dP.has_value()) || (!scale_dQKV.has_value()) || (!amax_dP.has_value()) || - (!amax_dQKV.has_value())) { - std::string err_tensors = "descale_QKV, descale_S, descale_O, descale_dO, descale_dP, "; - err_tensors = err_tensors + std::string("scale_S, scale_dP, scale_dQKV, "); - err_tensors = err_tensors + std::string("amax_dP and amax_dQKV "); - NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n")); - } - te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape, qkv_type, nullptr, nullptr, - descale_QKV.value().data_ptr()); - te_K = makeTransformerEngineTensor(K.data_ptr(), k_shape, qkv_type, nullptr, nullptr, - descale_QKV.value().data_ptr()); - te_V = makeTransformerEngineTensor(V.data_ptr(), v_shape, qkv_type, nullptr, nullptr, - descale_QKV.value().data_ptr()); - te_O = makeTransformerEngineTensor(O.data_ptr(), o_shape, qkv_type, nullptr, nullptr, - descale_O.value().data_ptr()); - te_dO = makeTransformerEngineTensor(dO.data_ptr(), o_shape, dqkv_type, nullptr, nullptr, - descale_dO.value().data_ptr()); - te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, nullptr, - scale_S.value().data_ptr(), descale_S.value().data_ptr()); - te_dP = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, amax_dP.value().data_ptr(), - scale_dP.value().data_ptr(), descale_dP.value().data_ptr()); - te_dQ = - makeTransformerEngineTensor(dQ.data_ptr(), q_shape, dqkv_type, amax_dQKV.value().data_ptr(), - scale_dQKV.value().data_ptr(), nullptr); - te_dK = - makeTransformerEngineTensor(dK.data_ptr(), k_shape, dqkv_type, amax_dQKV.value().data_ptr(), - scale_dQKV.value().data_ptr(), nullptr); - te_dV = - makeTransformerEngineTensor(dV.data_ptr(), v_shape, dqkv_type, amax_dQKV.value().data_ptr(), - scale_dQKV.value().data_ptr(), nullptr); + } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) { dQ.fill_(0); dK.fill_(0); dV.fill_(0); } - // BF16 or FP16 - te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape, qkv_type, nullptr, nullptr, nullptr); - te_K = makeTransformerEngineTensor(K.data_ptr(), k_shape, qkv_type, nullptr, nullptr, nullptr); - te_V = makeTransformerEngineTensor(V.data_ptr(), v_shape, qkv_type, nullptr, nullptr, nullptr); - te_O = makeTransformerEngineTensor(O.data_ptr(), o_shape, qkv_type, nullptr, nullptr, nullptr); - te_dO = - makeTransformerEngineTensor(dO.data_ptr(), o_shape, dqkv_type, nullptr, nullptr, nullptr); - te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, nullptr, nullptr, nullptr); - te_dP = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, nullptr, nullptr, nullptr); - te_dQ = - makeTransformerEngineTensor(dQ.data_ptr(), q_shape, dqkv_type, nullptr, nullptr, nullptr); - te_dK = - makeTransformerEngineTensor(dK.data_ptr(), k_shape, dqkv_type, nullptr, nullptr, nullptr); - te_dV = - makeTransformerEngineTensor(dV.data_ptr(), v_shape, dqkv_type, nullptr, nullptr, nullptr); } else { NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n"); } @@ -1152,11 +415,9 @@ std::vector fused_attn_bwd( std::vector cu_seqlens_kv_padded_shape{cu_seqlens_kv_padded_sizes.begin(), cu_seqlens_kv_padded_sizes.end()}; te_cu_seqlens_q_padded = makeTransformerEngineTensor(cu_seqlens_q_padded.value().data_ptr(), - cu_seqlens_q_padded_shape, DType::kInt32, - nullptr, nullptr, nullptr); - te_cu_seqlens_kv_padded = makeTransformerEngineTensor(cu_seqlens_kv_padded.value().data_ptr(), - cu_seqlens_kv_padded_shape, DType::kInt32, - nullptr, nullptr, nullptr); + cu_seqlens_q_padded_shape, DType::kInt32); + te_cu_seqlens_kv_padded = makeTransformerEngineTensor( + cu_seqlens_kv_padded.value().data_ptr(), cu_seqlens_kv_padded_shape, DType::kInt32); } // convert auxiliary tensors from forward to NVTETensors @@ -1219,7 +480,7 @@ std::vector fused_attn_bwd( // destroy tensor wrappers nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); - return {dQ, dK, dV, dBias}; + return {py_dQ, py_dK, py_dV, py::cast(dBias)}; } namespace flash_attention { diff --git a/transformer_engine/pytorch/csrc/extensions/bias.cpp b/transformer_engine/pytorch/csrc/extensions/bias.cpp new file mode 100644 index 0000000000..a1fe8bd2b5 --- /dev/null +++ b/transformer_engine/pytorch/csrc/extensions/bias.cpp @@ -0,0 +1,51 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "common.h" +#include "pybind.h" +#include "transformer_engine/cast.h" + +namespace transformer_engine::pytorch { + +std::vector bgrad_quantize(const at::Tensor& input, py::handle py_quantizer) { + auto quantizer = convert_quantizer(py_quantizer); + + auto input_tensor = makeTransformerEngineTensor(input); + + auto dbias = allocateTorchTensor(input.size(-1), input_tensor.dtype()); + + std::vector output_shape; + for (auto s : input.sizes()) { + output_shape.emplace_back(static_cast(s)); + } + auto [out_tensor, out] = quantizer->create_tensor(output_shape, input_tensor.dtype()); + + // Return immediately if tensors are empty + if (product(output_shape) == 0) { + return {py::cast(dbias.zero_()), out}; + } + + auto dbias_tensor = makeTransformerEngineTensor(dbias); + // Query workspace size and allocate workspace + transformer_engine::TensorWrapper workspace; + nvte_quantize_dbias(input_tensor.data(), out_tensor.data(), dbias_tensor.data(), workspace.data(), + at::cuda::getCurrentCUDAStream()); + + void* workspace_data_ptr = nullptr; + if (workspace.shape().ndim > 0) { + auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); + workspace_data_ptr = workspace_data.data_ptr(); + } + workspace = makeTransformerEngineTensor(workspace_data_ptr, workspace.shape(), workspace.dtype()); + + // Launch kernel + nvte_quantize_dbias(input_tensor.data(), out_tensor.data(), dbias_tensor.data(), workspace.data(), + at::cuda::getCurrentCUDAStream()); + + return {py::cast(dbias), out}; +} + +} // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 771fa4920a..66dafdaafb 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -4,69 +4,126 @@ * See LICENSE for license information. ************************************************************************/ +#include "transformer_engine/cast.h" + +#include "common.h" #include "extensions.h" +#include "pybind.h" +#include "transformer_engine/transformer_engine.h" + +namespace transformer_engine::pytorch { + +py::object quantize(const at::Tensor& tensor, py::handle quantizer, const py::object& output, + std::optional noop) { + init_extension(); + auto my_quantizer = convert_quantizer(quantizer); + auto input_tensor = tensor.contiguous(); + + const TensorWrapper& te_input = makeTransformerEngineTensor(input_tensor); + const auto& te_input_shape = te_input.shape(); + std::vector input_shape(te_input_shape.data, te_input_shape.data + te_input_shape.ndim); + auto fake_tensor_type = tensor.scalar_type(); + if (!detail::IsFloatingPointType(fake_tensor_type)) { + fake_tensor_type = at::kFloat; + } + + TensorWrapper te_output; + py::object out; + if (output.is_none()) { + DType fake_te_type = GetTransformerEngineDType(fake_tensor_type); + std::tie(te_output, out) = my_quantizer->create_tensor(input_shape, fake_te_type); + } else { + out = output; + te_output = makeTransformerEngineTensor(output, quantizer); + } + + TensorWrapper te_noop; + if (noop.has_value()) { + te_noop = makeTransformerEngineTensor(*noop); + } else { + te_noop = TensorWrapper(); + } + + if (te_output.numel() == 0) return out; + nvte_quantize_noop(te_input.data(), te_output.data(), te_noop.data(), + at::cuda::getCurrentCUDAStream()); + + return out; +} + +py::object dequantize(const py::handle& input, transformer_engine::DType otype) { + init_extension(); -at::Tensor cast_to_fp8(const at::Tensor& input, const at::Tensor& scale, at::Tensor amax, - at::Tensor scale_inv, transformer_engine::DType otype, - const int scale_offset, const int amax_offset, const int scale_inv_offset) { - using namespace transformer_engine; - auto input_shape = input.sizes().vec(); - std::vector shape{input_shape.begin(), input_shape.end()}; + const auto none = py::none(); - auto output = at::empty_like(input, at::CUDA(GetATenDType(otype))); + const auto& input_tensor = makeTransformerEngineTensor(input, none); - if (input.numel() == 0) return output; + NoneQuantizer q(none); - // Get pointers for FP8 scale, amax, scale-inverse - void* scale_dptr = getDataPtr(scale, scale_offset); - void* amax_dptr = getDataPtr(amax, amax_offset); - void* scale_inv_dptr = getDataPtr(scale_inv, scale_inv_offset); + const auto& shape = convertShape(input_tensor.shape()); - auto input_cu = makeTransformerEngineTensor(input); - auto output_cu = makeTransformerEngineTensor(output.data_ptr(), shape, otype, amax_dptr, - scale_dptr, scale_inv_dptr); + auto [out_tensor, out] = q.create_tensor(shape, otype); - nvte_fp8_quantize(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); + nvte_dequantize(input_tensor.data(), out_tensor.data(), at::cuda::getCurrentCUDAStream()); - return output; + return out; } -void cast_to_fp8_noalloc(const at::Tensor& input, const at::Tensor& scale, at::Tensor output, - at::Tensor amax, at::Tensor scale_inv, transformer_engine::DType otype, - const int scale_offset, const int amax_offset, - const int scale_inv_offset) { - using namespace transformer_engine; - size_t N = static_cast(input.size(0)); - size_t H = static_cast(input.size(1)); +template +std::vector dbias_dact(const at::Tensor& grad_output, const at::Tensor& act_input, + py::handle quantizer) { + init_extension(); + auto my_quantizer = convert_quantizer(quantizer); + + auto grad_tensor = makeTransformerEngineTensor(grad_output); + + auto grad_bias = allocateTorchTensor(grad_output.size(-1), grad_tensor.dtype()); + auto act_input_tensor = makeTransformerEngineTensor(act_input); + + const auto& shape = convertShape(grad_tensor.shape()); + auto [dact_tensor, dact] = my_quantizer->create_tensor(shape, act_input_tensor.dtype()); - // Get pointers for FP8 scale, amax, scale-inverse - void* scale_dptr = getDataPtr(scale, scale_offset); - void* amax_dptr = getDataPtr(amax, amax_offset); - void* scale_inv_dptr = getDataPtr(scale_inv, scale_inv_offset); + auto dbias_tensor = makeTransformerEngineTensor(grad_bias); - auto input_cu = makeTransformerEngineTensor(input); - auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {N, H}, otype, amax_dptr, - scale_dptr, scale_inv_dptr); + // Query workspace size and allocate workspace + transformer_engine::TensorWrapper workspace; + func(grad_tensor.data(), act_input_tensor.data(), dact_tensor.data(), dbias_tensor.data(), + workspace.data(), at::cuda::getCurrentCUDAStream()); + auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); + workspace = + makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype()); - nvte_fp8_quantize(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); + // Launch kernel + func(grad_tensor.data(), act_input_tensor.data(), dact_tensor.data(), dbias_tensor.data(), + workspace.data(), at::cuda::getCurrentCUDAStream()); - return; + return {py::cast(grad_bias), dact}; } -at::Tensor cast_from_fp8(const at::Tensor& input, const at::Tensor& scale_inv, - transformer_engine::DType itype, transformer_engine::DType otype, - const int scale_inv_offset) { - using namespace transformer_engine; - auto input_shape = input.sizes().vec(); - std::vector shape{input_shape.begin(), input_shape.end()}; +std::vector dbias_dgelu(const at::Tensor& grad_output, const at::Tensor& act_input, + py::handle quantizer) { + return dbias_dact(grad_output, act_input, quantizer); +} - auto output = at::empty_like(input, at::CUDA(GetATenDType(otype))); +std::vector dbias_dsilu(const at::Tensor& grad_output, const at::Tensor& act_input, + py::handle quantizer) { + return dbias_dact(grad_output, act_input, quantizer); +} - auto input_cu = makeTransformerEngineTensor(input.data_ptr(), shape, itype, nullptr, nullptr, - getDataPtr(scale_inv, scale_inv_offset)); - auto output_cu = makeTransformerEngineTensor(output); +std::vector dbias_drelu(const at::Tensor& grad_output, const at::Tensor& act_input, + py::handle quantizer) { + return dbias_dact(grad_output, act_input, quantizer); +} - nvte_fp8_dequantize(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); +std::vector dbias_dqgelu(const at::Tensor& grad_output, const at::Tensor& act_input, + py::handle quantizer) { + return dbias_dact(grad_output, act_input, quantizer); +} - return output; +std::vector dbias_dsrelu(const at::Tensor& grad_output, const at::Tensor& act_input, + py::handle quantizer) { + return dbias_dact(grad_output, act_input, quantizer); } + +} // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp b/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp index 6b54f2de69..30126651ce 100644 --- a/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp +++ b/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp @@ -5,6 +5,7 @@ ************************************************************************/ #include "../extensions.h" +#include "transformer_engine/transformer_engine.h" #define HALF_BYTES 2 #define UB_MAX_SM 32 @@ -14,50 +15,6 @@ using namespace std::placeholders; namespace te = transformer_engine; -#define MAKE_TRANSFORMER_ENGINE_TENSORS(A, A_scale_inv, A_fp8_index, A_type, B, B_scale_inv, \ - B_fp8_index, B_type, D, D_amax, D_scale, D_type, bias, \ - bias_type, pre_gelu_out, workspace) \ - A = A.contiguous(); \ - void *A_scale_inv_ptr = nullptr; \ - if (te::is_fp8_dtype(A_type)) { \ - assert(A_scale_inv.numel()); \ - A_scale_inv_ptr = A_scale_inv[A_fp8_index].data_ptr(); \ - } \ - auto A_ = makeTransformerEngineTensor( \ - A.data_ptr(), {static_cast(A.size(0)), static_cast(A.size(1))}, A_type, \ - nullptr, nullptr, A_scale_inv_ptr); \ - B = B.contiguous(); \ - void *B_scale_inv_ptr = nullptr; \ - if (te::is_fp8_dtype(B_type)) { \ - assert(B_scale_inv.numel()); \ - B_scale_inv_ptr = B_scale_inv[B_fp8_index].data_ptr(); \ - } \ - auto B_ = makeTransformerEngineTensor( \ - B.data_ptr(), {static_cast(B.size(0)), static_cast(B.size(1))}, B_type, \ - nullptr, nullptr, B_scale_inv_ptr); \ - void *D_amax_ptr = nullptr; \ - void *D_scale_ptr = nullptr; \ - if (te::is_fp8_dtype(D_type)) { \ - assert(D_amax.numel()); \ - D_amax_ptr = D_amax.data_ptr(); \ - assert(D_scale.numel()); \ - D_scale_ptr = D_scale.data_ptr(); \ - } \ - auto D_ = makeTransformerEngineTensor( \ - D.data_ptr(), {static_cast(D.size(0)), static_cast(D.size(1))}, D_type, \ - D_amax_ptr, D_scale_ptr, nullptr); \ - auto bias_ = makeTransformerEngineTensor( \ - bias.data_ptr(), std::vector{static_cast(bias.size(0))}, bias_type); \ - const auto gelu_shape = (pre_gelu_out.data_ptr() == nullptr) \ - ? std::vector{static_cast(pre_gelu_out.size(0))} \ - : std::vector{static_cast(pre_gelu_out.size(0)), \ - static_cast(pre_gelu_out.size(1))}; \ - auto pre_gelu_out_ = makeTransformerEngineTensor( \ - pre_gelu_out.data_ptr(), gelu_shape, GetTransformerEngineDType(pre_gelu_out.scalar_type())); \ - auto workspace_ = makeTransformerEngineTensor( \ - workspace.data_ptr(), std::vector{static_cast(workspace.size(0))}, \ - te::DType::kByte); - /*************************************************************************************************** * CommOverlapHelper **************************************************************************************************/ @@ -185,145 +142,92 @@ void CommOverlapHelper::ub_barrier(ExtComm group) { CommOverlap::CommOverlap(const std::vector &buffer_shape, at::ScalarType buffer_dtype, CommOverlapHelper *helper, int tp_size, int num_splits, - int num_max_streams, int comm_cga_size, int num_comm_sm, - bool set_sm_margin, bool atomic_gemm) - : te::CommOverlapBase(buffer_shape, GetTransformerEngineDType(buffer_dtype), helper->myrank, - helper->numranks, helper->mylocal, helper->numlocal, helper->mynode, - helper->numnodes, tp_size, + int num_max_streams, int comm_cga_size, int gemm_priority, + int comm_priority, int num_comm_sm, bool set_sm_margin, bool atomic_gemm, + bool rs_overlap_first_gemm) + : te::CommOverlapBase(buffer_shape, te::pytorch::GetTransformerEngineDType(buffer_dtype), + helper->myrank, helper->numranks, helper->mylocal, helper->numlocal, + helper->mynode, helper->numnodes, tp_size, std::bind(&CommOverlapHelper::ub_allgather, helper, _1, _2, _3, _4, _5), std::bind(&CommOverlapHelper::ub_barrier, helper, _1), num_splits, - num_max_streams, comm_cga_size, num_comm_sm, set_sm_margin, atomic_gemm) { - // Even though we never use these PyTorch tensor wrappers directly, they're still necessary to - // for PyTorch to factor externally allocated memory into its memory pool and garbage collection - // threshold calculation. - _ubuf_torch = torch::from_blob( - _ubuf.dptr(), {static_cast(_ubuf.size(0)), static_cast(_ubuf.size(1))}, - at::device(torch::kCUDA).dtype(buffer_dtype)); - if (_atomic_gemm) { - _ubuf_counter = torch::from_blob(_counter.dptr(), {static_cast(_num_splits * 2)}, - at::device(torch::kCUDA).dtype(torch::kInt32)); - } -} - -/* -** Bulk GEMM + COMM -** This function assumes the communication input is pre-copied to _ubuf -*/ -std::vector CommOverlap::bulk_overlap( - at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, te::DType A_type, bool transa, - at::Tensor B, at::Tensor B_scale_inverse, int64_t B_fp8_tensor, te::DType B_type, bool transb, - at::Tensor D, at::Tensor D_scale, te::DType D_type, at::Tensor D_amax, at::Tensor bias, - te::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, - size_t workspaceSize, bool accumulate, bool use_split_accumulator, - te::CommOverlapType comm_type, at::Tensor rs_output) { - MAKE_TRANSFORMER_ENGINE_TENSORS(A, A_scale_inverse, A_fp8_tensor, A_type, B, B_scale_inverse, - B_fp8_tensor, B_type, D, D_amax, D_scale, D_type, bias, bias_type, - pre_gelu_out, workspace) - - auto rs_out_ = makeTransformerEngineTensor(rs_output); - cudaStream_t stream_main = static_cast(at::cuda::getCurrentCUDAStream()); - te::CommOverlapBase::bulk_overlap(A_, transa, B_, transb, D_, bias_, pre_gelu_out_, workspace_, - grad, accumulate, use_split_accumulator, comm_type, rs_out_, - stream_main); - - // Get the current userbuf offset - char *ubuf_wt_ptr = reinterpret_cast(_ubuf.dptr()); - if (comm_type == te::CommOverlapType::RS) { - ubuf_wt_ptr += _ubuf.numel() / _tp_size * _tp_id * _ubuf.element_size(); - } - - // Generate output tensor from userbuf data pointer - int output_c_dim0 = - (comm_type == te::CommOverlapType::AG) ? _ubuf.size(0) : _ubuf.size(0) / _tp_size; - int output_c_dim1 = _ubuf.size(1); - auto output_tensor = - torch::from_blob(ubuf_wt_ptr, {output_c_dim0, output_c_dim1}, _ubuf_torch.options()); - - return {D, output_tensor}; -} // CommOverlap::bulk_overlap - -/* -** Split FPROP GEMM + ReduceScatter -*/ -void CommOverlap::atomic_gemm_overlap_rs( - at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, te::DType A_type, bool transa, - at::Tensor B, at::Tensor B_scale_inverse, int64_t B_fp8_tensor, te::DType B_type, bool transb, - at::Tensor D, at::Tensor D_scale, te::DType D_type, at::Tensor D_amax, at::Tensor bias, - te::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, - size_t workspaceSize, bool accumulate, bool use_split_accumulator, bool gemm_overlap, - at::Tensor rs_output) { - MAKE_TRANSFORMER_ENGINE_TENSORS(A, A_scale_inverse, A_fp8_tensor, A_type, B, B_scale_inverse, - B_fp8_tensor, B_type, D, D_amax, D_scale, D_type, bias, bias_type, - pre_gelu_out, workspace) - - auto rs_out_ = makeTransformerEngineTensor(rs_output); - cudaStream_t stream_main = static_cast(at::cuda::getCurrentCUDAStream()); - te::CommOverlapBase::atomic_gemm_overlap_rs(A_, transa, B_, transb, D_, bias_, pre_gelu_out_, - workspace_, grad, accumulate, use_split_accumulator, - gemm_overlap, rs_out_, stream_main); -} // CommOverlap::split_overlap_rs + num_max_streams, comm_cga_size, gemm_priority, comm_priority, num_comm_sm, + set_sm_margin, atomic_gemm, rs_overlap_first_gemm) {} -/* -** Split FPROP GEMM + ReduceScatter -*/ -void CommOverlap::split_overlap_rs(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, - te::DType A_type, bool transa, at::Tensor B, - at::Tensor B_scale_inverse, int64_t B_fp8_tensor, - te::DType B_type, bool transb, at::Tensor D, at::Tensor D_scale, - te::DType D_type, at::Tensor D_amax, at::Tensor bias, - te::DType bias_type, at::Tensor pre_gelu_out, bool grad, - at::Tensor workspace, size_t workspaceSize, bool accumulate, - bool use_split_accumulator, bool gemm_overlap, - at::Tensor rs_output) { - MAKE_TRANSFORMER_ENGINE_TENSORS(A, A_scale_inverse, A_fp8_tensor, A_type, B, B_scale_inverse, - B_fp8_tensor, B_type, D, D_amax, D_scale, D_type, bias, bias_type, - pre_gelu_out, workspace) - - auto rs_out_ = makeTransformerEngineTensor(rs_output); - cudaStream_t stream_main = static_cast(at::cuda::getCurrentCUDAStream()); - te::CommOverlapBase::split_overlap_rs(A_, transa, B_, transb, D_, bias_, pre_gelu_out_, - workspace_, grad, accumulate, use_split_accumulator, - gemm_overlap, rs_out_, stream_main); -} // CommOverlap::split_overlap_rs +void CommOverlap::set_buffer_params(py::handle quantizer) { + std::unique_ptr my_quantizer = te::pytorch::convert_quantizer(quantizer); + my_quantizer->set_quantization_params(&_ubuf); + _ubuf_scale_inv_initialized = true; +} /* ** Helper function to copy input to _ubuf */ -void CommOverlap::copy_input_to_ubuf(torch::Tensor input, int comm_type) { +void CommOverlap::copy_into_buffer(py::handle input, py::handle quantizer, bool local_chunk) { + auto input_tensor = te::pytorch::makeTransformerEngineTensor(input, quantizer); + auto input_ptr = input_tensor.dptr() ? input_tensor.dptr() : input_tensor.columnwise_dptr(); + NVTE_CHECK(input_ptr, "Input tensor does not have rowwise or columnwise data!"); + char *ubuf_ptr = reinterpret_cast(_ubuf.dptr()); - te::CommOverlapType _comm_type = static_cast(comm_type); - if (_comm_type == te::CommOverlapType::AG) { - if ((input.numel() * _tp_size) != (int64_t)_ubuf.numel() || - input.element_size() != (int64_t)_ubuf.element_size()) { - NVTE_ERROR("input and ubuf size do not match!"); - } - ubuf_ptr += _ubuf.numel() / _tp_size * _tp_id * _ubuf.element_size(); + if (local_chunk) { + if (input_tensor.numel() * _tp_size > (int64_t)_ubuf.numel()) + NVTE_ERROR("input is larger than the local communication buffer!"); + if (input_tensor.element_size() != (int64_t)_ubuf.element_size()) + NVTE_ERROR("input data type does not match communication buffer!"); + ubuf_ptr += (_ubuf.numel() / _tp_size) * _tp_id * _ubuf.element_size(); } else { - if (input.numel() != (int64_t)_ubuf.numel() || - input.element_size() != (int64_t)_ubuf.element_size()) { - NVTE_ERROR("input and ubuf size do not match!"); - } + if (input_tensor.numel() > (int64_t)_ubuf.numel()) + NVTE_ERROR("input is larger than the global communication buffer!"); + if (input_tensor.element_size() != (int64_t)_ubuf.element_size()) + NVTE_ERROR("input data type does not match communication buffer!"); } + // Copy either row or columnwise data into the communication buffer's columnwise data + // NOTE: _ubuf.columnwise_dptr() is not a valid copy target because it is not registered with + // the Userbuffers communicator. at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream(); NVTE_CHECK_CUDA(cudaEventRecord(_start_d2dcopy, (cudaStream_t)stream_main)); NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_d2dcopy, 0)); - NVTE_CHECK_CUDA(cudaMemcpyAsync(ubuf_ptr, input.data_ptr(), input.numel() * input.element_size(), + NVTE_CHECK_CUDA(cudaMemcpyAsync(ubuf_ptr, input_tensor.dptr(), + input_tensor.numel() * input_tensor.element_size(), cudaMemcpyDeviceToDevice, (cudaStream_t)_stream_comm)); } -torch::Tensor CommOverlap::get_ubuf_output(int comm_type) { +py::object CommOverlap::get_buffer(py::handle quantizer, bool local_chunk, + std::optional> shape) { + using namespace te::pytorch; char *ubuf_wt_ptr = reinterpret_cast(_ubuf.dptr()); - te::CommOverlapType _comm_type = static_cast(comm_type); - if (_comm_type != te::CommOverlapType::AG && _comm_type != te::CommOverlapType::RS) - NVTE_ERROR("Invalid comm_type"); - if (_comm_type == te::CommOverlapType::RS) - ubuf_wt_ptr += _ubuf.numel() / _tp_size * _tp_id * _ubuf.element_size(); - int output_c_dim0 = - (_comm_type == te::CommOverlapType::AG) ? _ubuf.size(0) : _ubuf.size(0) / _tp_size; - int output_c_dim1 = _ubuf.size(1); - return torch::from_blob(ubuf_wt_ptr, {output_c_dim0, output_c_dim1}, - torch::device(torch::kCUDA).dtype(GetATenDType(_ubuf.dtype()))); + if (local_chunk) ubuf_wt_ptr += _ubuf.numel() / _tp_size * _tp_id * _ubuf.element_size(); + + std::vector torch_shape; + if (shape.has_value()) { + torch_shape = shape.value(); + auto requested = product(torch_shape); + auto expected = local_chunk ? _ubuf.numel() / _tp_size : _ubuf.numel(); + NVTE_CHECK(requested == expected, "Number of elements in the requested shape (", requested, + ") does not match allocated buffer size (", expected, ")!"); + } else { + int64_t output_c_dim0 = (local_chunk) ? _ubuf.size(0) / _tp_size : _ubuf.size(0); + int64_t output_c_dim1 = _ubuf.size(1); + torch_shape = {output_c_dim0, output_c_dim1}; + } + + auto ubuf_tensor = torch::from_blob(reinterpret_cast(ubuf_wt_ptr), torch_shape, + at::dtype(GetATenDType(_ubuf.dtype())).device(torch::kCUDA)); + + std::unique_ptr my_quantizer = convert_quantizer(quantizer); + std::vector te_shape; + for (auto s : torch_shape) te_shape.emplace_back(static_cast(s)); + + // Always output a rowwise-only QuantizedTensor + // TODO (Alp): This needs to produce an un-interleaved transpose when required. + auto is_internal = my_quantizer->internal; + auto uses_columnwise = my_quantizer->columnwise_usage; + my_quantizer->internal = false; + my_quantizer->columnwise_usage = false; + auto [te_tensor, py_tensor] = my_quantizer->create_tensor(te_shape, _ubuf.dtype(), ubuf_tensor); + my_quantizer->internal = is_internal; + my_quantizer->columnwise_usage = uses_columnwise; + return py_tensor; } /*************************************************************************************************** @@ -333,148 +237,85 @@ torch::Tensor CommOverlap::get_ubuf_output(int comm_type) { CommOverlapP2P::CommOverlapP2P(const std::vector &buffer_shape, at::ScalarType buffer_dtype, CommOverlapHelper *helper, int tp_size, te::CommOverlapType comm_type, int num_max_streams, - int comm_cga_size, int num_comm_sm, bool set_sm_margin, - bool atomic_gemm, bool use_ce, bool aggregate) + int comm_cga_size, int gemm_priority, int comm_priority, + int num_comm_sm, bool set_sm_margin, bool atomic_gemm, bool use_ce, + bool aggregate) : te::CommOverlapP2PBase( - buffer_shape, GetTransformerEngineDType(buffer_dtype), helper->myrank, helper->numranks, - helper->mylocal, helper->numlocal, helper->mynode, helper->numnodes, tp_size, - std::bind(&CommOverlapHelper::ub_allgather, helper, _1, _2, _3, _4, _5), + buffer_shape, te::pytorch::GetTransformerEngineDType(buffer_dtype), helper->myrank, + helper->numranks, helper->mylocal, helper->numlocal, helper->mynode, helper->numnodes, + tp_size, std::bind(&CommOverlapHelper::ub_allgather, helper, _1, _2, _3, _4, _5), std::bind(&CommOverlapHelper::ub_barrier, helper, _1), comm_type, num_max_streams, - comm_cga_size, num_comm_sm, set_sm_margin, use_ce, atomic_gemm, aggregate) { - // Even though we never use these PyTorch tensor wrappers directly, they're still necessary to - // for PyTorch to factor externally allocated memory into its memory pool and garbage collection - // threshold calculation. - _ubuf_torch = torch::from_blob( - _ubuf.dptr(), {static_cast(_ubuf.size(0)), static_cast(_ubuf.size(1))}, - at::device(torch::kCUDA).dtype(buffer_dtype)); - if (_atomic_gemm) { - _ubuf_counter = torch::from_blob(_counter.dptr(), {static_cast(_num_splits * 2)}, - at::device(torch::kCUDA).dtype(torch::kInt32)); - } -} - -/* -** Split AllGather + AtomicGEMM using P2P communication -** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is -*needed to have AG outputs -** in each rank to be in the contiguous memory space after all ring exchange -*phases. -*/ -void CommOverlapP2P::atomic_gemm_overlap_ag( - at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, te::DType A_type, bool transa, - at::Tensor B, at::Tensor B_scale_inverse, int64_t B_fp8_tensor, te::DType B_type, bool transb, - at::Tensor D, at::Tensor D_scale, te::DType D_type, at::Tensor D_amax, at::Tensor bias, - te::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, - size_t workspaceSize, bool accumulate, bool use_split_accumulator, at::Tensor B_copy) { - MAKE_TRANSFORMER_ENGINE_TENSORS(A, A_scale_inverse, A_fp8_tensor, A_type, B, B_scale_inverse, - B_fp8_tensor, B_type, D, D_amax, D_scale, D_type, bias, bias_type, - pre_gelu_out, workspace) - - auto B_copy_ = makeTransformerEngineTensor(B_copy); - cudaStream_t stream_main = static_cast(at::cuda::getCurrentCUDAStream()); - te::CommOverlapP2PBase::atomic_gemm_overlap_ag(A_, transa, B_, transb, D_, bias_, pre_gelu_out_, - workspace_, grad, accumulate, - use_split_accumulator, B_copy_, stream_main); -} // atomic_gemm_overlap_ag + comm_cga_size, gemm_priority, comm_priority, num_comm_sm, set_sm_margin, use_ce, + atomic_gemm, aggregate) {} -/* -** Split AllGather + GEMM using P2P communication -** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is -*needed to have AG outputs -** in each rank to be in the contiguous memory space after all ring exchange -*phases. -*/ -void CommOverlapP2P::split_overlap_ag( - at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, te::DType A_type, bool transa, - at::Tensor B, at::Tensor B_scale_inverse, int64_t B_fp8_tensor, te::DType B_type, bool transb, - at::Tensor D, at::Tensor D_scale, te::DType D_type, at::Tensor D_amax, at::Tensor bias, - te::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, - size_t workspaceSize, bool accumulate, bool use_split_accumulator, at::Tensor B_copy) { - MAKE_TRANSFORMER_ENGINE_TENSORS(A, A_scale_inverse, A_fp8_tensor, A_type, B, B_scale_inverse, - B_fp8_tensor, B_type, D, D_amax, D_scale, D_type, bias, bias_type, - pre_gelu_out, workspace) - - auto B_copy_ = makeTransformerEngineTensor(B_copy); - cudaStream_t stream_main = static_cast(at::cuda::getCurrentCUDAStream()); - te::CommOverlapP2PBase::split_overlap_ag(A_, transa, B_, transb, D_, bias_, pre_gelu_out_, - workspace_, grad, accumulate, use_split_accumulator, - B_copy_, stream_main); -} // split_overlap_ag - -/* -** Split ReduceScatter + GEMM using P2P communication -*/ -void CommOverlapP2P::atomic_gemm_overlap_rs( - at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, te::DType A_type, bool transa, - at::Tensor B, at::Tensor B_scale_inverse, int64_t B_fp8_tensor, te::DType B_type, bool transb, - at::Tensor D, at::Tensor D_scale, te::DType D_type, at::Tensor D_amax, at::Tensor bias, - te::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, - size_t workspaceSize, bool accumulate, bool use_split_accumulator, at::Tensor rs_output) { - MAKE_TRANSFORMER_ENGINE_TENSORS(A, A_scale_inverse, A_fp8_tensor, A_type, B, B_scale_inverse, - B_fp8_tensor, B_type, D, D_amax, D_scale, D_type, bias, bias_type, - pre_gelu_out, workspace) - - auto rs_out_ = makeTransformerEngineTensor(rs_output); - cudaStream_t stream_main = static_cast(at::cuda::getCurrentCUDAStream()); - te::CommOverlapP2PBase::atomic_gemm_overlap_rs(A_, transa, B_, transb, D_, bias_, pre_gelu_out_, - workspace_, grad, accumulate, - use_split_accumulator, rs_out_, stream_main); -} - -/* -** Split ReduceScatter + GEMM using P2P communication -*/ -void CommOverlapP2P::split_overlap_rs( - at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, te::DType A_type, bool transa, - at::Tensor B, at::Tensor B_scale_inverse, int64_t B_fp8_tensor, te::DType B_type, bool transb, - at::Tensor D, at::Tensor D_scale, te::DType D_type, at::Tensor D_amax, at::Tensor bias, - te::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, - size_t workspaceSize, bool accumulate, bool use_split_accumulator, at::Tensor rs_output) { - MAKE_TRANSFORMER_ENGINE_TENSORS(A, A_scale_inverse, A_fp8_tensor, A_type, B, B_scale_inverse, - B_fp8_tensor, B_type, D, D_amax, D_scale, D_type, bias, bias_type, - pre_gelu_out, workspace) - - auto rs_out_ = makeTransformerEngineTensor(rs_output); - cudaStream_t stream_main = static_cast(at::cuda::getCurrentCUDAStream()); - te::CommOverlapP2PBase::split_overlap_rs(A_, transa, B_, transb, D_, bias_, pre_gelu_out_, - workspace_, grad, accumulate, use_split_accumulator, - rs_out_, stream_main); +void CommOverlapP2P::set_buffer_params(py::handle quantizer) { + std::unique_ptr my_quantizer = te::pytorch::convert_quantizer(quantizer); + my_quantizer->set_quantization_params(&_ubuf); + for (size_t i = 0; i < _ubufs.size(); i++) my_quantizer->set_quantization_params(&_ubufs[i]); } /* ** Copy input to _ubufs[0] */ -void CommOverlapP2P::copy_input_to_ubuf(torch::Tensor input, bool chunk) { +void CommOverlapP2P::copy_into_buffer(py::handle input, py::handle quantizer, bool local_chunk) { + auto input_tensor = te::pytorch::makeTransformerEngineTensor(input, quantizer); + auto input_ptr = input_tensor.dptr() ? input_tensor.dptr() : input_tensor.columnwise_dptr(); + NVTE_CHECK(input_ptr, "Input tensor does not have rowwise or columnwise data!"); + at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream(); - if (chunk) { + if (local_chunk) { // Copy input to the target ubuf chunk by rank offset - if (input.numel() != (int64_t)_ubufs[0].numel() || - input.element_size() != (int64_t)_ubufs[0].element_size()) { - NVTE_ERROR("input and ubuf size do not match!"); - } - NVTE_CHECK_CUDA(cudaMemcpyAsync(_ubufs[_tp_id].dptr(), input.data_ptr(), - input.numel() * input.element_size(), cudaMemcpyDeviceToDevice, - (cudaStream_t)stream_main)); + if (input_tensor.numel() * _tp_size > (int64_t)_ubuf.numel()) + NVTE_ERROR("input is larger than the local communication buffer!"); + if (input_tensor.element_size() != (int64_t)_ubuf.element_size()) + NVTE_ERROR("input data type does not match communication buffer!"); + NVTE_CHECK_CUDA(cudaMemcpyAsync(_ubufs[_tp_id].dptr(), input_ptr, + input_tensor.numel() * input_tensor.element_size(), + cudaMemcpyDeviceToDevice, (cudaStream_t)stream_main)); + } else { - if (input.numel() != (int64_t)_ubuf.numel() || - input.element_size() != (int64_t)_ubuf.element_size()) { - NVTE_ERROR("input and ubuf size do not match!"); - } - NVTE_CHECK_CUDA(cudaMemcpyAsync(_ubuf.dptr(), input.data_ptr(), - input.numel() * input.element_size(), cudaMemcpyDeviceToDevice, - (cudaStream_t)stream_main)); + if (input_tensor.numel() > (int64_t)_ubuf.numel()) + NVTE_ERROR("input is larger than the global communication buffer!"); + if (input_tensor.element_size() != (int64_t)_ubuf.element_size()) + NVTE_ERROR("input data type does not match communication buffer!"); + NVTE_CHECK_CUDA(cudaMemcpyAsync(_ubuf.dptr(), input_ptr, + input_tensor.numel() * input_tensor.element_size(), + cudaMemcpyDeviceToDevice, (cudaStream_t)stream_main)); } } -torch::Tensor CommOverlapP2P::get_ubuf_output(int comm_type) { - char *ubuf_wt_ptr = reinterpret_cast(_ubuf.dptr()); - te::CommOverlapType _comm_type = static_cast(comm_type); - if (_comm_type != te::CommOverlapType::AG && _comm_type != te::CommOverlapType::RS) - NVTE_ERROR("Invalid comm_type"); - if (_comm_type == te::CommOverlapType::RS) - ubuf_wt_ptr += _ubuf.numel() / _tp_size * _self_chunk_id * _ubuf.element_size(); - int output_c_dim0 = - (_comm_type == te::CommOverlapType::AG) ? _ubuf.size(0) : _ubuf.size(0) / _tp_size; - int output_c_dim1 = _ubuf.size(1); - return torch::from_blob(ubuf_wt_ptr, {output_c_dim0, output_c_dim1}, _ubuf_torch.options()); +py::object CommOverlapP2P::get_buffer(py::handle quantizer, bool local_chunk, + std::optional> shape) { + using namespace te::pytorch; + char *ubuf_wt_ptr = reinterpret_cast(local_chunk ? _ubufs[_tp_id].dptr() : _ubuf.dptr()); + + std::vector torch_shape; + if (shape.has_value()) { + torch_shape = shape.value(); + auto requested = product(torch_shape); + auto expected = local_chunk ? _ubufs[_tp_id].numel() : _ubuf.numel(); + NVTE_CHECK(requested == expected, "Number of elements in the requested shape (", requested, + ") does not match allocated buffer size (", expected, ")!"); + } else { + int64_t output_c_dim0 = (local_chunk) ? _ubuf.size(0) / _tp_size : _ubuf.size(0); + int64_t output_c_dim1 = _ubuf.size(1); + torch_shape = {output_c_dim0, output_c_dim1}; + } + auto ubuf_tensor = torch::from_blob(reinterpret_cast(ubuf_wt_ptr), torch_shape, + at::dtype(GetATenDType(_ubuf.dtype())).device(torch::kCUDA)); + + std::unique_ptr my_quantizer = convert_quantizer(quantizer); + std::vector te_shape; + for (auto s : torch_shape) te_shape.emplace_back(static_cast(s)); + + // Always output a rowwise-only QuantizedTensor + // TODO (Alp): This needs to produce an un-interleaved transpose when required. + auto is_internal = my_quantizer->internal; + auto uses_columnwise = my_quantizer->columnwise_usage; + my_quantizer->internal = false; + my_quantizer->columnwise_usage = false; + auto [te_tensor, py_tensor] = my_quantizer->create_tensor(te_shape, _ubuf.dtype(), ubuf_tensor); + my_quantizer->internal = is_internal; + my_quantizer->columnwise_usage = uses_columnwise; + return py_tensor; } diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cpp b/transformer_engine/pytorch/csrc/extensions/gemm.cpp index 250c9993fb..b044c9f604 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cpp @@ -4,74 +4,272 @@ * See LICENSE for license information. ************************************************************************/ +#include +#include + +#include +#include + +#include "../common.h" +#include "common.h" #include "common/util/cuda_runtime.h" +#include "common/util/system.h" #include "extensions.h" +#include "pybind.h" +#include "transformer_engine/transformer_engine.h" -void te_gemm(at::Tensor A, at::Tensor A_scale_inverse, transformer_engine::DType A_type, - bool transa, at::Tensor B, at::Tensor B_scale_inverse, - transformer_engine::DType B_type, bool transb, at::Tensor D, at::Tensor D_scale, - transformer_engine::DType D_type, at::Tensor D_amax, at::Tensor bias, - transformer_engine::DType bias_type, at::Tensor pre_gelu_out, bool grad, - at::Tensor workspace, size_t workspaceSize, bool accumulate, - bool use_split_accumulator, int math_sm_count) { - using namespace transformer_engine; - if (A.numel() == 0 || B.numel() == 0) { - if (D.numel() != 0 && !accumulate) D.zero_(); - if (bias.numel() != 0 && grad) { - if (B.numel() == 0) { - bias.zero_(); - } else { - bias.copy_(B.sum(0)); +namespace { + +void* get_data_ptr(MaybeTensor tensor) { + if (tensor.has_value()) return tensor->data_ptr(); + return nullptr; +} + +size_t get_size(MaybeTensor tensor, int dim) { + if (tensor.has_value()) return static_cast(tensor->size(dim)); + return 0; +} + +} // namespace + +namespace transformer_engine::pytorch { + +namespace detail { + +std::vector getGemmOutputShape(const NVTEShape& A_shape, const bool transa, + const NVTEShape& B_shape, const bool transb) { + // Flatten outer dims to get 2D matrices + const size_t A0 = product(A_shape, 0, A_shape.ndim - 1); + const size_t A1 = A_shape.data[A_shape.ndim - 1]; + const size_t B0 = product(B_shape, 0, B_shape.ndim - 1); + const size_t B1 = B_shape.data[B_shape.ndim - 1]; + + // Check matrix dims + NVTE_CHECK((transa ? A1 : A0) == (transb ? B0 : B1), "Invalid matrix dimensions for GEMM (A=(", + A0, ",", A1, "), transa=", transa, ", B=(", B0, ",", B1, "), transb=", transb, ")"); + + // Construct output dims + std::vector ret; + if (transb) { + ret.emplace_back(B1); + } else { + // Unflatten B0 + for (size_t i = 0; i < B_shape.ndim - 1; ++i) { + ret.emplace_back(B_shape.data[i]); + } + } + if (transa) { + ret.emplace_back(A0); + } else { + ret.emplace_back(A1); + } + return ret; +} + +bool checkGemmShape(const std::vector& expected, const NVTEShape& actual) { + if (expected.size() != actual.ndim) return false; + for (size_t i = 0; i < expected.size(); ++i) { + if (expected[i] != actual.data[i]) return false; + } + return true; +} + +} // namespace detail + +std::pair createOutputTensor(const std::vector& shape, + DType dtype, py::handle quantizer) { + std::unique_ptr my_quantizer = convert_quantizer(quantizer); + return my_quantizer->create_tensor(shape, dtype); +} + +std::vector gemm(py::handle A, bool transa, py::handle B, bool transb, py::object D, + py::handle quantizer, std::optional out_dtype, MaybeTensor bias, + DType bias_type, bool gelu, MaybeTensor gelu_in, bool grad, + at::Tensor workspace, size_t workspaceSize, bool accumulate, + bool use_split_accumulator, CommOverlapCore* comm_overlap, + std::optional comm_type, MaybeTensor extra_output, + bool bulk_overlap) { + // Input tensors + NVTE_CHECK(!A.is_none(), "Tensor A has not been provided"); + NVTE_CHECK(!B.is_none(), "Tensor B has not been provided"); + auto none = py::none(); + TensorWrapper A_tensor = makeTransformerEngineTensor(A, none); + TensorWrapper B_tensor = makeTransformerEngineTensor(B, none); + + // Check tensor dimensions + const auto& A_shape = A_tensor.shape(); + const auto& B_shape = B_tensor.shape(); + const auto& D_shape = detail::getGemmOutputShape(A_shape, transa, B_shape, transb); + NVTE_CHECK(A_shape.ndim >= 1, "Tensor A needs to have at least 1 dimension"); + NVTE_CHECK(B_shape.ndim >= 1, "Tensor B needs to have at least 1 dimension"); + + // Output tensor + TensorWrapper D_tensor; + if (D.is_none()) { + DType output_dtype = out_dtype ? *out_dtype : A_tensor.dtype(); + std::tie(D_tensor, D) = createOutputTensor(D_shape, output_dtype, quantizer); + } else { + D_tensor = makeTransformerEngineTensor(D, quantizer); + NVTE_CHECK(detail::checkGemmShape(D_shape, D_tensor.shape()), + "GEMM output has invalid dims (expected ", std::to_string(D_shape), ", got ", + std::to_string(D_tensor.shape()), ")"); + if (out_dtype) { + NVTE_CHECK(*out_dtype == D_tensor.dtype(), "GEMM output has invalid dtype (expected ", + static_cast(*out_dtype), ", found ", static_cast(D_tensor.dtype()), ")"); + } + } + + // Bias tensor + TensorWrapper bias_tensor; + MaybeTensor bias_grad = std::nullopt; + if (bias.has_value()) { + if (grad) { + auto opts = torch::TensorOptions().dtype(GetATenDType(D_tensor.dtype())).device(torch::kCUDA); + bias_grad = at::empty({static_cast(B_shape.data[B_shape.ndim - 1])}, opts); + bias_tensor = makeTransformerEngineTensor(*bias_grad); + } else { + if (!bias->is_contiguous()) { + bias = bias->contiguous(); } + bias_tensor = makeTransformerEngineTensor(*bias); } - if (pre_gelu_out.numel() != 0) pre_gelu_out.zero_(); - return; } - A = A.contiguous(); - B = B.contiguous(); + // Activation input tensor + MaybeTensor pre_gelu_out = std::nullopt; + DType gelu_type = bias_type; + if (gelu) { + if (!grad) { + auto dtype = GetATenDType(gelu_type); + auto opts = torch::TensorOptions().dtype(dtype).device(torch::kCUDA); + std::vector torch_shape; + for (auto v : D_shape) { + torch_shape.push_back(v); + } + pre_gelu_out = at::empty(torch_shape, opts); + } else { + if (gelu_in.has_value()) { + pre_gelu_out = *gelu_in; + } + } + } + const auto gelu_shape = gelu ? D_shape : std::vector{0}; - auto te_A = makeTransformerEngineTensor( - A.data_ptr(), {static_cast(A.size(0)), static_cast(A.size(1))}, A_type, - nullptr, nullptr, A_scale_inverse.data_ptr()); - auto te_B = makeTransformerEngineTensor( - B.data_ptr(), {static_cast(B.size(0)), static_cast(B.size(1))}, B_type, - nullptr, nullptr, B_scale_inverse.data_ptr()); - auto te_D = makeTransformerEngineTensor( - D.data_ptr(), {static_cast(D.size(0)), static_cast(D.size(1))}, D_type, - D_amax.data_ptr(), D_scale.data_ptr(), nullptr); - auto te_bias = - makeTransformerEngineTensor(bias.data_ptr(), {static_cast(bias.size(0))}, bias_type); + auto te_pre_gelu_out = + makeTransformerEngineTensor(get_data_ptr(pre_gelu_out), gelu_shape, gelu_type); - const auto gelu_shape = pre_gelu_out.data_ptr() == nullptr - ? std::vector{static_cast(pre_gelu_out.size(0))} - : std::vector{static_cast(pre_gelu_out.size(0)), - static_cast(pre_gelu_out.size(1))}; - auto te_pre_gelu_out = makeTransformerEngineTensor( - pre_gelu_out.data_ptr(), gelu_shape, GetTransformerEngineDType(pre_gelu_out.scalar_type())); + // Workspace auto te_workspace = makeTransformerEngineTensor(workspace.data_ptr(), {workspaceSize}, DType::kByte); - nvte_cublas_gemm(te_A.data(), te_B.data(), te_D.data(), te_bias.data(), te_pre_gelu_out.data(), - transa, transb, grad, te_workspace.data(), accumulate, use_split_accumulator, - math_sm_count, at::cuda::getCurrentCUDAStream()); + // Set an external SM Margin to all the GEMMs. + // This comes in handy when DP is overlapped with GEMMs + const int device_id = at::cuda::current_device(); + const int sm_count = transformer_engine::cuda::sm_count(device_id); + int num_math_sms = sm_count - transformer_engine::getenv("NVTE_EXT_MARGIN_SM", sm_count); + + auto main_stream = at::cuda::getCurrentCUDAStream(); + if (A_tensor.numel() != 0 && B_tensor.numel() != 0) { + if (comm_overlap) { + // Prepare extra output tensor + TensorWrapper extra_output_tensor; + if (extra_output.has_value()) { + extra_output_tensor = makeTransformerEngineTensor(*extra_output); + } else { + extra_output_tensor = + makeTransformerEngineTensor(nullptr, std::vector{0}, DType::kByte); + } + + // Direct GEMM call to the correct overlap + if (bulk_overlap) { + comm_overlap->bulk_overlap(A_tensor, transa, B_tensor, transb, D_tensor, bias_tensor, + te_pre_gelu_out, te_workspace, grad, accumulate, + use_split_accumulator, comm_type.value(), extra_output_tensor, + main_stream); + } else if (comm_type.value() == CommOverlapType::AG) { + if (comm_overlap->is_atomic_gemm()) { + comm_overlap->atomic_gemm_overlap_ag(A_tensor, transa, B_tensor, transb, D_tensor, + bias_tensor, te_pre_gelu_out, te_workspace, grad, + accumulate, use_split_accumulator, + extra_output_tensor, main_stream); + } else { + comm_overlap->split_overlap_ag(A_tensor, transa, B_tensor, transb, D_tensor, bias_tensor, + te_pre_gelu_out, te_workspace, grad, accumulate, + use_split_accumulator, extra_output_tensor, main_stream); + } + } else { + if (comm_overlap->is_atomic_gemm()) { + comm_overlap->atomic_gemm_overlap_rs(A_tensor, transa, B_tensor, transb, D_tensor, + bias_tensor, te_pre_gelu_out, te_workspace, grad, + accumulate, use_split_accumulator, + extra_output_tensor, main_stream); + } else { + comm_overlap->split_overlap_rs(A_tensor, transa, B_tensor, transb, D_tensor, bias_tensor, + te_pre_gelu_out, te_workspace, grad, accumulate, + use_split_accumulator, extra_output_tensor, main_stream); + } + } + } else { + // Launch GEMM + nvte_cublas_gemm(A_tensor.data(), B_tensor.data(), D_tensor.data(), bias_tensor.data(), + te_pre_gelu_out.data(), transa, transb, grad, te_workspace.data(), + accumulate, use_split_accumulator, num_math_sms, main_stream); + } + } else { + if (D_tensor.numel() != 0 && !accumulate) { + D_tensor.zero_(main_stream); + } + if (bias.has_value()) { + if (bias->numel() != 0 && grad) { + bias_grad->zero_(); + } + } + } + + // Pack outputs + std::vector out; + out.emplace_back(std::move(D)); + out.emplace_back(py::cast(bias_grad)); + if (gelu && !grad) { + out.emplace_back(py::cast(*pre_gelu_out)); + } else { + out.emplace_back(py::none()); + } + if (extra_output.has_value()) { + out.emplace_back(py::cast(extra_output)); + } else { + out.emplace_back(py::none()); + } + return out; } +} // namespace transformer_engine::pytorch + void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, transformer_engine::DType A_type, - bool transa, at::Tensor B, at::Tensor B_scale_inverse, - transformer_engine::DType B_type, bool transb, at::Tensor D, at::Tensor D_scale, - transformer_engine::DType D_type, at::Tensor D_amax, at::Tensor bias, - transformer_engine::DType bias_type, at::Tensor pre_gelu_out, bool grad, - at::Tensor workspace, size_t workspaceSize, bool accumulate, + std::vector A_scaling_mode, bool transa, at::Tensor B, + at::Tensor B_scale_inverse, transformer_engine::DType B_type, + std::vector B_scaling_mode, bool transb, at::Tensor D, + at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, + at::Tensor bias, transformer_engine::DType bias_type, at::Tensor pre_gelu_out, + bool grad, at::Tensor workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator, int math_sm_count, int m_split, int n_split, bool gemm_producer, at::Tensor counter) { using namespace transformer_engine; + using namespace transformer_engine::pytorch; + + // TODO: Handle scaling modes + NVTEScalingMode nvte_scaling_modeA = NVTE_DELAYED_TENSOR_SCALING; + NVTEScalingMode nvte_scaling_modeB = NVTE_DELAYED_TENSOR_SCALING; + auto te_A = makeTransformerEngineTensor( A.data_ptr(), {static_cast(A.size(0)), static_cast(A.size(1))}, A_type, - nullptr, nullptr, A_scale_inverse.data_ptr()); + nullptr, nullptr, A_scale_inverse.data_ptr(), getTensorShape(A_scale_inverse), + nvte_scaling_modeA); auto te_B = makeTransformerEngineTensor( B.data_ptr(), {static_cast(B.size(0)), static_cast(B.size(1))}, B_type, - nullptr, nullptr, B_scale_inverse.data_ptr()); + nullptr, nullptr, B_scale_inverse.data_ptr(), getTensorShape(B_scale_inverse), + nvte_scaling_modeB); + // TODO: D_scale_inv cannot be nullptr when D_type is FP8. auto te_D = makeTransformerEngineTensor( D.data_ptr(), {static_cast(D.size(0)), static_cast(D.size(1))}, D_type, D_amax.data_ptr(), D_scale.data_ptr(), nullptr); @@ -95,134 +293,108 @@ void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, transformer_engine gemm_producer, te_counter.data(), at::cuda::getCurrentCUDAStream()); } -void te_grouped_gemm(std::vector A, at::Tensor A_scale_inverse, int A_offset, - transformer_engine::DType A_type, bool transa, std::vector B, - at::Tensor B_scale_inverse, int B_offset, transformer_engine::DType B_type, - bool transb, std::vector D, int D_offset, at::Tensor D_scale, - transformer_engine::DType D_type, at::Tensor D_amax, - std::vector bias, transformer_engine::DType bias_type, - std::vector pre_gelu_out, bool grad, - std::vector workspace, size_t workspaceSize, bool accumulate, - bool use_split_accumulator, int math_sm_count) { +std::optional> te_general_grouped_gemm( + std::vector A, bool transa, std::vector B, bool transb, + std::optional> D, transformer_engine::DType D_type, + std::vector m_splits, std::vector bias, + transformer_engine::DType bias_type, bool single_output, std::vector pre_gelu_out, + bool grad, std::vector workspace, size_t workspaceSize, bool accumulate, + bool use_split_accumulator, int math_sm_count) { using namespace transformer_engine; - std::vector te_A, te_B, te_D, te_bias, te_pre_gelu_out, te_workspace; - std::vector tensor_wrappers; - auto make_tensor = [&tensor_wrappers](void* dptr, const std::vector& shape, - transformer_engine::DType dtype, void* amax_dptr, - void* scale_dptr, void* scale_inv_dptr) -> NVTETensor { - tensor_wrappers.emplace_back( - makeTransformerEngineTensor(dptr, shape, dtype, amax_dptr, scale_dptr, scale_inv_dptr)); - return tensor_wrappers.back().data(); - }; + using namespace transformer_engine::pytorch; + std::vector te_A_vector, te_B_vector, te_D_vector, te_bias_vector, + te_pre_gelu_out_vector, te_workspace_vector; + std::vector wrappers; + std::vector D_vectors; + + auto none = py::none(); + + std::vector single_output_begins; + std::vector single_output_ends; + int slicing_dim; + if (single_output && D == std::nullopt) { + NVTE_ERROR("not implemented, D should be allocated for single output case."); + } + + void* output_data_ptr; + if (single_output) { + output_data_ptr = (*D)[0].data_ptr(); + } + for (size_t i = 0; i < A.size(); i++) { - if (A[i].numel() == 0 || B[i].numel() == 0) { - if (D[i].numel() != 0 && !accumulate) D[i].zero_(); + auto te_A = makeTransformerEngineTensor(A[i], none); + auto te_B = makeTransformerEngineTensor(B[i], none); + + // if there is single output + at::Tensor out_tensor; + auto size_t_shape = + pytorch::detail::getGemmOutputShape(te_A.shape(), transa, te_B.shape(), transb); + std::vector D_shape; + for (size_t t : size_t_shape) { + D_shape.push_back(t); + } + auto dtype = GetATenDType(D_type); + auto opts = torch::TensorOptions().dtype(dtype).device(torch::kCUDA); + if (single_output) { + out_tensor = at::from_blob(output_data_ptr, D_shape, opts); + char* char_ptr = reinterpret_cast(output_data_ptr); + char_ptr += m_splits[i] * te_A.size(0) * (*D)[0].element_size(); + output_data_ptr = reinterpret_cast(char_ptr); + D_vectors.emplace_back(out_tensor); + } else { + if (D == std::nullopt) { + auto opts = torch::TensorOptions().dtype(dtype).device(torch::kCUDA); + out_tensor = at::empty(D_shape, opts); + D_vectors.emplace_back(out_tensor); + } else { + out_tensor = (*D)[i]; + } + } + + if (te_A.numel() == 0 || te_B.numel() == 0) { + if (out_tensor.numel() != 0 && !accumulate) out_tensor.zero_(); if (bias[i].numel() != 0 && grad) { - if (B[i].numel() == 0) { - bias[i].zero_(); - } else { - bias[i].copy_(B[i].sum(0)); - } + bias[i].zero_(); } if (pre_gelu_out[i].numel() != 0) pre_gelu_out[i].zero_(); continue; } - NVTE_CHECK(A[i].is_contiguous(), "A[", i, "] must be contiguous."); - NVTE_CHECK(B[i].is_contiguous(), "B[", i, "] must be contiguous."); - NVTE_CHECK(D[i].is_contiguous(), "D[", i, "] must be contiguous."); - - te_A.emplace_back(make_tensor( - A[i].data_ptr(), {static_cast(A[i].size(0)), static_cast(A[i].size(1))}, - A_type, nullptr, nullptr, getDataPtr(A_scale_inverse, A_offset + i))); - te_B.emplace_back(make_tensor( - B[i].data_ptr(), {static_cast(B[i].size(0)), static_cast(B[i].size(1))}, - B_type, nullptr, nullptr, getDataPtr(B_scale_inverse, B_offset + i))); - te_D.emplace_back(make_tensor( - D[i].data_ptr(), {static_cast(D[i].size(0)), static_cast(D[i].size(1))}, - D_type, getDataPtr(D_amax, D_offset + i), getDataPtr(D_scale, D_offset + i), nullptr)); - te_bias.emplace_back(make_tensor(bias[i].data_ptr(), {static_cast(bias[i].size(0))}, - bias_type, nullptr, nullptr, nullptr)); + auto te_D = makeTransformerEngineTensor(out_tensor); + auto te_bias = makeTransformerEngineTensor(bias[i]); + auto te_pre_gelu_out = makeTransformerEngineTensor(pre_gelu_out[i]); const auto gelu_shape = pre_gelu_out[i].data_ptr() == nullptr - ? std::vector{static_cast(pre_gelu_out[i].size(0))} - : std::vector{static_cast(pre_gelu_out[i].size(0)), - static_cast(pre_gelu_out[i].size(1))}; - te_pre_gelu_out.emplace_back(make_tensor( - pre_gelu_out[i].data_ptr(), gelu_shape, - GetTransformerEngineDType(pre_gelu_out[i].scalar_type()), nullptr, nullptr, nullptr)); - } - for (size_t i = 0; i < workspace.size(); i++) { - te_workspace.emplace_back(make_tensor(workspace[i].data_ptr(), {workspaceSize}, DType::kByte, - nullptr, nullptr, nullptr)); - } + ? std::vector{static_cast(te_pre_gelu_out.size(0))} + : std::vector{static_cast(te_pre_gelu_out.size(0)), + static_cast(te_pre_gelu_out.size(1))}; - // For now, we only have multi-stream cublas backend. - nvte_multi_stream_cublas_gemm(te_A.data(), te_B.data(), te_D.data(), te_bias.data(), - te_pre_gelu_out.data(), te_A.size(), transa, transb, grad, - te_workspace.data(), accumulate, use_split_accumulator, - math_sm_count, at::cuda::getCurrentCUDAStream()); -} + DType gelu_type = bias_type; + te_pre_gelu_out = + makeTransformerEngineTensor(get_data_ptr(pre_gelu_out[i]), gelu_shape, gelu_type); -void te_grouped_gemm_single_output( - std::vector A, std::vector A_scale_inverse, int A_offset, - transformer_engine::DType A_type, bool transa, std::vector B, - at::Tensor B_scale_inverse, int B_offset, transformer_engine::DType B_type, bool transb, - std::vector m_splits, at::Tensor D, int D_offset, at::Tensor D_scale, - transformer_engine::DType D_type, at::Tensor D_amax, std::vector bias, - transformer_engine::DType bias_type, std::vector pre_gelu_out, bool grad, - std::vector workspace, size_t workspaceSize, bool accumulate, - bool use_split_accumulator, int math_sm_count) { - using namespace transformer_engine; - std::vector te_A, te_B, te_D, te_bias, te_pre_gelu_out, te_workspace; - std::vector tensor_wrappers; - auto make_tensor = [&tensor_wrappers](void* dptr, const std::vector& shape, - transformer_engine::DType dtype, void* amax_dptr, - void* scale_dptr, void* scale_inv_dptr) -> NVTETensor { - tensor_wrappers.emplace_back( - makeTransformerEngineTensor(dptr, shape, dtype, amax_dptr, scale_dptr, scale_inv_dptr)); - return tensor_wrappers.back().data(); - }; - NVTE_CHECK(D.is_contiguous(), "D must be contiguous."); - void* d_i_ptr = reinterpret_cast(D.data_ptr()); - for (size_t i = 0; i < A.size(); i++) { - if (m_splits[i] == 0) continue; - NVTE_CHECK(A[i].data_ptr() != nullptr, "A[", i, "] must not be nullptr."); - NVTE_CHECK(B[i].data_ptr() != nullptr, "B[", i, "] must not be nullptr."); - NVTE_CHECK(A[i].is_contiguous(), "A[", i, "] must be contiguous."); - NVTE_CHECK(B[i].is_contiguous(), "B[", i, "] must be contiguous."); - te_A.emplace_back(make_tensor( - A[i].data_ptr(), {static_cast(A[i].size(0)), static_cast(A[i].size(1))}, - A_type, nullptr, nullptr, getDataPtr(A_scale_inverse[i], A_offset))); - te_B.emplace_back(make_tensor( - B[i].data_ptr(), {static_cast(B[i].size(0)), static_cast(B[i].size(1))}, - B_type, nullptr, nullptr, getDataPtr(B_scale_inverse, B_offset + i))); - te_D.emplace_back(make_tensor( - d_i_ptr, {static_cast(m_splits[i]), static_cast(A[i].size(0))}, D_type, - getDataPtr(D_amax, D_offset + i), getDataPtr(D_scale, D_offset + i), nullptr)); - te_bias.emplace_back(make_tensor(bias[i].data_ptr(), {static_cast(bias[i].size(0))}, - bias_type, nullptr, nullptr, nullptr)); + te_A_vector.emplace_back(te_A.data()); + te_B_vector.emplace_back(te_B.data()); + te_D_vector.emplace_back(te_D.data()); + te_bias_vector.emplace_back(te_bias.data()); + te_pre_gelu_out_vector.emplace_back(te_pre_gelu_out.data()); - const auto gelu_shape = pre_gelu_out[i].data_ptr() == nullptr - ? std::vector{static_cast(pre_gelu_out[i].size(0))} - : std::vector{static_cast(pre_gelu_out[i].size(0)), - static_cast(pre_gelu_out[i].size(1))}; - te_pre_gelu_out.emplace_back(make_tensor( - pre_gelu_out[i].data_ptr(), gelu_shape, - GetTransformerEngineDType(pre_gelu_out[i].scalar_type()), nullptr, nullptr, nullptr)); - // Move the D pointer to the next split. - char* char_ptr = reinterpret_cast(d_i_ptr); - char_ptr += m_splits[i] * A[i].size(0) * D.element_size(); - d_i_ptr = reinterpret_cast(char_ptr); + wrappers.emplace_back(std::move(te_A)); + wrappers.emplace_back(std::move(te_B)); + wrappers.emplace_back(std::move(te_D)); + wrappers.emplace_back(std::move(te_bias)); + wrappers.emplace_back(std::move(te_pre_gelu_out)); } for (size_t i = 0; i < workspace.size(); i++) { - te_workspace.emplace_back(make_tensor(workspace[i].data_ptr(), {workspaceSize}, DType::kByte, - nullptr, nullptr, nullptr)); + auto wsp = makeTransformerEngineTensor(workspace[i].data_ptr(), {workspaceSize}, DType::kByte); + te_workspace_vector.emplace_back(wsp.data()); + wrappers.emplace_back(std::move(wsp)); } - // For now, we only have multi-stream cublas backend. - nvte_multi_stream_cublas_gemm(te_A.data(), te_B.data(), te_D.data(), te_bias.data(), - te_pre_gelu_out.data(), te_A.size(), transa, transb, grad, - te_workspace.data(), accumulate, use_split_accumulator, + nvte_multi_stream_cublas_gemm(te_A_vector.data(), te_B_vector.data(), te_D_vector.data(), + te_bias_vector.data(), te_pre_gelu_out_vector.data(), + te_A_vector.size(), transa, transb, grad, + te_workspace_vector.data(), accumulate, use_split_accumulator, math_sm_count, at::cuda::getCurrentCUDAStream()); + return bias; } diff --git a/transformer_engine/pytorch/csrc/extensions/normalization.cpp b/transformer_engine/pytorch/csrc/extensions/normalization.cpp index 2124b551fd..66ad03381c 100644 --- a/transformer_engine/pytorch/csrc/extensions/normalization.cpp +++ b/transformer_engine/pytorch/csrc/extensions/normalization.cpp @@ -6,10 +6,29 @@ #include "extensions.h" -std::vector layernorm_bwd(const at::Tensor &dz, const at::Tensor &x, +namespace transformer_engine::pytorch { +std::pair createOutputTensor(const NVTEShape &shape, DType dtype, + py::handle quantizer) { + std::vector shape_vec; + for (int i = 0; i < shape.ndim; i++) { + size_t t = shape.data[i]; + shape_vec.push_back(t); + } + std::unique_ptr my_quantizer = convert_quantizer(quantizer); + return my_quantizer->create_tensor(shape_vec, dtype); +} +std::pair createOutputTensor(std::vector &shape, DType dtype, + py::handle quantizer) { + std::unique_ptr my_quantizer = convert_quantizer(quantizer); + return my_quantizer->create_tensor(shape, dtype); +} +} // namespace transformer_engine::pytorch + +std::vector layernorm_bwd(const at::Tensor &dz, const at::Tensor &x, const at::Tensor &mu, const at::Tensor &rsigma, const at::Tensor &gamma, const int sm_margin, const bool zero_centered_gamma) { + using namespace transformer_engine::pytorch; const auto &dz_ = dz.contiguous(); const auto &x_ = x.contiguous(); const auto &mu_ = mu.contiguous(); @@ -47,61 +66,57 @@ std::vector layernorm_bwd(const at::Tensor &dz, const at::Tensor &x, at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, zero_centered_gamma, at::cuda::getCurrentCUDAStream()); - return {dx, dgamma, dbeta}; + return {py::cast(dx), py::cast(dgamma), py::cast(dbeta)}; } -std::vector layernorm_fwd_fp8(const at::Tensor &input, const at::Tensor &weight, - const at::Tensor &bias, float eps, at::Tensor scale, - at::Tensor amax, at::Tensor scale_inv, - transformer_engine::DType otype, const int sm_margin, - const bool zero_centered_gamma, const int scale_offset, - const int amax_offset, const int scale_inv_offset) { +std::vector layernorm_fwd(py::handle input, py::handle weight, MaybeTensor bias, + float eps, py::object ln_out, py::handle quantizer, + DType out_dtype, const int sm_margin, + const bool zero_centered_gamma) { + using namespace transformer_engine::pytorch; using namespace transformer_engine; - const auto &input_ = input.contiguous(); + auto none = py::none(); + const TensorWrapper &input_tensor = makeTransformerEngineTensor(input, none); + const TensorWrapper &weight_tensor = makeTransformerEngineTensor(weight, none); - auto ln_out = at::empty_like(input_, at::CUDA(GetATenDType(otype))); - return layernorm_fwd_fp8_noalloc(input_, weight, bias, eps, scale, ln_out, amax, scale_inv, otype, - sm_margin, zero_centered_gamma, scale_offset, amax_offset, - scale_inv_offset); -} - -std::vector layernorm_fwd_fp8_noalloc( - const at::Tensor &input, const at::Tensor &weight, const at::Tensor &bias, float eps, - at::Tensor scale, at::Tensor ln_out, at::Tensor amax, at::Tensor scale_inv, - transformer_engine::DType otype, const int sm_margin, const bool zero_centered_gamma, - const int scale_offset, const int amax_offset, const int scale_inv_offset) { - using namespace transformer_engine; - - const auto &input_ = input.contiguous(); - const auto &weight_ = weight.contiguous(); - const auto &bias_ = bias.contiguous(); + TensorWrapper bias_tensor; + MaybeTensor bias_grad = std::nullopt; + if (bias.has_value()) { + bias_tensor = makeTransformerEngineTensor(*bias); + } // Tensor dimensions - size_t N = static_cast(input.size(0)); - size_t H = static_cast(input.size(1)); - - // Get pointers for FP8 scale, amax, scale-inverse - void *scale_dptr = getDataPtr(scale, scale_offset); - void *amax_dptr = getDataPtr(amax, amax_offset); - void *scale_inv_dptr = getDataPtr(scale_inv, scale_inv_offset); + size_t N = static_cast(input_tensor.size(0)); + size_t H = static_cast(input_tensor.size(1)); + std::vector size = {N, H}; // Construct Transformer Engine tensors - DType itype = GetTransformerEngineDType(input.scalar_type()); - auto mu = at::empty({static_cast(N)}, at::CUDA(at::kFloat)); - auto rsigma = at::empty({static_cast(N)}, at::CUDA(at::kFloat)); - auto input_cu = makeTransformerEngineTensor(input_); - auto gamma_cu = makeTransformerEngineTensor(weight_); - auto beta_cu = makeTransformerEngineTensor(bias_); - auto z_cu = makeTransformerEngineTensor(ln_out.data_ptr(), {N, H}, otype, amax_dptr, scale_dptr, - scale_inv_dptr); - auto mu_cu = makeTransformerEngineTensor(mu); - auto rsigma_cu = makeTransformerEngineTensor(rsigma); + at::Tensor mu = at::empty({static_cast(N)}, at::CUDA(at::kFloat)); + at::Tensor rsigma = at::empty({static_cast(N)}, at::CUDA(at::kFloat)); + + TensorWrapper ln_out_tensor; + std::unique_ptr my_quantizer = convert_quantizer(quantizer); + py::object ln_output; + + if (my_quantizer->get_scaling_mode() == NVTE_MXFP8_1D_SCALING) { + // Use high precision output from normalization + NoneQuantizer q{none}; + std::tie(ln_out_tensor, ln_output) = q.create_tensor(size, out_dtype); + } else { + if (ln_out.is_none()) { + std::tie(ln_out_tensor, ln_out) = my_quantizer->create_tensor(size, out_dtype); + } else { + ln_out_tensor = makeTransformerEngineTensor(ln_out, quantizer); + } + } + TensorWrapper mu_cu = makeTransformerEngineTensor(mu); + TensorWrapper rsigma_cu = makeTransformerEngineTensor(rsigma); // Query workspace sizes transformer_engine::TensorWrapper workspace; - nvte_layernorm_fwd(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(), - mu_cu.data(), rsigma_cu.data(), workspace.data(), + nvte_layernorm_fwd(input_tensor.data(), weight_tensor.data(), bias_tensor.data(), eps, + ln_out_tensor.data(), mu_cu.data(), rsigma_cu.data(), workspace.data(), at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, zero_centered_gamma, at::cuda::getCurrentCUDAStream()); @@ -111,66 +126,30 @@ std::vector layernorm_fwd_fp8_noalloc( makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype()); // Launch kernel - nvte_layernorm_fwd(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(), - mu_cu.data(), rsigma_cu.data(), workspace.data(), + nvte_layernorm_fwd(input_tensor.data(), weight_tensor.data(), bias_tensor.data(), eps, + ln_out_tensor.data(), mu_cu.data(), rsigma_cu.data(), workspace.data(), at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, zero_centered_gamma, at::cuda::getCurrentCUDAStream()); - return {ln_out, mu, rsigma}; -} - -at::Tensor layernorm_fwd_fp8_inf(const at::Tensor &input, const at::Tensor &weight, - const at::Tensor &bias, float eps, at::Tensor scale, - at::Tensor amax, at::Tensor scale_inv, - transformer_engine::DType otype, const int sm_margin, - const bool zero_centered_gamma, const int scale_offset, - const int amax_offset, const int scale_inv_offset - -) { - // This is a specialized version of layernorm_fwd_fp8, optimized for inference, - // which only returns the normalized output. - std::vector out = - layernorm_fwd_fp8(input, weight, bias, eps, scale, amax, scale_inv, otype, sm_margin, - zero_centered_gamma, scale_offset, amax_offset, scale_inv_offset); - return out[0]; -} - -std::vector layernorm_fwd(const at::Tensor &input, const at::Tensor &weight, - const at::Tensor &bias, float eps, const int sm_margin, - const bool zero_centered_gamma) { - using namespace transformer_engine; - - DType itype = GetTransformerEngineDType(input.scalar_type()); - const auto &input_ = input.contiguous(); - auto ln_out = at::empty_like(input_, at::CUDA(GetATenDType(itype))); - - return layernorm_fwd_noalloc(input_, weight, bias, ln_out, eps, sm_margin, zero_centered_gamma); -} - -std::vector layernorm_fwd_noalloc(const at::Tensor &input, const at::Tensor &weight, - const at::Tensor &bias, at::Tensor ln_out, float eps, - const int sm_margin, const bool zero_centered_gamma) { - using namespace transformer_engine; + if (my_quantizer->get_scaling_mode() == NVTE_MXFP8_1D_SCALING) { + TensorWrapper cast_out_tensor; + if (ln_out.is_none()) { + std::tie(cast_out_tensor, ln_out) = my_quantizer->create_tensor(size, out_dtype); + } else { + cast_out_tensor = makeTransformerEngineTensor(ln_out, quantizer); + } - DType itype = GetTransformerEngineDType(input.scalar_type()); + nvte_quantize_noop(ln_out_tensor.data(), cast_out_tensor.data(), nullptr, + at::cuda::getCurrentCUDAStream()); + } - return layernorm_fwd_fp8_noalloc(input, weight, bias, eps, at::Tensor(), ln_out, at::Tensor(), - at::Tensor(), itype, sm_margin, zero_centered_gamma); + return {ln_out, py::cast(mu), py::cast(rsigma)}; } -at::Tensor layernorm_fwd_inf(const at::Tensor &input, const at::Tensor &weight, - const at::Tensor &bias, float eps, const int sm_margin, - const bool zero_centered_gamma) { - // This is a specialized version of layernorm_fwd, optimized for inference, - // which only returns the normalized output. - std::vector out = - layernorm_fwd(input, weight, bias, eps, sm_margin, zero_centered_gamma); - return out[0]; -} - -std::vector rmsnorm_bwd(const at::Tensor &dz, const at::Tensor &x, +std::vector rmsnorm_bwd(const at::Tensor &dz, const at::Tensor &x, const at::Tensor &rsigma, const at::Tensor &gamma, const int sm_margin, const bool zero_centered_gamma) { + using namespace transformer_engine::pytorch; const auto &dz_ = dz.contiguous(); const auto &x_ = x.contiguous(); const auto &rsigma_ = rsigma.contiguous(); @@ -204,57 +183,48 @@ std::vector rmsnorm_bwd(const at::Tensor &dz, const at::Tensor &x, at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, zero_centered_gamma, at::cuda::getCurrentCUDAStream()); - return {dx, dgamma}; + return {py::cast(dx), py::cast(dgamma)}; } -std::vector rmsnorm_fwd_fp8(const at::Tensor &input, const at::Tensor &weight, - float eps, at::Tensor scale, at::Tensor amax, - at::Tensor scale_inv, transformer_engine::DType otype, - const int sm_margin, const bool zero_centered_gamma, - const int scale_offset, const int amax_offset, - const int scale_inv_offset) { +std::vector rmsnorm_fwd(const py::handle &input, const py::handle &weight, float eps, + py::object ln_out, py::handle quantizer, + transformer_engine::DType otype, const int sm_margin, + const bool zero_centered_gamma) { + using namespace transformer_engine::pytorch; using namespace transformer_engine; - const auto &input_ = input.contiguous(); - const auto &weight_ = weight.contiguous(); - - auto ln_out = at::empty_like(input_, at::CUDA(GetATenDType(otype))); - return rmsnorm_fwd_fp8_noalloc(input_, weight_, eps, scale, ln_out, amax, scale_inv, otype, - sm_margin, zero_centered_gamma, scale_offset, amax_offset, - scale_inv_offset); -} - -std::vector rmsnorm_fwd_fp8_noalloc(const at::Tensor &input, const at::Tensor &weight, - float eps, at::Tensor scale, at::Tensor ln_out, - at::Tensor amax, at::Tensor scale_inv, - transformer_engine::DType otype, - const int sm_margin, const bool zero_centered_gamma, - const int scale_offset, const int amax_offset, - const int scale_inv_offset) { - using namespace transformer_engine; + auto none = py::none(); + const TensorWrapper &input_tensor = makeTransformerEngineTensor(input, none); + const TensorWrapper &weight_tensor = makeTransformerEngineTensor(weight, none); // Tensor dimensions - size_t N = static_cast(input.size(0)); - size_t H = static_cast(input.size(1)); - - // Get pointers for FP8 scale, amax, scale-inverse - void *scale_dptr = getDataPtr(scale, scale_offset); - void *amax_dptr = getDataPtr(amax, amax_offset); - void *scale_inv_dptr = getDataPtr(scale_inv, scale_inv_offset); + size_t N = static_cast(input_tensor.shape().data[0]); + size_t H = static_cast(input_tensor.shape().data[1]); // Construct Transformer Engine tensors - DType itype = GetTransformerEngineDType(input.scalar_type()); auto rsigma = at::empty({static_cast(N)}, at::CUDA(at::kFloat)); - auto input_cu = makeTransformerEngineTensor(input); - auto gamma_cu = makeTransformerEngineTensor(weight); - auto z_cu = makeTransformerEngineTensor(ln_out.data_ptr(), {N, H}, otype, amax_dptr, scale_dptr, - scale_inv_dptr); + std::vector size = {N, H}; + TensorWrapper ln_out_tensor; + std::unique_ptr my_quantizer = convert_quantizer(quantizer); + py::object ln_output; + + if (my_quantizer->get_scaling_mode() == NVTE_MXFP8_1D_SCALING) { + // Use high precision output from normalization + NoneQuantizer q{none}; + std::tie(ln_out_tensor, ln_output) = q.create_tensor(size, otype); + } else { + if (ln_out.is_none()) { + std::tie(ln_out_tensor, ln_out) = my_quantizer->create_tensor(size, otype); + } else { + ln_out_tensor = makeTransformerEngineTensor(ln_out, quantizer); + } + } auto rsigma_cu = makeTransformerEngineTensor(rsigma); // Query workspace sizes transformer_engine::TensorWrapper workspace; - nvte_rmsnorm_fwd(input_cu.data(), gamma_cu.data(), eps, z_cu.data(), rsigma_cu.data(), - workspace.data(), + nvte_rmsnorm_fwd(input_tensor.data(), weight_tensor.data(), eps, ln_out_tensor.data(), + rsigma_cu.data(), workspace.data(), at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, zero_centered_gamma, at::cuda::getCurrentCUDAStream()); @@ -264,55 +234,22 @@ std::vector rmsnorm_fwd_fp8_noalloc(const at::Tensor &input, const a makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype()); // Launch kernel - nvte_rmsnorm_fwd(input_cu.data(), gamma_cu.data(), eps, z_cu.data(), rsigma_cu.data(), - workspace.data(), + nvte_rmsnorm_fwd(input_tensor.data(), weight_tensor.data(), eps, ln_out_tensor.data(), + rsigma_cu.data(), workspace.data(), at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, zero_centered_gamma, at::cuda::getCurrentCUDAStream()); - return {ln_out, rsigma}; -} + if (my_quantizer->get_scaling_mode() == NVTE_MXFP8_1D_SCALING) { + TensorWrapper cast_out_tensor; + if (ln_out.is_none()) { + std::tie(cast_out_tensor, ln_out) = my_quantizer->create_tensor(size, otype); + } else { + cast_out_tensor = makeTransformerEngineTensor(ln_out, quantizer); + } -at::Tensor rmsnorm_fwd_fp8_inf(const at::Tensor &input, const at::Tensor &weight, float eps, - at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, - transformer_engine::DType otype, const int sm_margin, - const bool zero_centered_gamma, const int scale_offset, - const int amax_offset, const int scale_inv_offset) { - // This is a specialized version of rmsnorm_fwd_fp8, optimized for inference, - // which only returns the normalized output. - std::vector out = - rmsnorm_fwd_fp8(input, weight, eps, scale, amax, scale_inv, otype, sm_margin, - zero_centered_gamma, scale_offset, amax_offset, scale_inv_offset); - return out[0]; -} - -std::vector rmsnorm_fwd(const at::Tensor &input, const at::Tensor &weight, float eps, - const int sm_margin, const bool zero_centered_gamma) { - using namespace transformer_engine; - - const auto &input_ = input.contiguous(); - const auto &weight_ = weight.contiguous(); - - DType itype = GetTransformerEngineDType(input.scalar_type()); - auto ln_out = at::empty_like(input_, at::CUDA(GetATenDType(itype))); - - return rmsnorm_fwd_noalloc(input_, weight_, ln_out, eps, sm_margin, zero_centered_gamma); -} - -std::vector rmsnorm_fwd_noalloc(const at::Tensor &input, const at::Tensor &weight, - at::Tensor ln_out, float eps, const int sm_margin, - const bool zero_centered_gamma) { - using namespace transformer_engine; - - DType itype = GetTransformerEngineDType(input.scalar_type()); - - return rmsnorm_fwd_fp8_noalloc(input, weight, eps, at::Tensor(), ln_out, at::Tensor(), - at::Tensor(), itype, sm_margin, zero_centered_gamma); -} + nvte_quantize_noop(ln_out_tensor.data(), cast_out_tensor.data(), nullptr, + at::cuda::getCurrentCUDAStream()); + } -at::Tensor rmsnorm_fwd_inf(const at::Tensor &input, const at::Tensor &weight, float eps, - const int sm_margin, const bool zero_centered_gamma) { - // This is a specialized version of rmsnorm_fwd, optimized for inference, - // which only returns the normalized output. - std::vector out = rmsnorm_fwd(input, weight, eps, sm_margin, zero_centered_gamma); - return out[0]; + return {ln_out, py::none(), py::cast(rsigma)}; } diff --git a/transformer_engine/pytorch/csrc/extensions/padding.cpp b/transformer_engine/pytorch/csrc/extensions/padding.cpp index ca10e4d3c9..b9972af7cb 100644 --- a/transformer_engine/pytorch/csrc/extensions/padding.cpp +++ b/transformer_engine/pytorch/csrc/extensions/padding.cpp @@ -10,6 +10,7 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output, std::vector input_row_list, std::vector padded_input_row_list) { using namespace transformer_engine; + using namespace transformer_engine::pytorch; NVTE_CHECK(input_row_list.size() == padded_input_row_list.size(), "Number of input row list and padded row list must match."); diff --git a/transformer_engine/pytorch/csrc/extensions/permutation.cu b/transformer_engine/pytorch/csrc/extensions/permutation.cu index f363e6e7ea..47282da504 100644 --- a/transformer_engine/pytorch/csrc/extensions/permutation.cu +++ b/transformer_engine/pytorch/csrc/extensions/permutation.cu @@ -11,6 +11,7 @@ std::tuple> moe_permute_fwd( at::Tensor input, const transformer_engine::DType dtype, at::Tensor indices, int64_t num_out_tokens, std::vector workspace, int64_t max_expanded_token_num) { + using namespace transformer_engine::pytorch; const int num_tokens = input.size(0); int num_cols = input.size(1); const int topK = indices.size(1); @@ -96,6 +97,7 @@ at::Tensor moe_permute_bwd(at::Tensor input, const transformer_engine::DType dty at::Tensor moe_unpermute_fwd(at::Tensor input, const transformer_engine::DType dtype, at::Tensor row_id_map, at::Tensor prob, int64_t num_tokens, int64_t topK) { + using namespace transformer_engine::pytorch; int num_cols = input.size(1); // Activations type @@ -129,6 +131,7 @@ at::Tensor moe_unpermute_fwd(at::Tensor input, const transformer_engine::DType d std::tuple moe_unpermute_bwd(at::Tensor input_bwd, at::Tensor input_fwd, const transformer_engine::DType dtype, at::Tensor row_id_map, at::Tensor prob) { + using namespace transformer_engine::pytorch; const int topK = (prob.numel() > 0) ? prob.size(1) : 1; const int num_tokens = (prob.numel() > 0) ? prob.size(0) : row_id_map.size(0); int num_cols = input_bwd.size(1); diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index e5d8744eef..442837d767 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -4,14 +4,131 @@ * See LICENSE for license information. ************************************************************************/ +#include "pybind.h" + +#include +#include +#include +#include #include #include +#include + +#include "../common.h" #include "../extensions.h" +#include "common.h" + +namespace transformer_engine::pytorch { + +PyTypeObject *Float8TensorPythonClass = nullptr; /// TODO Remove +PyTypeObject *Float8TensorBasePythonClass = nullptr; +PyTypeObject *Float8QuantizerClass = nullptr; +PyTypeObject *MXFP8TensorPythonClass = nullptr; /// TODO Remove +PyTypeObject *MXFP8TensorBasePythonClass = nullptr; +PyTypeObject *MXFP8QuantizerClass = nullptr; + +void init_float8_extension() { + if (Float8TensorPythonClass) return; + auto fp8_module = py::module_::import("transformer_engine.pytorch.tensor.float8_tensor"); + Float8QuantizerClass = + reinterpret_cast(PyObject_GetAttrString(fp8_module.ptr(), "Float8Quantizer")); + Float8TensorPythonClass = + reinterpret_cast(PyObject_GetAttrString(fp8_module.ptr(), "Float8Tensor")); + auto fp8_base_module = + py::module_::import("transformer_engine.pytorch.tensor._internal.float8_tensor_base"); + Float8TensorBasePythonClass = reinterpret_cast( + PyObject_GetAttrString(fp8_base_module.ptr(), "Float8TensorBase")); + NVTE_CHECK(Float8TensorPythonClass != nullptr, + "Internal error: could not initialize pyTorch Float8 extension."); +} + +void init_mxfp8_extension() { + if (MXFP8TensorPythonClass) return; + auto fp8_module = py::module_::import("transformer_engine.pytorch.tensor.mxfp8_tensor"); + MXFP8QuantizerClass = + reinterpret_cast(PyObject_GetAttrString(fp8_module.ptr(), "MXFP8Quantizer")); + MXFP8TensorPythonClass = + reinterpret_cast(PyObject_GetAttrString(fp8_module.ptr(), "MXFP8Tensor")); + auto fp8_base_module = + py::module_::import("transformer_engine.pytorch.tensor._internal.mxfp8_tensor_base"); + MXFP8TensorBasePythonClass = reinterpret_cast( + PyObject_GetAttrString(fp8_base_module.ptr(), "MXFP8TensorBase")); + NVTE_CHECK(MXFP8TensorPythonClass != nullptr, + "Internal error: could not initialize pyTorch MXFP8 extension."); +} + +void init_extension() { + init_float8_extension(); + init_mxfp8_extension(); +} + +} // namespace transformer_engine::pytorch + #include "common/util/pybind_helper.h" PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { NVTE_DECLARE_COMMON_PYBIND11_HANDLES(m) + m.def("quantize", transformer_engine::pytorch::quantize, py::arg("tensor"), py::arg("quantizer"), + py::arg("output") = py::none(), py::arg("noop") = py::none()); + m.def("dequantize", &transformer_engine::pytorch::dequantize, "Dequantize", py::arg("input"), + py::arg("otype")); + m.def("bgrad_quantize", transformer_engine::pytorch::bgrad_quantize, + "Compute bias gradient and quantize", py::arg("input"), py::arg("quantizer")); + m.def("generic_gemm", transformer_engine::pytorch::gemm, "Compute GEMM (matrix-matrix multiply)", + py::arg("A"), py::arg("transA"), py::arg("B"), py::arg("transB"), py::arg("D"), + py::arg("quantizer"), py::arg("output_dtype"), py::arg("bias"), py::arg("bias_type"), + py::arg("gelu"), py::arg("gelu_in"), py::arg("grad"), py::arg("workspace"), + py::arg("workspace_size"), py::arg("accumulate"), py::arg("use_split_accumulator"), + py::arg("comm_overlap") = nullptr, py::arg("comm_type") = std::nullopt, + py::arg("extra_output") = std::nullopt, py::arg("bulk_overlap") = false); + m.def("rowwise_swizzle", &rowwise_swizzle, "Swizzle rowwise scale inverses.", + py::call_guard()); + m.def("columnwise_swizzle", &columnwise_swizzle, "Swizzle columnwise scale inverses.", + py::call_guard()); + m.def("gelu", transformer_engine::pytorch::gelu, "GeLU activation", py::arg("input"), + py::arg("quantizer")); + m.def("relu", transformer_engine::pytorch::relu, "ReLU activation", py::arg("input"), + py::arg("quantizer")); + m.def("geglu", transformer_engine::pytorch::geglu, "GeGLU activation", py::arg("input"), + py::arg("quantizer")); + m.def("qgeglu", transformer_engine::pytorch::qgeglu, "QuickGeGLU activation", py::arg("input"), + py::arg("quantizer")); + m.def("reglu", transformer_engine::pytorch::reglu, "ReGLU activation", py::arg("input"), + py::arg("quantizer")); + m.def("swiglu", transformer_engine::pytorch::swiglu, "SwiGLU activation", py::arg("input"), + py::arg("quantizer")); + m.def("qgelu", transformer_engine::pytorch::qgelu, "QuickGELU activation", py::arg("input"), + py::arg("quantizer")); + m.def("srelu", transformer_engine::pytorch::srelu, "Squared ReLU activation", py::arg("input"), + py::arg("quantizer")); + m.def("dgelu", transformer_engine::pytorch::dgelu, "Backward of GeLU", py::arg("grad"), + py::arg("fwd_input"), py::arg("quantizer")); + m.def("drelu", transformer_engine::pytorch::drelu, "Backward of ReLU", py::arg("grad"), + py::arg("fwd_input"), py::arg("quantizer")); + m.def("dgeglu", transformer_engine::pytorch::dgeglu, "Backward of GeGLU", py::arg("grad"), + py::arg("fwd_input"), py::arg("quantizer")); + m.def("dqgeglu", transformer_engine::pytorch::dqgeglu, "Backward of QuickGeGLU", py::arg("grad"), + py::arg("fwd_input"), py::arg("quantizer")); + m.def("dreglu", transformer_engine::pytorch::dreglu, "Backward of ReGLU", py::arg("grad"), + py::arg("fwd_input"), py::arg("quantizer")); + m.def("dswiglu", transformer_engine::pytorch::dswiglu, "Backward of SwiGLU", py::arg("grad"), + py::arg("fwd_input"), py::arg("quantizer")); + m.def("dqgelu", transformer_engine::pytorch::dqgelu, "Backward of QuickGELU", py::arg("grad"), + py::arg("fwd_input"), py::arg("quantizer")); + m.def("dsrelu", transformer_engine::pytorch::dsrelu, "Backward of Squared ReLU", py::arg("grad"), + py::arg("fwd_input"), py::arg("quantizer")); + m.def("dbias_dgelu", transformer_engine::pytorch::dbias_dgelu, "DGeLU + DBias + Quantize", + py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer")); + m.def("dbias_dsilu", transformer_engine::pytorch::dbias_dsilu, "DSiLU + DBias + Quantize", + py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer")); + m.def("dbias_drelu", transformer_engine::pytorch::dbias_drelu, "DReLU + DBias + Quantize", + py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer")); + m.def("dbias_dqgelu", transformer_engine::pytorch::dbias_dqgelu, "DQGeLU + DBias + Quantize", + py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer")); + m.def("dbias_dsrelu", transformer_engine::pytorch::dbias_dsrelu, + "DSquaredReLU + DBias + Quantize", py::arg("grad"), py::arg("fwd_input"), + py::arg("quantizer")); // Permutation functions m.def("moe_permute_fwd", moe_permute_fwd); @@ -42,116 +159,24 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::call_guard()); // Other granular functions - m.def("layernorm_fwd_fp8", &layernorm_fwd_fp8, "LN FWD FP8", - py::call_guard(), py::arg("input"), py::arg("weight"), - py::arg("bias"), py::arg("eps"), py::arg("scale"), py::arg("amax"), py::arg("scale_inv"), - py::arg("otype"), py::arg("sm_margin"), py::arg("zero_centered_gamma"), - py::arg("scale_offset") = 0, py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0); - m.def("layernorm_fwd_fp8_noalloc", &layernorm_fwd_fp8_noalloc, "LN FWD FP8", - py::call_guard(), py::arg("input"), py::arg("weight"), - py::arg("bias"), py::arg("eps"), py::arg("scale"), py::arg("ln_out"), py::arg("amax"), - py::arg("scale_inv"), py::arg("otype"), py::arg("sm_margin"), - py::arg("zero_centered_gamma"), py::arg("scale_offset") = 0, py::arg("amax_offset") = 0, - py::arg("scale_inv_offset") = 0); - m.def("layernorm_bwd", &layernorm_bwd, "LN BWD", py::call_guard()); - m.def("layernorm_fwd", &layernorm_fwd, "LN FWD", py::call_guard()); - m.def("layernorm_fwd_noalloc", &layernorm_fwd_noalloc, "LN FWD", - py::call_guard()); - m.def("rmsnorm_fwd_fp8", &rmsnorm_fwd_fp8, "RMSNorm FWD FP8", - py::call_guard(), py::arg("input"), py::arg("weight"), - py::arg("eps"), py::arg("scale"), py::arg("amax"), py::arg("scale_inv"), py::arg("otype"), - py::arg("sm_margin"), py::arg("zero_centered_gamma"), py::arg("scale_offset") = 0, - py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0); - m.def("rmsnorm_fwd_fp8_noalloc", &rmsnorm_fwd_fp8_noalloc, "RMSNorm FWD FP8", - py::call_guard(), py::arg("input"), py::arg("weight"), - py::arg("eps"), py::arg("scale"), py::arg("ln_out"), py::arg("amax"), py::arg("scale_inv"), - py::arg("otype"), py::arg("sm_margin"), py::arg("zero_centered_gamma"), - py::arg("scale_offset") = 0, py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0); - m.def("rmsnorm_bwd", &rmsnorm_bwd, "RMSNorm BWD", py::call_guard()); - m.def("rmsnorm_fwd", &rmsnorm_fwd, "RMSNorm FWD", py::call_guard()); - m.def("rmsnorm_fwd_noalloc", &rmsnorm_fwd_noalloc, "RMSNorm FWD", - py::call_guard()); - m.def("fused_cast_transpose", &fused_cast_transpose, "Fused Cast + Transpose", - py::call_guard()); - m.def("fused_cast_transpose_noop", &fused_cast_transpose_noop, - "Cast + Transpose with noop option", py::call_guard(), - py::arg("input"), py::arg("noop"), py::arg("scale"), py::arg("amax"), py::arg("scale_inv"), - py::arg("input_cast"), py::arg("input_transpose"), py::arg("otype"), - py::arg("scale_offset") = 0, py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0); - m.def("fused_cast_transpose_bgrad", &fused_cast_transpose_bgrad, "Fused Cast + Transpose + BGRAD", - py::call_guard(), py::arg("grad_output"), py::arg("scale"), - py::arg("amax"), py::arg("scale_inv"), py::arg("otype"), py::arg("scale_offset") = 0, - py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0); - m.def("fused_fp8_transpose_bgrad", &fused_fp8_transpose_bgrad, "Fused FP8 Transpose + BGRAD", - py::call_guard(), py::arg("grad_output"), py::arg("scale"), - py::arg("amax"), py::arg("scale_inv"), py::arg("otype"), py::arg("grad_bias_type"), - py::arg("scale_offset") = 0, py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0); - m.def("fused_cast_transpose_bgrad_dgelu", &fused_cast_transpose_bgrad_dgelu, - "Fused Cast + Transpose + BGRAD + DGELU", py::call_guard(), - py::arg("grad_output"), py::arg("gelu_input"), py::arg("scale"), py::arg("amax"), - py::arg("scale_inv"), py::arg("otype"), py::arg("scale_offset") = 0, - py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0); - m.def("fused_dswiglu_cast_transpose", &fused_dswiglu_cast_transpose, - "Fused SwiGLU backward + FP8 cast + FP8 transpose", - py::call_guard(), py::arg("grad_output"), py::arg("input"), - py::arg("grad_input"), py::arg("grad_input_transpose"), py::arg("scale"), py::arg("amax"), - py::arg("scale_inv"), py::arg("otype"), py::arg("scale_offset") = 0, - py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0); - m.def("fused_multi_cast_transpose", &fused_multi_cast_transpose, - "Fused Multi-tensor Cast + Transpose", py::call_guard()); - m.def("fused_multi_cast_transpose_alloc", &fused_multi_cast_transpose_alloc, - "Fused Multi-tensor Cast + Transpose with allocating output tensors", - py::call_guard()); - m.def("cast_to_fp8", &cast_to_fp8, "Cast to FP8", py::call_guard(), - py::arg("input"), py::arg("scale"), py::arg("amax"), py::arg("scale_inv"), py::arg("otype"), - py::arg("scale_offset") = 0, py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0); - m.def("cast_to_fp8_noalloc", &cast_to_fp8_noalloc, "Cast to FP8", - py::call_guard(), py::arg("input"), py::arg("scale"), - py::arg("output"), py::arg("amax"), py::arg("scale_inv"), py::arg("otype"), - py::arg("scale_offset") = 0, py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0); - m.def("cast_from_fp8", &cast_from_fp8, "Cast from FP8", py::call_guard(), - py::arg("input"), py::arg("scale_inv"), py::arg("itype"), py::arg("otype"), - py::arg("scale_inv_offset") = 0); - m.def("te_gemm", &te_gemm, "CublasLt GEMM"); /// TODO Think - m.def("te_grouped_gemm", &te_grouped_gemm, "Grouped GEMM"); - m.def("fused_attn_fwd_qkvpacked", &fused_attn_fwd_qkvpacked, - "Fused Attention FP8/BF16/FP16 FWD with packed QKV", - py::call_guard()); - m.def("fused_attn_bwd_qkvpacked", &fused_attn_bwd_qkvpacked, - "Fused Attention FP8/BF16/FP16 BWD with packed QKV", - py::call_guard()); - m.def("fused_attn_fwd_kvpacked", &fused_attn_fwd_kvpacked, - "Fused Attention FP8/BF16/FP16 FWD with packed KV", - py::call_guard()); - m.def("fused_attn_bwd_kvpacked", &fused_attn_bwd_kvpacked, - "Fused Attention FP8/BF16/FP16 BWD with packed KV", - py::call_guard()); + m.def("layernorm_fwd", &layernorm_fwd, "LayerNorm", py::arg("input"), py::arg("weight"), + py::arg("bias"), py::arg("eps"), py::arg("ln_out"), py::arg("quantizer"), py::arg("otype"), + py::arg("sm_margin"), py::arg("zero_centered_gamma")); + m.def("layernorm_bwd", &layernorm_bwd, "Backward of LayerNorm"); + m.def("rmsnorm_fwd", &rmsnorm_fwd, "RMSNorm", py::arg("input"), py::arg("weight"), py::arg("eps"), + py::arg("ln_out"), py::arg("quantizer"), py::arg("otype"), py::arg("sm_margin"), + py::arg("zero_centered_gamma")); + m.def("rmsnorm_bwd", &rmsnorm_bwd, "Backward of RMSNorm"); + m.def("fused_multi_quantize", &fused_multi_quantize, "Fused Multi-tensor Cast + Transpose", + py::arg("input_list"), py::arg("output_list"), py::arg("quantizer_list"), py::arg("otype")); + + m.def("te_general_grouped_gemm", &te_general_grouped_gemm, "Grouped GEMM"); m.def("fused_attn_fwd", &fused_attn_fwd, - "Fused Attention FP8/BF16/FP16 FWD with separate Q, K and V", - py::call_guard()); + "Fused Attention FP8/BF16/FP16 FWD with separate Q, K and V"); m.def("fused_attn_bwd", &fused_attn_bwd, - "Fused Attention FP8/BF16/FP16 BWD with separate Q, K and V", - py::call_guard()); - m.def("fp8_transpose", &fp8_transpose, "Transpose with FP8 I/O", - py::call_guard()); - m.def("fp8_transpose_noalloc", &fp8_transpose_noalloc, "Transpose with FP8 I/O", - py::call_guard()); - m.def("fp8_transpose_noalloc_noop", &fp8_transpose_noalloc_noop, - "Transpose with FP8 I/O with noop option.", py::call_guard()); - m.def("gelu", &gelu, "GeLU with FP8 output", py::call_guard()); - m.def("relu", &relu, "ReLU with FP8 output", py::call_guard()); - m.def("geglu", &geglu, "GeGLU with FP8 output", py::call_guard()); - m.def("reglu", ®lu, "ReGLU with FP8 output", py::call_guard()); - m.def("swiglu", &swiglu, "SwiGLU with FP8 output", py::call_guard()); - m.def("qgelu", &qgelu, "QuickGELU with FP8 output", py::call_guard()); - m.def("srelu", &srelu, "Squared ReLU with FP8 output", py::call_guard()); - m.def("dgelu", &dgelu, "Backward of GeLU", py::call_guard()); - m.def("drelu", &drelu, "Backward of ReLU", py::call_guard()); - m.def("dgeglu", &dgeglu, "Backward of GeGLU", py::call_guard()); - m.def("dreglu", &dreglu, "Backward of ReGLU", py::call_guard()); - m.def("dswiglu", &dswiglu, "Backward of SwiGLU", py::call_guard()); - m.def("dqgelu", &dqgelu, "Backward of QuickGELU", py::call_guard()); - m.def("dsrelu", &dsrelu, "Backward of Squared ReLU", py::call_guard()); + "Fused Attention FP8/BF16/FP16 BWD with separate Q, K and V"); + m.def("fp8_transpose", &fp8_transpose, "Transpose with FP8 I/O", py::arg("input"), + py::arg("dtype"), py::kw_only(), py::arg("out"), py::call_guard()); m.def("fa_prepare_fwd", &fa_prepare_fwd, "Prepare QKV for Flash Attention", py::call_guard()); m.def("fa_prepare_bwd", &fa_prepare_bwd, "Backward of QKV preparation for Flash Attention", @@ -233,30 +258,30 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::call_guard()); // Data structures - py::class_(m, "FP8TensorMeta") + py::class_(m, "FP8TensorMeta") .def(py::init<>()) - .def_readwrite("scale", &transformer_engine::FP8TensorMeta::scale) - .def_readwrite("scale_inv", &transformer_engine::FP8TensorMeta::scale_inv) - .def_readwrite("amax_history", &transformer_engine::FP8TensorMeta::amax_history); + .def_readwrite("scale", &transformer_engine::pytorch::FP8TensorMeta::scale) + .def_readwrite("scale_inv", &transformer_engine::pytorch::FP8TensorMeta::scale_inv) + .def_readwrite("amax_history", &transformer_engine::pytorch::FP8TensorMeta::amax_history); - py::enum_(m, "FP8FwdTensors") - .value("GEMM1_INPUT", transformer_engine::FP8FwdTensors::GEMM1_INPUT) - .value("GEMM1_WEIGHT", transformer_engine::FP8FwdTensors::GEMM1_WEIGHT) - .value("GEMM1_OUTPUT", transformer_engine::FP8FwdTensors::GEMM1_OUTPUT) - .value("GEMM2_INPUT", transformer_engine::FP8FwdTensors::GEMM2_INPUT) - .value("GEMM2_WEIGHT", transformer_engine::FP8FwdTensors::GEMM2_WEIGHT) - .value("GEMM2_OUTPUT", transformer_engine::FP8FwdTensors::GEMM2_OUTPUT) - .value("GEMM3_INPUT", transformer_engine::FP8FwdTensors::GEMM3_INPUT) - .value("GEMM3_WEIGHT", transformer_engine::FP8FwdTensors::GEMM3_WEIGHT) - .value("GEMM3_OUTPUT", transformer_engine::FP8FwdTensors::GEMM3_OUTPUT); + py::enum_(m, "FP8FwdTensors") + .value("GEMM1_INPUT", transformer_engine::pytorch::FP8FwdTensors::GEMM1_INPUT) + .value("GEMM1_WEIGHT", transformer_engine::pytorch::FP8FwdTensors::GEMM1_WEIGHT) + .value("GEMM1_OUTPUT", transformer_engine::pytorch::FP8FwdTensors::GEMM1_OUTPUT) + .value("GEMM2_INPUT", transformer_engine::pytorch::FP8FwdTensors::GEMM2_INPUT) + .value("GEMM2_WEIGHT", transformer_engine::pytorch::FP8FwdTensors::GEMM2_WEIGHT) + .value("GEMM2_OUTPUT", transformer_engine::pytorch::FP8FwdTensors::GEMM2_OUTPUT) + .value("GEMM3_INPUT", transformer_engine::pytorch::FP8FwdTensors::GEMM3_INPUT) + .value("GEMM3_WEIGHT", transformer_engine::pytorch::FP8FwdTensors::GEMM3_WEIGHT) + .value("GEMM3_OUTPUT", transformer_engine::pytorch::FP8FwdTensors::GEMM3_OUTPUT); - py::enum_(m, "FP8BwdTensors") - .value("GRAD_OUTPUT1", transformer_engine::FP8BwdTensors::GRAD_OUTPUT1) - .value("GRAD_INPUT1", transformer_engine::FP8BwdTensors::GRAD_INPUT1) - .value("GRAD_OUTPUT2", transformer_engine::FP8BwdTensors::GRAD_OUTPUT2) - .value("GRAD_INPUT2", transformer_engine::FP8BwdTensors::GRAD_INPUT2) - .value("GRAD_OUTPUT3", transformer_engine::FP8BwdTensors::GRAD_OUTPUT3) - .value("GRAD_INPUT3", transformer_engine::FP8BwdTensors::GRAD_INPUT3); + py::enum_(m, "FP8BwdTensors") + .value("GRAD_OUTPUT1", transformer_engine::pytorch::FP8BwdTensors::GRAD_OUTPUT1) + .value("GRAD_INPUT1", transformer_engine::pytorch::FP8BwdTensors::GRAD_INPUT1) + .value("GRAD_OUTPUT2", transformer_engine::pytorch::FP8BwdTensors::GRAD_OUTPUT2) + .value("GRAD_INPUT2", transformer_engine::pytorch::FP8BwdTensors::GRAD_INPUT2) + .value("GRAD_OUTPUT3", transformer_engine::pytorch::FP8BwdTensors::GRAD_OUTPUT3) + .value("GRAD_INPUT3", transformer_engine::pytorch::FP8BwdTensors::GRAD_INPUT3); py::class_(m, "CommOverlapHelper") .def(py::init<>(), py::call_guard()) @@ -265,54 +290,37 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::call_guard(), py::arg("world_group"), py::arg("intra_node_group") = py::none(), py::arg("inter_node_group") = py::none()); - py::class_(m, "CommOverlap") + py::class_, transformer_engine::CommOverlapBase, + transformer_engine::CommOverlapCore>(m, "CommOverlap") .def(py::init &, at::ScalarType, CommOverlapHelper *, int, int, int, - int, int, bool, bool>(), + int, int, int, int, bool, bool, bool>(), py::call_guard(), py::arg("buffer_shape"), py::arg("buffer_dtype"), py::arg("helper"), py::arg("tp_size"), py::arg("num_splits") = 3, py::arg("num_max_streams") = NVTE_COMM_OVERLAP_MAX_STREAMS, - py::arg("comm_cga_size") = 2, py::arg("num_comm_sm") = 16, - py::arg("set_sm_margin") = true, py::arg("atomic_gemm") = false) - .def("bulk_overlap", &CommOverlap::bulk_overlap, py::call_guard()) - .def("split_overlap_rs", &CommOverlap::split_overlap_rs, - py::call_guard()) - .def("atomic_gemm_overlap_rs", &CommOverlap::atomic_gemm_overlap_rs, - py::call_guard()) - .def("copy_input_to_ubuf", &CommOverlap::copy_input_to_ubuf, - py::call_guard()) - .def("get_ubuf_output", &CommOverlap::get_ubuf_output, - py::call_guard()) - .def("set_ubuf_scale_inv", &CommOverlap::set_ubuf_scale_inv, - py::call_guard()) - .def("is_atomic_gemm", &CommOverlap::is_atomic_gemm, py::call_guard()) - .def("is_p2p_overlap", &CommOverlap::is_p2p_overlap, py::call_guard()) - .def("is_fp8_ubuf", &CommOverlap::is_fp8_ubuf, py::call_guard()); + py::arg("comm_cga_size") = 2, py::arg("gemm_priority") = 0, py::arg("comm_priority") = 0, + py::arg("num_comm_sm") = 16, py::arg("set_sm_margin") = true, + py::arg("atomic_gemm") = false, py::arg("rs_overlap_first_gemm") = false) + .def("copy_into_buffer", &CommOverlap::copy_into_buffer, py::arg("input"), + py::arg("quantizer"), py::arg("local_chunk") = false) + .def("get_buffer", &CommOverlap::get_buffer, py::arg("quantizer"), + py::arg("local_chunk") = false, py::arg("shape") = std::nullopt) + .def("set_buffer_params", &CommOverlap::set_buffer_params); - py::class_(m, "CommOverlapP2P") + py::class_, + transformer_engine::CommOverlapP2PBase, transformer_engine::CommOverlapCore>( + m, "CommOverlapP2P") .def(py::init &, at::ScalarType, CommOverlapHelper *, int, - transformer_engine::CommOverlapType, int, int, int, bool, bool, bool, bool>(), + transformer_engine::CommOverlapType, int, int, int, int, int, bool, bool, bool, + bool>(), py::call_guard(), py::arg("buffer_shape"), py::arg("buffer_dtype"), py::arg("helper"), py::arg("tp_size"), py::arg("comm_type"), py::arg("num_max_streams") = NVTE_COMM_OVERLAP_MAX_STREAMS, py::arg("comm_cga_size") = 1, - py::arg("num_comm_sm") = 1, py::arg("set_sm_margin") = false, - py::arg("atomic_gemm") = false, py::arg("use_ce") = true, py::arg("aggregate") = false) - .def("split_overlap_ag_p2p", &CommOverlapP2P::split_overlap_ag, - py::call_guard()) - .def("split_overlap_rs_p2p", &CommOverlapP2P::split_overlap_rs, - py::call_guard()) - .def("atomic_gemm_overlap_ag_p2p", &CommOverlapP2P::atomic_gemm_overlap_ag, - py::call_guard()) - .def("atomic_gemm_overlap_rs_p2p", &CommOverlapP2P::atomic_gemm_overlap_rs, - py::call_guard()) - .def("copy_input_to_ubuf", &CommOverlapP2P::copy_input_to_ubuf, - py::call_guard()) - .def("get_ubuf_output", &CommOverlapP2P::get_ubuf_output, - py::call_guard()) - .def("set_ubuf_scale_inv", &CommOverlapP2P::set_ubuf_scale_inv, - py::call_guard()) - .def("is_fp8_ubuf", &CommOverlapP2P::is_fp8_ubuf, py::call_guard()) - .def("is_atomic_gemm", &CommOverlapP2P::is_atomic_gemm, - py::call_guard()) - .def("is_p2p_overlap", &CommOverlapP2P::is_p2p_overlap, - py::call_guard()); + py::arg("gemm_priority") = 0, py::arg("comm_priority") = 0, py::arg("num_comm_sm") = 1, + py::arg("set_sm_margin") = false, py::arg("atomic_gemm") = false, + py::arg("use_ce") = true, py::arg("aggregate") = false) + .def("copy_into_buffer", &CommOverlapP2P::copy_into_buffer, py::arg("input"), + py::arg("quantizer"), py::arg("local_chunk") = false) + .def("get_buffer", &CommOverlapP2P::get_buffer, py::arg("quantizer"), + py::arg("local_chunk") = false, py::arg("shape") = std::nullopt) + .def("set_buffer_params", &CommOverlapP2P::set_buffer_params); } diff --git a/transformer_engine/pytorch/csrc/extensions/quantizer.cpp b/transformer_engine/pytorch/csrc/extensions/quantizer.cpp new file mode 100644 index 0000000000..effeb8cb4d --- /dev/null +++ b/transformer_engine/pytorch/csrc/extensions/quantizer.cpp @@ -0,0 +1,227 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include + +#include "common.h" +#include "pybind.h" +#include "torch/torch.h" +#include "util.h" + +namespace transformer_engine::pytorch { + +constexpr size_t MXFP8_BLOCK_SIZE = 32; + +Quantizer::Quantizer(const py::handle& quantizer) { + if (quantizer.is_none()) { + this->rowwise_usage = true; + this->columnwise_usage = true; + this->internal = false; + } else { + this->rowwise_usage = quantizer.attr("rowwise_usage").cast(); + this->columnwise_usage = quantizer.attr("columnwise_usage").cast(); + this->internal = quantizer.attr("internal").cast(); + this->quantizer = quantizer; + } +} + +Float8Quantizer::Float8Quantizer(const py::handle& quantizer) : Quantizer(quantizer) { + const at::Tensor& scale = quantizer.attr("scale").cast(); + const at::Tensor& amax = quantizer.attr("amax").cast(); + const DType type = quantizer.attr("dtype").cast(); + + this->amax = amax; + this->scale = scale; + this->dtype = type; +} + +std::pair NoneQuantizer::create_tensor( + const std::vector& shape, DType dtype, std::optional rowwise_data) const { + at::TensorOptions opts; + opts = opts.dtype(GetATenDType(dtype)).device(torch::kCUDA); + std::vector torch_shape; + for (auto s : shape) { + torch_shape.emplace_back(static_cast(s)); + } + at::Tensor ret; + if (rowwise_data.has_value()) { + ret = std::move(*rowwise_data); + } else { + ret = at::empty(torch_shape, opts); + } + + TensorWrapper tensor; + tensor.set_rowwise_data(ret.data_ptr(), dtype, shape); + return {std::move(tensor), py::cast(ret)}; +} + +void Float8Quantizer::set_quantization_params(TensorWrapper* tensor) const { + tensor->set_scale(scale.data_ptr(), GetTransformerEngineDType(scale.scalar_type()), + getTensorShape(scale)); + at::TensorOptions opts = opts.dtype(torch::kFloat32).device(torch::kCUDA); + tensor->set_amax(amax.data_ptr(), GetTransformerEngineDType(amax.scalar_type()), + getTensorShape(amax)); + auto rowwise_data = tensor->get_rowwise_data(); + rowwise_data.dtype = static_cast(dtype); + + auto columnwise_data = tensor->get_columnwise_data(); + columnwise_data.dtype = static_cast(dtype); + + tensor->set_rowwise_data(rowwise_data.data_ptr, static_cast(rowwise_data.dtype), + rowwise_data.shape); + tensor->set_columnwise_data(columnwise_data.data_ptr, static_cast(columnwise_data.dtype), + columnwise_data.shape); +} + +std::pair Float8Quantizer::create_tensor( + const std::vector& shape, DType dtype, std::optional rowwise_data) const { + using namespace pybind11::literals; + std::vector rowwise_torch_shape; + std::vector columnwise_torch_shape; + + if (!shape.empty()) { + columnwise_torch_shape.emplace_back(static_cast(shape.back())); + } + for (size_t i = 0; i < shape.size(); ++i) { + if (i < shape.size() - 1) { + columnwise_torch_shape.emplace_back(static_cast(shape[i])); + } + rowwise_torch_shape.emplace_back(static_cast(shape[i])); + } + at::TensorOptions opts; + opts = opts.dtype(torch::kUInt8).device(torch::kCUDA); + at::Tensor data; + if (rowwise_usage) { + if (rowwise_data.has_value()) { + data = std::move(*rowwise_data); + } else { + data = at::empty(rowwise_torch_shape, opts); + } + } + const py::object py_data = rowwise_usage ? py::cast(data) : py::none(); + at::Tensor columnwise_data; + bool create_transpose = columnwise_usage && !non_tn_fp8_gemm_supported(); + if (create_transpose) { + columnwise_data = at::empty(columnwise_torch_shape, opts); + } + const py::object py_columnwise_data = create_transpose ? py::cast(columnwise_data) : py::none(); + opts = opts.dtype(torch::kFloat32); + at::Tensor scale_inv = at::reciprocal(scale); + py::object ret; + if (internal) { + py::handle Float8TensorClass(reinterpret_cast(Float8TensorBasePythonClass)); + ret = Float8TensorClass("data"_a = py_data, "fp8_scale_inv"_a = scale_inv, + "fp8_dtype"_a = this->dtype, "data_transpose"_a = py_columnwise_data, + "quantizer"_a = this->quantizer); + } else { + py::handle Float8TensorClass(reinterpret_cast(Float8TensorPythonClass)); + ret = Float8TensorClass("shape"_a = rowwise_torch_shape, "dtype"_a = GetATenDType(dtype), + "data"_a = py_data, "fp8_scale_inv"_a = scale_inv, + "fp8_dtype"_a = this->dtype, "data_transpose"_a = py_columnwise_data, + "quantizer"_a = this->quantizer); + } + TensorWrapper tensor(this->get_scaling_mode()); + if (rowwise_usage) { + tensor.set_rowwise_data(data.data_ptr(), this->dtype, shape); + tensor.set_rowwise_scale_inv(scale_inv.data_ptr(), DType::kFloat32, std::vector{1}); + } + if (create_transpose) { + std::vector transposed_shape; + for (auto s : columnwise_torch_shape) { + transposed_shape.emplace_back(static_cast(s)); + } + tensor.set_columnwise_data(columnwise_data.data_ptr(), this->dtype, transposed_shape); + tensor.set_columnwise_scale_inv(scale_inv.data_ptr(), DType::kFloat32, std::vector{1}); + } + this->set_quantization_params(&tensor); + return {std::move(tensor), std::move(ret)}; +} + +MXFP8Quantizer::MXFP8Quantizer(const py::handle& quantizer) : Quantizer(quantizer) { + this->dtype = quantizer.attr("dtype").cast(); +} + +void MXFP8Quantizer::set_quantization_params(TensorWrapper* tensor) const { + auto rowwise_data = tensor->get_rowwise_data(); + rowwise_data.dtype = static_cast(dtype); + + auto columnwise_data = tensor->get_columnwise_data(); + columnwise_data.dtype = static_cast(dtype); + + tensor->set_rowwise_data(rowwise_data.data_ptr, static_cast(rowwise_data.dtype), + rowwise_data.shape); + tensor->set_columnwise_data(columnwise_data.data_ptr, static_cast(columnwise_data.dtype), + columnwise_data.shape); +} + +std::pair MXFP8Quantizer::create_tensor( + const std::vector& shape, DType dtype, std::optional rowwise_data) const { + using namespace pybind11::literals; + std::vector torch_shape; + size_t numel = 1; + for (auto s : shape) { + torch_shape.emplace_back(static_cast(s)); + numel *= s; + } + + TensorWrapper tensor(NVTE_MXFP8_1D_SCALING); + at::TensorOptions opts; + at::Tensor rowwise_data1, columnwise_data, rowwise_scale_inv, + columnwise_scale_inv; // TODO(pgadzinski) - change + opts = opts.dtype(torch::kUInt8).device(torch::kCUDA); + auto last_dim = static_cast(torch_shape.back()); + + NVTE_CHECK(last_dim % MXFP8_BLOCK_SIZE == 0 && (numel / last_dim) % MXFP8_BLOCK_SIZE == 0, + "MXFP8 requires tensor dims that are divisble by ", MXFP8_BLOCK_SIZE, + " (got shape=", torch_shape, ")"); + + at::Tensor data; + if (rowwise_usage) { + if (rowwise_data.has_value()) { + data = std::move(*rowwise_data); + } else { + data = at::empty(torch_shape, opts); + } + auto sinv0 = roundup(numel / last_dim, 128); + auto sinv1 = roundup(last_dim / MXFP8_BLOCK_SIZE, 4); + rowwise_scale_inv = at::zeros({sinv0, sinv1}, opts); + tensor.set_rowwise_data(data.data_ptr(), this->dtype, shape); + tensor.set_rowwise_scale_inv(rowwise_scale_inv.data_ptr(), DType::kFloat8E8M0, + std::vector{sinv0, sinv1}); + } + + if (columnwise_usage) { + auto sinv0 = roundup(numel / (last_dim * MXFP8_BLOCK_SIZE), 4); + auto sinv1 = roundup(last_dim, 128); + columnwise_data = at::empty(torch_shape, opts); + columnwise_scale_inv = at::zeros({sinv0, sinv1}, opts); + + tensor.set_columnwise_data(columnwise_data.data_ptr(), this->dtype, shape); + tensor.set_columnwise_scale_inv(columnwise_scale_inv.data_ptr(), DType::kFloat8E8M0, + std::vector{sinv0, sinv1}); + } + this->set_quantization_params(&tensor); + + py::object ret; + if (internal) { + py::handle MXFP8TensorClass(reinterpret_cast(MXFP8TensorBasePythonClass)); + ret = MXFP8TensorClass("rowwise_data"_a = data, "columnwise_data"_a = columnwise_data, + "rowwise_scale_inv"_a = rowwise_scale_inv, + "columnwise_scale_inv"_a = columnwise_scale_inv, + "fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer); + } else { + py::handle MXFP8TensorClass(reinterpret_cast(MXFP8TensorPythonClass)); + ret = MXFP8TensorClass("shape"_a = torch_shape, "dtype"_a = GetATenDType(dtype), + "rowwise_data"_a = data, "columnwise_data"_a = columnwise_data, + "rowwise_scale_inv"_a = rowwise_scale_inv, + "columnwise_scale_inv"_a = columnwise_scale_inv, + "fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer); + } + + return {std::move(tensor), std::move(ret)}; +} + +} // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/extensions/recipe.cpp b/transformer_engine/pytorch/csrc/extensions/recipe.cpp index ec75a2a8c6..e8a31da99a 100644 --- a/transformer_engine/pytorch/csrc/extensions/recipe.cpp +++ b/transformer_engine/pytorch/csrc/extensions/recipe.cpp @@ -9,20 +9,22 @@ #include +#include "common/common.h" #include "extensions.h" -void fused_amax_and_scale_update_after_reduction( - const at::Tensor &amax_reduction_buffer, std::vector amax_histories, - std::vector scales, std::vector scale_invs, - const std::string &amax_compute_algo, transformer_engine::DType fp8_dtype, float margin) { +void fused_amax_and_scale_update_after_reduction(const at::Tensor &amax_reduction_buffer, + std::vector amax_histories, + std::vector scales, + const std::string &amax_compute_algo, + transformer_engine::DType fp8_dtype, + float margin) { using namespace transformer_engine; + using namespace transformer_engine::pytorch; size_t num_tensors = amax_histories.size(); std::vector t_amax_histories(num_tensors); std::vector t_scales(num_tensors); - std::vector t_scale_invs(num_tensors); std::vector te_amax_histories(num_tensors); std::vector te_scales(num_tensors); - std::vector te_scale_invs(num_tensors); for (size_t i = 0; i < num_tensors; i++) { t_amax_histories[i].data.dptr = amax_histories[i].data_ptr(); auto amax_sizes = amax_histories[i].sizes().vec(); @@ -36,18 +38,11 @@ void fused_amax_and_scale_update_after_reduction( t_scales[i].data.shape = scale_shape; t_scales[i].data.dtype = DType::kFloat32; - t_scale_invs[i].data.dptr = scale_invs[i].data_ptr(); - auto scale_inv_sizes = scale_invs[i].sizes().vec(); - std::vector scale_inv_shape{scale_inv_sizes.begin(), scale_inv_sizes.end()}; - t_scale_invs[i].data.shape = scale_inv_shape; - t_scale_invs[i].data.dtype = DType::kFloat32; - te_amax_histories[i] = reinterpret_cast(&t_amax_histories[i]); te_scales[i] = reinterpret_cast(&t_scales[i]); - te_scale_invs[i] = reinterpret_cast(&t_scale_invs[i]); } nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction( makeTransformerEngineTensor(amax_reduction_buffer).data(), te_amax_histories, te_scales, - te_scale_invs, amax_compute_algo.c_str(), static_cast(fp8_dtype), margin, + amax_compute_algo.c_str(), static_cast(fp8_dtype), margin, at::cuda::getCurrentCUDAStream()); } diff --git a/transformer_engine/pytorch/csrc/extensions/softmax.cpp b/transformer_engine/pytorch/csrc/extensions/softmax.cpp index 93be90c9f3..02f8fcbdf6 100644 --- a/transformer_engine/pytorch/csrc/extensions/softmax.cpp +++ b/transformer_engine/pytorch/csrc/extensions/softmax.cpp @@ -7,7 +7,7 @@ #include "extensions.h" at::Tensor scaled_softmax_forward(at::Tensor input, float scale_factor) { - using namespace transformer_engine; + using namespace transformer_engine::pytorch; AT_ASSERTM(input.dim() == 4, "expected 4D tensor"); AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || (input.scalar_type() == at::ScalarType::BFloat16), @@ -38,7 +38,7 @@ at::Tensor scaled_softmax_forward(at::Tensor input, float scale_factor) { at::Tensor scaled_softmax_backward(at::Tensor output_grad_, at::Tensor softmax_results_, float scale_factor) { - using namespace transformer_engine; + using namespace transformer_engine::pytorch; auto output_grads = output_grad_.contiguous(); auto softmax_results = softmax_results_.contiguous(); @@ -65,7 +65,7 @@ at::Tensor scaled_softmax_backward(at::Tensor output_grad_, at::Tensor softmax_r } at::Tensor scaled_masked_softmax_forward(at::Tensor input, at::Tensor mask, float scale_factor) { - using namespace transformer_engine; + using namespace transformer_engine::pytorch; AT_ASSERTM(input.dim() == 4, "expected 4D tensor"); AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || @@ -105,7 +105,7 @@ at::Tensor scaled_masked_softmax_forward(at::Tensor input, at::Tensor mask, floa at::Tensor scaled_masked_softmax_backward(at::Tensor output_grad_, at::Tensor softmax_results_, float scale_factor) { - using namespace transformer_engine; + using namespace transformer_engine::pytorch; auto output_grads = output_grad_.contiguous(); auto softmax_results = softmax_results_.contiguous(); @@ -132,7 +132,7 @@ at::Tensor scaled_masked_softmax_backward(at::Tensor output_grad_, at::Tensor so } at::Tensor scaled_upper_triang_masked_softmax_forward(at::Tensor input, float scale_factor) { - using namespace transformer_engine; + using namespace transformer_engine::pytorch; AT_ASSERTM(input.dim() == 3, "expected 3D tensor"); AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || @@ -159,7 +159,7 @@ at::Tensor scaled_upper_triang_masked_softmax_forward(at::Tensor input, float sc at::Tensor scaled_upper_triang_masked_softmax_backward(at::Tensor output_grads_, at::Tensor softmax_results_, float scale_factor) { - using namespace transformer_engine; + using namespace transformer_engine::pytorch; auto output_grads = output_grads_.contiguous(); auto softmax_results = softmax_results_.contiguous(); @@ -188,7 +188,7 @@ at::Tensor scaled_upper_triang_masked_softmax_backward(at::Tensor output_grads_, } at::Tensor scaled_aligned_causal_masked_softmax_forward(at::Tensor input, float scale_factor) { - using namespace transformer_engine; + using namespace transformer_engine::pytorch; AT_ASSERTM(input.dim() == 4, "expected 4D tensor"); AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || (input.scalar_type() == at::ScalarType::BFloat16), @@ -220,7 +220,7 @@ at::Tensor scaled_aligned_causal_masked_softmax_forward(at::Tensor input, float at::Tensor scaled_aligned_causal_masked_softmax_backward(at::Tensor output_grad_, at::Tensor softmax_results_, float scale_factor) { - using namespace transformer_engine; + using namespace transformer_engine::pytorch; auto output_grads = output_grad_.contiguous(); auto softmax_results = softmax_results_.contiguous(); diff --git a/transformer_engine/pytorch/csrc/extensions/swizzle.cpp b/transformer_engine/pytorch/csrc/extensions/swizzle.cpp new file mode 100644 index 0000000000..316e6515bf --- /dev/null +++ b/transformer_engine/pytorch/csrc/extensions/swizzle.cpp @@ -0,0 +1,120 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "extensions.h" +#include "transformer_engine/transformer_engine.h" + +void swizzle_scaling_factors(transformer_engine::TensorWrapper& input, bool rowwise) { + using namespace transformer_engine::pytorch; + + if (input.scaling_mode() == NVTE_INVALID_SCALING) { + NVTE_ERROR("Invalid scaling mode for swizzle."); + } else if (input.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) { + return; + } + + NVTE_CHECK(input.element_size() == 1, "8-bit input required for swizzling scaling factors."); + + NVTEBasicTensor scale_inv; + if (rowwise) { + scale_inv = input.get_rowwise_scale_inv(); + } else { + scale_inv = input.get_columnwise_scale_inv(); + } + + auto input_shape = nvte_shape_to_vector(input.shape()); + auto scale_inv_shape = nvte_shape_to_vector(scale_inv.shape); + + // Allocate memory for swizzled output. + auto options = at::TensorOptions().dtype(torch::kByte).device(torch::kCUDA); + std::vector scale_inv_shape_int; + for (size_t i = 0; i < scale_inv_shape.size(); ++i) { + scale_inv_shape_int.push_back(static_cast(scale_inv_shape[i])); + } + auto swizzled_scale_inv = at::empty(scale_inv_shape_int, options); + void* scale_inv_dptr = scale_inv.data_ptr; + void* swizzled_scale_inv_dptr = getDataPtr(swizzled_scale_inv, 0); + + // Reconstruct input only to avoid swizzling both directions if not needed. + // Use any 8 bit type, it's irrelevant. + transformer_engine::TensorWrapper input_cu(NVTE_MXFP8_1D_SCALING); + transformer_engine::TensorWrapper output_cu(NVTE_MXFP8_1D_SCALING); + if (rowwise) { + input_cu.set_rowwise_data(input.dptr(), DType::kFloat8E4M3, input_shape); + input_cu.set_rowwise_scale_inv(scale_inv_dptr, DType::kFloat8E8M0, scale_inv_shape); + output_cu.set_rowwise_data(input.dptr(), DType::kFloat8E4M3, input_shape); + output_cu.set_rowwise_scale_inv(swizzled_scale_inv_dptr, DType::kFloat8E8M0, scale_inv_shape); + } else { + input_cu.set_columnwise_data(input.dptr(), DType::kFloat8E4M3, input_shape); + input_cu.set_columnwise_scale_inv(scale_inv_dptr, DType::kFloat8E8M0, scale_inv_shape); + output_cu.set_columnwise_data(input.dptr(), DType::kFloat8E4M3, input_shape); + output_cu.set_columnwise_scale_inv(swizzled_scale_inv_dptr, DType::kFloat8E8M0, + scale_inv_shape); + } + + // Launch kernel + nvte_swizzle_scaling_factors(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); + + if (rowwise) { + input.set_rowwise_scale_inv(swizzled_scale_inv_dptr, DType::kFloat8E8M0, scale_inv_shape); + } else { + input.set_columnwise_scale_inv(swizzled_scale_inv_dptr, DType::kFloat8E8M0, scale_inv_shape); + } +} + +at::Tensor rowwise_swizzle(at::Tensor input, at::Tensor scale_inv) { + using namespace transformer_engine::pytorch; + + NVTE_CHECK(input.element_size() == 1, "8-bit input required for swizzling scaling factors."); + + auto options = at::TensorOptions().dtype(scale_inv.dtype()).device(torch::kCUDA); + auto swizzled_scale_inv = at::empty_like(scale_inv, options); + + void* scale_inv_dptr = getDataPtr(scale_inv, 0); + void* swizzled_scale_inv_dptr = getDataPtr(swizzled_scale_inv, 0); + + auto input_cu = makeTransformerEngineTensor(input.data_ptr(), getTensorShape(input), + DType::kFloat8E4M3, nullptr, nullptr, scale_inv_dptr, + getTensorShape(scale_inv), NVTE_MXFP8_1D_SCALING); + auto output_cu = makeTransformerEngineTensor( + input.data_ptr(), getTensorShape(input), DType::kFloat8E4M3, nullptr, nullptr, + swizzled_scale_inv_dptr, getTensorShape(swizzled_scale_inv), NVTE_MXFP8_1D_SCALING); + + // Launch kernel + nvte_swizzle_scaling_factors(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); + + return swizzled_scale_inv; +} + +at::Tensor columnwise_swizzle(at::Tensor input, at::Tensor scale_inv) { + using namespace transformer_engine::pytorch; + + NVTE_CHECK(input.element_size() == 1, "8-bit input required for swizzling scaling factors."); + + auto options = at::TensorOptions().dtype(scale_inv.dtype()).device(torch::kCUDA); + auto swizzled_scale_inv = at::empty_like(scale_inv, options); + + // Return immediately if tensor is empty + if (scale_inv.numel() == 0) { + return swizzled_scale_inv; + } + + void* scale_inv_dptr = getDataPtr(scale_inv, 0); + void* swizzled_scale_inv_dptr = getDataPtr(swizzled_scale_inv, 0); + + auto input_cu = makeTransformerEngineTensor( + nullptr, input.data_ptr(), {1}, getTensorShape(input), DType::kFloat8E4M3, nullptr, nullptr, + nullptr, scale_inv_dptr, {1}, getTensorShape(scale_inv), NVTE_MXFP8_1D_SCALING); + auto output_cu = makeTransformerEngineTensor( + nullptr, input.data_ptr(), {1}, getTensorShape(input), DType::kFloat8E4M3, nullptr, nullptr, + nullptr, swizzled_scale_inv_dptr, {1}, getTensorShape(swizzled_scale_inv), + NVTE_MXFP8_1D_SCALING); + + // Launch kernel + nvte_swizzle_scaling_factors(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); + + return swizzled_scale_inv; +} diff --git a/transformer_engine/pytorch/csrc/extensions/transpose.cpp b/transformer_engine/pytorch/csrc/extensions/transpose.cpp index 40f76c898c..37fbddcc18 100644 --- a/transformer_engine/pytorch/csrc/extensions/transpose.cpp +++ b/transformer_engine/pytorch/csrc/extensions/transpose.cpp @@ -4,434 +4,104 @@ * See LICENSE for license information. ************************************************************************/ -#include "extensions.h" - -void fused_cast_transpose(at::Tensor input, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, - at::Tensor input_cast, at::Tensor input_transpose, - transformer_engine::DType otype) { - using namespace transformer_engine; - - size_t M = static_cast(input.size(0)); - size_t N = static_cast(input.size(1)); - - auto input_cu = makeTransformerEngineTensor(input); - auto output_cast_cu = - makeTransformerEngineTensor(input_cast.data_ptr(), {M, N}, otype, amax.data_ptr(), - scale.data_ptr(), scale_inv.data_ptr()); - auto output_transpose_cu = - makeTransformerEngineTensor(input_transpose.data_ptr(), {N, M}, otype, amax.data_ptr(), - scale.data_ptr(), scale_inv.data_ptr()); +#include - nvte_cast_transpose(input_cu.data(), output_cast_cu.data(), output_transpose_cu.data(), - at::cuda::getCurrentCUDAStream()); -} - -void fused_cast_transpose_noop(at::Tensor input, at::Tensor noop, at::Tensor scale, at::Tensor amax, - at::Tensor scale_inv, at::Tensor input_cast, - at::Tensor input_transpose, transformer_engine::DType otype, - int scale_offset, int amax_offset, int scale_inv_offset) { - using namespace transformer_engine; - - // Tensor dimensions - size_t M = static_cast(input.size(0)); - size_t N = static_cast(input.size(1)); - - // Get pointers for FP8 scale, amax, scale-inverse - void* scale_dptr = getDataPtr(scale, scale_offset); - void* amax_dptr = getDataPtr(amax, amax_offset); - void* scale_inv_dptr = getDataPtr(scale_inv, scale_inv_offset); - - // Construct Transformer Engine tensors - auto input_cu = makeTransformerEngineTensor(input); - auto noop_cu = makeTransformerEngineTensor(noop); - auto output_cast_cu = makeTransformerEngineTensor(input_cast.data_ptr(), {M, N}, otype, amax_dptr, - scale_dptr, scale_inv_dptr); - auto output_transpose_cu = makeTransformerEngineTensor(input_transpose.data_ptr(), {N, M}, otype, - amax_dptr, scale_dptr, scale_inv_dptr); - - // Launch kernel - nvte_cast_transpose_with_noop(input_cu.data(), noop_cu.data(), output_cast_cu.data(), - output_transpose_cu.data(), at::cuda::getCurrentCUDAStream()); -} - -std::vector fused_cast_transpose_bgrad(at::Tensor grad_output, at::Tensor scale, - at::Tensor amax, at::Tensor scale_inv, - transformer_engine::DType otype, - int scale_offset, int amax_offset, - int scale_inv_offset) { - using namespace transformer_engine; - - // Tensor dimensions - size_t M = static_cast(grad_output.size(0)); - size_t N = static_cast(grad_output.size(1)); +#include "ATen/core/TensorBody.h" +#include "extensions.h" - // Allocate output tensors - DType grad_output_type = GetTransformerEngineDType(grad_output.scalar_type()); - auto grad_bias = allocateTorchTensor(grad_output.size(-1), grad_output_type); - auto grad_output_cast = - allocateTorchTensor(grad_output.size(0), grad_output.size(1), DType::kByte); - auto grad_output_transpose = - allocateTorchTensor(grad_output.size(1), grad_output.size(0), DType::kByte); +std::vector fused_multi_quantize(std::vector input_list, + std::optional> output_list, + std::vector quantizer_list, + transformer_engine::DType otype) { + using namespace transformer_engine::pytorch; + std::vector nvte_tensor_input_list; + std::vector nvte_tensor_output_list; + std::vector py_output_objects_list; + std::vector tensor_wrappers; + auto none = py::none(); + + // create TE tensors from input + for (int i = 0; i < input_list.size(); i++) { + auto input_tensor = makeTransformerEngineTensor(input_list[i], none); + const NVTEShape input_shape = input_tensor.shape(); + + transformer_engine::TensorWrapper output_tensor; + + if (output_list == std::nullopt) { + std::unique_ptr quantizer = convert_quantizer(quantizer_list[i]); + std::vector output_shape(input_shape.data, input_shape.data + input_shape.ndim); + py::object o; + std::tie(output_tensor, o) = quantizer->create_tensor(output_shape, otype); + py_output_objects_list.push_back(o); + } else { + output_tensor = makeTransformerEngineTensor((*output_list)[i], quantizer_list[i]); + } + if (input_tensor.numel() == 0) continue; - // Return immediately if tensors are empty - if (M == 0 || N == 0) { - return {grad_bias.zero_(), grad_output_cast, grad_output_transpose}; + nvte_tensor_output_list.emplace_back(output_tensor.data()); + nvte_tensor_input_list.emplace_back(input_tensor.data()); + tensor_wrappers.emplace_back(std::move(input_tensor)); + tensor_wrappers.emplace_back(std::move(output_tensor)); } - // Get pointers for FP8 scale, amax, scale-inverse - void* scale_dptr = getDataPtr(scale, scale_offset); - void* amax_dptr = getDataPtr(amax, amax_offset); - void* scale_inv_dptr = getDataPtr(scale_inv, scale_inv_offset); - - // Construct Transformer Engine tensors - auto input_cu = makeTransformerEngineTensor(grad_output); - auto cast_output_cu = makeTransformerEngineTensor(grad_output_cast.data_ptr(), {M, N}, otype, - amax_dptr, scale_dptr, scale_inv_dptr); - auto transposed_output_cu = makeTransformerEngineTensor( - grad_output_transpose.data_ptr(), {N, M}, otype, amax_dptr, scale_dptr, scale_inv_dptr); - auto dbias_cu = makeTransformerEngineTensor(grad_bias); - - // Query workspace size and allocate workspace - transformer_engine::TensorWrapper workspace; - nvte_cast_transpose_dbias(input_cu.data(), cast_output_cu.data(), transposed_output_cu.data(), - dbias_cu.data(), workspace.data(), at::cuda::getCurrentCUDAStream()); - auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); - workspace = - makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype()); - - // Launch kernel - nvte_cast_transpose_dbias(input_cu.data(), cast_output_cu.data(), transposed_output_cu.data(), - dbias_cu.data(), workspace.data(), at::cuda::getCurrentCUDAStream()); - - return {grad_bias, grad_output_cast, grad_output_transpose}; -} - -std::vector fused_fp8_transpose_bgrad(at::Tensor grad_output, at::Tensor scale, - at::Tensor amax, at::Tensor scale_inv, - transformer_engine::DType otype, - transformer_engine::DType grad_bias_type, - int scale_offset, int amax_offset, - int scale_inv_offset) { - using namespace transformer_engine; - - // Tensor dimensions - size_t M = static_cast(grad_output.size(0)); - size_t N = static_cast(grad_output.size(1)); - - // Get pointers for FP8 scale, amax, scale-inverse - void* scale_dptr = getDataPtr(scale, scale_offset); - void* amax_dptr = getDataPtr(amax, amax_offset); - void* scale_inv_dptr = getDataPtr(scale_inv, scale_inv_offset); - - // Construct Transformer Engine tensors - auto grad_bias = allocateTorchTensor(grad_output.size(-1), grad_bias_type); - auto grad_output_transpose = - allocateTorchTensor(grad_output.size(1), grad_output.size(0), DType::kByte); - auto input_cu = makeTransformerEngineTensor(grad_output.data_ptr(), {M, N}, otype, amax_dptr, - scale_dptr, scale_inv_dptr); - auto transposed_output_cu = makeTransformerEngineTensor( - grad_output_transpose.data_ptr(), {N, M}, otype, amax_dptr, scale_dptr, scale_inv_dptr); - auto dbias_cu = makeTransformerEngineTensor(grad_bias); - - // Query workspace size and allocate workspace - transformer_engine::TensorWrapper workspace; - nvte_fp8_transpose_dbias(input_cu.data(), transposed_output_cu.data(), dbias_cu.data(), - workspace.data(), at::cuda::getCurrentCUDAStream()); - auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); - workspace = - makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype()); - - // Launch kernel - nvte_fp8_transpose_dbias(input_cu.data(), transposed_output_cu.data(), dbias_cu.data(), - workspace.data(), at::cuda::getCurrentCUDAStream()); - - return {grad_bias, grad_output_transpose}; -} - -std::vector fused_cast_transpose_bgrad_dgelu(at::Tensor grad_output, - at::Tensor gelu_input, at::Tensor scale, - at::Tensor amax, at::Tensor scale_inv, - transformer_engine::DType otype, - int scale_offset, int amax_offset, - int scale_inv_offset) { - using namespace transformer_engine; - - // Tensor dimensions - size_t M = static_cast(grad_output.size(0)); - size_t N = static_cast(grad_output.size(1)); - - // Get pointers for FP8 scale, amax, scale-inverse - void* scale_dptr = getDataPtr(scale, scale_offset); - void* amax_dptr = getDataPtr(amax, amax_offset); - void* scale_inv_dptr = getDataPtr(scale_inv, scale_inv_offset); - - // Construct Transformer Engine tensors - DType grad_output_type = GetTransformerEngineDType(grad_output.scalar_type()); - auto grad_bias = allocateTorchTensor(grad_output.size(-1), grad_output_type); - auto dgelu = allocateTorchTensor(grad_output.size(0), grad_output.size(1), DType::kByte); - auto dgelu_transpose = - allocateTorchTensor(grad_output.size(1), grad_output.size(0), DType::kByte); - auto gelu_input_cu = makeTransformerEngineTensor(gelu_input); - auto input_cu = makeTransformerEngineTensor(grad_output); - auto cast_output_cu = makeTransformerEngineTensor(dgelu.data_ptr(), {M, N}, otype, amax_dptr, - scale_dptr, scale_inv_dptr); - auto transposed_output_cu = makeTransformerEngineTensor(dgelu_transpose.data_ptr(), {N, M}, otype, - amax_dptr, scale_dptr, scale_inv_dptr); - auto dbias_cu = makeTransformerEngineTensor(grad_bias); - - // Query workspace size and allocate workspace - transformer_engine::TensorWrapper workspace; - nvte_cast_transpose_dbias_dgelu(input_cu.data(), gelu_input_cu.data(), cast_output_cu.data(), - transposed_output_cu.data(), dbias_cu.data(), workspace.data(), - at::cuda::getCurrentCUDAStream()); - auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); - workspace = - makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype()); - - // Launch kernel - nvte_cast_transpose_dbias_dgelu(input_cu.data(), gelu_input_cu.data(), cast_output_cu.data(), - transposed_output_cu.data(), dbias_cu.data(), workspace.data(), - at::cuda::getCurrentCUDAStream()); - - return {grad_bias, dgelu, dgelu_transpose}; -} - -void fused_dswiglu_cast_transpose(at::Tensor grad_output, at::Tensor input, at::Tensor grad_input, - at::Tensor grad_input_transpose, at::Tensor scale, - at::Tensor amax, at::Tensor scale_inv, - transformer_engine::DType otype, int scale_offset, - int amax_offset, int scale_inv_offset) { - using namespace transformer_engine; - - // Tensor dimensions - auto outer_dim = [](const at::Tensor& tensor) -> size_t { - return tensor.numel() / tensor.size(-1); - }; - const auto M = outer_dim(grad_output); - const auto N = static_cast(grad_output.size(-1)); - - // Check tensor dims - NVTE_CHECK(grad_output.dim() == 2, "Expected grad output tensor to have 2 dims, but found ", - grad_output.dim()); - NVTE_CHECK(input.dim() == 2, "Expected input tensor to have 2 dims, but found ", input.dim()); - NVTE_CHECK(outer_dim(input) == M, "Expected input tensor to have outer dimension of ", M, - ", but found ", outer_dim(input)); - NVTE_CHECK(input.size(-1) == 2 * N, "Expected input tensor to have inner dimension of ", 2 * N, - ", but found ", input.size(-1)); - NVTE_CHECK(grad_input.dim() == 2, "Expected grad input tensor to have 2 dims, but found ", - grad_input.dim()); - NVTE_CHECK(outer_dim(grad_input) == M, "Expected grad input tensor to have outer dimension of ", - M, ", but found ", outer_dim(grad_input)); - NVTE_CHECK(grad_input.size(-1) == 2 * N, "Expected grad input tensor to have inner dimension of ", - 2 * N, ", but found ", grad_input.size(-1)); - NVTE_CHECK(grad_input_transpose.dim() == 2, - "Expected grad input transpose tensor to have 2 dims, but found ", - grad_input_transpose.dim()); - NVTE_CHECK(grad_input_transpose.size(0) == 2 * N, - "Expected grad input tensor to have outer dimension of ", 2 * N, ", but found ", - grad_input_transpose.size(0)); - NVTE_CHECK(grad_input_transpose.size(1) == M, - "Expected grad input tensor to have outer dimension of ", M, ", but found ", - grad_input_transpose.size(1)); - - // Check tensor format - NVTE_CHECK(grad_output.is_contiguous(), "Expected grad output tensor to be contiguous"); - NVTE_CHECK(input.is_contiguous(), "Expected input tensor to be contiguous"); - NVTE_CHECK(grad_input.is_contiguous(), "Expected grad input tensor to be contiguous"); - NVTE_CHECK(grad_input_transpose.is_contiguous(), - "Expected grad input transpose tensor to be contiguous"); - NVTE_CHECK(grad_output.scalar_type() == input.scalar_type(), - "Expected grad output tensor and input tensor to have same dtype"); - NVTE_CHECK(grad_input.scalar_type() == at::ScalarType::Byte, - "Expected grad input tensor to be uint8 buffer"); - NVTE_CHECK(grad_input_transpose.scalar_type() == at::ScalarType::Byte, - "Expected grad input transpose tensor to be uint8 buffer"); - - // Get pointers for FP8 scale, amax, scale-inverse - void* scale_dptr = getDataPtr(scale, scale_offset); - void* amax_dptr = getDataPtr(amax, amax_offset); - void* scale_inv_dptr = getDataPtr(scale_inv, scale_inv_offset); - - // Construct Transformer Engine tensors - auto dy_cu = makeTransformerEngineTensor(grad_output); - auto x_cu = makeTransformerEngineTensor(input); - auto dx_cu = makeTransformerEngineTensor(grad_input.data_ptr(), {M, 2 * N}, otype, amax_dptr, - scale_dptr, scale_inv_dptr); - auto dx_t_cu = makeTransformerEngineTensor(grad_input_transpose.data_ptr(), {2 * N, M}, otype, - amax_dptr, scale_dptr, scale_inv_dptr); - - // Launch kernel - nvte_dswiglu_cast_transpose(dy_cu.data(), x_cu.data(), dx_cu.data(), dx_t_cu.data(), - at::cuda::getCurrentCUDAStream()); -} - -void fused_multi_cast_transpose_base(std::vector input_list, - std::vector scale_dptr_list, - std::vector cast_output_list, - std::vector transposed_output_list, - std::vector amax_dptr_list, - std::vector scale_inv_dptr_list, - transformer_engine::DType otype) { - using namespace transformer_engine; - - // Extract properties from PyTorch tensors - std::vector input_dptr_list, cast_output_dptr_list, transposed_output_dptr_list; - std::vector> input_shape_list, cast_output_shape_list, - transposed_output_shape_list; - std::vector input_type_list, cast_output_type_list, - transposed_output_type_list; - auto extract_tensor_props_skip_dtype = [](at::Tensor& tensor, std::vector& dptr_list, - std::vector>& shape_list) { - dptr_list.push_back(tensor.data_ptr()); - shape_list.push_back({}); - for (int d = 0; d < tensor.dim(); ++d) { - shape_list.back().push_back(tensor.size(d)); + // Check tensor lists + NVTE_CHECK(nvte_tensor_output_list.size() == nvte_tensor_input_list.size(), + "Number of input and output tensors must match"); + + // Choose implementation + // Note: Currently only have fused kernel for FP8 cast-transpose + bool with_fused_kernel = true; + for (size_t i = 0; i < nvte_tensor_output_list.size(); i++) { + const auto& tensor = nvte_tensor_output_list[i]; + if (nvte_tensor_scaling_mode(tensor) != NVTE_DELAYED_TENSOR_SCALING) { + with_fused_kernel = false; + break; } - }; - auto extract_tensor_props = [](at::Tensor& tensor, std::vector& dptr_list, - std::vector>& shape_list, - std::vector& type_list) { - dptr_list.push_back(tensor.data_ptr()); - shape_list.push_back({}); - for (int d = 0; d < tensor.dim(); ++d) { - shape_list.back().push_back(tensor.size(d)); + if (nvte_tensor_columnwise_data(tensor) == nullptr) { + with_fused_kernel = false; + break; } - type_list.push_back(GetTransformerEngineDType(tensor.scalar_type())); - }; - for (size_t tensor_id = 0; tensor_id < input_list.size(); ++tensor_id) { - extract_tensor_props(input_list[tensor_id], input_dptr_list, input_shape_list, input_type_list); - extract_tensor_props_skip_dtype(cast_output_list[tensor_id], cast_output_dptr_list, - cast_output_shape_list); - cast_output_type_list.push_back(otype); - extract_tensor_props_skip_dtype(transposed_output_list[tensor_id], transposed_output_dptr_list, - transposed_output_shape_list); - transposed_output_type_list.push_back(otype); } - // Construct TE tensors - std::vector nvte_input_list, nvte_cast_output_list, nvte_transposed_output_list; - std::vector tensor_wrappers; - auto make_tensor = [&tensor_wrappers](void* dptr, const std::vector& shape, - transformer_engine::DType dtype, void* amax_dptr, - void* scale_dptr, void* scale_inv_dptr) -> NVTETensor { - tensor_wrappers.emplace_back( - makeTransformerEngineTensor(dptr, shape, dtype, amax_dptr, scale_dptr, scale_inv_dptr)); - return tensor_wrappers.back().data(); - }; - for (size_t i = 0; i < input_dptr_list.size(); ++i) { - if (input_dptr_list[i] == nullptr) continue; - nvte_input_list.emplace_back(make_tensor(input_dptr_list[i], input_shape_list[i], - input_type_list[i], nullptr, nullptr, nullptr)); - nvte_cast_output_list.emplace_back( - make_tensor(cast_output_dptr_list[i], cast_output_shape_list[i], cast_output_type_list[i], - amax_dptr_list[i], scale_dptr_list[i], scale_inv_dptr_list[i])); - nvte_transposed_output_list.emplace_back( - make_tensor(transposed_output_dptr_list[i], transposed_output_shape_list[i], - transposed_output_type_list[i], amax_dptr_list[i], scale_dptr_list[i], - scale_inv_dptr_list[i])); - } - - // Check tensor lists - NVTE_CHECK(nvte_cast_output_list.size() == nvte_input_list.size(), - "Number of input and C output tensors must match"); - NVTE_CHECK(nvte_transposed_output_list.size() == nvte_input_list.size(), - "Number of input and T output tensors must match"); - // Launch TE kernel - nvte_multi_cast_transpose(nvte_input_list.size(), nvte_input_list.data(), - nvte_cast_output_list.data(), nvte_transposed_output_list.data(), - at::cuda::getCurrentCUDAStream()); -} - -void fused_multi_cast_transpose(std::vector input_list, - std::vector scale_list, - std::vector cast_output_list, - std::vector transposed_output_list, - std::vector amax_list, - std::vector scale_inv_list, - transformer_engine::DType otype) { - std::vector scale_dptr_list, amax_dptr_list, scale_inv_dptr_list; - for (size_t i = 0; i < scale_list.size(); ++i) { - scale_dptr_list.push_back(scale_list[i].data_ptr()); - amax_dptr_list.push_back(amax_list[i].data_ptr()); - scale_inv_dptr_list.push_back(scale_inv_list[i].data_ptr()); + if (with_fused_kernel) { + nvte_multi_cast_transpose(nvte_tensor_input_list.size(), nvte_tensor_input_list.data(), + nvte_tensor_output_list.data(), at::cuda::getCurrentCUDAStream()); + } else { + for (size_t i = 0; i < nvte_tensor_output_list.size(); i++) { + nvte_quantize(nvte_tensor_input_list[i], nvte_tensor_output_list[i], + at::cuda::getCurrentCUDAStream()); + } } - - fused_multi_cast_transpose_base(input_list, scale_dptr_list, cast_output_list, - transposed_output_list, amax_dptr_list, scale_inv_dptr_list, - otype); + return py_output_objects_list; } -std::tuple, std::vector> fused_multi_cast_transpose_alloc( - std::vector input_list, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, - std::vector scale_indices, std::vector amax_indices, - std::vector scale_inv_indices, transformer_engine::DType otype) { - using namespace transformer_engine; - - std::vector cast_output_list; - std::vector transposed_output_list; - std::vector scale_dptr_list, amax_dptr_list, scale_inv_dptr_list; - for (size_t i = 0; i < input_list.size(); ++i) { - auto input_i = input_list[i]; - // construct cast output tensors - auto cast_output_i = allocateTorchTensor(input_i.size(0), input_i.size(1), DType::kByte); - cast_output_list.push_back(cast_output_i); - // construct transposed output tensors - auto transposed_output_i = allocateTorchTensor(input_i.size(1), input_i.size(0), DType::kByte); - transposed_output_list.push_back(transposed_output_i); - // construct amax/scale/scale_inv dptr lists - amax_dptr_list.push_back(getDataPtr(amax, amax_indices[i])); - scale_dptr_list.push_back(getDataPtr(scale, scale_indices[i])); - scale_inv_dptr_list.push_back(getDataPtr(scale_inv, scale_inv_indices[i])); - } - - fused_multi_cast_transpose_base(input_list, scale_dptr_list, cast_output_list, - transposed_output_list, amax_dptr_list, scale_inv_dptr_list, - otype); +at::Tensor fp8_transpose(at::Tensor input, transformer_engine::DType otype, + std::optional output) { + using namespace transformer_engine::pytorch; - return std::make_tuple(std::move(cast_output_list), std::move(transposed_output_list)); -} + const auto dim = input.dim(); + NVTE_CHECK(dim >= 2, "Need at least 2D tensor to transpose."); -at::Tensor fp8_transpose(at::Tensor input, transformer_engine::DType otype) { - using namespace transformer_engine; + if (input.dim() > 2) { + input = input.view({-1, input.size(dim - 1)}); + } size_t M = static_cast(input.size(0)); size_t N = static_cast(input.size(1)); - auto output = allocateTorchTensor(input.size(1), input.size(0), DType::kByte); - if (M == 0 || N == 0) return output; - - auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, otype); - auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {N, M}, otype); - - nvte_transpose(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); - - return output; -} - -void fp8_transpose_noalloc(at::Tensor input, at::Tensor output, transformer_engine::DType otype) { - using namespace transformer_engine; - - size_t M = static_cast(input.size(0)); - size_t N = static_cast(input.size(1)); + at::Tensor out; + if (output.has_value()) { + out = *output; + } else { + out = allocateTorchTensor(input.size(1), input.size(0), DType::kByte); + } + if (M == 0 || N == 0) return out; auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, otype); - auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {N, M}, otype); + auto output_cu = makeTransformerEngineTensor(out.data_ptr(), {N, M}, otype); nvte_transpose(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); -} - -void fp8_transpose_noalloc_noop(at::Tensor input, at::Tensor output, at::Tensor noop, - transformer_engine::DType otype) { - using namespace transformer_engine; - - size_t M = static_cast(input.size(0)); - size_t N = static_cast(input.size(1)); - - auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, otype); - auto noop_cu = makeTransformerEngineTensor(noop); - auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {N, M}, otype); - nvte_transpose_with_noop(input_cu.data(), noop_cu.data(), output_cu.data(), - at::cuda::getCurrentCUDAStream()); + return out; } diff --git a/transformer_engine/pytorch/csrc/extensions/type_converters.cpp b/transformer_engine/pytorch/csrc/extensions/type_converters.cpp new file mode 100644 index 0000000000..d2607e4ed0 --- /dev/null +++ b/transformer_engine/pytorch/csrc/extensions/type_converters.cpp @@ -0,0 +1,79 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "common.h" +#include "pybind.h" + +namespace transformer_engine::pytorch { +namespace detail { + +TensorWrapper NVTETensorFromFloat8Tensor(py::handle tensor, Quantizer *quantizer) { + const at::Tensor &data = tensor.attr("_data").cast(); + const at::Tensor &scale_inv = tensor.attr("_scale_inv").cast(); + float *scale_inv_dptr = reinterpret_cast(scale_inv.data_ptr()); + const DType dtype = tensor.attr("_fp8_dtype").cast(); + + const auto &shape = getTensorShape(data); + + bool transpose_valid = !tensor.attr("_transpose_invalid").cast(); + std::optional transpose = std::nullopt; + if (transpose_valid) { + transpose = tensor.attr("_transpose").cast>(); + } + + auto ret = TensorWrapper(quantizer->get_scaling_mode()); + + ret.set_rowwise_data(data.data_ptr(), dtype, shape); + if (transpose_valid && transpose != std::nullopt) { + const auto &transpose_shape = getTensorShape(*transpose); + ret.set_columnwise_data(transpose->data_ptr(), dtype, transpose_shape); + } + + const auto scale_inv_dtype = GetTransformerEngineDType(scale_inv.scalar_type()); + const auto scale_inv_shape = getTensorShape(scale_inv); + ret.set_rowwise_scale_inv(scale_inv_dptr, scale_inv_dtype, scale_inv_shape); + ret.set_columnwise_scale_inv(scale_inv_dptr, scale_inv_dtype, scale_inv_shape); + quantizer->set_quantization_params(&ret); + return ret; +} + +TensorWrapper NVTETensorFromMXFP8Tensor(py::handle tensor, Quantizer *quantizer) { + const DType dtype = tensor.attr("_fp8_dtype").cast(); + auto ret = TensorWrapper(NVTE_MXFP8_1D_SCALING); + + bool rowwise_usage = !(tensor.attr("_rowwise_data").is_none()); + bool columnwise_usage = !(tensor.attr("_columnwise_data").is_none()); + + if (rowwise_usage) { + const at::Tensor &data_rowwise = tensor.attr("_rowwise_data").cast(); + const at::Tensor &scale_inv_rowwise = tensor.attr("_rowwise_scale_inv").cast(); + void *scale_inv_rowwise_dptr = scale_inv_rowwise.data_ptr(); + const auto &shape = getTensorShape(data_rowwise); + ret.set_rowwise_data(data_rowwise.data_ptr(), dtype, shape); + + const auto scale_inv_rowwise_shape = getTensorShape(scale_inv_rowwise); + ret.set_rowwise_scale_inv(scale_inv_rowwise_dptr, DType::kFloat8E8M0, scale_inv_rowwise_shape); + } + + if (columnwise_usage) { + const at::Tensor &data_colwise = tensor.attr("_columnwise_data").cast(); + const at::Tensor &scale_inv_colwise = tensor.attr("_columnwise_scale_inv").cast(); + void *scale_inv_colwise_dptr = scale_inv_colwise.data_ptr(); + const auto &shape = getTensorShape(data_colwise); + ret.set_columnwise_data(data_colwise.data_ptr(), dtype, shape); + + const auto scale_inv_colwise_shape = getTensorShape(scale_inv_colwise); + ret.set_columnwise_scale_inv(scale_inv_colwise_dptr, DType::kFloat8E8M0, + scale_inv_colwise_shape); + } + + quantizer->set_quantization_params(&ret); + return ret; +} + +} // namespace detail + +} // namespace transformer_engine::pytorch diff --git a/tests/pytorch/custom_ort_ops/custom_op_library.h b/transformer_engine/pytorch/csrc/extensions/util.cpp old mode 100755 new mode 100644 similarity index 53% rename from tests/pytorch/custom_ort_ops/custom_op_library.h rename to transformer_engine/pytorch/csrc/extensions/util.cpp index 747e6c5083..5f49383d11 --- a/tests/pytorch/custom_ort_ops/custom_op_library.h +++ b/transformer_engine/pytorch/csrc/extensions/util.cpp @@ -4,15 +4,11 @@ * See LICENSE for license information. ************************************************************************/ -#pragma once -#include "onnxruntime/core/session/onnxruntime_c_api.h" +#include "util.h" -#ifdef __cplusplus -extern "C" { -#endif +#include "ATen/cuda/CUDAContextLight.h" -ORT_EXPORT OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options, const OrtApiBase* api); - -#ifdef __cplusplus +bool non_tn_fp8_gemm_supported() { + int major = at::cuda::getCurrentDeviceProperties()->major; + return major >= 10; } -#endif diff --git a/transformer_engine/pytorch/csrc/pybind.h b/transformer_engine/pytorch/csrc/pybind.h new file mode 100644 index 0000000000..0679528b94 --- /dev/null +++ b/transformer_engine/pytorch/csrc/pybind.h @@ -0,0 +1,73 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#define PYBIND11_DETAILED_ERROR_MESSAGES // TODO remove + +#ifndef TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_PYBIND_H_ +#define TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_PYBIND_H_ +#include +#include +#include +#include + +#include "common.h" +#include "transformer_engine/transformer_engine.h" + +namespace transformer_engine::pytorch { + +extern PyTypeObject *Float8TensorPythonClass; +extern PyTypeObject *Float8TensorBasePythonClass; +extern PyTypeObject *Float8QuantizerClass; +extern PyTypeObject *MXFP8TensorPythonClass; +extern PyTypeObject *MXFP8TensorBasePythonClass; +extern PyTypeObject *MXFP8QuantizerClass; + +void init_extension(); + +void init_float8_extension(); + +void init_mxfp8_extension(); + +namespace detail { + +inline bool IsFloat8QParams(PyObject *obj) { return Py_TYPE(obj) == Float8QuantizerClass; } + +inline bool IsFloat8Tensor(PyObject *obj) { + return Py_TYPE(obj) == Float8TensorPythonClass || Py_TYPE(obj) == Float8TensorBasePythonClass; +} + +inline bool IsMXFP8QParams(PyObject *obj) { return Py_TYPE(obj) == MXFP8QuantizerClass; } + +inline bool IsMXFP8Tensor(PyObject *obj) { + return Py_TYPE(obj) == MXFP8TensorPythonClass || Py_TYPE(obj) == MXFP8TensorBasePythonClass; +} + +TensorWrapper NVTETensorFromFloat8Tensor(py::handle tensor, Quantizer *quantizer); + +template +std::unique_ptr CreateQuantizer(const py::handle quantizer) { + return std::make_unique(quantizer); +} + +TensorWrapper NVTETensorFromMXFP8Tensor(py::handle tensor, Quantizer *quantization_params); + +std::unique_ptr CreateMXFP8Params(const py::handle params); + +inline bool IsFloatingPointType(at::ScalarType type) { + return type == at::kFloat || type == at::kHalf || type == at::kBFloat16; +} + +constexpr std::array custom_types_converters = { + std::make_tuple(IsFloat8Tensor, IsFloat8QParams, NVTETensorFromFloat8Tensor, + CreateQuantizer), + std::make_tuple(IsMXFP8Tensor, IsMXFP8QParams, NVTETensorFromMXFP8Tensor, + CreateQuantizer)}; + +} // namespace detail + +} // namespace transformer_engine::pytorch + +#endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_PYBIND_H_ diff --git a/transformer_engine/pytorch/csrc/ts_fp8_op.cpp b/transformer_engine/pytorch/csrc/ts_fp8_op.cpp deleted file mode 100644 index 203b575a0d..0000000000 --- a/transformer_engine/pytorch/csrc/ts_fp8_op.cpp +++ /dev/null @@ -1,414 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -#include -#include -#include - -#include "common/util/cuda_runtime.h" -#include "common/util/system.h" -#include "extensions.h" - -namespace { -transformer_engine::DType reverse_map_dtype(int64_t dtype) { - if (dtype >= 0 && dtype < static_cast(transformer_engine::DType::kNumTypes)) { - return static_cast(dtype); - } else { - NVTE_ERROR("Type not supported."); - } -} -} // namespace - -at::Tensor cast_to_fp8_ts(const at::Tensor &input, const at::Tensor &scale, at::Tensor amax, - at::Tensor scale_inv, int64_t fp8_tensor, int64_t otype) { - transformer_engine::DType otype_arg = reverse_map_dtype(otype); - at::Tensor output = - cast_to_fp8(input, scale, amax, scale_inv, otype_arg, fp8_tensor, fp8_tensor, fp8_tensor); - return output; -} - -at::Tensor cast_to_fp8_noalloc_ts(const at::Tensor &input, const at::Tensor &scale, - at::Tensor output, at::Tensor amax, at::Tensor scale_inv, - int64_t fp8_tensor, int64_t otype) { - transformer_engine::DType otype_arg = reverse_map_dtype(otype); - cast_to_fp8_noalloc(input, scale, output, amax, scale_inv, otype_arg, fp8_tensor, fp8_tensor, - fp8_tensor); - return output; -} - -at::Tensor cast_from_fp8_ts(const at::Tensor &input, const at::Tensor &scale_inv, - int64_t fp8_tensor, int64_t itype, int64_t otype) { - transformer_engine::DType itype_arg = reverse_map_dtype(itype); - transformer_engine::DType otype_arg = reverse_map_dtype(otype); - at::Tensor output = cast_from_fp8(input, scale_inv, itype_arg, otype_arg, fp8_tensor); - return output; -} - -at::Tensor gelu_ts(at::Tensor input, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, - int64_t fp8_tensor, int64_t otype) { - transformer_engine::DType otype_arg = reverse_map_dtype(otype); - - at::Tensor s, a, s_inv; - if (scale.numel()) { - s = scale[fp8_tensor]; - } else { - s = scale; - } - - if (amax.numel()) { - a = amax[0][fp8_tensor]; - } else { - a = amax; - } - - if (scale_inv.numel()) { - s_inv = scale_inv[fp8_tensor]; - } else { - s_inv = scale_inv; - } - - at::Tensor output = gelu(input, s, a, s_inv, otype_arg); - return output; -} - -at::Tensor relu_ts(at::Tensor input, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, - int64_t fp8_tensor, int64_t otype) { - transformer_engine::DType otype_arg = reverse_map_dtype(otype); - - at::Tensor s, a, s_inv; - if (scale.numel()) { - s = scale[fp8_tensor]; - } else { - s = scale; - } - - if (amax.numel()) { - a = amax[0][fp8_tensor]; - } else { - a = amax; - } - - if (scale_inv.numel()) { - s_inv = scale_inv[fp8_tensor]; - } else { - s_inv = scale_inv; - } - - at::Tensor output = relu(input, s, a, s_inv, otype_arg); - return output; -} - -at::Tensor reglu_ts(at::Tensor input, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, - int64_t fp8_tensor, int64_t otype) { - transformer_engine::DType otype_arg = reverse_map_dtype(otype); - - at::Tensor s, a, s_inv; - if (scale.numel()) { - s = scale[fp8_tensor]; - } else { - s = scale; - } - - if (amax.numel()) { - a = amax[0][fp8_tensor]; - } else { - a = amax; - } - - if (scale_inv.numel()) { - s_inv = scale_inv[fp8_tensor]; - } else { - s_inv = scale_inv; - } - - at::Tensor output = reglu(input, s, a, s_inv, otype_arg); - return output; -} - -at::Tensor geglu_ts(at::Tensor input, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, - int64_t fp8_tensor, int64_t otype) { - transformer_engine::DType otype_arg = reverse_map_dtype(otype); - - at::Tensor s, a, s_inv; - if (scale.numel()) { - s = scale[fp8_tensor]; - } else { - s = scale; - } - - if (amax.numel()) { - a = amax[0][fp8_tensor]; - } else { - a = amax; - } - - if (scale_inv.numel()) { - s_inv = scale_inv[fp8_tensor]; - } else { - s_inv = scale_inv; - } - - at::Tensor output = geglu(input, s, a, s_inv, otype_arg); - return output; -} - -at::Tensor swiglu_ts(at::Tensor input, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, - int64_t fp8_tensor, int64_t otype) { - transformer_engine::DType otype_arg = reverse_map_dtype(otype); - - at::Tensor s, a, s_inv; - if (scale.numel()) { - s = scale[fp8_tensor]; - } else { - s = scale; - } - - if (amax.numel()) { - a = amax[0][fp8_tensor]; - } else { - a = amax; - } - - if (scale_inv.numel()) { - s_inv = scale_inv[fp8_tensor]; - } else { - s_inv = scale_inv; - } - - at::Tensor output = swiglu(input, s, a, s_inv, otype_arg); - return output; -} - -at::Tensor qgelu_ts(at::Tensor input, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, - int64_t fp8_tensor, int64_t otype) { - transformer_engine::DType otype_arg = reverse_map_dtype(otype); - - at::Tensor s, a, s_inv; - if (scale.numel()) { - s = scale[fp8_tensor]; - } else { - s = scale; - } - - if (amax.numel()) { - a = amax[0][fp8_tensor]; - } else { - a = amax; - } - - if (scale_inv.numel()) { - s_inv = scale_inv[fp8_tensor]; - } else { - s_inv = scale_inv; - } - - at::Tensor output = qgelu(input, s, a, s_inv, otype_arg); - return output; -} - -at::Tensor srelu_ts(at::Tensor input, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, - int64_t fp8_tensor, int64_t otype) { - transformer_engine::DType otype_arg = reverse_map_dtype(otype); - - at::Tensor s, a, s_inv; - if (scale.numel()) { - s = scale[fp8_tensor]; - } else { - s = scale; - } - - if (amax.numel()) { - a = amax[0][fp8_tensor]; - } else { - a = amax; - } - - if (scale_inv.numel()) { - s_inv = scale_inv[fp8_tensor]; - } else { - s_inv = scale_inv; - } - - at::Tensor output = srelu(input, s, a, s_inv, otype_arg); - return output; -} - -at::Tensor te_gemm_ts(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, - int64_t A_type, int64_t transa, at::Tensor B, at::Tensor B_scale_inverse, - int64_t B_fp8_tensor, int64_t B_type, int64_t transb, at::Tensor D, - at::Tensor D_scale, int64_t D_type, at::Tensor D_amax, at::Tensor bias, - int64_t bias_type, at::Tensor pre_gelu_out, int64_t grad, - at::Tensor workspace, int64_t workspaceSize, int64_t accumulate, - int64_t use_split_accumulator) { - // cast inputs to types accepted by te_gemm - transformer_engine::DType A_type_arg = reverse_map_dtype(A_type); - bool transa_arg = static_cast(transa); - transformer_engine::DType B_type_arg = reverse_map_dtype(B_type); - bool transb_arg = static_cast(transb); - transformer_engine::DType D_type_arg = reverse_map_dtype(D_type); - transformer_engine::DType bias_type_arg = reverse_map_dtype(bias_type); - bool grad_arg = static_cast(grad); - size_t workspaceSize_arg = static_cast(workspaceSize); - bool accumulate_arg = static_cast(accumulate); - bool use_split_accumulator_arg = static_cast(use_split_accumulator); - - // Set an external SM Margin to all the GEMMs. - // This comes in handy when DP is overlapped with GEMMs - - const int device_id = at::cuda::current_device(); - const int sm_count = transformer_engine::cuda::sm_count(device_id); - int num_math_sms = sm_count - transformer_engine::getenv("NVTE_EXT_MARGIN_SM", sm_count); - - if (A_scale_inverse.numel()) A_scale_inverse = A_scale_inverse[A_fp8_tensor]; - - if (B_scale_inverse.numel()) B_scale_inverse = B_scale_inverse[B_fp8_tensor]; - - te_gemm(A, A_scale_inverse, A_type_arg, transa_arg, B, B_scale_inverse, B_type_arg, transb_arg, D, - D_scale, D_type_arg, D_amax, bias, bias_type_arg, pre_gelu_out, grad_arg, workspace, - workspaceSize_arg, accumulate_arg, use_split_accumulator_arg, num_math_sms); - return D; -} - -std::vector te_grouped_gemm_ts( - std::vector A, at::Tensor A_scale_inverse, int64_t A_offset, int64_t A_type, - int64_t transa, std::vector B, at::Tensor B_scale_inverse, int64_t B_offset, - int64_t B_type, int64_t transb, std::vector D, int64_t D_offset, at::Tensor D_scale, - int64_t D_type, at::Tensor D_amax, std::vector bias, int64_t bias_type, - std::vector pre_gelu_out, int64_t grad, std::vector workspace, - int64_t workspaceSize, int64_t accumulate, int64_t use_split_accumulator) { - // cast inputs to types accepted by te_gemm - transformer_engine::DType A_type_arg = reverse_map_dtype(A_type); - bool transa_arg = static_cast(transa); - transformer_engine::DType B_type_arg = reverse_map_dtype(B_type); - bool transb_arg = static_cast(transb); - transformer_engine::DType D_type_arg = reverse_map_dtype(D_type); - transformer_engine::DType bias_type_arg = reverse_map_dtype(bias_type); - bool grad_arg = static_cast(grad); - size_t workspaceSize_arg = static_cast(workspaceSize); - bool accumulate_arg = static_cast(accumulate); - bool use_split_accumulator_arg = static_cast(use_split_accumulator); - - // Set an external SM Margin to all the GEMMs. - // This comes in handy when DP is overlapped with GEMMs - - const int device_id = at::cuda::current_device(); - const int sm_count = transformer_engine::cuda::sm_count(device_id); - int num_math_sms = sm_count - transformer_engine::getenv("NVTE_EXT_MARGIN_SM", sm_count); - - te_grouped_gemm(A, A_scale_inverse, A_offset, A_type_arg, transa_arg, B, B_scale_inverse, - B_offset, B_type_arg, transb_arg, D, D_offset, D_scale, D_type_arg, D_amax, bias, - bias_type_arg, pre_gelu_out, grad_arg, workspace, workspaceSize_arg, - accumulate_arg, use_split_accumulator_arg, num_math_sms); - return D; -} - -at::Tensor te_grouped_gemm_single_output_ts( - std::vector A, std::vector A_scale_inverse, int64_t A_offset, - int64_t A_type, int64_t transa, std::vector B, at::Tensor B_scale_inverse, - int64_t B_offset, int64_t B_type, int64_t transb, std::vector m_splits, at::Tensor D, - int64_t D_offset, at::Tensor D_scale, int64_t D_type, at::Tensor D_amax, - std::vector bias, int64_t bias_type, std::vector pre_gelu_out, - int64_t grad, std::vector workspace, int64_t workspaceSize, int64_t accumulate, - int64_t use_split_accumulator) { - // cast inputs to types accepted by te_gemm - transformer_engine::DType A_type_arg = reverse_map_dtype(A_type); - bool transa_arg = static_cast(transa); - transformer_engine::DType B_type_arg = reverse_map_dtype(B_type); - bool transb_arg = static_cast(transb); - transformer_engine::DType D_type_arg = reverse_map_dtype(D_type); - transformer_engine::DType bias_type_arg = reverse_map_dtype(bias_type); - bool grad_arg = static_cast(grad); - size_t workspaceSize_arg = static_cast(workspaceSize); - bool accumulate_arg = static_cast(accumulate); - bool use_split_accumulator_arg = static_cast(use_split_accumulator); - - // Set an external SM Margin to all the GEMMs. - // This comes in handy when DP is overlapped with GEMMs - - const int device_id = at::cuda::current_device(); - const int sm_count = transformer_engine::cuda::sm_count(device_id); - int num_math_sms = sm_count - transformer_engine::getenv("NVTE_EXT_MARGIN_SM", sm_count); - - te_grouped_gemm_single_output(A, A_scale_inverse, A_offset, A_type_arg, transa_arg, B, - B_scale_inverse, B_offset, B_type_arg, transb_arg, m_splits, D, - D_offset, D_scale, D_type_arg, D_amax, bias, bias_type_arg, - pre_gelu_out, grad_arg, workspace, workspaceSize_arg, - accumulate_arg, use_split_accumulator_arg, num_math_sms); - return D; -} - -at::Tensor layernorm_fwd_fp8_inf_ts(const at::Tensor &input, const at::Tensor &weight, - const at::Tensor &bias, double eps, at::Tensor scale, - at::Tensor amax, at::Tensor scale_inv, int64_t fp8_tensor, - int64_t otype, const int64_t sm_margin, - const bool zero_centered_gamma) { - transformer_engine::DType otype_arg = reverse_map_dtype(otype); - float eps_float = static_cast(eps); - - at::Tensor output = layernorm_fwd_fp8_inf(input, weight, bias, eps_float, scale, amax, scale_inv, - otype_arg, sm_margin, zero_centered_gamma, - fp8_tensor, // scale_offset - fp8_tensor, // amax_offset - fp8_tensor); // scale_inv_offset - - return output; -} - -at::Tensor layernorm_fwd_inf_ts(const at::Tensor &input, const at::Tensor &weight, - const at::Tensor &bias, double eps, const int64_t sm_margin, - const bool zero_centered_gamma) { - float eps_float = static_cast(eps); - - at::Tensor output = - layernorm_fwd_inf(input, weight, bias, eps_float, sm_margin, zero_centered_gamma); - - return output; -} - -at::Tensor rmsnorm_fwd_fp8_inf_ts(const at::Tensor &input, const at::Tensor &weight, double eps, - at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, - int64_t fp8_tensor, int64_t otype, const int64_t sm_margin, - const bool zero_centered_gamma) { - transformer_engine::DType otype_arg = reverse_map_dtype(otype); - float eps_float = static_cast(eps); - - at::Tensor output = rmsnorm_fwd_fp8_inf(input, weight, eps_float, scale, amax, scale_inv, - otype_arg, sm_margin, zero_centered_gamma, - fp8_tensor, // scale_offset - fp8_tensor, // amax_offset - fp8_tensor); // scale_inv_offset - - return output; -} - -at::Tensor rmsnorm_fwd_inf_ts(const at::Tensor &input, const at::Tensor &weight, double eps, - const int64_t sm_margin, const bool zero_centered_gamma) { - float eps_float = static_cast(eps); - - at::Tensor output = rmsnorm_fwd_inf(input, weight, eps_float, sm_margin, zero_centered_gamma); - - return output; -} - -TORCH_LIBRARY(tex_ts, m) { - m.def("cast_to_fp8_ts", &cast_to_fp8_ts); - m.def("cast_to_fp8_noalloc_ts", &cast_to_fp8_noalloc_ts); - m.def("cast_from_fp8_ts", &cast_from_fp8_ts); - m.def("gelu_ts", &gelu_ts); - m.def("relu_ts", &relu_ts); - m.def("geglu_ts", &geglu_ts); - m.def("reglu_ts", ®lu_ts); - m.def("swiglu_ts", &swiglu_ts); - m.def("qgelu_ts", &qgelu_ts); - m.def("srelu_ts", &srelu_ts); - m.def("te_gemm_ts", &te_gemm_ts); - m.def("te_grouped_gemm_ts", &te_grouped_gemm_ts); - m.def("te_grouped_gemm_single_output_ts", &te_grouped_gemm_single_output_ts); - m.def("layernorm_fwd_fp8_inf_ts", &layernorm_fwd_fp8_inf_ts); - m.def("layernorm_fwd_inf_ts", &layernorm_fwd_inf_ts); - m.def("rmsnorm_fwd_fp8_inf_ts", &rmsnorm_fwd_fp8_inf_ts); - m.def("rmsnorm_fwd_inf_ts", &rmsnorm_fwd_inf_ts); -} diff --git a/transformer_engine/pytorch/csrc/util.h b/transformer_engine/pytorch/csrc/util.h new file mode 100644 index 0000000000..cbdf0833ed --- /dev/null +++ b/transformer_engine/pytorch/csrc/util.h @@ -0,0 +1,12 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_PYTORCH_CSRC_UTIL_H_ +#define TRANSFORMER_ENGINE_PYTORCH_CSRC_UTIL_H_ + +bool non_tn_fp8_gemm_supported(); + +#endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_UTIL_H_ diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index e6d63ab9e4..aa5964bc4a 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -7,6 +7,7 @@ from contextlib import contextmanager, AbstractContextManager, ContextDecorator from functools import lru_cache +import math from typing import Any, Callable, Dict, List, Optional, Tuple, Union import warnings @@ -20,7 +21,11 @@ from .utils import safely_set_viewless_tensor_data from .constants import dist_group_type from .fp8 import FP8GlobalStateManager -from .float8_tensor import Float8Tensor +from .tensor.float8_tensor import Float8Quantizer, Float8Tensor +from .tensor.mxfp8_tensor import MXFP8Quantizer, MXFP8Tensor +from .tensor.quantized_tensor import QuantizedTensor, Quantizer +from .tensor._internal.float8_tensor_base import Float8TensorBase +from .tensor._internal.mxfp8_tensor_base import MXFP8TensorBase __all__ = ["checkpoint", "CudaRNGStatesTracker"] @@ -815,7 +820,7 @@ def fork(self, name: str = "model-parallel-rng"): def reduce_scatter_along_first_dim( input_: torch.Tensor, tp_group: dist_group_type, async_op: bool = False -) -> Tuple[torch.Tensor, torch.Tensor]: +) -> Tuple[torch.Tensor, Optional[torch.distributed.Work]]: """Reduce-scatter the input tensor across model parallel group.""" world_size = get_distributed_world_size(tp_group) # Bypass the function if we are using only 1 GPU. @@ -836,57 +841,232 @@ def reduce_scatter_along_first_dim( return output, handle +def _all_gather_fp8( + input_: torch.Tensor, + process_group: dist_group_type, + *, + async_op: bool = False, + quantizer: Optional[Float8Quantizer] = None, + out_shape: Optional[list[int]] = None, +) -> tuple[Float8TensorBase, Optional[torch.distributed.Work]]: + """All-gather FP8 tensor along first dimension.""" + world_size = get_distributed_world_size(process_group) + + # Output tensor dims + if out_shape is None: + out_shape = list(input_.size()) + out_shape[0] *= world_size + + # Quantize input tensor if needed + if not isinstance(input_, Float8TensorBase): + assert isinstance(quantizer, Float8Quantizer) + init_columnwise_usage = quantizer.columnwise_usage + quantizer.set_usage(columnwise=False) + input_ = quantizer(input_) + quantizer.set_usage(columnwise=init_columnwise_usage) + + # Construct output tensor + out: Float8TensorBase + if isinstance(quantizer, Float8Quantizer): + dtype = torch.float32 + device = "cuda" + if isinstance(input_, Float8Tensor): + dtype = input_.dtype + device = input_.device + out = quantizer.make_empty(out_shape, dtype=dtype, device=device) + elif isinstance(input, Float8Tensor): + out = input_.make_like(input_, shape=out_shape) + out._data = torch.empty_like( + out_shape, + dtype=torch.uint8, + device=input_.device, + ) + out._transpose = None + out._transpose_invalid = True + else: + raise RuntimeError("FP8TensorBase is not supported yet without Quantizer") + out._scale_inv = input_._scale_inv + + # Perform communication + handle = torch.distributed.all_gather_into_tensor( + out._data, + input_._data.contiguous(), + group=process_group, + async_op=async_op, + ) + + # Make sure FP8 transpose is populated if needed + if out._transpose is not None: + if handle is not None: + handle.wait() + handle = None + if not isinstance(out, Float8Tensor): + raise RuntimeError("FP8TensorBase does not support FP8 transpose yet") + out._create_transpose() + + return out, handle + + +def _all_gather_mxfp8( + input_: torch.Tensor, + process_group: dist_group_type, + *, + async_op: bool = False, + quantizer: MXFP8Quantizer, + out_shape: Optional[list[int]] = None, +) -> tuple[MXFP8TensorBase, Optional[torch.distributed.Work]]: + """All-gather MXFP8 tensor along first dimension.""" + + # Tensor dims + world_size = get_distributed_world_size(process_group) + in_shape = list(input_.size()) + if out_shape is None: + out_shape = [in_shape[0] * world_size] + in_shape[1:] + + # Gather MXFP8 data for row-wise usage + if quantizer.rowwise_usage and not quantizer.columnwise_usage: + + # Cast input tensor to MXFP8 if needed + if not isinstance(input_, MXFP8TensorBase): + input_ = quantizer(input_) + + # Construct MXFP8 output tensor + dtype = torch.float32 + device = "cuda" + if isinstance(input_, MXFP8Tensor): + dtype = input_.dtype + device = input_.device + out = quantizer.make_empty(out_shape, dtype=dtype, device=device) + + # Remove padding from MXFP8 scale-inverses + in_scale_inv = input_._rowwise_scale_inv + out_scale_inv = out._rowwise_scale_inv + flattened_in_shape0 = math.prod(in_shape[:-1]) + if in_scale_inv.size(0) != flattened_in_shape0: + in_scale_inv = in_scale_inv[:flattened_in_shape0] + out_scale_inv[flattened_in_shape0 * world_size :].zero_() + out_scale_inv = out_scale_inv[: flattened_in_shape0 * world_size] + + # Launch all-gathers + with torch.distributed._coalescing_manager( + group=process_group, + device=device, + async_ops=async_op, + ) as coalescing_manager: + torch.distributed.all_gather_into_tensor( + out._rowwise_data, + input_._rowwise_data, + group=process_group, + ) + torch.distributed.all_gather_into_tensor( + out_scale_inv, + in_scale_inv, + group=process_group, + ) + handle = coalescing_manager if async_op else None + return out, handle + + # Gather in high precision and quantize for column-wise usage + if isinstance(input_, QuantizedTensor): + input_ = input_.dequantize(dtype=torch.bfloat16) + out = torch.empty( + out_shape, + dtype=input_.dtype, + device=input_.device, + memory_format=torch.contiguous_format, + ) + torch.distributed.all_gather_into_tensor(out, input_, group=process_group) + out = quantizer(out) + return out, None + + def gather_along_first_dim( input_: torch.Tensor, process_group: dist_group_type, async_op: bool = False, -) -> tuple[torch.Tensor, Any]: + quantizer: Optional[Quantizer] = None, +) -> tuple[torch.Tensor, Optional[torch.distributed.Work]]: """All-gather tensors and concatenate along first dimension.""" # Return immediately if no communication is required world_size = get_distributed_world_size(process_group) if world_size == 1: + if quantizer is not None and not isinstance(input_, QuantizedTensor): + input_ = quantizer(input_) return input_, None - # Allocate output tensor - output_shape = list(input_.size()) - output_shape[0] *= world_size - if isinstance(input_, Float8Tensor): - output = Float8Tensor.make_like( + # Output tensor dims + out_shape = list(input_.size()) + out_shape[0] *= world_size + + # FP8 case + if isinstance(input_, Float8TensorBase) or isinstance(quantizer, Float8Quantizer): + return _all_gather_fp8( input_, - data=torch.empty( - output_shape, - dtype=torch.uint8, - device=input_.device, - ), + process_group, + async_op=async_op, + quantizer=quantizer, + out_shape=out_shape, ) - src = input_._data.contiguous() - dst = output._data - else: - output = torch.empty( - output_shape, + + # MXFP8 case + if isinstance(input_, MXFP8TensorBase) or isinstance(quantizer, MXFP8Quantizer): + assert isinstance(quantizer, MXFP8Quantizer) + return _all_gather_mxfp8( + input_, + process_group, + async_op=async_op, + quantizer=quantizer, + out_shape=out_shape, + ) + + # High-precision communication for quantized tensors + if quantizer is not None: + warnings.warn( + "Attempting to all-gather an unsupported quantized tensor. " + "Falling back to high-precision all-gather." + ) + if isinstance(input_, QuantizedTensor): + input_ = input_.dequantize() + out = torch.empty( + out_shape, dtype=input_.dtype, device=input_.device, memory_format=torch.contiguous_format, ) - src = input_.contiguous() - dst = output + torch.distributed.all_gather_into_tensor(out, input_, group=process_group) + out = quantizer(out) + return out, None - # Launch all-gather + # Dequantize quantized tensor if not supported + if isinstance(input_, QuantizedTensor): + warnings.warn( + "Attempting to all-gather an unsupported quantized tensor. " + "Falling back to high-precision all-gather." + ) + input_ = input_.dequantize() + + # Communication for plain PyTorch tensors + out = torch.empty( + out_shape, + dtype=input_.dtype, + device=input_.device, + memory_format=torch.contiguous_format, + ) handle = torch.distributed.all_gather_into_tensor( - dst, - src, + out, + input_.contiguous(), group=process_group, async_op=async_op, ) - return output, handle + return out, handle def allreduce( input_: torch.Tensor, tp_group: Optional[dist_group_type] = None, async_op: bool = False, -) -> Tuple[torch.Tensor, torch.Tensor]: +) -> Tuple[torch.Tensor, Optional[torch.distributed.Work]]: """All-reduce the input tensor across model parallel group.""" # Bypass the function if we are using only 1 GPU. @@ -907,12 +1087,13 @@ def _fsdp_scatter_tensors( if fsdp_group is not None: for t in tensors: if isinstance(t, torch.Tensor): - target = t._data if isinstance(t, Float8Tensor) else t - shapes.append(target.data.shape) - safely_set_viewless_tensor_data( - target, - split_tensor_into_1d_equal_chunks(target.data, fsdp_group, new_buffer=True), - ) + targets = t.get_data_tensors() if isinstance(t, QuantizedTensor) else [t] + for target in targets: + shapes.append(target.data.shape) + safely_set_viewless_tensor_data( + target, + split_tensor_into_1d_equal_chunks(target.data, fsdp_group, new_buffer=True), + ) else: shapes.append(None) return shapes @@ -928,10 +1109,11 @@ def _fsdp_gather_tensors( for s, t in zip(shapes, tensors): if isinstance(t, torch.Tensor): assert s is not None, "Internal TE error." - target = t._data if isinstance(t, Float8Tensor) else t - safely_set_viewless_tensor_data( - target, gather_split_1d_tensor(target.data, fsdp_group).view(s) - ) + targets = t.get_data_tensors() if isinstance(t, QuantizedTensor) else [t] + for target in targets: + safely_set_viewless_tensor_data( + target, gather_split_1d_tensor(target.data, fsdp_group).view(s) + ) def _is_te_module(module): diff --git a/transformer_engine/pytorch/export.py b/transformer_engine/pytorch/export.py deleted file mode 100755 index 79b839edfd..0000000000 --- a/transformer_engine/pytorch/export.py +++ /dev/null @@ -1,40 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -"""Export utilities for TransformerEngine""" -from contextlib import contextmanager - -_IN_ONNX_EXPORT_MODE = False - - -@contextmanager -def onnx_export( - enabled: bool = False, -) -> None: - """ - Context manager for exporting to ONNX. - - .. code-block:: python - - with onnx_export(enabled=True): - torch.onnx.export(model) - - Parameters - ---------- - enabled: bool, default = `False` - whether or not to enable export - """ - - global _IN_ONNX_EXPORT_MODE - onnx_export_state = _IN_ONNX_EXPORT_MODE - try: - _IN_ONNX_EXPORT_MODE = enabled - yield - finally: - _IN_ONNX_EXPORT_MODE = onnx_export_state - - -def is_in_onnx_export_mode() -> bool: - """Returns True if onnx export mode is enabled, False otherwise.""" - return _IN_ONNX_EXPORT_MODE diff --git a/transformer_engine/pytorch/float8_tensor.py b/transformer_engine/pytorch/float8_tensor.py index 8554cc7443..a771e3bb75 100644 --- a/transformer_engine/pytorch/float8_tensor.py +++ b/transformer_engine/pytorch/float8_tensor.py @@ -4,6 +4,6 @@ """Tensor class with FP8 data""" -from .tensor import Float8Tensor +from .tensor.float8_tensor import Float8Tensor __all__ = ["Float8Tensor"] diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index b1b6165777..254bcf12e1 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -3,6 +3,9 @@ # See LICENSE for license information. """FP8 utilities for TransformerEngine""" +from __future__ import annotations + +import abc import os from contextlib import contextmanager from collections import deque @@ -10,7 +13,7 @@ import torch import transformer_engine_torch as tex -from transformer_engine.common.recipe import DelayedScaling, Format +from transformer_engine.common.recipe import Recipe, DelayedScaling, Format, MXFP8BlockScaling from .constants import dist_group_type from .utils import get_device_compute_capability @@ -33,12 +36,21 @@ def check_fp8_support() -> Tuple[bool, str]: return True, "" -def get_default_fp8_recipe() -> DelayedScaling: +def check_mxfp8_support() -> Tuple[bool, str]: + """Return if fp8 support is available""" + if get_device_compute_capability() >= (10, 0): # blackwell and above + return True, "" + return False, "Device compute capability 10.0 or higher required for MXFP8 execution." + + +def get_default_fp8_recipe() -> Recipe: """FP8 recipe with default args.""" + if get_device_compute_capability() >= (10, 0): # blackwell and above + return MXFP8BlockScaling() return DelayedScaling() -def get_fp8_torch_dtype(fp8_recipe: DelayedScaling, fprop_tensor: bool = True) -> torch.dtype: +def get_fp8_torch_dtype(fp8_recipe: Recipe, fprop_tensor: bool = True) -> torch.dtype: """Get fp8 data type according to recipe and tensor""" if fp8_recipe.fp8_format == Format.E4M3 or ( fp8_recipe.fp8_format == Format.HYBRID and fprop_tensor @@ -47,7 +59,7 @@ def get_fp8_torch_dtype(fp8_recipe: DelayedScaling, fprop_tensor: bool = True) - return torch.float8_e5m2fn -def get_fp8_te_dtype(fp8_recipe: DelayedScaling, fprop_tensor: bool = True) -> tex.DType: +def get_fp8_te_dtype(fp8_recipe: Recipe, fprop_tensor: bool = True) -> tex.DType: """Get fp8 data type according to recipe and tensor""" if fp8_recipe.fp8_format == Format.E4M3 or ( fp8_recipe.fp8_format == Format.HYBRID and fprop_tensor @@ -56,7 +68,7 @@ def get_fp8_te_dtype(fp8_recipe: DelayedScaling, fprop_tensor: bool = True) -> t return tex.DType.kFloat8E5M2 -def get_fp8_max(fp8_recipe: DelayedScaling, fprop_tensor: bool = True) -> tex.DType: +def get_fp8_max(fp8_recipe: Recipe, fprop_tensor: bool = True) -> tex.DType: """Get max representible FP8 value.""" if fp8_recipe.fp8_format == Format.E4M3 or ( fp8_recipe.fp8_format == Format.HYBRID and fprop_tensor @@ -81,7 +93,6 @@ class FP8GlobalStateManager: global_amax_buffer = {} global_amax_history_buffer = {} global_scale_buffer = {} - global_scale_inv_buffer = {} fp8_tensors_recompute_buffer = [] fp8_available = None reason_for_no_fp8 = "" @@ -89,6 +100,8 @@ class FP8GlobalStateManager: autocast_to_fp8_params = {} fp8_param_to_autocast = {} skip_fp8_weight_update_tensor = None + mxfp8_available = None + reason_for_no_mxfp8 = "" @classmethod def reset(cls) -> None: @@ -104,12 +117,15 @@ def reset(cls) -> None: cls.global_amax_buffer = {} cls.global_amax_history_buffer = {} cls.global_scale_buffer = {} - cls.global_scale_inv_buffer = {} cls.fp8_tensors_recompute_buffer = [] cls.fp8_available = None cls.reason_for_no_fp8 = "" cls.autocast_arguments = {} + cls.autocast_to_fp8_params = {} + cls.fp8_param_to_autocast = {} cls.skip_fp8_weight_update_tensor = None + cls.mxfp8_available = None + cls.reason_for_no_mxfp8 = "" @classmethod def set_skip_fp8_weight_update_tensor(cls, skip: bool) -> None: @@ -130,6 +146,13 @@ def is_fp8_available(cls) -> Tuple[bool, str]: cls.fp8_available, cls.reason_for_no_fp8 = check_fp8_support() return cls.fp8_available, cls.reason_for_no_fp8 + @classmethod + def is_mxfp8_available(cls) -> Tuple[bool, str]: + """Return if MXFP8/current scaling support is available.""" + if cls.mxfp8_available is None: + cls.mxfp8_available, cls.reason_for_no_mxfp8 = check_mxfp8_support() + return cls.mxfp8_available, cls.reason_for_no_mxfp8 + @staticmethod def get_meta_tensor_key(forward: bool = True) -> str: """Returns scaling key in `fp8_meta`.""" @@ -154,7 +177,7 @@ def get_buffer_info(cls) -> str: def get_key_in_buffer( cls, forward: bool, - fp8_recipe: DelayedScaling, + fp8_recipe: Recipe, fp8_group: dist_group_type, ) -> str: """Returns a key into the global FP8 buffers.""" @@ -188,6 +211,9 @@ def add_fp8_tensors_to_global_buffer( wrapper. For non CG case, it's called from within the module. """ + if fp8_meta["recipe"].mxfp8(): + return + # Every module must call this function exactly once since # the amax tensors are static. Ensures that compatibility # with non-graphed modules is maintained. @@ -208,14 +234,12 @@ def add_fp8_tensors_to_global_buffer( cls.global_amax_buffer[key] = [fp8_meta[fp8_meta_tensor_key].amax_history[0]] cls.global_amax_history_buffer[key] = [fp8_meta[fp8_meta_tensor_key].amax_history] cls.global_scale_buffer[key] = [fp8_meta[fp8_meta_tensor_key].scale] - cls.global_scale_inv_buffer[key] = [fp8_meta[fp8_meta_tensor_key].scale_inv] else: cls.global_amax_buffer[key].append(fp8_meta[fp8_meta_tensor_key].amax_history[0]) cls.global_amax_history_buffer[key].append( fp8_meta[fp8_meta_tensor_key].amax_history ) cls.global_scale_buffer[key].append(fp8_meta[fp8_meta_tensor_key].scale) - cls.global_scale_inv_buffer[key].append(fp8_meta[fp8_meta_tensor_key].scale_inv) fp8_meta[index_in_buffer].append(len(cls.global_amax_buffer[key]) - 1) fp8_meta[index_in_buffer].append(key) @@ -249,7 +273,7 @@ def is_first_fp8_module(cls): return tmp @classmethod - def get_fp8_recipe(cls) -> DelayedScaling: + def get_fp8_recipe(cls) -> Recipe: """Return the fp8 recipe""" if cls.FP8_RECIPE is not None: return cls.FP8_RECIPE @@ -261,7 +285,7 @@ def get_fp8_group(cls) -> Union[dist_group_type, None]: return cls.FP8_DISTRIBUTED_GROUP @classmethod - def get_fp8_autocast_state(cls) -> Tuple[bool, bool, DelayedScaling, dist_group_type, bool]: + def get_fp8_autocast_state(cls) -> Tuple[bool, bool, Recipe, dist_group_type, bool]: """FP8 autocast state getter""" return ( cls.FP8_ENABLED, @@ -335,7 +359,6 @@ def reduce_and_update_fp8_tensors( contiguous_amax, cls.global_amax_history_buffer[buffer_key], cls.global_scale_buffer[buffer_key], - cls.global_scale_inv_buffer[buffer_key], recipe.amax_compute_algo, get_fp8_te_dtype(recipe, forward), recipe.margin, @@ -343,19 +366,18 @@ def reduce_and_update_fp8_tensors( else: split_and_copy(contiguous_amax, amax_buffer, [x.numel() for x in amax_buffer]) - for amax_history, scale, scale_inv in zip( + for amax_history, scale in zip( cls.global_amax_history_buffer[buffer_key], cls.global_scale_buffer[buffer_key], - cls.global_scale_inv_buffer[buffer_key], ): _amax_and_scale_update( - amax_history, scale, scale_inv, get_fp8_max(recipe, forward), recipe + amax_history, scale, get_fp8_max(recipe, forward), recipe ) @classmethod def get_unique_autocast_key( cls, - recipe: Optional[DelayedScaling] = None, + recipe: Optional[Recipe] = None, group: Optional[dist_group_type] = None, ): """ @@ -369,7 +391,7 @@ def fp8_autocast_enter( cls, enabled: bool = False, calibrating: bool = False, - fp8_recipe: Optional[DelayedScaling] = None, + fp8_recipe: Optional[Recipe] = None, fp8_group: Optional[dist_group_type] = None, _graph: bool = False, ) -> None: @@ -392,6 +414,9 @@ def fp8_autocast_enter( if enabled: fp8_available, reason_for_no_fp8 = cls.is_fp8_available() assert fp8_available, reason_for_no_fp8 + if isinstance(fp8_recipe, MXFP8BlockScaling): + mxfp8_available, reason_for_no_mxfp8 = cls.is_mxfp8_available() + assert mxfp8_available, reason_for_no_mxfp8 @classmethod def fp8_autocast_exit(cls, enabled: bool, _graph: bool) -> None: @@ -408,12 +433,15 @@ def copy_forward_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) - """Copy the scaling factors and amaxes for recompute forward phase to ensure both forward steps are numerically same. """ + + if fp8_meta["recipe"].mxfp8(): + return + buffer_position_key = "global_fp8_buffer_pos_fwd_recompute" to_copy = [ fp8_meta["scaling_fwd"].amax_history.clone(), fp8_meta["scaling_fwd"].scale.clone(), - fp8_meta["scaling_fwd"].scale_inv.clone(), ] if buffer_position_key in fp8_meta: @@ -432,10 +460,12 @@ def get_old_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) -> Non 1 forward for indentical numerical outputs. """ + if fp8_meta["recipe"].mxfp8(): + return + # Store updated amaxes and scales from phase 1 post forward. fp8_meta["updated_amax_history_fwd"] = fp8_meta["scaling_fwd"].amax_history fp8_meta["updated_scale_fwd"] = fp8_meta["scaling_fwd"].scale - fp8_meta["updated_scale_inv_fwd"] = fp8_meta["scaling_fwd"].scale_inv # Retrieve stashed amaxes and scales from phase 1 pre forward. buffer_position_key = "global_fp8_buffer_pos_fwd_recompute" @@ -444,18 +474,20 @@ def get_old_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) -> Non # Replace amaxes and scales with stashed values for phase 2 forward fp8_meta["scaling_fwd"].amax_history.copy_(stashed_fp8_meta[0]) fp8_meta["scaling_fwd"].scale.copy_(stashed_fp8_meta[1]) - fp8_meta["scaling_fwd"].scale_inv.copy_(stashed_fp8_meta[2]) @staticmethod def restore_fp8_meta_tensors(fp8_meta: Dict[str, Any]) -> None: """Restore latest scaling factors and amaxes after recompute forward run.""" + + if fp8_meta["recipe"].mxfp8(): + return + fp8_meta["scaling_fwd"].amax_history.copy_(fp8_meta["updated_amax_history_fwd"]) fp8_meta["scaling_fwd"].scale.copy_(fp8_meta["updated_scale_fwd"]) - fp8_meta["scaling_fwd"].scale_inv.copy_(fp8_meta["updated_scale_inv_fwd"]) @contextmanager -def fp8_model_init(enabled: bool = True) -> None: +def fp8_model_init(enabled: bool = True, recipe: Optional[Recipe] = None) -> None: """ Context manager for FP8 initialization of parameters. @@ -479,22 +511,27 @@ def fp8_model_init(enabled: bool = True) -> None: precision copies of weights are already present in the optimizer. * inference, where only the FP8 copies of the parameters are used. * LoRA-like fine-tuning, where the main parameters of the model do not change. + recipe: transformer_engine.common.recipe.Recipe, default = `None` + Recipe used to create the parameters. If left to None, it uses the default FP8 recipe. This functionality is *EXPERIMENTAL*. """ _fp8_parameters = FP8GlobalStateManager.FP8_PARAMETERS + _fp8_recipe = FP8GlobalStateManager.FP8_RECIPE FP8GlobalStateManager.FP8_PARAMETERS = enabled + FP8GlobalStateManager.FP8_RECIPE = get_default_fp8_recipe() if recipe is None else recipe try: yield finally: FP8GlobalStateManager.FP8_PARAMETERS = _fp8_parameters + FP8GlobalStateManager.FP8_RECIPE = _fp8_recipe @contextmanager def fp8_autocast( enabled: bool = True, calibrating: bool = False, - fp8_recipe: Optional[DelayedScaling] = None, + fp8_recipe: Optional[Recipe] = None, fp8_group: Optional[dist_group_type] = None, _graph: bool = False, ) -> None: @@ -529,7 +566,7 @@ def fp8_autocast( data of fp8 tensors even when executing without fp8 enabled. This is useful for saving an inference ready fp8 checkpoint while training using a higher precision. - fp8_recipe: recipe.DelayedScaling, default = `None` + fp8_recipe: recipe.Recipe, default = `None` recipe used for FP8 training. fp8_group: torch._C._distributed_c10d.ProcessGroup, default = `None` distributed group over which amaxes for the fp8 tensors @@ -639,7 +676,6 @@ def _compute_scaling_factor( def _amax_and_scale_update( amax_history: torch.Tensor, scale: torch.Tensor, - scale_inv: torch.Tensor, fp8_max: float, recipe: DelayedScaling, ) -> None: @@ -650,7 +686,6 @@ def _amax_and_scale_update( ) new_scale = _compute_scaling_factor(amax, scale, fp8_max, recipe) scale.copy_(new_scale) - scale_inv.copy_(1.0 / new_scale) amax_history.copy_(new_amax_history) @@ -662,3 +697,152 @@ def split_and_copy( """Split `buffer` by `chunk_sizes` and copy into `outputs`.""" splits = buffer.split(chunk_sizes) torch._foreach_copy_(outputs, splits) + + +class RecipeState(abc.ABC): + """Configuration and state for a quantization recipe. + + This is a builder class for quantizers, which are in turn builder + classes for quantized tensors. + + This class may pack together the state for multiple quantizers, + which is helpful for applying fused kernels with less overhead. + + """ + + @staticmethod + def create( + recipe: Recipe, + *, + mode: str, + num_quantizers: int = 1, + device: Optional[torch.device] = None, + ) -> RecipeState: + """Factory method to create the state for a quantization recipe + + Parameters + ---------- + recipe: Recipe + Quantization recipe. + mode: {"forward", "backward"} + Training stage where quantization will be performed. + num_quantizers: int, default = 1 + Number of quantizers to create state for. + device: torch.device, default = default CUDA device + Device for quantized tensors. + + Returns + ------- + RecipeState: + Quantization recipe state. + + """ + + cls = None + if recipe.delayed(): + cls = DelayedScalingRecipeState + elif recipe.mxfp8(): + cls = MXFP8BlockScalingRecipeState + else: + raise ValueError("{recipe.__class__.__name__} is not supported") + return cls( + recipe, + mode=mode, + num_quantizers=num_quantizers, + device=device, + ) + + @abc.abstractmethod + def make_quantizers(self) -> list: + """Convert recipe state to quantizers. + + Quantizers are builder classes for quantized tensors. They are + typically used to convert a high-precision tensor (e.g. in + FP32 or BF16) into a quantized tensor (e.g. in FP8). + + """ + + +class DelayedScalingRecipeState(RecipeState): + """State for FP8 quantization with per-tensor delayed scaling. + + Delayed scaling recipe requires a scaling factor (applied when + casting to FP8) and a history of max-abs values ("amax") from + recent FP8 casts for updating the scaling factor. The scale update + is handled externally by `FP8GlobalStateManager`. + + """ + + recipe: DelayedScaling + mode: str + dtype: tex.DType + scale: torch.Tensor + amax_history: torch.Tensor + + def __init__( + self, + recipe: DelayedScaling, + *, + mode: str, + num_quantizers: int = 1, + device: Optional[torch.device] = None, + ) -> None: + self.recipe = recipe + self.mode = mode + self.num_quantizers = num_quantizers + self.dtype = get_fp8_te_dtype(recipe, mode == "forward") + + # Allocate buffers + if device is None: + device = torch.device("cuda") + self.scale = torch.ones(num_quantizers, dtype=torch.float32, device=device) + self.amax_history = torch.zeros( + recipe.amax_history_len, + num_quantizers, + dtype=torch.float32, + device=device, + ) + + def make_quantizers(self) -> list: + # TODO(ksivamani); Find better design for this, adding here to avoid circular import. + from .tensor.float8_tensor import Float8Quantizer + + return [ + Float8Quantizer(self.scale[i], self.amax_history[0][i].reshape((1,)), self.dtype) + for i in range(self.num_quantizers) + ] + + +class MXFP8BlockScalingRecipeState(RecipeState): + """Configuration for MXFP8 quantization. + + MXFP8 quantization does not require state. + + """ + + recipe: MXFP8BlockScaling + mode: str + dtype: tex.DType + + def __init__( + self, + recipe: MXFP8BlockScaling, + *, + mode: str, + num_quantizers: int = 1, + device: Optional[torch.device] = None, + ) -> None: + self.recipe = recipe + self.mode = mode + self.num_quantizers = num_quantizers + self.dtype = get_fp8_te_dtype(recipe, mode == "forward") + + # Allocate buffers + if device is None: + device = torch.device("cuda") + + def make_quantizers(self) -> list: + # TODO(ksivamani); Find better design for this, adding here to avoid circular import. + from .tensor.mxfp8_tensor import MXFP8Quantizer + + return [MXFP8Quantizer(self.dtype) for i in range(self.num_quantizers)] diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index 3853e70048..83b316aad4 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -11,7 +11,7 @@ from torch.utils._pytree import tree_unflatten as _tree_unflatten from torch._C import _graph_pool_handle -from transformer_engine.common.recipe import DelayedScaling +from transformer_engine.common.recipe import DelayedScaling, Recipe from transformer_engine.pytorch.constants import dist_group_type from .fp8 import ( fp8_autocast, @@ -556,12 +556,16 @@ def new_fwd(*user_args, **user_kwargs): def save_fp8_tensors( modules: Iterable[torch.nn.Module], - fp8_recipe: DelayedScaling, -) -> List[Any]: + fp8_recipe: Recipe, +) -> Optional[List[Any]]: """ Returns the FP8 tensors for all modules with adjusted amax history sizes. """ + + if not isinstance(fp8_recipe, DelayedScaling): + return None + fp8_tensors = [] for module in modules: for m in module.modules(): @@ -579,9 +583,13 @@ def save_fp8_tensors( def restore_fp8_tensors( modules: Iterable[torch.nn.Module], - fp8_tensors: List[Any], + fp8_tensors: Optional[List[Any]], ) -> None: """Restore FP8 tensors.""" + + if fp8_tensors is None: + return + for module in modules: for m in module.modules(): module_tensors = fp8_tensors.pop(0) diff --git a/transformer_engine/pytorch/module/_common.py b/transformer_engine/pytorch/module/_common.py index 2be291e4f9..cd18808465 100644 --- a/transformer_engine/pytorch/module/_common.py +++ b/transformer_engine/pytorch/module/_common.py @@ -4,29 +4,27 @@ """Internal function used by multiple modules.""" -from typing import Any, Dict, List, Optional, Tuple, Union, Callable +import os +from typing import Any, List, Optional, Tuple, Union, Callable from dataclasses import dataclass +from functools import reduce +from operator import mul as multiply_op import torch from .. import cpp_extensions as tex -from ..export import is_in_onnx_export_mode -from ..fp8 import get_fp8_te_dtype +from ..constants import TE_DType from ..utils import get_default_init_method +from ..tensor.float8_tensor import Float8Tensor +from ..tensor.mxfp8_tensor import MXFP8Quantizer +_use_cudnn_mxfp8_norm = bool(int(os.getenv("NVTE_CUDNN_MXFP8_NORM", "0"))) -def _get_normalization_func( - normalization: str, fp8_output: bool, is_grad_enabled: bool, forward: bool -): + +def _get_normalization_func(normalization: str, forward: bool): fwd_normalization_funcs = { - ("LayerNorm", True, True): tex.layernorm_fwd_fp8, - ("LayerNorm", True, False): tex.layernorm_fwd_fp8_inf, - ("LayerNorm", False, True): tex.layernorm_fwd_noalloc, - ("LayerNorm", False, False): tex.layernorm_fwd_inf, - ("RMSNorm", True, True): tex.rmsnorm_fwd_fp8, - ("RMSNorm", True, False): tex.rmsnorm_fwd_fp8_inf, - ("RMSNorm", False, True): tex.rmsnorm_fwd_noalloc, - ("RMSNorm", False, False): tex.rmsnorm_fwd_inf, + "LayerNorm": tex.layernorm_fwd, + "RMSNorm": tex.rmsnorm_fwd, } bwd_normalization_funcs = { "LayerNorm": tex.layernorm_bwd, @@ -34,81 +32,79 @@ def _get_normalization_func( } if forward: - return fwd_normalization_funcs[(normalization, fp8_output, is_grad_enabled)] - assert not fp8_output, "FP8 output is not supported in backward normalization!" - assert is_grad_enabled, "Gradient has to be enabled to call backward normalization!" + return fwd_normalization_funcs[normalization] return bwd_normalization_funcs[normalization] -def _apply_normalization( +def _fix_gathered_fp8_transpose(fp8_tensor: Float8Tensor, tp_size: int) -> Float8Tensor: + """Reorder FP8 transposes after Userbuffers gather. + + The all-gather is performed in-place in the Float8Tensor's + row-wise data, and afterwards we need to do a transpose to get the + correct ordering. This misuses data fields in Float8Tensor and + should be considered an evil hack. It would be best to move + transpose logic into CommOverlap::get_buffer. + + Responsibility for fixing: adener, tmoon + + """ + assert isinstance(fp8_tensor, Float8Tensor), "Tensor is not a Float8Tensor" + assert tp_size > 1, "The tensor transpose cannot be interleaved when TP size is 1" + assert fp8_tensor._data is not None, "The tensor does not hold any rowwise data" + assert ( + fp8_tensor._data.shape[0] % tp_size == 0 + ), "Leading dimension of data is not divisble by TP size" + + data = fp8_tensor._data + batched_size = reduce(multiply_op, data.shape[1:]) + interleaved_shape = [tp_size, data.shape[0] // tp_size, batched_size] + transposed_shape = [data.shape[0] // tp_size, batched_size * tp_size] + fp8_tensor._transpose = ( + data.view(interleaved_shape).transpose(0, 1).contiguous().view(transposed_shape) + ) + + fp8_tensor._transpose_invalid = False + fp8_tensor._data = None + + return fp8_tensor + + +def apply_normalization( inputmat: torch.Tensor, ln_out: torch.Tensor, ln_weight: torch.Tensor, ln_bias: Union[torch.Tensor, None], eps: float, - fp8_out: bool, - fp8_meta: Dict[str, Any], + output_quantizer, + output_dtype, normalization: str, fwd_ln_sm_margin: int, zero_centered_gamma: bool, - is_grad_enabled: bool, - fp8_scale: Optional[torch.Tensor] = None, - fp8_amax: Optional[torch.Tensor] = None, - fp8_scale_inv: Optional[torch.Tensor] = None, ): - normalization_func = _get_normalization_func(normalization, fp8_out, is_grad_enabled, True) + """Apply normalization to input.""" + normalization_func = _get_normalization_func(normalization, True) inputs = (inputmat, ln_weight) if ln_bias is None else (inputmat, ln_weight, ln_bias) - if fp8_out: - fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) - - if is_grad_enabled: - output_key = "ln_out" if normalization == "LayerNorm" else "rmsnorm_out" - output_kwarg = {output_key: ln_out} - output = normalization_func( - *inputs, - eps, - fp8_meta["scaling_fwd"], - tex.FP8FwdTensors.GEMM1_INPUT, - fp8_dtype_forward, - fwd_ln_sm_margin, - zero_centered_gamma, - scale=fp8_scale, - amax=fp8_amax, - scale_inv=fp8_scale_inv, - **output_kwarg, - ) - else: - return ( - normalization_func( - *inputs, - eps, - fp8_meta["scaling_fwd"], - tex.FP8FwdTensors.GEMM1_INPUT, - fp8_dtype_forward, - fwd_ln_sm_margin, - zero_centered_gamma, - scale=fp8_scale, - amax=fp8_amax, - scale_inv=fp8_scale_inv, - ), - None, - None, - ) - else: - if is_grad_enabled: - output = normalization_func(*inputs, ln_out, eps, fwd_ln_sm_margin, zero_centered_gamma) - else: - return ( - normalization_func(*inputs, eps, fwd_ln_sm_margin, zero_centered_gamma), - None, - None, - ) - if normalization == "RMSNorm": - output = (ln_out, None, output[1]) - elif normalization == "LayerNorm": - output = (ln_out, output[1], output[2]) - return output + + split_mxfp8_cast = False + if not _use_cudnn_mxfp8_norm and isinstance(output_quantizer, MXFP8Quantizer): + split_mxfp8_cast = True + + output = normalization_func( + *inputs, + eps, + None if split_mxfp8_cast else ln_out, + None if split_mxfp8_cast else output_quantizer, + TE_DType[output_dtype] if output_dtype in TE_DType else output_dtype, + fwd_ln_sm_margin, + zero_centered_gamma, + ) + + return ( + (output_quantizer.quantize(output[0], out=ln_out), *output[1:]) + if split_mxfp8_cast + else output + ) class _NoopCatFunc(torch.autograd.Function): @@ -202,7 +198,7 @@ def backward( return None, *grad_inputs -def _noop_cat( +def noop_cat( tensors: List[torch.Tensor], dim: int = 0, ) -> torch.Tensor: @@ -217,8 +213,6 @@ def _noop_cat( raise ValueError("Attempted to concatenate 0 tensors") if len(tensors) == 1: return tensors[0] - if is_in_onnx_export_mode(): - return torch.cat(tensors, dim=dim) return _NoopCatFunc.apply(dim, *tensors) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 8de0b733a9..d0f9525135 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -18,12 +18,14 @@ import torch.nn.functional as F import transformer_engine_torch as tex +from transformer_engine.common.recipe import Recipe + from ._common import _ParameterInitMeta -from ..export import is_in_onnx_export_mode from ..fp8 import ( - get_default_fp8_recipe, - get_fp8_te_dtype, + MXFP8BlockScalingRecipeState, + DelayedScalingRecipeState, FP8GlobalStateManager, + RecipeState, ) from ..distributed import ( gather_along_first_dim, @@ -31,13 +33,10 @@ in_fp8_activation_recompute_phase, _fsdp_gather_tensors, ) -from ..cpp_extensions import ( - fp8_cast_transpose_fused, - fp8_cast_transpose_bgrad_fused, - cast_to_fp8, -) from ..constants import dist_group_type -from ..float8_tensor import Float8Tensor +from ..tensor import QuantizedTensor, Quantizer +from ..tensor._internal.float8_tensor_base import Float8TensorBase +from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase __all__ = ["initialize_ub", "destroy_ub"] @@ -48,6 +47,7 @@ _cublas_workspace = None _ub_communicators = None _NUM_MAX_UB_STREAMS = 3 +_MIN_STREAM_PRIORITY, _MAX_STREAM_PRIORITY = None, None layers_atomic_ring_exchange = [] @@ -295,34 +295,43 @@ def get_method(name): raise KeyError(f"Given layer name {name} does not exist.") def get_default_config(name): + global _MIN_STREAM_PRIORITY, _MAX_STREAM_PRIORITY method = get_method(name) is_reduce_scatter = name in layers_reduce_scatter_overlap + if _MIN_STREAM_PRIORITY is None or _MAX_STREAM_PRIORITY is None: + _MIN_STREAM_PRIORITY, _MAX_STREAM_PRIORITY = tex.get_stream_priority_range() default_cfg = { "method": method, "is_reduce_scatter": is_reduce_scatter, "num_sm": 1 if method == "ring_exchange" else 16, "cga_size": 1 if method == "ring_exchange" else 2, - "set_sm_margin": False, - "num_splits": 4 if method == "pipeline" else tp_size, + "set_sm_margin": not method == "ring_exchange", + "num_splits": tp_size if method == "ring_exchange" else 4, "aggregate": False, "atomic_gemm": False, "use_ce": True, "fp8_buf": name in layers_all_gather_overlap, + "comm_priority": _MAX_STREAM_PRIORITY, + "gemm_priority": _MIN_STREAM_PRIORITY, + "pipeline_rs_overlap_first_gemm": False, } return default_cfg def add_ub( name: str, method: str, - is_reduce_scatter: int, + is_reduce_scatter: bool, num_sm: int = 16, cga_size: int = 2, - set_sm_margin: int = 0, + set_sm_margin: bool = False, num_splits: int = 0, - aggregate: int = 0, - atomic_gemm: int = 0, + aggregate: bool = False, + atomic_gemm: bool = False, use_ce: bool = True, fp8_buf: bool = False, + comm_priority: int = 0, + gemm_priority: int = 0, + pipeline_rs_overlap_first_gemm: bool = False, ) -> None: if atomic_gemm: warnings.warn( @@ -373,6 +382,8 @@ def add_ub( atomic_gemm=atomic_gemm, use_ce=use_ce, aggregate=aggregate, + gemm_priority=gemm_priority, + comm_priority=comm_priority, ) else: ub_obj = tex.CommOverlap( @@ -386,6 +397,9 @@ def add_ub( num_comm_sm=num_sm, set_sm_margin=set_sm_margin, atomic_gemm=atomic_gemm, + gemm_priority=gemm_priority, + comm_priority=comm_priority, + rs_overlap_first_gemm=pipeline_rs_overlap_first_gemm, ) _ub_communicators[name] = ub_obj @@ -439,8 +453,8 @@ def __init__(self) -> None: self.fp8_meta = {} self.fp8_meta["fp8_checkpoint"] = False self.fp8_meta["fp8_group"] = None - self.fp8_meta["recipe"] = get_default_fp8_recipe() self.fp8_meta_tensors_initialized = False + self.quantizers = {"scaling_fwd": {}, "scaling_bwd": {}} self.tp_group = None self.tp_size = 1 self.sequence_parallel = False @@ -448,7 +462,7 @@ def __init__(self) -> None: self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters() self.fsdp_wrapped = False self.fsdp_group = None - self._fp8_workspaces: Dict[str, Float8Tensor] = {} + self._fp8_workspaces: Dict[str, QuantizedTensor] = {} self.activation_dtype: Optional[torch.dtype] = None # Names of attributes that can be set quickly (see __setattr__ @@ -499,6 +513,9 @@ def adjust_amax_history_length(self, length: int, fwd: Optional[bool] = None) -> self.fp8_meta[meta_key].amax_history, pad=(0, 0, 0, extra_rows) ) + # Update quantizers with new amax pointers. + self.quantizers[meta_key] = self.fp8_meta[meta_key].make_quantizers() + # Update the global buffers with new amax and history pointers. if FP8GlobalStateManager.get_buffer_info() in self.fp8_meta: fwd_pos, fwd_key, bwd_pos, bwd_key = self.fp8_meta[ @@ -516,37 +533,38 @@ def adjust_amax_history_length(self, length: int, fwd: Optional[bool] = None) -> self.fp8_meta[meta_key].amax_history ) - def set_meta_tensor(self, fwd: bool) -> None: + def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: """Init scales and amaxes for fwd | bwd.""" fp8_meta_tensor_key = "scaling_fwd" if fwd else "scaling_bwd" + # Return early if recipe state matches recipe if self.fp8_meta_tensors_initialized: - # Handle changed amax history size. - self.adjust_amax_history_length(self.fp8_meta["recipe"].amax_history_len, fwd=fwd) - return + recipe_state = self.fp8_meta[fp8_meta_tensor_key] + if recipe.delayed() and isinstance(recipe_state, DelayedScalingRecipeState): + self.adjust_amax_history_length(recipe.amax_history_len, fwd=fwd) + return + if recipe.mxfp8() and isinstance(recipe_state, MXFP8BlockScalingRecipeState): + return # Max. number of fp8 tensors per GEMM = 3 (input, weight, output) for fwd and # 2 (grad_output and grad_input) for bwd num_fp8_tensors = self.fp8_meta["num_gemms"] * 3 if fwd else self.fp8_meta["num_gemms"] * 2 - self.fp8_meta[fp8_meta_tensor_key] = tex.FP8TensorMeta() - self.fp8_meta[fp8_meta_tensor_key].scale = torch.ones( - num_fp8_tensors, dtype=torch.float32, device="cuda" - ) - self.fp8_meta[fp8_meta_tensor_key].scale_inv = torch.ones( - num_fp8_tensors, dtype=torch.float32, device="cuda" - ) - self.fp8_meta[fp8_meta_tensor_key].amax_history = torch.zeros( - self.fp8_meta["recipe"].amax_history_len, - num_fp8_tensors, - dtype=torch.float32, - device="cuda", + # Initialize recipe state and quantizers + recipe_state = RecipeState.create( + recipe, + mode=("forward" if fwd else "backward"), + num_quantizers=num_fp8_tensors, ) - def init_fp8_meta_tensors(self) -> None: + self.fp8_meta[fp8_meta_tensor_key] = recipe_state + self.quantizers[fp8_meta_tensor_key] = recipe_state.make_quantizers() + + def init_fp8_meta_tensors(self, recipe: Recipe) -> None: """Init scales and amaxes.""" - self.set_meta_tensor(True) - self.set_meta_tensor(False) + self.set_meta_tensor(True, recipe) + self.set_meta_tensor(False, recipe) + self.fp8_meta_tensors_initialized = True def get_fp8_meta_tensors(self) -> None: @@ -559,7 +577,6 @@ def get_fp8_meta_tensors(self) -> None: with torch.no_grad(): for key in (fwd_key, bwd_key): fp8_meta_tensors[key].append(self.fp8_meta[key].scale.clone()) - fp8_meta_tensors[key].append(self.fp8_meta[key].scale_inv.clone()) fp8_meta_tensors[key].append(self.fp8_meta[key].amax_history.clone()) return fp8_meta_tensors @@ -570,17 +587,13 @@ def reset(key): if key in self.fp8_meta: if fp8_meta_tensors is None: self.fp8_meta[key].scale.copy_(torch.ones_like(self.fp8_meta[key].scale)) - self.fp8_meta[key].scale_inv.copy_( - torch.ones_like(self.fp8_meta[key].scale_inv) - ) self.fp8_meta[key].amax_history.copy_( torch.zeros_like(self.fp8_meta[key].amax_history) ) else: assert key in fp8_meta_tensors, "Cannot reset fp8 tensors." self.fp8_meta[key].scale.copy_(fp8_meta_tensors[key][0]) - self.fp8_meta[key].scale_inv.copy_(fp8_meta_tensors[key][1]) - self.fp8_meta[key].amax_history.copy_(fp8_meta_tensors[key][2]) + self.fp8_meta[key].amax_history.copy_(fp8_meta_tensors[key][1]) with torch.no_grad(): reset("scaling_fwd") @@ -624,12 +637,12 @@ def to_cpu(src: torch.Tensor) -> torch.Tensor: # Copy tensors to CPU and store state = {} - state["scale_fwd"] = to_cpu(self.fp8_meta["scaling_fwd"].scale) - state["amax_history_fwd"] = to_cpu(self.fp8_meta["scaling_fwd"].amax_history) - state["scale_inv_fwd"] = to_cpu(self.fp8_meta["scaling_fwd"].scale_inv) - state["scale_bwd"] = to_cpu(self.fp8_meta["scaling_bwd"].scale) - state["amax_history_bwd"] = to_cpu(self.fp8_meta["scaling_bwd"].amax_history) - state["scale_inv_bwd"] = to_cpu(self.fp8_meta["scaling_bwd"].scale_inv) + state["recipe"] = self.fp8_meta["recipe"] + if state["recipe"].delayed(): + state["scale_fwd"] = to_cpu(self.fp8_meta["scaling_fwd"].scale) + state["amax_history_fwd"] = to_cpu(self.fp8_meta["scaling_fwd"].amax_history) + state["scale_bwd"] = to_cpu(self.fp8_meta["scaling_bwd"].scale) + state["amax_history_bwd"] = to_cpu(self.fp8_meta["scaling_bwd"].amax_history) # Store other pickelable values extra = {} @@ -667,12 +680,12 @@ def set_extra_state(self, state: torch.Tensor) -> None: # Load extra items self.fp8_meta.update(state["extra_fp8_variables"]) - self.fp8_meta["recipe"].amax_history_len = state["amax_history_fwd"].shape[0] + self.fp8_meta["recipe"] = state["recipe"] if "global_fp8_buffer_pos_fwd_recompute" in self.fp8_meta: del self.fp8_meta["global_fp8_buffer_pos_fwd_recompute"] # Initialize before loading - self.init_fp8_meta_tensors() + self.init_fp8_meta_tensors(self.fp8_meta["recipe"]) def copy_tensor(src: torch.Tensor, dst: torch.Tensor) -> None: """Helper function to copy tensor from CPU @@ -684,12 +697,11 @@ def copy_tensor(src: torch.Tensor, dst: torch.Tensor) -> None: dst.copy_(src, non_blocking=True) # Load tensors - copy_tensor(state["scale_fwd"], self.fp8_meta["scaling_fwd"].scale) - copy_tensor(state["amax_history_fwd"], self.fp8_meta["scaling_fwd"].amax_history) - copy_tensor(state["scale_inv_fwd"], self.fp8_meta["scaling_fwd"].scale_inv) - copy_tensor(state["scale_bwd"], self.fp8_meta["scaling_bwd"].scale) - copy_tensor(state["amax_history_bwd"], self.fp8_meta["scaling_bwd"].amax_history) - copy_tensor(state["scale_inv_bwd"], self.fp8_meta["scaling_bwd"].scale_inv) + if self.fp8_meta["recipe"].delayed(): + copy_tensor(state["scale_fwd"], self.fp8_meta["scaling_fwd"].scale) + copy_tensor(state["amax_history_fwd"], self.fp8_meta["scaling_fwd"].amax_history) + copy_tensor(state["scale_bwd"], self.fp8_meta["scaling_bwd"].scale) + copy_tensor(state["amax_history_bwd"], self.fp8_meta["scaling_bwd"].amax_history) torch.cuda.synchronize() def set_activation_dtype(self, inp: torch.Tensor) -> None: @@ -729,7 +741,7 @@ def _get_fp8_params(self) -> Union[List[torch.Tensor], None]: """returns the FP8 weights.""" fp8_params = [] for param in self.parameters(recurse=False): - if isinstance(param, Float8Tensor) and param.requires_grad: + if isinstance(param, QuantizedTensor) and param.requires_grad: fp8_params.append(param) if len(fp8_params) == 0: return None @@ -742,22 +754,28 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None: self.fp8_parameters = FP8GlobalStateManager.with_fp8_parameters() self.fp8 = FP8GlobalStateManager.is_fp8_enabled() self.fp8_calibration = FP8GlobalStateManager.is_fp8_calibration() + fp8_enabled = self.fp8 or self.fp8_calibration self.fp8_meta["fp8_checkpoint"] = self.fp8 or self.fp8_calibration - if self.fp8_parameters and not self.fp8_initialized: - self.fp8_meta["num_gemms"] = num_gemms - self.init_fp8_meta_tensors() - - if self.fp8 or self.fp8_calibration: - # FP8 init has already been run and recipe is the same, don't do anything. + if self.fp8_parameters or fp8_enabled: if ( self.fp8_initialized and FP8GlobalStateManager.get_fp8_recipe() == self.fp8_meta["recipe"] ): + # FP8 init has already been run and recipe is the same, don't do anything. return - - # Set FP8, recipe, and other FP8 metadata self.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe() + else: + # If fp8 isn't enabled, turn off and return. + self.fp8_initialized = False + return + + if self.fp8_parameters and not self.fp8_initialized: + self.fp8_meta["num_gemms"] = num_gemms + self.init_fp8_meta_tensors(self.fp8_meta["recipe"]) + + if fp8_enabled: + # Set FP8 and other FP8 metadata self.fp8_meta["num_gemms"] = num_gemms self.fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group() @@ -766,17 +784,15 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None: self.fp8_meta["fp8_max_bwd"] = self.fp8_meta["recipe"].fp8_format.value.max_bwd # Allocate scales and amaxes - self.init_fp8_meta_tensors() + self.init_fp8_meta_tensors(self.fp8_meta["recipe"]) self.fp8_initialized = True - else: - # If fp8 isn't enabled, turn off and return. - self.fp8_initialized = False + + self.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe() @contextmanager def prepare_forward( self, inp: torch.Tensor, - is_first_microbatch: Union[bool, None], # pylint: disable=unused-argument num_gemms: int = 1, allow_non_contiguous: bool = False, ) -> Generator[torch.Tensor, None, None]: @@ -798,7 +814,7 @@ def prepare_forward( self.set_activation_dtype(inp) self.init_fp8_metadata(num_gemms=num_gemms) - if self.fp8 and self.sequence_parallel: + if self.fp8 and self.sequence_parallel and self.fp8_meta["recipe"].delayed(): assert self.fp8_meta["recipe"].reduce_amax, ( "Amax reduction across tensor parallel group is " "necessary when using sequence parallelism with FP8." @@ -838,110 +854,64 @@ def set_nccl_overlap_warning_if_tp(self) -> None: @staticmethod def grad_output_preprocess( - ctx, grad_output: torch.Tensor, row_parallel_mode: bool + ctx, + grad_output: torch.Tensor, + row_parallel_mode: bool, + quantizer: Optional[Quantizer], ) -> Tuple[Union[torch.Tensor, None], ...]: """Utility function for backward. Returns tuple in order (all optional/None based on training precion/recipe): - R1: gathered `grad_output` in higher precision. - R2: gathered `grad_output` in FP8. - R3: R2 transposed. - R4: bias gradient on R1. + R1: gathered `grad_output`. + R2: bias gradient on R1. """ - if isinstance(grad_output, Float8Tensor): - grad_output._data = grad_output._data.contiguous() - else: - grad_output = grad_output.contiguous() - grad_output_mat = grad_output.view(-1, grad_output.shape[-1]) + grad_output = grad_output.reshape((-1, grad_output.shape[-1])) + grad_output = grad_output.contiguous() gather_grad_output = row_parallel_mode and ctx.sequence_parallel - # No-FP8 case: bgrad is fused with wgrad for this case. + # Non-FP8 case: bgrad is fused with wgrad for this case. if not ctx.fp8: if gather_grad_output: if not ctx.ub_overlap_ag: - grad_output_mat, _ = gather_along_first_dim(grad_output_mat, ctx.tp_group) + grad_output, _ = gather_along_first_dim(grad_output, ctx.tp_group) else: - ctx.ub_obj_gradout.copy_input_to_ubuf(grad_output, True) - grad_output_mat = ctx.ub_obj_gradout.get_ubuf_output(1) - return grad_output_mat, None, None, None - - fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False) - - # FP8 case with non-FP8 wgrad - if gather_grad_output and ctx.fp8_meta["recipe"].override_linear_precision.wgrad: - assert ( - not ctx.ub_overlap_ag - ), "override_linear_precision.wgrad not supported with UB AG overlap" - grad_output_mat, _ = gather_along_first_dim(grad_output_mat, ctx.tp_group) - # FP8 case with gather: unfused bgrad, cast, transpose for efficient gather - elif gather_grad_output: + ctx.ub_obj_gradout.copy_into_buffer(grad_output, quantizer, local_chunk=True) + grad_output = ctx.ub_obj_gradout.get_buffer(quantizer) + return grad_output, None + + # FP8 with all-gather: unfused bgrad, fused cast + transpose + if gather_grad_output: + grad_bias = None if ctx.use_bias: - grad_bias = grad_output_mat.sum(dim=0) - else: - grad_bias = None + grad_bias = grad_output.view(-1, grad_output.shape[-1]).sum(dim=0) if ctx.ub_overlap_ag: - grad_output_c = ctx.ub_obj_gradout.get_ubuf_output(0) + # Quantize the gradient if needed + if not isinstance( + grad_output, (QuantizedTensor, Float8TensorBase, MXFP8TensorBase) + ): + grad_output = quantizer(grad_output) + + # Copy into communication buffer, and replace original gradient with it + ctx.ub_obj_gradout.copy_into_buffer(grad_output, quantizer, local_chunk=True) + grad_output = ctx.ub_obj_gradout.get_buffer(quantizer) else: - grad_output_c = torch.empty_like(grad_output_mat, dtype=torch.uint8) - if not isinstance(grad_output_mat, Float8Tensor): - cast_to_fp8( - grad_output_mat, - ctx.fp8_meta["scaling_bwd"], - tex.FP8BwdTensors.GRAD_OUTPUT1, - fp8_dtype_backward, - out=grad_output_c, + grad_output, _ = gather_along_first_dim( + grad_output, + ctx.tp_group, + quantizer=quantizer, ) - else: - grad_output_c = grad_output_mat - if not ctx.ub_overlap_ag: - grad_output_c, _ = gather_along_first_dim(grad_output_c, ctx.tp_group) - if not isinstance(grad_output_c, Float8Tensor): - grad_output_t = tex.fp8_transpose(grad_output_c, fp8_dtype_backward) - else: - grad_output_t = grad_output_c.transpose_2d() - else: - grad_output_c = ctx.ub_obj_gradout.get_ubuf_output(1) - grad_output_t = None + return grad_output, grad_bias - return grad_output_mat, grad_output_c, grad_output_t, grad_bias - - # FP8 case without gather: cast, transpose, bgrad fused + # FP8 without all-gather: fused bgrad + cast + transpose + grad_bias = None if ctx.use_bias: - grad_output_mat_no_fp8 = grad_output_mat - if isinstance(grad_output_mat, Float8Tensor): - grad_output_mat_no_fp8 = grad_output_mat.from_float8(grad_output_mat.dtype) - grad_bias, grad_output_c, grad_output_t = fp8_cast_transpose_bgrad_fused( - grad_output_mat_no_fp8, - ctx.fp8_meta["scaling_bwd"], - tex.FP8BwdTensors.GRAD_OUTPUT1, - fp8_dtype_backward, - ) - else: - if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad: - if isinstance(grad_output_mat, Float8Tensor): - grad_output_c = grad_output_mat - grad_output_t = grad_output_c.transpose_2d() - else: - grad_output_c, grad_output_t = fp8_cast_transpose_fused( - grad_output_mat, - ctx.fp8_meta["scaling_bwd"], - tex.FP8BwdTensors.GRAD_OUTPUT1, - fp8_dtype_backward, - ) + if isinstance(grad_output, (QuantizedTensor, Float8TensorBase, MXFP8TensorBase)): + grad_bias = grad_output.dequantize().view(-1, grad_output.shape[-1]).sum(dim=0) else: - grad_output_t = None - if not isinstance(grad_output_mat, Float8Tensor): - grad_output_c = cast_to_fp8( - grad_output_mat, - ctx.fp8_meta["scaling_bwd"], - tex.FP8BwdTensors.GRAD_OUTPUT1, - fp8_dtype_backward, - ) - else: - grad_output_c = grad_output_mat - grad_bias = None - - return grad_output_mat, grad_output_c, grad_output_t, grad_bias + grad_bias, grad_output = tex.bgrad_quantize(grad_output, quantizer) + if not isinstance(grad_output, (QuantizedTensor, Float8TensorBase, MXFP8TensorBase)): + grad_output = quantizer(grad_output) + return grad_output, grad_bias def register_parameter(self, name, param, **kwargs): """ @@ -978,21 +948,15 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None: with get_rng_state_tracker().fork(): init_fn(param) - # If primary weights are in fp8, wrap the parameter as Float8Tensor + # If primary weights are in fp8, wrap the parameter as FP8Tensor fp8_meta_index = self.param_init_meta[name].fp8_meta_index if self.primary_weights_in_fp8 and fp8_meta_index is not None: - dummy_amax = torch.empty( - (1, 1), - dtype=torch.float32, - device=param.device, - ) # Dummy buffer to avoid overwriting amax history - param = Float8Tensor.to_float8( - param, - fp8_meta=self.fp8_meta, - fp8_meta_index=fp8_meta_index, - amax=dummy_amax, - with_transpose_cache=torch.is_grad_enabled(), - ) + quantizer = self.quantizers["scaling_fwd"][fp8_meta_index] + assert ( + quantizer is not None + ) # to use primary fp8 weight one needs to use FP8 autocast with specific recipe. + quantizer.internal = False + param = quantizer(param) # Redo parameter wrap in case we broke it above # NOTE: Currently this can only be broken when primary weights are in Fp8 but @@ -1004,17 +968,16 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None: def forward(self): """Needs override.""" - def get_fp8_workspace( + def get_weight_workspace( self, *, tensor: Optional[torch.Tensor] = None, - fp8_meta_forward: Optional[bool] = None, - fp8_meta_index: Optional[int] = None, + quantizer: Optional[Quantizer] = None, cache_name: Optional[str] = None, update_workspace: bool = True, skip_update_flag: Optional[torch.Tensor] = None, - fsdp_group: dist_group_type = None, - ) -> Float8Tensor: + fsdp_group: Optional[dist_group_type] = None, + ) -> QuantizedTensor: """Get FP8 workspace buffer and maybe update its values The workspace buffer may be cached for future function calls. @@ -1024,13 +987,9 @@ def get_fp8_workspace( tensor : torch.Tensor, optional Values to copy into workspace. Required if the workspace is being constructed or updated. - fp8_meta_forward: bool, optional - Whether to access FP8 meta tensors for the forward pass or - backward pass. Required if the workspace is being - constructed. - fp8_meta_index: int, optional - Index to access in FP8 meta tensors. Required if the - workspace is being constructed. + quantizer: Quantizer, optional + Quantizer used to cast the weights. Required if the + workspace is being constructed or updated. cache_name: str, optional Key for caching. update_workspace: bool, default = `True` @@ -1052,61 +1011,24 @@ def get_fp8_workspace( # for models initialized with Fp8 primary weights. if ( out is not None - and not isinstance(out, Float8Tensor) + and tensor is not None and fsdp_group is not None - and out._data.shape != tensor.data.shape + and out.data.shape != tensor.data.shape ): _fsdp_gather_tensors(fsdp_group, [tensor.data.shape], out) # Construct workspace if needed if out is None: - - # FP8 data - if tensor is None or fp8_meta_forward is None or fp8_meta_index is None: + if tensor is None or quantizer is None: raise ValueError( - "tensor, fp8_meta_forward, and fp8_meta_index kwargs " - "must be provided to construct FP8 workspace" - ) - fp8_dtype = get_fp8_te_dtype( - self.fp8_meta["recipe"], - fprop_tensor=fp8_meta_forward, - ) - data = torch.empty_like(tensor, dtype=torch.uint8) - scale_inv = torch.empty([1], dtype=torch.float32, device=tensor.device) - - # Transpose cache - with_transpose_cache = torch.is_grad_enabled() - if ( - not with_transpose_cache - and is_fp8_activation_recompute_enabled() - and not in_fp8_activation_recompute_phase() - ): - with_transpose_cache = True - data_transpose = None - if with_transpose_cache: - data_transpose = torch.empty( - (tensor.size(-1), tensor.numel() // tensor.size(-1)), - dtype=torch.uint8, - device=tensor.device, + "tensor and quantizer kwargs must be provided to construct FP8 workspace" ) - - # Construct FP8 tensor - out = Float8Tensor( - data=data, - fp8_meta=self.fp8_meta, - fp8_meta_forward=fp8_meta_forward, - fp8_meta_index=fp8_meta_index, - fp8_dtype=fp8_dtype, - fp8_scale_inv=scale_inv, - dtype=tensor.dtype, - data_transpose=data_transpose, - ) + out = quantizer(tensor) # Update cache if cache_name is not None: self._fp8_workspaces[cache_name] = out - update_workspace = True - skip_update_flag = None + return out # Update workspace if needed if skip_update_flag is not None: @@ -1114,17 +1036,10 @@ def get_fp8_workspace( if update_workspace: if tensor is None: raise ValueError("tensor kwarg must be provided to update FP8 workspace") - if is_in_onnx_export_mode(): - # ONNX export does not support fused cast-transpose - # kernel and requires that FP8 scales can be - # represented with constant ops. - transpose_cache = out._transpose - out._transpose = None - out.quantize_(tensor) - out._scale_inv.fill_(out._scale_inv.item()) - out._transpose = transpose_cache - else: + if hasattr(out, "quantize_"): out.quantize_(tensor, noop_flag=skip_update_flag) + else: + tex.quantize(tensor, quantizer, out, skip_update_flag) return out diff --git a/transformer_engine/pytorch/module/fp8_padding.py b/transformer_engine/pytorch/module/fp8_padding.py index 1034398875..2549d45728 100644 --- a/transformer_engine/pytorch/module/fp8_padding.py +++ b/transformer_engine/pytorch/module/fp8_padding.py @@ -8,9 +8,8 @@ import torch -from ..cpp_extensions import ( - multi_padding_fused, -) +import transformer_engine_torch as tex + from ..jit import no_torch_dynamo @@ -36,7 +35,7 @@ def forward( total_row = sum(padded_m_splits) out = torch.empty([total_row, in_features], dtype=inp.dtype, device=inp.device) - multi_padding_fused(inp.view(-1, in_features), m_splits, padded_m_splits, out) + tex.fused_multi_row_padding(inp.view(-1, in_features), out, m_splits, padded_m_splits) if is_grad_enabled: ctx.m_splits = m_splits diff --git a/transformer_engine/pytorch/module/fp8_unpadding.py b/transformer_engine/pytorch/module/fp8_unpadding.py index b0832b0848..479b91d396 100644 --- a/transformer_engine/pytorch/module/fp8_unpadding.py +++ b/transformer_engine/pytorch/module/fp8_unpadding.py @@ -8,9 +8,8 @@ import torch -from ..cpp_extensions import ( - multi_padding_fused, -) +import transformer_engine_torch as tex + from ..jit import no_torch_dynamo @@ -56,8 +55,8 @@ def backward(ctx, grad_output: torch.Tensor): [total_row, in_features], dtype=grad_output.dtype, device=grad_output.device ) # FP8 pad input for forward, FP8 input transpose for backward wgrad - multi_padding_fused( - grad_output.view(-1, in_features), ctx.m_splits, ctx.padded_m_splits, grad_input + tex.fused_multi_row_padding( + grad_output.view(-1, in_features), grad_input, ctx.m_splits, ctx.padded_m_splits ) return (grad_input, None, None, None) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 65023e493b..2f9de58984 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -3,7 +3,7 @@ # See LICENSE for license information. """GroupedLinear API""" -from typing import Union, Optional, Callable, Tuple, List, Dict, Any +from typing import Union, Optional, Callable, Tuple, List import torch @@ -16,7 +16,7 @@ _2X_ACC_DGRAD, _2X_ACC_WGRAD, ) -from ..fp8 import get_fp8_te_dtype, FP8GlobalStateManager +from ..fp8 import FP8GlobalStateManager from ..utils import ( divide, cast_if_needed, @@ -28,21 +28,26 @@ from ..distributed import ( set_tensor_model_parallel_attributes, get_distributed_world_size, + is_fp8_activation_recompute_enabled, + in_fp8_activation_recompute_phase, ) from ..cpp_extensions import ( - cast_to_fp8, - fp8_cast_transpose_bgrad_fused, - fp8_multi_cast_transpose_fused, - fp8_grouped_gemm, - grouped_gemm, + general_grouped_gemm, ) -from ..constants import GemmParallelModes, dist_group_type +from ..constants import GemmParallelModes, dist_group_type, TE_DType from ..jit import no_torch_dynamo from ..graph import is_graph_capturing -from ..tensor import Float8Tensor, QuantizedTensor -from ..export import is_in_onnx_export_mode +from ..tensor.float8_tensor import Float8Tensor from ..cpu_offload import is_cpu_offload_enabled +from ..tensor.quantized_tensor import ( + QuantizedTensor, + Quantizer, + prepare_for_saving, + restore_from_saved, +) + + __all__ = ["GroupedLinear"] @@ -60,202 +65,141 @@ def forward( is_first_microbatch: Union[bool, None], fp8: bool, fp8_calibration: bool, - fp8_meta: Dict[str, Any], + input_quantizers: List[Quantizer], + weight_quantizers: List[Quantizer], + output_quantizers: List[Quantizer], + grad_output_quantizers: List[Quantizer], fuse_wgrad_accumulation: bool, cpu_offloading: bool, sequence_parallel: bool, activation_dtype: torch.dtype, - fp8_meta_offsets: Dict[str, int], is_grad_enabled: bool, - weights_fp8: List[Union[Float8Tensor, None]], - *weights_and_biases: Union[Float8Tensor, torch.Tensor, None], + module, + skip_fp8_weight_update, + *weights_and_biases, ) -> torch.Tensor: + # pylint: disable=missing-function-docstring num_gemms = len(m_splits) weights = weights_and_biases[:num_gemms] biases = weights_and_biases[num_gemms:] + device = inp.device + + # TODO Support MXFP8 # pylint: disable=fixme + if fp8 and FP8GlobalStateManager.get_fp8_recipe().mxfp8(): + raise NotImplementedError("GroupedLinear does not yet support MXFP8") # Make sure input dimensions are compatible in_features = weights[0].shape[-1] assert inp.shape[-1] == in_features, "GEMM not possible" inputmats = torch.split(inp.view(-1, in_features), m_splits) if fp8: - for i in range(num_gemms): - assert_dim_for_fp8_exec(inputmats[i]) - assert_dim_for_fp8_exec(weights[i]) + assert_dim_for_fp8_exec(*inputmats, *weights) # Cast input to expected dtype inputmats_no_fp8 = [cast_if_needed(mat, activation_dtype) for mat in inputmats] inputmats = [] - inputmats_t = [] - inputmat_scale_inv = None - if fp8: - fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) - inputmat_scale_inv = torch.empty([num_gemms], dtype=torch.float32, device=inp.device) - if ( - not fp8_meta["recipe"].override_linear_precision.wgrad - and is_grad_enabled - and weights[0].requires_grad - and not sequence_parallel - ): - # FP8 input for forward, FP8 input transpose for backward wgrad - indices = list( - range(fp8_meta_offsets["input"], fp8_meta_offsets["input"] + num_gemms) + weight_requires_grad = weights[0].requires_grad + + if input_quantizers[0] is not None: + for input_quantizer in input_quantizers: + input_quantizer.set_usage( + rowwise=True, + columnwise=(is_grad_enabled and weight_requires_grad), ) - inputmats, inputmats_t = fp8_multi_cast_transpose_fused( - inputmats_no_fp8, - fp8_meta["scaling_fwd"], - indices, # scale_indices - indices, # amax_indices - indices, # scale_inv_indices - fp8_dtype_forward, - scale_inv=inputmat_scale_inv, + columnwise_usage = is_grad_enabled and inp.requires_grad + if not columnwise_usage: + columnwise_usage = ( + is_fp8_activation_recompute_enabled() + and not in_fp8_activation_recompute_phase() ) - else: - # FP8 input for forward - inputmats = [ - cast_to_fp8( - inputmats_no_fp8[i], - fp8_meta["scaling_fwd"], - fp8_meta_offsets["input"] + i, - fp8_dtype_forward, - scale_inv=inputmat_scale_inv, + if weight_quantizers[0] is not None: + for weight_quantizer in weight_quantizers: + weight_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) + if output_quantizers[0] is not None: + for output_quantizer in output_quantizers: + output_quantizer.set_usage(rowwise=True, columnwise=False) + + if fp8: + inputmats = tex.fused_multi_quantize( + inputmats_no_fp8, None, input_quantizers, TE_DType[activation_dtype] + ) + weights_fp8 = [] + bias_dtype = torch.bfloat16 if activation_dtype == torch.float32 else activation_dtype + if not isinstance(weights[0], QuantizedTensor): + # FP8 cast to workspace buffer + update_workspace = is_first_microbatch is None or is_first_microbatch + for i in range(num_gemms): + weight_fp8 = module.get_weight_workspace( + tensor=weights[i], + quantizer=weight_quantizers[i], + cache_name=(None if is_first_microbatch is None else f"weight{i}"), + update_workspace=update_workspace, + skip_update_flag=skip_fp8_weight_update, ) - for i in range(num_gemms) - ] + weights_fp8.append(weight_fp8) + else: + weights_fp8 = weights - # Hack for ONNX export - # Note: ONNX models are represented as a graph of tensor - # operations, so the in-place scale-inv update doesn't fit - # very well. We work around this by making it look like - # the scale-inv tensor is initialized with a copy. - # Note: ONNX export expects FP8 scales can be represented - # with constant ops. However, copying into a buffer - # involves an expand op for array broadcasting. We work - # around this by filling the buffer instead. - if is_in_onnx_export_mode(): - inputmat_scale_inv.fill_(inputmat_scale_inv.item()) else: inputmats = inputmats_no_fp8 + bias_dtype = activation_dtype + weights_fp8 = [cast_if_needed(weight, activation_dtype) for weight in weights] - if fp8: - bias_dtype = torch.bfloat16 if activation_dtype == torch.float32 else activation_dtype - biases = [cast_if_needed(bias, bias_dtype) for bias in biases] if use_bias else biases + biases = [cast_if_needed(bias, bias_dtype) for bias in biases] if use_bias else biases - # Use FP8 weights - if weights_fp8[0] is None: - weights_fp8 = weights - assert all(isinstance(w, Float8Tensor) for w in weights_fp8) - - out = torch.empty( - [sum(m_splits), weights_fp8[0].size(0)], - dtype=activation_dtype, - device=inputmats[0].device, - ) + out = torch.empty( + [sum(m_splits), weights_fp8[0].size(0)], + dtype=activation_dtype, + device=device, + ) - _ = fp8_grouped_gemm( - [w._data for w in weights_fp8], - [w._scale_inv for w in weights_fp8], - 0, # weight offset is 0 for the newly created _scale_inv - fp8_dtype_forward, - inputmats, - inputmat_scale_inv, - 0, - fp8_dtype_forward, - [out], - activation_dtype, - get_multi_stream_cublas_workspace(), - m_splits=m_splits, - bias=biases, - use_bias=use_bias, - use_split_accumulator=_2X_ACC_FPROP, - ) - else: - # Cast for native AMP - weights = [cast_if_needed(w, activation_dtype) for w in weights] - biases = ( - [cast_if_needed(bias, activation_dtype) for bias in biases] if use_bias else biases - ) + _ = general_grouped_gemm( + weights_fp8, + inputmats, + [out], + activation_dtype, + get_multi_stream_cublas_workspace(), + single_output=True, + m_splits=m_splits, + bias=biases, + use_bias=use_bias, + use_split_accumulator=_2X_ACC_FPROP, + ) - if fp8_calibration: + if fp8_calibration: + for i in range(num_gemms): + # amax of input for i in range(num_gemms): - # amax of input - amin, amax = inputmats[i].aminmax() - fp8_meta["scaling_fwd"].amax_history[0][fp8_meta_offsets["input"] + i] = ( - torch.max(-amin, amax).float() - ) - # amax of weight - amin, amax = weights[i].aminmax() - fp8_meta["scaling_fwd"].amax_history[0][fp8_meta_offsets["weight"] + i] = ( - torch.max(-amin, amax).float() - ) + input_quantizers[i].calibrate(inputmats[i]) + for i in range(num_gemms): + weight_quantizers[i].calibrate(weights[i]) - out = torch.empty( - [sum(m_splits), weights[0].size(0)], - dtype=activation_dtype, - device=inputmats[0].device, - ) + if is_grad_enabled: - _ = grouped_gemm( - weights, - inputmats, - torch.split(out, m_splits), - activation_dtype, - get_multi_stream_cublas_workspace(), - bias=biases, - use_bias=use_bias, - ) + saved_inputs, saved_weights = [], [] + ctx.weights_shape_1 = weights[0].shape[1] - if is_grad_enabled: - saved_inputmats = [None] * num_gemms - saved_inputmats_t = [None] * num_gemms - if weights[0].requires_grad: - if fp8 and not fp8_meta["recipe"].override_linear_precision.wgrad: - if not inputmats_t: - saved_inputmats = inputmats - else: - saved_inputmats_t = inputmats_t - if cpu_offloading: - for t in saved_inputmats_t: - t.activation_offloading = True - else: - saved_inputmats = inputmats_no_fp8 - - if cpu_offloading: - if fp8: - for w in weights_fp8: - if w is not None: - w.weight_offloading = True - for w in weights: - w.weight_offloading = True - for t in saved_inputmats: - if t is not None: - t.activation_offloading = True - - ctx.save_for_backward( - inputmat_scale_inv, - *saved_inputmats, - *saved_inputmats_t, - *weights, - *weights_fp8, - *[ - w.main_grad if cpu_offloading and fuse_wgrad_accumulation else None - for w in weights - ], - ) + tensors_to_save, tensor_objects = prepare_for_saving(*inputmats, *weights_fp8, *biases) + ctx.save_for_backward(*tensors_to_save) + ctx.tensor_objects = tensor_objects + + ctx.weights_requires_grad = weights[0].requires_grad + ctx.device = device + ctx.saved_inputs = saved_inputs + ctx.saved_weights = saved_weights + ctx.grad_output_quantizers = grad_output_quantizers ctx.m_splits = m_splits ctx.num_gemms = num_gemms ctx.activation_dtype = activation_dtype ctx.fp8 = fp8 - ctx.fp8_meta = fp8_meta ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation ctx.cpu_offloading = cpu_offloading ctx.is_first_microbatch = is_first_microbatch ctx.use_bias = use_bias ctx.sequence_parallel = sequence_parallel ctx.inp_shape = inp.shape - ctx.fp8_meta_offsets = fp8_meta_offsets ctx.requires_dgrad = inp.requires_grad ctx.reduce_and_update_bwd_fp8_tensors = False if ctx.fp8 and requires_grad(inp, weights[0], biases[0]): @@ -271,66 +215,42 @@ def forward( def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: # pylint: disable=missing-function-docstring with torch.cuda.nvtx.range("_GroupedLinear_backward"): - ( - inputmat_scale_inv, - *saved_tensors, - ) = ctx.saved_tensors - inputmats = saved_tensors[: ctx.num_gemms] - inputmats_t = saved_tensors[ctx.num_gemms : 2 * ctx.num_gemms] - weights = saved_tensors[2 * ctx.num_gemms : 3 * ctx.num_gemms] - weights_fp8 = saved_tensors[3 * ctx.num_gemms : 4 * ctx.num_gemms] - main_grads = saved_tensors[4 * ctx.num_gemms :] - if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation: + saved_tensors = restore_from_saved(ctx.tensor_objects, ctx.saved_tensors) + N = ctx.num_gemms + inputmats = saved_tensors[:N] + weights = saved_tensors[N : 2 * N] + biases = saved_tensors[2 * N : 3 * N] + main_grads = saved_tensors[3 * N :] + + if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation: # TOSO for i in ctx.num_gemms: w = torch.nn.Parameter(weights[i], weights[i].requires_grad) w.main_grad = main_grads[i] weights[i] = w # preprocess grad_output + grad_output = grad_output.contiguous() grad_output_mats = torch.split( grad_output.view(-1, grad_output.shape[-1]), ctx.m_splits ) - grad_output_c = [None] * ctx.num_gemms - grad_output_t = [None] * ctx.num_gemms + grad_output = [None] * ctx.num_gemms grad_biases = [None] * ctx.num_gemms if ctx.fp8: - fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True) - fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False) if ctx.use_bias: for i in range(ctx.num_gemms): - grad_biases[i], grad_output_c[i], grad_output_t[i] = ( - fp8_cast_transpose_bgrad_fused( - grad_output_mats[i], - ctx.fp8_meta["scaling_bwd"], - ctx.fp8_meta_offsets["grad_output"] + i, - fp8_dtype_backward, - ) + grad_biases[i], grad_output[i] = tex.bgrad_quantize( + grad_output_mats[i], ctx.grad_output_quantizers[i] ) else: - if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad: - indices = list( - range( - ctx.fp8_meta_offsets["grad_output"], - ctx.fp8_meta_offsets["grad_output"] + ctx.num_gemms, - ) - ) - grad_output_c, grad_output_t = fp8_multi_cast_transpose_fused( - grad_output_mats, - ctx.fp8_meta["scaling_bwd"], - indices, # scale_indices - indices, # amax_indices - indices, # scale_inv_indices - fp8_dtype_backward, - ) - else: - for i in range(ctx.num_gemms): - grad_output_c[i] = cast_to_fp8( - grad_output_mats[i], - ctx.fp8_meta["scaling_bwd"], - ctx.fp8_meta_offsets["grad_output"] + i, - fp8_dtype_backward, - ) + grad_output = tex.fused_multi_quantize( + grad_output_mats, + None, + ctx.grad_output_quantizers, + TE_DType[ctx.activation_dtype], + ) + else: + grad_output = grad_output_mats if ctx.is_first_microbatch is not None: accumulate_wgrad_into_param_main_grad = ( @@ -340,111 +260,57 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation if ctx.requires_dgrad: - if ctx.fp8: - dgrad = torch.empty( - (sum(ctx.m_splits), weights_fp8[0].size(1)), - dtype=ctx.activation_dtype, - device=grad_output.device, - ) - fp8_grouped_gemm( - [w.transpose_2d() for w in weights_fp8], - [w._scale_inv for w in weights_fp8], - 0, # weight offset is 0 for the newly created _scale_inv - weights_fp8[0]._fp8_dtype, - grad_output_c, - ctx.fp8_meta["scaling_bwd"].scale_inv, - ctx.fp8_meta_offsets["grad_output"], - fp8_dtype_backward, - [dgrad], - ctx.activation_dtype, - get_multi_stream_cublas_workspace(), - m_splits=ctx.m_splits, - use_split_accumulator=_2X_ACC_DGRAD, - ) - else: - dgrad = torch.empty( - (sum(ctx.m_splits), weights[0].size(1)), - dtype=ctx.activation_dtype, - device=grad_output.device, - ) - grouped_gemm( - weights, - grad_output_mats, - torch.split(dgrad, ctx.m_splits), - ctx.activation_dtype, - get_multi_stream_cublas_workspace(), - layout="NN", - grad=True, - ) + dgrad = torch.empty( + (sum(ctx.m_splits), ctx.weights_shape_1), + dtype=ctx.activation_dtype, + device=ctx.device, + ) - if weights[0].requires_grad: + general_grouped_gemm( + weights, + grad_output, + torch.split(dgrad, ctx.m_splits), + ctx.activation_dtype, + get_multi_stream_cublas_workspace(), + layout="NN", + m_splits=ctx.m_splits, + grad=True, + use_split_accumulator=_2X_ACC_DGRAD, + ) + + if ctx.weights_requires_grad: if ctx.fuse_wgrad_accumulation: wgrad_list = [w.main_grad for w in weights] else: wgrad_list = [ - torch.empty(w.size(), dtype=ctx.activation_dtype, device=w.device) + torch.empty(w.size(), dtype=ctx.activation_dtype, device=ctx.device) for w in weights ] - if ctx.fp8: - # WGRAD - if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad: - if inputmats_t[0] is None: - for i in range(ctx.num_gemms): - if isinstance(inputmats[i], Float8Tensor): - inputmats_t[i] = inputmats[i].transpose_2d() - else: - inputmats_t[i] = tex.fp8_transpose( - inputmats[i], fp8_dtype_backward - ) - fp8_grouped_gemm( - [ - inp._data if isinstance(inp, Float8Tensor) else inp - for inp in inputmats_t - ], - [inputmat_scale_inv], - 0, - fp8_dtype_forward, - grad_output_t, - ctx.fp8_meta["scaling_bwd"].scale_inv, - ctx.fp8_meta_offsets["grad_output"], - fp8_dtype_backward, - wgrad_list, - ctx.activation_dtype, - get_multi_stream_cublas_workspace(), - accumulate=accumulate_wgrad_into_param_main_grad, - use_split_accumulator=_2X_ACC_WGRAD, - ) - else: - grouped_gemm( - inputmats, - grad_output_mats, - wgrad_list, - ctx.activation_dtype, - get_multi_stream_cublas_workspace(), - layout="NT", - grad=True, - accumulate=accumulate_wgrad_into_param_main_grad, - ) - else: # WGRAD - _, grad_biases, _ = grouped_gemm( + _, grad_biases_, _ = general_grouped_gemm( inputmats, - grad_output_mats, + grad_output, wgrad_list, ctx.activation_dtype, get_multi_stream_cublas_workspace(), layout="NT", grad=True, - use_bias=ctx.use_bias, + m_splits=ctx.m_splits, + use_bias=ctx.use_bias if grad_biases[0] is None else None, + bias=biases, + use_split_accumulator=_2X_ACC_WGRAD, accumulate=accumulate_wgrad_into_param_main_grad, ) + for i in range(ctx.num_gemms): + if grad_biases[i] is None: + grad_biases[i] = grad_biases_[i] + del grad_biases_ # Deallocate input tensor clear_tensor_data(*inputmats) - clear_tensor_data(*inputmats_t) def handle_custom_ddp_from_mcore(w, wgrad): - if w.requires_grad: + if ctx.weights_requires_grad: if ctx.fuse_wgrad_accumulation and hasattr(w, "grad_added_to_main_grad"): w.grad_added_to_main_grad = True if getattr(w, "zero_out_wgrad", False): @@ -478,22 +344,24 @@ def handle_custom_ddp_from_mcore(w, wgrad): if ctx.reduce_and_update_bwd_fp8_tensors and not is_graph_capturing(): FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False) - return ( dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None, - None, # m_splits - None, # use_bias - None, # is_first_microbatch - None, # fp8 - None, # fp8_calibration - None, # fp8_meta - None, # fuse_wgrad_accumulation - None, # cpu_offloading - None, # sequence_parallel - None, # activation_dtype - None, # fp8_meta_offsets + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, # is_grad_enabled None, # is_grad_enabled - None, # weights_fp8 *wgrad_list, *grad_biases, ) @@ -718,7 +586,7 @@ def forward( if skip_fp8_weight_update is not None: is_first_microbatch = False - with self.prepare_forward(inp, is_first_microbatch, num_gemms=self.num_gemms) as inp: + with self.prepare_forward(inp, num_gemms=self.num_gemms) as inp: weight_tensors = [getattr(self, f"weight{i}") for i in range(self.num_gemms)] bias_tensors = [getattr(self, f"bias{i}") for i in range(self.num_gemms)] @@ -727,29 +595,32 @@ def forward( w.dequantize() if isinstance(w, QuantizedTensor) else w for w in weight_tensors ] - weight_tensors_fp8 = [None] * self.num_gemms + input_quantizers, weight_quantizers, output_quantizers = ( + [None] * self.num_gemms, + [None] * self.num_gemms, + [None] * self.num_gemms, + ) + grad_output_quantizers, _ = [None] * self.num_gemms, [None] * self.num_gemms if self.fp8: + input_quantizers = [ + self.quantizers["scaling_fwd"][self._offsets["input"] + i] + for i in range(self.num_gemms) + ] for i in range(self.num_gemms): - if isinstance(weight_tensors[i], Float8Tensor): - # Make sure transpose cache is valid, if present - # Note: Transpose cache may have been invalidated - # externally, e.g. by optimizer. - if weight_tensors[i]._transpose is not None: - weight_tensors[i].transpose_2d( - fill_cache=True, - noop_flag=skip_fp8_weight_update, - ) - else: - # FP8 cast to workspace buffer - update_workspace = is_first_microbatch is None or is_first_microbatch - weight_tensors_fp8[i] = self.get_fp8_workspace( - tensor=weight_tensors[i], - fp8_meta_forward=True, - fp8_meta_index=self._offsets["weight"] + i, - cache_name=(None if is_first_microbatch is None else f"weight{i}"), - update_workspace=update_workspace, - skip_update_flag=skip_fp8_weight_update, - ) + input_quantizers[i].internal = True + weight_quantizers = [ + self.quantizers["scaling_fwd"][self._offsets["weight"] + i] + for i in range(self.num_gemms) + ] + for i in range(self.num_gemms): + weight_quantizers[i].internal = True + if torch.is_grad_enabled(): + grad_output_quantizers = [ + self.quantizers["scaling_bwd"][self._offsets["input"] + i] + for i in range(self.num_gemms) + ] + for i in range(self.num_gemms): + grad_output_quantizers[i].internal = True if torch.is_grad_enabled(): linear_fn = _GroupedLinear.apply @@ -764,14 +635,17 @@ def forward( is_first_microbatch, self.fp8, self.fp8_calibration, - self.fp8_meta, + input_quantizers, + weight_quantizers, + output_quantizers, + grad_output_quantizers, self.fuse_wgrad_accumulation, is_cpu_offload_enabled(), self.sequence_parallel, self.activation_dtype, - self._offsets, torch.is_grad_enabled(), - weight_tensors_fp8, + self, + skip_fp8_weight_update, *weight_tensors, *bias_tensors, ) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 189464cf80..60c73a8d7d 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -5,12 +5,14 @@ """LayerNormLinear API""" import os import warnings -from typing import Any, Callable, Dict, Optional, Tuple, Union +from typing import Callable, Dict, Optional, Tuple, Union +from functools import reduce +from operator import mul as multiply_op import torch from torch.nn import init -from .. import cpp_extensions as tex +import transformer_engine_torch as tex from .base import ( get_workspace, @@ -20,7 +22,7 @@ _2X_ACC_DGRAD, _2X_ACC_WGRAD, ) -from ..fp8 import get_fp8_te_dtype, FP8GlobalStateManager +from ..fp8 import FP8GlobalStateManager from ..utils import ( divide, get_default_init_method, @@ -40,14 +42,22 @@ _fsdp_scatter_tensors, _fsdp_gather_tensors, ) -from ..constants import GemmParallelModes, dist_group_type, TE_DType +from ..constants import GemmParallelModes, dist_group_type from ..jit import no_torch_dynamo from ..graph import is_graph_capturing -from ._common import _apply_normalization, _noop_cat -from ..float8_tensor import Float8Tensor -from ..export import is_in_onnx_export_mode -from ..tensor import QuantizedTensor -from ..cpu_offload import is_cpu_offload_enabled +from ._common import apply_normalization, noop_cat, _fix_gathered_fp8_transpose +from ..tensor.quantized_tensor import ( + QuantizedTensor, + Quantizer, + prepare_for_saving, + restore_from_saved, +) +from ..tensor.mxfp8_tensor import MXFP8Quantizer +from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param + +from ..cpp_extensions import ( + general_gemm, +) __all__ = ["LayerNormLinear"] @@ -64,15 +74,18 @@ def forward( ln_weight: torch.Tensor, ln_bias: Union[torch.Tensor, None], weight: torch.Tensor, - weight_fp8: Optional[torch.Tensor], bias: torch.Tensor, use_bias: bool, eps: float, is_first_microbatch: Union[bool, None], fp8: bool, fp8_calibration: bool, - fp8_meta: Dict[str, Any], fuse_wgrad_accumulation: bool, + input_quantizer: Optional[Quantizer], + weight_quantizer: Optional[Quantizer], + output_quantizer: Optional[Quantizer], + grad_output_quantizer: Optional[Quantizer], + grad_input_quantizer: Optional[Quantizer], cpu_offloading: bool, tp_group: Union[dist_group_type, None], tp_size: int, @@ -87,13 +100,16 @@ def forward( bwd_ln_sm_margin: int, zero_centered_gamma: bool, normalization: str, + ub_overlap_ag_fprop: bool, + ub_overlap_rs_fprop: bool, + ub_overlap_ag_dgrad: bool, + ub_overlap_rs_dgrad: bool, ub_bulk_wgrad: bool, ub_bulk_dgrad: bool, - ub_overlap_rs_dgrad: bool, - ub_overlap_ag: bool, ub_name: str, - fp8_output: bool, fsdp_group: Union[dist_group_type, None], + module: torch.nn.Module, + skip_fp8_weight_update: bool, ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: # pylint: disable=missing-function-docstring # Make sure input dimensions are compatible @@ -102,8 +118,7 @@ def forward( assert inp_shape[-1] == in_features, "GEMM not possible" inputmat = inp.view((-1, in_features)) if fp8: - assert_dim_for_fp8_exec(inputmat) - assert_dim_for_fp8_exec(weight) + assert_dim_for_fp8_exec(inputmat, weight) # Cast for native AMP inputmat = cast_if_needed(inputmat, activation_dtype) @@ -111,205 +126,183 @@ def forward( if ln_bias is not None: ln_bias = cast_if_needed(ln_bias, activation_dtype) - if ub_overlap_ag: - tp_world_size = get_distributed_world_size(tp_group) - if tp_world_size == 1 or (not is_grad_enabled): - ub_overlap_ag = False - if ub_overlap_ag: - dim_size = list(inputmat.size()) - dim_size[0] = dim_size[0] * tp_world_size - ub_obj_lnout = get_ub(ub_name + "_fprop") - if return_layernorm_output: - # First prepare LN output in higher precision, - # which will be later copied to a FP8 UB - ln_out = torch.empty_like(inputmat, memory_format=torch.contiguous_format) + tp_world_size = get_distributed_world_size(tp_group) + ub_overlap_ag_fprop = ( + ub_overlap_ag_fprop and is_grad_enabled and not return_layernorm_output + ) + + weight_requires_grad = weight.requires_grad + backward_needs_input = is_grad_enabled and weight_requires_grad + with_input_all_gather = parallel_mode == "column" and sequence_parallel + + if fp8: + if ( + any([ub_overlap_ag_fprop, ub_overlap_rs_fprop]) + and not FP8GlobalStateManager.get_fp8_recipe().delayed() + ): + raise NotImplementedError( + "Comm+GEMM overlap is only supported with FP8 delayed scaling" + ) + + if input_quantizer is None: + raise ValueError("Missing quantizer for input tensor") + + # Configure quantizer for normalization output + with_quantized_norm = fp8 and not return_layernorm_output + if with_quantized_norm: + if with_input_all_gather: + input_quantizer.set_usage(rowwise=True, columnwise=False) + if isinstance(input_quantizer, MXFP8Quantizer): + with_quantized_norm = False else: - ln_out = ub_obj_lnout.get_ubuf_output(0) + input_quantizer.set_usage( + rowwise=True, + columnwise=backward_needs_input, + ) + + ub_obj_fprop = None + ln_out = None + if ub_overlap_ag_fprop: + ub_obj_fprop = get_ub(ub_name + "_fprop") + ln_out = ub_obj_fprop.get_buffer(input_quantizer, local_chunk=True) + elif with_quantized_norm: + if with_input_all_gather: + input_quantizer.set_usage(rowwise=True, columnwise=False) + ln_out = input_quantizer.make_empty(inputmat.shape, dtype=inputmat.dtype, device="cuda") else: - ln_out_dtype = torch.uint8 if (fp8 and not return_layernorm_output) else inputmat.dtype ln_out = torch.empty_like( - inputmat, dtype=ln_out_dtype, memory_format=torch.contiguous_format + inputmat, dtype=inputmat.dtype, memory_format=torch.contiguous_format, device="cuda" ) - # Objects for FP8 cast - fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) - ln_out_scale_inv = None - if fp8: - ln_out_scale_inv = torch.empty([1], dtype=torch.float32, device=inputmat.device) - - # Launch normalization kernel - ln_out, mu, rsigma = _apply_normalization( + # Apply normalization + ln_out, mu, rsigma = apply_normalization( inputmat, ln_out, ln_weight, ln_bias, eps, - fp8 and not return_layernorm_output, - fp8_meta, + input_quantizer if with_quantized_norm else None, + inp.dtype, normalization, fwd_ln_sm_margin, zero_centered_gamma, - is_grad_enabled, - fp8_scale_inv=ln_out_scale_inv, ) - - # Column Parallel Linear - ln_out_gathered = False - ub_algo = None - if ub_overlap_ag: - ln_out_total = ub_obj_lnout.get_ubuf_output(1) - if not return_layernorm_output: - ln_out = torch.empty_like(ln_out) - if ub_obj_lnout.is_atomic_gemm(): - ub_algo = tex.CommOverlapAlgo.ATOMIC_GEMM_AG_P2P + ln_out_return = ln_out if return_layernorm_output else None + + # Prepare GEMM input + # Note: Cast to expected dtype and perform tensor-parallel communication + if with_input_all_gather and not ub_overlap_ag_fprop: + with_quantized_all_gather = fp8 + if return_layernorm_output and return_layernorm_output_gathered: + with_quantized_all_gather = False + if fp8: + input_quantizer.set_usage(rowwise=True, columnwise=False) + ln_out_total, _ = gather_along_first_dim( + ln_out, + tp_group, + quantizer=(input_quantizer if with_quantized_all_gather else None), + ) + if return_layernorm_output and return_layernorm_output_gathered: + ln_out_return = ln_out_total + if fp8 and not with_quantized_all_gather: + ln_out_total = input_quantizer(ln_out_total) + else: + if ub_overlap_ag_fprop: + ln_out_total = ub_obj_fprop.get_buffer(input_quantizer) else: - ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P - elif parallel_mode == "column" and sequence_parallel: - ln_out_gathered = True - ln_out_total, _ = gather_along_first_dim(ln_out, tp_group) + if fp8: + if not isinstance(ln_out, QuantizedTensor): + input_quantizer.set_usage(rowwise=True, columnwise=backward_needs_input) + ln_out = input_quantizer(ln_out) + elif backward_needs_input: + ln_out.update_usage(rowwise_usage=True, columnwise_usage=True) + ln_out_total = ln_out + + # Cast weight to expected dtype + weightmat = weight + quantized_weight = False + if not fp8: + weightmat = cast_if_needed(weightmat, activation_dtype) else: - ln_out_total = ln_out + if not isinstance(weight, QuantizedTensor): + quantized_weight = True + + # Configure quantizer + if weight_quantizer is not None: + weight_quantizer.set_usage(rowwise=True, columnwise=True) + + # FP8 cast to workspace buffer + update_workspace = is_first_microbatch is None or is_first_microbatch + weightmat = module.get_weight_workspace( + tensor=weight, + quantizer=weight_quantizer, + cache_name=(None if is_first_microbatch is None else "weight"), + update_workspace=update_workspace, + skip_update_flag=skip_fp8_weight_update, + fsdp_group=fsdp_group, + ) - # If residual connection is after LN, we need `ln_out_return` - # tensor in higher precision, this comes at the cost - # of an extra fp8 cast. - if return_layernorm_output: - ln_out_return = ln_out_total if return_layernorm_output_gathered else ln_out + # Cast bias to expected dtype + bias_dtype = activation_dtype + if fp8 and activation_dtype == torch.float32: + bias_dtype = torch.bfloat16 + bias = cast_if_needed(bias, bias_dtype) if bias is not None else bias + + # Configure output quantizer + if output_quantizer is not None: + output_quantizer.set_usage(rowwise=True, columnwise=False) + + # Calibrate quantizers if needed + if not fp8 and fp8_calibration: + if input_quantizer is not None: + input_quantizer.calibrate(ln_out_total) + if weight_quantizer is not None: + weight_quantizer.calibrate(weight) + + ub_obj = None + ub_type = None + rs_out = None + if ub_overlap_rs_fprop: + ub_obj = get_ub(ub_name + "_fprop") + ub_type = tex.CommOverlapType.RS + out_shape = [reduce(multiply_op, inp_shape[:-1]) // tp_world_size, out_features] + rs_out = torch.empty(out_shape, dtype=activation_dtype, device=ln_out_total.device) + + elif ub_overlap_ag_fprop: + ub_obj = get_ub(ub_name + "_fprop") + ub_type = tex.CommOverlapType.AG if fp8: - if ub_overlap_ag: - ln_out_fp8 = ub_obj_lnout.get_ubuf_output(0) - tex.cast_to_fp8( - ln_out, - fp8_meta["scaling_fwd"], - tex.FP8FwdTensors.GEMM1_INPUT, - fp8_dtype_forward, - out=ln_out_fp8, - scale_inv=ln_out_scale_inv, - ) - ln_out = torch.empty_like(ln_out_fp8) - else: - ln_out_total = tex.cast_to_fp8( - ln_out_total, - fp8_meta["scaling_fwd"], - tex.FP8FwdTensors.GEMM1_INPUT, - fp8_dtype_forward, - scale_inv=ln_out_scale_inv, - ) - if ln_out_gathered: - rank = torch.distributed.get_rank(tp_group) - slice_start = rank * ln_out.size(0) - slice_end = (rank + 1) * ln_out.size(0) - ln_out = ln_out_total[slice_start:slice_end, ...] - else: - ln_out = ln_out_total - - if fp8: - bias_dtype = torch.bfloat16 if activation_dtype == torch.float32 else activation_dtype - bias = cast_if_needed(bias, bias_dtype) if use_bias else bias - - # Use FP8 weights - if weight_fp8 is None: - weight_fp8 = weight - - assert isinstance(weight_fp8, Float8Tensor) - - # Hack for ONNX export - # Note: ONNX models are represented as a graph of tensor - # operations, so the in-place scale-inv update doesn't fit - # very well. We work around this by making it look like - # the scale-inv tensor is initialized with a copy. - # Note: ONNX export expects FP8 scales can be represented - # with constant ops. However, copying into a buffer - # involves an expand op for array broadcasting. We work - # around this by filling the buffer instead. - if is_in_onnx_export_mode(): - ln_out_scale_inv.fill_(ln_out_scale_inv.item()) - - if fp8_output: - out_index, meta_tensor, output_te_dtype, output_dtype = ( - tex.FP8FwdTensors.GEMM1_OUTPUT, - fp8_meta["scaling_fwd"], - fp8_dtype_forward, - torch.uint8, - ) - else: - out_index, meta_tensor, output_te_dtype, output_dtype = ( - None, - None, - None, - activation_dtype, - ) - out, _ = tex.fp8_gemm( - weight_fp8._data, - weight_fp8._scale_inv, - 0, - weight_fp8._fp8_dtype, - ln_out_total, - ln_out_scale_inv, - 0, - fp8_dtype_forward, - output_dtype, - get_workspace(), - bias=bias, - use_bias=use_bias, - use_split_accumulator=_2X_ACC_FPROP, - ub_algo=ub_algo if ub_overlap_ag else None, - ub=ub_obj_lnout if ub_overlap_ag else None, - extra_output_tensor=ln_out if ub_overlap_ag else None, - out_index=out_index, - fp8_meta_tensor=meta_tensor, - D_dtype=output_te_dtype, - ) - if output_dtype == torch.uint8: - out = Float8Tensor( - data=out, - fp8_meta=fp8_meta, - fp8_meta_forward=True, - fp8_meta_index=tex.FP8FwdTensors.GEMM1_OUTPUT, - fp8_dtype=fp8_dtype_forward, - dtype=activation_dtype, - ) - else: - # Cast for native AMP - weight = cast_if_needed(weight, activation_dtype) - bias = cast_if_needed(bias, activation_dtype) if use_bias else bias - - if fp8_calibration: - # amax of input - amin, amax = ln_out_total.aminmax() - fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_INPUT] = torch.max( - -amin, amax - ).float() - # amax of weight - amin, amax = weight.aminmax() - fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_WEIGHT] = torch.max( - -amin, amax - ).float() - - out, _, _ = tex.gemm( - weight, - ln_out_total, - activation_dtype, - get_workspace(), - bias=bias, - use_bias=use_bias, - ub_algo=tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P if ub_overlap_ag else None, - ub=ub_obj_lnout if ub_overlap_ag else None, - extra_output_tensor=ln_out if ub_overlap_ag else None, - ) + assert ub_obj.is_fp8_ubuf(), "AG overlap with FP8 GEMM inputs requires FP8 buffer." + ln_out_total = ub_obj.get_buffer(input_quantizer) + + out, *_, rs_out = general_gemm( + weightmat, + ln_out_total, + get_workspace(), + quantization_params=output_quantizer, + out_dtype=activation_dtype, + bias=bias, + use_split_accumulator=_2X_ACC_FPROP, + ub=ub_obj, + ub_type=ub_type, + extra_output=rs_out, + ) + if not weight.requires_grad: + if not return_layernorm_output: + ln_out = ln_out_total = None + clear_tensor_data(ln_out, ln_out_total) if is_grad_enabled: if cpu_offloading: - if fp8 and weight_fp8 is not None: - weight_fp8.weight_offloading = True - ln_weight.weight_offloading = True - weight.weight_offloading = True + if fp8 and weightmat is not None: + set_offloading_param(weightmat, "weight_offloading", True) + set_offloading_param(ln_weight, "weight_offloading", True) + set_offloading_param(weight, "weight_offloading", True) - inputmat.activation_offloading = True - if normalization == "LayerNorm": - mu.activation_offloading = True - rsigma.activation_offloading = True - ln_out.activation_offloading = True + set_offloading_param(inputmat, "activation_offloading", True) + set_offloading_param(mu, "activation_offloading", True) + set_offloading_param(rsigma, "activation_offloading", True) + set_offloading_param(ln_out, "activation_offloading", True) # Scatter intermediate/activation tensors saved for the backward pass # NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already @@ -319,25 +312,34 @@ def forward( fsdp_group, mu, rsigma, - weight_fp8 if fp8 and not isinstance(weight, Float8Tensor) else None, + weightmat if quantized_weight else None, ln_out if weight.requires_grad else None, ) - ctx.save_for_backward( + tensors_to_save, tensor_objects = prepare_for_saving( inputmat, + weightmat, + weight, + bias, ln_weight, + ln_out, mu, rsigma, - weight, - weight_fp8, - weight.main_grad if cpu_offloading and fuse_wgrad_accumulation else None, - ln_out if weight.requires_grad else None, - ln_out_scale_inv, ) - + ctx.save_for_backward(*tensors_to_save) + ctx.tensor_objects = tensor_objects + ctx.requires_dgrad = inp.requires_grad + ctx.requires_wgrad = weight.requires_grad + ctx.quantized_weight = quantized_weight + if fuse_wgrad_accumulation and weight.requires_grad: + ctx.main_grad = weight.main_grad + ctx.grad_input_quantizer = grad_input_quantizer + ctx.grad_output_quantizer = grad_output_quantizer + ctx.input_quantizer = input_quantizer + ctx.owns_input = inputmat is not inp + ctx.weight = weight ctx.activation_dtype = activation_dtype ctx.fp8 = fp8 - ctx.fp8_meta = fp8_meta ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation ctx.cpu_offloading = cpu_offloading ctx.is_first_microbatch = is_first_microbatch @@ -349,14 +351,13 @@ def forward( ctx.tp_group = tp_group ctx.tp_size = tp_size ctx.return_layernorm_output = return_layernorm_output - ctx.return_layernorm_output_gathered = ( - return_layernorm_output_gathered and ln_out_gathered - ) + ctx.return_layernorm_output_gathered = return_layernorm_output_gathered ctx.bwd_ln_sm_margin = bwd_ln_sm_margin ctx.zero_centered_gamma = zero_centered_gamma + ctx.ub_overlap_ag = ub_overlap_ag_dgrad + ctx.ub_overlap_rs_dgrad = ub_overlap_rs_dgrad ctx.ub_bulk_wgrad = ub_bulk_wgrad ctx.ub_bulk_dgrad = ub_bulk_dgrad - ctx.ub_overlap_rs_dgrad = ub_overlap_rs_dgrad ctx.ub_name = ub_name ctx.requires_dgrad = inp.requires_grad ctx.normalization = normalization @@ -368,10 +369,13 @@ def forward( FP8GlobalStateManager.IS_FIRST_FP8_MODULE = _first_fp8_module # Row Parallel Linear - if parallel_mode == "row" and sequence_parallel: - out, _ = reduce_scatter_along_first_dim(out, tp_group) - elif parallel_mode == "row" and tensor_parallel: - out, _ = allreduce(out, tp_group) + if ub_overlap_rs_fprop: + out = rs_out + elif parallel_mode == "row": + if sequence_parallel: + out, _ = reduce_scatter_along_first_dim(out, tp_group) + elif tensor_parallel: + out, _ = allreduce(out, tp_group) # [*, in_features] -> [*, out_features] except first dimension changes for SP out = out.view(-1, *inp_shape[1:-1], out_features) @@ -389,23 +393,42 @@ def backward( ctx, *grad_outputs: Tuple[torch.Tensor, ...] ) -> Tuple[Union[torch.Tensor, None], ...]: # pylint: disable=missing-function-docstring - if isinstance(grad_outputs[0], Float8Tensor): - ctx.fp8_meta["scaling_bwd"].scale_inv[tex.FP8BwdTensors.GRAD_OUTPUT1] = grad_outputs[ - 0 - ]._scale_inv with torch.cuda.nvtx.range("_LayerNormLinear_backward"): - ( + if ( + ctx.fp8 + and any( + [ + ctx.ub_overlap_ag, + ctx.ub_overlap_rs_dgrad, + ctx.ub_bulk_dgrad, + ctx.ub_bulk_wgrad, + ] + ) + and not FP8GlobalStateManager.get_fp8_recipe().delayed() + ): + raise NotImplementedError( + "Comm+GEMM overlap is only supported with FP8 delayed scaling" + ) + + saved_tensors = ctx.saved_tensors + ( # pylint: disable=unbalanced-tuple-unpacking inputmat, + weight, + _, + bias, ln_weight, + ln_out, mu, rsigma, - weight, - weight_fp8, - main_grad, - ln_out, - ln_out_scale_inv, - ) = ctx.saved_tensors + ) = restore_from_saved(ctx.tensor_objects, saved_tensors) + + # Since main_grad can be modified inplace, it should not be a part of saved_tensors + main_grad = ( + ctx.main_grad + if weight is not None and ctx.fuse_wgrad_accumulation and ctx.requires_wgrad + else None + ) # Gather intermediate/activation tensors if needed # NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already @@ -415,56 +438,93 @@ def backward( ctx.fsdp_shapes, mu, rsigma, - weight_fp8 if ctx.fp8 and not isinstance(weight, Float8Tensor) else None, + weight if ctx.fp8 and ctx.quantized_weight else None, ln_out, ) + # For CPU offloading, we offloaded weight and weight.main_grad to different tensors, + # we need to connect them into one. if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation: - weight = torch.nn.Parameter(weight, weight.requires_grad) weight.main_grad = main_grad - if ctx.ub_overlap_rs_dgrad: - ctx.ub_bulk_dgrad = False - ctx.ub_bulk_wgrad = False - tp_world_size = get_distributed_world_size(ctx.tp_group) - if tp_world_size == 1: - ctx.ub_overlap_rs_dgrad = False - if ctx.ub_bulk_dgrad: - tp_world_size = get_distributed_world_size(ctx.tp_group) - if tp_world_size == 1 or not weight.requires_grad: - ctx.ub_bulk_dgrad = False - if ctx.ub_bulk_dgrad: - dim_size = list(ln_out.size()) - dim_size[0] = dim_size[0] * tp_world_size - ub_obj_lnout = get_ub(ctx.ub_name + "_dgrad") - ub_obj_lnout.copy_input_to_ubuf(ln_out, 1) + ctx.ub_obj_gradout = None + ub_obj_dgrad = None + ub_obj_wgrad = None + ub_type_dgrad = None + ub_type_wgrad = None + dgrad_shape = [reduce(multiply_op, ctx.inp_shape[:-1]), ctx.inp_shape[-1]] + rs_out = None + dgrad_bulk = None + if ctx.ub_overlap_ag: + # Overlap grad_output all-gather with dgrad compute + ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad") + ub_obj_dgrad = ctx.ub_obj_gradout + ub_type_dgrad = tex.CommOverlapType.AG + + elif ctx.ub_overlap_rs_dgrad: + # Overlap dgrad reduce-scatter with dgrad compute + ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad") + ub_obj_dgrad = ctx.ub_obj_gradout + ub_type_dgrad = tex.CommOverlapType.RS + rs_out = torch.empty( + dgrad_shape, dtype=ctx.activation_dtype, device=inputmat.device + ) + + else: + if ctx.ub_bulk_dgrad: + # Overlap inputmat all-gather with dgrad compute + # NOTE: Copying into communication buffer will always prefer rowwise data, + # and will copy columnwise data if rowwise does not exist. In that case, + # the all-gather will apply to the leading dimension of the transpose, + # which then needs to be interleaved correctly before WGRAD. + ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad") + ub_obj_dgrad = ctx.ub_obj_gradout + ub_type_dgrad = tex.CommOverlapType.AG + ub_obj_dgrad.copy_into_buffer(ln_out, ctx.input_quantizer, local_chunk=True) + + if ctx.ub_bulk_wgrad: + # Overlap dgrad reduce-scatter with wgrad compute + ub_obj_wgrad = get_ub(ctx.ub_name + "_wgrad") + ub_type_wgrad = tex.CommOverlapType.RS + ub_obj_wgrad.set_buffer_params(ctx.grad_input_quantizer) + dgrad_bulk = ub_obj_wgrad.get_buffer(ctx.grad_input_quantizer) + + if ctx.grad_output_quantizer is not None: + ctx.grad_output_quantizer.set_usage(rowwise=True, columnwise=True) ( grad_output, - grad_output_c, - grad_output_t, grad_bias, ) = TransformerEngineBaseModule.grad_output_preprocess( - ctx, grad_outputs[0], ctx.parallel_mode == "row" + ctx, + grad_outputs[0], + ctx.parallel_mode == "row", + ctx.grad_output_quantizer, ) - if ctx.ub_bulk_wgrad: - tp_world_size = get_distributed_world_size(ctx.tp_group) - if tp_world_size == 1 or not weight.requires_grad: - ctx.ub_bulk_wgrad = False - - # Column Parallel Linear - # Overlap input AG with dgrad + # Prepare GEMM input + # Note: Perform tensor-parallel communication if needed + ln_out_total = None + ln_out_total_work = None if ( - weight.requires_grad - and (not ctx.ub_bulk_dgrad) + ctx.requires_wgrad and ctx.parallel_mode == "column" and ctx.sequence_parallel + and not ctx.ub_bulk_dgrad ): - ln_out_total, handle = gather_along_first_dim(ln_out, ctx.tp_group, async_op=True) + quantizer = None + if ctx.fp8: + quantizer = ctx.input_quantizer + quantizer.set_usage(rowwise=True, columnwise=True) + ln_out_total, ln_out_total_work = gather_along_first_dim( + ln_out, + ctx.tp_group, + async_op=True, + quantizer=quantizer, + ) else: ln_out_total = ln_out - handle = None + # Check whether to output wgrad GEMM directly into main grad if ctx.is_first_microbatch is not None: accumulate_wgrad_into_param_main_grad = ( ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch @@ -472,218 +532,129 @@ def backward( else: accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation - dgrad_size = list(grad_output.size()) - dgrad_size[1] = weight.size(1) - if ctx.ub_bulk_wgrad: # allocate dgrad output - ub_obj_dgrad = get_ub(ctx.ub_name + "_wgrad") - dgrad = ub_obj_dgrad.get_ubuf_output(1) # AllGather output - elif ctx.ub_overlap_rs_dgrad: - ub_obj_dgrad = get_ub(ctx.ub_name + "_dgrad") - dgrad = ub_obj_dgrad.get_ubuf_output(1) # AllGather output - else: - dgrad = torch.empty(dgrad_size, dtype=ctx.activation_dtype, device=weight.device) + # dgrad GEMM + if ctx.grad_input_quantizer is not None: + ctx.grad_input_quantizer.set_usage(rowwise=True, columnwise=False) - rs_out = None - if ctx.ub_bulk_dgrad: - ub_algo = tex.CommOverlapAlgo.BULK_OVERLAP_AG - ub_obj = ub_obj_lnout - elif ctx.ub_overlap_rs_dgrad: - dim_size = list(grad_output.size()) - dim_size[0] = dim_size[0] // tp_world_size - dim_size[1] = weight.size(1) - rs_out = torch.empty( - dim_size, dtype=ctx.activation_dtype, device=grad_output.device - ) - if ub_obj_dgrad.is_p2p_overlap(): - if ctx.fp8 and ub_obj_dgrad.is_atomic_gemm(): - ub_algo = tex.CommOverlapAlgo.ATOMIC_GEMM_RS_P2P - else: - ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P - else: - if ctx.fp8 and ub_obj_dgrad.is_atomic_gemm(): - ub_algo = tex.CommOverlapAlgo.ATOMIC_GEMM_RS - else: - ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS - ub_obj = ub_obj_dgrad - else: - ub_algo = None - ub_obj = None - - if ctx.fp8: - fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True) - fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False) - out_index, meta_tensor, out_te_type, out_type = ( - None, - None, - None, - ctx.activation_dtype, - ) - if (ctx.ub_bulk_wgrad or ctx.ub_overlap_rs_dgrad) and ub_obj_dgrad.is_fp8_ubuf(): - out_index = tex.FP8BwdTensors.GRAD_INPUT1 - meta_tensor = ctx.fp8_meta["scaling_bwd"] - out_te_type = fp8_dtype_backward - out_type = torch.uint8 - ub_obj_dgrad.set_ubuf_scale_inv(meta_tensor.scale_inv[out_index]) - - # DGRAD: Evaluated unconditionally to feed into Linear backward - _ = tex.fp8_gemm( - weight_fp8.transpose_2d(), - weight_fp8._scale_inv, - 0, - weight_fp8._fp8_dtype, - ( - grad_output_c._data - if isinstance(grad_output_c, Float8Tensor) - else grad_output_c - ), - ctx.fp8_meta["scaling_bwd"].scale_inv, - tex.FP8BwdTensors.GRAD_OUTPUT1, - fp8_dtype_backward, - out_type, - get_workspace(), - out=dgrad, - use_split_accumulator=_2X_ACC_DGRAD, - ub_algo=ub_algo, - ub=ub_obj, - extra_output_tensor=rs_out if ctx.ub_overlap_rs_dgrad else None, - out_index=out_index, - fp8_meta_tensor=meta_tensor, - D_dtype=out_te_type, - ) - clear_tensor_data(grad_output_c) - else: - # DGRAD: Evaluated unconditionally to feed into Linear backward - _, _, _ = tex.gemm( - weight, - grad_output, - ctx.activation_dtype, - get_workspace(), - out=dgrad, - layout="NN", - grad=True, - ub_algo=ub_algo, - ub=ub_obj, - extra_output_tensor=rs_out if ctx.ub_overlap_rs_dgrad else None, - ) - if ctx.ub_bulk_dgrad: - ln_out_total = ub_obj_lnout.get_ubuf_output(1) - - # Overlap dgrad-RS/AR with wgrad - if ctx.parallel_mode == "column" and ctx.sequence_parallel: - if not ctx.ub_bulk_dgrad and handle is not None: - handle.wait() - if not ctx.ub_bulk_wgrad and not ctx.ub_overlap_rs_dgrad: + dgrad, *_ = general_gemm( + weight, + grad_output, + get_workspace(), + layout="NN", + grad=True, + quantization_params=ctx.grad_input_quantizer, + out=dgrad_bulk, + out_dtype=ctx.activation_dtype, + use_split_accumulator=_2X_ACC_DGRAD, + ub=ub_obj_dgrad, + ub_type=ub_type_dgrad, + extra_output=rs_out, + bulk_overlap=ctx.ub_bulk_dgrad, + ) + + # Launch tensor-parallel communication + dgrad_work = None + if ctx.ub_overlap_rs_dgrad: + dgrad = rs_out + elif ctx.parallel_mode == "column" and not ctx.ub_bulk_wgrad: + if ctx.sequence_parallel: if ctx.return_layernorm_output and ctx.return_layernorm_output_gathered: dgrad = dgrad + grad_outputs[1].view_as(dgrad) - dgrad, handle = reduce_scatter_along_first_dim( - dgrad, ctx.tp_group, async_op=True + dgrad, dgrad_work = reduce_scatter_along_first_dim( + dgrad, + ctx.tp_group, + async_op=True, ) - elif ctx.parallel_mode == "column" and ctx.tensor_parallel: - dgrad, handle = allreduce(dgrad, ctx.tp_group, async_op=True) + else: + dgrad, dgrad_work = allreduce(dgrad, ctx.tp_group, async_op=True) + # Compute grad weight tensor wgrad = None - if weight.requires_grad: - if ctx.fp8: - # WGRAD - extra_output_tensor = None - if ctx.ub_bulk_wgrad: - if ub_obj_dgrad.is_fp8_ubuf(): - dim_size = list(ub_obj_dgrad.get_ubuf_output(0).size()) # RS output - extra_output_tensor = torch.empty( - dim_size, dtype=ctx.activation_dtype, device=dgrad.device - ) - dgrad = extra_output_tensor + if ctx.requires_wgrad: + if ctx.ub_bulk_dgrad: + ln_out_total = ub_obj_dgrad.get_buffer(ctx.input_quantizer) + if ctx.fp8: + # FP8 GEMM on Hopper only supports TN layout so the gathered input must have + # a valid transpose. + if ln_out._data is None: + # All-gather executed on columnwise data and result is in rowwise data, + # so we need to fix the interleaving before WGRAD. + ln_out_total = _fix_gathered_fp8_transpose(ln_out_total, ctx.tp_size) else: - dgrad = ub_obj_dgrad.get_ubuf_output(0) - if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad: - ln_out_total_t = tex.fp8_transpose(ln_out_total, fp8_dtype_forward) - wgrad, _ = tex.fp8_gemm( - ln_out_total_t, - ln_out_scale_inv, - 0, - fp8_dtype_forward, - ( - grad_output_t._data - if isinstance(grad_output_t, Float8Tensor) - else grad_output_t - ), - ctx.fp8_meta["scaling_bwd"].scale_inv, - tex.FP8BwdTensors.GRAD_OUTPUT1, - fp8_dtype_backward, - ctx.activation_dtype, - get_workspace(), - accumulate=accumulate_wgrad_into_param_main_grad, - out=weight.main_grad if ctx.fuse_wgrad_accumulation else None, - use_split_accumulator=_2X_ACC_WGRAD, - ub_algo=( - tex.CommOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None - ), - ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None, - extra_output_tensor=extra_output_tensor, - ) - clear_tensor_data(ln_out_total_t, grad_output_t) - else: - ln_out_total_c = torch.ops.tex_ts.cast_from_fp8_ts( - ln_out_total, - ln_out_scale_inv, - 0, - fp8_dtype_forward, - TE_DType[ctx.activation_dtype], - ) - wgrad, _, _ = tex.gemm( - ln_out_total_c, - grad_output, - ctx.activation_dtype, - get_workspace(), - layout="NT", - grad=True, - accumulate=accumulate_wgrad_into_param_main_grad, - out=weight.main_grad if ctx.fuse_wgrad_accumulation else None, - ub_algo=( - tex.CommOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None - ), - ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None, - extra_output_tensor=extra_output_tensor, - ) - clear_tensor_data(ln_out_total_c) + # FP8 GEMM on Hopper only supports TN layout so the gathered input must + # have a valid transpose. + ln_out_total._create_transpose() + else: - # WGRAD - wgrad, grad_bias, _ = tex.gemm( - ln_out_total, - grad_output, - ctx.activation_dtype, - get_workspace(), - layout="NT", - grad=True, - use_bias=ctx.use_bias, - accumulate=accumulate_wgrad_into_param_main_grad, - out=weight.main_grad if ctx.fuse_wgrad_accumulation else None, - ub_algo=tex.CommOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None, - ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None, + if ln_out_total_work is not None: + # Synchronize tensor-parallel communication + ln_out_total_work.wait() + ln_out_total_work = None + + if isinstance(grad_output, QuantizedTensor): + # This is a no-op if platform supports non-TN FP8 GEMM or the transpose + # already exists. + grad_output.update_usage(rowwise_usage=True, columnwise_usage=True) + + if ctx.ub_bulk_wgrad and ub_obj_wgrad.is_fp8_ubuf(): + rs_out = torch.empty( + dgrad_shape, dtype=ctx.activation_dtype, device=inputmat.device ) + + # wgrad GEMM + # Note: Fuse with bgrad computation if needed + wgrad, grad_bias_, *_, rs_out = general_gemm( + ln_out_total, + grad_output, + get_workspace(), + layout="NT", + grad=True, + out_dtype=( + main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype + ), + bias=(bias if (grad_bias is None and not ctx.fp8) else None), + out=main_grad if ctx.fuse_wgrad_accumulation else None, + use_split_accumulator=_2X_ACC_WGRAD, + accumulate=accumulate_wgrad_into_param_main_grad, + ub=ub_obj_wgrad, + ub_type=ub_type_wgrad, + extra_output=rs_out, + bulk_overlap=ctx.ub_bulk_wgrad, + ) + + if ctx.ub_bulk_wgrad: + if ub_obj_wgrad.is_fp8_ubuf(): + dgrad = rs_out + else: + dgrad = ub_obj_wgrad.get_buffer(None, local_chunk=True) + + if grad_bias is None: + grad_bias = grad_bias_ + del grad_bias_ + + # Deallocate input tensor + if not ctx.return_layernorm_output: + # TODO (pgadzinski) - deallocate transpose only # pylint: disable=fixme clear_tensor_data(ln_out_total) - if ctx.ub_bulk_wgrad: - dgrad = ub_obj_dgrad.get_ubuf_output(0) # Reduce-scatter output - # Column Parallel Linear - if ( - (not ctx.ub_bulk_wgrad) - and ctx.parallel_mode == "column" - and ctx.tensor_parallel - and handle is not None - ): - handle.wait() + # Don't return grad bias if not needed + if not ctx.use_bias: + grad_bias = None - # LayerNorm gradient - if ctx.ub_overlap_rs_dgrad: - dgrad = rs_out.view(inputmat.shape) - else: - dgrad = dgrad.view(inputmat.shape) + # Synchronize tensor parallel communication + if ln_out_total_work is not None: + ln_out_total_work.wait() + ln_out_total_work = None + if dgrad_work is not None: + dgrad_work.wait() + dgrad_work = None # Residual gradient + dgrad = dgrad.view(inputmat.shape) if ctx.return_layernorm_output and not ctx.return_layernorm_output_gathered: dgrad = dgrad + grad_outputs[1].view_as(dgrad) + # Norm gradient dgamma = None dbeta = None if ctx.normalization == "LayerNorm": @@ -696,6 +667,7 @@ def backward( ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma, ) + dgrad = dgrad.reshape(inputmat.size()) elif ctx.normalization == "RMSNorm": dgrad, dgamma = tex.rmsnorm_bwd( dgrad, @@ -705,14 +677,12 @@ def backward( ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma, ) + dgrad = dgrad.reshape(inputmat.size()) dbeta = None clear_tensor_data(mu) clear_tensor_data(rsigma) - if not ctx.use_bias: - grad_bias = None - - if weight.requires_grad: + if ctx.requires_wgrad: # Handle custom DDP from mcore. if ctx.fuse_wgrad_accumulation and hasattr(weight, "grad_added_to_main_grad"): weight.grad_added_to_main_grad = True @@ -724,12 +694,7 @@ def backward( requires_grad=False, ) else: - wgrad = torch.empty( - weight.main_grad.shape, - dtype=weight.dtype, - device=torch.cuda.current_device(), - requires_grad=False, - ) + wgrad = None elif ctx.fuse_wgrad_accumulation: wgrad = None else: @@ -739,23 +704,26 @@ def backward( FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False) # Scatter fp8 weight buffers - if ctx.fp8 and not isinstance(weight, Float8Tensor): - _fsdp_scatter_tensors(ctx.fsdp_group, weight_fp8) + # if ctx.fp8 and not isinstance(weight, QuantizedTensor): + # _fsdp_scatter_tensors(ctx.fsdp_group, weight_fp8) return ( dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None, dgamma, dbeta, wgrad, - None, # weight_fp8 grad_bias, None, # use_bias None, # eps None, # is_first_microbatch None, # fp8 None, # fp8_calibration - None, # fp8_meta None, # fuse_wgrad_accumulation + None, # input_quantizer + None, # weight_quantizer + None, # output_quantizer + None, # grad_output_quantizer + None, # grad_input_quantizer None, # cpu_offloading None, # tp_group None, # tp_size @@ -770,13 +738,16 @@ def backward( None, # bwd_ln_sm_margin None, # zero_centered_gamma None, # normalization - None, # ub_bulk_wgrad - None, # ub_bulk_dgrad + None, # ub_overlap_ag_fprop + None, # ub_overlap_rs_fprop + None, # ub_overlap_ag_dgrad None, # ub_overlap_rs_dgrad - None, # ub_overlap_ag + None, # ub_bulk_dgrad + None, # ub_bulk_wgrad None, # ub_name - None, # fp8_output None, # fsdp_group + None, # module + None, # skip_fp8_weight_update ) @@ -887,10 +858,11 @@ def __init__( parameters_split: Optional[Union[Tuple[str, ...], Dict[str, int]]] = None, zero_centered_gamma: bool = False, device: Union[torch.device, str] = "cuda", - ub_bulk_wgrad: bool = False, - ub_bulk_dgrad: bool = False, ub_overlap_ag: bool = False, + ub_overlap_rs: bool = False, ub_overlap_rs_dgrad: bool = False, + ub_bulk_wgrad: bool = False, + ub_bulk_dgrad: bool = False, ub_name: Optional[str] = None, ) -> None: super().__init__() @@ -907,13 +879,6 @@ def __init__( self.return_layernorm_output = return_layernorm_output self.return_layernorm_output_gathered = return_layernorm_output_gathered self.zero_centered_gamma = zero_centered_gamma - self.ub_bulk_wgrad = ub_bulk_wgrad - self.ub_bulk_dgrad = ub_bulk_dgrad - self.ub_overlap_ag = ub_overlap_ag - self.ub_overlap_rs_dgrad = ub_overlap_rs_dgrad - if any([ub_bulk_wgrad, ub_bulk_dgrad, ub_overlap_ag, ub_overlap_rs_dgrad]): - assert ub_name is not None, "Userbuffer name [string] is not set." - self.ub_name = ub_name if tp_group is None: self.tp_size = tp_size @@ -939,9 +904,49 @@ def __init__( self.sequence_parallel = (self.tp_size > 1) and sequence_parallel + # Column-parallel overlaps + self.ub_overlap_ag_fprop = ( + ub_overlap_ag and self.sequence_parallel and self.parallel_mode == "column" + ) + self.ub_overlap_rs_dgrad = ( + ub_overlap_rs_dgrad and self.sequence_parallel and self.parallel_mode == "column" + ) + self.ub_bulk_wgrad = ( + ub_bulk_wgrad + and self.sequence_parallel + and self.parallel_mode == "column" + and not self.ub_overlap_rs_dgrad + ) + self.ub_bulk_dgrad = ( + ub_bulk_dgrad + and self.sequence_parallel + and self.parallel_mode == "column" + and not self.ub_overlap_rs_dgrad + ) + + # Row-parallel overlaps + self.ub_overlap_rs_fprop = ( + ub_overlap_rs and self.sequence_parallel and self.parallel_mode == "row" + ) + self.ub_overlap_ag_dgrad = ( + ub_overlap_ag and self.sequence_parallel and self.parallel_mode == "row" + ) + if any( + [ + self.ub_overlap_ag_fprop, + self.ub_overlap_rs_dgrad, + self.ub_bulk_dgrad, + self.ub_bulk_wgrad, + self.ub_overlap_rs_fprop, + self.ub_overlap_ag_dgrad, + ] + ): + assert ub_name is not None, "Userbuffer name [string] is not set." + self.ub_name = ub_name + self.eps = eps layer_norm_weight = torch.nn.Parameter( - torch.empty(in_features, device=device, dtype=params_dtype) + torch.empty(self.in_features, device=device, dtype=params_dtype) ) self.register_parameter( "layer_norm_weight", @@ -950,7 +955,7 @@ def __init__( ) if self.normalization != "RMSNorm": layer_norm_bias = torch.nn.Parameter( - torch.empty(in_features, device=device, dtype=params_dtype) + torch.empty(self.in_features, device=device, dtype=params_dtype) ) self.register_parameter( "layer_norm_bias", layer_norm_bias, init_fn=init_method_constant(0.0) @@ -1034,7 +1039,9 @@ def __init__( # Check if parameters are subviews of buffers is_subview = (split_start, split_end) != (0, self.out_features) if is_subview and with_fp8_params: - raise RuntimeError("Splitting Float8Tensor into multiple params is not supported") + raise RuntimeError( + "Splitting QuantizedTensor into multiple params is not supported" + ) # Construct weight parameter self.register_parameter( @@ -1159,7 +1166,9 @@ def forward( if skip_fp8_weight_update is not None: is_first_microbatch = False - with self.prepare_forward(inp, is_first_microbatch) as inp: + with self.prepare_forward( + inp, allow_non_contiguous=False # removed .contiguous from inside the layer + ) as inp: # Get concatenated weight and bias tensors unfused_weights = [getattr(self, name) for name in self.weight_names] @@ -1171,35 +1180,20 @@ def forward( ) else: unfused_weights = [w.dequantize() for w in unfused_weights] - weight_tensor = _noop_cat(unfused_weights) + + weight_tensor = noop_cat(unfused_weights) if self.use_bias: - bias_tensor = _noop_cat([getattr(self, name) for name in self.bias_names]) + bias_tensor = noop_cat([getattr(self, name) for name in self.bias_names]) else: bias_tensor = getattr(self, self.bias_names[0]) # Unused - # Initialize FP8 weights if needed - weight_fp8 = None - if self.fp8: - if isinstance(weight_tensor, Float8Tensor): - # Make sure transpose cache is valid, if present - # Note: Transpose cache may have been invalidated - # externally, e.g. by optimizer. - if weight_tensor._transpose is not None: - weight_tensor.transpose_2d( - fill_cache=True, - noop_flag=skip_fp8_weight_update, - ) - else: - # FP8 cast to workspace buffer - update_workspace = is_first_microbatch is None or is_first_microbatch - weight_fp8 = self.get_fp8_workspace( - tensor=weight_tensor, - fp8_meta_forward=True, - fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT, - cache_name=(None if is_first_microbatch is None else "weight"), - update_workspace=update_workspace, - skip_update_flag=skip_fp8_weight_update, - ) + ( + input_quantizer, + weight_quantizer, + output_quantizer, + grad_output_quantizer, + grad_input_quantizer, + ) = self._get_quantizers(fp8_output) if torch.is_grad_enabled(): fwd_fn = _LayerNormLinear.apply @@ -1212,15 +1206,18 @@ def forward( self.layer_norm_weight, self.layer_norm_bias, weight_tensor, - weight_fp8, bias_tensor, self.apply_bias and not self.gemm_bias_unfused_add, self.eps, is_first_microbatch, self.fp8, self.fp8_calibration, - self.fp8_meta, self.fuse_wgrad_accumulation, + input_quantizer, + weight_quantizer, + output_quantizer, + grad_output_quantizer, + grad_input_quantizer, is_cpu_offload_enabled(), self.tp_group, self.tp_size, @@ -1235,13 +1232,16 @@ def forward( self.bwd_ln_sm_margin, self.zero_centered_gamma, self.normalization, + self.ub_overlap_ag_fprop, + self.ub_overlap_rs_fprop, + self.ub_overlap_ag_dgrad, + self.ub_overlap_rs_dgrad, self.ub_bulk_wgrad, self.ub_bulk_dgrad, - self.ub_overlap_rs_dgrad, - self.ub_overlap_ag, self.ub_name, - fp8_output, self.fsdp_group, + self, + skip_fp8_weight_update, ) out = fwd_fn(*args) @@ -1258,3 +1258,27 @@ def forward( if self.return_layernorm_output: return out, ln_out return out + + def _get_quantizers(self, fp8_output): + if not self.fp8: + return [None] * 5 + grad_input_quantizer = None + grad_output_quantizer = None + output_quantizer = None + input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT] + input_quantizer.internal = False + weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT] + weight_quantizer.internal = True + if fp8_output: + output_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_OUTPUT] + if torch.is_grad_enabled(): + grad_output_quantizer = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT1] + grad_output_quantizer.internal = True + + return ( + input_quantizer, + weight_quantizer, + output_quantizer, + grad_output_quantizer, + grad_input_quantizer, + ) diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 7bcbb1eb7d..88eebc8e6c 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -5,12 +5,16 @@ """LayerNormMLP API""" import os import warnings -from typing import Any, Callable, Dict, Optional, Tuple, Union +from typing import Callable, Optional, Tuple, Union +from functools import reduce +from operator import mul as multiply_op import torch from torch.nn.parameter import Parameter from torch.nn import init +import transformer_engine_torch as tex + from .base import ( get_workspace, _ub_communicators, @@ -20,7 +24,7 @@ _2X_ACC_DGRAD, _2X_ACC_WGRAD, ) -from ..fp8 import get_fp8_te_dtype, FP8GlobalStateManager +from ..fp8 import FP8GlobalStateManager from ..jit import ( bias_gelu_fused, bgrad_dgelu_fused, @@ -35,6 +39,7 @@ assert_dim_for_fp8_exec, clear_tensor_data, requires_grad, + non_tn_fp8_gemm_supported, ) from ..distributed import ( set_tensor_model_parallel_attributes, @@ -45,30 +50,39 @@ use_reentrant_activation_recompute, in_fp8_activation_recompute_phase, _fsdp_scatter_tensors, - _fsdp_gather_tensors, ) -from .. import cpp_extensions as tex - -from ..constants import dist_group_type, TE_DType +from ..constants import dist_group_type from ..jit import no_torch_dynamo from ..graph import is_graph_capturing -from ..float8_tensor import Float8Tensor -from ._common import _apply_normalization -from ..cpu_offload import is_cpu_offload_enabled +from ..tensor.float8_tensor import Float8Tensor +from ..tensor.mxfp8_tensor import MXFP8Quantizer +from ._common import apply_normalization, _fix_gathered_fp8_transpose +from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param + +from ..tensor.quantized_tensor import ( + QuantizedTensor, + Quantizer, + prepare_for_saving, + restore_from_saved, +) +from ..cpp_extensions import ( + general_gemm, +) __all__ = ["LayerNormMLP"] def _act_func(activation: str): funcs = { - "gelu": (tex.gelu, tex.dgelu), - "relu": (tex.relu, tex.drelu), - "geglu": (tex.geglu, tex.dgeglu), - "reglu": (tex.reglu, tex.dreglu), - "swiglu": (tex.swiglu, tex.dswiglu), - "qgelu": (tex.qgelu, tex.dqgelu), - "srelu": (tex.srelu, tex.dsrelu), + "gelu": (tex.gelu, tex.dgelu, tex.dbias_dgelu), + "relu": (tex.relu, tex.drelu, tex.dbias_drelu), + "geglu": (tex.geglu, tex.dgeglu, None), + "reglu": (tex.reglu, tex.dreglu, None), + "swiglu": (tex.swiglu, tex.dswiglu, None), + "qgelu": (tex.qgelu, tex.dqgelu, tex.dbias_dqgelu), + "qgeglu": (tex.qgeglu, tex.dqgeglu, None), + "srelu": (tex.srelu, tex.dsrelu, tex.dbias_dsrelu), } if activation not in funcs: raise NotImplementedError("Activation type " + activation + " is not supported!") @@ -87,19 +101,24 @@ def forward( ln_weight: torch.Tensor, ln_bias: torch.Tensor, fc1_weight: torch.Tensor, - fc1_weight_fp8: Optional[torch.Tensor], fc1_bias: torch.Tensor, use_fc1_bias: bool, fc2_weight: torch.Tensor, - fc2_weight_fp8: Optional[torch.Tensor], fc2_bias: torch.Tensor, use_fc2_bias: bool, eps: float, is_first_microbatch: Union[bool, None], fp8: bool, fp8_calibration: bool, - fp8_meta: Dict[str, Any], fuse_wgrad_accumulation: bool, + fc1_input_quantizer: Optional[Quantizer], + fc1_weight_quantizer: Optional[Quantizer], + fc2_input_quantizer: Optional[Quantizer], + fc2_weight_quantizer: Optional[Quantizer], + output_quantizer: Optional[Quantizer], + grad_fc2_output_quantizer: Optional[Quantizer], + grad_fc1_output_quantizer: Optional[Quantizer], + grad_input_quantizer: Optional[Quantizer], cpu_offloading: bool, tp_group: Union[dist_group_type, None], tp_size: int, @@ -108,7 +127,7 @@ def forward( activation_dtype: torch.dtype, return_layernorm_output: bool, return_layernorm_output_gathered: bool, - bias_gelu_nvfusion: bool, + bias_gelu_fusion: bool, set_parallel_mode: bool, is_grad_enabled: bool, fwd_ln_sm_margin: int, @@ -116,26 +135,34 @@ def forward( zero_centered_gamma: bool, activation: str, normalization: str, + ub_overlap_ag: bool, + ub_overlap_rs: bool, + ub_overlap_rs_dgrad: bool, ub_bulk_wgrad: bool, ub_bulk_dgrad: bool, - ub_overlap_rs_dgrad: bool, - ub_overlap_rs: bool, - ub_overlap_ag: bool, gemm_gelu_fusion: bool, fsdp_group: Union[dist_group_type, None], + module: torch.nn.Module, + skip_fp8_weight_update: bool, ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: # pylint: disable=missing-function-docstring + + in_features, inp_shape = ln_weight.numel(), inp.shape # Make sure input dimensions are compatible - in_features = ln_weight.numel() - inp_shape = inp.shape assert inp_shape[-1] == in_features, "GEMM not possible" inputmat = inp.view((-1, in_features)) if fp8: - assert_dim_for_fp8_exec(inputmat) - assert_dim_for_fp8_exec(fc1_weight) - assert_dim_for_fp8_exec(fc2_weight) + assert_dim_for_fp8_exec(inputmat, fc1_weight, fc2_weight) + if ( + any([ub_overlap_ag, ub_overlap_rs]) + and not FP8GlobalStateManager.get_fp8_recipe().delayed() + ): + raise NotImplementedError( + "Comm+GEMM overlap is only supported with FP8 delayed scaling" + ) activation_func = _act_func(activation)[0] + device = inp.device # Cast for native AMP inputmat = cast_if_needed(inputmat, activation_dtype) @@ -143,314 +170,250 @@ def forward( if ln_bias is not None: ln_bias = cast_if_needed(ln_bias, activation_dtype) + # for standard fp8: layernorm output = FP8 + # only output of the linear is returned + # for return_layernorm_output: layernorm output = High precision, then cast to FP8 + # high precision layernorm output and output of the linear are returned + with_quantized_norm = fp8 and not return_layernorm_output + tp_world_size = get_distributed_world_size(tp_group) - if ub_overlap_ag: - if tp_world_size == 1 or (not is_grad_enabled) or return_layernorm_output: - ub_overlap_ag = False + ub_overlap_ag = ub_overlap_ag and is_grad_enabled and not return_layernorm_output + ub_overlap_rs = ub_overlap_rs and is_grad_enabled + with_input_all_gather_nccl = sequence_parallel and not ub_overlap_ag + backwards_needs_fc1_input = is_grad_enabled and fc1_weight.requires_grad + + # Configure quantizer for normalization output + if fp8 and fc1_input_quantizer is None: + raise ValueError("Missing quantizer for input tensor") + if with_quantized_norm: + if with_input_all_gather_nccl: + fc1_input_quantizer.set_usage(rowwise=True, columnwise=False) + if isinstance(fc1_input_quantizer, MXFP8Quantizer): + with_quantized_norm = False + else: + fc1_input_quantizer.set_usage( + rowwise=True, + columnwise=backwards_needs_fc1_input, + ) + + ub_obj_lnout = None + ln_out = None if ub_overlap_ag: ub_obj_lnout = get_ub("fc1_fprop") - ln_out = ub_obj_lnout.get_ubuf_output(0) - else: - ln_out_dtype = torch.uint8 if (fp8 and not return_layernorm_output) else inputmat.dtype + ln_out = ub_obj_lnout.get_buffer(fc1_input_quantizer, local_chunk=True) + elif not with_quantized_norm: ln_out = torch.empty_like( - inputmat, dtype=ln_out_dtype, memory_format=torch.contiguous_format + inputmat, dtype=inputmat.dtype, memory_format=torch.contiguous_format, device="cuda" ) - ub_overlap_rs = False if tp_world_size == 1 else ub_overlap_rs - - fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) - ln_out, mu, rsigma = _apply_normalization( + # Apply normalization + ln_out, mu, rsigma = apply_normalization( inputmat, ln_out, ln_weight, ln_bias, eps, - fp8 and not return_layernorm_output, - fp8_meta, + fc1_input_quantizer if with_quantized_norm else None, + inp.dtype, normalization, fwd_ln_sm_margin, zero_centered_gamma, - is_grad_enabled, ) - # Column Parallel Linear + # Prepare GEMM input + # Note: Cast to expected dtype and perform tensor-parallel communication ln_out_gathered = False - ub_algo_ag = None - if ub_overlap_ag: - ln_out_total = ub_obj_lnout.get_ubuf_output(1) - ln_out = torch.empty_like(ln_out) - if ub_obj_lnout.is_atomic_gemm(): - ub_algo_ag = tex.CommOverlapAlgo.ATOMIC_GEMM_AG_P2P - else: - ub_algo_ag = tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P - elif set_parallel_mode and sequence_parallel: + with_quantized_all_gather = fp8 + if with_input_all_gather_nccl: + if return_layernorm_output and return_layernorm_output_gathered: + with_quantized_all_gather = False + if fp8: + fc1_input_quantizer.set_usage(rowwise=True, columnwise=False) + ln_out_total, _ = gather_along_first_dim( + ln_out, + tp_group, + quantizer=(fc1_input_quantizer if with_quantized_all_gather else None), + ) ln_out_gathered = True - ln_out_total, _ = gather_along_first_dim(ln_out, tp_group) else: - ln_out_total = ln_out + with_quantized_all_gather = False + if ub_overlap_ag: + ln_out_total = ub_obj_lnout.get_buffer(fc1_input_quantizer, False) + else: + ln_out_total = ln_out # If residual connection is after LN, we need `ln_out` # tensor in higher precision, this comes at the cost # of an extra fp8 cast. + ln_out_return = None if return_layernorm_output: ln_out_return = ln_out_total if return_layernorm_output_gathered else ln_out - if fp8: - if ub_overlap_ag: - ln_out = tex.cast_to_fp8( - ln_out, - fp8_meta["scaling_fwd"], - tex.FP8FwdTensors.GEMM1_INPUT, - fp8_dtype_forward, - ) + if fp8 and not with_quantized_all_gather: + ln_out_total = fc1_input_quantizer(ln_out_total) + if ln_out_gathered: + rank = torch.distributed.get_rank(tp_group) + slice_start = rank * ln_out.size(0) + slice_end = (rank + 1) * ln_out.size(0) + ln_out = ln_out_total[ + slice_start:slice_end, ... + ] # TODO(pgadzinski) - check this # pylint: disable=fixme else: - ln_out_total = tex.cast_to_fp8( - ln_out_total, - fp8_meta["scaling_fwd"], - tex.FP8FwdTensors.GEMM1_INPUT, - fp8_dtype_forward, - ) - if ln_out_gathered: - rank = torch.distributed.get_rank(tp_group) - slice_start = rank * ln_out.size(0) - slice_end = (rank + 1) * ln_out.size(0) - ln_out = ln_out_total[slice_start:slice_end, ...] - else: - ln_out = ln_out_total - - if fp8: - bias_dtype = torch.bfloat16 if activation_dtype == torch.float32 else activation_dtype - fc1_bias = cast_if_needed(fc1_bias, bias_dtype) if use_fc1_bias else fc1_bias - fc2_bias = cast_if_needed(fc2_bias, bias_dtype) if use_fc2_bias else fc2_bias - - # Use FP8 weights - if fc1_weight_fp8 is None: - fc1_weight_fp8 = fc1_weight - if fc2_weight_fp8 is None: - fc2_weight_fp8 = fc2_weight - - assert isinstance(fc1_weight_fp8, Float8Tensor) - assert isinstance(fc2_weight_fp8, Float8Tensor) - - # Perform FP8 GEMM - fp8_gemm_args = [ - fc1_weight_fp8._data, - fc1_weight_fp8._scale_inv, - 0, - fc1_weight_fp8._fp8_dtype, - ln_out_total, - fp8_meta["scaling_fwd"].scale_inv, - tex.FP8FwdTensors.GEMM1_INPUT, - fp8_dtype_forward, - activation_dtype, - get_workspace(), - ] - fp8_gemm_kwargs = { - "bias": fc1_bias, - "use_bias": use_fc1_bias, - "use_split_accumulator": _2X_ACC_FPROP, - "ub_algo": ub_algo_ag if ub_overlap_ag else None, - "ub": ub_obj_lnout if ub_overlap_ag else None, - "extra_output_tensor": ln_out if ub_overlap_ag else None, - } - if gemm_gelu_fusion: - fp8_gemm_args[8] = torch.uint8 # out_dtype - fp8_gemm_kwargs.update( - { - "gelu": True, - "out_index": tex.FP8FwdTensors.GEMM2_INPUT, - "fp8_meta_tensor": fp8_meta["scaling_fwd"], - "D_dtype": fp8_dtype_forward, - } + ln_out = ln_out_total + + # Cast weights to expected dtype + fc1_weight_final = fc1_weight + fc2_weight_final = fc2_weight + if not fp8: + fc1_weight_final = cast_if_needed(fc1_weight_final, activation_dtype) + fc2_weight_final = cast_if_needed(fc2_weight_final, activation_dtype) + else: + # If weights are not quantized, we call get_weight_workspace, + # which handles weight caching etc. + if not isinstance(fc1_weight, QuantizedTensor): + # FP8 cast to workspace buffer + update_workspace = is_first_microbatch is None or is_first_microbatch + fc1_weight_final = module.get_weight_workspace( + tensor=fc1_weight, + quantizer=fc1_weight_quantizer, + cache_name=(None if is_first_microbatch is None else "fc1_weight"), + update_workspace=update_workspace, + skip_update_flag=skip_fp8_weight_update, + fsdp_group=fsdp_group, ) - fp8_gemm_out = tex.fp8_gemm(*fp8_gemm_args, **fp8_gemm_kwargs) - if not is_grad_enabled: - clear_tensor_data(ln_out_total) - - # Perform activation - if gemm_gelu_fusion: - gelu_out, fc1_out = fp8_gemm_out - else: - fc1_out, _ = fp8_gemm_out - gelu_out = activation_func( - fc1_out, - fp8_meta["scaling_fwd"], - tex.FP8FwdTensors.GEMM2_INPUT, - fp8_dtype_forward, + if not isinstance(fc2_weight, QuantizedTensor): + fc2_weight_quantizer.set_usage(rowwise=True, columnwise=True) + fc2_weight_final = module.get_weight_workspace( + tensor=fc2_weight, + quantizer=fc2_weight_quantizer, + cache_name=(None if is_first_microbatch is None else "fc2_weight"), + update_workspace=update_workspace, + skip_update_flag=skip_fp8_weight_update, + fsdp_group=fsdp_group, ) - if not is_grad_enabled: - clear_tensor_data(fc1_out) - - fc2_out_index, fc2_meta_tensor, fc2_te_type, out_type = ( - None, - None, - None, - activation_dtype, - ) - rs_out = None - ub_algo_rs = None - if ub_overlap_rs: - ub_obj_fc2out = get_ub("fc2_fprop") - fc2_out = ub_obj_fc2out.get_ubuf_output(1) - dim_size = list(gelu_out.size()) - dim_size[0] = dim_size[0] // tp_world_size - dim_size[1] = fc2_weight_fp8.size(0) - rs_out = torch.empty(dim_size, dtype=activation_dtype, device=gelu_out.device) - if ub_obj_fc2out.is_p2p_overlap(): - if ub_obj_fc2out.is_atomic_gemm(): - ub_algo_rs = tex.CommOverlapAlgo.ATOMIC_GEMM_RS_P2P - else: - ub_algo_rs = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P - else: - if ub_obj_fc2out.is_atomic_gemm(): - ub_algo_rs = tex.CommOverlapAlgo.ATOMIC_GEMM_RS - else: - ub_algo_rs = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS - - if ub_obj_fc2out.is_fp8_ubuf(): - fc2_out_index = tex.FP8FwdTensors.GEMM2_OUTPUT - fc2_meta_tensor = fp8_meta["scaling_fwd"] - fc2_te_type = fp8_dtype_forward - out_type = torch.uint8 - ub_obj_fc2out.set_ubuf_scale_inv(fc2_meta_tensor.scale_inv[fc2_out_index]) + # Cast biases to expected dtype + bias_dtype = activation_dtype + if fp8 and activation_dtype == torch.float32: + bias_dtype = torch.bfloat16 + if fc1_bias is not None: + fc1_bias = cast_if_needed(fc1_bias, bias_dtype) + if fc2_bias is not None: + fc2_bias = cast_if_needed(fc2_bias, bias_dtype) + + # Calibrate quantizers if needed + if not fp8 and fp8_calibration: + if fc1_input_quantizer is not None: + fc1_input_quantizer.calibrate(ln_out_total) + if fc1_weight_quantizer is not None: + fc1_weight_quantizer.calibrate(fc1_weight) + + # FC1 GEMM + + # There are 2 fussions possible: + # - gemm_gelu_fusion - default for full precision, optional for fp8 - need to turn on gemm_gelu_fusion, + # - bias_gelu_fusion - only for full precision. + # If both gemm_gelu_fusion and bias_gelu_fusion are enabled, only bias_gelu_fusion will be performer + if activation != "gelu": + gemm_gelu_fusion = bias_gelu_fusion = False + else: + if fp8: + assert not bias_gelu_fusion, "Bias gelu fusion is supported only for full precision" else: - dim_size = list(gelu_out.size()) - dim_size[1] = fc2_weight_fp8.size(0) - fc2_out = torch.empty(dim_size, dtype=activation_dtype, device=gelu_out.device) - - _ = tex.fp8_gemm( - fc2_weight_fp8._data, - fc2_weight_fp8._scale_inv, - 0, - fc2_weight_fp8._fp8_dtype, - gelu_out, - fp8_meta["scaling_fwd"].scale_inv, - tex.FP8FwdTensors.GEMM2_INPUT, - fp8_dtype_forward, - out_type, - get_workspace(), - bias=fc2_bias, - use_bias=use_fc2_bias, - use_split_accumulator=_2X_ACC_FPROP, - out=fc2_out, - ub_algo=ub_algo_rs if ub_overlap_rs else None, - ub=ub_obj_fc2out if ub_overlap_rs else None, - extra_output_tensor=rs_out if ub_overlap_rs else None, - out_index=fc2_out_index, - fp8_meta_tensor=fc2_meta_tensor, - D_dtype=fc2_te_type, - ) - if not is_grad_enabled: - clear_tensor_data(gelu_out) + gemm_gelu_fusion = True + if gemm_gelu_fusion and bias_gelu_fusion: + gemm_gelu_fusion = False + + fc1_outputs = general_gemm( + fc1_weight_final, + ln_out_total, + get_workspace(), + quantization_params=( + fc2_input_quantizer if gemm_gelu_fusion else None # fused gelu output is in fp8 + ), + out_dtype=activation_dtype, + bias=( + fc1_bias if not bias_gelu_fusion else None + ), # otherwise bias is added later (fused with gelu) + gelu=gemm_gelu_fusion, + accumulate=_2X_ACC_FPROP, + ub=ub_obj_lnout, + ub_type=tex.CommOverlapType.AG if ub_overlap_ag else None, + ) + if not is_grad_enabled and (ln_out_total is not ln_out_return): + clear_tensor_data(ln_out_total) + + # ACTIVATION - sometimes activation is fused with the GEMM above. + + fc1_out_without_bias = None + + if bias_gelu_fusion: + fc1_out = None + fc1_out_without_bias, *_ = fc1_outputs + act_out = bias_gelu_fused(fc1_out_without_bias, fc1_bias) + elif gemm_gelu_fusion: + act_out, _, fc1_out, _ = fc1_outputs else: - # Cast for native AMP - fc1_weight = cast_if_needed(fc1_weight, activation_dtype) - fc2_weight = cast_if_needed(fc2_weight, activation_dtype) - fc1_bias = cast_if_needed(fc1_bias, activation_dtype) if use_fc1_bias else fc1_bias - fc2_bias = cast_if_needed(fc2_bias, activation_dtype) if use_fc2_bias else fc2_bias - - if fp8_calibration: - # amax of fc1 input - amin, amax = ln_out_total.aminmax() - fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_INPUT] = torch.max( - -amin, amax - ).float() - # amax of fc1 weight - amin, amax = fc1_weight.aminmax() - fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_WEIGHT] = torch.max( - -amin, amax - ).float() - - fc1_outputs = tex.gemm( - fc1_weight, - ln_out_total, - activation_dtype, - get_workspace(), - bias=fc1_bias, - use_bias=(not bias_gelu_nvfusion) and use_fc1_bias, - gelu=not bias_gelu_nvfusion and (activation == "gelu"), - ub_algo=tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P if ub_overlap_ag else None, - ub=ub_obj_lnout if ub_overlap_ag else None, - extra_output_tensor=ln_out if ub_overlap_ag else None, - ) - if not is_grad_enabled and not return_layernorm_output: - clear_tensor_data(ln_out_total) + fc1_out, *_ = fc1_outputs + act_out = activation_func(fc1_out, fc2_input_quantizer) - if bias_gelu_nvfusion: - fc1_out, _, _ = fc1_outputs - gelu_out = bias_gelu_fused(fc1_out, fc1_bias) - else: - if activation == "gelu": - gelu_out, _, fc1_out = fc1_outputs - else: - fc1_out, _, _ = fc1_outputs - gelu_out = activation_func( - fc1_out, None, tex.FP8FwdTensors.GEMM2_INPUT, TE_DType[fc1_out.dtype] - ) - if not is_grad_enabled: - clear_tensor_data(fc1_out) - - if fp8_calibration: - # amax of fc2 input - amin, amax = gelu_out.aminmax() - fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM2_INPUT] = torch.max( - -amin, amax - ).float() - # amax of fc2 weight - amin, amax = fc2_weight.aminmax() - fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM2_WEIGHT] = torch.max( - -amin, amax - ).float() - - if ub_overlap_rs: - ub_obj_fc2out = get_ub("fc2_fprop") - fc2_out = ub_obj_fc2out.get_ubuf_output(1) - dim_size = list(gelu_out.size()) - dim_size[0] = dim_size[0] // tp_world_size - dim_size[1] = fc2_weight.size(0) - rs_out = torch.empty(dim_size, dtype=activation_dtype, device=gelu_out.device) - if ub_obj_fc2out.is_p2p_overlap(): - ub_algo_rs = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P - else: - ub_algo_rs = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS - else: - dim_size = list(gelu_out.size()) - dim_size[1] = fc2_weight.size(0) - fc2_out = torch.empty(dim_size, dtype=activation_dtype, device=gelu_out.device) - _ = tex.gemm( - fc2_weight, - gelu_out, - activation_dtype, - get_workspace(), - bias=fc2_bias, - use_bias=use_fc2_bias, - out=fc2_out, - ub_algo=ub_algo_rs if ub_overlap_rs else None, - ub=ub_obj_fc2out if ub_overlap_rs else None, - extra_output_tensor=rs_out if ub_overlap_rs else None, - ) - if not is_grad_enabled: - clear_tensor_data(gelu_out) + if not is_grad_enabled: + clear_tensor_data(fc1_out) + + if fp8_calibration: + fc2_input_quantizer.calibrate(act_out) + fc2_weight_quantizer.calibrate(fc2_weight) + + ub_obj_fc2out = None + rs_out = None + fc2_out = None + if ub_overlap_rs: + ub_obj_fc2out = get_ub("fc2_fprop") + dim_size = list(act_out.size()) + dim_size[0] = dim_size[0] // tp_world_size + dim_size[1] = fc2_weight.size(0) + rs_out = torch.empty(dim_size, dtype=activation_dtype, device=device) + fc2_out = ub_obj_fc2out.get_buffer(output_quantizer) + else: + dim_size = list(act_out.size()) + dim_size[1] = fc2_weight.size(0) + fc2_out = torch.empty(dim_size, dtype=activation_dtype, device=device) + + # FC2 GEMM + _ = general_gemm( + fc2_weight_final, + act_out, + get_workspace(), + out_dtype=activation_dtype, + bias=fc2_bias, + quantization_params=output_quantizer, + out=fc2_out, + use_split_accumulator=_2X_ACC_FPROP, + ub=ub_obj_fc2out, + ub_type=tex.CommOverlapType.RS if ub_overlap_rs else None, + extra_output=rs_out, + ) + if not is_grad_enabled: + clear_tensor_data(act_out, fc1_out_without_bias, fc1_out) if is_grad_enabled: if cpu_offloading: - if fp8 and fc1_weight_fp8 is not None: - fc1_weight_fp8.weight_offloading = True - if fp8 and fc2_weight_fp8 is not None: - fc2_weight_fp8.weight_offloading = True - ln_weight.weight_offloading = True - fc1_weight.weight_offloading = True - fc2_weight.weight_offloading = True - if fc1_bias is not None: - fc1_bias.weight_offloading = True - - inputmat.activation_offloading = True - if normalization == "LayerNorm": - mu.activation_offloading = True - rsigma.activation_offloading = True - ln_out.activation_offloading = True - fc1_out.activation_offloading = True - gelu_out.activation_offloading = True + if fp8 and fc1_weight_final is not None: + set_offloading_param(fc1_weight_final, "weight_offloading", True) + if fp8 and fc2_weight_final is not None: + set_offloading_param(fc2_weight_final, "weight_offloading", True) + set_offloading_param(ln_weight, "weight_offloading", True) + set_offloading_param(fc1_weight, "weight_offloading", True) + set_offloading_param(fc2_weight, "weight_offloading", True) + set_offloading_param(fc1_bias, "weight_offloading", True) + + set_offloading_param(inputmat, "activation_offloading", True) + set_offloading_param(mu, "activation_offloading", True) + set_offloading_param(rsigma, "activation_offloading", True) + set_offloading_param(mu, "activation_offloading", True) + set_offloading_param(ln_out, "activation_offloading", True) + set_offloading_param(fc1_out, "activation_offloading", True) + set_offloading_param(fc1_out_without_bias, "activation_offloading", True) + set_offloading_param(act_out, "activation_offloading", True) # Scatter intermediate/activation tensors saved for the backward pass # NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already @@ -461,45 +424,68 @@ def forward( mu, rsigma, ln_out, - fc1_out, - gelu_out, - fc1_weight_fp8 if fp8 and not isinstance(fc1_weight, Float8Tensor) else None, - fc2_weight_fp8 if fp8 and not isinstance(fc2_weight, Float8Tensor) else None, + fc1_out_without_bias if bias_gelu_fusion else fc1_out, + act_out, + fc1_weight_final if fp8 and not isinstance(fc1_weight, Float8Tensor) else None, + fc2_weight_final if fp8 and not isinstance(fc2_weight, Float8Tensor) else None, ) - ctx.save_for_backward( + if not fc1_weight.requires_grad: + if not return_layernorm_output: + clear_tensor_data(ln_out) + ln_out = None + if not fc2_weight.requires_grad: + clear_tensor_data(act_out) + act_out = None + tensors_to_save, tensor_objects = prepare_for_saving( inputmat, ln_weight, + ln_out.clone() if ub_overlap_ag else ln_out, # avoid saving a UB buffer + fc1_weight_final, + fc1_bias, + fc1_out, + fc1_out_without_bias, + act_out, + fc2_weight_final, + fc2_bias, mu, rsigma, - ln_out if fc1_weight.requires_grad else None, - fc1_out, - gelu_out if fc2_weight.requires_grad else None, - fc1_weight, - fc1_weight_fp8, - fc1_weight.main_grad if (cpu_offloading and fuse_wgrad_accumulation) else None, - fc2_weight, - fc2_weight_fp8, - fc2_weight.main_grad if (cpu_offloading and fuse_wgrad_accumulation) else None, - fc1_bias, - fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None, ) + if fuse_wgrad_accumulation: + ctx.fc1_main_grad = fc1_weight.main_grad if fc1_weight.requires_grad else None + ctx.fc2_main_grad = fc2_weight.main_grad if fc2_weight.requires_grad else None + + ctx.save_for_backward(*tensors_to_save) + ctx.tensor_objects = tensor_objects + + ctx.grad_fc1_output_quantizer = grad_fc1_output_quantizer + ctx.grad_fc2_output_quantizer = grad_fc2_output_quantizer + ctx.grad_input_quantizer = grad_input_quantizer + ctx.fc2_input_quantizer = fc2_input_quantizer + ctx.fc1_input_quantizer = fc1_input_quantizer + + ctx.fc1_weight_requires_grad = fc1_weight.requires_grad + ctx.fc2_weight_requires_grad = fc2_weight.requires_grad + ctx.fc1_weight = fc1_weight + ctx.fc2_weight = fc2_weight + + ctx.device = device ctx.activation_dtype = activation_dtype ctx.activation = activation ctx.fp8 = fp8 - ctx.fp8_meta = fp8_meta ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation ctx.cpu_offloading = cpu_offloading ctx.is_first_microbatch = is_first_microbatch ctx.use_fc1_bias = use_fc1_bias ctx.use_fc2_bias = use_fc2_bias + ctx.use_bias = ctx.use_fc1_bias ctx.sequence_parallel = sequence_parallel ctx.tensor_parallel = tensor_parallel ctx.inp_shape = inp_shape ctx.tp_group = tp_group ctx.tp_size = tp_size - ctx.bias_gelu_nvfusion = bias_gelu_nvfusion + ctx.bias_gelu_fusion = bias_gelu_fusion ctx.return_layernorm_output = return_layernorm_output ctx.return_layernorm_output_gathered = ( return_layernorm_output_gathered and ln_out_gathered @@ -511,7 +497,10 @@ def forward( ctx.ub_bulk_dgrad = ub_bulk_dgrad ctx.ub_overlap_rs_dgrad = ub_overlap_rs_dgrad ctx.ub_overlap_ag = ub_overlap_ag - ctx.requires_dgrad = inp.requires_grad + + ctx.requires_dgrad = ( + inp.requires_grad or ln_weight.requires_grad or ln_bias.requires_grad + ) ctx.normalization = normalization ctx.reduce_and_update_bwd_fp8_tensors = False if ctx.fp8 and requires_grad( @@ -547,499 +536,366 @@ def backward( ) -> Tuple[Union[torch.Tensor, None], ...]: # pylint: disable=missing-function-docstring with torch.cuda.nvtx.range("_LayerNormMLP_backward"): - ( + if ( + ctx.fp8 + and any( + [ + ctx.ub_overlap_ag, + ctx.ub_overlap_rs_dgrad, + ctx.ub_bulk_dgrad, + ctx.ub_bulk_wgrad, + ] + ) + and not FP8GlobalStateManager.get_fp8_recipe().delayed() + ): + raise NotImplementedError( + "Comm+GEMM overlap is only supported with FP8 delayed scaling" + ) + + saved_tensors = ctx.saved_tensors + ( # pylint: disable=unbalanced-tuple-unpacking inputmat, ln_weight, - mu, - rsigma, ln_out, - fc1_out, - gelu_out, fc1_weight, - fc1_weight_fp8, - fc1_weight_main_grad, - fc2_weight, - fc2_weight_fp8, - fc2_weight_main_grad, fc1_bias, - fwd_scale_inverses, - ) = ctx.saved_tensors - - # Gather saved autograd context tensors when running with FSDP - # NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already - # shards/unshards the base weights so we don't do it ourselves - _fsdp_gather_tensors( - ctx.fsdp_group, - ctx.fsdp_shapes, + fc1_out, + fc1_out_without_bias, + act_out, + fc2_weight, + fc2_bias, mu, rsigma, - ln_out, - fc1_out, - gelu_out, - fc1_weight_fp8 if ctx.fp8 and not isinstance(fc1_weight, Float8Tensor) else None, - fc2_weight_fp8 if ctx.fp8 and not isinstance(fc2_weight, Float8Tensor) else None, + ) = restore_from_saved(ctx.tensor_objects, saved_tensors) + # Since main_grad can be modified inplace, it should not be a part of saved_tensors + fc1_weight_main_grad = ( + ctx.fc1_main_grad + if fc1_weight is not None + and ctx.fuse_wgrad_accumulation + and ctx.fc1_weight_requires_grad + else None + ) + fc2_weight_main_grad = ( + ctx.fc2_main_grad + if fc2_weight is not None + and ctx.fuse_wgrad_accumulation + and ctx.fc2_weight_requires_grad + else None ) - if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation: - fc1_weight = Parameter(fc1_weight, fc1_weight.requires_grad) - fc2_weight = Parameter(fc2_weight, fc2_weight.requires_grad) - + # For CPU offloading, we offloaded weight and weight.main_grad to different tensors, + # we need to connect them into one. + if ctx.fuse_wgrad_accumulation: fc1_weight.main_grad = fc1_weight_main_grad fc2_weight.main_grad = fc2_weight_main_grad - activation_func = _act_func(ctx.activation)[1] - - if ctx.ub_overlap_rs_dgrad: - ctx.ub_bulk_dgrad = False - ctx.ub_bulk_wgrad = False - tp_world_size = get_distributed_world_size(ctx.tp_group) - if tp_world_size == 1: - ctx.ub_overlap_rs_dgrad = False - if ctx.ub_bulk_dgrad: - tp_world_size = get_distributed_world_size(ctx.tp_group) - if tp_world_size == 1 or not fc1_weight.requires_grad: - ctx.ub_bulk_dgrad = False - if ctx.ub_bulk_dgrad: - dim_size = list(ln_out.size()) - dim_size[0] = dim_size[0] * tp_world_size - ub_obj_lnout = get_ub("fc1_dgrad") - ub_obj_lnout.copy_input_to_ubuf(ln_out, 1) - if ctx.ub_overlap_ag: - tp_world_size = get_distributed_world_size(ctx.tp_group) - if tp_world_size == 1: - ctx.ub_overlap_ag = False + # TODO: Fix this # pylint: disable=fixme + # Gather saved autograd context tensors when running with FSDP + # NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already + # shards/unshards the base weights so we don't do it ourselves + # _fsdp_gather_tensors( + # ctx.fsdp_group, + # ctx.fsdp_shapes, + # mu, + # rsigma, + # ln_out, + # fc1_out_without_bias if bias_gelu_nvfusion else fc1_out,, + # gelu_out, + # fc1_weight_fp8 if ctx.fp8 and not isinstance(fc1_weight, Float8Tensor) else None, + # fc2_weight_fp8 if ctx.fp8 and not isinstance(fc2_weight, Float8Tensor) else None, + # ) + + # No need to do bulk DGRAD/WGRAD overlap if WGRAD is not required + ctx.ub_bulk_dgrad = ctx.fc1_weight_requires_grad and ctx.ub_bulk_dgrad + ctx.ub_bulk_wgrad = ctx.fc1_weight_requires_grad and ctx.ub_bulk_wgrad + + # Prepare grad output tensor + # Note: Cast to expected dtype and perform tensor-parallel communication + if ctx.grad_fc2_output_quantizer is not None: + ctx.grad_fc2_output_quantizer.set_usage( + rowwise=True, + columnwise=True, + ) - ub_algo = None + ub_obj_fc2_dgrad = None if ctx.ub_overlap_ag: - dim_size = list(grad_outputs[0].size()) - dim_size[0] = dim_size[0] * tp_world_size - ctx.ub_obj_gradout = get_ub("fc2_dgrad") - if ctx.ub_obj_gradout.is_atomic_gemm(): - ub_algo = tex.CommOverlapAlgo.ATOMIC_GEMM_AG_P2P - else: - ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P - - ctx.use_bias = ctx.use_fc2_bias # For grad_output_preprocess + ub_obj_fc2_dgrad = get_ub("fc2_dgrad") + ctx.ub_obj_gradout = ub_obj_fc2_dgrad ( grad_output, - grad_output_c, - grad_output_t, fc2_bias_grad, - ) = TransformerEngineBaseModule.grad_output_preprocess(ctx, grad_outputs[0], True) - - if ctx.ub_bulk_wgrad: - tp_world_size = get_distributed_world_size(ctx.tp_group) - if tp_world_size == 1 or not fc1_weight.requires_grad: - ctx.ub_bulk_wgrad = False - # Column Parallel Linear - # Overlap input AG with dgrad + ) = TransformerEngineBaseModule.grad_output_preprocess( + ctx, grad_outputs[0], True, ctx.grad_fc2_output_quantizer + ) + + # Prepare FC1 GEMM input + # Note: Perform tensor-parallel communication if needed + ln_out_total = None + ln_out_total_work = None if ( - fc1_weight.requires_grad - and (not ctx.ub_bulk_dgrad) - and ctx.set_parallel_mode + ctx.fc1_weight_requires_grad + and ctx.tensor_parallel and ctx.sequence_parallel + and not ctx.ub_bulk_dgrad ): - ln_out_total, handle = gather_along_first_dim(ln_out, ctx.tp_group, async_op=True) + quantizer = None + if ctx.fp8: + quantizer = ctx.fc1_input_quantizer + quantizer.set_usage(rowwise=True, columnwise=True) + ln_out_total, ln_out_total_work = gather_along_first_dim( + ln_out, + ctx.tp_group, + async_op=True, + quantizer=quantizer, + ) else: ln_out_total = ln_out - handle = None + # Check whether to output wgrad GEMM directly into main grad if ctx.is_first_microbatch is not None: accumulate_wgrad_into_param_main_grad = ( ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch ) else: accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation + # There are 5 possible fusion paths + # 1 high-precision bias_gelu_fusion: gemm, FC1_bias + gelu, + # 2 high-precision fc2_dgrad_gemm_gelu_fusion: gemm + gelu, FC1_bias + quantize + # 3 fp8 activation+bias+quantize fusion: gemm, activation + FC1_bias + quantize + # 4 fp8 bias+quantize fusion: gemm, activation, FC1_bias + quantize + # 5 high-precision unfused: gemm, activation, FC1_bias + FC1_gemm + fc2_dgrad_gemm_gelu_fusion = ( + not ctx.fp8 and (ctx.activation == "gelu") and (not ctx.bias_gelu_fusion) + ) - fc2_wgrad = None - if ctx.fp8: - fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True) - fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False) - - # FC2 DGRAD; Unconditional - fc2_dgrad, _ = tex.fp8_gemm( - fc2_weight_fp8.transpose_2d(), - fc2_weight_fp8._scale_inv, - 0, - fc2_weight_fp8._fp8_dtype, - grad_output_c, - ctx.fp8_meta["scaling_bwd"].scale_inv, - tex.FP8BwdTensors.GRAD_OUTPUT1, - fp8_dtype_backward, - ctx.activation_dtype, - get_workspace(), - use_split_accumulator=_2X_ACC_DGRAD, - ub_algo=ub_algo if ctx.ub_overlap_ag else None, - ub=ctx.ub_obj_gradout if ctx.ub_overlap_ag else None, - ) - if ctx.ub_overlap_ag: - grad_output_t = tex.fp8_transpose(grad_output_c, fp8_dtype_backward) - clear_tensor_data(grad_output_c) - - # FC2 WGRAD - if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad: - if fc2_weight.requires_grad: - gelu_out_t = tex.fp8_transpose(gelu_out, fp8_dtype_forward) - clear_tensor_data(gelu_out) - fc2_wgrad, _ = tex.fp8_gemm( - gelu_out_t, - fwd_scale_inverses, - tex.FP8FwdTensors.GEMM2_INPUT, - fp8_dtype_forward, - grad_output_t, - ctx.fp8_meta["scaling_bwd"].scale_inv, - tex.FP8BwdTensors.GRAD_OUTPUT1, - fp8_dtype_backward, - ctx.activation_dtype, - get_workspace(), - accumulate=accumulate_wgrad_into_param_main_grad, - out=fc2_weight.main_grad if ctx.fuse_wgrad_accumulation else None, - use_split_accumulator=_2X_ACC_WGRAD, - ) - clear_tensor_data(gelu_out_t, grad_output_t) - - if ctx.activation == "gelu": - fc1_bias_grad, dgelu, dgelu_t = tex.fp8_cast_transpose_bgrad_dgelu_fused( - fc2_dgrad, - fc1_out, - ctx.fp8_meta["scaling_bwd"], - tex.FP8BwdTensors.GRAD_OUTPUT2, - fp8_dtype_backward, - ) - else: - dgelu = activation_func(fc2_dgrad, fc1_out, TE_DType[fc2_dgrad.dtype]) - fc1_bias_grad, dgelu, dgelu_t = tex.fp8_cast_transpose_bgrad_fused( - dgelu, - ctx.fp8_meta["scaling_bwd"], - tex.FP8BwdTensors.GRAD_OUTPUT2, - fp8_dtype_backward, - ) - clear_tensor_data(fc1_out) - else: - if fc2_weight.requires_grad: - gelu_out_c = torch.ops.tex_ts.cast_from_fp8_ts( - gelu_out, - fwd_scale_inverses, - tex.FP8FwdTensors.GEMM2_INPUT, - fp8_dtype_forward, - TE_DType[ctx.activation_dtype], - ) - clear_tensor_data(gelu_out) - fc2_wgrad, _, _ = tex.gemm( - gelu_out_c, - grad_output, - ctx.activation_dtype, - get_workspace(), - layout="NT", - grad=True, - use_bias=False, - accumulate=accumulate_wgrad_into_param_main_grad, - out=fc2_weight.main_grad if ctx.fuse_wgrad_accumulation else None, - ) - clear_tensor_data(gelu_out_c) - - if ctx.activation == "gelu": - fc1_bias_grad, dgelu_no_fp8 = bgrad_dgelu_fused( - fc2_dgrad, fc1_out, fc1_bias - ) - else: - dgelu_no_fp8 = activation_func( - fc2_dgrad, fc1_out, TE_DType[fc2_dgrad.dtype] - ) - fc1_bias_grad = dgelu_no_fp8.sum(dim=0) - clear_tensor_data(fc1_out) - - dgelu = tex.cast_to_fp8( - dgelu_no_fp8, - ctx.fp8_meta["scaling_bwd"], - tex.FP8BwdTensors.GRAD_OUTPUT2, - fp8_dtype_backward, - ) - dgelu_t = None + # FC2 DGRAD; Unconditional + gemm_output, *_ = general_gemm( + fc2_weight, + grad_output, + get_workspace(), + layout="NN", + grad=True, + quantization_params=( + ctx.grad_fc1_output_quantizer if fc2_dgrad_gemm_gelu_fusion else None + ), # high precision to activation + out_dtype=ctx.activation_dtype, + gelu=fc2_dgrad_gemm_gelu_fusion, + gelu_in=fc1_out if fc2_dgrad_gemm_gelu_fusion else None, + use_split_accumulator=_2X_ACC_DGRAD, + ub=ub_obj_fc2_dgrad, + ub_type=tex.CommOverlapType.AG if ctx.ub_overlap_ag else None, + ) + if fc2_dgrad_gemm_gelu_fusion: + dact = gemm_output + fc2_dgrad = None + else: + fc2_dgrad = gemm_output - out_index, meta_tensor, out_te_type, out_type = ( - None, - None, - None, - ctx.activation_dtype, - ) - fc1_dgrad_size = list(dgelu.size()) - fc1_dgrad_size[1] = fc1_weight.size(1) - # Get/alloc fc1_dgrad - if ctx.ub_bulk_wgrad: # allocate dgrad output - ub_obj_dgrad = get_ub("fc1_wgrad") - fc1_dgrad = ub_obj_dgrad.get_ubuf_output(1) # AllGather output - elif ctx.ub_overlap_rs_dgrad: - ub_obj_dgrad = get_ub("fc1_dgrad") - fc1_dgrad = ub_obj_dgrad.get_ubuf_output(1) # AllGather output - else: - fc1_dgrad = torch.empty( - fc1_dgrad_size, dtype=ctx.activation_dtype, device=fc1_weight.device - ) + # FC2 WGRAD + if ctx.fc2_weight_requires_grad: + if isinstance(act_out, QuantizedTensor): + act_out.update_usage(rowwise_usage=True, columnwise_usage=True) - # FP8 RS - if (ctx.ub_bulk_wgrad or ctx.ub_overlap_rs_dgrad) and ub_obj_dgrad.is_fp8_ubuf(): - out_index = tex.FP8BwdTensors.GRAD_INPUT2 - meta_tensor = ctx.fp8_meta["scaling_bwd"] - out_te_type = fp8_dtype_backward - out_type = torch.uint8 - ub_obj_dgrad.set_ubuf_scale_inv(meta_tensor.scale_inv[out_index]) + if isinstance(grad_output, QuantizedTensor): + grad_output.update_usage(rowwise_usage=True, columnwise_usage=True) - # Set UB algo and UB obj for fc1_dgrad bulk/pipelined overlap - rs_out = None - if ctx.ub_bulk_dgrad: - ub_algo = tex.CommOverlapAlgo.BULK_OVERLAP_AG - ub_obj = ub_obj_lnout - elif ctx.ub_overlap_rs_dgrad: - dim_size = list(dgelu.size()) - dim_size[0] = dim_size[0] // tp_world_size - dim_size[1] = fc1_weight_fp8.size(1) - rs_out = torch.empty(dim_size, dtype=ctx.activation_dtype, device=dgelu.device) - if ub_obj_dgrad.is_p2p_overlap(): - if ub_obj_dgrad.is_atomic_gemm(): - ub_algo = tex.CommOverlapAlgo.ATOMIC_GEMM_RS_P2P - else: - ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P - else: - if ub_obj_dgrad.is_atomic_gemm(): - ub_algo = tex.CommOverlapAlgo.ATOMIC_GEMM_RS - else: - ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS - ub_obj = ub_obj_dgrad - else: - ub_algo = None - ub_obj = None - # FC1 DGRAD: Unconditional - _ = tex.fp8_gemm( - fc1_weight_fp8.transpose_2d(), - fc1_weight_fp8._scale_inv, - 0, - fc1_weight_fp8._fp8_dtype, - dgelu, - ctx.fp8_meta["scaling_bwd"].scale_inv, - tex.FP8BwdTensors.GRAD_OUTPUT2, - fp8_dtype_backward, - out_type, - get_workspace(), - out=fc1_dgrad, - use_split_accumulator=_2X_ACC_DGRAD, - ub_algo=ub_algo, - ub=ub_obj, - extra_output_tensor=rs_out if ctx.ub_overlap_rs_dgrad else None, - out_index=out_index, - fp8_meta_tensor=meta_tensor, - D_dtype=out_te_type, - ) - else: - # FC2 DGRAD; Unconditional - fc2_dgrad, _, _ = tex.gemm( - fc2_weight, + fc2_wgrad, fc2_bias_grad_, *_ = general_gemm( + act_out, grad_output, - ctx.activation_dtype, get_workspace(), - layout="NN", - gelu=(not ctx.bias_gelu_nvfusion) and (ctx.activation == "gelu"), + out_dtype=ctx.activation_dtype, + quantization_params=None, # wgrad in high precision + layout="NT", grad=True, - gelu_input=fc1_out, - ub_algo=( - tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P if ctx.ub_overlap_ag else None - ), - ub=ctx.ub_obj_gradout if ctx.ub_overlap_ag else None, + bias=fc2_bias if fc2_bias is not None and fc2_bias_grad is None else None, + accumulate=accumulate_wgrad_into_param_main_grad, + use_split_accumulator=_2X_ACC_WGRAD, + out=fc2_weight.main_grad if ctx.fuse_wgrad_accumulation else None, ) + if fc2_bias_grad is None: + fc2_bias_grad = fc2_bias_grad_ + del fc2_bias_grad_ + clear_tensor_data(act_out) + + # bias computation + fc1_bias_grad = None + fuse_gemm_and_bias_fc1_wgrad = False + if ctx.grad_fc1_output_quantizer is not None: + ctx.grad_fc1_output_quantizer.set_usage(rowwise=True, columnwise=True) + if ctx.bias_gelu_fusion: + # Fusion: gemm, bias + gelu + assert ctx.activation == "gelu" + assert not ctx.fp8 + fc1_bias_grad, dact = bgrad_dgelu_fused(fc2_dgrad, fc1_out_without_bias, fc1_bias) + if ctx.grad_fc1_output_quantizer is not None: + dact = ctx.grad_fc1_output_quantizer(dact) + elif _act_func(ctx.activation)[2] is not None and ctx.fp8: + # Fusion: gemm, bias + gelu + quantize + dbias_dact_quantize_func = _act_func(ctx.activation)[2] + fc1_bias_grad, dact = dbias_dact_quantize_func( + fc2_dgrad, fc1_out.to(ctx.activation_dtype), ctx.grad_fc1_output_quantizer + ) # quantize bgrad gelu fused + else: + # Fusion: gemm + gelu, + if not fc2_dgrad_gemm_gelu_fusion: + activation_func_bwd = _act_func(ctx.activation)[1] + dact = activation_func_bwd( + fc2_dgrad, fc1_out.to(ctx.activation_dtype), None + ) # activation in high precision - # FC2 WGRAD - if fc2_weight.requires_grad: - fc2_wgrad, fc2_bias_grad, _ = tex.gemm( - gelu_out, - grad_output, - ctx.activation_dtype, - get_workspace(), - layout="NT", - grad=True, - use_bias=ctx.use_fc2_bias, - accumulate=accumulate_wgrad_into_param_main_grad, - out=fc2_weight.main_grad if ctx.fuse_wgrad_accumulation else None, - ) - clear_tensor_data(gelu_out) - - if ctx.bias_gelu_nvfusion and ctx.activation == "gelu": - fc1_bias_grad, fc2_dgrad = bgrad_dgelu_fused(fc2_dgrad, fc1_out, fc1_bias) - else: - if ctx.activation != "gelu": - fc2_dgrad = activation_func(fc2_dgrad, fc1_out, TE_DType[fc2_dgrad.dtype]) - - # For non-fp8 execution, FC1 bias gradient is fused with FC1 wgrad GEMM - # and will not be calculated in case wgrad is not required. - if not fc1_weight.requires_grad: - fc1_bias_grad = fc2_dgrad.sum(dim=0) - - # Overwrite data. Deleting the tensor does not release underlying memory. - clear_tensor_data(fc1_out) - dgelu = fc2_dgrad - - fc1_dgrad_size = list(dgelu.size()) - fc1_dgrad_size[1] = fc1_weight.size(1) - if ctx.ub_bulk_wgrad: # allocate dgrad output - ub_obj_dgrad = get_ub("fc1_wgrad") - fc1_dgrad = ub_obj_dgrad.get_ubuf_output(1) # AllGather output - elif ctx.ub_overlap_rs_dgrad: - ub_obj_dgrad = get_ub("fc1_dgrad") - fc1_dgrad = ub_obj_dgrad.get_ubuf_output(1) # AllGather output + if ctx.fp8: + fc1_bias_grad, dact = tex.bgrad_quantize(dact, ctx.grad_fc1_output_quantizer) else: - fc1_dgrad = torch.empty( - fc1_dgrad_size, dtype=ctx.activation_dtype, device=fc1_weight.device + fuse_gemm_and_bias_fc1_wgrad = ( + True # fc1_bias_grad is computed later, fused with wgrad gemm for the FC1 ) + # it may not be calculated in case wgrad is not required. + if fc1_bias is not None: + if not ctx.fc1_weight_requires_grad and fc1_bias.requires_grad: + fc1_bias_grad = dact.sum(dim=0) + + # Overwrite data. Deleting the tensor does not release underlying memory. + clear_tensor_data(fc1_out, fc1_out_without_bias) + + # Set UB algo and UB obj for fc1_dgrad/wgrad bulk/pipelined overlap + ub_obj_fc1_dgrad = None + ub_obj_fc1_wgrad = None + ub_type_fc1_dgrad = None + fc1_dgrad_shape = [reduce(multiply_op, inputmat.shape[:-1]), inputmat.shape[-1]] + fc1_dgrad_rs_out = None + fc1_dgrad_bulk = None + if ctx.ub_overlap_rs_dgrad: + # Overlap DGRAD+RS + ub_obj_fc1_dgrad = get_ub("fc1_dgrad") + ub_type_fc1_dgrad = tex.CommOverlapType.RS + fc1_dgrad_rs_out = torch.empty( + fc1_dgrad_shape, dtype=ctx.activation_dtype, device="cuda" + ) - # Set UB algo and UB obj for fc1_dgrad bulk/pipelined overlap + else: if ctx.ub_bulk_dgrad: - ub_algo = tex.CommOverlapAlgo.BULK_OVERLAP_AG - ub_obj = ub_obj_lnout - elif ctx.ub_overlap_rs_dgrad: - dim_size = list(dgelu.size()) - dim_size[0] = dim_size[0] // tp_world_size - dim_size[1] = fc1_weight.size(1) - rs_out = torch.empty(dim_size, dtype=ctx.activation_dtype, device=dgelu.device) - if ub_obj_dgrad.is_p2p_overlap(): - ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P - else: - ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS - ub_obj = ub_obj_dgrad - else: - ub_algo = None - ub_obj = None - # FC1 DGRAD: Unconditional - _ = tex.gemm( - fc1_weight, - dgelu, - ctx.activation_dtype, - get_workspace(), - out=fc1_dgrad, - layout="NN", - grad=True, - ub_algo=ub_algo, - ub=ub_obj, - extra_output_tensor=rs_out if ctx.ub_overlap_rs_dgrad else None, - ) + # Overlap ln_out all-gather with DGRAD compute + # NOTE: Copying into communication buffer will always prefer rowwise data, + # and will copy columnwise data if rowwise does not exist. In that case, + # the all-gather will apply to the leading dimension of the transpose, + # which then needs to be interleaved correctly before WGRAD. + ub_obj_fc1_dgrad = get_ub("fc1_dgrad") + ub_type_fc1_dgrad = tex.CommOverlapType.AG + ub_obj_fc1_dgrad.copy_into_buffer( + ln_out, ctx.fc1_input_quantizer, local_chunk=True + ) + + if ctx.ub_bulk_wgrad: + # Overlap FC1 DGRAD reduce-scatter with WGRAD compute + ub_obj_fc1_wgrad = get_ub("fc1_wgrad") + fc1_dgrad_bulk = ub_obj_fc1_wgrad.get_buffer(None) + + # FC1 DGRAD: Unconditional + fc1_dgrad, *_, fc1_dgrad_rs_out = general_gemm( + fc1_weight, + dact, + get_workspace(), + out=fc1_dgrad_bulk, + out_dtype=ctx.activation_dtype, + layout="NN", + grad=True, + ub=ub_obj_fc1_dgrad, + ub_type=ub_type_fc1_dgrad, + extra_output=fc1_dgrad_rs_out, + bulk_overlap=ctx.ub_bulk_dgrad, + ) - if ctx.ub_bulk_dgrad: - ln_out_total = ub_obj_lnout.get_ubuf_output(1) # Overlap dgrad-RS/AR with wgrad - if ctx.set_parallel_mode and ctx.sequence_parallel: - if not ctx.ub_bulk_dgrad and handle is not None: - handle.wait() - if not ctx.ub_bulk_wgrad and not ctx.ub_overlap_rs_dgrad: + fc1_dgrad_work = None + if ctx.ub_overlap_rs_dgrad: + fc1_dgrad = fc1_dgrad_rs_out + elif ctx.set_parallel_mode and not ctx.ub_bulk_wgrad: + if ctx.sequence_parallel: if ctx.return_layernorm_output and ctx.return_layernorm_output_gathered: fc1_dgrad = fc1_dgrad + grad_outputs[1].view_as(fc1_dgrad) - fc1_dgrad, handle = reduce_scatter_along_first_dim( - fc1_dgrad, ctx.tp_group, async_op=True + fc1_dgrad, fc1_dgrad_work = reduce_scatter_along_first_dim( + fc1_dgrad, + ctx.tp_group, + async_op=True, ) - elif ctx.set_parallel_mode and ctx.tensor_parallel: - fc1_dgrad, handle = allreduce(fc1_dgrad, ctx.tp_group, async_op=True) + elif ctx.tensor_parallel: + fc1_dgrad, fc1_dgrad_work = allreduce(fc1_dgrad, ctx.tp_group, async_op=True) + # FC1 WGRAD fc1_wgrad = None - if fc1_weight.requires_grad: - if ctx.fp8: - # FC1 WGRAD - extra_output_tensor = None - if ctx.ub_bulk_wgrad: - if ub_obj_dgrad.is_fp8_ubuf(): - dim_size = list(ub_obj_dgrad.get_ubuf_output(0).size()) # RS output - extra_output_tensor = torch.empty( - dim_size, dtype=ctx.activation_dtype, device=fc1_dgrad.device - ) - fc1_dgrad = extra_output_tensor - else: - fc1_dgrad = ub_obj_dgrad.get_ubuf_output(0) - if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad: - ln_out_total_t = tex.fp8_transpose(ln_out_total, fp8_dtype_forward) - fc1_wgrad, _ = tex.fp8_gemm( - ln_out_total_t, - fwd_scale_inverses, - tex.FP8FwdTensors.GEMM1_INPUT, - fp8_dtype_forward, - dgelu_t, - ctx.fp8_meta["scaling_bwd"].scale_inv, - tex.FP8BwdTensors.GRAD_OUTPUT2, - fp8_dtype_backward, - ctx.activation_dtype, - get_workspace(), - accumulate=accumulate_wgrad_into_param_main_grad, - out=fc1_weight.main_grad if ctx.fuse_wgrad_accumulation else None, - use_split_accumulator=_2X_ACC_WGRAD, - ub_algo=( - tex.CommOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None - ), - ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None, - extra_output_tensor=extra_output_tensor, - ) - clear_tensor_data(ln_out_total_t, dgelu_t) - else: - ln_out_total_c = torch.ops.tex_ts.cast_from_fp8_ts( - ln_out_total, - fwd_scale_inverses, - tex.FP8FwdTensors.GEMM1_INPUT, - fp8_dtype_forward, - TE_DType[ctx.activation_dtype], - ) - fc1_wgrad, _, _ = tex.gemm( - ln_out_total_c, - dgelu_no_fp8, - ctx.activation_dtype, - get_workspace(), - layout="NT", - grad=True, - accumulate=accumulate_wgrad_into_param_main_grad, - out=fc1_weight.main_grad if ctx.fuse_wgrad_accumulation else None, - ub_algo=( - tex.CommOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None - ), - ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None, - extra_output_tensor=extra_output_tensor, - ) - clear_tensor_data(ln_out_total_c, dgelu_no_fp8) + if ctx.fc1_weight_requires_grad: + if ctx.ub_bulk_dgrad: + ln_out_total = ub_obj_fc1_dgrad.get_buffer(ctx.fc1_input_quantizer) + if ctx.fp8: + if ln_out._data is None: + # All-gather executed on columnwise data and result is in rowwise data, + # so we need to fix the interleaving before WGRAD. + ln_out_total = _fix_gathered_fp8_transpose(ln_out_total, ctx.tp_size) + elif not non_tn_fp8_gemm_supported(): + # FP8 GEMM on Hopper only supports TN layout so the gathered input must + # have a valid transpose. + ln_out_total._create_transpose() + else: - # FC1 WGRAD - fc1_wgrad_outputs = tex.gemm( - ln_out_total, - dgelu, - ctx.activation_dtype, - get_workspace(), - layout="NT", - grad=True, - use_bias=not ctx.bias_gelu_nvfusion, - accumulate=accumulate_wgrad_into_param_main_grad, - out=fc1_weight.main_grad if ctx.fuse_wgrad_accumulation else None, - ub_algo=tex.CommOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None, - ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None, + if ln_out_total_work is not None: + # Synchronize tensor-parallel communication + ln_out_total_work.wait() + ln_out_total_work = None + + # Make sure GEMM inputs have expected data + if isinstance(ln_out_total, QuantizedTensor): + ln_out_total.update_usage(rowwise_usage=True, columnwise_usage=True) + if isinstance(dact, QuantizedTensor): + dact.update_usage(rowwise_usage=True, columnwise_usage=True) + + if ctx.ub_bulk_wgrad and ub_obj_fc1_wgrad.is_fp8_ubuf(): + fc1_dgrad_rs_out = torch.empty( + fc1_dgrad_shape, dtype=ctx.activation_dtype, device="cuda" ) - clear_tensor_data(ln_out_total, dgelu) - if ctx.bias_gelu_nvfusion: - fc1_wgrad, _, _ = fc1_wgrad_outputs - else: - fc1_wgrad, fc1_bias_grad, _ = fc1_wgrad_outputs - if ctx.ub_bulk_wgrad: - fc1_dgrad = ub_obj_dgrad.get_ubuf_output(0) # Reduce-scatter output + fc1_wgrad_outputs = general_gemm( + ln_out_total, + dact, + get_workspace(), + out_dtype=ctx.activation_dtype, + layout="NT", + grad=fuse_gemm_and_bias_fc1_wgrad, + bias=fc1_bias if fuse_gemm_and_bias_fc1_wgrad else None, + accumulate=accumulate_wgrad_into_param_main_grad, + out=fc1_weight.main_grad if ctx.fuse_wgrad_accumulation else None, + ub=ub_obj_fc1_wgrad, + ub_type=tex.CommOverlapType.RS if ctx.ub_bulk_wgrad else None, + extra_output=fc1_dgrad_rs_out, + bulk_overlap=ctx.ub_bulk_wgrad, + ) - # Column Parallel Linear - if ( - (not ctx.ub_bulk_wgrad) - and ctx.set_parallel_mode - and ctx.tensor_parallel - and handle is not None - ): - handle.wait() + clear_tensor_data(ln_out_total, dact) - # LayerNorm gradient - if ctx.ub_overlap_rs_dgrad: - dgrad = rs_out.view(inputmat.shape) - else: - dgrad = fc1_dgrad.view(inputmat.shape) + if fuse_gemm_and_bias_fc1_wgrad: + fc1_wgrad, fc1_bias_grad, *_ = fc1_wgrad_outputs + else: + fc1_wgrad, *_ = fc1_wgrad_outputs + + if ctx.ub_bulk_wgrad: + if ub_obj_fc1_wgrad.is_fp8_ubuf(): + fc1_dgrad = fc1_dgrad_rs_out + else: + fc1_dgrad = ub_obj_fc1_wgrad.get_buffer(None, local_chunk=True) + + # Synchronize tensor parallel communication + if ln_out_total_work is not None: + ln_out_total_work.wait() + ln_out_total_work = None + if fc1_dgrad_work is not None: + fc1_dgrad_work.wait() + fc1_dgrad_work = None # Residual gradient + dgrad = fc1_dgrad.view(inputmat.shape) if ctx.return_layernorm_output and not ctx.return_layernorm_output_gathered: dgrad = dgrad + grad_outputs[1].view_as(dgrad) + # Norm gradient dgamma = None dbeta = None if ctx.normalization == "LayerNorm": @@ -1062,10 +918,9 @@ def backward( ctx.zero_centered_gamma, ) dbeta = None - clear_tensor_data(mu) - clear_tensor_data(rsigma) + clear_tensor_data(mu, rsigma) - if fc1_weight.requires_grad: + if ctx.fc1_weight_requires_grad: # Handle custom DDP from mcore. if ctx.fuse_wgrad_accumulation and hasattr(fc1_weight, "grad_added_to_main_grad"): fc1_weight.grad_added_to_main_grad = True @@ -1077,18 +932,13 @@ def backward( requires_grad=False, ) else: - fc1_wgrad = torch.empty( - fc1_weight.main_grad.shape, - dtype=fc1_weight.dtype, - device=torch.cuda.current_device(), - requires_grad=False, - ) + fc1_wgrad = None elif ctx.fuse_wgrad_accumulation: fc1_wgrad = None else: fc1_wgrad = None - if fc2_weight.requires_grad: + if ctx.fc2_weight_requires_grad: # Handle custom DDP from mcore. if ctx.fuse_wgrad_accumulation and hasattr(fc2_weight, "grad_added_to_main_grad"): fc2_weight.grad_added_to_main_grad = True @@ -1100,12 +950,7 @@ def backward( requires_grad=False, ) else: - fc2_wgrad = torch.empty( - fc2_weight.main_grad.shape, - dtype=fc2_weight.dtype, - device=torch.cuda.current_device(), - requires_grad=False, - ) + fc2_wgrad = None elif ctx.fuse_wgrad_accumulation: fc2_wgrad = None else: @@ -1114,34 +959,37 @@ def backward( if ctx.reduce_and_update_bwd_fp8_tensors and not is_graph_capturing(): FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False) + # FIX THIS # Scatter Fp8 tranposed-weight buffers - if ctx.fp8: - _fsdp_scatter_tensors( - ctx.fsdp_group, - fc1_weight_fp8 if not isinstance(fc1_weight, Float8Tensor) else None, - fc2_weight_fp8 if not isinstance(fc2_weight, Float8Tensor) else None, - ) - + # if ctx.fp8: + # _fsdp_scatter_tensors( + # ctx.fsdp_group, + # fc1_weight_fp8 if not isinstance(fc1_weight, Float8Tensor) else None, + # fc2_weight_fp8 if not isinstance(fc2_weight, Float8Tensor) else None, + # ) return ( dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None, dgamma, dbeta, fc1_wgrad, - None, # fc1_weight_fp8 - # Due to bias gelu nvfusion available in the bf16 case, fc1_bias_grad is calculated at - # different paths and this confused the linter. - fc1_bias_grad if ctx.use_fc1_bias else None, # pylint: disable=used-before-assignment + fc1_bias_grad if ctx.use_fc1_bias else None, None, # use_fc1_bias - fc2_wgrad, - None, # fc2_weight_fp8 + fc2_wgrad, # pylint: disable=possibly-used-before-assignment fc2_bias_grad if ctx.use_fc2_bias else None, None, # use_fc2_bias None, # eps None, # is_first_microbatch None, # fp8 None, # fp8_calibration - None, # fp8_meta None, # fuse_wgrad_accumulation + None, # fc1_input_quantizer + None, # fc1_weight_quantizer + None, # fc2_input_quantizer + None, # fc2_weight_quantizer + None, # output_quantizer + None, # grad_fc2_output_quantizer + None, # grad_fc1_output_quantizer + None, # grad_input_quantizer None, # cpu_offloading None, # tp_group None, # tp_size @@ -1150,7 +998,7 @@ def backward( None, # activation_dtype None, # return_layernorm_output None, # return_layernorm_output_gathered - None, # bias_gelu_nvfusion + None, # bias_gelu_fusion None, # set_parallel_mode None, # is_grad_enabled None, # fwd_ln_sm_margin @@ -1158,13 +1006,15 @@ def backward( None, # zero_centered_gamma None, # activation None, # normalization - None, # ub_bulk_wgrad - None, # ub_bulk_dgrad - None, # ub_overlap_rs_dgrad - None, # ub_overlap_rs None, # ub_overlap_ag + None, # ub_overlap_rs + None, # ub_overlap_rs_dgrad + None, # ub_bulk_dgrad + None, # ub_bulk_wgrad None, # gemm_gelu_fusion None, # fsdp_group + None, # module + None, # skip_fp8_weight_update ) @@ -1285,11 +1135,11 @@ def __init__( set_parallel_mode: bool = False, zero_centered_gamma: bool = False, device: Union[torch.device, str] = "cuda", - ub_bulk_wgrad: bool = False, - ub_bulk_dgrad: bool = False, - ub_overlap_rs_dgrad: bool = False, - ub_overlap_rs: bool = False, ub_overlap_ag: bool = False, + ub_overlap_rs: bool = False, + ub_overlap_rs_dgrad: bool = False, + ub_bulk_dgrad: bool = False, + ub_bulk_wgrad: bool = False, ) -> None: super().__init__() @@ -1308,11 +1158,7 @@ def __init__( ) self.set_parallel_mode = set_parallel_mode self.zero_centered_gamma = zero_centered_gamma - self.ub_bulk_wgrad = ub_bulk_wgrad - self.ub_bulk_dgrad = ub_bulk_dgrad - self.ub_overlap_rs_dgrad = ub_overlap_rs_dgrad - self.ub_overlap_rs = ub_overlap_rs - self.ub_overlap_ag = ub_overlap_ag + # GEMM-GELU fusion is currently only supported with split GEMM-AG overlap self.gemm_gelu_fusion = ( bool(int(os.getenv("NVTE_GEMM_GELU_FUSION", "0"))) @@ -1337,6 +1183,16 @@ def __init__( self.sequence_parallel = (self.tp_size > 1) and sequence_parallel self.size_per_partition = divide(ffn_hidden_size, self.tp_size) + self.ub_overlap_ag = ub_overlap_ag and self.sequence_parallel + self.ub_overlap_rs = ub_overlap_rs and self.sequence_parallel + self.ub_overlap_rs_dgrad = ub_overlap_rs_dgrad and self.sequence_parallel + self.ub_bulk_wgrad = ( + ub_bulk_wgrad and self.sequence_parallel and not self.ub_overlap_rs_dgrad + ) + self.ub_bulk_dgrad = ( + ub_bulk_dgrad and self.sequence_parallel and not self.ub_overlap_rs_dgrad + ) + # Initialize params in FP8 with_fp8_params = FP8GlobalStateManager.with_fp8_parameters() @@ -1357,7 +1213,7 @@ def __init__( self.layer_norm_bias = None # FC1 init - if self.activation in ["reglu", "geglu", "swiglu"]: + if self.activation in ["reglu", "geglu", "qgeglu", "swiglu"]: fc1_output_features = 2 * self.size_per_partition else: fc1_output_features = self.size_per_partition @@ -1491,61 +1347,30 @@ def forward( if skip_fp8_weight_update is not None: is_first_microbatch = False - with self.prepare_forward(inp, is_first_microbatch, num_gemms=2) as inp: + with self.prepare_forward(inp, num_gemms=2) as inp: + # Get quantizers + ( + fc1_input_quantizer, + fc1_weight_quantizer, + fc2_input_quantizer, + fc2_weight_quantizer, + output_quantizer, + grad_fc1_output_quantizer, + grad_fc2_output_quantizer, + grad_input_quantizer, + ) = self._get_quantizers() # Get weight tensors fc1_weight = self.fc1_weight - fc1_bias = self.fc1_bias + fc1_bias = self.fc1_bias if self.use_bias else None fc2_weight = self.fc2_weight - fc2_bias = self.fc2_bias + fc2_bias = self.fc2_bias if self.use_bias else None if not self.fp8: if isinstance(fc1_weight, Float8Tensor): fc1_weight = fc1_weight.from_float8() if isinstance(fc2_weight, Float8Tensor): fc2_weight = fc2_weight.from_float8() - # Cast weights to FP8 if needed - fc1_weight_fp8 = None - fc2_weight_fp8 = None - if self.fp8: - update_workspace = is_first_microbatch is None or is_first_microbatch - if isinstance(fc1_weight, Float8Tensor): - if fc1_weight._transpose is not None: - fc1_weight.transpose_2d( - fill_cache=True, - noop_flag=skip_fp8_weight_update, - ) - else: - cache_name = None - if is_first_microbatch is not None: - cache_name = "fc1_weight" - fc1_weight_fp8 = self.get_fp8_workspace( - tensor=fc1_weight, - fp8_meta_forward=True, - fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT, - cache_name=cache_name, - update_workspace=update_workspace, - skip_update_flag=skip_fp8_weight_update, - ) - if isinstance(fc2_weight, Float8Tensor): - if fc2_weight._transpose is not None: - fc2_weight.transpose_2d( - fill_cache=True, - noop_flag=skip_fp8_weight_update, - ) - else: - cache_name = None - if is_first_microbatch is not None: - cache_name = "fc2_weight" - fc2_weight_fp8 = self.get_fp8_workspace( - tensor=fc2_weight, - fp8_meta_forward=True, - fp8_meta_index=tex.FP8FwdTensors.GEMM2_WEIGHT, - cache_name=cache_name, - update_workspace=update_workspace, - skip_update_flag=skip_fp8_weight_update, - ) - # Disable bias_gelu_nvfusion for determinism checkpointing in non-reentrant mode if self.bias_gelu_nvfusion and not use_reentrant_activation_recompute(): self.bias_gelu_nvfusion = False @@ -1561,19 +1386,24 @@ def forward( self.layer_norm_weight, self.layer_norm_bias, fc1_weight, - fc1_weight_fp8, fc1_bias, self.use_bias, fc2_weight, - fc2_weight_fp8, fc2_bias, self.apply_bias and not self.gemm_bias_unfused_add, self.eps, is_first_microbatch, self.fp8, self.fp8_calibration, - self.fp8_meta, self.fuse_wgrad_accumulation, + fc1_input_quantizer, + fc1_weight_quantizer, + fc2_input_quantizer, + fc2_weight_quantizer, + output_quantizer, + grad_input_quantizer, + grad_fc1_output_quantizer, + grad_fc2_output_quantizer, is_cpu_offload_enabled(), self.tp_group, self.tp_size, @@ -1582,7 +1412,7 @@ def forward( self.activation_dtype, self.return_layernorm_output, self.return_layernorm_output_gathered, - self.bias_gelu_nvfusion, + self.bias_gelu_nvfusion and not self.fp8, self.set_parallel_mode, torch.is_grad_enabled(), self.fwd_ln_sm_margin if torch.is_grad_enabled() else self.inf_ln_sm_margin, @@ -1590,13 +1420,15 @@ def forward( self.zero_centered_gamma, self.activation, self.normalization, - self.ub_bulk_wgrad, - self.ub_bulk_dgrad, - self.ub_overlap_rs_dgrad, - self.ub_overlap_rs, self.ub_overlap_ag, + self.ub_overlap_rs, + self.ub_overlap_rs_dgrad, + self.ub_bulk_dgrad, + self.ub_bulk_wgrad, self.gemm_gelu_fusion, self.fsdp_group, + self, + skip_fp8_weight_update, ) out = fwd_fn(*args) @@ -1613,3 +1445,48 @@ def forward( if self.return_layernorm_output: return out, ln_out return out + + def _get_quantizers(self): + ( + fc1_input_quantizer, + fc1_weight_quantizer, + fc2_input_quantizer, + fc2_weight_quantizer, + output_quantizer, + grad_fc1_output_quantizer, + grad_fc2_output_quantizer, + grad_input_quantizer, + ) = [None] * 8 + if self.fp8: + fc1_input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT] + fc1_input_quantizer.internal = False # temporary + fc1_weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT] + fc1_weight_quantizer.internal = True + fc2_input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_INPUT] + fc2_input_quantizer.set_usage( + rowwise=True, columnwise=isinstance(fc2_input_quantizer, MXFP8Quantizer) + ) + fc2_weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_WEIGHT] + fc2_weight_quantizer.internal = True + if torch.is_grad_enabled(): + grad_fc2_output_quantizer = self.quantizers["scaling_bwd"][ + tex.FP8BwdTensors.GRAD_OUTPUT1 + ] + grad_fc2_output_quantizer.internal = True + grad_fc1_output_quantizer = self.quantizers["scaling_bwd"][ + tex.FP8BwdTensors.GRAD_INPUT1 + ] + grad_fc1_output_quantizer.internal = True + grad_input_quantizer = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_INPUT2] + grad_input_quantizer.internal = True + + return ( + fc1_input_quantizer, + fc1_weight_quantizer, + fc2_input_quantizer, + fc2_weight_quantizer, + output_quantizer, + grad_fc1_output_quantizer, + grad_fc2_output_quantizer, + grad_input_quantizer, + ) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 5893c4ea3c..460ce87bc6 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -3,9 +3,9 @@ # See LICENSE for license information. """Linear API""" +from typing import Callable, Dict, Optional, Tuple, Union from functools import reduce from operator import mul as multiply_op -from typing import Any, Callable, Dict, Optional, Tuple, Union import torch @@ -19,15 +19,15 @@ _2X_ACC_DGRAD, _2X_ACC_WGRAD, ) -from ._common import _noop_cat -from ..fp8 import get_fp8_te_dtype, FP8GlobalStateManager +from ._common import noop_cat, _fix_gathered_fp8_transpose +from ..fp8 import FP8GlobalStateManager from ..utils import ( divide, cast_if_needed, - assert_dim_for_fp8_exec, clear_tensor_data, init_method_constant, requires_grad, + non_tn_fp8_gemm_supported, ) from ..distributed import ( set_tensor_model_parallel_attributes, @@ -35,23 +35,25 @@ allreduce, reduce_scatter_along_first_dim, gather_along_first_dim, + is_fp8_activation_recompute_enabled, in_fp8_activation_recompute_phase, _fsdp_scatter_tensors, _fsdp_gather_tensors, ) from ..cpp_extensions import ( - fp8_gemm, - gemm, - fp8_cast_transpose_fused, - cast_to_fp8, + general_gemm, ) -from ..constants import GemmParallelModes, dist_group_type, TE_DType +from ..constants import GemmParallelModes, dist_group_type from ..jit import no_torch_dynamo from ..graph import is_graph_capturing -from ..float8_tensor import Float8Tensor -from ..export import is_in_onnx_export_mode -from ..tensor import QuantizedTensor -from ..cpu_offload import is_cpu_offload_enabled +from ..tensor.quantized_tensor import ( + QuantizedTensor, + Quantizer, + prepare_for_saving, + restore_from_saved, +) + +from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param __all__ = ["Linear"] @@ -64,15 +66,17 @@ class _Linear(torch.autograd.Function): @staticmethod def forward( ctx, - weight: Union[Float8Tensor, torch.Tensor], - weight_fp8: Optional[Float8Tensor], + weight: torch.Tensor, inp: torch.Tensor, - bias: torch.Tensor, - use_bias: bool, + bias: Optional[torch.Tensor], is_first_microbatch: Union[bool, None], fp8: bool, fp8_calibration: bool, - fp8_meta: Dict[str, Any], + input_quantizer: Optional[Quantizer], + weight_quantizer: Optional[Quantizer], + output_quantizer: Optional[Quantizer], + grad_output_quantizer: Optional[Quantizer], + grad_input_quantizer: Optional[Quantizer], fuse_wgrad_accumulation: bool, cpu_offloading: bool, tp_group: Union[dist_group_type, None], @@ -89,293 +93,186 @@ def forward( ub_bulk_dgrad: bool, ub_bulk_wgrad: bool, ub_name: str, - fp8_output: bool, + fp8_output: bool, # pylint: disable=unused-argument fsdp_group: Union[dist_group_type, None], + module: torch.nn.Module, + skip_fp8_weight_update: bool, ) -> torch.Tensor: # pylint: disable=missing-function-docstring - is_input_fp8 = isinstance(inp, Float8Tensor) # Make sure input dimensions are compatible out_features, in_features = weight.shape inp_shape = inp.shape assert inp_shape[-1] == in_features, "GEMM not possible" - inputmat = inp.view(-1, in_features) - if fp8: - assert_dim_for_fp8_exec(inputmat) - assert_dim_for_fp8_exec(weight) tp_world_size = get_distributed_world_size(tp_group) - ub_overlap_ag_fprop = False if tp_world_size == 1 else ub_overlap_ag_fprop - ub_overlap_rs_fprop = False if tp_world_size == 1 else ub_overlap_rs_fprop - - # Cast input to expected dtype - inputmat = cast_if_needed(inputmat, activation_dtype) - inputmat_t = None - inputmat_no_fp8 = inputmat - inputmat_scale_inv = None - - if fp8: - fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) - if isinstance(inputmat, Float8Tensor): - inputmat_scale_inv = inputmat._scale_inv - else: - inputmat_scale_inv = torch.empty([1], dtype=torch.float32, device=inputmat.device) - if ( - not fp8_meta["recipe"].override_linear_precision.wgrad - and is_grad_enabled - and weight.requires_grad - and not sequence_parallel - ): - # FP8 input for forward, FP8 input transpose for backward wgrad - inputmat, inputmat_t = fp8_cast_transpose_fused( - inputmat, - fp8_meta["scaling_fwd"], - tex.FP8FwdTensors.GEMM1_INPUT, - fp8_dtype_forward, - scale_inv=inputmat_scale_inv, - ) - else: - # FP8 input for forward - inputmat = cast_to_fp8( - inputmat, - fp8_meta["scaling_fwd"], - tex.FP8FwdTensors.GEMM1_INPUT, - fp8_dtype_forward, - scale_inv=inputmat_scale_inv, - ) - - # Hack for ONNX export - # Note: ONNX models are represented as a graph of tensor - # operations, so the in-place scale-inv update doesn't fit - # very well. We work around this by making it look like - # the scale-inv tensor is initialized with a copy. - # Note: ONNX export expects FP8 scales can be represented - # with constant ops. However, copying into a buffer - # involves an expand op for array broadcasting. We work - # around this by filling the buffer instead. - if is_in_onnx_export_mode(): - inputmat_scale_inv.fill_(inputmat_scale_inv.item()) - - # Column Parallel Linear - if parallel_mode == "column" and sequence_parallel and not ub_overlap_ag_fprop: - inputmat_total, _ = gather_along_first_dim(inputmat, tp_group) - else: - inputmat_total = inputmat - + backward_needs_input = is_grad_enabled and weight.requires_grad + + # Prepare input tensor + # Note: Cast to expected dtype and perform tensor-parallel communication + inputmat = inp + inputmat_total = None + with_input_all_gather_nccl = ( + parallel_mode == "column" and sequence_parallel and not ub_overlap_ag_fprop + ) + own_quantized_input = False if fp8: - bias_dtype = torch.bfloat16 if activation_dtype == torch.float32 else activation_dtype - bias = cast_if_needed(bias, bias_dtype) if use_bias else bias - - # Use FP8 weights - if weight_fp8 is None: - weight_fp8 = weight - - assert isinstance(weight_fp8, Float8Tensor) - - if fp8_output: - out_index, meta_tensor, out_tedtype, out_pttype = ( - tex.FP8FwdTensors.GEMM1_OUTPUT, - fp8_meta["scaling_fwd"], - fp8_dtype_forward, - torch.uint8, - ) - else: - out_index, meta_tensor, out_tedtype, out_pttype = ( - None, - None, - None, - activation_dtype, + if ( + any([ub_overlap_ag_fprop, ub_overlap_rs_fprop]) + and not FP8GlobalStateManager.get_fp8_recipe().delayed() + ): + raise NotImplementedError( + "Comm+GEMM overlap is only supported with FP8 delayed scaling" ) - ub_obj = None - ub_algo = None - rs_out = None - inputmat_data = ( - inputmat_total._data if isinstance(inputmat_total, Float8Tensor) else inputmat_total - ) - if ub_overlap_rs_fprop: - ub_obj = get_ub(ub_name + "_fprop") - out = ub_obj.get_ubuf_output(1) - dim_size = list(inputmat_total.size()) - dim_size[0] = dim_size[0] // tp_world_size - dim_size[1] = out_features - rs_out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device) - if ub_obj.is_p2p_overlap(): - if ub_obj.is_atomic_gemm(): - ub_algo = tex.CommOverlapAlgo.ATOMIC_GEMM_RS_P2P - else: - ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P - else: - if ub_obj.is_atomic_gemm(): - ub_algo = tex.CommOverlapAlgo.ATOMIC_GEMM_RS - else: - ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS - if ub_obj.is_fp8_ubuf(): - out_index = tex.FP8FwdTensors.GEMM1_OUTPUT - meta_tensor = fp8_meta["scaling_fwd"] - out_tedtype = fp8_dtype_forward - out_pttype = torch.uint8 - ub_obj.set_ubuf_scale_inv(meta_tensor.scale_inv[out_index]) - - elif ub_overlap_ag_fprop: - ub_obj = get_ub(ub_name + "_fprop") - assert ub_obj.is_fp8_ubuf(), "AG overlap with FP8 GEMM requires FP8 buffer." - ub_obj.copy_input_to_ubuf(inputmat_data, True) - ub_obj.set_ubuf_scale_inv(inputmat_scale_inv) - if ub_obj.is_atomic_gemm(): - ub_algo = tex.CommOverlapAlgo.ATOMIC_GEMM_AG_P2P - else: - ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P - out_tedtype = TE_DType[activation_dtype] - out_pttype = activation_dtype - dim_size = list(inputmat_total.size()) - dim_size[0] *= tp_size - dim_size[1] = out_features - out = torch.empty(dim_size, dtype=out_pttype, device=inputmat_total.device) - + if input_quantizer is None: + raise ValueError("Missing quantizer for input tensor") + if with_input_all_gather_nccl: + assert not isinstance( + inputmat, QuantizedTensor + ), "All gather of fp8 input is not supported" + input_quantizer.set_usage(rowwise=True, columnwise=False) + inputmat_total, _ = gather_along_first_dim( + inputmat, + tp_group, + quantizer=input_quantizer, + ) else: - dim_size = list(inputmat_total.size()) - dim_size[1] = out_features - out = torch.empty(dim_size, dtype=out_pttype, device=inputmat_total.device) - - _ = fp8_gemm( - weight_fp8._data, - weight_fp8._scale_inv, - 0, - weight_fp8._fp8_dtype, - inputmat_data, - inputmat_scale_inv, - 0, - fp8_dtype_forward, - out_pttype, - get_workspace(), - bias=bias, - use_bias=use_bias, - use_split_accumulator=_2X_ACC_FPROP, - out=out, - ub_algo=ub_algo, - ub=ub_obj, - extra_output_tensor=rs_out, - out_index=out_index, - fp8_meta_tensor=meta_tensor, - D_dtype=out_tedtype, - ) - if fp8_output: - out = Float8Tensor( - data=out, - fp8_meta=fp8_meta, - fp8_meta_forward=True, - fp8_meta_index=tex.FP8FwdTensors.GEMM1_OUTPUT, - fp8_dtype=fp8_dtype_forward, - dtype=activation_dtype, + input_quantizer.set_usage( + rowwise=True, + columnwise=backward_needs_input, ) + if not isinstance(inputmat, QuantizedTensor): + inputmat = input_quantizer(inputmat) + elif backward_needs_input: + inputmat.update_usage(rowwise_usage=True, columnwise_usage=True) + inputmat_total = inputmat else: - # Cast for native AMP - weight = cast_if_needed(weight, activation_dtype) - bias = cast_if_needed(bias, activation_dtype) if use_bias else bias - - if fp8_calibration: - # amax of input - amin, amax = inputmat_total.aminmax() - fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_INPUT] = torch.max( - -amin, amax - ).float() - # amax of weight - amin, amax = weight.aminmax() - fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_WEIGHT] = torch.max( - -amin, amax - ).float() - - ub_obj = None - ub_algo = None - rs_out = None - if ub_overlap_rs_fprop: - ub_obj = get_ub(ub_name + "_fprop") - out = ub_obj.get_ubuf_output(1) - dim_size = list(inputmat_total.size()) - dim_size[0] = dim_size[0] // tp_world_size - dim_size[1] = out_features - rs_out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device) - if ub_obj.is_p2p_overlap(): - ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P - else: - ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS - - elif ub_overlap_ag_fprop: - ub_obj = get_ub(ub_name + "_fprop") - ub_obj.copy_input_to_ubuf(inputmat_total, True) - dim_size = list(inputmat_total.size()) - dim_size[0] *= tp_size # all-gathered sequence length - dim_size[1] = out_features - out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device) - ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P - + inputmat = cast_if_needed(inp, activation_dtype) + if with_input_all_gather_nccl: + inputmat_total, _ = gather_along_first_dim(inputmat, tp_group) else: - dim_size = list(inputmat_total.size()) - dim_size[1] = out_features - out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device) + inputmat_total = inputmat - _ = gemm( - weight, - inputmat_total, - activation_dtype, - get_workspace(), - bias=bias, - use_bias=use_bias, - out=out, - ub_algo=ub_algo, - ub=ub_obj, - extra_output_tensor=rs_out, - ) + # Cast weight to expected dtype + weightmat = weight + if not fp8: + weightmat = cast_if_needed(weightmat, activation_dtype) + else: + if not isinstance(weight, QuantizedTensor): + # Configure quantizer + if weight_quantizer is not None: + columnwise_usage = is_grad_enabled and inp.requires_grad + if not columnwise_usage: + columnwise_usage = ( + is_fp8_activation_recompute_enabled() + and not in_fp8_activation_recompute_phase() + ) + weight_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) + + # FP8 cast to workspace buffer + update_workspace = is_first_microbatch is None or is_first_microbatch + weightmat = module.get_weight_workspace( + tensor=weight, + quantizer=weight_quantizer, + cache_name=(None if is_first_microbatch is None else "weight"), + update_workspace=update_workspace, + skip_update_flag=skip_fp8_weight_update, + fsdp_group=fsdp_group, + ) + + # Cast bias to expected dtype + bias_dtype = activation_dtype + if fp8 and activation_dtype == torch.float32: + bias_dtype = torch.bfloat16 + bias = cast_if_needed(bias, bias_dtype) if bias is not None else bias + + # Configure output quantizer + if output_quantizer is not None: + output_quantizer.set_usage(rowwise=True, columnwise=False) + + # Calibrate quantizers if needed + if not fp8 and fp8_calibration: + if input_quantizer is not None: + input_quantizer.calibrate(inputmat_total) + if weight_quantizer is not None: + weight_quantizer.calibrate(weight) + + ub_obj = None + ub_type = None + rs_out = None + out_dtype = activation_dtype + if ub_overlap_rs_fprop: + ub_obj = get_ub(ub_name + "_fprop") + ub_type = tex.CommOverlapType.RS + out_shape = [reduce(multiply_op, inp_shape[:-1]) // tp_world_size, out_features] + rs_out = torch.empty(out_shape, dtype=activation_dtype, device=inputmat_total.device) + + elif ub_overlap_ag_fprop: + ub_obj = get_ub(ub_name + "_fprop") + ub_type = tex.CommOverlapType.AG + if fp8: + assert ub_obj.is_fp8_ubuf(), "AG overlap with FP8 GEMM inputs requires FP8 buffer." + ub_obj.copy_into_buffer(inputmat_total, input_quantizer, local_chunk=True) + inputmat_total = ub_obj.get_buffer(input_quantizer) + + out, *_, rs_out = general_gemm( + weightmat, + inputmat_total, + get_workspace(), + quantization_params=output_quantizer, + out_dtype=out_dtype, + bias=bias, + use_split_accumulator=_2X_ACC_FPROP, + ub=ub_obj, + ub_type=ub_type, + extra_output=rs_out, + ) if is_grad_enabled: saved_inputmat = None - saved_inputmat_t = None - if weight.requires_grad: - if fp8 and not fp8_meta["recipe"].override_linear_precision.wgrad: - if inputmat_t is None: - saved_inputmat = inputmat - else: - saved_inputmat_t = inputmat_t - if cpu_offloading: - saved_inputmat_t.activation_offloading = True - else: - saved_inputmat = inputmat_no_fp8 + if backward_needs_input: + if own_quantized_input and isinstance(inputmat, QuantizedTensor): + inputmat.update_usage(rowwise_usage=False) + saved_inputmat = inputmat - if cpu_offloading: - if fp8 and weight_fp8 is not None: - weight_fp8.weight_offloading = True - weight.weight_offloading = True - - if saved_inputmat is not None: - saved_inputmat.activation_offloading = True + if cpu_offloading: + set_offloading_param(weight, "weight_offloading", True) + set_offloading_param(weightmat, "weight_offloading", True) + if saved_inputmat is not None: + set_offloading_param(saved_inputmat, "activation_offloading", True) # Scatter intermediate/activation tensors saved for the backward pass # NOTE: FSDP sharding is not valid for models initialized with primary Fp8 weights ctx.fsdp_group = fsdp_group ctx.fsdp_shapes = _fsdp_scatter_tensors( fsdp_group, - saved_inputmat, # None if fp8 == False - saved_inputmat_t, # None if fp8 == False AND not is_grad_enabled - weight_fp8 if fp8 and not isinstance(weight, Float8Tensor) else None, + saved_inputmat, + weightmat if fp8 and not isinstance(weight, QuantizedTensor) else None, ) - ctx.save_for_backward( + # TODO(ksivamani): Check memory usage + tensors_to_save, tensor_objects = prepare_for_saving( saved_inputmat, - saved_inputmat_t, - inputmat_scale_inv, + weightmat, weight, - weight_fp8, - weight.main_grad if cpu_offloading and fuse_wgrad_accumulation else None, + bias, ) + ctx.save_for_backward(*tensors_to_save) + ctx.tensor_objects = tensor_objects ctx.activation_dtype = activation_dtype ctx.fp8 = fp8 - ctx.fp8_meta = fp8_meta + ctx.input_quantizer = input_quantizer + ctx.grad_output_quantizer = grad_output_quantizer + ctx.grad_input_quantizer = grad_input_quantizer ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation + if fuse_wgrad_accumulation and weight.requires_grad: + ctx.main_grad = weight.main_grad + ctx.cpu_offloading = cpu_offloading ctx.is_first_microbatch = is_first_microbatch - ctx.use_bias = use_bias + ctx.use_bias = bias is not None ctx.sequence_parallel = sequence_parallel ctx.tensor_parallel = tensor_parallel ctx.inp_shape = inp_shape @@ -388,8 +285,10 @@ def forward( ctx.ub_name = ub_name ctx.tp_size = tp_size ctx.requires_dgrad = inp.requires_grad - ctx.is_input_fp8 = is_input_fp8 + ctx.requires_wgrad = weight.requires_grad ctx.reduce_and_update_bwd_fp8_tensors = False + ctx.owns_input = saved_inputmat is not inp + ctx.is_input_fp8 = not own_quantized_input if ctx.fp8 and requires_grad(inp, weight, bias): _first_fp8_module = FP8GlobalStateManager.IS_FIRST_FP8_MODULE ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module() @@ -397,34 +296,53 @@ def forward( FP8GlobalStateManager.IS_FIRST_FP8_MODULE = _first_fp8_module # Row Parallel Linear - if parallel_mode == "row": - if ub_overlap_rs_fprop: - out = rs_out - elif sequence_parallel: + if ub_overlap_rs_fprop: + out = rs_out + elif parallel_mode == "row": + if sequence_parallel: out, _ = reduce_scatter_along_first_dim(out, tp_group) elif tensor_parallel: out, _ = allreduce(out, tp_group) - # [*, in_features] -> [*, out_features] except first dimension changes for SP - return out.view(-1, *inp_shape[1:-1], out_features) + out = out.view(-1, *inp_shape[1:-1], out_features) + return out @staticmethod def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: # pylint: disable=missing-function-docstring - if isinstance(grad_output, Float8Tensor): - ctx.fp8_meta["scaling_bwd"].scale_inv[ - tex.FP8BwdTensors.GRAD_OUTPUT1 - ] = grad_output._scale_inv with torch.cuda.nvtx.range("_Linear_backward"): - ( - inputmat, - inputmat_t, - inputmat_scale_inv, - weight, - weight_fp8, - main_grad, - ) = ctx.saved_tensors + if ( + ctx.fp8 + and any( + [ + ctx.ub_overlap_ag, + ctx.ub_overlap_rs_dgrad, + ctx.ub_bulk_dgrad, + ctx.ub_bulk_wgrad, + ] + ) + and not FP8GlobalStateManager.get_fp8_recipe().delayed() + ): + raise NotImplementedError( + "Comm+GEMM overlap is only supported with FP8 delayed scaling" + ) + + saved_tensors = ctx.saved_tensors + inputmat, weight_fp8, weight, bias = ( # pylint: disable=unbalanced-tuple-unpacking + restore_from_saved(ctx.tensor_objects, saved_tensors) + ) + + # Since main_grad can be modified inplace, it should not be a part of saved_tensors + main_grad = ( + ctx.main_grad + if weight is not None and ctx.fuse_wgrad_accumulation and ctx.requires_wgrad + else None + ) + + if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation: + weight = torch.nn.Parameter(weight, weight.requires_grad) + weight.main_grad = main_grad # Gather intermediate/activation tensors if needed # NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already @@ -433,105 +351,89 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ctx.fsdp_group, ctx.fsdp_shapes, inputmat, - inputmat_t, - weight_fp8 if ctx.fp8 and not isinstance(weight, Float8Tensor) else None, + weight_fp8, ) - if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation: - weight = torch.nn.Parameter(weight, weight.requires_grad) - weight.main_grad = main_grad - - tp_world_size = get_distributed_world_size(ctx.tp_group) - ctx.ub_overlap_ag = False if tp_world_size == 1 else ctx.ub_overlap_ag - ctx.ub_overlap_rs_dgrad = False if tp_world_size == 1 else ctx.ub_overlap_rs_dgrad - ctx.ub_bulk_dgrad = False if tp_world_size == 1 else ctx.ub_bulk_dgrad - ctx.ub_bulk_wgrad = False if tp_world_size == 1 else ctx.ub_bulk_wgrad - ctx.ub_obj_gradout = None + ub_obj_dgrad = None ub_obj_wgrad = None - ub_algo_wgrad = None - ub_algo_dgrad = None - rs_out = None - dgrad = None + ub_type_dgrad = None + ub_type_wgrad = None dgrad_shape = [reduce(multiply_op, ctx.inp_shape[:-1]), ctx.inp_shape[-1]] + rs_out = None + dgrad_bulk = None if ctx.ub_overlap_ag: # Overlap grad_output all-gather with dgrad compute ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad") - if ctx.ub_obj_gradout.is_atomic_gemm(): - ub_algo_dgrad = tex.CommOverlapAlgo.ATOMIC_GEMM_AG_P2P - else: - ub_algo_dgrad = tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P - dgrad = torch.empty( - dgrad_shape, dtype=ctx.activation_dtype, device=grad_output.device - ) + ub_obj_dgrad = ctx.ub_obj_gradout + ub_type_dgrad = tex.CommOverlapType.AG elif ctx.ub_overlap_rs_dgrad: # Overlap dgrad reduce-scatter with dgrad compute ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad") - dgrad = ctx.ub_obj_gradout.get_ubuf_output(1) - if ctx.ub_obj_gradout.is_p2p_overlap(): - if ctx.ub_obj_gradout.is_atomic_gemm(): - ub_algo_dgrad = tex.CommOverlapAlgo.ATOMIC_GEMM_RS_P2P - else: - ub_algo_dgrad = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P - else: - if ctx.ub_obj_gradout.is_atomic_gemm(): - ub_algo_dgrad = tex.CommOverlapAlgo.ATOMIC_GEMM_RS - else: - ub_algo_dgrad = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS + ub_obj_dgrad = ctx.ub_obj_gradout + ub_type_dgrad = tex.CommOverlapType.RS rs_out = torch.empty( dgrad_shape, dtype=ctx.activation_dtype, device=grad_output.device ) - ctx.ub_bulk_dgrad = False - ctx.ub_bulk_wgrad = False else: if ctx.ub_bulk_dgrad: # Overlap inputmat all-gather with dgrad compute - ub_algo_dgrad = tex.CommOverlapAlgo.BULK_OVERLAP_AG + # NOTE: Copying into communication buffer will always prefer rowwise data, + # and will copy columnwise data if rowwise does not exist. In that case, + # the all-gather will apply to the leading dimension of the transpose, + # which then needs to be interleaved correctly before WGRAD. ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad") - inputmat_data = ( - inputmat._data if isinstance(inputmat, Float8Tensor) else inputmat - ) - ctx.ub_obj_gradout.copy_input_to_ubuf(inputmat_data, True) - inputmat_ubuf = ctx.ub_obj_gradout.get_ubuf_output(1) - if isinstance(inputmat, Float8Tensor): - inputmat._data = inputmat_ubuf - else: - inputmat = inputmat_ubuf + ub_obj_dgrad = ctx.ub_obj_gradout + ub_type_dgrad = tex.CommOverlapType.AG + ub_obj_dgrad.copy_into_buffer(inputmat, ctx.input_quantizer, local_chunk=True) if ctx.ub_bulk_wgrad: # Overlap dgrad reduce-scatter with wgrad compute - ub_algo_wgrad = tex.CommOverlapAlgo.BULK_OVERLAP_RS ub_obj_wgrad = get_ub(ctx.ub_name + "_wgrad") - dgrad = ub_obj_wgrad.get_ubuf_output(1) - + ub_type_wgrad = tex.CommOverlapType.RS + ub_obj_wgrad.set_buffer_params(ctx.grad_input_quantizer) + dgrad_bulk = ub_obj_wgrad.get_buffer(ctx.grad_input_quantizer) + + # Prepare grad output tensor + # Note: Cast to expected dtype and perform tensor-parallel communication + if ctx.grad_output_quantizer is not None: + ctx.grad_output_quantizer.set_usage(rowwise=True, columnwise=True) ( grad_output, - grad_output_c, - grad_output_t, grad_bias, ) = TransformerEngineBaseModule.grad_output_preprocess( - ctx, grad_output, ctx.parallel_mode == "row" + ctx, + grad_output, + ctx.parallel_mode == "row", + ctx.grad_output_quantizer, ) - # Overlap inputmat AG with dgrad via NCCL async comms (no TP overlap via Userbuffers) + # Prepare input tensor + # Note: Perform tensor-parallel communication if needed inputmat_total = None - inputmat_t_total = None - inputmat_gather_handle = None + inputmat_total_work = None if ( - weight.requires_grad + ctx.requires_wgrad and ctx.parallel_mode == "column" and ctx.sequence_parallel and not ctx.ub_bulk_dgrad ): - inputmat_total, inputmat_gather_handle = gather_along_first_dim( - inputmat, ctx.tp_group, async_op=ctx.requires_dgrad + quantizer = None + if ctx.fp8: + quantizer = ctx.input_quantizer + quantizer.set_usage(rowwise=True, columnwise=True) + inputmat_total, inputmat_total_work = gather_along_first_dim( + inputmat, + ctx.tp_group, + async_op=True, + quantizer=quantizer, ) else: inputmat_total = inputmat - inputmat_t_total = inputmat_t + # Check whether to output wgrad GEMM directly into main grad if ctx.is_first_microbatch is not None: accumulate_wgrad_into_param_main_grad = ( ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch @@ -539,185 +441,132 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], else: accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation - if ctx.fp8: - fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True) - fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False) - - output_dtype = ctx.activation_dtype + # Compute grad input tensor + dgrad = None + dgrad_work = None if ctx.requires_dgrad: - if ctx.fp8: - if ctx.is_input_fp8 or ( - ctx.ub_overlap_rs_dgrad and ctx.ub_obj_gradout.is_fp8_ubuf() - ): - out_index, meta_tensor, output_te_dtype, output_dtype = ( - tex.FP8BwdTensors.GRAD_INPUT1, - ctx.fp8_meta["scaling_bwd"], - fp8_dtype_backward, - torch.uint8, + + # Update quantizer + if ctx.grad_input_quantizer is not None: + ctx.grad_input_quantizer.set_usage(rowwise=True, columnwise=False) + + # dgrad GEMM + dgrad, *_, rs_out = general_gemm( + weight_fp8, + grad_output, + get_workspace(), + layout="NN", + grad=True, + quantization_params=ctx.grad_input_quantizer, + out=dgrad_bulk, + out_dtype=ctx.activation_dtype, + use_split_accumulator=_2X_ACC_DGRAD, + ub=ub_obj_dgrad, + ub_type=ub_type_dgrad, + extra_output=rs_out, + bulk_overlap=ctx.ub_bulk_dgrad, + ) + + # Launch tensor-parallel communication + if ctx.ub_overlap_rs_dgrad: + dgrad = rs_out + elif ctx.parallel_mode == "column" and not ctx.ub_bulk_wgrad: + if ctx.sequence_parallel: + dgrad, dgrad_work = reduce_scatter_along_first_dim( + dgrad, + ctx.tp_group, + async_op=True, ) - if ctx.ub_overlap_rs_dgrad and ctx.ub_obj_gradout.is_fp8_ubuf(): - ctx.ub_obj_gradout.set_ubuf_scale_inv(meta_tensor.scale_inv[out_index]) else: - out_index, meta_tensor, output_te_dtype, output_dtype = ( - None, - None, - None, - ctx.activation_dtype, - ) + dgrad, dgrad_work = allreduce(dgrad, ctx.tp_group, async_op=True) - if dgrad is None: - if ctx.parallel_mode == "column" and ctx.sequence_parallel: - dgrad_shape[0] = dgrad_shape[0] * tp_world_size - dgrad = torch.empty(dgrad_shape, dtype=output_dtype, device=grad_output.device) - - if ctx.requires_dgrad: - if ctx.fp8: - _ = fp8_gemm( - weight_fp8.transpose_2d(), - weight_fp8._scale_inv, - 0, - weight_fp8._fp8_dtype, - grad_output_c, - ctx.fp8_meta["scaling_bwd"].scale_inv, - tex.FP8BwdTensors.GRAD_OUTPUT1, - fp8_dtype_backward, - output_dtype, - get_workspace(), - use_split_accumulator=_2X_ACC_DGRAD, - ub_algo=ub_algo_dgrad, - ub=ctx.ub_obj_gradout, - out=dgrad, - out_index=out_index, - fp8_meta_tensor=meta_tensor, - D_dtype=output_te_dtype, - extra_output_tensor=rs_out, - ) + # Compute grad weight tensor + wgrad = None + if ctx.requires_wgrad: + if ctx.ub_bulk_dgrad: + inputmat_total = ub_obj_dgrad.get_buffer(ctx.input_quantizer) + if ctx.fp8: + if inputmat._data is None: + # All-gather executed on columnwise data and result is in rowwise data, + # so we need to fix the interleaving before WGRAD. + inputmat_total = _fix_gathered_fp8_transpose( + inputmat_total, ctx.tp_size + ) + elif not non_tn_fp8_gemm_supported(): + # FP8 GEMM on Hopper only supports TN layout so the gathered input must + # have a valid transpose. + inputmat_total._create_transpose() - if ctx.ub_overlap_rs_dgrad: - dgrad = rs_out - elif output_dtype == torch.uint8: - dgrad = Float8Tensor( - data=dgrad, - fp8_meta=ctx.fp8_meta, - fp8_meta_forward=False, - fp8_meta_index=tex.FP8BwdTensors.GRAD_INPUT1, - fp8_dtype=fp8_dtype_backward, - dtype=ctx.activation_dtype, - ) else: - _ = gemm( - weight, - grad_output, - ctx.activation_dtype, - get_workspace(), - layout="NN", - grad=True, - ub_algo=ub_algo_dgrad, - ub=ctx.ub_obj_gradout, - out=dgrad, - extra_output_tensor=rs_out, + if inputmat_total_work is not None: + # Synchronize tensor-parallel communication + inputmat_total_work.wait() + inputmat_total_work = None + + if isinstance(grad_output, QuantizedTensor): + # This is a no-op if platform supports non-TN FP8 GEMM or the transpose + # already exists. + grad_output.update_usage(rowwise_usage=True, columnwise_usage=True) + + if ctx.ub_bulk_wgrad and ub_obj_wgrad.is_fp8_ubuf(): + rs_out = torch.empty( + dgrad_shape, dtype=ctx.activation_dtype, device=grad_output.device ) - if ctx.ub_overlap_rs_dgrad: - dgrad = rs_out - - if inputmat_gather_handle is not None: - inputmat_gather_handle.wait() - - # Overlap dgrad RS/AR with wgrad via NCCL async comms (no TP overlap via Userbuffers) - dgrad_reduce_handle = None - if ctx.requires_dgrad and ctx.parallel_mode == "column": - if ctx.sequence_parallel and not (ctx.ub_overlap_rs_dgrad or ctx.ub_bulk_wgrad): - dgrad, dgrad_reduce_handle = reduce_scatter_along_first_dim( - dgrad, ctx.tp_group, async_op=True - ) - elif ctx.tensor_parallel and not ctx.sequence_parallel: - dgrad, dgrad_reduce_handle = allreduce(dgrad, ctx.tp_group, async_op=True) + # wgrad GEMM + # Note: Fuse with bgrad computation if needed + wgrad, grad_bias_, _, rs_out = general_gemm( + inputmat_total, + grad_output, + get_workspace(), + layout="NT", + grad=True, + out_dtype=( + main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype + ), + bias=(bias if (grad_bias is None and not ctx.fp8) else None), + out=main_grad if ctx.fuse_wgrad_accumulation else None, + use_split_accumulator=_2X_ACC_WGRAD, + accumulate=accumulate_wgrad_into_param_main_grad, + ub=ub_obj_wgrad, + ub_type=ub_type_wgrad, + extra_output=rs_out, + bulk_overlap=ctx.ub_bulk_wgrad, + ) - wgrad = None - if weight.requires_grad: - if ctx.fp8: - # WGRAD - if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad: - if ctx.ub_overlap_ag: - if isinstance(grad_output_c, Float8Tensor): - grad_output_t = grad_output_c.transpose_2d() - else: - grad_output_t = tex.fp8_transpose(grad_output_c, fp8_dtype_backward) - if inputmat_t_total is None: - if isinstance(inputmat_total, Float8Tensor): - inputmat_t_total = inputmat_total.transpose_2d() - else: - inputmat_t_total = tex.fp8_transpose( - inputmat_total, fp8_dtype_backward - ) - wgrad, _ = fp8_gemm( - ( - inputmat_t_total._data - if isinstance(inputmat_t_total, Float8Tensor) - else inputmat_t_total - ), - inputmat_scale_inv, - 0, - fp8_dtype_forward, - grad_output_t, - ctx.fp8_meta["scaling_bwd"].scale_inv, - tex.FP8BwdTensors.GRAD_OUTPUT1, - fp8_dtype_backward, - ctx.activation_dtype, - get_workspace(), - accumulate=accumulate_wgrad_into_param_main_grad, - out=weight.main_grad if ctx.fuse_wgrad_accumulation else None, - use_split_accumulator=_2X_ACC_WGRAD, - ub=ub_obj_wgrad, - ub_algo=ub_algo_wgrad, - ) + if ctx.ub_bulk_wgrad: + if ub_obj_wgrad.is_fp8_ubuf(): + dgrad = rs_out else: - wgrad, _, _ = gemm( - inputmat_total, - grad_output, - ctx.activation_dtype, - get_workspace(), - layout="NT", - grad=True, - accumulate=accumulate_wgrad_into_param_main_grad, - out=weight.main_grad if ctx.fuse_wgrad_accumulation else None, - ub=ub_obj_wgrad, - ub_algo=ub_algo_wgrad, - ) - else: - # WGRAD - wgrad, grad_bias, _ = gemm( - inputmat_total, - grad_output, - ctx.activation_dtype, - get_workspace(), - layout="NT", - grad=True, - use_bias=ctx.use_bias, - accumulate=accumulate_wgrad_into_param_main_grad, - out=weight.main_grad if ctx.fuse_wgrad_accumulation else None, - ub=ub_obj_wgrad, - ub_algo=ub_algo_wgrad, - ) + dgrad = ub_obj_wgrad.get_buffer(ctx.grad_input_quantizer, local_chunk=True) - if ctx.ub_bulk_wgrad: - dgrad = ub_obj_wgrad.get_ubuf_output(0) + if grad_bias is None: + grad_bias = grad_bias_ + del grad_bias_ # Deallocate input tensor - clear_tensor_data(inputmat_total) - clear_tensor_data(inputmat_t_total) - - # Wait for dgrad reduce-scatter or all-reduce - if dgrad_reduce_handle is not None: - dgrad_reduce_handle.wait() + if ctx.owns_input: + clear_tensor_data(inputmat_total) + # Don't return grad bias if not needed if not ctx.use_bias: grad_bias = None - if weight.requires_grad: + # Synchronize tensor parallel communication + if inputmat_total_work is not None: + inputmat_total_work.wait() + inputmat_total_work = None + if dgrad_work is not None: + dgrad_work.wait() + dgrad_work = None + + if ctx.requires_wgrad: # Handle custom DDP from mcore. - if ctx.fuse_wgrad_accumulation and hasattr(weight, "grad_added_to_main_grad"): + if ( + ctx.fuse_wgrad_accumulation + and weight is not None + and hasattr(weight, "grad_added_to_main_grad") + ): weight.grad_added_to_main_grad = True if getattr(weight, "zero_out_wgrad", False): wgrad = torch.zeros( @@ -727,12 +576,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], requires_grad=False, ) else: - wgrad = torch.empty( - weight.main_grad.shape, - dtype=weight.dtype, - device=torch.cuda.current_device(), - requires_grad=False, - ) + wgrad = None elif ctx.fuse_wgrad_accumulation: wgrad = None else: @@ -742,19 +586,20 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False) # Scatter fp8 weight buffers - if ctx.fp8 and not isinstance(weight, Float8Tensor): + if ctx.fp8 and not isinstance(weight, QuantizedTensor): _fsdp_scatter_tensors(ctx.fsdp_group, weight_fp8) - return ( wgrad, - None, # weight_fp8 dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None, grad_bias, - None, # use_bias None, # is_first_microbatch None, # fp8 None, # fp8_calibration - None, # fp8_meta + None, # input_quantizer + None, # weight_quantizer + None, # output_quantizer + None, # grad_output_quantizer + None, # grad_input_quantizer None, # fuse_wgrad_accumulation None, # cpu_offloading None, # tp_group @@ -773,6 +618,8 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], None, # ub_name None, # fp8_output None, # fsdp_group + None, # module + None, # skip_fp8_weight_update ) @@ -865,6 +712,7 @@ def __init__( device: Union[torch.device, str] = "cuda", ub_overlap_ag: bool = False, ub_overlap_rs: bool = False, + ub_overlap_rs_dgrad: bool = False, ub_bulk_dgrad: bool = False, ub_bulk_wgrad: bool = False, ub_name: Optional[str] = None, @@ -878,6 +726,8 @@ def __init__( self.use_bias = bias self.return_bias = return_bias self.apply_bias = bias and not return_bias + self.get_rng_state_tracker = get_rng_state_tracker + self.rng_tracker_name = rng_tracker_name if device == "meta": assert parameters_split is None, "Cannot split module parameters on 'meta' device." @@ -903,17 +753,32 @@ def __init__( self.sequence_parallel = (self.tp_size > 1) and sequence_parallel # Column parallel TP overlap options - self.ub_overlap_ag_fprop = parallel_mode == "column" and sequence_parallel and ub_overlap_ag - self.ub_overlap_rs_dgrad = parallel_mode == "column" and sequence_parallel and ub_overlap_rs - self.ub_bulk_dgrad = parallel_mode == "column" and sequence_parallel and ub_bulk_dgrad - self.ub_bulk_wgrad = parallel_mode == "column" and sequence_parallel and ub_bulk_wgrad - if self.ub_overlap_rs_dgrad: - self.ub_bulk_dgrad = False - self.ub_bulk_wgrad = False + self.ub_overlap_ag_fprop = ( + self.parallel_mode == "column" and self.sequence_parallel and ub_overlap_ag + ) + self.ub_overlap_rs_dgrad = ( + self.parallel_mode == "column" and self.sequence_parallel and ub_overlap_rs_dgrad + ) + self.ub_bulk_dgrad = ( + self.parallel_mode == "column" + and self.sequence_parallel + and ub_bulk_dgrad + and not self.ub_overlap_rs_dgrad + ) + self.ub_bulk_wgrad = ( + self.parallel_mode == "column" + and self.sequence_parallel + and ub_bulk_wgrad + and not self.ub_overlap_rs_dgrad + ) # Row parallel TP overlap options - self.ub_overlap_rs_fprop = parallel_mode == "row" and sequence_parallel and ub_overlap_rs - self.ub_overlap_ag_dgrad = parallel_mode == "row" and sequence_parallel and ub_overlap_ag + self.ub_overlap_rs_fprop = ( + self.parallel_mode == "row" and self.sequence_parallel and ub_overlap_rs + ) + self.ub_overlap_ag_dgrad = ( + self.parallel_mode == "row" and self.sequence_parallel and ub_overlap_ag + ) if any( [ @@ -928,19 +793,6 @@ def __init__( assert ub_name is not None, f"Comm+GEMM overlap layer '{ub_name}' is not initialized." self.ub_name = ub_name - assert not ( - self.ub_overlap_rs_fprop and self.ub_overlap_ag_fprop - ), "Cannot enable AG+GEMM and GEMM+RS overlaps at the same time." - assert not ( - self.ub_overlap_rs_dgrad and self.ub_bulk_dgrad - ), "Cannot enable DGRAD+RS and bulk DGRAD overlaps at the same time." - assert not ( - self.ub_overlap_ag_dgrad and (self.ub_overlap_rs_dgrad or self.ub_bulk_dgrad) - ), "Cannot enable AG+DGRAD and DGRAD+RS or bulk DGRAD overlaps at the same time." - - self.get_rng_state_tracker = get_rng_state_tracker - self.rng_tracker_name = rng_tracker_name - # Initialize params in FP8 with_fp8_params = FP8GlobalStateManager.with_fp8_parameters() @@ -1017,7 +869,9 @@ def __init__( # Check if parameters are subviews of buffers is_subview = (split_start, split_end) != (0, self.out_features) if is_subview and with_fp8_params: - raise RuntimeError("Splitting Float8Tensor into multiple params is not supported") + raise RuntimeError( + "Splitting QuantizedTensor into multiple params is not supported" + ) # Construct weight parameter self.register_parameter( @@ -1084,6 +938,7 @@ def forward( inp: torch.Tensor, is_first_microbatch: Optional[bool] = None, fp8_output: Optional[bool] = False, + fp8_grad: Optional[bool] = False, ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: """ Apply the linear transformation to the input. @@ -1115,7 +970,6 @@ def forward( with self.prepare_forward( inp, - is_first_microbatch, allow_non_contiguous=isinstance(inp, QuantizedTensor), ) as inp: @@ -1129,36 +983,25 @@ def forward( ) else: unfused_weights = [w.dequantize() for w in unfused_weights] - weight_tensor = _noop_cat(unfused_weights) + weight_tensor = noop_cat(unfused_weights) if self.use_bias: - bias_tensor = _noop_cat([getattr(self, name) for name in self.bias_names]) + bias_tensor = noop_cat([getattr(self, name) for name in self.bias_names]) else: - bias_tensor = getattr(self, self.bias_names[0]) # Unused - - # Initialize FP8 weights if needed - weight_fp8 = None - if self.fp8: - if isinstance(weight_tensor, Float8Tensor): - # Make sure transpose cache is valid, if present - # Note: Transpose cache may have been invalidated - # externally, e.g. by optimizer. - if weight_tensor._transpose is not None: - weight_tensor.transpose_2d( - fill_cache=True, - noop_flag=skip_fp8_weight_update, - ) - else: - # FP8 cast to workspace buffer - update_workspace = is_first_microbatch is None or is_first_microbatch - weight_fp8 = self.get_fp8_workspace( - tensor=weight_tensor, - fp8_meta_forward=True, - fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT, - cache_name=(None if is_first_microbatch is None else "weight"), - update_workspace=update_workspace, - skip_update_flag=skip_fp8_weight_update, - fsdp_group=self.fsdp_group, - ) + bias_tensor = None + + ( + input_quantizer, + weight_quantizer, + output_quantizer, + grad_output_quantizer, + grad_input_quantizer, + ) = self._get_quantizers(fp8_output, fp8_grad) + + # Make sure weight tensor has correct quantizer + # Note: Quantizer might have changed if quantization + # recipe changed + if weight_quantizer is not None and isinstance(weight_tensor, QuantizedTensor): + weight_tensor._quantizer = weight_quantizer if torch.is_grad_enabled(): linear_fn = _Linear.apply @@ -1168,14 +1011,16 @@ def forward( args = [None] args += ( weight_tensor, - weight_fp8, inp, - bias_tensor, - self.apply_bias and not self.gemm_bias_unfused_add, + bias_tensor if (self.apply_bias and not self.gemm_bias_unfused_add) else None, is_first_microbatch, self.fp8, self.fp8_calibration, - self.fp8_meta, + input_quantizer, + weight_quantizer, + output_quantizer, + grad_output_quantizer, + grad_input_quantizer, self.fuse_wgrad_accumulation, is_cpu_offload_enabled(), self.tp_group, @@ -1194,12 +1039,38 @@ def forward( self.ub_name, fp8_output, self.fsdp_group, + self, + skip_fp8_weight_update, ) out = linear_fn(*args) - if self.gemm_bias_unfused_add: out = out + cast_if_needed(bias_tensor, self.activation_dtype) if self.return_bias: return out, cast_if_needed(bias_tensor, self.activation_dtype) return out + + def _get_quantizers(self, fp8_output, fp8_grad): + if not self.fp8: + return [None] * 5 + grad_input_quantizer = None + grad_output_quantizer = None + output_quantizer = None + input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT] + input_quantizer.internal = False + weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT] + weight_quantizer.internal = True + if fp8_output: + output_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_OUTPUT] + if torch.is_grad_enabled(): + grad_output_quantizer = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT1] + grad_output_quantizer.internal = True + if fp8_grad: + grad_input_quantizer = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_INPUT1] + return ( + input_quantizer, + weight_quantizer, + output_quantizer, + grad_output_quantizer, + grad_input_quantizer, + ) diff --git a/transformer_engine/pytorch/ops/_common.py b/transformer_engine/pytorch/ops/_common.py index 26bceab737..bb826e552e 100644 --- a/transformer_engine/pytorch/ops/_common.py +++ b/transformer_engine/pytorch/ops/_common.py @@ -11,11 +11,11 @@ from transformer_engine_torch import FP8TensorMeta from ..fp8 import FP8GlobalStateManager -from ..tensor import Float8Tensor +from ..tensor.float8_tensor import Float8Tensor from ..utils import ( - canonicalize_device, # pylint: disable=unused-import - canonicalize_dtype, # pylint: disable=unused-import - devices_match, # pylint: disable=unused-import + canonicalize_device, + canonicalize_dtype, + devices_match, ) @@ -61,12 +61,9 @@ def convert_tensor( # Note: torch.Tensor.to ignores memory_format kwarg (see # https://github.com/pytorch/pytorch/issues/132020). data = data.contiguous(memory_format=memory_format) - return Float8Tensor.make_like( - tensor, - data=data, - fp8_attrs=tensor._fp8_attrs, - dtype=dtype, - ) + out = Float8Tensor.make_like(tensor, dtype=dtype) + out.data = data + return out # Convert standard PyTorch tensor tensor = tensor.to(device=device, dtype=dtype) @@ -85,46 +82,14 @@ def reshape( device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ) -> torch.Tensor | Float8Tensor: - """Reshape tensor, keeping same data if possible - - If the input is a Float8Tensor, this function attempts to preserve - the cached transpose if available and valid. If a cached transpose - is present, it is interpreted as the transpose of a 2D matrix - where the width matches the innermost tensor dimension. - - """ - - # Make sure tensor is in expected format + """Reshape tensor, keeping same data if possible""" tensor = convert_tensor( tensor, device=device, dtype=dtype, memory_format=torch.contiguous_format, ) - - # Return immediately if tensor already has desired shape - shape = list(shape) - if len(shape) == tensor.dim(): - if sum(1 for d in shape if d == -1) > 1: - raise ValueError( - "Attempted to reshape tensor with " - f"shape={tuple(tensor.size())} into shape={tuple(shape)}" - ) - if all(d1 == d2 for d1, d2 in zip(shape, tensor.size()) if d1 != -1): - return tensor - - # Reshape FP8 tensor - # Note: Preserve cached transpose if possible - if is_float8_tensor(tensor): - out = Float8Tensor.make_like( - tensor, - data=tensor._data.view(shape), - fp8_attrs=tensor._fp8_attrs, - ) - return out - - # Reshape standard PyTorch tensor - return tensor.view(shape) + return tensor.reshape(*shape) def maybe_autocast_dtype( diff --git a/transformer_engine/pytorch/ops/basic/activation.py b/transformer_engine/pytorch/ops/basic/activation.py index 7ad6e70929..45c78bea87 100644 --- a/transformer_engine/pytorch/ops/basic/activation.py +++ b/transformer_engine/pytorch/ops/basic/activation.py @@ -10,20 +10,12 @@ import torch -import transformer_engine_torch -from ...constants import TE_DType -from ...cpp_extensions import ( - geglu as tex_geglu, - gelu as tex_gelu, - reglu as tex_reglu, - relu as tex_relu, - swiglu as tex_swiglu, - fp8_dswiglu_cast_transpose_fused, -) -from ...fp8 import FP8GlobalStateManager, get_fp8_te_dtype -from ...tensor import Float8Tensor, QuantizedTensor +import transformer_engine_torch as tex +from ...fp8 import FP8GlobalStateManager +from ...tensor import QuantizedTensor from ...utils import clear_tensor_data, devices_match from ..op import BasicOperation, OperationContext +from .._common import reshape class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta): @@ -93,43 +85,23 @@ def op_forward( # Check if FP8 is enabled fp8_enabled = FP8GlobalStateManager.is_fp8_enabled() - with_fp8_output = False - output_fp8_meta = None - output_dtype = TE_DType[dtype] - output_fp8_scale_inv = None - if fp8_enabled and next_op is not None and next_op.num_fp8_scales("input") > 0: - with_fp8_output = True - fp8_meta = next_op.get_fp8_meta("input") - fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(forward=True) - output_fp8_meta = fp8_meta[fp8_meta_key] - output_dtype = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) - output_fp8_scale_inv = torch.empty([1], dtype=torch.float32, device=x.device) + if fp8_enabled and next_op is not None and next_op.num_quantizers("forward") > 0: + quantizer = next_op.get_quantizer("forward", 0) + else: + quantizer = None # Launch kernel y = self._activation_forward_impl( - x, - output_fp8_meta, - 0, - output_dtype, - scale_inv=output_fp8_scale_inv, + reshape(x, (-1, x.size(-1))), + quantizer, ) # Check output tensor if y.dim() != x.dim(): y = y.reshape(list(x.shape[:-1]) + [-1]) - if with_fp8_output: - y = Float8Tensor( - data=y, - fp8_meta=output_fp8_meta, - fp8_meta_forward=True, - fp8_meta_index=0, - fp8_dtype=output_dtype, - fp8_scale_inv=output_fp8_scale_inv, - dtype=dtype, - ) # Save state for backward pass - ctx.save_for_backward(x) + ctx.save_for_backward(x.detach()) ctx.fp8_enabled = fp8_enabled ctx.prev_op = prev_op @@ -154,7 +126,11 @@ def op_backward( dy = dy.contiguous() # Launch kernel - dx = self._activation_backward_impl(dy, x, TE_DType[x.dtype]) + dx = self._activation_backward_impl( + reshape(dy, (-1, dy.size(-1))), + reshape(x, (-1, x.size(-1))), + None, + ) # Check grad input tensor if dx.size() != x.size(): @@ -181,10 +157,10 @@ class GELU(_ActivationOperation): """ def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor: - return tex_gelu(*args, **kwargs) + return tex.gelu(*args, **kwargs) def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: - return transformer_engine_torch.dgelu(*args, **kwargs) + return tex.dgelu(*args, **kwargs) class ReLU(_ActivationOperation): @@ -197,10 +173,10 @@ class ReLU(_ActivationOperation): """ def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor: - return tex_relu(*args, **kwargs) + return tex.relu(*args, **kwargs) def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: - return transformer_engine_torch.drelu(*args, **kwargs) + return tex.drelu(*args, **kwargs) class GEGLU(_ActivationOperation): @@ -232,10 +208,10 @@ class GEGLU(_ActivationOperation): """ def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor: - return tex_geglu(*args, **kwargs) + return tex.geglu(*args, **kwargs) def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: - return transformer_engine_torch.dgeglu(*args, **kwargs) + return tex.dgeglu(*args, **kwargs) class ReGLU(_ActivationOperation): @@ -261,10 +237,10 @@ class ReGLU(_ActivationOperation): """ def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor: - return tex_reglu(*args, **kwargs) + return tex.reglu(*args, **kwargs) def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: - return transformer_engine_torch.dreglu(*args, **kwargs) + return tex.dreglu(*args, **kwargs) class SwiGLU(_ActivationOperation): @@ -299,92 +275,7 @@ class SwiGLU(_ActivationOperation): """ def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor: - return tex_swiglu(*args, **kwargs) + return tex.swiglu(*args, **kwargs) def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: - return transformer_engine_torch.dswiglu(*args, **kwargs) - - def op_backward( - self, - ctx: OperationContext, - grad_output: torch.Tensor, - ) -> tuple[torch.Tensor, tuple[()]]: - - # Saved tensors from forward pass - (x,) = ctx.saved_tensors - - # Tensor attributes - dtype = x.dtype - device = x.device - - # Check grad output tensor - dy = grad_output - if isinstance(dy, QuantizedTensor): - dy = dy.dequantize() - if not devices_match(dy.device, device) or dy.dtype != dtype: - dy = dy.to(device=device, dtype=dtype) - if not dy.is_contiguous(): - dy = dy.contiguous() - - # Check if FP8 is enabled - with_fp8_grad_input = False - grad_input_fp8_meta = None - grad_input_dtype = TE_DType[dtype] - grad_input_fp8_scale_inv = None - if ( - ctx.fp8_enabled - and ctx.prev_op is not None - and ctx.prev_op.num_fp8_scales("grad_output") > 0 - ): - with_fp8_grad_input = True - fp8_meta = ctx.prev_op.get_fp8_meta("grad_output") - fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(forward=False) - grad_input_fp8_meta = fp8_meta[fp8_meta_key] - grad_input_dtype = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=False) - grad_input_fp8_scale_inv = torch.empty([1], dtype=torch.float32, device=device) - - # Launch kernel - if with_fp8_grad_input: - # Fused with FP8 cast-transpose - input_dims = x.size() - flat_input_dims = [x.numel() // input_dims[-1], input_dims[-1]] - flat_output_dims = [flat_input_dims[0], flat_input_dims[1] // 2] - dx = torch.empty(input_dims, dtype=torch.uint8, device=device) - dx_t = torch.empty( - (flat_input_dims[1], flat_input_dims[0]), - dtype=torch.uint8, - device=device, - ) - fp8_dswiglu_cast_transpose_fused( - dy.reshape(flat_output_dims), - x.reshape(flat_input_dims), - grad_input=dx.reshape(flat_input_dims), - grad_input_transpose=dx_t, - otype=grad_input_dtype, - fp8_meta=grad_input_fp8_meta, - fp8_meta_index=0, - scale_inv=grad_input_fp8_scale_inv, - ) - dx = Float8Tensor( - data=dx, - fp8_meta=grad_input_fp8_meta, - fp8_meta_forward=True, - fp8_meta_index=0, - fp8_dtype=grad_input_dtype, - fp8_scale_inv=grad_input_fp8_scale_inv, - dtype=dtype, - ) - dx._transpose = dx_t - dx._transpose_invalid = False - else: - # Standard impl - dx = self._activation_backward_impl(dy, x, TE_DType[dtype]) - if dx.size() != x.size(): - dx = dx.reshape(x.size()) - - # Note: This fails if op is preceeded by an identity op like Quantize(forward=False) - # # Clear input tensor if possible - # if ctx.prev_op is not None: - # clear_tensor_data(x) - - return dx, () + return tex.dswiglu(*args, **kwargs) diff --git a/transformer_engine/pytorch/ops/basic/all_gather.py b/transformer_engine/pytorch/ops/basic/all_gather.py index 2dd1d1b75e..15b1f65d85 100644 --- a/transformer_engine/pytorch/ops/basic/all_gather.py +++ b/transformer_engine/pytorch/ops/basic/all_gather.py @@ -9,12 +9,9 @@ import torch -from transformer_engine.pytorch.float8_tensor import Float8Tensor -from transformer_engine.pytorch.ops.op import ( - BasicOperation, - OperationContext, -) -from .._common import convert_tensor, is_float8_tensor +from ...distributed import gather_along_first_dim +from ...tensor import QuantizedTensor +from ..op import BasicOperation, OperationContext class AllGather(BasicOperation): @@ -45,47 +42,12 @@ def op_forward( prev_op: Optional[BasicOperation] = None, next_op: Optional[BasicOperation] = None, ) -> torch.Tensor: - - # Trivial case + out: torch.Tensor if self.process_group_size == 1: - return input_ - - # Tensor dimensions - input_dims = input_.size() - if not input_dims: - raise RuntimeError( - "Attempted to all-gather a tensor " - f"with shape={list(input_dims)} " - f"over {self.process_group_size} processes" - ) - output_dims = list(input_dims) - output_dims[0] *= self.process_group_size - - # Perform all-gather - x = convert_tensor(input_, memory_format=torch.contiguous_format) - y = None - if is_float8_tensor(x): - y = Float8Tensor.make_like( - x, - data=torch.empty( - output_dims, - dtype=torch.uint8, - device=x.device, - ), - ) - torch.distributed.all_gather_into_tensor( - y._data, - x._data, - group=self.process_group, - ) + out = input_.detach() else: - y = torch.empty(output_dims, dtype=x.dtype, device=x.device) - torch.distributed.all_gather_into_tensor( - y, - x, - group=self.process_group, - ) - return y + out, _ = gather_along_first_dim(input_, self.process_group) + return out def op_backward( self, @@ -110,8 +72,8 @@ def op_backward( # Check output gradient tensor dy = grad_output - if is_float8_tensor(dy): - dy = dy.from_float8() + if isinstance(dy, QuantizedTensor): + dy = dy.dequantize() dy = dy.contiguous() # Perform reduce-scatter diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index c5178d2d91..892e120da1 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -12,33 +12,24 @@ import torch -from transformer_engine.pytorch.cpp_extensions import ( - FP8TensorMeta, - fp8_gemm, - gemm, -) -from transformer_engine.pytorch.distributed import ( +from transformer_engine.pytorch.module.base import get_workspace +from ...cpp_extensions import general_gemm +from ...distributed import ( CudaRNGStatesTracker, gather_along_first_dim, reduce_scatter_along_first_dim, ) -from transformer_engine.pytorch.float8_tensor import Float8Tensor -from transformer_engine.pytorch.fp8 import ( - FP8GlobalStateManager, - get_fp8_te_dtype, -) -from transformer_engine.pytorch.module.base import get_workspace -from transformer_engine.pytorch.ops.op import ( - BasicOperation, - OperationContext, -) +from ...fp8 import FP8GlobalStateManager +from ...module.base import _2X_ACC_FPROP, _2X_ACC_DGRAD, _2X_ACC_WGRAD +from ...tensor import Quantizer, QuantizedTensor +from ...tensor.float8_tensor import Float8Quantizer +from ...tensor.mxfp8_tensor import MXFP8Quantizer +from ...tensor._internal.float8_tensor_base import Float8TensorBase +from ..op import BasicOperation, OperationContext from .._common import ( canonicalize_device, canonicalize_dtype, - convert_tensor, devices_match, - is_float8_tensor, - reshape, ) from ...utils import clear_tensor_data @@ -110,17 +101,8 @@ def __init__( self.in_features: int = in_features self.out_features: int = out_features - # Weight tensor device - defer_param_init = False + # Weight tensor attributes device = canonicalize_device(device) - if device.type == "meta": - defer_param_init = True - device = canonicalize_device(None) - if device.type != "cuda": - raise ValueError(f"Only CUDA devices are supported (got {device})") - self.device: torch.device = device - - # Weight tensor datatype dtype = canonicalize_dtype(dtype) if dtype not in (torch.float32, torch.float16, torch.bfloat16): raise ValueError(f"Supported dtypes are float32, float16, bfloat16 (got {dtype})") @@ -147,16 +129,14 @@ def __init__( out_features=out_features, ) - # Whether weight tensor is natively in FP8 - self._with_fp8_parameters: bool = FP8GlobalStateManager.with_fp8_parameters() - if self._with_fp8_parameters: - self._fp8_metas = self._make_fp8_metas() + # Whether weight tensor is natively quantized + self._with_quantized_weight: bool = FP8GlobalStateManager.with_fp8_parameters() # Initialize parameters if needed weight = torch.empty( self.local_out_features, self.local_in_features, - device="meta", + device=device, dtype=dtype, ) weight = torch.nn.Parameter(weight) @@ -164,7 +144,7 @@ def __init__( self.register_parameter("weight", weight) self._rng_state_tracker_function: Optional[Callable[[], CudaRNGStatesTracker]] self._rng_state_tracker_function = rng_state_tracker_function - if not defer_param_init: + if weight.device.type != "meta": self.reset_parameters() # Whether to accumulate weight gradient into main_grad @@ -273,43 +253,48 @@ def _canonicalize_tensor_parallelism( local_out_features, ) - def num_fp8_scales(self, mode: str) -> int: - if mode in ("input", "param", "grad_output"): + def num_quantizers(self, mode: str) -> int: + if mode == "forward": + return 2 + if mode == "backward": return 1 return 0 def reset_parameters(self) -> None: """Initialize parameter buffers and values""" - # Make sure parameter is initialized + # Parameter device weight = self.weight - if weight.device.type != "cuda" or is_float8_tensor(weight): - weight = torch.empty_like(weight, device=self.device) - else: - weight = weight.to(device=self.device) + device = weight.device + if device.type == "meta": + device = canonicalize_device(None) + + # Allocate buffer if needed + if isinstance(weight, QuantizedTensor): + weight = torch.empty( + weight.size(), + dtype=weight.dtype, + device=device, + ) + elif not devices_match(weight.device, device): + weight = torch.empty_like(weight, device=device) # Initialize values - init_context = contextlib.nullcontext + init_context = contextlib.nullcontext() if self._rng_state_tracker_function is not None: - init_context = self._rng_state_tracker_function().fork - with init_context(): + init_context = self._rng_state_tracker_function().fork() + with init_context: torch.nn.init.kaiming_uniform_(weight, a=math.sqrt(5)) - # Cast to FP8 if needed - if self._with_fp8_parameters: - dummy_amax = torch.empty( - (1, 1), - dtype=torch.float32, - device=self.device, - ) # Dummy buffer to avoid overwriting amax history - weight = Float8Tensor.to_float8( - weight, - fp8_meta=self.get_fp8_meta("param"), - fp8_meta_forward=True, - fp8_meta_index=0, - amax=dummy_amax, - with_transpose_cache=torch.is_grad_enabled(), + # Quantize if needed + if self._with_quantized_weight: + quantizer = self.get_quantizer("forward", 1) + quantizer.set_usage( + rowwise=True, + columnwise=torch.is_grad_enabled(), ) + with torch.no_grad(): + weight = quantizer(weight) # Save updated parameter if not isinstance(weight, torch.nn.Parameter): @@ -318,8 +303,33 @@ def reset_parameters(self) -> None: def pre_forward(self, *args, **kwargs) -> None: super().pre_forward(*args, **kwargs) - if self.weight.device.type == "meta": + + # Initialize weights if needed + weight = self.weight + if weight.device.type == "meta": self.reset_parameters() + weight = self.weight + + # Configure quantizers + if FP8GlobalStateManager.is_fp8_enabled(): + input_quantizer = self.get_quantizer("forward", 0) + weight_quantizer = self.get_quantizer("forward", 1) + grad_output_quantizer = self.get_quantizer("backward", 0) + + # Specify required tensor formats + is_grad_enabled = torch.is_grad_enabled() + weight_requires_grad = is_grad_enabled and weight.requires_grad + input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) + weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled) + grad_output_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) + + # Make sure weight tensor has correct quantizer + # Note: Quantizer might have changed if quantization + # recipe changed + if isinstance(weight_quantizer, Float8Quantizer) and isinstance( + weight, Float8TensorBase + ): + weight._quantizer = weight_quantizer @staticmethod def _functional_forward( @@ -327,17 +337,17 @@ def _functional_forward( weight: torch.Tensor, *, bias: Optional[torch.Tensor] = None, - device: Optional[torch.device] = None, + device: Optional[torch.device] = None, # pylint: disable=unused-argument dtype: Optional[torch.dtype] = None, out: Optional[torch.Tensor] = None, accumulate_into_out: bool = False, tensor_parallel_mode: Optional[str] = None, tensor_parallel_group: Optional[torch.distributed.ProcessGroup] = None, sequence_parallel: bool = False, - with_fp8_compute: bool = False, - input_fp8_meta: Optional[dict[str, Any]] = None, - weight_fp8_meta: Optional[dict[str, Any]] = None, - output_fp8_meta: Optional[dict[str, Any]] = None, + with_quantized_compute: bool = False, + input_quantizer: Optional[Quantizer] = None, + weight_quantizer: Optional[Quantizer] = None, + output_quantizer: Optional[Quantizer] = None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Functional API for forward pass @@ -366,16 +376,14 @@ def _functional_forward( parallelism, i.e. distributing input or output tensors along outer dimension (sequence or batch dim) when not distributing along inner dimension (embedding dim) - with_fp8_compute: bool, default = `False` - Whether to perform compute in FP8 - input_fp8_meta: dict, optional - FP8 metadata for casting input tensor to FP8. Required for - FP8 compute if input is not already in FP8. - weight_fp8_meta: dict, optional - FP8 metadata for casting weight tensor to FP8. Required for - FP8 compute if weight is not already in FP8. - output_fp8_meta: dict, optional - FP8 metadata for casting output tensor to FP8 + with_quantized_compute: bool, default = `False` + Whether to perform compute with quantized data. + input_quantizer: Quantizer, optional + Builder class for quantized input tensor. + weight_quantizer: Quantizer, optional + Builder class for quantized weight tensor. + output_quantizer: Quantizer, optional + Builder class for quantized output tensor. Returns ------- @@ -390,17 +398,6 @@ def _functional_forward( """ - # Check device - if device is None: - device = weight.device if out is None else out.device - device = canonicalize_device(device) - if device.type != "cuda": - raise ValueError(f"Only CUDA devices are supported (got {device})") - if out is not None and not devices_match(out.device, device): - raise ValueError( - f"Output tensor has invalid device (expected {device}, got {out.device})" - ) - # Check datatype if dtype is None: dtype = weight.dtype if out is None else out.dtype @@ -410,36 +407,88 @@ def _functional_forward( if out is not None and out.dtype != dtype: raise ValueError(f"Output tensor has invalid dtype (expected {dtype}, got {out.dtype})") - # Check input tensor dims - input_dims = tuple(input.size()) - weight_dims = tuple(weight.size()) - if len(weight_dims) != 2: - raise ValueError(f"Weight tensor is not 2D (shape={weight_dims})") - if len(input_dims) == 0 or weight_dims[1] != input_dims[-1]: - raise ValueError( - f"Input tensor (shape={input_dims}) " - f"and weight tensor (shape={weight_dims}) " - "are not compatible" - ) - - # Check output tensor dims - output_dims: list[int] - if out is None: - output_dims = list(input_dims) - output_dims[0] = -1 - output_dims[-1] = weight_dims[0] + # Check input tensor + x_local = input + x = None + x_async = None + with_x_all_gather = tensor_parallel_mode == "column" and sequence_parallel + own_quantized_x_local = False + if with_quantized_compute: + if input_quantizer is None: + raise ValueError("Missing quantizer for input tensor") + input_quantizer.set_usage(rowwise=True) + if with_x_all_gather: + input_quantizer.set_usage(columnwise=False) + x, x_async = gather_along_first_dim( + x_local, + tensor_parallel_group, + async_op=True, + quantizer=input_quantizer, + ) + else: + if not isinstance(x_local, QuantizedTensor): + x_local = input_quantizer(x_local) + own_quantized_x_local = True + x = x_local else: - output_dims = list(out.size()) - if len(output_dims) == 0 or weight_dims[0] != output_dims[-1]: + if isinstance(x_local, QuantizedTensor): + x_local = x_local.dequantize() + if x_local.dtype != dtype: + x_local = x_local.to(dtype=dtype) + if with_x_all_gather: + x, x_async = gather_along_first_dim( + x_local, + tensor_parallel_group, + async_op=True, + ) + else: + x = x_local + + # Check weight tensor + w = weight + w_is_quantized = isinstance(w, QuantizedTensor) + if with_quantized_compute and not w_is_quantized: + if weight_quantizer is None: + raise ValueError("Missing quantizer for weight tensor") + weight_quantizer.set_usage(rowwise=True) + w = weight_quantizer(w) + elif not with_quantized_compute and w_is_quantized: + w = w.dequantize() + if not with_quantized_compute and w.dtype != dtype: + w = w.to(dtype=dtype) + + # Check output tensor + y = out + if y is None: + if not with_quantized_compute: + output_quantizer = None + if tensor_parallel_mode == "row": + output_quantizer = None + elif isinstance(y, QuantizedTensor): + if not with_quantized_compute: + raise ValueError("Output tensor is quantized, but quantized compute is not enabled") + if tensor_parallel_mode == "row": raise ValueError( - f"Output tensor (shape={output_dims}) " - f"and weight tensor (shape={weight_dims}) " - "are not compatible" + "Output tensor is quantized, " + "but row tensor parallelism does not support quantized output" ) + if output_quantizer is None: + output_quantizer = getattr(y, "_quantizer", None) + if output_quantizer is None: + raise ValueError("Output tensor is quantized, but quantizer was not provided") + else: + output_quantizer = None + if isinstance(output_quantizer, MXFP8Quantizer): + raise RuntimeError( + "Attempting to generate MXFP8 output tensor, " + "but GEMM with MXFP8 output is not supported" + ) + if output_quantizer is not None: + output_quantizer.set_usage(rowwise=True, columnwise=False) # Check if accumulating into output tensor if accumulate_into_out: - if out is None: + if y is None: raise ValueError( "Attempted to accumulate into output tensor without providing output tensor" ) @@ -448,181 +497,22 @@ def _functional_forward( "Accumulating into output tensor is not supported with row tensor parallelism" ) - # Check if FP8 is enabled - if with_fp8_compute: - if input_fp8_meta is None and not is_float8_tensor(input): - raise ValueError("No FP8 metadata was provided for casting input to FP8") - if weight_fp8_meta is None and not is_float8_tensor(weight): - raise ValueError("No FP8 metadata was provided for casting weight to FP8") - else: - input_fp8_meta = None - weight_fp8_meta = None - output_fp8_meta = None - with_fp8_output = with_fp8_compute and tensor_parallel_mode != "row" - if out is None: - with_fp8_output = with_fp8_output and output_fp8_meta is not None - else: - if is_float8_tensor(out): - if not with_fp8_output: - raise ValueError( - "Output tensor is a Float8Tensor, but FP8 output is not supported" - ) - out._reset_caches() - else: - with_fp8_output = False - - # Check input tensor - x_local = reshape( - input, - (-1, input_dims[-1]), - device=device, - dtype=dtype, - ) - if with_fp8_compute and not is_float8_tensor(x_local): - fp8_dtype = get_fp8_te_dtype( - input_fp8_meta["recipe"], - fprop_tensor=True, - ) - with_transpose_cache = weight.requires_grad - if tensor_parallel_mode == "column" and sequence_parallel: - with_transpose_cache = False - x_local = Float8Tensor.to_float8( - x_local, - fp8_meta=input_fp8_meta, - fp8_meta_forward=True, - fp8_meta_index=0, - fp8_dtype=fp8_dtype, - with_transpose_cache=with_transpose_cache, - ) - elif not with_fp8_compute and is_float8_tensor(x_local): - x_local = x_local.dequantize() - x = x_local + # Synchronize communication for input + _wait_async(x_async) x_async = None - if tensor_parallel_mode == "column" and sequence_parallel: - x, x_async = gather_along_first_dim( - x_local, - tensor_parallel_group, - async_op=True, - ) - - # Check weight tensor - w = convert_tensor( - weight, - device=device, - dtype=dtype, - memory_format=torch.contiguous_format, - ) - if with_fp8_compute and not is_float8_tensor(w): - fp8_dtype = get_fp8_te_dtype( - weight_fp8_meta["recipe"], - fprop_tensor=True, - ) - w = Float8Tensor.to_float8( - w, - fp8_meta=weight_fp8_meta, - fp8_meta_forward=True, - fp8_meta_index=0, - fp8_dtype=fp8_dtype, - ) - elif not with_fp8_compute and is_float8_tensor(w): - w = w.dequantize() - - # Check bias tensor - b = None - if bias is not None: - b = convert_tensor( - bias, - device=device, - dtype=dtype, - memory_format=torch.contiguous_format, - ) - - # Construct output tensor - y = None - if out is not None: - y = reshape(out, (-1, output_dims[-1])) - elif with_fp8_output: - fp8_dtype = get_fp8_te_dtype( - output_fp8_meta["recipe"], - fprop_tensor=True, - ) - data = torch.empty( - (x.size(0), weight_dims[0]), - dtype=torch.uint8, - device=device, - ) - y = Float8Tensor( - data=data, - fp8_meta=output_fp8_meta, - fp8_meta_forward=True, - fp8_meta_index=0, - fp8_dtype=fp8_dtype, - dtype=dtype, - ) - else: - y = torch.empty( - (x.size(0), weight_dims[0]), - dtype=dtype, - device=device, - ) # Perform GEMM - _wait_async(x_async) - x_async = None - if with_fp8_compute: - kwargs = { - "accumulate": accumulate_into_out, - "out": y, - "bias": b, - "use_bias": (b is not None), - } - if with_fp8_output: - if y._fp8_meta is None: - # Hackily create FP8TensorMeta if needed - fp8_meta = FP8TensorMeta() - fp8_meta.scale = y._scale_inv.reciprocal() - fp8_meta.amax_history = torch.empty(1, 1, dtype=torch.float32, device=device) - fp8_meta.scale_inv = y._scale_inv - fp8_meta_index = 0 - else: - # Get FP8TensorMeta from Float8Tensor - fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( - forward=y._fp8_meta_forward, - ) - fp8_meta = y._fp8_meta[fp8_meta_key] - fp8_meta_index = y._fp8_meta_index - kwargs.update( - { - "out": y._data, - "out_index": fp8_meta_index, - "fp8_meta_tensor": fp8_meta, - "D_dtype": y._fp8_dtype, - } - ) - fp8_gemm( - w._data, - w._scale_inv, - 0, - w._fp8_dtype, - x._data, - x._scale_inv, - 0, - x._fp8_dtype, - y.dtype, - get_workspace(), - **kwargs, - ) - else: - gemm( - w, - x, - y.dtype, - get_workspace(), - accumulate=accumulate_into_out, - out=y, - bias=b, - use_bias=(b is not None), - ) + y, *_ = general_gemm( + w, + x, + get_workspace(), + out_dtype=dtype, + quantization_params=output_quantizer, + accumulate=accumulate_into_out, + out=y, + bias=bias, + use_split_accumulator=_2X_ACC_FPROP, + ) # Reduce tensor-parallel output if needed if tensor_parallel_mode == "row": @@ -631,23 +521,29 @@ def _functional_forward( else: torch.distributed.all_reduce(y, group=tensor_parallel_group) - # Reshape output tensor if needed - if out is None: - out = reshape(y, output_dims) + # Configure input tensor for backward pass + if own_quantized_x_local: + ### TODO Restore once column-wise usage is supported by itself # pylint: disable=fixme + # x_local.update_usage(rowwise_usage=False) + pass + + # Detach input tensor if needed + # Note: PyTorch autograd produces esoteric errors if we save + # input tensor as context for backward pass. + if x_local is input: + x_local = x_local.detach() - return out, x_local, w + return y, x_local, w @staticmethod def _functional_backward( grad_output: torch.Tensor, input: Optional[torch.Tensor], # pylint: disable=redefined-builtin weight: Optional[torch.Tensor], - input_dims: Iterable[int], - weight_dims: Iterable[int], *, input_requires_grad: bool = True, weight_requires_grad: bool = True, - device: Optional[torch.device] = None, + device: Optional[torch.device] = None, # pylint: disable=unused-argument dtype: Optional[torch.dtype] = None, grad_weight: Optional[torch.Tensor] = None, accumulate_into_grad_weight: bool = False, @@ -656,11 +552,11 @@ def _functional_backward( tensor_parallel_mode: Optional[str] = None, tensor_parallel_group: Optional[torch.distributed.ProcessGroup] = None, sequence_parallel: bool = False, - with_fp8_compute: bool = False, - input_fp8_meta: Optional[dict[str, Any]] = None, - weight_fp8_meta: Optional[dict[str, Any]] = None, - grad_output_fp8_meta: Optional[dict[str, Any]] = None, - grad_input_fp8_meta: Optional[dict[str, Any]] = None, + with_quantized_compute: bool = False, + input_quantizer: Optional[Quantizer] = None, + weight_quantizer: Optional[Quantizer] = None, + grad_output_quantizer: Optional[Quantizer] = None, + grad_input_quantizer: Optional[Quantizer] = None, ) -> tuple[torch.Tensor, torch.Tensor]: """Functional API for backward pass @@ -674,10 +570,6 @@ def _functional_backward( weight: torch.Tensor, optional Weight tensor. Required to compute loss gradient w.r.t. input. - input_dims: iterable of int - Input tensor dimensions - weight_dims: iterable of int - Weight tensor dimensions input_requires_grad: bool Whether to compute loss gradient w.r.t. input tensor weight_requires_grad: bool @@ -703,21 +595,18 @@ def _functional_backward( parallelism, i.e. distributing input or output tensors along outer dimension (sequence or batch dim) when not distributing along inner dimension (embedding dim) - with_fp8_compute: bool, default = `False` - Whether to perform compute in FP8 - input_fp8_meta: dict, optional - FP8 metadata for casting input tensor to FP8. Required for - FP8 compute if input is not already in FP8. - weight_fp8_meta: dict, optional - FP8 metadata for casting weight tensor to FP8. Required for - FP8 compute if weight is not already in FP8. - grad_output_fp8_meta: dict, optional - FP8 metadata for casting loss gradient w.r.t. output - tensor to FP8. Required if output grad is not already in - FP8. - grad_input_fp8_meta: dict, optional - FP8 metadata for casting loss gradient w.r.t. input - tensor to FP8 + with_quantized_compute: bool, default = `False` + Whether to perform compute with quantized data. + input_quantizer: Quantizer, optional + Builder class for quantized input tensor. + weight_quantizer: Quantizer, optional + Builder class for quantized weight tensor. + grad_output_quantizer: Quantizer, optional + Builder class for quantized loss gradient w.r.t. output + tensor. + grad_input_quantizer: dict, optional + Builder class for quantized loss gradient w.r.t. input + tensor. Returns ------- @@ -728,13 +617,6 @@ def _functional_backward( """ - # Check device - if device is None: - device = weight.device - device = canonicalize_device(device) - if device.type != "cuda": - raise ValueError(f"Only CUDA devices are supported (got {device})") - # Check datatype if dtype is None: dtype = weight.dtype @@ -742,109 +624,42 @@ def _functional_backward( if dtype not in (torch.float32, torch.float16, torch.bfloat16): raise ValueError(f"Supported dtypes are float32, float16, bfloat16 (got {dtype})") - # Check tensor dims - output_dims = tuple(grad_output.size()) - input_dims = tuple(input_dims) - weight_dims = tuple(weight_dims) - if len(weight_dims) != 2: - raise ValueError(f"Weight tensor is not 2D (shape={weight_dims})") - if len(input_dims) == 0 or weight_dims[1] != input_dims[-1]: - raise ValueError( - f"Input tensor (shape={input_dims}) " - f"and weight tensor (shape={weight_dims}) " - "are not compatible" - ) - if weight_dims[0] != output_dims[-1]: - raise ValueError( - f"Grad output tensor (shape={output_dims}) " - f"and weight tensor (shape={weight_dims}) " - "are not compatible" - ) - if grad_input is not None and tuple(grad_input.size()) != input_dims: - raise ValueError( - f"Grad input tensor (shape={tuple(grad_input.size())}) " - f"does not match expected shape ({input_dims})" - ) - - # Check grad input tensor - if not input_requires_grad: - grad_input = None - if grad_input is not None and not devices_match(grad_input.device, device): - raise ValueError( - f"Grad input tensor has invalid device (expected {device}, got {grad_input.device})" - ) - if grad_input is not None and grad_input.dtype != dtype: - raise ValueError( - f"Grad input tensor has invalid dtype (expected {dtype}, got {grad_input.dtype})" + # Check grad output tensor + dy_local = grad_output + dy = None + dy_async = None + with_dy_all_gather = tensor_parallel_mode == "row" and sequence_parallel + if with_quantized_compute: + if grad_output_quantizer is None: + raise ValueError("Missing quantizer for grad output tensor") + grad_output_quantizer.set_usage( + rowwise=input_requires_grad, + columnwise=weight_requires_grad, ) - if accumulate_into_grad_input: - if grad_input is None: - raise ValueError( - "Attempted to accumulate into grad input tensor " - "without providing grad input tensor" - ) - if tensor_parallel_mode == "column": - raise ValueError( - "Accumulating into grad input tensor " - "is not supported with column tensor parallelism" + if with_dy_all_gather: + dy, dy_async = gather_along_first_dim( + dy_local, + tensor_parallel_group, + async_op=True, + quantizer=grad_output_quantizer, ) - - # Check if FP8 is enabled - if with_fp8_compute: - if grad_output_fp8_meta is None and not is_float8_tensor(grad_output): - raise ValueError("No FP8 metadata was provided for casting output gradient to FP8") - else: - input_fp8_meta = None - weight_fp8_meta = None - grad_output_fp8_meta = None - grad_input_fp8_meta = None - with_fp8_grad_input = ( - with_fp8_compute and input_requires_grad and tensor_parallel_mode != "column" - ) - if grad_input is None: - with_fp8_grad_input = with_fp8_grad_input and grad_input_fp8_meta is not None + else: + if not isinstance(dy_local, QuantizedTensor): + dy_local = grad_output_quantizer(dy_local) + dy = dy_local else: - if is_float8_tensor(grad_input): - if not with_fp8_grad_input: - raise ValueError( - "Grad input tensor is a Float8Tensor, but FP8 output is not supported" - ) - grad_input._reset_caches() + if isinstance(dy_local, QuantizedTensor): + dy_local = dy_local.dequantize() + if dy_local.dtype != dtype: + dy_local = dy_local.to(dtype=dtype) + if with_dy_all_gather: + dy, dy_async = gather_along_first_dim( + dy_local, + tensor_parallel_group, + async_op=True, + ) else: - with_fp8_grad_input = False - - # Check grad output tensor - dy_async = None - dy = reshape( - grad_output, - (-1, output_dims[-1]), - device=device, - dtype=dtype, - ) - if with_fp8_compute and not is_float8_tensor(dy): - fp8_dtype = get_fp8_te_dtype( - grad_output_fp8_meta["recipe"], - fprop_tensor=False, - ) - with_transpose_cache = weight_requires_grad - if tensor_parallel_mode == "row" and sequence_parallel: - with_transpose_cache = False - dy = Float8Tensor.to_float8( - dy, - fp8_meta=grad_output_fp8_meta, - fp8_meta_forward=False, - fp8_meta_index=0, - fp8_dtype=fp8_dtype, - with_transpose_cache=with_transpose_cache, - ) - elif not with_fp8_compute and is_float8_tensor(dy): - dy = dy.dequantize() - if tensor_parallel_mode == "row" and sequence_parallel: - dy, dy_async = gather_along_first_dim( - dy, - tensor_parallel_group, - async_op=True, - ) + dy = dy_local # Check input tensor x = None @@ -852,35 +667,36 @@ def _functional_backward( if weight_requires_grad: if input is None: raise ValueError("Input tensor is required to compute weight grad") - x_local = reshape( - input, - (-1, input_dims[-1]), - device=device, - dtype=dtype, - ) - x_is_sharded = tensor_parallel_mode == "column" and sequence_parallel - if with_fp8_compute and not is_float8_tensor(x_local): - fp8_dtype = get_fp8_te_dtype( - input_fp8_meta["recipe"], - fprop_tensor=True, - ) - x_local = Float8Tensor.to_float8( - x_local, - fp8_meta=input_fp8_meta, - fp8_meta_forward=True, - fp8_meta_index=0, - fp8_dtype=fp8_dtype, - with_transpose_cache=(not x_is_sharded), - ) - elif not with_fp8_compute and is_float8_tensor(x_local): - x_local = x_local.from_float8() - x = x_local - if x_is_sharded: - x, x_async = gather_along_first_dim( - x_local, - tensor_parallel_group, - async_op=True, - ) + x_local = input + with_x_all_gather = tensor_parallel_mode == "column" and sequence_parallel + if with_quantized_compute: + if input_quantizer is None: + raise ValueError("Missing quantizer for input tensor") + input_quantizer.set_usage(columnwise=True) + if with_x_all_gather: + x, x_async = gather_along_first_dim( + x_local, + tensor_parallel_group, + async_op=True, + quantizer=input_quantizer, + ) + else: + if not isinstance(x_local, QuantizedTensor): + x_local = input_quantizer(x_local) + x = x_local + else: + if isinstance(x_local, QuantizedTensor): + x_local = x_local.dequantize() + if x_local.dtype != dtype: + x_local = x_local.to(dtype=dtype) + if with_x_all_gather: + x, x_async = gather_along_first_dim( + x_local, + tensor_parallel_group, + async_op=True, + ) + else: + x = x_local # Compute grad input dx = None @@ -890,110 +706,80 @@ def _functional_backward( # Check weight tensor if weight is None: raise ValueError("Weight tensor is required to compute input grad") - w = convert_tensor( - weight, - device=device, - dtype=dtype, - memory_format=torch.contiguous_format, - ) - if with_fp8_compute and not is_float8_tensor(w): - fp8_dtype = get_fp8_te_dtype( - weight_fp8_meta["recipe"], - fprop_tensor=True, - ) - w = Float8Tensor.to_float8( - w, - fp8_meta=weight_fp8_meta, - fp8_meta_forward=True, - fp8_meta_index=0, - fp8_dtype=fp8_dtype, - with_transpose_cache=True, - ) - elif not with_fp8_compute and is_float8_tensor(w): + w = weight + w_is_quantized = isinstance(w, QuantizedTensor) + if with_quantized_compute and not w_is_quantized: + if weight_quantizer is None: + raise ValueError("Missing quantizer for weight tensor") + weight_quantizer.set_usage(columnwise=True) + w = weight_quantizer(w) + elif not with_quantized_compute and w_is_quantized: w = w.dequantize() + if not with_quantized_compute and w.dtype != dtype: + w = w.to(dtype=dtype) - # Construct grad input tensor - if grad_input is not None: - dx = reshape(grad_input, (-1, input_dims[-1])) - elif with_fp8_grad_input: - fp8_dtype = get_fp8_te_dtype( - grad_input_fp8_meta["recipe"], - fprop_tensor=False, - ) - data = torch.empty( - (dy.size(0), weight_dims[1]), - dtype=torch.uint8, - device=device, - ) - dx = Float8Tensor( - data=data, - fp8_meta=grad_input_fp8_meta, - fp8_meta_forward=False, - fp8_meta_index=0, - fp8_dtype=fp8_dtype, - dtype=dtype, - ) - else: - dx = torch.empty( - (dy.size(0), weight_dims[1]), - dtype=dtype, - device=device, - ) - - # Perform dgrad GEMM + # Synchronize tensor-parallel communication _wait_async(dy_async) dy_async = None - if with_fp8_compute: - kwargs = {"accumulate": accumulate_into_grad_input, "out": dx} - if with_fp8_grad_input: - if dx._fp8_meta is None: - # Hackily create FP8TensorMeta if needed - fp8_meta = FP8TensorMeta() - fp8_meta.scale = dx._scale_inv.reciprocal() - fp8_meta.amax_history = torch.empty( - 1, 1, dtype=torch.float32, device=device - ) - fp8_meta.scale_inv = dx._scale_inv - fp8_meta_index = 0 - else: - # Get FP8TensorMeta from Float8Tensor - fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( - forward=dx._fp8_meta_forward, - ) - fp8_meta = dx._fp8_meta[fp8_meta_key] - fp8_meta_index = dx._fp8_meta_index - kwargs.update( - { - "out": dx._data, - "out_index": fp8_meta_index, - "fp8_meta_tensor": fp8_meta, - "D_dtype": dx._fp8_dtype, - } + + # Check grad input tensor + dx = grad_input + if dx is None: + if not with_quantized_compute: + grad_input_quantizer = None + if tensor_parallel_mode == "column": + grad_input_quantizer = None + elif isinstance(dx, QuantizedTensor): + if not with_quantized_compute: + raise ValueError( + "Grad input tensor is quantized, but quantized compute is not enabled" + ) + if tensor_parallel_mode == "column": + raise ValueError( + "Grad input tensor is quantized, " + "but column tensor parallelism does not support quantized grad input" + ) + if grad_input_quantizer is None: + grad_input_quantizer = getattr(dx, "_quantizer", None) + if grad_input_quantizer is None: + raise ValueError( + "Grad input tensor is quantized, but quantizer was not provided" ) - fp8_gemm( - w.transpose_2d(), - w._scale_inv, - 0, - w._fp8_dtype, - dy._data, - dy._scale_inv, - 0, - dy._fp8_dtype, - dx.dtype, - get_workspace(), - **kwargs, - ) else: - gemm( - w, - dy, - dx.dtype, - get_workspace(), - accumulate=accumulate_into_grad_input, - layout="NN", - out=dx, + grad_input_quantizer = None + if isinstance(grad_input_quantizer, MXFP8Quantizer): + raise RuntimeError( + "Attempting to generate MXFP8 grad input tensor, " + "but GEMM with MXFP8 output is not supported" ) + # Check if accumulating into grad input tensor + if accumulate_into_grad_input: + if dx is None: + raise ValueError( + "Attempted to accumulate into grad input tensor " + "without providing grad input tensor" + ) + if tensor_parallel_mode == "column": + raise ValueError( + "Accumulating into grad input tensor " + "is not supported with column tensor parallelism" + ) + + # Perform dgrad GEMM + dx, *_ = general_gemm( + w, + dy, + get_workspace(), + out_dtype=dtype, + quantization_params=grad_input_quantizer, + accumulate=accumulate_into_grad_input, + layout="NN", + out=dx, + use_split_accumulator=_2X_ACC_DGRAD, + grad=True, + ) + # Reduce tensor-parallel grad input if needed if tensor_parallel_mode == "column": if sequence_parallel: @@ -1009,59 +795,46 @@ def _functional_backward( async_op=True, ) - # Perform wgrad GEMM - if not weight_requires_grad: - grad_weight = None - else: - if grad_weight is None: + # Compute grad weight + dw = None + if weight_requires_grad: + + # Synchronize tensor-parallel communication + _wait_async(x_async) + _wait_async(dy_async) + x_async = None + dy_async = None + + # Check grad input tensor + dw = grad_weight + dw_dtype = dtype + if dw is None: if accumulate_into_grad_weight: raise ValueError( - "Attempted to accumulate into grad weight buffer" - "without providing grad weight" + "Attempted to accumulate into grad weight tensor " + "without providing grad weight tensor" ) - grad_weight = torch.empty( - weight_dims, - dtype=dtype, - device=device, - memory_format=torch.contiguous_format, - ) - _wait_async(dy_async) - _wait_async(x_async) - dy_async = None - x_async = None - if with_fp8_compute: - fp8_gemm( - x.transpose_2d(), - x._scale_inv, - 0, - x._fp8_dtype, - dy.transpose_2d(), - dy._scale_inv, - 0, - dy._fp8_dtype, - grad_weight.dtype, - get_workspace(), - accumulate=accumulate_into_grad_weight, - out=grad_weight, - ) else: - gemm( - x, - dy, - x.dtype, - get_workspace(), - accumulate=accumulate_into_grad_weight, - layout="NT", - out=grad_weight, - ) + dw_dtype = dw.dtype + + # Perform wgrad GEMM + dw, *_ = general_gemm( + x, + dy, + get_workspace(), + out_dtype=dw_dtype, + accumulate=accumulate_into_grad_weight, + layout="NT", + out=dw, + use_split_accumulator=_2X_ACC_WGRAD, + grad=True, + ) # Clean up and return grads _wait_async(dy_async) _wait_async(x_async) _wait_async(dx_async) - if dx is not None and grad_input is None: - grad_input = reshape(dx, input_dims) - return grad_input, grad_weight + return dx, dw def op_forward( self, @@ -1071,21 +844,33 @@ def op_forward( next_op: Optional[BasicOperation] = None, ) -> torch.Tensor: + # Check which grads are required + input_requires_grad = ctx.requires_grad and input_.requires_grad + weight_requires_grad = ctx.requires_grad and self.weight.requires_grad + # FP8 metadata - with_fp8_compute = FP8GlobalStateManager.is_fp8_enabled() - input_fp8_meta = None - weight_fp8_meta = None - output_fp8_meta = None - grad_output_fp8_meta = None - grad_input_fp8_meta = None - if with_fp8_compute: - input_fp8_meta = self.get_fp8_meta("input") - weight_fp8_meta = self.get_fp8_meta("param") - if next_op is not None and next_op.num_fp8_scales("input") > 0: - output_fp8_meta = next_op.get_fp8_meta("input") - grad_output_fp8_meta = self.get_fp8_meta("grad_output") - if prev_op is not None and prev_op.num_fp8_scales("grad_output") > 0: - grad_input_fp8_meta = prev_op.get_fp8_meta("grad_output") + with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() + input_quantizer = None + weight_quantizer = None + output_quantizer = None + grad_output_quantizer = None + grad_input_quantizer = None + if with_quantized_compute: + + # Get quantizers + input_quantizer = self.get_quantizer("forward", 0) + weight_quantizer = self.get_quantizer("forward", 1) + if next_op is not None and next_op.num_quantizers("forward") > 0: + output_quantizer = next_op.get_quantizer("forward", 0) + grad_output_quantizer = self.get_quantizer("backward", 0) + if prev_op is not None and prev_op.num_quantizers("backward") > 0: + grad_input_quantizer = prev_op.get_quantizer("backward", 0) + + # Configure quantizers + # Note: We cache the quantized input for backward pass, + # but discard the quantized weights. + input_quantizer.set_usage(columnwise=weight_requires_grad) + weight_quantizer.set_usage(columnwise=False) # Get autocast dtype if needed dtype = None @@ -1096,27 +881,26 @@ def op_forward( output, x_local, _ = BasicLinear._functional_forward( input=input_, weight=self.weight, - device=self.device, dtype=dtype, tensor_parallel_mode=self.tensor_parallel_mode, tensor_parallel_group=self.tensor_parallel_group, sequence_parallel=self.sequence_parallel, - with_fp8_compute=with_fp8_compute, - input_fp8_meta=input_fp8_meta, - weight_fp8_meta=weight_fp8_meta, - output_fp8_meta=output_fp8_meta, + with_quantized_compute=with_quantized_compute, + input_quantizer=input_quantizer, + weight_quantizer=weight_quantizer, + output_quantizer=output_quantizer, ) # Save state for backward pass ctx.save_for_backward(x_local) - ctx.with_fp8_compute = with_fp8_compute - ctx.weight_fp8_meta = weight_fp8_meta - ctx.grad_output_fp8_meta = grad_output_fp8_meta - ctx.grad_input_fp8_meta = grad_input_fp8_meta + ctx.with_quantized_compute = with_quantized_compute + ctx.input_quantizer = input_quantizer + ctx.weight_quantizer = weight_quantizer + ctx.grad_output_quantizer = grad_output_quantizer + ctx.grad_input_quantizer = grad_input_quantizer ctx.dtype = dtype - ctx.input_dims = input_.size() - ctx.input_requires_grad = input_.requires_grad - ctx.weight_requires_grad = self.weight.requires_grad + ctx.input_requires_grad = input_requires_grad + ctx.weight_requires_grad = weight_requires_grad ctx.has_prev_op = prev_op is not None return output @@ -1149,21 +933,19 @@ def op_backward( grad_output=grad_output, input=x_local, weight=self.weight, - input_dims=ctx.input_dims, - weight_dims=self.weight.size(), input_requires_grad=ctx.input_requires_grad, weight_requires_grad=ctx.weight_requires_grad, - device=self.device, dtype=ctx.dtype, grad_weight=grad_weight, accumulate_into_grad_weight=accumulate_into_main_grad, tensor_parallel_mode=self.tensor_parallel_mode, tensor_parallel_group=self.tensor_parallel_group, sequence_parallel=self.sequence_parallel, - with_fp8_compute=ctx.with_fp8_compute, - weight_fp8_meta=ctx.weight_fp8_meta, - grad_output_fp8_meta=ctx.grad_output_fp8_meta, - grad_input_fp8_meta=ctx.grad_input_fp8_meta, + with_quantized_compute=ctx.with_quantized_compute, + input_quantizer=ctx.input_quantizer, + weight_quantizer=ctx.weight_quantizer, + grad_output_quantizer=ctx.grad_output_quantizer, + grad_input_quantizer=ctx.grad_input_quantizer, ) # Clear input tensor if possible diff --git a/transformer_engine/pytorch/ops/basic/layer_norm.py b/transformer_engine/pytorch/ops/basic/layer_norm.py index 65717d5fa5..c5897486e3 100644 --- a/transformer_engine/pytorch/ops/basic/layer_norm.py +++ b/transformer_engine/pytorch/ops/basic/layer_norm.py @@ -13,13 +13,9 @@ import torch from transformer_engine_torch import layernorm_bwd, layernorm_fwd -from ...cpp_extensions import ( - layernorm_fwd_fp8, - layernorm_fwd_fp8_inf, - layernorm_fwd_inf, -) -from ...fp8 import FP8GlobalStateManager, get_fp8_te_dtype -from ...tensor import Float8Tensor, QuantizedTensor +from ...fp8 import FP8GlobalStateManager +from ...constants import TE_DType +from ...tensor import QuantizedTensor from ...utils import ( canonicalize_device, canonicalize_dtype, @@ -213,60 +209,28 @@ def op_forward( # Check if backward pass is needed requires_grad = ctx.requires_grad - # Check if FP8 is enabled - with_fp8_output = ( + # Check if output is quantized + output_quantizer = None + if ( FP8GlobalStateManager.is_fp8_enabled() and next_op is not None - and next_op.num_fp8_scales("input") > 0 - ) - output_fp8_meta = None - if with_fp8_output: - output_fp8_meta = next_op.get_fp8_meta("input") + and next_op.num_quantizers("forward") > 0 + ): + output_quantizer = next_op.get_quantizer("forward", 0) # Compute layer norm - y = None - means = None - rstdevs = None sm_margin = self._sm_margins["forward" if requires_grad else "inference"] - if with_fp8_output: - fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(forward=True) - fp8_dtype = get_fp8_te_dtype(output_fp8_meta["recipe"], fprop_tensor=True) - args = ( - x, - w, - b, - self.eps, - output_fp8_meta[fp8_meta_key], - 0, # fp8_meta_index - fp8_dtype, - sm_margin, - self.zero_centered_gamma, - ) - if requires_grad: - data, means, rstdevs = layernorm_fwd_fp8(*args) - else: - data = layernorm_fwd_fp8_inf(*args) - y = Float8Tensor( - data=data, - fp8_meta=output_fp8_meta, - fp8_meta_forward=True, - fp8_meta_index=0, - fp8_dtype=fp8_dtype, - dtype=dtype, - ) - else: - args = ( - x, - w, - b, - self.eps, - sm_margin, - self.zero_centered_gamma, - ) - if requires_grad: - y, means, rstdevs = layernorm_fwd(*args) - else: - y = layernorm_fwd_inf(*args) + y, means, rstdevs = layernorm_fwd( + x, + w, + b, + self.eps, + None, + output_quantizer, + TE_DType[dtype], + sm_margin, + self.zero_centered_gamma, + ) # Save state for backward pass if requires_grad: diff --git a/transformer_engine/pytorch/ops/basic/quantize.py b/transformer_engine/pytorch/ops/basic/quantize.py index e3755decd6..448954fc69 100644 --- a/transformer_engine/pytorch/ops/basic/quantize.py +++ b/transformer_engine/pytorch/ops/basic/quantize.py @@ -9,8 +9,8 @@ import torch -from ...fp8 import FP8GlobalStateManager, get_fp8_te_dtype -from ...tensor import Float8Tensor, QuantizedTensor +from ...fp8 import FP8GlobalStateManager +from ...tensor import QuantizedTensor from ..op import BasicOperation, OperationContext @@ -38,10 +38,10 @@ def __init__( self._quantize_forward = forward self._quantize_backward = backward - def num_fp8_scales(self, mode: str) -> int: - if mode == "input" and self._quantize_forward: + def num_quantizers(self, mode: str) -> int: + if mode == "forward" and self._quantize_forward: return 1 - if mode == "grad_output" and self._quantize_backward: + if mode == "backward" and self._quantize_backward: return 1 return 0 @@ -61,15 +61,7 @@ def op_forward( # Quantize if needed out = input_ if quantize_forward and not isinstance(out, QuantizedTensor): - fp8_meta = self.get_fp8_meta("input") - fp8_dtype = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) - out = Float8Tensor.to_float8( - out, - fp8_meta=fp8_meta, - fp8_meta_forward=True, - fp8_meta_index=0, - fp8_dtype=fp8_dtype, - ) + out = self.get_quantizer("forward", 0)(out) ctx.quantize_backward = quantize_backward return out @@ -81,13 +73,5 @@ def op_backward( ) -> tuple[torch.Tensor, tuple[()]]: grad_input = grad_output if ctx.quantize_backward and not isinstance(grad_input, QuantizedTensor): - fp8_meta = self.get_fp8_meta("grad_output") - fp8_dtype = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=False) - grad_input = Float8Tensor.to_float8( - grad_input, - fp8_meta=fp8_meta, - fp8_meta_forward=False, - fp8_meta_index=0, - fp8_dtype=fp8_dtype, - ) + grad_input = self.get_quantizer("backward", 0)(grad_input) return grad_input, () diff --git a/transformer_engine/pytorch/ops/basic/reduce_scatter.py b/transformer_engine/pytorch/ops/basic/reduce_scatter.py index 03a02786b4..adfd46641b 100644 --- a/transformer_engine/pytorch/ops/basic/reduce_scatter.py +++ b/transformer_engine/pytorch/ops/basic/reduce_scatter.py @@ -9,9 +9,9 @@ import torch -from ...tensor import Float8Tensor, QuantizedTensor +from ...distributed import gather_along_first_dim +from ...tensor import QuantizedTensor from ..op import BasicOperation, OperationContext -from .._common import convert_tensor class ReduceScatter(BasicOperation): @@ -45,7 +45,7 @@ def op_forward( # Trivial case if self.process_group_size == 1: - return input_ + return input_.detach() # Tensor dimensions input_dims = input_.size() @@ -74,47 +74,9 @@ def op_backward( ctx: OperationContext, grad_output: torch.Tensor, ) -> tuple[torch.Tensor, tuple[()]]: - - # Trivial case + grad_input: torch.Tensor if self.process_group_size == 1: - return grad_output, () - - # Tensor dimensions - output_dims = grad_output.size() - if not output_dims: - raise RuntimeError( - "Attempted to all-gather a tensor " - f"with shape={list(output_dims)} " - f"over {self.process_group_size} processes" - ) - input_dims = list(output_dims) - input_dims[0] *= self.process_group_size - - # Perform all-gather - dy = convert_tensor(grad_output, memory_format=torch.contiguous_format) - dx = None - if isinstance(dy, Float8Tensor): - dx = Float8Tensor.make_like( - dy, - data=torch.empty( - input_dims, - dtype=torch.uint8, - device=dy.device, - ), - ) - torch.distributed.all_gather_into_tensor( - dx._data, - dy._data, - group=self.process_group, - ) + grad_input = grad_output.detach() else: - if isinstance(dy, QuantizedTensor): - dy = dy.dequantize() - dx = torch.empty(input_dims, dtype=dy.dtype, device=dy.device) - torch.distributed.all_gather_into_tensor( - dx, - dy, - group=self.process_group, - ) - - return dx, () + grad_input, _ = gather_along_first_dim(grad_output, self.process_group) + return grad_input, () diff --git a/transformer_engine/pytorch/ops/basic/reshape.py b/transformer_engine/pytorch/ops/basic/reshape.py index 53524cdd83..1e9095169c 100644 --- a/transformer_engine/pytorch/ops/basic/reshape.py +++ b/transformer_engine/pytorch/ops/basic/reshape.py @@ -14,7 +14,6 @@ BasicOperation, OperationContext, ) -from .._common import reshape class Reshape(BasicOperation): @@ -42,11 +41,11 @@ def op_forward( next_op: Optional[BasicOperation] = None, ) -> torch.Tensor: ctx.input_shape = input_.size() - return reshape(input_, self._shape) + return input_.reshape(*self._shape) def op_backward( self, ctx: OperationContext, grad_output: torch.Tensor, ) -> tuple[torch.Tensor, tuple[()]]: - return reshape(grad_output, ctx.input_shape), () + return grad_output.reshape(*ctx.input_shape), () diff --git a/transformer_engine/pytorch/ops/basic/rmsnorm.py b/transformer_engine/pytorch/ops/basic/rmsnorm.py index 32ef242b90..c1f32af93a 100644 --- a/transformer_engine/pytorch/ops/basic/rmsnorm.py +++ b/transformer_engine/pytorch/ops/basic/rmsnorm.py @@ -13,13 +13,9 @@ import torch from transformer_engine_torch import rmsnorm_bwd, rmsnorm_fwd -from ...cpp_extensions import ( - rmsnorm_fwd_fp8, - rmsnorm_fwd_fp8_inf, - rmsnorm_fwd_inf, -) -from ...fp8 import FP8GlobalStateManager, get_fp8_te_dtype -from ...tensor import Float8Tensor, QuantizedTensor +from ...fp8 import FP8GlobalStateManager +from ...tensor import QuantizedTensor +from ...constants import TE_DType from ...utils import ( canonicalize_device, canonicalize_dtype, @@ -193,57 +189,27 @@ def op_forward( # Check if backward pass is needed requires_grad = ctx.requires_grad - # Check if FP8 is enabled - with_fp8_output = ( + # Check if output is quantized + output_quantizer = None + if ( FP8GlobalStateManager.is_fp8_enabled() and next_op is not None - and next_op.num_fp8_scales("input") > 0 - ) - output_fp8_meta = None - if with_fp8_output: - output_fp8_meta = next_op.get_fp8_meta("input") + and next_op.num_quantizers("forward") > 0 + ): + output_quantizer = next_op.get_quantizer("forward", 0) # Compute RMSNorm - y = None - rstdevs = None sm_margin = self._sm_margins["forward" if requires_grad else "inference"] - if with_fp8_output: - fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(forward=True) - fp8_dtype = get_fp8_te_dtype(output_fp8_meta["recipe"], fprop_tensor=True) - args = ( - x, - w, - self.eps, - output_fp8_meta[fp8_meta_key], - 0, # fp8_meta_index - fp8_dtype, - sm_margin, - self.zero_centered_gamma, - ) - if requires_grad: - data, rstdevs = rmsnorm_fwd_fp8(*args) - else: - data = rmsnorm_fwd_fp8_inf(*args) - y = Float8Tensor( - data=data, - fp8_meta=output_fp8_meta, - fp8_meta_forward=True, - fp8_meta_index=0, - fp8_dtype=fp8_dtype, - dtype=dtype, - ) - else: - args = ( - x, - w, - self.eps, - sm_margin, - self.zero_centered_gamma, - ) - if requires_grad: - y, rstdevs = rmsnorm_fwd(*args) - else: - y = rmsnorm_fwd_inf(*args) + y, _, rstdevs = rmsnorm_fwd( + x, + w, + self.eps, + None, + output_quantizer, + TE_DType[dtype], + sm_margin, + self.zero_centered_gamma, + ) # Save state for backward pass if requires_grad: diff --git a/transformer_engine/pytorch/ops/fused/backward_linear_add.py b/transformer_engine/pytorch/ops/fused/backward_linear_add.py index 1ddd8d116c..e295929e98 100644 --- a/transformer_engine/pytorch/ops/fused/backward_linear_add.py +++ b/transformer_engine/pytorch/ops/fused/backward_linear_add.py @@ -73,11 +73,8 @@ def fuser_backward( grad_output=grad_output, input=x_local, weight=linear_op.weight, - input_dims=linear_op_ctx.input_dims, - weight_dims=linear_op.weight.size(), input_requires_grad=linear_op_ctx.input_requires_grad, weight_requires_grad=linear_op_ctx.weight_requires_grad, - device=linear_op.device, dtype=grad_input.dtype, grad_weight=grad_weight, accumulate_into_grad_weight=accumulate_into_main_grad, @@ -86,10 +83,11 @@ def fuser_backward( tensor_parallel_mode=linear_op.tensor_parallel_mode, tensor_parallel_group=linear_op.tensor_parallel_group, sequence_parallel=linear_op.sequence_parallel, - with_fp8_compute=linear_op_ctx.with_fp8_compute, - weight_fp8_meta=linear_op_ctx.weight_fp8_meta, - grad_output_fp8_meta=linear_op_ctx.grad_output_fp8_meta, - grad_input_fp8_meta=linear_op_ctx.grad_input_fp8_meta, + with_quantized_compute=linear_op_ctx.with_quantized_compute, + input_quantizer=linear_op_ctx.input_quantizer, + weight_quantizer=linear_op_ctx.weight_quantizer, + grad_output_quantizer=linear_op_ctx.grad_output_quantizer, + grad_input_quantizer=linear_op_ctx.grad_input_quantizer, ) if accumulate_into_main_grad: grad_weight = None diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py index c746f21f2c..6088b3c0db 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py @@ -83,22 +83,22 @@ def fuser_forward( raise NotImplementedError("Activations are not yet supported") # FP8 metadata - with_fp8_compute = FP8GlobalStateManager.is_fp8_enabled() - input_fp8_meta = None - weight_fp8_meta = None - output_fp8_meta = None - grad_output_fp8_meta = None - grad_input_fp8_meta = None - if with_fp8_compute: - input_fp8_meta = linear_op.get_fp8_meta("input") - weight_fp8_meta = linear_op.get_fp8_meta("param") + with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() + input_quantizer = None + weight_quantizer = None + output_quantizer = None + grad_output_quantizer = None + grad_input_quantizer = None + if with_quantized_compute: + input_quantizer = linear_op.get_quantizer("forward", 0) + weight_quantizer = linear_op.get_quantizer("forward", 1) next_op = basic_op_next_ops[-1] - if next_op is not None and next_op.num_fp8_scales("input") > 0: - output_fp8_meta = next_op.get_fp8_meta("input") - grad_output_fp8_meta = linear_op.get_fp8_meta("grad_output") + if next_op is not None and next_op.num_quantizers("forward") > 0: + output_quantizer = next_op.get_quantizer("forward", 0) + grad_output_quantizer = linear_op.get_quantizer("backward", 0) prev_op = basic_op_prev_ops[0] - if prev_op is not None and prev_op.num_fp8_scales("grad_output") > 0: - grad_input_fp8_meta = prev_op.get_fp8_meta("grad_output") + if prev_op is not None and prev_op.num_quantizers("backward") > 0: + grad_input_quantizer = prev_op.get_quantizer("backward", 0) # Get autocast dtype if needed dtype = None @@ -110,25 +110,24 @@ def fuser_forward( input=input_, weight=linear_op.weight, bias=bias, - device=linear_op.device, dtype=dtype, tensor_parallel_mode=linear_op.tensor_parallel_mode, tensor_parallel_group=linear_op.tensor_parallel_group, sequence_parallel=linear_op.sequence_parallel, - with_fp8_compute=with_fp8_compute, - input_fp8_meta=input_fp8_meta, - weight_fp8_meta=weight_fp8_meta, - output_fp8_meta=output_fp8_meta, + with_quantized_compute=with_quantized_compute, + input_quantizer=input_quantizer, + weight_quantizer=weight_quantizer, + output_quantizer=output_quantizer, ) # Save state for backward pass linear_op_ctx.save_for_backward(x_local) - linear_op_ctx.with_fp8_compute = with_fp8_compute - linear_op_ctx.weight_fp8_meta = weight_fp8_meta - linear_op_ctx.grad_output_fp8_meta = grad_output_fp8_meta - linear_op_ctx.grad_input_fp8_meta = grad_input_fp8_meta + linear_op_ctx.with_quantized_compute = with_quantized_compute + linear_op_ctx.input_quantizer = input_quantizer + linear_op_ctx.weight_quantizer = weight_quantizer + linear_op_ctx.grad_output_quantizer = grad_output_quantizer + linear_op_ctx.grad_input_quantizer = grad_input_quantizer linear_op_ctx.dtype = dtype - linear_op_ctx.input_dims = input_.size() linear_op_ctx.input_requires_grad = input_.requires_grad linear_op_ctx.weight_requires_grad = linear_op.weight.requires_grad linear_op_ctx.has_prev_op = basic_op_prev_ops[0] is not None diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py index fa7f07cb95..69b0c3ba5a 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py @@ -77,19 +77,19 @@ def fuser_forward( raise ValueError("Bias operation forward does not expect keyword arguments") # FP8 metadata - with_fp8_compute = FP8GlobalStateManager.is_fp8_enabled() - input_fp8_meta = None - weight_fp8_meta = None - output_fp8_meta = None - grad_output_fp8_meta = None - grad_input_fp8_meta = None - if with_fp8_compute: - input_fp8_meta = linear_op.get_fp8_meta("input") - weight_fp8_meta = linear_op.get_fp8_meta("param") - grad_output_fp8_meta = linear_op.get_fp8_meta("grad_output") + with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() + input_quantizer = None + weight_quantizer = None + output_quantizer = None + grad_output_quantizer = None + grad_input_quantizer = None + if with_quantized_compute: + input_quantizer = linear_op.get_quantizer("forward", 0) + weight_quantizer = linear_op.get_quantizer("forward", 1) + grad_output_quantizer = linear_op.get_quantizer("backward", 0) prev_op = basic_op_prev_ops[0] - if prev_op is not None and prev_op.num_fp8_scales("grad_output") > 0: - grad_input_fp8_meta = prev_op.get_fp8_meta("grad_output") + if prev_op is not None and prev_op.num_quantizers("backward") > 0: + grad_input_quantizer = prev_op.get_quantizer("backward", 0) # Get autocast dtype if needed dtype = None @@ -102,26 +102,25 @@ def fuser_forward( input=input_, weight=linear_op.weight, bias=bias, - device=linear_op.device, out=output, accumulate_into_out=True, tensor_parallel_mode=linear_op.tensor_parallel_mode, tensor_parallel_group=linear_op.tensor_parallel_group, sequence_parallel=linear_op.sequence_parallel, - with_fp8_compute=with_fp8_compute, - input_fp8_meta=input_fp8_meta, - weight_fp8_meta=weight_fp8_meta, - output_fp8_meta=output_fp8_meta, + with_quantized_compute=with_quantized_compute, + input_quantizer=input_quantizer, + weight_quantizer=weight_quantizer, + output_quantizer=output_quantizer, ) # Save state for backward pass linear_op_ctx.save_for_backward(x_local) - linear_op_ctx.with_fp8_compute = with_fp8_compute - linear_op_ctx.weight_fp8_meta = weight_fp8_meta - linear_op_ctx.grad_output_fp8_meta = grad_output_fp8_meta - linear_op_ctx.grad_input_fp8_meta = grad_input_fp8_meta + linear_op_ctx.with_quantized_compute = with_quantized_compute + linear_op_ctx.input_quantizer = input_quantizer + linear_op_ctx.weight_quantizer = weight_quantizer + linear_op_ctx.grad_output_quantizer = grad_output_quantizer + linear_op_ctx.grad_input_quantizer = grad_input_quantizer linear_op_ctx.dtype = dtype - linear_op_ctx.input_dims = input_.size() linear_op_ctx.input_requires_grad = input_.requires_grad linear_op_ctx.weight_requires_grad = linear_op.weight.requires_grad linear_op_ctx.has_prev_op = basic_op_prev_ops[0] is not None diff --git a/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py b/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py index dab4c8f681..bbb27f86e6 100644 --- a/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py +++ b/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py @@ -4,6 +4,8 @@ """Linear layer backward with Userbuffers communication.""" +# pylint: skip-file ### TODO Debug Userbuffers support + from __future__ import annotations from collections.abc import Iterable from typing import Any, Optional @@ -12,11 +14,7 @@ import torch from transformer_engine_torch import CommOverlapAlgo -from ...cpp_extensions import ( - fp8_cast_transpose_bgrad_fused, - fp8_gemm, - gemm, -) +from ...cpp_extensions import general_gemm from ...distributed import get_distributed_world_size from ...float8_tensor import Float8Tensor from ...fp8 import FP8GlobalStateManager, get_fp8_te_dtype @@ -49,6 +47,9 @@ def __init__( reduce_scatter: Optional[ReduceScatter], ) -> None: + ### TODO Debug Userbuffers support + raise NotImplementedError("Userbuffers support has been broken by recent refactors") + # Basic operations that comprise this fused operation op_idxs = {"linear": None, "bias": None, "reduce_scatter": None} ops = [] @@ -706,6 +707,8 @@ def fuse_userbuffers_backward_linear( """ + return ops ### TODO Debug Userbuffers support + # Return immediately if environment is not distributed if not torch.distributed.is_initialized() or torch.distributed.get_world_size() == 1: return ops diff --git a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py index 1f3635eb4b..a08c0a6ef9 100644 --- a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py +++ b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py @@ -4,6 +4,8 @@ """Linear layer forward with Userbuffers communication.""" +# pylint: skip-file ### TODO Debug Userbuffers support + from __future__ import annotations from collections.abc import Iterable from typing import Any, Optional @@ -11,7 +13,7 @@ import torch from transformer_engine_torch import CommOverlapAlgo -from ...cpp_extensions import fp8_gemm, gemm +from ...cpp_extensions import general_gemm from ...distributed import get_distributed_world_size from ...float8_tensor import Float8Tensor from ...fp8 import FP8GlobalStateManager, get_fp8_te_dtype @@ -49,6 +51,9 @@ def __init__( reduce_scatter: Optional[ReduceScatter], ) -> None: + ### TODO Debug Userbuffers support + raise NotImplementedError("Userbuffers support has been broken by recent refactors") + # Basic operations that comprise this fused operation op_idxs = {"linear": 0, "bias": None, "reduce_scatter": None} ops = [linear] @@ -524,6 +529,8 @@ def fuse_userbuffers_forward_linear( """ + return ops ### TODO Debug Userbuffers support + # Return immediately if environment is not distributed if not torch.distributed.is_initialized() or torch.distributed.get_world_size() == 1: return ops diff --git a/transformer_engine/pytorch/ops/op.py b/transformer_engine/pytorch/ops/op.py index 30367d2c5e..8346d31a40 100644 --- a/transformer_engine/pytorch/ops/op.py +++ b/transformer_engine/pytorch/ops/op.py @@ -13,13 +13,14 @@ import torch -import transformer_engine_torch as tex -from transformer_engine.pytorch.fp8 import ( - DelayedScaling, +from transformer_engine.common.recipe import Recipe +from ..fp8 import ( + MXFP8BlockScalingRecipeState, + DelayedScalingRecipeState, FP8GlobalStateManager, - get_default_fp8_recipe, + RecipeState, ) -from ._common import canonicalize_device +from ..tensor import Quantizer @dataclasses.dataclass @@ -174,132 +175,148 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta): def __init__(self) -> None: super().__init__() - # FP8 metadata objects + # Objects for quantization + self._quantizers: Optional[dict[str, list[Quantizer]]] = None self._fp8_metas: Optional[dict[str, dict[str, Any]]] = None @property def is_fused_op(self) -> bool: return False - def num_fp8_scales( + def num_quantizers( self, mode: str, # pylint: disable=unused-argument ) -> int: - """Number of FP8 scaling factors + """Number of quantizers + + Matches number of quantized tensors used in operation. Parameters ---------- - mode: {"input", "param", "grad_output"} - Type of FP8 scaling factor + mode: {"forward", "backward"} + Quantizer type """ return 0 - def _make_fp8_metas(self) -> dict[str, Optional[dict[str, Any]]]: - """Construct FP8 metadata""" - - # Shared objects for FP8 metadata - dtype = torch.float32 - device = canonicalize_device(None) - recipe = get_default_fp8_recipe() - - def _make_meta( - num_scales: int, - is_forward: bool, - ) -> Optional[dict[str, Any]]: - """Construct FP8 metadata for one tensor type""" - if num_scales == 0: - return None - key = FP8GlobalStateManager.get_meta_tensor_key(forward=is_forward) - meta = tex.FP8TensorMeta() - meta.scale = torch.ones(num_scales, dtype=dtype, device=device) - meta.scale_inv = torch.ones(num_scales, dtype=dtype, device=device) - meta.amax_history = torch.zeros( - (recipe.amax_history_len, num_scales), - dtype=dtype, - device=device, + def _reset_quantization_recipe_state( + self, + *, + recipe: Optional[Recipe] = None, + ) -> None: + """Construct state for quantization recipe""" + + # Quantization recipe + if recipe is None: + recipe = FP8GlobalStateManager.get_fp8_recipe() + + # Quantization recipe state for forward and backward pass + self._fp8_metas = {"forward": None, "backward": None} + self._quantizers = {"forward": [], "backward": []} + for mode in ("forward", "backward"): + num_quantizers = self.num_quantizers(mode) + if num_quantizers == 0: + continue + + # Construct quantization recipe state + recipe_state = RecipeState.create( + recipe, + mode=mode, + num_quantizers=num_quantizers, ) - return { - key: meta, + fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( + forward=(mode == "forward"), + ) + self._fp8_metas[mode] = { + fp8_meta_key: recipe_state, "recipe": recipe, - "fp8_group": None, + "fp8_group": FP8GlobalStateManager.get_fp8_group(), } - # Construct FP8 metadata for all tensor types - return { - "input": _make_meta(self.num_fp8_scales("input"), True), - "param": _make_meta(self.num_fp8_scales("param"), True), - "grad_output": _make_meta(self.num_fp8_scales("grad_output"), False), - } - - @classmethod - def _maybe_update_fp8_meta( - cls, - fp8_meta: Optional[dict[str, Any]], + # Construct builder class for quantized tensors + self._quantizers[mode] = recipe_state.make_quantizers() + + def _update_quantization_recipe_state( + self, *, - fp8_recipe: Optional[DelayedScaling] = None, + recipe: Optional[Recipe] = None, ) -> None: - if fp8_meta is None: - return + """Make sure quantizer state matches quantization recipe""" - # Update FP8 recipe - if fp8_recipe is None: - fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() - fp8_meta["recipe"] = fp8_recipe + # Quantization recipe + if recipe is None: + recipe = FP8GlobalStateManager.get_fp8_recipe() - # Update FP8 communication group - fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group() - - # Adjust amax history length if needed - amax_history_len = fp8_recipe.amax_history_len - for is_forward in (True, False): - fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(forward=is_forward) - if fp8_meta_key not in fp8_meta: + # Reset quantization state if needed + if self._fp8_metas is None or self._quantizers is None: + self._reset_quantization_recipe_state(recipe=recipe) + return + for mode in ("forward", "backward"): + fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( + forward=(mode == "forward"), + ) + if self._fp8_metas[mode] is None or fp8_meta_key not in self._fp8_metas[mode]: continue - meta = fp8_meta[fp8_meta_key] - curr_len = meta.amax_history.size(0) - - # Nothing to be done if amax history is already correct - if curr_len == amax_history_len: + recipe_state = self._fp8_metas[mode][fp8_meta_key] + need_to_reset_recipe_state = ( + recipe.delayed() and not isinstance(recipe_state, DelayedScalingRecipeState) + ) or (recipe.mxfp8() and not isinstance(recipe_state, MXFP8BlockScalingRecipeState)) + if need_to_reset_recipe_state: + self._reset_quantization_recipe_state(recipe=recipe) + return + + # Quantization recipe state for forward and backward pass + for mode in ("forward", "backward"): + num_quantizers = self.num_quantizers(mode) + if num_quantizers == 0: continue - # Reallocate amax history - with torch.no_grad(): - if curr_len > amax_history_len: - meta.amax_history = meta.amax_history[:amax_history_len].clone() - else: - meta.amax_history = torch.nn.functional.pad( - meta.amax_history, - pad=(0, 0, 0, amax_history_len - curr_len), - ) + # Update FP8 metadata + fp8_meta = self._fp8_metas[mode] + fp8_meta["recipe"] = recipe + fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group() - # Update global buffers for amax reductions - buffer_info_key = FP8GlobalStateManager.get_buffer_info() - if buffer_info_key in fp8_meta: - fwd_pos, fwd_key, bwd_pos, bwd_key = fp8_meta[buffer_info_key] - for pos, buffer_key in zip((fwd_pos, bwd_pos), (fwd_key, bwd_key)): - assert ( - buffer_key in FP8GlobalStateManager.global_amax_history_buffer - ), "TE internal error during amax history change." - FP8GlobalStateManager.global_amax_buffer[buffer_key][pos] = fp8_meta[ - fp8_meta_key - ].amax_history[0] - FP8GlobalStateManager.global_amax_history_buffer[buffer_key][pos] = fp8_meta[ - fp8_meta_key - ].amax_history - - def get_fp8_meta(self, mode: str) -> Optional[dict[str, Any]]: - """FP8 metadata + # Get recipe state + fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( + forward=(mode == "forward"), + ) + recipe_state = fp8_meta[fp8_meta_key] + + # Reallocate amax history if needed + if recipe.mxfp8(): + continue + + current_length = recipe_state.amax_history.size(0) + target_length = recipe.amax_history_len + if current_length != target_length: + with torch.no_grad(): + if target_length < current_length: + recipe_state.amax_history = recipe_state.amax_history[ + :target_length + ].clone() + else: + recipe_state.amax_history = torch.nn.functional.pad( + recipe_state.amax_history, + pad=(0, 0, 0, target_length - current_length), + ) + self._quantizers[mode] = recipe_state.make_quantizers() + + def get_quantizer( + self, + mode: str, + index: int, + ) -> Quantizer: + """Get builder class for quantized tensor Parameters ---------- - mode: {"input", "param", "grad_output"} - Type of FP8 scaling factor + mode: {"forward", "backward"} + Quantizer type """ - if self._fp8_metas is None: - self._fp8_metas = self._make_fp8_metas() - return self._fp8_metas[mode] + if self._quantizers is None: + self._reset_quantization_recipe_state() + return self._quantizers[mode][index] @torch.no_grad() def _save_fp8_metas(self) -> Optional[dict[str, Any]]: @@ -321,7 +338,6 @@ def _save_fp8_metas(self) -> Optional[dict[str, Any]]: continue out[mode][fp8_meta_key] = ( fp8_meta[fp8_meta_key].scale.clone(), - fp8_meta[fp8_meta_key].scale_inv.clone(), fp8_meta[fp8_meta_key].amax_history.clone(), ) return out @@ -346,16 +362,15 @@ def _load_fp8_metas(self, fp8_metas: Optional[dict[str, Any]]) -> None: assert ( fp8_meta_key in self._fp8_metas[mode] ), f"Found an unexpected key ({mode=}, {fp8_meta_key=}) in saved FP8 metadata" - scale, scale_inv, amax_history = tensors + scale, amax_history = tensors self._fp8_metas[mode][fp8_meta_key].scale.copy_(scale) - self._fp8_metas[mode][fp8_meta_key].scale_inv.copy_(scale_inv) self._fp8_metas[mode][fp8_meta_key].amax_history.copy_(amax_history) def pre_forward( self, *, fp8_enabled: Optional[bool] = None, - fp8_recipe: Optional[DelayedScaling] = None, + fp8_recipe: Optional[Recipe] = None, ) -> None: """Preprocessing before forward pass""" @@ -363,28 +378,15 @@ def pre_forward( if fp8_enabled is None: fp8_enabled = FP8GlobalStateManager.is_fp8_enabled() if fp8_enabled: - - # Construct FP8 metadata if needed - if self._fp8_metas is None: - self._fp8_metas = self._make_fp8_metas() - - # Make sure FP8 metadata matches FP8 autocast context - for fp8_meta in self._fp8_metas.values(): - self._maybe_update_fp8_meta(fp8_meta, fp8_recipe=fp8_recipe) - - # Register FP8 metadata for amax and scale update + self._update_quantization_recipe_state(recipe=fp8_recipe) if not FP8GlobalStateManager.fp8_graph_capturing(): - if self.num_fp8_scales("input"): - FP8GlobalStateManager.add_fp8_tensors_to_global_buffer( - self.get_fp8_meta("input"), - ) - if self.num_fp8_scales("param"): + if self.num_quantizers("forward"): FP8GlobalStateManager.add_fp8_tensors_to_global_buffer( - self.get_fp8_meta("param"), + self._fp8_metas["forward"], ) - if self.num_fp8_scales("grad_output"): + if self.num_quantizers("backward"): FP8GlobalStateManager.add_fp8_tensors_to_global_buffer( - self.get_fp8_meta("grad_output"), + self._fp8_metas["backward"], ) @abc.abstractmethod @@ -527,13 +529,6 @@ def get_extra_state(self) -> torch.Tensor: # See: https://github.com/NVIDIA/TransformerEngine/pull/351 # See: https://github.com/NVIDIA/TransformerEngine/pull/363 - # Return immediately if op has no FP8 state - has_fp8_state = any( - self.num_fp8_scales(mode) > 0 for mode in ("input", "param", "grad_output") - ) - if not has_fp8_state: - return torch.Tensor() - def to_cpu(src: torch.Tensor) -> torch.Tensor: """Helper function to make CPU copy of tensor @@ -547,25 +542,20 @@ def to_cpu(src: torch.Tensor) -> torch.Tensor: # Store FP8 state state = {} - for mode in ("input", "param", "grad_output"): + for mode in ("forward", "backward"): # Get state for a given FP8 tensor - if self.num_fp8_scales(mode) == 0: - state[mode] = None + if self.num_quantizers(mode) == 0: continue fp8_meta = self.get_fp8_meta(mode) - if fp8_meta is None: - continue state[mode] = {} # Store tensors if "scaling_fwd" in fp8_meta: state[mode]["scale_fwd"] = to_cpu(fp8_meta["scaling_fwd"].scale) - state[mode]["scale_inv_fwd"] = to_cpu(fp8_meta["scaling_fwd"].scale_inv) state[mode]["amax_history_fwd"] = to_cpu(fp8_meta["scaling_fwd"].amax_history) if "scaling_bwd" in fp8_meta: state[mode]["scale_bwd"] = to_cpu(fp8_meta["scaling_bwd"].scale) - state[mode]["scale_inv_bwd"] = to_cpu(fp8_meta["scaling_bwd"].scale_inv) state[mode]["amax_history_bwd"] = to_cpu(fp8_meta["scaling_bwd"].amax_history) # Store other picklable items @@ -591,7 +581,7 @@ def set_extra_state(self, state: Optional[torch.Tensor]) -> None: # Deserialize state from byte tensor state = pickle.loads(state.detach().numpy(force=True).tobytes()) - if state is None: + if state is None or len(state) == 0: return def copy_tensor(src: torch.Tensor, dst: torch.Tensor) -> None: @@ -606,12 +596,12 @@ def copy_tensor(src: torch.Tensor, dst: torch.Tensor) -> None: dst.copy_(src, non_blocking=True) # Load FP8 state - for mode in ("input", "param", "grad_output"): + for mode in ("forward", "backward"): # Get state for a given FP8 tensor if mode not in state: continue - if self.num_fp8_scales(mode) == 0: + if self.num_quantizers(mode) == 0: continue fp8_meta = self.get_fp8_meta(mode) if fp8_meta is None: @@ -631,12 +621,10 @@ def copy_tensor(src: torch.Tensor, dst: torch.Tensor) -> None: if "scaling_fwd" in fp8_meta: fp8_meta_fwd = fp8_meta["scaling_fwd"] copy_tensor(state[mode]["scale_fwd"], fp8_meta_fwd.scale) - copy_tensor(state[mode]["scale_inv_fwd"], fp8_meta_fwd.scale_inv) copy_tensor(state[mode]["amax_history_fwd"], fp8_meta_fwd.amax_history) if "scaling_bwd" in fp8_meta: fp8_meta_bwd = fp8_meta["scaling_bwd"] copy_tensor(state[mode]["scale_bwd"], fp8_meta_bwd.scale) - copy_tensor(state[mode]["scale_inv_bwd"], fp8_meta_bwd.scale_inv) copy_tensor(state[mode]["amax_history_bwd"], fp8_meta_bwd.amax_history) # Finish CPU-GPU memory transfers diff --git a/transformer_engine/pytorch/optimizers/fused_adam.py b/transformer_engine/pytorch/optimizers/fused_adam.py index b86c973304..d972fd96ab 100644 --- a/transformer_engine/pytorch/optimizers/fused_adam.py +++ b/transformer_engine/pytorch/optimizers/fused_adam.py @@ -8,24 +8,20 @@ import torch import transformer_engine_torch as tex -from transformer_engine.pytorch.float8_tensor import Float8Tensor -from transformer_engine.pytorch.fp8 import FP8GlobalStateManager +from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8Quantizer from .multi_tensor_apply import multi_tensor_applier -from ..float8_tensor import Float8Tensor def get_fp8_meta(fp8_tensor): """FP8 metadata getter.""" - if fp8_tensor._fp8_meta is None: - raise RuntimeError("FP8 meta data is not initialized.") + assert isinstance(fp8_tensor, Float8Tensor), "Fused optimizer supports only Float8Tensor class" + if fp8_tensor._quantizer is None: + raise RuntimeError("FP8 quantizer data is not initialized.") - fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( - forward=fp8_tensor._fp8_meta_forward, - ) + quantizer = fp8_tensor._quantizer - fp8_meta_index = fp8_tensor._fp8_meta_index - scale = fp8_tensor._fp8_meta[fp8_meta_key].scale[fp8_meta_index] - amax = fp8_tensor._fp8_meta[fp8_meta_key].amax_history[0][fp8_meta_index] + scale = quantizer.scale + amax = quantizer.amax scale_inv = fp8_tensor._scale_inv return scale, amax, scale_inv @@ -237,6 +233,10 @@ def _apply_scale(self, state_name, unscaled_state, scaled_state, scale): dtype = self.name_to_dtype_map[state_name] if dtype == torch.uint8: assert isinstance(scaled_state, Float8Tensor) + assert len(scaled_state._quantizer.scale) == 1, ( + "Only scaling with one scaling factor per tensor is supported by the" + " FusedAdam." + ) else: assert scaled_state.dtype == dtype @@ -251,7 +251,7 @@ def _apply_scale(self, state_name, unscaled_state, scaled_state, scale): absmax = absmax.to(dtype=torch.float32, device=unscaled_state.device) torch.div(absmax, max_range, out=scale) if isinstance(scaled_state, Float8Tensor): - scaled_state._scale_inv.copy_(scale) + scaled_state._quantizer.scale.copy_(1 / scale) scaled_state.copy_(unscaled_state) else: rscale = torch.where(scale > 0, scale.reciprocal(), 0.0) @@ -269,7 +269,6 @@ def get_unscaled_state(self, param, state_name): state = self.state[param] dtype = self.name_to_dtype_map[state_name] if dtype == torch.uint8: - assert isinstance(state[state_name], Float8Tensor) unscaled = state[state_name].float() elif dtype == torch.float16: assert state[state_name].dtype == torch.float16 @@ -343,12 +342,15 @@ def _initialize_state( data.zero_() if dtype == torch.uint8: - self.state[param][state_name] = Float8Tensor( - data=data, - dtype=torch.float32, - fp8_scale_inv=torch.ones([1], dtype=torch.float32, device=param.device), + quantizer = Float8Quantizer( + scale=torch.ones([1], dtype=torch.float32, device=param.device), + amax=torch.zeros([1], dtype=torch.float32, device=param.device), + fp8_dtype=tex.DType.kFloat8E4M3, ) + self.state[param][state_name] = quantizer.make_empty(param.shape) + self.state[param][state_name].quantize_(data.float()) else: + self.state[param][state_name] = data # Create scale if necessary. @@ -421,6 +423,8 @@ def load_state_dict(self, state_dict): param = id_map[k] self.state[param] = {} for name in v: + if v[name] is None: + continue if ( self.store_param_remainders and name == "master_param" diff --git a/transformer_engine/pytorch/permutation.py b/transformer_engine/pytorch/permutation.py index 264b620be8..2e6167a6e0 100644 --- a/transformer_engine/pytorch/permutation.py +++ b/transformer_engine/pytorch/permutation.py @@ -48,8 +48,12 @@ def forward( # Data type check fp8 = isinstance(inp, Float8Tensor) if fp8: + assert ( + inp._quantizer.scale.ndim == 0 + ), "Only one factor scaling per tensor (Delayed Scaling) supported by moe_permute." dtype = inp._fp8_dtype fp8_scale_inv = inp._scale_inv + fake_dtype = inp.dtype inp = inp._data else: dtype = TE_DType[inp.dtype] @@ -78,7 +82,11 @@ def forward( if fp8: permuted_act = Float8Tensor( - data=permuted_act, fp8_dtype=dtype, fp8_scale_inv=fp8_scale_inv + data=permuted_act, + fp8_dtype=dtype, + fp8_scale_inv=fp8_scale_inv, + shape=permuted_act.shape, + dtype=fake_dtype, ) ctx.row_id_map = row_id_map @@ -107,6 +115,7 @@ def backward( ), "Grad of the output must be in Float8Tensor type for FP8 moe_permute." dtype = permuted_act_grad._fp8_dtype fp8_scale_inv = permuted_act_grad._scale_inv + fake_dtype = permuted_act_grad.dtype permuted_act_grad = permuted_act_grad._data else: dtype = TE_DType[permuted_act_grad.dtype] @@ -118,7 +127,11 @@ def backward( ) if ctx.fp8: act_grad = Float8Tensor( - data=act_grad, fp8_dtype=dtype, fp8_scale_inv=fp8_scale_inv * ctx.topK + data=act_grad, + fp8_dtype=dtype, + fp8_scale_inv=fp8_scale_inv * ctx.topK, + shape=act_grad.shape, + dtype=fake_dtype, ) return act_grad, None, None, None @@ -167,6 +180,7 @@ def forward( if fp8: dtype = inp._fp8_dtype fp8_scale_inv = inp._scale_inv + fake_dtype = inp.dtype inp = inp._data else: dtype = TE_DType[inp.dtype] @@ -181,7 +195,11 @@ def forward( if fp8: unpermuted_output = Float8Tensor( - data=unpermuted_output, fp8_dtype=dtype, fp8_scale_inv=fp8_scale_inv + data=unpermuted_output, + fp8_dtype=dtype, + fp8_scale_inv=fp8_scale_inv, + shape=unpermuted_output.shape, + dtype=fake_dtype, ) ctx.save_for_backward(inp, row_id_map, probs) @@ -207,6 +225,7 @@ def backward( ), "Grad of the output must be in Float8Tensor type for FP8 moe_unpermute." dtype = unpermuted_act_grad._fp8_dtype fp8_scale_inv = unpermuted_act_grad._scale_inv + fake_dtype = unpermuted_act_grad.dtype unpermuted_act_grad = unpermuted_act_grad._data else: dtype = TE_DType[unpermuted_act_grad.dtype] @@ -220,7 +239,13 @@ def backward( unpermuted_act_grad, inp, dtype, row_id_map, probs ) if ctx.fp8: - act_grad = Float8Tensor(data=act_grad, fp8_dtype=dtype, fp8_scale_inv=fp8_scale_inv) + act_grad = Float8Tensor( + data=act_grad, + fp8_dtype=dtype, + fp8_scale_inv=fp8_scale_inv, + shape=act_grad.shape, + dtype=fake_dtype, + ) if not ctx.needs_input_grad[2]: prob_grad = None diff --git a/transformer_engine/pytorch/setup.py b/transformer_engine/pytorch/setup.py index d3b3f03e10..20503fea2f 100644 --- a/transformer_engine/pytorch/setup.py +++ b/transformer_engine/pytorch/setup.py @@ -12,10 +12,9 @@ from pathlib import Path import setuptools -from torch.utils.cpp_extension import BuildExtension try: - import torch # pylint: disable=unused-import + from torch.utils.cpp_extension import BuildExtension except ImportError as e: raise RuntimeError("This package needs Torch to build.") from e @@ -57,7 +56,7 @@ ext_modules=ext_modules, cmdclass={"build_ext": CMakeBuildExtension}, install_requires=["torch"], - tests_require=["numpy", "onnxruntime", "torchvision"], + tests_require=["numpy", "torchvision"], ) if any(x in sys.argv for x in (".", "sdist", "bdist_wheel")): shutil.rmtree(common_headers_dir) diff --git a/transformer_engine/pytorch/softmax.py b/transformer_engine/pytorch/softmax.py index 3950c071b6..25362e1d58 100644 --- a/transformer_engine/pytorch/softmax.py +++ b/transformer_engine/pytorch/softmax.py @@ -7,11 +7,7 @@ from typing import Callable, Tuple, Union, Optional import torch from torch import nn -import torch._C._onnx as _C_onnx -from torch.onnx import _type_utils import transformer_engine_torch as tex -from transformer_engine.pytorch.export import is_in_onnx_export_mode -from transformer_engine.pytorch.te_onnx_extensions import compute_in_fp32 THREADS_PER_WARP = 32 @@ -32,35 +28,6 @@ def _get_default_causal_mask(mask_type: str, sq: int, sk: int) -> torch.Tensor: return _default_causal_mask[matrix_identifiers] -def _get_onnx_export_causal_mask( - seq_q: int, seq_k: int, onnx_causal_mask: torch.Tensor -) -> torch.Tensor: - """Return the causal upper triangular mask for softmax input, for ONNX export. - - ONNX does not support dynamic control-flow and requires non-square masks when - using a KV-cache (seq_k's length len(context)+len(generative) while seq_q's length is 1). - - Argument `onnx_causal_mask` is a square triu (k=1) mask that is sliced to the correct - shape for GPT context and generation phases. - In the context phase the derived mask is a square triu of shape (seq_k, seq_k), and in - the generation phase the mask is rectangular with shape (1, seq_k). - """ - assert len(onnx_causal_mask.size()) == 2 - assert onnx_causal_mask.size(0) == onnx_causal_mask.size(1) - assert onnx_causal_mask.size(0) >= (seq_k - seq_q) >= 0 - derived_mask = onnx_causal_mask[seq_k - seq_q : seq_k, :seq_k] - return derived_mask - - -def fp32_compute(onnx_symbolic_fn): - """A decorator that wraps an ONNX symoblic function with FP32 compute operators.""" - - def wrapper(g: torch.Graph, inp: torch._C.Value, scale: float, *args, **kwargs): - return compute_in_fp32(g, inp, onnx_symbolic_fn, scale, *args, **kwargs) - - return wrapper - - class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function): """ Fused operation which performs following three operations in sequence @@ -88,34 +55,6 @@ def backward(ctx, output_grads: torch.Tensor) -> Tuple[Union[torch.Tensor, None] return input_grads, None - @staticmethod - @fp32_compute - def symbolic(g: torch.Graph, inputs: torch._C.Value, scale: float) -> torch._C.Value: - """ScaledUpperTriangMaskedSoftmax symbolic method""" - - def triangular_mask(): - dtype = _type_utils.JitScalarType.INT64 - ones = torch.onnx.symbolic_opset9.ones_like(g, inputs, dtype) - k = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)) - mask = g.op("Trilu", ones, k, upper_i=1) - mask = g.op("Cast", mask, to_i=_C_onnx.TensorProtoDataType.BOOL) - return mask - - # Captures the logic of function scaled_upper_triang_masked_softmax_warp_forward - mask = triangular_mask() - one = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)) - inv_mask = g.op("Sub", one, mask) - - neg_tenK = g.op("Constant", value_t=torch.tensor(-10000.0, dtype=torch.float16)) - softmax_mask = g.op("Mul", mask, neg_tenK) - - scale_input = g.op("Constant", value_t=torch.tensor(scale, dtype=torch.float16)) - scaled = g.op("Mul", inputs, scale_input) - masked_scaled = g.op("Mul", inv_mask, scaled) - masked = g.op("Add", masked_scaled, softmax_mask) - out = g.op("Softmax", masked) - return out - class ScaledAlignedCausalMaskedSoftmax(torch.autograd.Function): """ @@ -143,40 +82,6 @@ def backward(ctx, output_grads: torch.Tensor) -> Tuple[Union[torch.Tensor, None] return input_grads, None - @staticmethod - @fp32_compute - def symbolic(g: torch.Graph, inputs: torch._C.Value, scale: float) -> torch._C.Value: - """ScaledAlignedCausalMaskedSoftmax symbolic method""" - - def triangular_mask(): - dtype = _type_utils.JitScalarType.INT64 - ones = torch.onnx.symbolic_opset9.ones_like(g, inputs, dtype) - k = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)) - - # rectangular causal mask aligned to the bottom right corner of Attention matrix - rows = inputs.size(dim=-2) - cols = inputs.size(dim=-1) - diag_shift = cols - rows + 1 - - mask = g.op("Trilu", ones, k, upper_i=diag_shift) - mask = g.op("Cast", mask, to_i=_C_onnx.TensorProtoDataType.BOOL) - return mask - - # Captures the logic of function scaled_aligned_masked_softmax_warp_forward - mask = triangular_mask() - one = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)) - inv_mask = g.op("Sub", one, mask) - - neg_tenK = g.op("Constant", value_t=torch.tensor(-10000.0, dtype=torch.float16)) - softmax_mask = g.op("Mul", mask, neg_tenK) - - scale_input = g.op("Constant", value_t=torch.tensor(scale, dtype=torch.float16)) - scaled = g.op("Mul", inputs, scale_input) - masked_scaled = g.op("Mul", inv_mask, scaled) - masked = g.op("Add", masked_scaled, softmax_mask) - out = g.op("Softmax", masked) - return out - class ScaledMaskedSoftmax(torch.autograd.Function): """ @@ -203,30 +108,6 @@ def backward(ctx, output_grads: torch.Tensor) -> Tuple[Union[torch.Tensor, None] input_grads = tex.scaled_masked_softmax_backward(output_grads, softmax_results, scale_t[0]) return input_grads, None, None - @staticmethod - @fp32_compute - def symbolic( - g: torch.Graph, inputs: torch._C.Value, mask: torch._C.Value, scale: float - ) -> torch._C.Value: - """ScaledMaskedSoftmax symbolic method""" - # Captures the logic of function scaled_masked_softmax_warp_forward. - # output = softmax(mask(input*scale) - # Computed as: - # masked_scaled = (1 - mask)*(input*scale) - # softmax_mask = mask * -10000 - # output = softmax(masked_scaled + softmax_mask) - scale_input = g.op("Constant", value_t=torch.tensor(scale, dtype=torch.float16)) - scaled = g.op("Mul", inputs, scale_input) - one = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)) - inv_mask = g.op("Sub", one, mask) - # Note: type is hard coded because softmax uses FP16 or BF16 - neg_tenK = g.op("Constant", value_t=torch.tensor(-10000.0, dtype=torch.float16)) - softmax_mask = g.op("Mul", mask, neg_tenK) - masked_scaled = g.op("Mul", inv_mask, scaled) - masked = g.op("Add", masked_scaled, softmax_mask) - out = g.op("Softmax", masked) - return out - class ScaledSoftmax(torch.autograd.Function): """ @@ -252,15 +133,6 @@ def backward(ctx, output_grads: torch.Tensor) -> Tuple[Union[torch.Tensor, None] input_grads = tex.scaled_softmax_backward(output_grads, softmax_results, scale_t[0]) return input_grads, None, None - @staticmethod - @fp32_compute - def symbolic(g: torch.Graph, inputs: torch._C.Value, scale: float) -> torch._C.Value: - """ScaledSoftmax symbolic method""" - scale_input = g.op("Constant", value_t=torch.tensor(scale, dtype=torch.float16)) - scaled = g.op("Mul", inputs, scale_input) - out = g.op("Softmax", scaled) - return out - class FusedScaleMaskSoftmax(nn.Module): """ @@ -281,18 +153,6 @@ def __init__( self.mask_func = mask_func self.softmax_in_fp32 = softmax_in_fp32 - # Users exporting to ONNX can optimize the attention mask for GPT text generation. - self.kvcache_max_seq = int(os.getenv("NVTE_ONNX_KVCACHE_MAX_SEQ_LEN", "-1")) - if self.kvcache_max_seq > 0: - self.register_buffer( - "onnx_causal_mask", - torch.triu( - torch.ones(self.kvcache_max_seq, self.kvcache_max_seq, device="cuda"), - diagonal=1, - ).bool(), - persistent=False, - ) - def forward( self, inp: torch.Tensor, @@ -310,7 +170,7 @@ def forward( assert scale is None or self.softmax_in_fp32, "softmax should be in fp32 when scaled" - if self.is_kernel_available(mask, *inp.size()) and not is_in_onnx_export_mode(): + if self.is_kernel_available(mask, *inp.size()): return self.forward_fused_softmax(inp, mask, scale) return self.forward_torch_softmax(inp, mask, scale) @@ -363,8 +223,9 @@ def forward_fused_softmax( """ scale = 1.0 if scale is None else scale - if self.attn_mask_type in ["causal", "causal_bottom_right"]: - return ScaledAlignedCausalMaskedSoftmax.apply(inp, scale) + # Disable for now until unalignment bug is fixed. + # if self.attn_mask_type in ["causal", "causal_bottom_right"]: + # return ScaledAlignedCausalMaskedSoftmax.apply(inp, scale) # input is 4D tensor (1, 1, sq, sk) or (b, 1, sq, sk) if mask is not None and self.attn_mask_type != "no_mask": @@ -383,13 +244,7 @@ def forward_torch_softmax( if self.attn_mask_type in ["causal", "causal_bottom_right"]: seq_len_q, seq_len_k = inp.size(2), inp.size(3) - if is_in_onnx_export_mode() and self.kvcache_max_seq > 0: - assert self.kvcache_max_seq >= seq_len_k - causal_mask = _get_onnx_export_causal_mask( - seq_len_q, seq_len_k, self.onnx_causal_mask - ) - else: - causal_mask = _get_default_causal_mask(self.attn_mask_type, seq_len_q, seq_len_k) + causal_mask = _get_default_causal_mask(self.attn_mask_type, seq_len_q, seq_len_k) if mask is None: mask = causal_mask else: diff --git a/transformer_engine/pytorch/te_onnx_extensions.py b/transformer_engine/pytorch/te_onnx_extensions.py deleted file mode 100755 index 54eb37ecab..0000000000 --- a/transformer_engine/pytorch/te_onnx_extensions.py +++ /dev/null @@ -1,519 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -""" -ONNX symbolic functions for Transformer Engine - -Warnings of the type pasted below are a known Pytorch issue -(https://github.com/pytorch/pytorch/issues/81693): - -tests/test_onnx_export.py::test_export_cast_ops[112] - /opt/conda/lib/python3.8/site-packages/torch/onnx/utils.py:649: - UserWarning: The shape inference of trt::TRT_FP8DequantizeLinear type is missing, - so it may result in wrong shape inference for the exported graph. - Please consider adding it in symbolic function. (Triggered internally at - /opt/pytorch/pytorch/torch/csrc/jit/passes/onnx/shape_type_inference.cpp:1880.) - _C._jit_pass_onnx_graph_shape_type_inference( - - -Scale tensors are treated as lists ("fs") instead of tensors ("v") because we need to access -specific entries using the index passes as `fp8_tensor`. If you fail to do this you will get -the following error when accessing a sepcific scale element (e.g. `scale_inv[fp8_tensor]`): - TypeError: 'torch._C.Value' object is not subscriptable -""" - -import torch -from torch.onnx import symbolic_helper, register_custom_op_symbolic, _type_utils -import torch._C._onnx as _C_onnx - -# Monkey-patch graph manipulation methods on Graph, used for the ONNX symbolics -from torch.onnx._internal import jit_utils - -import transformer_engine_torch as tex - - -# This file registers custom op symbolic ONNX functions and does not export any symbols. -__all__ = [] - - -# Custom ops spec version -VER = 1 -UNSPECIFIED_TYPE = -1 - - -def make_op_name(op_name: str) -> str: - """custom op name""" - return "trt::" + op_name - - -def get_TensorProtoDataType(t): - """Return the _C_onnx.TensorProtoDataType of the input tensor""" - try: - return { - "Float": _C_onnx.TensorProtoDataType.FLOAT, - "Half": _C_onnx.TensorProtoDataType.FLOAT16, - "BFloat16": _C_onnx.TensorProtoDataType.BFLOAT16, - }[t.type().scalarType()] - except KeyError as e: - raise TypeError(f"Onnx export for dtype {t.type().scalarType()} not supported.") from e - - -def is_dtype_fp32(t): - """Check fp32 dtype""" - return t.type().scalarType() == "Float" - - -def is_dtype_fp16(t): - """Check fp16 dtype""" - return t.type().scalarType() == "Half" - - -def is_dtype_bf16(t): - """Check bf16 dtype""" - return t.type().scalarType() == "BFloat16" - - -def quantize(g, inputs, scale, fp8_tensor): - """Helper Function for Quantization""" - output_shape = torch.onnx.symbolic_helper._get_tensor_sizes(inputs) - - # Q inputs are currently constrained to FP32 due to a similar limitation in ORT - # custom ops, so cast the input if needed. - if not is_dtype_fp32(inputs): - inputs = g.op("Cast", inputs, to_i=_C_onnx.TensorProtoDataType.FLOAT) - - scale = g.op("Constant", value_t=torch.tensor(1 / scale[fp8_tensor])) - q_op = g.op(make_op_name("TRT_FP8QuantizeLinear"), inputs, scale).setType( - inputs.type().with_dtype(torch.uint8).with_sizes(output_shape) - ) - return q_op - - -def dequantize(g, inputs, scale_inv, fp8_tensor, otype): - """Helper Function for Dequantization""" - output_shape = torch.onnx.symbolic_helper._get_tensor_sizes(inputs) - - scale = g.op("Constant", value_t=torch.tensor(scale_inv[fp8_tensor])) - out = g.op(make_op_name("TRT_FP8DequantizeLinear"), inputs, scale).setType( - inputs.type().with_dtype(torch.float32).with_sizes(output_shape) - ) - - # DQ outputs are currently constrained to FP32 due to a similar limitation in ORT - # custom ops, so cast the output if needed. - if otype == int(tex.DType.kFloat16): - out = g.op("Cast", out, to_i=_C_onnx.TensorProtoDataType.FLOAT16) - elif otype == int(tex.DType.kBFloat16): - out = g.op("Cast", out, to_i=_C_onnx.TensorProtoDataType.BFLOAT16) - return out - - -def compute_in_fp32(g, inp, subgraph, *args, **kwargs): - """Wrap subgraph with casts to/from FP32 so that its precision is FP32. - - If `inp` data type is not FP32, add a cast of `inp` to FP32 and feed that into `subgraph`; - then cast subgraphs's output back to `inp` data type. - """ - inp_dtype = get_TensorProtoDataType(inp) - is_fp32 = inp_dtype == _type_utils.JitScalarType.FLOAT - if not is_fp32: - inp = g.op("Cast", inp, to_i=_C_onnx.TensorProtoDataType.FLOAT) - sg_out = subgraph(g, inp, *args, **kwargs) - if not is_fp32: - sg_out = g.op("Cast", sg_out, to_i=inp_dtype) - return sg_out - - -@symbolic_helper.parse_args("v", "fs", "v", "v", "i", "i") -def onnx_cast_to_fp8(g, inputs, scale, amax, scale_inv, fp8_tensor, otype): - """ONNX graph for cast_to_fp8""" - # pylint: disable=unused-argument - return quantize(g, inputs, scale, fp8_tensor) - - -@symbolic_helper.parse_args("v", "fs", "v", "v", "v", "i", "i") -def onnx_cast_to_fp8_noalloc(g, inputs, scale, output, amax, scale_inv, fp8_tensor, otype): - """ONNX graph for cast_to_fp8_noalloc""" - # pylint: disable=unused-argument - return quantize(g, inputs, scale, fp8_tensor) - - -@symbolic_helper.parse_args("v", "fs", "i", "i", "i") -def onnx_cast_from_fp8(g, inputs, scale_inv, fp8_tensor, itype, otype): - """ONNX graph for cast_from_fp8""" - # pylint: disable=unused-argument - return dequantize(g, inputs, scale_inv, fp8_tensor, otype) - - -@symbolic_helper.parse_args("v", "fs", "v", "v", "i", "i") -def onnx_fp8_gelu(g, inp, scale, amax, scale_inv, fp8_tensor, otype): - """ONNX graph for fp8_gelu""" - # pylint: disable=unused-argument - # TE computes GELU using float32 precision so wrap the GELU subgraph with - # conversion to/from float32. - dtype = get_TensorProtoDataType(inp) - if dtype != _type_utils.JitScalarType.FLOAT: - inp = g.op("Cast", inp, to_i=_C_onnx.TensorProtoDataType.FLOAT) - out = torch.onnx.symbolic_opset9.gelu(g, inp, "tanh") - if scale: - out = quantize(g, out, scale, fp8_tensor) - elif dtype != _type_utils.JitScalarType.FLOAT: - out = g.op("Cast", out, to_i=dtype) - return out - - -@symbolic_helper.parse_args("v", "fs", "v", "v", "i", "i") -def onnx_fp8_relu(g, inp, scale, amax, scale_inv, fp8_tensor, otype): - """ONNX graph for fp8_relu""" - # pylint: disable=unused-argument - out = torch.onnx.symbolic_opset9.relu(g, inp) - if scale: - out = quantize(g, out, scale, fp8_tensor) - return out - - -@symbolic_helper.parse_args("v", "i") -def onnx_swiglu(g: jit_utils.GraphContext, inp, dim): - """ONNX graph for swiglu""" - - # Check dimensions - dim_size = symbolic_helper._get_tensor_dim_size(inp, dim) - if dim_size is not None: - assert dim_size % 2 == 0 - - # Perform compute in FP32 - dtype = get_TensorProtoDataType(inp) - if dtype != _type_utils.JitScalarType.FLOAT: - inp = g.op("Cast", inp, to_i=_C_onnx.TensorProtoDataType.FLOAT) - first, second = g.op("Split", inp, axis_i=dim, outputs=2) - out = g.op("Mul", g.op("Sigmoid", first), second) - if dtype != _type_utils.JitScalarType.FLOAT: - out = g.op("Cast", out, to_i=dtype) - return out - - -@symbolic_helper.parse_args("v", "fs", "v", "v", "i", "i") -def onnx_fp8_swiglu(g, inp, scale, amax, scale_inv, fp8_tensor, otype): - """ONNX graph for fp8_swiglu""" - # pylint: disable=unused-argument - dtype = get_TensorProtoDataType(inp) - if dtype != _type_utils.JitScalarType.FLOAT: - inp = g.op("Cast", inp, to_i=_C_onnx.TensorProtoDataType.FLOAT) - out = onnx_swiglu(g, inp, 1) - if scale: - out = quantize(g, out, scale, fp8_tensor) - elif dtype != _type_utils.JitScalarType.FLOAT: - out = g.op("Cast", out, to_i=dtype) - return out - - -@symbolic_helper.parse_args("v", "i") -def onnx_reglu(g: jit_utils.GraphContext, inp, dim): - """ONNX graph for reglu""" - - # Check dimensions - dim_size = symbolic_helper._get_tensor_dim_size(inp, dim) - if dim_size is not None: - assert dim_size % 2 == 0 - - # Perform compute in FP32 - dtype = get_TensorProtoDataType(inp) - if dtype != _type_utils.JitScalarType.FLOAT: - inp = g.op("Cast", inp, to_i=_C_onnx.TensorProtoDataType.FLOAT) - first, second = g.op("Split", inp, axis_i=dim, outputs=2) - out = g.op("Mul", g.op("Relu", first), second) - if dtype != _type_utils.JitScalarType.FLOAT: - out = g.op("Cast", out, to_i=dtype) - return out - - -@symbolic_helper.parse_args("v", "fs", "v", "v", "i", "i") -def onnx_fp8_reglu(g, inp, scale, amax, scale_inv, fp8_tensor, otype): - """ONNX graph for fp8_reglu""" - # pylint: disable=unused-argument - dtype = get_TensorProtoDataType(inp) - if dtype != _type_utils.JitScalarType.FLOAT: - inp = g.op("Cast", inp, to_i=_C_onnx.TensorProtoDataType.FLOAT) - out = onnx_reglu(g, inp, 1) - if scale: - out = quantize(g, out, scale, fp8_tensor) - elif dtype != _type_utils.JitScalarType.FLOAT: - out = g.op("Cast", out, to_i=dtype) - return out - - -@symbolic_helper.parse_args("v", "i") -def onnx_geglu(g: jit_utils.GraphContext, inp, dim): - """ONNX graph for geglu""" - - # Check dimensions - dim_size = symbolic_helper._get_tensor_dim_size(inp, dim) - if dim_size is not None: - assert dim_size % 2 == 0 - - # Perform compute in FP32 - dtype = get_TensorProtoDataType(inp) - if dtype != _type_utils.JitScalarType.FLOAT: - inp = g.op("Cast", inp, to_i=_C_onnx.TensorProtoDataType.FLOAT) - first, second = g.op("Split", inp, axis_i=dim, outputs=2) - first = torch.onnx.symbolic_opset9.gelu(g, first, "tanh") - out = g.op("Mul", first, second) - if dtype != _type_utils.JitScalarType.FLOAT: - out = g.op("Cast", out, to_i=dtype) - return out - - -@symbolic_helper.parse_args("v", "fs", "v", "v", "i", "i") -def onnx_fp8_geglu(g, inp, scale, amax, scale_inv, fp8_tensor, otype): - """ONNX graph for fp8_geglu""" - # pylint: disable=unused-argument - dtype = get_TensorProtoDataType(inp) - if dtype != _type_utils.JitScalarType.FLOAT: - inp = g.op("Cast", inp, to_i=_C_onnx.TensorProtoDataType.FLOAT) - out = onnx_geglu(g, inp, 1) - if scale: - out = quantize(g, out, scale, fp8_tensor) - elif dtype != _type_utils.JitScalarType.FLOAT: - out = g.op("Cast", out, to_i=dtype) - return out - - -@symbolic_helper.parse_args( - "v", - "fs", - "i", - "i", - "i", - "v", - "fs", - "i", - "i", - "i", - "v", - "fs", - "i", - "v", - "v", - "i", - "v", - "i", - "v", - "i", - "i", - "i", -) -def onnx_te_gemm( - g, - weight, - weight_scale_inverse, - weight_fp8_tensor, - weight_type, - trans_weight, - inputs, - input_scale_inverse, - input_fp8_tensor, - input_type, - trans_input, - out, - out_scale, - out_type, - out_amax, - bias, - bias_type, - pre_gelu_out, - grad, - workspace, - workspaceSize, - accumulate, - use_split_accumulator, -): - """ONNX graph for te_gemm""" - # pylint: disable=unused-argument - is_fp16 = is_dtype_fp16(inputs) - is_bf16 = is_dtype_bf16(inputs) - if input_type == int(tex.DType.kFloat8E4M3): - inputs = dequantize(g, inputs, input_scale_inverse, input_fp8_tensor, out_type) - - if weight_type == int(tex.DType.kFloat8E4M3): - weight = dequantize(g, weight, weight_scale_inverse, weight_fp8_tensor, out_type) - - empty_tensor_size = [0] - bias_empty = torch.onnx.symbolic_helper._get_tensor_sizes(bias) == empty_tensor_size - pre_gelu_out_empty = ( - torch.onnx.symbolic_helper._get_tensor_sizes(pre_gelu_out) == empty_tensor_size - ) - - if not bias_empty: - output = g.op("Gemm", inputs, weight, bias, transA_i=trans_input, transB_i=trans_weight) - else: - output = g.op("Gemm", inputs, weight, transA_i=trans_input, transB_i=trans_weight) - if not bias_empty: - if not pre_gelu_out_empty: - # TE computes GELU using float32 precision so wrap the GELU subgraph with - # conversion to/from float32. - output = compute_in_fp32(g, output, torch.onnx.symbolic_opset9.gelu, "tanh") - else: - if is_fp16: - output = g.op("Cast", output, to_i=_C_onnx.TensorProtoDataType.FLOAT16) - elif is_bf16: - output = g.op("Cast", output, to_i=_C_onnx.TensorProtoDataType.BFLOAT16) - return output - - -def _ones_like(g, inp, dtype): - """Returns a tensor filled with the scalar value 1, with the same size as input and - with dtype data-type""" - shape = g.op("Shape", inp) - # WAR ONNX spec: ConstantOfShape accepts all data types except for BF16. To WAR - # create a ConstantOfShape with type FP32 and then add a Cast to BF16. - is_bf16 = dtype == torch.bfloat16 - one = g.op( - "ConstantOfShape", - shape, - value_t=torch.tensor([1], dtype=torch.float32 if is_bf16 else dtype), - ) - if is_bf16: - one = g.op("Cast", one, to_i=_C_onnx.TensorProtoDataType.BFLOAT16) - return one - - -@symbolic_helper.parse_args("v", "v", "v", "f", "fs", "v", "v", "i", "i", "i", "b") -def onnx_layernorm_fwd_fp8( - g, - inputs, - weight, - bias, - eps, - scale, - amax, - scale_inv, - fp8_tensor, - otype, - sm_margin, - zero_centered_gamma, -): - """ONNX graph for layernorm_fwd_fp8""" - # pylint: disable=unused-argument - inp_dtype = get_TensorProtoDataType(inputs) - - if inp_dtype != get_TensorProtoDataType(weight): - weight = g.op("Cast", weight, to_i=inp_dtype) - if inp_dtype != get_TensorProtoDataType(bias): - bias = g.op("Cast", bias, to_i=inp_dtype) - - ln = onnx_layernorm_fwd(g, inputs, weight, bias, eps, sm_margin, zero_centered_gamma) - fp8_ln = quantize(g, ln, scale, fp8_tensor) - return fp8_ln - - -@symbolic_helper.parse_args("v", "v", "v", "f", "i", "b") -def onnx_layernorm_fwd(g, inputs, weight, bias, eps, sm_margin, zero_centered_gamma): - """ONNX graph for layernorm_fwd""" - # pylint: disable=unused-argument - - normalized_shape = torch.onnx.symbolic_helper._get_tensor_sizes(inputs) - if normalized_shape is None: - ndim = torch.onnx.symbolic_helper._get_tensor_rank(inputs) - assert ndim is not None - normalized_shape = list(range(0, ndim)) - # Normalization axis = 0, so normalized_shape uses all dims except dim = 0 - normalized_shape = normalized_shape[1:] - - if zero_centered_gamma: - inputs_dtype = inputs.type().dtype() - one = _ones_like(g, weight, inputs_dtype) - weight = g.op("Add", weight, one) - - axis = -len(normalized_shape) - ln = g.op( - "LayerNormalization", - inputs, - weight, - bias, - epsilon_f=eps, - axis_i=axis, - # This sets the LN compute precision - use FP32 always as does TE. - stash_type_i=_C_onnx.TensorProtoDataType.FLOAT, - ) - return ln - - -@symbolic_helper.parse_args("v", "v", "f", "fs", "v", "v", "i", "i", "i", "b") -def onnx_rmsnorm_fwd_fp8( - g, - inp, - weight, - eps, - scale, - amax, - scale_inv, - fp8_tensor, - otype, - sm_margin, - zero_centered_gamma, -): - """ONNX graph for rmsnorm_fwd_fp8""" - # pylint: disable=unused-argument - dtype = get_TensorProtoDataType(inp) - if dtype != _type_utils.JitScalarType.FLOAT: - inp = g.op("Cast", inp, to_i=_C_onnx.TensorProtoDataType.FLOAT) - out = onnx_rmsnorm_fwd(g, inp, weight, eps, sm_margin, zero_centered_gamma) - out = quantize(g, out, scale, fp8_tensor) - return out - - -@symbolic_helper.parse_args("v", "v", "f", "i", "b") -def onnx_rmsnorm_fwd(g, inp, weight, eps, sm_margin, zero_centered_gamma): - """ONNX graph for rmsnorm_fwd""" - # pylint: disable=unused-argument - - # Check dimensions - normalized_shape = torch.onnx.symbolic_helper._get_tensor_sizes(inp) - if normalized_shape is None: - ndim = torch.onnx.symbolic_helper._get_tensor_rank(inp) - assert ndim is not None - normalized_shape = list(range(0, ndim)) - # Normalization axis = 0, so normalized_shape uses all dims except dim = 0 - normalized_shape = normalized_shape[1:] - axis = -len(normalized_shape) - - # Cast input tensors to FP32 if needed - dtype = get_TensorProtoDataType(inp) - if dtype != _type_utils.JitScalarType.FLOAT: - inp = g.op("Cast", inp, to_i=_C_onnx.TensorProtoDataType.FLOAT) - if get_TensorProtoDataType(weight) != _type_utils.JitScalarType.FLOAT: - weight = g.op("Cast", weight, to_i=_C_onnx.TensorProtoDataType.FLOAT) - - # Adjust zero-centered weights - if zero_centered_gamma: - one = _ones_like(g, weight, torch.float32) - weight = g.op("Add", weight, one) - - # Perform compute in FP32 - sum_square = g.op("ReduceSumSquare", inp, axes_i=[axis]) - shape = g.op("Shape", inp, start_i=-1) - shape_f = g.op("Cast", shape, to_i=_C_onnx.TensorProtoDataType.FLOAT) - mean_squared = g.op("Div", sum_square, shape_f) - eps_tensor = g.op("ConstantOfShape", shape, value_t=torch.tensor([eps], dtype=torch.float32)) - rms_squared = g.op("Add", mean_squared, eps_tensor) - rms_eps = g.op("Sqrt", rms_squared) - normalized_input = g.op("Div", inp, rms_eps) - out = g.op("Mul", weight, normalized_input) - if dtype != _type_utils.JitScalarType.FLOAT: - out = g.op("Cast", out, to_i=dtype) - return out - - -register_custom_op_symbolic("tex_ts::cast_to_fp8_ts", onnx_cast_to_fp8, VER) -register_custom_op_symbolic("tex_ts::cast_to_fp8_noalloc_ts", onnx_cast_to_fp8_noalloc, VER) -register_custom_op_symbolic("tex_ts::cast_from_fp8_ts", onnx_cast_from_fp8, VER) -register_custom_op_symbolic("tex_ts::gelu_ts", onnx_fp8_gelu, VER) -register_custom_op_symbolic("tex_ts::relu_ts", onnx_fp8_relu, VER) -register_custom_op_symbolic("tex_ts::reglu_ts", onnx_fp8_reglu, VER) -register_custom_op_symbolic("tex_ts::geglu_ts", onnx_fp8_geglu, VER) -register_custom_op_symbolic("tex_ts::swiglu_ts", onnx_fp8_swiglu, VER) -register_custom_op_symbolic("tex_ts::te_gemm_ts", onnx_te_gemm, VER) -register_custom_op_symbolic("tex_ts::layernorm_fwd_fp8_inf_ts", onnx_layernorm_fwd_fp8, VER) -register_custom_op_symbolic("tex_ts::layernorm_fwd_inf_ts", onnx_layernorm_fwd, VER) -register_custom_op_symbolic("tex_ts::rmsnorm_fwd_fp8_inf_ts", onnx_rmsnorm_fwd_fp8, VER) -register_custom_op_symbolic("tex_ts::rmsnorm_fwd_inf_ts", onnx_rmsnorm_fwd, VER) diff --git a/transformer_engine/pytorch/tensor/__init__.py b/transformer_engine/pytorch/tensor/__init__.py index aceaaf5d10..610ec2a777 100644 --- a/transformer_engine/pytorch/tensor/__init__.py +++ b/transformer_engine/pytorch/tensor/__init__.py @@ -6,10 +6,12 @@ import torch -from .float8_tensor import Float8Tensor -from .quantized_tensor import QuantizedTensor +from .quantized_tensor import QuantizedTensor, Quantizer -__all__ = ["Float8Tensor", "QuantizedTensor"] +__all__ = [ + "QuantizedTensor", + "Quantizer", +] def _make_module_cast_func(dtype): @@ -22,14 +24,8 @@ def _make_module_cast_func(dtype): def tensor_cast_func(tensor: torch.Tensor) -> torch.Tensor: """Cast tensor dtype""" - if isinstance(tensor, Float8Tensor): - return Float8Tensor.make_like( - tensor, - data=tensor._data, - fp8_attrs=tensor._fp8_attrs, - dtype=dtype, - requires_grad=tensor.requires_grad, - ) + if isinstance(tensor, QuantizedTensor): + return tensor.__class__.make_like(tensor, dtype=dtype) if tensor.is_floating_point(): return getattr(tensor, cast_func_name)() return tensor diff --git a/tests/paddle/test_sanity_import.py b/transformer_engine/pytorch/tensor/_internal/__init__.py similarity index 69% rename from tests/paddle/test_sanity_import.py rename to transformer_engine/pytorch/tensor/_internal/__init__.py index 0390f2f6a0..e13014bf75 100644 --- a/tests/paddle/test_sanity_import.py +++ b/transformer_engine/pytorch/tensor/_internal/__init__.py @@ -1,7 +1,4 @@ # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. - -import transformer_engine.paddle - -print("OK") +"""Internal data structures for quantized tensors.""" diff --git a/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py b/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py new file mode 100644 index 0000000000..6b816db3b5 --- /dev/null +++ b/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py @@ -0,0 +1,139 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Mixin class holding data specific for Float8Tensor""" + +from __future__ import annotations +from typing import Any, Dict, Optional, Tuple +import torch + +import transformer_engine_torch as tex +from transformer_engine_torch import DType as TE_DType + +from ...constants import TE_DType as torch_to_transformer_engine_dtype + +from ..quantized_tensor import Quantizer + + +class _FromFloat8Func(torch.autograd.Function): + """Cast from FP8 to other dtype""" + + @staticmethod + def forward( + _ctx: Optional[torch.autograd.function.FunctionCtx], # unused + tensor: Float8TensorBase, + dtype: torch.dtype, + ) -> torch.Tensor: + # pylint: disable=missing-function-docstring + dtype = torch_to_transformer_engine_dtype[dtype] + + # Make sure FP8 data is in expected format + if tensor._data is not None: + # Cast from FP8 + return tex.dequantize(tensor, dtype) + + raise NotImplementedError("Casting back from the transpose not implemented yet!") + + @staticmethod + def backward( + _ctx: torch.autograd.function.FunctionCtx, # unused + grad: torch.Tensor, + ) -> Tuple[Optional[torch.Tensor], ...]: + # pylint: disable=missing-function-docstring + # Assume that we want gradients in full precision + return grad, None + + +class Float8TensorBase: + """Mixin class that holds data attributes of Float8Tensor. + + Float8Tensor inherits from the PyTorch tensor class and this mixin + class. If this class is instantiated directly, it has the same + data, lower CPU overhead, and less functionality. It should only + be instantiated directly for performance-critical internal usage. + + """ + + _data: Optional[torch.Tensor] + _quantizer: Optional[Quantizer] + _fp8_dtype: TE_DType + _scale_inv: torch.Tensor + + # FP8 transpose cache + _transpose: Optional[torch.Tensor] + _transpose_invalid: bool + + def __new__( + cls, + *args, + data: Optional[torch.Tensor], + fp8_scale_inv: torch.Tensor, + fp8_dtype: TE_DType, + data_transpose: Optional[torch.Tensor] = None, + quantizer: Optional[Quantizer] = None, + **kwargs, + ): + if cls is Float8TensorBase: + instance = object.__new__(cls) + else: + instance = super().__new__(cls, *args, **kwargs) + instance._data = data + instance._quantizer = quantizer + instance._fp8_dtype = fp8_dtype + instance._scale_inv = fp8_scale_inv + instance._transpose = data_transpose + instance._transpose_invalid = instance._transpose is None + + return instance + + def get_metadata(self) -> Dict[str, Any]: + """Get this tensor's metadata.""" + return { + "data": self._data, + "fp8_scale_inv": self._scale_inv, + "fp8_dtype": self._fp8_dtype, + "data_transpose": self._transpose, + "quantizer": self._quantizer, + } + + def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], Float8TensorBase]: + """Prepare the tensor base for saving for backward + + After calling this, the tensor instance does not hold any + data. + + """ + tensors = [self._data, self._transpose] + # self._data = None + # self._transpose = None + return tensors, self + + def restore_from_saved( + self, tensors: list[Optional[torch.Tensor]] + ) -> list[Optional[torch.Tensor]]: + """Restore the tensor base data from the saved tensors list""" + self._data = tensors[0] + self._transpose = tensors[1] + return tensors[2:] + + def get_data_tensors(self): + """Get this Tensor's data.""" + return self._data, self._transpose + + def dequantize(self, *, dtype: torch.dtype = torch.float32) -> torch.Tensor: + """Dequantize to a higher precision.""" + return _FromFloat8Func.forward(None, self, dtype) + + def size(self, *args, **kwargs): + # pylint: disable=missing-function-docstring + return self._data.size(*args, **kwargs) + + def __repr__(self): + return ( + "Float8TensorBase(" + f"fp8_dtype={self._fp8_dtype}, " + f"scale_inv={self._scale_inv.item()}, " + f"data={self.dequantize()}" + ")" + ) diff --git a/transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py b/transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py new file mode 100644 index 0000000000..d78bd55d9a --- /dev/null +++ b/transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py @@ -0,0 +1,136 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Mixin class holding data specific for MXFP8Tensor""" + +from __future__ import annotations +from typing import Optional, Dict, Any, Tuple +import torch + +import transformer_engine_torch as tex +from transformer_engine_torch import DType as TE_DType + +from ...constants import TE_DType as torch_to_transformer_engine_dtype + +from ..quantized_tensor import Quantizer + + +class _FromMXFP8Func(torch.autograd.Function): + """Cast from MXFP8 to other dtype""" + + @staticmethod + def forward( + _ctx: Optional[torch.autograd.function.FunctionCtx], # unused + tensor: MXFP8TensorBase, + dtype: torch.dtype, + ) -> torch.Tensor: + # pylint: disable=missing-function-docstring + dtype = torch_to_transformer_engine_dtype[dtype] + + # Make sure FP8 data is in expected format + if tensor._rowwise_data is not None: + return tex.dequantize(tensor, dtype) + raise NotImplementedError("Casting back from the transpose not implemented yet!") + + @staticmethod + def backward( + _ctx: torch.autograd.function.FunctionCtx, # unused + grad: torch.Tensor, + ) -> Tuple[Optional[torch.Tensor], ...]: + # pylint: disable=missing-function-docstring + # Assume that we want gradients in full precision + return grad, None + + +class MXFP8TensorBase: + """Mixin class that holds data attributes of MXFP8Tensor. + + MXFP8Tensor inherits from the PyTorch tensor class and this mixin + class. If this class is instantiated directly, it has the same + data, lower CPU overhead, and less functionality. It should only + be instantiated directly for performance-critical internal usage. + + """ + + _rowwise_data: Optional[torch.Tensor] + _columnwise_data: Optional[torch.Tensor] + _quantizer: Optional[Quantizer] + _fp8_dtype: TE_DType + _rowwise_scale_inv: torch.Tensor + _columnwise_scale_inv: torch.Tensor + + def __new__( + cls, + *args, + rowwise_data: Optional[torch.Tensor], + rowwise_scale_inv: torch.Tensor, + columnwise_data: Optional[torch.Tensor], + columnwise_scale_inv: torch.Tensor, + fp8_dtype: TE_DType, + quantizer: Optional[Quantizer] = None, + **kwargs, + ): + instance = super().__new__(cls, *args, **kwargs) + instance._rowwise_data = rowwise_data + instance._columnwise_data = columnwise_data + instance._quantizer = quantizer + instance._fp8_dtype = fp8_dtype + instance._rowwise_scale_inv = rowwise_scale_inv + instance._columnwise_scale_inv = columnwise_scale_inv + + return instance + + def get_metadata(self) -> Dict[str, Any]: + """Get this tensor's metadata.""" + return { + "rowwise_data": self._rowwise_data, + "rowwise_scale_inv": self._rowwise_scale_inv, + "columnwise_data": self._columnwise_data, + "columnwise_scale_inv": self._columnwise_scale_inv, + "fp8_dtype": self._fp8_dtype, + "quantizer": self._quantizer, + } + + def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], MXFP8TensorBase]: + """Prepare the tensor base for saving for backward + + After calling this, the tensor instance does not hold any + data. + + """ + tensors = [self._rowwise_data, self._columnwise_data] + # self._rowwise_data = None + # self._columnwise_data = None + return tensors, self + + def restore_from_saved( + self, tensors: list[Optional[torch.Tensor]] + ) -> list[Optional[torch.Tensor]]: + """Restore the tensor base data from the saved tensors list.""" + self._rowwise_data = tensors[0] + self._columnwise_data = tensors[1] + return tensors[2:] + + def get_data_tensors(self): + """Get this Tensor's data.""" + return self._rowwise_data, self._columnwise_data + + def dequantize(self, *, dtype: torch.dtype = torch.float32) -> torch.Tensor: + """Dequantize to a higher precision.""" + return _FromMXFP8Func.forward(None, self, dtype) + + def size(self, *args, **kwargs): + # pylint: disable=missing-function-docstring + return self._rowwise_data.size(*args, **kwargs) + + def __repr__(self): + data_rowwise = self.dequantize() + + return ( + "MXFP8TensorBase(" + f"fp8_dtype={self._fp8_dtype}, " + f"rowwise_scaled_data={data_rowwise}" + f"rowwise_scale_inv={self._rowwise_scale_inv}, " + ")" + ) diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index d356df58dc..da788182a0 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -4,25 +4,18 @@ """Tensor class with FP8 data""" from __future__ import annotations -from typing import Any, Dict, Optional, Tuple +from typing import Optional, Tuple, Iterable import warnings import torch import transformer_engine_torch as tex from transformer_engine_torch import DType as TE_DType -from ..constants import TE_DType as torch_to_transformer_engine_dtype -from ..cpp_extensions import ( - cast_from_fp8, - cast_to_fp8, - fp8_cast_transpose_fused, -) -from ..fp8 import FP8GlobalStateManager -from ..utils import devices_match -from .quantized_tensor import QuantizedTensor +from ..utils import devices_match, non_tn_fp8_gemm_supported +from ._internal.float8_tensor_base import Float8TensorBase, _FromFloat8Func +from .quantized_tensor import QuantizedTensor, Quantizer, _IdentityFunc aten = torch.ops.aten -updated_fp8_params = {} _ops_to_preserve_subclass_in_fsdp2 = { torch.ops.aten.empty_like.default, @@ -38,265 +31,142 @@ } -def _make_fp8_attr_property_funcs(name: str) -> Any: - """Make accessors for an FP8 attribute +class Float8Quantizer(Quantizer): + """Builder class for FP8 tensors with per-tensor delayed scaling - We store FP8 attributes in a dictionary so we can share them - between tensors with the same data, e.g. detached tensors. For - convenience, we also expose them as property attributes. This - function creates the accessors for property attributes. - - Parameters - ---------- - name: str - Key in dictionary of FP8 attributes + High-precision tensors (e.g. in FP32 or BF16) are quantized by + multiplying with a scaling factor and casting to FP8. The max-abs + value ("amax") in the tensor is also computed, which can be used + for updating the scaling factor (handled externally by + DelayedScalingRecipeState and FP8GlobalStateManager). """ - def get_func(self) -> Any: - return self._fp8_attrs[name] + """Scaling factor to multiply when quantizing to FP8""" + scale: torch.Tensor + """Max-abs value from last FP8 cast""" + amax: torch.Tensor + """FP8 datatype""" + dtype: TE_DType - def set_func(self, value: Any) -> None: - self._fp8_attrs[name] = value + def __init__( + self, + scale: torch.Tensor, + amax: torch.Tensor, + fp8_dtype: TE_DType, + *, + rowwise: bool = True, + columnwise: bool = True, + ) -> None: + super().__init__(rowwise=rowwise, columnwise=columnwise) + self.scale = scale + self.amax = amax + self.dtype = fp8_dtype - def del_func(self) -> None: - del self._fp8_attrs[name] + def update_quantized( + self, + src: torch.Tensor, + dst: QuantizedTensor, + *, + noop_flag: Optional[torch.Tensor] = None, + ) -> QuantizedTensor: + if not isinstance(dst, Float8Tensor): + raise ValueError("Float8Quantizer can only update Float8Tensor") - return {"fget": get_func, "fset": set_func, "fdel": del_func} + # Make sure input is in expected format + if not devices_match(src.device, dst.device): + src = src.to(device=dst.device) + if not src.is_contiguous(): + src = src.contiguous() + # Launch cast kernel + tex.quantize(src, self, dst, noop_flag) -class _FromFloat8Func(torch.autograd.Function): - """Cast from FP8 to other dtype""" + # Update FP8 dtype + dst._fp8_dtype = self.dtype - @staticmethod - def forward( - _ctx: torch.autograd.function.FunctionCtx, # unused - tensor: Float8Tensor, - dtype: Optional[torch.dtype] = None, - ) -> torch.Tensor: - # pylint: disable=missing-function-docstring - return tensor.dequantize(dtype=dtype) + return dst - @staticmethod - def backward( - _ctx: torch.autograd.function.FunctionCtx, # unused - grad: torch.Tensor, - ) -> Tuple[Optional[torch.Tensor], ...]: - # pylint: disable=missing-function-docstring - # Assume that we want gradients in full precision - return grad, None - - -class _ToFloat8Func(torch.autograd.Function): - """Cast to FP8 from other dtype""" - - @staticmethod - def forward( - _ctx: torch.autograd.function.FunctionCtx, # unused - tensor: torch.Tensor, - fp8_meta: Optional[Dict[str, Any]] = None, - fp8_meta_forward: bool = True, - fp8_meta_index: Optional[int] = None, - fp8_dtype: TE_DType = TE_DType.kFloat8E4M3, - data: Optional[torch.Tensor] = None, - scale: Optional[torch.Tensor] = None, - amax: Optional[torch.Tensor] = None, - scale_inv: Optional[torch.Tensor] = None, - with_transpose_cache: bool = False, - data_transpose: Optional[torch.Tensor] = None, + def make_empty( + self, + shape: Iterable[int], + *, + dtype: torch.dtype = torch.float32, + device: Optional[torch.device] = None, + requires_grad: bool = False, ) -> Float8Tensor: - # pylint: disable=missing-function-docstring - # Tensor attributes - dtype = tensor.dtype - if dtype not in (torch.float32, torch.bfloat16, torch.float16): - dtype = torch.float32 - device = tensor.device - if device.type != "cuda": + # Canonicalize tensor attributes + if device is None: device = torch.device("cuda") - # FP8 data buffer - if data is None: - data = torch.empty(tensor.size(), dtype=torch.uint8, device=device) - - # Check scale - if scale is None and fp8_meta is None: - scale = torch.full([1], 1, dtype=torch.float32, device=device) - if scale is not None: - scale = scale.to(device=device, dtype=torch.float32) - - # Check scale-inverse - if scale_inv is None: - scale_inv = torch.empty([1], dtype=torch.float32, device=device) - elif not devices_match(scale_inv.device, device) or scale_inv.dtype != dtype: - scale_inv = scale_inv.to(device=device, dtype=torch.float32) + # Allocate FP8 data + data = torch.empty(shape, dtype=torch.uint8, device=device) - # Transpose cache - if data_transpose is None and with_transpose_cache: + # Allocate FP8 data transpose if needed + data_transpose = None + if self.columnwise_usage: + inner_dim = data.size(-1) data_transpose = torch.empty( - (data.size(-1), data.numel() // data.size(-1)), + inner_dim, + data.numel() // inner_dim, dtype=torch.uint8, - device=tensor.device, + device=device, ) # Construct FP8 tensor - out = Float8Tensor( - data=data, - fp8_meta=fp8_meta, - fp8_meta_forward=fp8_meta_forward, - fp8_meta_index=fp8_meta_index, - fp8_dtype=fp8_dtype, - fp8_scale_inv=scale_inv, + return Float8Tensor( + shape=shape, dtype=dtype, + data=data, + fp8_scale_inv=torch.empty(1, dtype=torch.float32, device=device), + fp8_dtype=self.dtype, + requires_grad=requires_grad, data_transpose=data_transpose, + quantizer=self, ) - # Cast to FP8 tensor - out.quantize_(tensor, scale=scale, amax=amax) - - return out - - @staticmethod - def backward( - _ctx: torch.autograd.function.FunctionCtx, # unused - grad: torch.Tensor, - ) -> Tuple[Optional[torch.Tensor], ...]: - # pylint: disable=missing-function-docstring - # Assume that we want gradients in full precision - return grad, None, None, None, None, None, None, None, None, None - + def calibrate(self, tensor: torch.Tensor) -> None: + amin, amax = tensor.aminmax() + self.amax.copy_(torch.max(-amin, amax)) -class _IdentityFunc(torch.autograd.Function): - """Identity function - - If constructor keyword-arguments are provided, then construct a - new Float8Tensor using the provided tensor's attributes. - - """ - - @staticmethod - def forward( - ctx, - tensor: Float8Tensor, - init_kwargs: Optional[Dict[str, Any]] = None, - ) -> torch.Tensor: - # pylint: disable=missing-function-docstring - - # Return input tensor if constructor kwargs are not provided - ctx.input_dtype = tensor.dtype - if init_kwargs is None: - return tensor - - # Construct new tensor if constructor kwargs are provided - default_kwargs = { - "data": tensor._data, - "fp8_meta": tensor._fp8_meta, - "fp8_meta_forward": tensor._fp8_meta_forward, - "fp8_meta_index": tensor._fp8_meta_index, - "fp8_dtype": tensor._fp8_dtype, - "fp8_scale_inv": tensor._scale_inv, - "dtype": tensor.dtype, - } - for key, val in default_kwargs.items(): - if key not in init_kwargs: - init_kwargs[key] = val - return Float8Tensor(**init_kwargs) - - @staticmethod - def backward(ctx, grad): - # pylint: disable=missing-function-docstring - return grad.to(ctx.input_dtype), None - - -class _ViewFunc(torch.autograd.Function): - """View function - - View the Float8Tensor using the provided shape. - - """ - - @staticmethod - def forward( - ctx, - tensor: torch.Tensor, - shape: Tuple[int] = None, - ) -> torch.Tensor: - # pylint: disable=missing-function-docstring - - # Return input tensor if shape is not provided - ctx.shape = tensor.shape - if shape is None: - return tensor - - # Construct new tensor if shape is provided - if isinstance(tensor, Float8Tensor): - return Float8Tensor.make_like( - tensor, - data=tensor._data.view(*shape), - ) - return tensor.view(*shape) - - @staticmethod - def backward( - ctx, - grad: torch.Tensor, - ) -> Tuple[Optional[torch.Tensor], ...]: - # pylint: disable=missing-function-docstring - - if isinstance(grad, Float8Tensor): - dgrad = Float8Tensor.make_like( - grad, - data=grad._data.view(ctx.shape), - ) - return dgrad, None - return grad.view(ctx.shape), None - - -class _ReshapeFunc(torch.autograd.Function): - """Reshape function - - Reshape the Float8Tensor using the provided shape. - - """ - - @staticmethod - def forward( - ctx, - tensor: torch.Tensor, - shape: Tuple[int] = None, - ) -> torch.Tensor: - # pylint: disable=missing-function-docstring - - # Return input tensor if shape is not provided - ctx.shape = tensor.shape - if shape is None: - return tensor - - # Construct new tensor if shape is provided - if isinstance(tensor, Float8Tensor): - return Float8Tensor.make_like( - tensor, - data=tensor._data.reshape(*shape), - ) - return tensor.reshape(*shape) - - @staticmethod - def backward( - ctx, - grad: torch.Tensor, - ) -> Tuple[Optional[torch.Tensor], ...]: - # pylint: disable=missing-function-docstring - - if isinstance(grad, Float8Tensor): - dgrad = Float8Tensor.make_like( - grad, - data=grad._data.reshape(ctx.shape), + def create_tensor_from_data( + self, + data: torch.Tensor, + fake_dtype=torch.float32, + requires_grad: bool = False, + internal: bool = False, + ): + """Create Float8Tensor from raw uint8 data""" + assert data.dtype in [ + torch.uint8, + torch.float8_e4m3fn, + torch.float8_e4m3fnuz, + torch.float8_e5m2, + torch.float8_e5m2fnuz, + ] + if internal: + return Float8TensorBase( + data=data, + fp8_scale_inv=1 / self.scale, + fp8_dtype=self.dtype, + requires_grad=requires_grad, + data_transpose=None, + quantizer=self, ) - return dgrad, None - return grad.reshape(ctx.shape), None + return Float8Tensor( + shape=data.shape, + dtype=fake_dtype, + data=data, + fp8_scale_inv=1 / self.scale, + fp8_dtype=self.dtype, + requires_grad=requires_grad, + data_transpose=None, + quantizer=self, + ) -class Float8Tensor(QuantizedTensor): +class Float8Tensor(Float8TensorBase, QuantizedTensor): """Experimental tensor class with FP8 data The tensor presents as having a standard, higher-precision dtype, @@ -306,256 +176,69 @@ class Float8Tensor(QuantizedTensor): Parameters ---------- + shape: int or iterable of int + Tensor dimensions. + dtype: torch.dtype + Nominal tensor datatype. + requires_grad: bool, optional = False + Whether to compute gradients for this tensor. data: torch.Tensor - Raw FP8 data in a uint8 tensor - fp8_attrs: dict, optional - FP8 metadata, primarily managed by Float8Tensor. If - provided, all other FP8 configuration is ignored. - fp8_meta: dict, optional - FP8 metadata object, primarily managed by TE modules. - fp8_meta_forward: bool, default = `True` - Whether to access the FP8 metadata for the - forward pass. Ignored if fp8_meta is not - provided. - fp8_meta_index: int, optional - Index to access in FP8 meta tensors. Required if - fp8_meta is provided and otherwise ignored. - fp8_dtype: transformer_engine_torch.DType, default = kFloat8E4M3 - FP8 format. + Raw FP8 data in a uint8 tensor fp8_scale_inv: torch.Tensor - Reciprocal of the scaling factor applied when - casting to FP8, i.e. the scaling factor that must - be applied when casting from FP8 to higher - precision. Can be inferred from fp8_meta if - provided. - dtype: torch.dtype, default = torch.float32 - Nominal tensor datatype. + Reciprocal of the scaling factor applied when casting to FP8, + i.e. the scaling factor that must be applied when casting from + FP8 to higher precision. + fp8_dtype: transformer_engine_torch.DType + FP8 format. + data_transpose: torch.Tensor, optional + FP8 transpose data in a uint8 tensor + quantizer: Float8Quantizer, optional + Builder class for FP8 tensors """ - _data: torch.Tensor - _fp8_attrs: Dict[str, Any] - _fp8_meta: Optional[Dict[str, Any]] - _fp8_meta_forward: bool - _fp8_meta_index: Optional[int] - _fp8_dtype: TE_DType - _scale_inv: torch.Tensor - - # FP8 transpose cache - _transpose: Optional[torch.Tensor] - _transpose_invalid: bool - - def __new__( - cls, - *, - data: torch.Tensor, - fp8_attrs: Optional[Dict[str, Any]] = None, - fp8_meta: Optional[Dict[str, Any]] = None, - fp8_meta_forward: bool = True, - fp8_meta_index: Optional[int] = None, - fp8_dtype: TE_DType = TE_DType.kFloat8E4M3, - fp8_scale_inv: Optional[torch.Tensor] = None, - dtype: torch.dtype = torch.float32, - requires_grad: bool = False, - data_transpose: Optional[torch.Tensor] = None, - ): - - # Check that data buffer is valid - if data.element_size() != 1: - raise ValueError( - f"Float8Tensor requires data buffer with 8-bit dtype (got dtype={data.dtype})" - ) - if data.requires_grad: - raise ValueError("Float8Tensor requires non-differentiable data buffer") - if not data.is_cuda: - data = data.cuda() - - # Initialize tensor object - self = torch.Tensor._make_wrapper_subclass( - cls, - data.size(), - strides=data.stride(), - storage_offset=data.storage_offset(), - dtype=dtype, - layout=data.layout, - requires_grad=requires_grad, - device=data.device, - ) - self._data = data - - # Initialize dict of class attributes - # Note: We store FP8 attributes in a dictionary so we can - # share them between tensors with the same data, e.g. detached - # tensors. - if fp8_attrs is None: - self._fp8_attrs = {} - else: - self._fp8_attrs = fp8_attrs - return self - - # FP8 meta tensors - if fp8_meta is not None and fp8_meta_index is None: - raise ValueError( - "To initialize Float8Tensor with FP8 meta tensors, " - "the FP8 meta tensor index must also be provided" - ) - self._fp8_meta = fp8_meta - self._fp8_meta_forward = fp8_meta_forward - self._fp8_meta_index = fp8_meta_index - - # FP8 dtype - assert fp8_dtype in ( - TE_DType.kFloat8E4M3, - TE_DType.kFloat8E5M2, - ), f"Unsupported fp8_dtype {fp8_dtype}." - self._fp8_dtype = fp8_dtype - - # FP8 scale-inverse - if fp8_scale_inv is None and self._fp8_meta is not None: - fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( - forward=self._fp8_meta_forward, - ) - fp8_scale_inv = self._fp8_meta[fp8_meta_key].scale_inv[self._fp8_meta_index] - fp8_scale_inv = fp8_scale_inv.detach().view(1).clone() - if fp8_scale_inv is None: - raise ValueError( - "Attempted to initialize Float8Tensor without specifying scale-inverse" - ) - if fp8_scale_inv.numel() != 1: - raise ValueError( - "Attempted to initialize Float8Tensor with invalid scale-inverse tensor" - ) - if fp8_scale_inv.dim() != 1: - fp8_scale_inv = fp8_scale_inv.reshape(1) - if ( - not devices_match(fp8_scale_inv.device, self._data.device) - or fp8_scale_inv.dtype != torch.float32 - ): - fp8_scale_inv = fp8_scale_inv.to( - device=self._data.device, - dtype=torch.float32, - ) - self._scale_inv = fp8_scale_inv - - # FP8 transpose cache - self._transpose = data_transpose - self._transpose_invalid = self._transpose is None - - return self - - def fsdp_pre_all_gather(self, mesh): # pylint: disable=unused-argument - """ - A hook function used in torch fsdp2, called before all-gather - return (all-gather input), (metadata) - Ref: https://github.com/pytorch/pytorch/pull/122908 - - """ - - return (self._data,), (self,) - - def fsdp_post_all_gather( - self, - all_gather_outputs: Tuple[torch.Tensor, ...], - metadata: Any, - param_dtype: torch.dtype, # pylint: disable=unused-argument - *, - out: Optional[torch.Tensor] = None, - ): - """ - A hook function used in torch fsdp2, called after all-gather - return (Float8Tensor class instance of all-gathered input), (Things to free after forward) - Ref: https://github.com/pytorch/pytorch/pull/122908 - - """ - (data,) = all_gather_outputs - (sample,) = metadata - if out is not None: - assert isinstance(out, Float8Tensor), f"{type(out)}" - return None - return Float8Tensor.make_like(sample, data=data), all_gather_outputs - - @classmethod - def make_like( - cls, - tensor: Float8Tensor, - *, - data: torch.Tensor, - fp8_attrs: Optional[Dict[str, Any]] = None, - **kwargs, - ) -> Float8Tensor: - """Use attributes of a Float8Tensor to create another Float8Tensor - - See constructor for list of keyword arguments. - - """ - default_kwargs = { - "fp8_meta": tensor._fp8_meta, - "fp8_meta_forward": tensor._fp8_meta_forward, - "fp8_meta_index": tensor._fp8_meta_index, - "fp8_dtype": tensor._fp8_dtype, - "fp8_scale_inv": tensor._scale_inv, - "dtype": tensor.dtype, - } - for key, val in default_kwargs.items(): - if key not in kwargs: - kwargs[key] = val - return Float8Tensor(data=data, fp8_attrs=fp8_attrs, **kwargs) - - def __repr__(self): + def __repr__(self, *, tensor_contents=None): return ( "Float8Tensor(" f"fp8_dtype={self._fp8_dtype}, " f"scale_inv={self._scale_inv.item()}, " - f"data={self.from_float8(dtype=self.dtype)}" + f"data={self.dequantize(dtype=self.dtype)}" ")" ) def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor: + """ + Construct plain PyTorch tensor from Float8Tensor + By default the resulting tensor's dtype is the + Float8Tensor's nominal dtype. + """ # Convert PyTorch dtype to TE dtype if dtype is None: dtype = self.dtype - dtype = torch_to_transformer_engine_dtype[dtype] - # Make sure FP8 data is in expected format - data = self._data - if data.device.type != "cuda": - data = data.cuda() - if not data.is_contiguous(): - data = data.contiguous() - if data.dim() != 2: - data = data.view(1, -1) - - # Cast from FP8 - out = cast_from_fp8( - data.view(1, -1), - None, # fp8_meta_tensor - None, # fp8_tensor - self._fp8_dtype, - dtype, - scale_inv=self._scale_inv, - ) + if torch.is_grad_enabled(): + return _FromFloat8Func.apply(self, dtype) + return _FromFloat8Func.forward(None, self, dtype) - # Make sure output is in expected format - if out.size() != self.size(): - out = out.view(self.size()) - return out + def _get_quantizer(self) -> Quantizer: + """Get builder for quantized tensor - def from_float8(self, dtype: Optional[torch.dtype] = None) -> torch.Tensor: - """ - Construct plain PyTorch tensor from Float8Tensor + Quantizer can be used for in-place operations. - By default the resulting tensor's dtype is the - Float8Tensor's nominal dtype. """ - return _FromFloat8Func.apply(self, dtype) + if self._quantizer is not None: + return self._quantizer + return Float8Quantizer( + scale=torch.reciprocal(self._scale_inv), + amax=torch.empty(1, dtype=torch.float32, device=self.device), + fp8_dtype=self._fp8_dtype, + ) def quantize_( self, tensor: torch.Tensor, *, - scale: Optional[torch.Tensor] = None, - amax: Optional[torch.Tensor] = None, noop_flag: Optional[torch.Tensor] = None, ) -> Float8Tensor: """Update FP8 data @@ -564,181 +247,47 @@ def quantize_( ---------- tensor: torch.Tensor Tensor to copy from - scale: torch.Tensor, optional - Scaling factor to use for FP8 quantization - amax: torch.Tensor, optional - History of maximum absolute values. The first entry will - be updated with the absmax of `tensor`. noop_flag: torch.Tensor, optional float32 flag indicating whether to avoid performing update """ - src = tensor - dst = self - - # In-place operations invalidate transpose cache - self._reset_caches() - - # Special logic if other tensor is Float8Tensor - if isinstance(src, Float8Tensor): - - # Cast to plain tensor if FP8 dtypes don't match - if dst._fp8_dtype != src._fp8_dtype: - return dst.quantize_(src.dequantize()) - - # Directly copy FP8 data - dst._data.copy_(src._data.detach()) - dst._scale_inv.copy_(src._scale_inv.detach()) - if amax is not None or dst._fp8_meta is not None: - src_amax: torch.Tensor - if src._fp8_meta is None: - src_min, src_max = src.dequantize().aminmax() - src_amax = torch.maximum(-src_min, src_max) - else: - fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( - forward=src._fp8_meta_forward, - ) - fp8_meta_index = src._fp8_meta_index - src_amax = src._fp8_meta[fp8_meta_key].amax_history[0, fp8_meta_index] - dst_amax: torch.Tensor - if amax is None: - fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( - forward=dst._fp8_meta_forward, - ) - fp8_meta_index = dst._fp8_meta_index - dst_amax = dst._fp8_meta[fp8_meta_key].amax_history[0, fp8_meta_index] - else: - dst_amax = amax - if dst_amax.dim() > 0: - dst_amax = dst_amax[tuple([0] * dst_amax.dim())] - torch.maximum(src_amax, dst_amax, out=dst_amax) - if dst._transpose is not None: - if src._transpose is None: - dst.transpose_2d(force_compute=True, fill_cache=True) - else: - dst._transpose.copy_(src._transpose) - dst._transpose_invalid = False - return self + if isinstance(tensor, QuantizedTensor): + return self.quantize_(tensor.dequantize(), noop_flag=noop_flag) + self._get_quantizer().update_quantized(tensor, self, noop_flag=noop_flag) + return self - # Convert QuantizedTensor to plain tensor - if isinstance(src, QuantizedTensor): - return dst.quantize_(src.dequantize()) + def detach(self) -> Float8Tensor: + # pylint: disable=missing-function-docstring + return Float8Tensor.make_like(self) - # Make sure input is in expected format - if src.size() != dst.size(): - src = src.expand(dst.size()) - if not devices_match(src.device, dst.device): - src = src.to(device=dst.device) - if src.dtype not in (torch.float32, torch.bfloat16, torch.float16): - src = src.float() - if not src.is_contiguous(): - src = src.contiguous() + def _create_transpose(self): + data = self._data + if not data.is_contiguous(): + data = data.contiguous() + self._transpose = tex.fp8_transpose(data, self._fp8_dtype, out=self._transpose) + self._transpose_invalid = False - # Make sure FP8 scaling factors are in expected format - if scale is not None: - if not devices_match(scale.device, dst.device) or scale.dtype != torch.float32: - scale = scale.to(device=dst.device, dtype=torch.float32) - if amax is not None: - while amax.dim() < 2: - amax = amax.unsqueeze(0) - if not devices_match(amax.device, dst.device): - raise ValueError( - f"Invalid device for amax (expected {dst.device}, found {amax.device})" - ) - if amax.dtype != torch.float32: - raise ValueError(f"Invalid dtype for amax (expected float32, found {amax.type})") - - # Default FP8 scaling factors - fp8_meta = None - if dst._fp8_meta is None: - if scale is None: - scale = dst._scale_inv.reciprocal() - if amax is None: - amax = torch.empty((1, 1), dtype=torch.float32, device=dst.device) + def update_usage(self, rowwise_usage=True, columnwise_usage=True): + assert rowwise_usage or columnwise_usage, "Could not disable all usages of the tensor" + if rowwise_usage: + assert self._data is not None, "Rowwise usage of the tensor was already disabled" else: - fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( - forward=dst._fp8_meta_forward, - ) - fp8_meta = dst._fp8_meta[fp8_meta_key] - - # Check local data - if not dst._data.is_contiguous(): - raise RuntimeError("Transformer Engine cast kernels require contiguous data") - - # Perform FP8 cast - if dst._transpose is None: - dst_data = dst._data - if src.dim() != 2: - src = src.view(1, -1) - dst_data = dst_data.view(1, -1) - cast_to_fp8( - src, - fp8_meta, - dst._fp8_meta_index, - dst._fp8_dtype, - out=dst_data, - scale=scale, - amax=amax, - scale_inv=dst._scale_inv, - ) + if not non_tn_fp8_gemm_supported(): + if self._transpose is None or self._transpose_invalid: + self._create_transpose() + self._data = None + if columnwise_usage: + if self._transpose is None or self._transpose_invalid: + assert self._data is not None, "The tensor does not hold any data anymore" + if not non_tn_fp8_gemm_supported(): + self._create_transpose() else: - fp8_cast_transpose_fused( - src.view(-1, src.size(-1)), - fp8_meta, - dst._fp8_meta_index, - dst._fp8_dtype, - cast_out=dst._data, - transpose_out=dst._transpose, - scale=scale, - amax=amax, - scale_inv=dst._scale_inv, - noop_flag=noop_flag, - ) - dst._transpose_invalid = False - - return self - - @classmethod - def to_float8( - cls, - tensor: torch.Tensor, - *, - fp8_meta: Optional[Dict[str, Any]] = None, - fp8_meta_forward: bool = True, - fp8_meta_index: Optional[int] = None, - fp8_dtype: TE_DType = TE_DType.kFloat8E4M3, - data: Optional[torch.Tensor] = None, - scale: Optional[torch.Tensor] = None, - amax: Optional[torch.Tensor] = None, - scale_inv: Optional[torch.Tensor] = None, - with_transpose_cache: bool = False, - data_transpose: Optional[torch.Tensor] = None, - ): - """Construct Float8Tensor from plain PyTorch tensor""" - return _ToFloat8Func.apply( - tensor, - fp8_meta, - fp8_meta_forward, - fp8_meta_index, - fp8_dtype, - data, - scale, - amax, - scale_inv, - with_transpose_cache, - data_transpose, - ) - - def detach(self) -> Float8Tensor: - # pylint: disable=missing-function-docstring - return Float8Tensor.make_like( - self, - data=self._data, - fp8_attrs=self._fp8_attrs, - ) + self._transpose = None + self._transpose_invalid = True def clone(self) -> Float8Tensor: # pylint: disable=missing-function-docstring + assert self._data is not None data = self._data.detach().clone() data_transpose = None if self._transpose is not None: @@ -761,7 +310,6 @@ def reshape(self, *shape: Tuple[int]) -> Float8Tensor: def contiguous( self, - *, memory_format: torch.memory_format = torch.contiguous_format, ) -> Float8Tensor: """Returns tensor with data in provided memory format @@ -769,148 +317,15 @@ def contiguous( Returns `self` if data is already in correct memory format. """ - if self._data.is_contiguous(memory_format=memory_format): + if self._data is not None and self._data.is_contiguous(memory_format=memory_format): return self - return _IdentityFunc.apply( - self, - {"data": self._data.detach().contiguous(memory_format=memory_format)}, - ) - - def transpose_2d( - self, - *, - force_compute: bool = False, - fill_cache: bool = False, - noop_flag: Optional[torch.Tensor] = None, - cache: Optional[bool] = None, - ) -> torch.Tensor: - """ - 2D transpose with caching support. - - Parameters - ---------- - force_compute: bool, default = `False` - Force computation of transpose. Otherwise use - cached values, if possible. - fill_cache: bool, default = `False` - Cache output tensor for future function calls. - noop_flag: torch.Tensor, optional - float32 flag indicating whether to avoid updating - cached values, if possible. - cache: bool, deprecated - - """ - - # Handle deprecated cache kwarg - if cache is not None: - msg = ( - "cache kwarg for Float8Tensor.transpose_2d is deprecated, " - "please use force_compute and fill_cache instead" - ) - warnings.warn(msg, DeprecationWarning) - if cache: - force_compute = False - fill_cache = True - else: - force_compute = True - fill_cache = False - - # Need to compute transpose if cache is invalid - need_compute = ( - force_compute - or (self._transpose is None) - or self._transpose_invalid - or (noop_flag is not None) - ) - - # Return cached transpose if possible - if not need_compute: - assert self._transpose is not None - return self._transpose - - # Allocate output if needed - data = self._data.contiguous().reshape(-1, self.size(-1)) - out: Optional[torch.Tensor] = self._transpose - if out is None: - out = torch.empty( - (data.size(1), data.size(0)), - dtype=torch.uint8, - device=data.device, - ) - noop_flag = None - else: - self._transpose_invalid = False - - # Apply transpose kernel - fp8_dtype = self._fp8_dtype - if noop_flag is None: - tex.fp8_transpose_noalloc(data, out, fp8_dtype) - else: - noop_flag = noop_flag.to(dtype=torch.float32, device=data.device) - tex.fp8_transpose_noalloc_noop(data, out, noop_flag, fp8_dtype) - - # Fill cache if needed - if fill_cache: - self._transpose = out - self._transpose_invalid = False - - return out - - @torch.no_grad() - def cast_transpose_( - self, - tensor: torch.Tensor, - noop_flag: Optional[torch.Tensor] = None, - ) -> None: - """Cast from tensor and populate transpose cache - - Tensor is reshaped as a 2D matrix. - - Parameters - ---------- - tensor: torch.Tensor - Tensor to copy from. Must have same dimensions as - destination tensor. - noop_flag: torch.Tensor, optional - float32 flag indicating whether to avoid updating - destination tensor. - - """ - if self._transpose is None: - self._transpose = torch.empty( - (self.size(-1), self.numel() // self.size(-1)), - dtype=torch.uint8, - device=self.device, - ) - self.quantize_(tensor, noop_flag=noop_flag) - - @torch.no_grad() - def reset_fp8_meta_scale_inv(self) -> None: - """Replace FP8 meta tensor scale-inverse with cached value - - The FP8 meta tensor scale_inv entry corresponding to this - tensor is replaced with the scale_inv value used to construct - the tensor. - - """ - assert self._fp8_meta is not None, "FP8 meta tensors not found." - fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( - forward=self._fp8_meta_forward, - ) - self._fp8_meta[fp8_meta_key].scale_inv[self._fp8_meta_index].copy_(self._scale_inv[0]) - - def to_dtype(self, dtype: torch.dtype) -> Float8Tensor: - """Create `Float8Tensor` with given nominal dtype - - The new tensor has the same underlying FP8 data. + if self._transpose is not None and self._transpose.is_contiguous( + memory_format=memory_format + ): + return self + return Float8Tensor.make_like(tensor=self, data=self._data.contiguous()) - """ - return Float8Tensor.make_like( - self, - data=self._data, - fp8_attrs=self._fp8_attrs, - dtype=dtype, - ) + # raise ValueError("Float8Tensor does not support different memory formats!") def _reset_caches(self) -> None: """ @@ -919,32 +334,55 @@ def _reset_caches(self) -> None: """ self._transpose_invalid = True + def clear(self): + """Deallocate this tensor's memory. Typically not needed and must be used carefully.""" + self._data = torch.Tensor() if self._data is not None else None + self._transpose = torch.Tensor() if self._transpose is not None else None + self._transpose_invalid = True + @classmethod def __torch_dispatch__(cls, func, types, args, kwargs=None): - # Slice op - if func == aten.slice.Tensor: + # View op + if func == aten.view.default: tensor = args[0] data = tensor._data - data_slice = data.__torch_dispatch__( + out_data = data.__torch_dispatch__( func, types, [data] + list(args[1:]), kwargs, ) - return Float8Tensor.make_like(tensor, data=data_slice) + out_shape = out_data.size() + out_transpose = None if tensor._transpose_invalid else tensor._transpose + if out_transpose is not None: + out_transpose_shape = out_transpose.size() + if ( + out_transpose_shape[0] != out_shape[-1] + or out_transpose_shape[1:] != out_shape[:-1] + ): + out_transpose = None + return Float8Tensor( + shape=out_shape, + dtype=tensor.dtype, + requires_grad=False, + data=out_data, + fp8_scale_inv=tensor._scale_inv, + fp8_dtype=tensor._fp8_dtype, + data_transpose=out_transpose, + quantizer=tensor._quantizer, + ) - # View op - if func == aten.view.default: + if func in [aten.slice.Tensor, aten.select.int]: tensor = args[0] data = tensor._data - data_view = data.__torch_dispatch__( + data_slice = data.__torch_dispatch__( func, types, [data] + list(args[1:]), kwargs, ) - return Float8Tensor.make_like(tensor, data=data_view) + return Float8Tensor.make_like(tensor, data=data_slice, shape=data_slice.shape) # Related to FSDP2 if func == aten.split.Tensor: @@ -982,8 +420,14 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): if func == torch.ops.aten.clone.default: return cls.clone(args[0]) if func == torch.ops.aten.copy_.default: - # Implementation in the superclass (QuantizedTensor) returns a proper output - pass + dst, src = args[0], args[1] + # Just copy FP8 attrs if copying between Float8Tensors + if isinstance(src, Float8Tensor) and isinstance(dst, Float8Tensor): + dst._data.copy_(src._data.detach()) + dst._scale_inv.copy_(src._scale_inv.view(dst._scale_inv.size())) + if src._transpose is not None or dst._transpose is not None: + dst._create_transpose() + return dst elif func in _ops_to_preserve_subclass_in_fsdp2: # Ops in the _ops_to_preserve_subclass_in_fsdp2 are recommened to return the same class instance to work fine with the torch fsdp2 warnings.warn( @@ -1002,6 +446,7 @@ def _make_in_reduce_ex( fp8_dtype: TE_DType, fp8_scale_inv: torch.Tensor, dtype: torch.dtype, + shape: torch.shape, ) -> Float8Tensor: """Build Float8Tensor, for use in __reduce__ @@ -1014,13 +459,14 @@ def _make_in_reduce_ex( fp8_dtype=fp8_dtype, fp8_scale_inv=fp8_scale_inv, dtype=dtype, + shape=shape, ) def __reduce_ex__(self, protocol: int) -> tuple: """Custom pickling to remove references to FP8 metadata objects""" return ( Float8Tensor._make_in_reduce_ex, - (self._data, self._fp8_dtype, self._scale_inv, self.dtype), + (self._data, self._fp8_dtype, self._scale_inv, self.dtype, self.shape), ) def _get_data(self) -> Float8Tensor: @@ -1039,12 +485,10 @@ def _set_data(self, tensor: torch.Tensor) -> None: # Tensor device new_device = tensor.device if tensor.is_cuda else self.device - # Check whether grad is required - if self.requires_grad != tensor.requires_grad: - self.requires_grad_(requires_grad=tensor.requires_grad) - # Just copy FP8 data if other tensor is Float8Tensor if isinstance(tensor, Float8Tensor): + + # PyTorch tensor attributes if ( # pylint: disable=too-many-boolean-expressions self.size() != tensor.size() or self.stride() != tensor.stride() @@ -1065,57 +509,110 @@ def _set_data(self, tensor: torch.Tensor) -> None: ) # pylint: disable=unnecessary-dunder-call super(Float8Tensor, type(self)).data.__set__(self, dummy_tensor) + + # Float8Tensor attributes self._data = tensor._data - self._fp8_attrs = tensor._fp8_attrs + self._quantizer = tensor._quantizer + self._fp8_dtype = tensor._fp8_dtype + self._scale_inv = tensor._scale_inv + self._transpose = tensor._transpose + self._transpose_invalid = tensor._transpose_invalid return - # Reallocate FP8 data if needed - if ( - self.size() != tensor.size() - or self.stride() != tensor.stride() - or self.dtype != tensor.dtype - or self.layout != tensor.layout - or not devices_match(self.device, new_device) - ): - self._data = torch.empty_like( - tensor, - dtype=torch.uint8, - device=new_device, - ) - dummy_tensor = torch.Tensor._make_wrapper_subclass( - Float8Tensor, - self._data.size(), - strides=self._data.stride(), - storage_offset=self._data.storage_offset(), - dtype=tensor.dtype, - layout=self._data.layout, - requires_grad=tensor.requires_grad, - device=self._data.device, - ) - # pylint: disable=unnecessary-dunder-call - super(Float8Tensor, type(self)).data.__set__(self, dummy_tensor) - if self._transpose is not None: - self._transpose = torch.empty( - (self._data.size(-1), self._data.numel() // self._data.size(-1)), - dtype=torch.uint8, - device=self.device, - ) - self._transpose_invalid = True - - # Copy values from other tensor - self.quantize_(tensor) + # Quantize to FP8 + assert self._quantizer is not None, "Can't quantize without a quantizer" + self._quantizer.internal = False + self.data = self._quantizer.quantize(tensor) + if self.requires_grad != tensor.requires_grad: + self.requires_grad_(requires_grad=tensor.requires_grad) # Cast to FP8 when setting Float8Tensor.data data = property(_get_data, _set_data) - # Accessors for objects in self._fp8_attrs - # Note: We store FP8 attributes in a dictionary so we can share - # them between tensors with the same data, e.g. detached tensors. - # For convenience, we also expose them as property attributes. - _fp8_meta = property(**_make_fp8_attr_property_funcs("fp8_meta")) - _fp8_meta_forward = property(**_make_fp8_attr_property_funcs("fp8_meta_forward")) - _fp8_meta_index = property(**_make_fp8_attr_property_funcs("fp8_meta_index")) - _fp8_dtype = property(**_make_fp8_attr_property_funcs("dtype")) - _transpose = property(**_make_fp8_attr_property_funcs("transpose")) - _transpose_invalid = property(**_make_fp8_attr_property_funcs("transpose_invalid")) - _scale_inv = property(**_make_fp8_attr_property_funcs("scale_inv")) + +class _ViewFunc(torch.autograd.Function): + """View function + + View the Float8Tensor using the provided shape. + + """ + + @staticmethod + def forward( + ctx, + tensor: Float8Tensor, + shape: Optional[list[int]] = None, + ) -> Float8Tensor: + # pylint: disable=missing-function-docstring + ctx.shape = tensor.shape + if shape is None: + return tensor.detach() + out_data = tensor._data.view(*shape) + out_shape = out_data.size() + out_transpose = None if tensor._transpose_invalid else tensor._transpose + if out_transpose is not None: + out_transpose_shape = out_transpose.size() + if out_transpose_shape[0] != out_shape[-1] or out_transpose_shape[1:] != out_shape[:-1]: + out_transpose = None + return Float8Tensor( + shape=out_shape, + dtype=tensor.dtype, + requires_grad=tensor.requires_grad, + data=out_data, + fp8_scale_inv=tensor._scale_inv, + fp8_dtype=tensor._fp8_dtype, + data_transpose=out_transpose, + quantizer=tensor._quantizer, + ) + + @staticmethod + def backward( + ctx, + grad: torch.Tensor, + ) -> Tuple[Optional[torch.Tensor], ...]: + # pylint: disable=missing-function-docstring + return grad.reshape(ctx.shape), None + + +class _ReshapeFunc(torch.autograd.Function): + """Reshape function + + Reshape the Float8Tensor using the provided shape. + + """ + + @staticmethod + def forward( + ctx, + tensor: Float8Tensor, + shape: Tuple[int], + ) -> Float8Tensor: + # pylint: disable=missing-function-docstring + ctx.shape = tensor.shape + if shape is None: + return tensor.detach() + out_data = tensor._data.reshape(*shape) + out_shape = out_data.size() + out_transpose = None if tensor._transpose_invalid else tensor._transpose + if out_transpose is not None: + out_transpose_shape = out_transpose.size() + if out_transpose_shape[0] != out_shape[-1] or out_transpose_shape[1:] != out_shape[:-1]: + out_transpose = None + return Float8Tensor( + shape=out_shape, + dtype=tensor.dtype, + requires_grad=tensor.requires_grad, + data=out_data, + fp8_scale_inv=tensor._scale_inv, + fp8_dtype=tensor._fp8_dtype, + data_transpose=out_transpose, + quantizer=tensor._quantizer, + ) + + @staticmethod + def backward( + ctx, + grad: torch.Tensor, + ) -> Tuple[Optional[torch.Tensor], ...]: + # pylint: disable=missing-function-docstring + return grad.reshape(ctx.shape), None diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py new file mode 100644 index 0000000000..86b13415a1 --- /dev/null +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -0,0 +1,582 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Tensor class with FP8 data""" +from __future__ import annotations +from collections.abc import Iterable +import math +from typing import Optional, Tuple + +import torch +import transformer_engine_torch as tex + +from transformer_engine_torch import DType as TE_DType +from ..constants import MXFP8_BLOCK_SCALING_SIZE +from ..utils import devices_match, round_up_to_nearest_multiple + +from ._internal.mxfp8_tensor_base import MXFP8TensorBase, _FromMXFP8Func +from .quantized_tensor import QuantizedTensor, Quantizer, _IdentityFunc + +aten = torch.ops.aten + + +class MXFP8Quantizer(Quantizer): + """Builder class for FP8 tensors with MX block scaling + + High-precision tensors (e.g. in FP32 or BF16) are quantized by + dividing them into groups of 32 elements, each scaled and cast + separately using current data. + + """ + + dtype: TE_DType + + def __init__( + self, + fp8_dtype: TE_DType, + *, + rowwise: bool = True, + columnwise: bool = True, + ) -> None: + super().__init__(rowwise=rowwise, columnwise=columnwise) + self.dtype = fp8_dtype + + def update_quantized( + self, + src: torch.Tensor, + dst: QuantizedTensor, + *, + noop_flag: Optional[torch.Tensor] = None, + ) -> QuantizedTensor: + + assert isinstance(dst, MXFP8Tensor), f"Cannot store quantized MXFP8 in {type(dst)} type." + + # Make sure input is in expected format + if not devices_match(src.device, dst.device): + src = src.to(device=dst.device) + if not src.is_contiguous(): + src = src.contiguous() + + # Launch cast kernel + tex.quantize(src, self, dst, noop_flag) + + # Update FP8 dtype + dst._fp8_dtype = self.dtype + + return dst + + def make_empty( + self, + shape: Iterable[int], + *, + dtype: torch.dtype = torch.float32, + device: Optional[torch.device] = None, + requires_grad: bool = False, + ) -> MXFP8Tensor: + + # Canonicalize tensor attributes + if device is None: + device = torch.device("cuda") + + assert ( + shape[-1] % MXFP8_BLOCK_SCALING_SIZE == 0 + and math.prod(shape[:-1]) % MXFP8_BLOCK_SCALING_SIZE == 0 + ), ( + f"Incorrect shape {shape} for MXFP8. Tensor dims must be divisible by" + f" {MXFP8_BLOCK_SCALING_SIZE}" + ) + + # Allocate FP8 data + data = torch.empty(shape, dtype=torch.uint8, device=device) + scale_inv = torch.zeros( + round_up_to_nearest_multiple(math.prod(shape[:-1]), 128), + round_up_to_nearest_multiple(shape[-1] // MXFP8_BLOCK_SCALING_SIZE, 4), + dtype=torch.uint8, + device=device, + ) + + # Allocate FP8 data transpose if needed + columnwise_data = None + columnwise_scale_inv = None + if self.columnwise_usage: + columnwise_data = torch.empty_like(data) + columnwise_scale_inv = torch.zeros( + round_up_to_nearest_multiple(math.prod(shape[:-1]) // MXFP8_BLOCK_SCALING_SIZE, 4), + round_up_to_nearest_multiple(shape[-1], 128), + dtype=torch.uint8, + device=device, + ) + + # Construct FP8 tensor + return MXFP8Tensor( + shape=shape, + dtype=dtype, + fp8_dtype=self.dtype, + rowwise_data=data, + rowwise_scale_inv=scale_inv, + columnwise_data=columnwise_data, + columnwise_scale_inv=columnwise_scale_inv, + quantizer=self, + requires_grad=requires_grad, + ) + + def calibrate(self, tensor: torch.Tensor) -> None: + # TODO(ksivamani): No calibration needed for mxfp8? + pass + + +class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor): + """Experimental tensor class with FP8 data + + The tensor presents as having a standard, higher-precision dtype, + but the data itself is (scaled) FP8. For most tensor operations, + the data will be cast to the nominal dtype before performing the + operation. + + Parameters + ---------- + data: torch.Tensor + Raw FP8 data in a uint8 tensor + fp8_dtype: transformer_engine_torch.DType, default = kFloat8E4M3 + FP8 format. + fp8_scale_inv: torch.Tensor + Reciprocal of the scaling factor applied when + casting to FP8, i.e. the scaling factor that must + be applied when casting from FP8 to higher + precision. Can be inferred from fp8_meta if + provided. + dtype: torch.dtype, default = torch.float32 + Nominal tensor datatype. + + """ + + def __repr__(self, *, tensor_contents=None): + return f"MXFP8Tensor(fp8_dtype={self._fp8_dtype}, data={self.dequantize(dtype=self.dtype)})" + + def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor: + """ + Construct plain PyTorch tensor from MXFP8Tensor + + By default the resulting tensor's dtype is the + MXFP8Tensor's nominal dtype. + """ + # Convert PyTorch dtype to TE dtype + if dtype is None: + dtype = self.dtype + + if torch.is_grad_enabled(): + return _FromMXFP8Func.apply(self, dtype) + return _FromMXFP8Func.forward(None, self, dtype) + + def _get_quantizer(self) -> Quantizer: + """Get builder for quantized tensor + + Quantizer can be used for in-place operations. + + """ + if self._quantizer is not None: + return self._quantizer + return MXFP8Quantizer( + fp8_dtype=self._fp8_dtype, + ) + + def quantize_( + self, + tensor: torch.Tensor, + *, + noop_flag: Optional[torch.Tensor] = None, + ) -> MXFP8Tensor: + """Update FP8 data + + Parameters + ---------- + tensor: torch.Tensor + Tensor to copy from + noop_flag: torch.Tensor, optional + float32 flag indicating whether to avoid performing update + + """ + if isinstance(tensor, QuantizedTensor): + return self.quantize_(tensor.dequantize()) + self._get_quantizer().update_quantized(tensor, self, noop_flag=noop_flag) + return self + + def detach(self) -> MXFP8Tensor: + # pylint: disable=missing-function-docstring + # TODO(ksivamani): Fix the detach bug + return MXFP8Tensor.make_like(self) + + def update_usage(self, rowwise_usage=True, columnwise_usage=True): + """ + For MXFP8, columnwise scaled output is only produced by x2 + scaling kernels, so this function only disables usages. + """ + assert rowwise_usage or columnwise_usage, "Could not disable all usages of the tensor." + + if columnwise_usage and rowwise_usage: + assert ( + self._rowwise_data is not None + and self._rowwise_scale_inv is not None + and self._columnwise_data is not None + and self._columnwise_scale_inv is not None + ), "Cannot update to rowwise and columnwise usage." + return + + if rowwise_usage: + assert ( + self._rowwise_data is not None and self._rowwise_scale_inv is not None + ), "Cannot update to rowwise usage." + self._columnwise_data = None + self._columnwise_scale_inv = None + return + + assert ( + self._columnwise_data is not None and self._columnwise_scale_inv is not None + ), "Cannot update to columnwise usage." + self._rowwise_data = None + self._rowwise_scale_inv = None + return + + def clone(self) -> MXFP8Tensor: + # pylint: disable=missing-function-docstring + assert self._rowwise_data is not None + rowwise_data = self._rowwise_data.detach().clone() + columnwise_data = None + if self._columnwise_data is not None: + columnwise_data = self._columnwise_data.detach().clone() + return _IdentityFunc.apply( + self, + { + "rowwise_data": rowwise_data, + "columnwise_data": columnwise_data, + }, + ) + + def view(self, *shape: Tuple[int]) -> MXFP8Tensor: + # pylint: disable=missing-function-docstring + return _ViewFunc.apply(self, shape) + + def reshape(self, *shape: Tuple[int]) -> MXFP8Tensor: + # pylint: disable=missing-function-docstring + return _ReshapeFunc.apply(self, shape) + + def contiguous( + self, + memory_format: torch.memory_format = torch.contiguous_format, + ) -> MXFP8Tensor: + """Returns tensor with data in provided memory format + + Returns `self` if data is already in correct memory format. + + """ + if self._rowwise_data is not None and self._rowwise_data.is_contiguous( + memory_format=memory_format + ): + return self + if self._columnwise_data is not None and self._columnwise_data.is_contiguous( + memory_format=memory_format + ): + return self + raise ValueError("MXFP8Tensor does not support different memory formats!") + + def clear(self): + """Deallocate this tensor's memory. Typically not needed and must be used carefully.""" + self._rowwise_data = torch.Tensor() if self._rowwise_data is not None else None + self._columnwise_data = torch.Tensor() if self._columnwise_data is not None else None + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs=None): + + # View op + if func == aten.view.default: + tensor = args[0] + data = tensor._rowwise_data + out_data = data.__torch_dispatch__( + func, + types, + [data] + list(args[1:]), + kwargs, + ) + out_shape = out_data.size() + return MXFP8Tensor( + shape=out_shape, + dtype=tensor.dtype, + rowwise_data=out_data, + rowwise_scale_inv=tensor._rowwise_scale_inv, + columnwise_data=tensor._columnwise_data, + columnwise_scale_inv=tensor._columnwise_scale_inv, + quantizer=tensor._quantizer, + requires_grad=False, + fp8_dtype=tensor._fp8_dtype, + ) + + # Default case + return super().__torch_dispatch__(func, types, args, kwargs) + + @classmethod + def _make_in_reduce_ex( + cls, + rowwise_data: torch.Tensor, + rowwise_scale_inv: torch.Tensor, + columnwise_data: torch.Tensor, + columnwise_scale_inv: torch.Tensor, + fp8_dtype: TE_DType, + dtype: torch.dtype, + ) -> MXFP8Tensor: + """Build MXFP8Tensor, for use in __reduce__ + + __reduce_ex__ assumes object constructor has positional + arguments. + + """ + return MXFP8Tensor( + rowwise_data=rowwise_data, + rowwise_scale_inv=rowwise_scale_inv, + fp8_dtype=fp8_dtype, + columnwise_data=columnwise_data, + columnwise_scale_inv=columnwise_scale_inv, + dtype=dtype, + ) + + def __reduce_ex__(self, protocol: int) -> tuple: + """Custom pickling to remove references to FP8 metadata objects""" + return ( + MXFP8Tensor._make_in_reduce_ex, + ( + self._rowwise_data, + self._rowwise_scale_inv, + self._columnwise_data, + self._columnwise_scale_inv, + self._fp8_dtype, + self.dtype, + ), + ) + + def _get_data(self) -> MXFP8Tensor: + """Get tensor data property""" + return super().data + + @torch.no_grad() + def _set_data(self, tensor: torch.Tensor) -> None: + """Set tensor data property + + Just takes FP8 data if setting from a MXFP8Tensor. Otherwise + casts to FP8. + + """ + + # Tensor device + new_device = tensor.device if tensor.is_cuda else self.device + + # Just copy FP8 data if other tensor is MXFP8Tensor + if isinstance(tensor, MXFP8Tensor): + if ( # pylint: disable=too-many-boolean-expressions + self.size() != tensor.size() + or self.stride() != tensor.stride() + or self.storage_offset() != tensor.storage_offset() + or self.dtype != tensor.dtype + or self.layout != tensor.layout + or not devices_match(self.device, new_device) + ): + dummy_tensor = torch.Tensor._make_wrapper_subclass( + MXFP8Tensor, + tensor.size(), + strides=tensor.stride(), + storage_offset=tensor.storage_offset(), + dtype=tensor.dtype, + layout=tensor.layout, + requires_grad=tensor.requires_grad, + device=new_device, + ) + # pylint: disable=unnecessary-dunder-call + super(MXFP8Tensor, type(self)).data.__set__(self, dummy_tensor) + self._rowwise_data = tensor._rowwise_data + self._columnwise_data = tensor._columnwise_data + self._quantizer = tensor._quantizer + self._fp8_dtype = tensor._fp8_dtype + self._rowwise_scale_inv = tensor._rowwise_scale_inv + self._columnwise_scale_inv = tensor._columnwise_scale_inv + return + + # Quantize to FP8 + assert self._quantizer is not None, "Can't quantize without a quantizer" + self.data = self._quantizer.quantize(tensor) + if self.requires_grad != tensor.requires_grad: + self.requires_grad_(requires_grad=tensor.requires_grad) + + # Cast to FP8 when setting MXFP8Tensor.data + data = property(_get_data, _set_data) + + +class _ViewFunc(torch.autograd.Function): + """View function + + View the MXFP8Tensor using the provided shape. + + """ + + @staticmethod + def forward( + ctx, + tensor: MXFP8Tensor, + shape: Optional[list[int]] = None, + ) -> MXFP8Tensor: + # pylint: disable=missing-function-docstring + + # Return input tensor if shape is not provided + ctx.shape = tensor.shape + if shape is None: + return tensor + + # Canonicalize shape + if not isinstance(shape, Iterable): + shape = [shape] + elif len(shape) == 1 and isinstance(shape[0], Iterable): + shape = shape[0] + if -1 in shape: + shape = list(shape) + d_inferred = -math.prod(ctx.shape) // math.prod(shape) + for i, d in enumerate(shape): + if d == -1: + shape[i] = d_inferred + break + if shape[-1] != ctx.shape[-1]: + raise RuntimeError( + "MXFP8Tensor does not support reshaping inner dimension " + f"(attempted to reshape dims={tuple(tensor.shape)} to {tuple(shape)})" + ) + + # Construct new tensor if shape is provided + new_rowwise_data = None + new_columnwise_data = None + if tensor._rowwise_data is not None: + new_rowwise_data = tensor._rowwise_data.view(*shape) + if tensor._columnwise_data is not None: + columnwise_shape = [shape[-1]] + list(shape[:-1]) + new_columnwise_data = tensor._columnwise_data.view(columnwise_shape) + return MXFP8Tensor( + shape, + tensor.dtype, + rowwise_data=new_rowwise_data, + rowwise_scale_inv=tensor._rowwise_scale_inv, + columnwise_data=new_columnwise_data, + columnwise_scale_inv=tensor._columnwise_scale_inv, + fp8_dtype=tensor._fp8_dtype, + quantizer=tensor._quantizer, + ) + + @staticmethod + def backward( + ctx, + grad: torch.Tensor, + ) -> Tuple[Optional[torch.Tensor], ...]: + # pylint: disable=missing-function-docstring + + if isinstance(grad, MXFP8Tensor): + new_data = ( + grad._rowwise_data.view(*ctx.shape) if grad._rowwise_data is not None else None + ) + if grad._columnwise_data is not None: + new_columnwise_data = grad._columnwise_data.view(ctx.shape[-1], -1) + else: + new_columnwise_data = None + dgrad = MXFP8Tensor( + ctx.shape, + grad.dtype, + rowwise_data=new_data, + rowwise_scale_inv=grad._rowwise_scale_inv, + columnwise_data=new_columnwise_data, + columnwise_scale_inv=grad._columnwise_scale_inv, + fp8_dtype=grad._fp8_dtype, + quantizer=grad._quantizer, + ) + return dgrad, None + return grad.view(ctx.shape), None + + +class _ReshapeFunc(torch.autograd.Function): + """Reshape function + + Reshape the MXFP8Tensor using the provided shape. + + """ + + @staticmethod + def forward( + ctx, + tensor: MXFP8Tensor, + shape: Optional[list[int]] = None, + ) -> MXFP8Tensor: + # pylint: disable=missing-function-docstring + + # Return input tensor if shape is not provided + ctx.shape = tensor.shape + if shape is None: + return tensor + + # Canonicalize shape + if not isinstance(shape, Iterable): + shape = [shape] + elif len(shape) == 1 and isinstance(shape[0], Iterable): + shape = shape[0] + if -1 in shape: + shape = list(shape) + d_inferred = -math.prod(ctx.shape) // math.prod(shape) + for i, d in enumerate(shape): + if d == -1: + shape[i] = d_inferred + break + if shape[-1] != ctx.shape[-1]: + raise RuntimeError( + "MXFP8Tensor does not support reshaping inner dimension " + f"(attempted to reshape dims={tuple(tensor.shape)} to {tuple(shape)})" + ) + + # Construct new tensor if shape is provided + new_rowwise_data = None + new_columnwise_data = None + if tensor._rowwise_data is not None: + new_rowwise_data = tensor._rowwise_data.reshape(*shape) + if tensor._columnwise_data is not None: + columnwise_shape = [shape[-1]] + list(shape[:-1]) + new_columnwise_data = tensor._columnwise_data.view(columnwise_shape) + + return MXFP8Tensor( + shape, + tensor.dtype, + rowwise_data=new_rowwise_data, + rowwise_scale_inv=tensor._rowwise_scale_inv, + columnwise_data=new_columnwise_data, + columnwise_scale_inv=tensor._columnwise_scale_inv, + fp8_dtype=tensor._fp8_dtype, + quantizer=tensor._quantizer, + ) + + @staticmethod + def backward( + ctx, + grad: torch.Tensor, + ) -> Tuple[Optional[torch.Tensor], ...]: + # pylint: disable=missing-function-docstring + + if isinstance(grad, MXFP8Tensor): + new_rowwise_data = None + new_columnwise_data = None + if grad._rowwise_data is not None: + new_rowwise_data = grad._rowwise_data.view(*ctx.shape) + if grad._columnwise_data is not None: + columnwise_shape = [ctx.shape[-1]] + list(ctx.shape[:-1]) + new_columnwise_data = grad._columnwise_data.view(columnwise_shape) + dgrad = MXFP8Tensor( + ctx.shape, + grad.dtype, + rowwise_data=new_rowwise_data, + rowwise_scale_inv=grad._rowwise_scale_inv, + columnwise_data=new_columnwise_data, + columnwise_scale_inv=grad._columnwise_scale_inv, + fp8_dtype=grad._fp8_dtype, + quantizer=grad._quantizer, + ) + return dgrad, None + return grad.view(ctx.shape), None diff --git a/transformer_engine/pytorch/tensor/quantized_tensor.py b/transformer_engine/pytorch/tensor/quantized_tensor.py index 550e113389..707382696d 100644 --- a/transformer_engine/pytorch/tensor/quantized_tensor.py +++ b/transformer_engine/pytorch/tensor/quantized_tensor.py @@ -5,23 +5,192 @@ """Tensor with quantized data""" from __future__ import annotations -from typing import Optional, Tuple +from typing import Optional, Tuple, Iterable, Any, Dict, Union +import abc +import copy import torch from torch.utils._pytree import tree_map +import transformer_engine_torch as tex + + +def prepare_for_saving( + *tensors, +) -> Tuple[list[Optional[Union[torch.Tensor, torch.nn.Parameter]]], Optional[Any]]: + """Prepare tensors for saving. Needed because save_for_backward accepts only + torch.Tensor/torch.nn.Parameter types, while we want to be able to save + the internal TensorBase types too.""" + # pylint: disable=unidiomatic-typecheck # Using type instead of isinstance to check exact type + tensor_list, tensor_objects_list = [], [] + for tensor in tensors: + if tensor is None: + tensor_list.append(None) + tensor_objects_list.append(None) + elif type(tensor) in (torch.Tensor, torch.nn.Parameter): + tensor_list.append(tensor.data) + tensor_objects_list.append(None) + else: + t, t_obj = tensor.prepare_for_saving() + tensor_list.extend(t) + tensor_objects_list.append(t_obj) + return tensor_list, tensor_objects_list + + +def restore_from_saved( + tensors: list[Optional[Any]], + saved_tensors: list[Optional[Union[torch.Tensor, torch.nn.Parameter]]], +) -> list[Optional[Any]]: + """Recombine the tensor data and metadata during backward pass.""" + tensor_objects = [] + for tensor in tensors: + if tensor is None: + tensor_objects.append(saved_tensors[0]) + saved_tensors = saved_tensors[1:] + else: + saved_tensors = tensor.restore_from_saved(saved_tensors) + tensor_objects.append(tensor) + return tensor_objects + + +class Quantizer(abc.ABC): + """Builder class for quantized tensors. + + This class is typically used to convert a high-precision tensor + (e.g. in FP32 or BF16) into a quantized tensor (e.g. in FP8). -class _DequantizeFunc(torch.autograd.Function): - """Autograd function to convert quantized tensor to standard tensor""" + """ + + """Whether to construct quantized tensors with "row-wise usage" + + Hand-wave explanation: Consider the matrix multiplication C = A * + B^T (used in linear forward). Tensor Cores prefer "TN GEMMs" (in + Fortran-style column-major order), so A and B should be in + row-major order. + + """ + rowwise_usage: bool + + """Whether to construct quantized tensors with "column-wise usage" + + Hand-wave explanation: Consider the matrix multiplication C = A^T + * B (used in linear backward wgrad). Tensor Cores prefer "TN + GEMMs" (in Fortran-style column-major order), so A and B should be + in column-major order. + + """ + columnwise_usage: bool + + """Whether to instantiates tensor for purely internal usage + + Internal tensors are storage classes with minimal logic. They have + less overhead than PyTorch tensor sub-classes, but are not + compatible with PyTorch's autograd infrastructure nor PyTorch + operations. + + """ + internal: bool + + def __init__(self, *, rowwise: bool, columnwise: bool) -> None: + self.rowwise_usage = rowwise + self.columnwise_usage = columnwise + self.internal = False + + def __repr__(self): + return ( + f"{self.__class__.__name__}(" + f"rowwise_usage={self.rowwise_usage}, " + f"columnwise_usage={self.columnwise_usage}, " + f"internal={self.internal}, " + ")" + ) + + @abc.abstractmethod + def update_quantized( + self, + src: torch.Tensor, + dst: QuantizedTensor, + *, + noop_flag: Optional[torch.Tensor] = None, + ) -> QuantizedTensor: + """Quantize tensor in-place""" + + def quantize( + self, + tensor: torch.Tensor, + *, + out: Optional[QuantizedTensor] = None, + ) -> QuantizedTensor: + """Quantize tensor""" + if out is not None: + return self.update_quantized(tensor, out) + if (not self.internal) and torch.is_grad_enabled(): + return _QuantizeFunc.apply(tensor, self) + return _QuantizeFunc.forward(None, tensor, self) + + def multi_quantize(self, list_of_tensors): + """Quantize multiple tensors""" + list_of_output_tensors = [] + for tensor in list_of_tensors: + list_of_output_tensors.append(self.quantize(tensor)) + return list_of_output_tensors + + def __call__(self, tensor: torch.Tensor) -> QuantizedTensor: + """Quantize tensor""" + return self.quantize(tensor) + + @abc.abstractmethod + def make_empty( + self, + shape: Iterable[int], + *, + dtype: torch.dtype = torch.float32, + device: Optional[torch.device] = None, + ) -> QuantizedTensor: + """Construct quantized tensor with uninitialized data""" + + @abc.abstractmethod + def calibrate(self, tensor: torch.Tensor) -> None: + """Calibrate quantizer state + + Updates quantization state as if quantizing a tensor, but + without actually performing the quantization. + + """ + + def set_usage( + self, + *, + rowwise: Optional[bool] = None, + columnwise: Optional[bool] = None, + ) -> None: + """Set how the quantized tensor is expected to be used + + See documentation for `rowwise_usage` and `columnwise_usage` + variables. + + """ + if rowwise is not None: + self.rowwise_usage = rowwise + if columnwise is not None: + self.columnwise_usage = columnwise + + def copy(self) -> Quantizer: + """Create shallow copy""" + return copy.copy(self) + + +class _QuantizeFunc(torch.autograd.Function): + """Cast to FP8 from other dtype""" @staticmethod def forward( - _ctx: torch.autograd.function.FunctionCtx, # unused - tensor: QuantizedTensor, - dtype: Optional[torch.dtype] = None, - ) -> torch.Tensor: + _ctx: Optional[torch.autograd.function.FunctionCtx], # unused + tensor: torch.Tensor, + quantizer: Quantizer, + ) -> QuantizedTensor: # pylint: disable=missing-function-docstring - return tensor.dequantize(dtype=dtype) + return tex.quantize(tensor, quantizer) @staticmethod def backward( @@ -29,27 +198,55 @@ def backward( grad: torch.Tensor, ) -> Tuple[Optional[torch.Tensor], ...]: # pylint: disable=missing-function-docstring + # Assume that we want gradients in full precision return grad, None class _IdentityFunc(torch.autograd.Function): - """Autograd function to create quantized tensor with same data""" + """Identity function + + If constructor keyword-arguments are provided, then construct a + new Float8Tensor using the provided tensor's attributes. + + """ @staticmethod def forward( - _ctx: torch.autograd.function.FunctionCtx, # unused + ctx, tensor: QuantizedTensor, + init_kwargs: Optional[Dict[str, Any]] = None, ) -> QuantizedTensor: # pylint: disable=missing-function-docstring - return tensor.detach() + + # Return input tensor if constructor kwargs are not provided + if init_kwargs is None: + return tensor.detach() + + # Construct new tensor if constructor kwargs are provided + ctx.input_dtype = tensor.dtype + kwargs = tensor.get_metadata() + for key, val in init_kwargs.items(): + kwargs[key] = val + return type(tensor)(tensor.shape, tensor.dtype, **kwargs) @staticmethod - def backward( - _ctx: torch.autograd.function.FunctionCtx, # unused - grad: torch.Tensor, - ) -> torch.Tensor: + def backward(ctx, grad_output): # pylint: disable=missing-function-docstring - return grad + grad_input = grad_output + if grad_input.dtype == ctx.input_dtype: + grad_input = grad_input.detach() + else: + grad_input = grad_input.to(ctx.input_dtype) + return grad_input, None + + +def _stride_from_shape(shape: list[int]): + if len(shape) == 0: + return [] + rstride = [1] + for d in reversed(shape[1:]): + rstride.append(rstride[-1] * d) + return list(reversed(rstride)) class QuantizedTensor(torch.Tensor): @@ -62,6 +259,22 @@ class QuantizedTensor(torch.Tensor): """ + def __new__(cls, shape: Iterable[int], dtype: torch.dtype, *, requires_grad: bool = False): + # We are assuming only contiguous tensors + stride = _stride_from_shape(shape) + instance = torch.Tensor._make_wrapper_subclass( + cls, + shape, + strides=stride, + storage_offset=0, + dtype=dtype, + layout=torch.strided, + requires_grad=requires_grad, + device=torch.cuda.current_device(), + ) + + return instance + def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor: """Convert quantized data to standard PyTorch tensor""" raise NotImplementedError( @@ -85,24 +298,38 @@ def detach(self) -> QuantizedTensor: f"{self.__class__.__name__} class does not implement detach function" ) - def __repr__(self) -> str: + def update_usage(self, rowwise_usage=True, columnwise_usage=True): + """Indicate to the tensor how it is going to be used + + This enables optimizations to memory usage in some cases + where forward and backward passes use the tensor in + different directions. + """ + raise NotImplementedError( + f"{self.__class__.__name__} class does not implement update_usage function" + ) + + def clear(self): + """Deallocate this tensor's memory. Typically not needed and must be used carefully""" + + def __repr__(self, *, tensor_contents=None) -> str: return f"{self.__class__.__name__}(data={self.dequantize(dtype=self.dtype)})" def float(self) -> torch.Tensor: # pylint: disable=missing-function-docstring - return _DequantizeFunc.apply(self, torch.float32) + return self.dequantize(dtype=torch.float32) def bfloat16(self) -> torch.Tensor: # pylint: disable=missing-function-docstring - return _DequantizeFunc.apply(self, torch.bfloat16) + return self.dequantize(dtype=torch.bfloat16) def half(self) -> torch.Tensor: # pylint: disable=missing-function-docstring - return _DequantizeFunc.apply(self, torch.float16) + return self.dequantize(dtype=torch.float16) - def cpu(self) -> torch.Tensor: + def cpu(self, memory_format=torch.preserve_format) -> torch.Tensor: # pylint: disable=missing-function-docstring - return _DequantizeFunc.apply(self).cpu() + return self.dequantize().cpu(memory_format=memory_format) def expand_as(self, other: torch.Tensor) -> torch.Tensor: # pylint: disable=missing-function-docstring @@ -179,3 +406,54 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): kwargs = {} # Do not force the QuantizedTensor type on the returned tensor return torch._C._disabled_torch_function_impl(func, types, args, kwargs) + + def contiguous( + self, + memory_format: torch.memory_format = torch.contiguous_format, + ) -> QuantizedTensor: + # pylint: disable=missing-function-docstring + raise NotImplementedError( + f"{self.__class__.__name__} class does not implement contiguous function" + ) + + def get_metadata(self) -> Dict[str, Any]: + """Get keyword arguments for quantized tensor constructor + + Contains metadata so that the new quantized tensor has the + same underlying quantized data. + + """ + raise NotImplementedError( + f"{self.__class__.__name__} class does not implement get_metadata function" + ) + + @classmethod + def make_like( + cls, + tensor: QuantizedTensor, + *, + shape: Optional[Iterable[int]] = None, + dtype: Optional[torch.dtype] = None, + requires_grad: bool = False, + data: Optional[torch.Tensor] = None, + ) -> QuantizedTensor: + """Create new quantized tensor + + By default, new tensor has the same attributes and underlying + data. + + """ + shape = shape if shape is not None else tensor.shape + dtype = dtype if dtype is not None else tensor.dtype + kwargs = tensor.get_metadata() + if data is not None: + kwargs["data"] = data + return cls(shape=shape, dtype=dtype, requires_grad=requires_grad, **kwargs) + + def to_dtype(self, dtype: torch.dtype) -> QuantizedTensor: + """Create `QuantizedTensor` with given nominal dtype + + The new tensor has the same underlying data. + + """ + return self.__class__.make_like(self, dtype=dtype) diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index 7c3da9a73f..97b1361163 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -267,11 +267,11 @@ def __init__( zero_centered_gamma: bool = False, qkv_weight_interleaved: bool = True, ub_tp_comm_overlap: bool = False, - ub_bulk_wgrad: bool = True, - ub_bulk_dgrad: bool = True, ub_overlap_ag: bool = True, ub_overlap_rs: bool = True, ub_overlap_rs_dgrad: bool = False, + ub_bulk_dgrad: bool = True, + ub_bulk_wgrad: bool = True, bias: bool = True, activation: str = "gelu", normalization: str = "LayerNorm", diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index 63b2f2cfb5..5b1bd82221 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -6,11 +6,13 @@ from __future__ import annotations import functools import math -from typing import Any, Callable, Optional, Tuple +from typing import Any, Callable, List, Optional, Tuple import torch import transformer_engine.pytorch.cpp_extensions as ext +from .tensor.quantized_tensor import QuantizedTensor + def requires_grad(*tensors: Tuple[Optional[torch.Tensor], ...]) -> None: """Check if any of the given tensors require gradient.""" @@ -27,12 +29,10 @@ def clear_tensor_data(*tensors: Tuple[Optional[torch.Tensor], ...]) -> None: Must be used carefully. """ - from .float8_tensor import Float8Tensor - for t in tensors: if t is not None: - if isinstance(t, Float8Tensor): - t._data.data = torch.Tensor() + if isinstance(t, QuantizedTensor): + t.clear() else: t.data = torch.Tensor() del t @@ -231,14 +231,15 @@ def check_dim_for_fp8_exec(tensor: torch.Tensor) -> bool: return tensor.dim() == 2 and tensor.size(0) % 8 == 0 and tensor.size(1) % 16 == 0 -def assert_dim_for_fp8_exec(tensor: torch.Tensor) -> None: - """Assert that tensor dimensions are supported for FP8 TN GEMM""" - # single tensor check so it's clear which tensor is triggering the assertion - assert tensor.dim() == 2 and tensor.size(0) % 8 == 0 and tensor.size(1) % 16 == 0, ( - "FP8 execution requires 2D input matrices with " - "height divisible by 8 and width divisible by 16, " - f"but got tensor with dims={list(tensor.size())}" - ) +def assert_dim_for_fp8_exec(*tensors: List[torch.Tensor]) -> None: + """Assert that tensor or tensors dimensions are supported for FP8 TN GEMM.""" + + for tensor in tensors: + assert tensor.dim() == 2 and tensor.size(0) % 8 == 0 and tensor.size(1) % 16 == 0, ( + "FP8 execution requires 2D input matrices with " + "height divisible by 8 and width divisible by 16, " + f"but got tensor with dims={list(tensor.size())}" + ) def is_bf16_compatible() -> None: @@ -248,6 +249,13 @@ def is_bf16_compatible() -> None: return torch.cuda.get_device_capability()[0] >= 8 +def non_tn_fp8_gemm_supported() -> bool: + """Checks whether the device supports + non-TN layouts for FP8 GEMMs. + """ + return torch.cuda.get_device_capability() >= (10, 0) + + @functools.lru_cache(maxsize=None) def get_cudnn_version() -> Tuple[int, int, int]: """Runtime cuDNN version (major, minor, patch)""" @@ -305,3 +313,16 @@ def devices_match(device1: torch.device, device2: torch.device) -> bool: index2 = torch.cuda.current_device() return index1 == index2 return device1 == device2 + + +@functools.lru_cache +def get_sm_count() -> int: + """Returns the number of streaming multiprocessors in the current device.""" + return torch.cuda.get_device_properties(torch.cuda.current_device()).multi_processor_count + + +def round_up_to_nearest_multiple(value, multiple): + """Round up `value` to the next mutiple of `multiple`""" + if multiple == 0: + raise ValueError("multiple cannot be zero.") + return ((value + multiple - 1) // multiple) * multiple
    cuDNN 8.9.6+: sm90
    JAX, PaddlePaddle: `no_bias`, `post_scale_bias`JAX: `no_bias`, `post_scale_bias`ALiBi slopes: FP32cuDNN 9.0+: sm80+