diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml
index 964e71fa8c..4be7a30a86 100644
--- a/.github/workflows/build.yml
+++ b/.github/workflows/build.yml
@@ -73,23 +73,3 @@ jobs:
MAX_JOBS: 1
- name: 'Sanity check'
run: python tests/jax/test_sanity_import.py
- paddle:
- name: 'PaddlePaddle'
- runs-on: ubuntu-latest
- container:
- image: nvcr.io/nvidia/paddlepaddle:24.10-py3
- options: --user root
- steps:
- - name: 'Checkout'
- uses: actions/checkout@v3
- with:
- submodules: recursive
- - name: 'Build'
- run: |
- apt-get update
- apt-get install -y libgoogle-glog-dev
- pip install . -v
- env:
- NVTE_FRAMEWORK: paddle
- - name: 'Sanity check'
- run: python tests/paddle/test_sanity_import.py
diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml
index f98fc9aa3a..ee6433d484 100644
--- a/.github/workflows/lint.yml
+++ b/.github/workflows/lint.yml
@@ -61,30 +61,3 @@ jobs:
export PYTHON_ONLY=1
export TE_PATH=.
bash ./qa/L0_jax_lint/test.sh
- paddle_cpplint:
- name: 'PaddlePaddle C++'
- runs-on: ubuntu-latest
- steps:
- - name: Checkout
- uses: actions/checkout@v3
- - name: 'Lint'
- run: |
- sudo apt-get update
- sudo apt-get install pip -y
- export CPP_ONLY=1
- export TE_PATH=.
- bash ./qa/L0_paddle_lint/test.sh
- paddle_pylint:
- name: 'PaddlePaddle Python'
- runs-on: ubuntu-latest
- steps:
- - name: 'Checkout'
- uses: actions/checkout@v3
- - name: 'Lint'
- run: |
- sudo apt-get update
- sudo apt-get install pip -y
- pip install paddlepaddle-gpu
- export PYTHON_ONLY=1
- export TE_PATH=.
- bash ./qa/L0_paddle_lint/test.sh
diff --git a/.gitignore b/.gitignore
index 9b61454e21..f491b21f43 100644
--- a/.gitignore
+++ b/.gitignore
@@ -8,7 +8,6 @@
*.nsys-rep
*.ncu-rep
*.sqlite
-*.onnx
*.eggs
build/
*.so
diff --git a/3rdparty/cudnn-frontend b/3rdparty/cudnn-frontend
index cc5632eda7..91b7532f33 160000
--- a/3rdparty/cudnn-frontend
+++ b/3rdparty/cudnn-frontend
@@ -1 +1 @@
-Subproject commit cc5632eda70bbdac34455c2d94066d27d10e2699
+Subproject commit 91b7532f3386768bba4f444ee7672b497f34da8a
diff --git a/README.rst b/README.rst
index fbcf05f3c9..8fea8c9d94 100644
--- a/README.rst
+++ b/README.rst
@@ -174,7 +174,7 @@ To install the latest stable version of Transformer Engine,
pip install git+https://github.com/NVIDIA/TransformerEngine.git@stable
-This will automatically detect if any supported deep learning frameworks are installed and build Transformer Engine support for them. To explicitly specify frameworks, set the environment variable NVTE_FRAMEWORK to a comma-separated list (e.g. NVTE_FRAMEWORK=jax,pytorch,paddle).
+This will automatically detect if any supported deep learning frameworks are installed and build Transformer Engine support for them. To explicitly specify frameworks, set the environment variable NVTE_FRAMEWORK to a comma-separated list (e.g. NVTE_FRAMEWORK=jax,pytorch).
Alternatively, the package can be directly installed from `Transformer Engine's PyPI `_, e.g.
@@ -182,7 +182,7 @@ Alternatively, the package can be directly installed from `Transformer Engine's
pip install transformer_engine[pytorch]
-To obtain the necessary Python bindings for Transformer Engine, the frameworks needed must be explicitly specified as extra dependencies in a comma-separated list (e.g. [jax,pytorch,paddle]). Transformer Engine ships wheels for the core library as well as the PaddlePaddle extensions. Source distributions are shipped for the JAX and PyTorch extensions.
+To obtain the necessary Python bindings for Transformer Engine, the frameworks needed must be explicitly specified as extra dependencies in a comma-separated list (e.g. [jax,pytorch]). Transformer Engine ships wheels for the core library. Source distributions are shipped for the JAX and PyTorch extensions.
From source
^^^^^^^^^^^
diff --git a/build_tools/VERSION.txt b/build_tools/VERSION.txt
index 809a0327d8..eb5820cd2d 100644
--- a/build_tools/VERSION.txt
+++ b/build_tools/VERSION.txt
@@ -1 +1 @@
-1.14.0.dev0
+2.1.0.dev0
diff --git a/build_tools/build_ext.py b/build_tools/build_ext.py
index 5744439c1b..a3243d087b 100644
--- a/build_tools/build_ext.py
+++ b/build_tools/build_ext.py
@@ -129,63 +129,6 @@ def run(self) -> None:
super().run()
self.extensions = all_extensions
- paddle_ext = None
- if "paddle" in get_frameworks():
- for ext in self.extensions:
- if "paddle" in ext.name:
- paddle_ext = ext
- break
-
- # Manually write stub file for Paddle extension
- if paddle_ext is not None:
- # Load libtransformer_engine.so to avoid linker errors
- if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
- # Source compilation from top-level (--editable)
- search_paths = list(Path(__file__).resolve().parent.parent.iterdir())
- # Source compilation from top-level
- search_paths.extend(list(Path(self.build_lib).iterdir()))
-
- # Dynamically load required_libs.
- from transformer_engine.common import _load_cudnn, _load_nvrtc
-
- _load_cudnn()
- _load_nvrtc()
- else:
- # Only during release bdist build for paddlepaddle.
- import transformer_engine
-
- search_paths = list(Path(transformer_engine.__path__[0]).iterdir())
- del transformer_engine
-
- common_so_path = ""
- for path in search_paths:
- if path.name.startswith("libtransformer_engine."):
- common_so_path = str(path)
- assert common_so_path, "Could not find libtransformer_engine"
- ctypes.CDLL(common_so_path, mode=ctypes.RTLD_GLOBAL)
-
- # Figure out stub file path
- module_name = paddle_ext.name
- assert module_name.endswith(
- "_pd_"
- ), "Expected Paddle extension module to end with '_pd_'"
- stub_name = module_name[:-4] # remove '_pd_'
- stub_path = os.path.join(self.build_lib, "transformer_engine", stub_name + ".py")
- Path(stub_path).parent.mkdir(exist_ok=True, parents=True)
-
- # Figure out library name
- # Note: This library doesn't actually exist. Paddle
- # internally reinserts the '_pd_' suffix.
- so_path = self.get_ext_fullpath(module_name)
- _, so_ext = os.path.splitext(so_path)
- lib_name = stub_name + so_ext
-
- # Write stub file
- print(f"Writing Paddle stub for {lib_name} into file {stub_path}")
- from paddle.utils.cpp_extension.extension_utils import custom_write_stub
-
- custom_write_stub(lib_name, stub_path)
-
# Ensure that binaries are not in global package space.
target_dir = install_dir / "transformer_engine"
target_dir.mkdir(exist_ok=True, parents=True)
@@ -194,16 +137,10 @@ def run(self) -> None:
self.copy_file(ext, target_dir)
os.remove(ext)
- # For paddle, the stub file needs to be copied to the install location.
- if paddle_ext is not None:
- stub_path = Path(self.build_lib) / "transformer_engine"
- for stub in stub_path.glob("transformer_engine_paddle.py"):
- self.copy_file(stub, target_dir)
-
def build_extensions(self):
- # BuildExtensions from PyTorch and PaddlePaddle already handle CUDA files correctly
+ # BuildExtensions from PyTorch already handle CUDA files correctly
# so we don't need to modify their compiler. Only the pybind11 build_ext needs to be fixed.
- if "pytorch" not in get_frameworks() and "paddle" not in get_frameworks():
+ if "pytorch" not in get_frameworks():
# Ensure at least an empty list of flags for 'cxx' and 'nvcc' when
# extra_compile_args is a dict.
for ext in self.extensions:
diff --git a/build_tools/paddle.py b/build_tools/paddle.py
deleted file mode 100644
index f0fcdb8f25..0000000000
--- a/build_tools/paddle.py
+++ /dev/null
@@ -1,92 +0,0 @@
-# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
-#
-# See LICENSE for license information.
-
-"""Paddle-paddle related extensions."""
-from pathlib import Path
-
-import setuptools
-import os
-
-from .utils import cuda_version
-
-import paddle
-
-paddle_version = paddle.__version__.replace(".", "")
-
-
-def setup_paddle_extension(
- csrc_source_files,
- csrc_header_files,
- common_header_files,
-) -> setuptools.Extension:
- """Setup CUDA extension for Paddle support"""
-
- # Source files
- csrc_source_files = Path(csrc_source_files)
- sources = [
- csrc_source_files / "extensions.cpp",
- csrc_source_files / "common.cpp",
- csrc_source_files / "custom_ops.cu",
- ]
-
- # Header files
- include_dirs = [
- common_header_files,
- common_header_files / "common",
- common_header_files / "common" / "include",
- csrc_header_files,
- ]
-
- # Compiler flags
- cxx_flags = ["-O3"]
- nvcc_flags = [
- "-O3",
- "-gencode",
- "arch=compute_70,code=sm_70",
- "-U__CUDA_NO_HALF_OPERATORS__",
- "-U__CUDA_NO_HALF_CONVERSIONS__",
- "-U__CUDA_NO_BFLOAT16_OPERATORS__",
- "-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
- "-U__CUDA_NO_BFLOAT162_OPERATORS__",
- "-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
- f"-DPADDLE_VERSION={paddle_version}",
- "--expt-relaxed-constexpr",
- "--expt-extended-lambda",
- "--use_fast_math",
- ]
-
- # Version-dependent CUDA options
- try:
- version = cuda_version()
- except FileNotFoundError:
- print("Could not determine CUDA Toolkit version")
- else:
- if version < (12, 0):
- raise RuntimeError("Transformer Engine requires CUDA 12.0 or newer")
- nvcc_flags.extend(
- (
- "--threads",
- os.getenv("NVTE_BUILD_THREADS_PER_JOB", "1"),
- "-gencode",
- "arch=compute_80,code=sm_80",
- "-gencode",
- "arch=compute_90,code=sm_90",
- )
- )
-
- # Construct Paddle CUDA extension
- sources = [str(path) for path in sources]
- include_dirs = [str(path) for path in include_dirs]
- from paddle.utils.cpp_extension import CUDAExtension
-
- ext = CUDAExtension(
- sources=sources,
- include_dirs=include_dirs,
- extra_compile_args={
- "cxx": cxx_flags,
- "nvcc": nvcc_flags,
- },
- )
- ext.name = "transformer_engine_paddle_pd_"
- return ext
diff --git a/build_tools/pytorch.py b/build_tools/pytorch.py
index f060e99dff..b8501e1008 100644
--- a/build_tools/pytorch.py
+++ b/build_tools/pytorch.py
@@ -27,7 +27,6 @@ def setup_pytorch_extension(
extensions_dir = csrc_source_files / "extensions"
sources = [
csrc_source_files / "common.cpp",
- csrc_source_files / "ts_fp8_op.cpp",
] + all_files_in_dir(extensions_dir)
# Header files
diff --git a/build_tools/utils.py b/build_tools/utils.py
index f2a4200685..723f2f200c 100644
--- a/build_tools/utils.py
+++ b/build_tools/utils.py
@@ -190,7 +190,12 @@ def cuda_path() -> Tuple[str, str]:
@functools.lru_cache(maxsize=None)
def cuda_archs() -> str:
- return os.getenv("NVTE_CUDA_ARCHS", "70;80;89;90")
+ version = cuda_version()
+ if os.getenv("NVTE_CUDA_ARCHS") is None:
+ os.environ["NVTE_CUDA_ARCHS"] = (
+ "70;80;89;90;100;120" if version >= (12, 8) else "70;80;89;90"
+ )
+ return os.getenv("NVTE_CUDA_ARCHS")
def cuda_version() -> Tuple[int, ...]:
@@ -211,7 +216,7 @@ def cuda_version() -> Tuple[int, ...]:
def get_frameworks() -> List[str]:
"""DL frameworks to build support for"""
_frameworks: List[str] = []
- supported_frameworks = ["pytorch", "jax", "paddle"]
+ supported_frameworks = ["pytorch", "jax"]
# Check environment variable
if os.getenv("NVTE_FRAMEWORK"):
@@ -237,12 +242,6 @@ def get_frameworks() -> List[str]:
pass
else:
_frameworks.append("jax")
- try:
- import paddle
- except ImportError:
- pass
- else:
- _frameworks.append("paddle")
# Special framework names
if "all" in _frameworks:
@@ -311,7 +310,6 @@ def uninstall_te_wheel_packages():
"-y",
"transformer_engine_cu12",
"transformer_engine_torch",
- "transformer_engine_paddle",
"transformer_engine_jax",
]
)
diff --git a/build_tools/wheel_utils/build_wheels.sh b/build_tools/wheel_utils/build_wheels.sh
index ceebe626f4..9acb22aee6 100644
--- a/build_tools/wheel_utils/build_wheels.sh
+++ b/build_tools/wheel_utils/build_wheels.sh
@@ -9,7 +9,6 @@ BUILD_METAPACKAGE=${2:-true}
BUILD_COMMON=${3:-true}
BUILD_PYTORCH=${4:-true}
BUILD_JAX=${5:-true}
-BUILD_PADDLE=${6:-true}
export NVTE_RELEASE_BUILD=1
export TARGET_BRANCH=${TARGET_BRANCH:-}
@@ -63,38 +62,3 @@ if $BUILD_JAX ; then
/opt/python/cp310-cp310/bin/python setup.py sdist 2>&1 | tee /wheelhouse/logs/jax.txt
cp dist/* /wheelhouse/
fi
-
-if $BUILD_PADDLE ; then
- if [ "$PLATFORM" == "manylinux_2_28_x86_64" ] ; then
- dnf -y remove --allowerasing cudnn9-cuda-12
- dnf -y install libcudnn8-devel.x86_64 libcudnn8.x86_64
- cd /TransformerEngine/transformer_engine/paddle
-
- /opt/python/cp38-cp38/bin/pip install /wheelhouse/*.whl --no-deps
- /opt/python/cp38-cp38/bin/pip install paddlepaddle-gpu==2.6.1
- /opt/python/cp38-cp38/bin/python setup.py bdist_wheel --verbose --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/paddle_cp38.txt
- /opt/python/cp38-cp38/bin/pip uninstall -y transformer-engine transformer-engine-cu12 paddlepaddle-gpu
-
- /opt/python/cp39-cp39/bin/pip install /wheelhouse/*.whl --no-deps
- /opt/python/cp39-cp39/bin/pip install paddlepaddle-gpu==2.6.1
- /opt/python/cp39-cp39/bin/python setup.py bdist_wheel --verbose --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/paddle_cp39.txt
- /opt/python/cp39-cp39/bin/pip uninstall -y transformer-engine transformer-engine-cu12 paddlepaddle-gpu
-
- /opt/python/cp310-cp310/bin/pip install /wheelhouse/*.whl --no-deps
- /opt/python/cp310-cp310/bin/pip install paddlepaddle-gpu==2.6.1
- /opt/python/cp310-cp310/bin/python setup.py bdist_wheel --verbose --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/paddle_cp310.txt
- /opt/python/cp310-cp310/bin/pip uninstall -y transformer-engine transformer-engine-cu12 paddlepaddle-gpu
-
- /opt/python/cp311-cp311/bin/pip install /wheelhouse/*.whl --no-deps
- /opt/python/cp311-cp311/bin/pip install paddlepaddle-gpu==2.6.1
- /opt/python/cp311-cp311/bin/python setup.py bdist_wheel --verbose --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/paddle_cp311.txt
- /opt/python/cp311-cp311/bin/pip uninstall -y transformer-engine transformer-engine-cu12 paddlepaddle-gpu
-
- /opt/python/cp312-cp312/bin/pip install /wheelhouse/*.whl --no-deps
- /opt/python/cp312-cp312/bin/pip install paddlepaddle-gpu==2.6.1
- /opt/python/cp312-cp312/bin/python setup.py bdist_wheel --verbose --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/paddle_cp312.txt
- /opt/python/cp312-cp312/bin/pip uninstall -y transformer-engine transformer-engine-cu12 paddlepaddle-gpu
-
- mv dist/* /wheelhouse/
- fi
-fi
diff --git a/docs/api/common.rst b/docs/api/common.rst
index 85201aee5d..5e0a660ae6 100644
--- a/docs/api/common.rst
+++ b/docs/api/common.rst
@@ -8,4 +8,4 @@ Common API
.. autoapiclass:: transformer_engine.common.recipe.Format
-.. autoapiclass:: transformer_engine.common.recipe.DelayedScaling(margin=0, fp8_format=Format.HYBRID, amax_history_len=1024, amax_compute_algo="max", scaling_factor_compute_algo=None, override_linear_precision=(False, False, False))
+.. autoapiclass:: transformer_engine.common.recipe.DelayedScaling(margin=0, fp8_format=Format.HYBRID, amax_history_len=1024, amax_compute_algo="max", scaling_factor_compute_algo=None)
diff --git a/docs/api/framework.rst b/docs/api/framework.rst
index acd54fe3b1..0ac1a0e34e 100644
--- a/docs/api/framework.rst
+++ b/docs/api/framework.rst
@@ -10,4 +10,3 @@ Framework-specific API
pytorch
jax
- paddle
diff --git a/docs/api/paddle.rst b/docs/api/paddle.rst
deleted file mode 100644
index 3b3ecf55c6..0000000000
--- a/docs/api/paddle.rst
+++ /dev/null
@@ -1,34 +0,0 @@
-..
- Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
-
- See LICENSE for license information.
-
-paddle
-======
-
-.. autoapiclass:: transformer_engine.paddle.Linear(in_features, out_features, **kwargs)
- :members: forward
-
-.. autoapiclass:: transformer_engine.paddle.LayerNorm(hidden_size, eps=1e-5, **kwargs)
-
-.. autoapiclass:: transformer_engine.paddle.LayerNormLinear(in_features, out_features, eps=1e-5, **kwargs)
- :members: forward
-
-.. autoapiclass:: transformer_engine.paddle.LayerNormMLP(hidden_size, ffn_hidden_size, eps=1e-5, **kwargs)
- :members: forward
-
-.. autoapiclass:: transformer_engine.paddle.FusedScaleMaskSoftmax(attn_mask_type, mask_func, **kwargs)
- :members: forward
-
-.. autoapiclass:: transformer_engine.paddle.DotProductAttention(num_attention_heads, kv_channels, **kwargs)
- :members: forward
-
-.. autoapiclass:: transformer_engine.paddle.MultiHeadAttention(hidden_size, num_attention_heads, **kwargs)
- :members: forward
-
-.. autoapiclass:: transformer_engine.paddle.TransformerLayer(hidden_size, ffn_hidden_size, num_attention_heads, **kwargs)
- :members: forward
-
-.. autoapifunction:: transformer_engine.paddle.fp8_autocast
-
-.. autoapifunction:: transformer_engine.paddle.recompute
diff --git a/docs/api/pytorch.rst b/docs/api/pytorch.rst
index 43001feeb3..6d5fe6761d 100644
--- a/docs/api/pytorch.rst
+++ b/docs/api/pytorch.rst
@@ -42,8 +42,6 @@ pyTorch
.. autoapifunction:: transformer_engine.pytorch.checkpoint
-.. autoapifunction:: transformer_engine.pytorch.onnx_export
-
.. autoapifunction:: transformer_engine.pytorch.make_graphed_callables
.. autoapifunction:: transformer_engine.pytorch.get_cpu_offload_context
diff --git a/docs/examples/attention/attention.ipynb b/docs/examples/attention/attention.ipynb
index 27017b4773..16a3b05466 100644
--- a/docs/examples/attention/attention.ipynb
+++ b/docs/examples/attention/attention.ipynb
@@ -14,11 +14,10 @@
" Figure 1: Dot product attention. \n",
"\n",
"\n",
- "[Transformer Engine](https://github.com/NVIDIA/TransformerEngine.git) supports the calculation of dot product attention in three frameworks, [PyTorch](https://github.com/pytorch/pytorch), [JAX](https://github.com/google/jax) and [PaddlePaddle](https://github.com/PaddlePaddle/Paddle). The API for each framework is\n",
+ "[Transformer Engine](https://github.com/NVIDIA/TransformerEngine.git) supports the calculation of dot product attention in two frameworks, [PyTorch](https://github.com/pytorch/pytorch) and [JAX](https://github.com/google/jax). The API for each framework is\n",
"\n",
"- [transformer_engine.pytorch.DotProductAttention](../../api/pytorch.rst#transformer_engine.pytorch.DotProductAttention)\n",
- "- [transformer_engine.jax.flax.DotProductAttention](../../api/jax.rst#transformer_engine.jax.flax.DotProductAttention)\n",
- "- [transformer_engine.paddle.DotProductAttention](../../api/paddle.rst#transformer_engine.paddle.DotProductAttention)"
+ "- [transformer_engine.jax.flax.DotProductAttention](../../api/jax.rst#transformer_engine.jax.flax.DotProductAttention)"
]
},
{
@@ -56,15 +55,6 @@
"
\n",
" JAX-native attention (`_UnfusedDotProductAttention`) | \n",
"
\n",
- " \n",
- " PaddlePaddle | \n",
- " cuDNN attention (`_te_forward`) | \n",
- " [transformer_engine.paddle.layer.attention](https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/paddle/layer/attention.py)\n",
- " | \n",
- "
\n",
- " \n",
- " PaddlePaddle-native attention (`_pd_forward`) | \n",
- "
\n",
" \n",
""
]
@@ -87,7 +77,7 @@
"\n",
"Note: \n",
" \n",
- "Transformer Engine's flash-attention backend, available in PyTorch, and cuDNN attention backend (sub-backends 1 and 2), available in PyTorch, JAX and PaddlePaddle, are both based on the flash algorithm.\n",
+ "Transformer Engine's flash-attention backend, available in PyTorch, and cuDNN attention backend (sub-backends 1 and 2), available in PyTorch and JAX, are both based on the flash algorithm.\n",
"
\n"
]
},
@@ -102,13 +92,13 @@
"\n",
"The flash-attention backend supports `flash-attn`'s features as well as a few extra functionalities to facilitate the use of `flash-attn`, such as converting the `attention_mask` to cumulative sequence lengths `cu_seqlens` for `padding` mask use cases. Please see `transformer_engine.pytorch.attention.FlashAttention` for details.\n",
"\n",
- "The `flash-attn` dependency is regularly updated in Transformer Engine. As of v1.10, Transformer Engine supports `flash-attn` 2.0.6+ (see [setup.py](https://github.com/NVIDIA/TransformerEngine/blob/main/setup.py)).\n",
+ "The `flash-attn` dependency is regularly updated in Transformer Engine. As of v2.0, Transformer Engine supports `flash-attn` 2.0.6+ (see [setup.py](https://github.com/NVIDIA/TransformerEngine/blob/main/setup.py)).\n",
"\n",
"To understand `flash-attn`'s performance, please refer to their benchmarks [here](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#performance).\n",
"\n",
"### 1.3 cuDNN Attention\n",
"\n",
- "The cuDNN attention backend, available in PyTorch, JAX and PaddlePaddle, offers another high-performance solution to the attention calculation. It requires [cuDNN](https://developer.nvidia.com/cudnn) to run, and has several sub-backends to support the different precisions and sequence lengths.\n",
+ "The cuDNN attention backend, available in PyTorch and JAX, offers another high-performance solution to the attention calculation. It requires [cuDNN](https://developer.nvidia.com/cudnn) to run, and has several sub-backends to support the different precisions and sequence lengths.\n",
"\n",
"\n",
" \n",
@@ -153,9 +143,9 @@
"
\n",
"
\n",
"\n",
- "The cuDNN attention backend and flash-attention backend have several notable differences. As of Transformer Engine 1.10, cuDNN 9.3 and `flash-attn` 2.4.2,\n",
+ "The cuDNN attention backend and flash-attention backend have several notable differences. As of Transformer Engine 2.0, cuDNN 9.3 and `flash-attn` 2.4.2,\n",
"\n",
- "- flash-attention only supports the PyTorch framework while cuDNN attention supports PyTorch, JAX and PaddlePaddle.\n",
+ "- flash-attention only supports the PyTorch framework while cuDNN attention supports PyTorch and JAX.\n",
"- flash-attention supports BF16, FP16 precisions while cuDNN attention also supports FP8 (through its sub-backend 2).\n",
"- flash-attention supports `bshd`, `thd` input formats, without any transposes, and `sbhd` format, with transposes, while cuDNN attention supports all three formats without transposes (see Section 3.1 for more details).\n",
"- flash-attention does not support `post_scale_bias`, and cuDNN attention does.\n",
@@ -244,10 +234,6 @@
" JAX | \n",
" cuDNN attention > JAX-native attention | \n",
" \n",
- " \n",
- " PaddlePaddle | \n",
- " cuDNN attention > PaddlePaddle-native attention | \n",
- "
\n",
""
]
},
@@ -266,7 +252,7 @@
"\n",
"Note:\n",
" \n",
- "These flags are supported in PyTorch only as of Transformer Engine 1.10. JAX and PaddlePaddle support is expected to be added in the future.\n",
+ "These flags are supported in PyTorch only as of Transformer Engine 2.0. JAX support is expected to be added in the future.\n",
"
"
]
},
@@ -382,7 +368,7 @@
"\n",
"Note\n",
" \n",
- "Environment variables NVTE_FLASH_ATTN
, NVTE_FUSED_ATTN
, NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT
and NVTE_ALLOW_NONDETERMINISTIC_ALGO
are only supported in PyTorch, and will be added to JAX and PaddlePaddle in the future.\n",
+ "Environment variables NVTE_FLASH_ATTN
, NVTE_FUSED_ATTN
, NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT
and NVTE_ALLOW_NONDETERMINISTIC_ALGO
are only supported in PyTorch, and will be added to JAX in the future.\n",
"
\n",
"\n",
"### 2.3 Example Tests\n",
@@ -399,7 +385,7 @@
"source": [
"## 3. Backend Support\n",
"\n",
- "Transformer Engine supports commonly-used features such as self and cross attention, FP16/BF16 precisions, dropout, and checkpointing. But it also offers a range of other features. As of v1.10, Transformer Engine's attention backends have the following support matrix.\n",
+ "Transformer Engine supports commonly-used features such as self and cross attention, FP16/BF16 precisions, dropout, and checkpointing. But it also offers a range of other features. As of v2.0, Transformer Engine's attention backends have the following support matrix.\n",
"\n",
"| Attention Backend | Precision | Architecture | Sliding Window Attention | MQA/GQA | Multi-Latent Attention | Context Parallelism | Determinism Possible |\n",
"| :---------------- | :-------- | :----------- | :----------------------- | :------ | :--------------------- | :------------------ | :------------ |\n",
@@ -442,7 +428,7 @@
"**qkv_layout=thd_thd_thd:**\n",
"`q`, `k`, `v` have variable sequence lengths in a batch. They are all contiguous and have no interleaving.\n",
"\n",
- "As of v1.10, Transformer Engine has the following support matrix.\n",
+ "As of v2.0, Transformer Engine has the following support matrix.\n",
"\n",
"\n",
" \n",
@@ -462,13 +448,13 @@
"
\n",
" \n",
" \n",
- " JAX, PaddlePaddle: `bs3hd`, `bshd_bs2hd`, `bshd_bshd_bshd` layouts\n",
+ " JAX: `bs3hd`, `bshd_bs2hd`, `bshd_bshd_bshd` layouts\n",
" | \n",
"
\n",
" \n",
" Framework-native attention | \n",
" `bshd`, `sbhd` | \n",
- " PyTorch, JAX, PaddlePaddle: 2 formats, i.e. 10 layouts | \n",
+ " PyTorch, JAX: 2 formats, i.e. 10 layouts | \n",
"
\n",
"
\n",
"\n",
@@ -492,7 +478,7 @@
"\n",
"- `no_mask`, `padding`, `causal`, `causal_bottom_right`, `padding_causal`, `padding_causal_bottom_right`, `arbitrary`\n",
"\n",
- "Different backends offer different support for attention mask. As of Transformer Engine 1.10,\n",
+ "Different backends offer different support for attention mask. As of Transformer Engine 2.0,\n",
"\n",
"\n",
" \n",
@@ -512,21 +498,21 @@
"
\n",
" \n",
" Framework-native attention | \n",
- " All (PyTorch)`no_mask`, `causal`, `padding` (Jax, PaddlePaddle) | \n",
+ " All (PyTorch)`no_mask`, `causal`, `padding` (Jax) | \n",
"
\n",
" \n",
" | \n",
"
\n",
"
\n",
"\n",
- "**Padding masks:** For `padding`, `padding_causal`, `padding_causal_bottom_right` mask types, users need to provide sequence length information to help Transformer Engine figure out where each sequence ends in a batch. As of Transformer Engine 1.10, there are two options to do so in PyTorch and one in JAX and PaddlePaddle.\n",
+ "**Padding masks:** For `padding`, `padding_causal`, `padding_causal_bottom_right` mask types, users need to provide sequence length information to help Transformer Engine figure out where each sequence ends in a batch. As of Transformer Engine 2.0, there are two options to do so in PyTorch and one in JAX.\n",
"\n",
"* PyTorch: When both options are provided by the user, `cu_seqlens` is preferred as there is no extra conversion needed.\n",
" - `cu_seqlens`: Users can provide cumulative sequence length tensors `cu_seqlens_q` and `cu_seqlens_kv` for `q` and `k`/`v` to the flash-attention or cuDNN attention backend. An example of `cu_seqlens` is `[0, 2, 6, 7]` for a batch of 3 `[aa000, bbbb0, c0000]`.\n",
" - `attention_mask`: Users can also provide `attention_mask` as an alternative, which will then be converted to `cu_seqlens`. For self-attention, `attention_mask` should be one single tensor in shape `[batch_size, 1, 1, seqlen_q]`, and for cross-attention, `attention_mask` should be a list of two tensors in shapes `[batch_size, 1, 1, seqlen_q]` and `[batch_size, 1, 1, seqlen_kv]`, respectively.\n",
"\n",
"\n",
- "* JAX and PaddlePaddle: Users should provide the `attention_mask` tensor in shape `[batch_size, 1, seqlen_q, seqlen_kv]`.\n",
+ "* JAX: Users should provide the `attention_mask` tensor in shape `[batch_size, 1, seqlen_q, seqlen_kv]`.\n",
"\n",
"**qkv_format=thd:** Transformer Engine extracts the max sequence length information from `q`, `k`, `v` if `max_seqlen_q` and `max_seqlen_kv` are not provided. This requires GPU-CPU copy and synchronization operations. For performance reasons, please set `max_seqlen_q` and `max_seqlen_kv` to their appropriate values for `thd` QKV format.\n",
"\n",
@@ -566,7 +552,7 @@
"\n",
"### 3.3 Attention Bias\n",
"\n",
- "Transformer Engine supports 4 attention bias types, `no_bias`, `pre_scale_bias`, `post_scale_bias`, and `ALiBi` (with/without custom slopes). As of Transformer Engine 1.10, their support matrix is as follows.\n",
+ "Transformer Engine supports 4 attention bias types, `no_bias`, `pre_scale_bias`, `post_scale_bias`, and `ALiBi` (with/without custom slopes). As of Transformer Engine 2.0, their support matrix is as follows.\n",
"\n",
"\n",
" \n",
@@ -591,7 +577,7 @@
" cuDNN 8.9.6+: sm90 | \n",
"
\n",
" \n",
- " JAX, PaddlePaddle: `no_bias`, `post_scale_bias` | \n",
+ " JAX: `no_bias`, `post_scale_bias` | \n",
" ALiBi slopes: FP32 | \n",
" cuDNN 9.0+: sm80+ | \n",
"
\n",
@@ -620,7 +606,7 @@
"\n",
"A unique feature of Transformer Engine is its FP8 support, not only for the `Linear` layers but also for dot product attention. Transformer Engine's FP8 attention support is through its cuDNN attention sub-backend 2. Recall Figure 1: the two `MatMul` operations are performed in FP8 for computational efficiency, and the `SoftMax` operation is performed in FP32 for numerical accuracy.\n",
"\n",
- "Transformer Engine supports FP8 attention through its [C APIs](../../api/c/fused_attn.rst), and [PyTorch API](../../api/pytorch.rst#transformer_engine.pytorch.DotProductAttention), as of v1.10. Its PyTorch API offers two options, both controlled through the FP8 recipe definition, `transformer_engine.common.recipe.DelayedScaling`.\n",
+ "Transformer Engine supports FP8 attention through its [C APIs](../../api/c/fused_attn.rst), and [PyTorch API](../../api/pytorch.rst#transformer_engine.pytorch.DotProductAttention), as of v2.0. Its PyTorch API offers two options, both controlled through the FP8 recipe definition, `transformer_engine.common.recipe.DelayedScaling`.\n",
"\n",
"- `DelayedScaling.fp8_dpa=True (default=False)`: This enables the use of cuDNN attention sub-backend 2, when it does support the provided user inputs. The `FusedAttention` module for cuDNN attention takes FP16 or BF16 tensors as inputs, performs dot product attention in FP8, and returns attention logits in FP16 or BF16 (same as the input type). Casting operations are required to cast tensors to FP8 at the beginning, and back to FP16/BF16 at the end of the module.\n",
"\n",
diff --git a/docs/installation.rst b/docs/installation.rst
index fae01c64fa..ee7afa9006 100644
--- a/docs/installation.rst
+++ b/docs/installation.rst
@@ -37,7 +37,7 @@ Transformer Engine can be directly installed from `our PyPI desired_test_accuracy
-
- @unittest.skipIf(
- paddle.device.cuda.get_device_capability() < (8, 0),
- "BF16 MNIST example requires Ampere+ GPU",
- )
- def test_te_bf16(self):
- """Test Transformer Engine with BF16"""
- self.args.use_te = True
- self.args.use_fp8 = False
- self.args.save_model = True
- actual = train_and_evaluate(self.args)
- if os.path.exists("mnist_cnn.pdparams"):
- os.remove("mnist_cnn.pdparams")
- self.verify(actual)
-
- @unittest.skipIf(not gpu_has_fp8, reason)
- def test_te_fp8(self):
- """Test Transformer Engine with FP8"""
- self.args.use_te = True
- self.args.use_fp8 = True
- self.args.save_model = True
- actual = train_and_evaluate(self.args)
- if os.path.exists("mnist_cnn.pdparams"):
- os.remove("mnist_cnn.pdparams")
- self.verify(actual)
-
- @unittest.skipIf(not gpu_has_fp8, reason)
- def test_te_fp8_calibration(self):
- """Test Transformer Engine with FP8 calibration"""
- self.args.use_te = True
- self.args.use_fp8 = False
- self.args.use_fp8_infer = True
- actual = train_and_evaluate(self.args)
- if os.path.exists("mnist_cnn.pdparams"):
- os.remove("mnist_cnn.pdparams")
- self.verify(actual)
-
-
-if __name__ == "__main__":
- train_and_evaluate(mnist_parser(None))
diff --git a/pylintrc b/pylintrc
index b80679d72c..4af0c6b427 100644
--- a/pylintrc
+++ b/pylintrc
@@ -2,7 +2,6 @@
extension-pkg-whitelist=flash_attn_2_cuda,
torch,
transformer_engine_torch,
- transformer_engine_paddle,
transformer_engine_jax
extension-pkg-allow-list=transformer_engine.transformer_engine_jax
diff --git a/qa/L0_jax_unittest/test.sh b/qa/L0_jax_unittest/test.sh
index 6eff047721..8e2e540293 100644
--- a/qa/L0_jax_unittest/test.sh
+++ b/qa/L0_jax_unittest/test.sh
@@ -8,7 +8,7 @@ pip install "nltk>=3.8.2"
pip install pytest==8.2.1
: ${TE_PATH:=/opt/transformerengine}
-pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax -k 'not distributed'
+pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax -k 'not distributed' --ignore=$TE_PATH/tests/jax/test_praxis_layers.py
# Test without custom calls
NVTE_CUSTOM_CALLS_RE="" pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_custom_call_compute.py
diff --git a/qa/L0_paddle_lint/test.sh b/qa/L0_paddle_lint/test.sh
deleted file mode 100644
index 1c26bd265b..0000000000
--- a/qa/L0_paddle_lint/test.sh
+++ /dev/null
@@ -1,24 +0,0 @@
-# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
-#
-# See LICENSE for license information.
-
-set -e
-
-: "${TE_PATH:=/opt/transformerengine}"
-
-pip install cpplint==1.6.0 pylint==3.3.1
-if [ -z "${PYTHON_ONLY}" ]
-then
- cd $TE_PATH
- echo "Checking common API headers"
- cpplint --root transformer_engine/common/include --recursive transformer_engine/common/include
- echo "Checking C++ files"
- cpplint --recursive --exclude=transformer_engine/common/include --exclude=transformer_engine/build_tools/build transformer_engine/common
- cpplint --recursive transformer_engine/paddle
-fi
-if [ -z "${CPP_ONLY}" ]
-then
- cd $TE_PATH
- echo "Checking Python files"
- pylint --recursive=y transformer_engine/common transformer_engine/paddle
-fi
diff --git a/qa/L0_paddle_unittest/test.sh b/qa/L0_paddle_unittest/test.sh
deleted file mode 100644
index 9312f22ba4..0000000000
--- a/qa/L0_paddle_unittest/test.sh
+++ /dev/null
@@ -1,10 +0,0 @@
-# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
-#
-# See LICENSE for license information.
-
-set -xe
-
-pip install pytest==8.2.1
-: ${TE_PATH:=/opt/transformerengine}
-pytest -Wignore -v $TE_PATH/tests/paddle
-pytest -Wignore -v $TE_PATH/examples/paddle/mnist
diff --git a/qa/L0_paddle_wheel/test.sh b/qa/L0_paddle_wheel/test.sh
deleted file mode 100644
index 5116bdb5cf..0000000000
--- a/qa/L0_paddle_wheel/test.sh
+++ /dev/null
@@ -1,37 +0,0 @@
-# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
-#
-# See LICENSE for license information.
-
-set -e
-
-: "${TE_PATH:=/opt/transformerengine}"
-
-# Install dependencies
-# Note: Need to install wheel locally since PaddlePaddle container
-# already contains APT install.
-pip install pydantic
-pip install --user wheel==0.44.0
-
-cd $TE_PATH
-pip uninstall -y transformer-engine transformer-engine-cu12 transformer-engine-paddle
-
-VERSION=`cat $TE_PATH/build_tools/VERSION.txt`
-WHL_BASE="transformer_engine-${VERSION}"
-
-# Core wheel.
-NVTE_RELEASE_BUILD=1 python setup.py bdist_wheel
-python -m wheel unpack dist/*
-sed -i "s/Name: transformer-engine/Name: transformer-engine-cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA"
-sed -i "s/Name: transformer_engine/Name: transformer_engine_cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA"
-mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_cu12-${VERSION}.dist-info"
-python -m wheel pack ${WHL_BASE}
-rm dist/*.whl
-mv *.whl dist/
-NVTE_RELEASE_BUILD=1 NVTE_BUILD_METAPACKAGE=1 python setup.py bdist_wheel
-pip install dist/*.whl --no-deps
-
-cd transformer_engine/paddle
-NVTE_RELEASE_BUILD=1 python setup.py bdist_wheel
-pip install dist/*
-
-python $TE_PATH/tests/paddle/test_sanity_import.py
diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh
index 793fa47259..dd7f95bce0 100644
--- a/qa/L0_pytorch_unittest/test.sh
+++ b/qa/L0_pytorch_unittest/test.sh
@@ -11,11 +11,10 @@ pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py
pytest -v -s $TE_PATH/tests/pytorch/test_recipe.py
pytest -v -s $TE_PATH/tests/pytorch/test_deferred_init.py
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py
-PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_cuda_graphs.py
+NVTE_CUDNN_MXFP8_NORM=0 PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_cuda_graphs.py
pytest -v -s $TE_PATH/tests/pytorch/test_jit.py
pytest -v -s $TE_PATH/tests/pytorch/test_fused_rope.py
pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py
-pytest -v -s $TE_PATH/tests/pytorch/test_torch_save_load.py
pytest -v -s $TE_PATH/tests/pytorch/test_gqa.py
pytest -v -s $TE_PATH/tests/pytorch/test_fused_optimizer.py
pytest -v -s $TE_PATH/tests/pytorch/test_multi_tensor.py
diff --git a/qa/L1_pytorch_distributed_unittest/test.sh b/qa/L1_pytorch_distributed_unittest/test.sh
index ee7c28ca5f..8ee0be1af5 100644
--- a/qa/L1_pytorch_distributed_unittest/test.sh
+++ b/qa/L1_pytorch_distributed_unittest/test.sh
@@ -8,8 +8,8 @@ set -e
pip install pytest==8.2.1
pytest -v -s $TE_PATH/tests/pytorch/distributed/test_numerics.py
-pytest -v -s $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py
pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops.py
-pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py
pytest -v -s $TE_PATH/tests/pytorch/distributed/test_torch_fsdp2.py
+pytest -v -s $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py
+# pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py ### TODO Debug UB support with te.Sequential
pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn_with_cp.py
diff --git a/qa/L1_pytorch_onnx_test/test.sh b/qa/L1_pytorch_onnx_test/test.sh
deleted file mode 100644
index 8e4ef03b8e..0000000000
--- a/qa/L1_pytorch_onnx_test/test.sh
+++ /dev/null
@@ -1,16 +0,0 @@
-# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
-#
-# See LICENSE for license information.
-
-set -e
-
-: ${TE_PATH:=/opt/transformerengine}
-
-pip install pytest==8.2.1 onnxruntime==1.19.2
-
-# Build custom ONNX Runtime operators
-export CUSTOM_ORT_OPS_PATH=$TE_PATH/tests/pytorch/custom_ort_ops
-bash $CUSTOM_ORT_OPS_PATH/build.sh
-
-# Run tests
-NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/test_onnx_export.py
diff --git a/qa/L3_pytorch_FA_versions_test/test.sh b/qa/L3_pytorch_FA_versions_test/test.sh
index e63ba358a5..8ed3002214 100644
--- a/qa/L3_pytorch_FA_versions_test/test.sh
+++ b/qa/L3_pytorch_FA_versions_test/test.sh
@@ -12,7 +12,14 @@ pip install pytest==8.2.1
export MAX_JOBS=4
# Iterate over Flash Attention versions
-FA_versions=(2.1.1 2.3.0 2.4.1 2.5.7 2.6.3 3.0.0b1)
+sm_arch=`python -c "import torch; sm = torch.cuda.get_device_capability(0); print(sm[0]*10+sm[1])"`
+if [ $sm_arch -gt 90 ]
+then
+ FA_versions=(2.7.3)
+else
+ FA_versions=(2.1.1 2.3.0 2.4.1 2.5.7 2.7.3 3.0.0b1)
+fi
+
for fa_version in "${FA_versions[@]}"
do
@@ -21,10 +28,10 @@ do
then
pip install flash-attn==${fa_version}
else
- pip install "git+https://github.com/Dao-AILab/flash-attention.git#egg=flashattn-hopper&subdirectory=hopper"
+ pip install "git+https://github.com/Dao-AILab/flash-attention.git@v2.7.2#egg=flashattn-hopper&subdirectory=hopper"
python_path=`python -c "import site; print(site.getsitepackages()[0])"`
mkdir -p $python_path/flashattn_hopper
- wget -P $python_path/flashattn_hopper https://raw.githubusercontent.com/Dao-AILab/flash-attention/main/hopper/flash_attn_interface.py
+ wget -P $python_path/flashattn_hopper https://raw.githubusercontent.com/Dao-AILab/flash-attention/v2.7.2/hopper/flash_attn_interface.py
fi
# Run tests
diff --git a/setup.py b/setup.py
index 643dd7a908..1d9818458e 100644
--- a/setup.py
+++ b/setup.py
@@ -5,6 +5,7 @@
"""Installation script."""
import os
+import sys
import time
from pathlib import Path
from typing import List, Tuple
@@ -35,14 +36,13 @@
if "pytorch" in frameworks:
from torch.utils.cpp_extension import BuildExtension
-elif "paddle" in frameworks:
- from paddle.utils.cpp_extension import BuildExtension
elif "jax" in frameworks:
install_and_import("pybind11[global]")
from pybind11.setup_helpers import build_ext as BuildExtension
CMakeBuildExtension = get_build_ext(BuildExtension)
+archs = cuda_archs()
class TimedBdist(bdist_wheel):
@@ -57,7 +57,7 @@ def run(self):
def setup_common_extension() -> CMakeExtension:
"""Setup CMake extension for common library"""
- cmake_flags = ["-DCMAKE_CUDA_ARCHITECTURES={}".format(cuda_archs())]
+ cmake_flags = ["-DCMAKE_CUDA_ARCHITECTURES={}".format(archs)]
if bool(int(os.getenv("NVTE_UB_WITH_MPI", "0"))):
assert (
os.getenv("MPI_HOME") is not None
@@ -104,13 +104,11 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
if "pytorch" in frameworks:
install_reqs.extend(["torch"])
- test_reqs.extend(["numpy", "onnxruntime", "torchvision", "prettytable"])
+ test_reqs.extend(["numpy", "torchvision", "prettytable"])
if "jax" in frameworks:
install_reqs.extend(["jax", "flax>=0.7.1"])
- test_reqs.extend(["numpy", "praxis"])
- if "paddle" in frameworks:
- install_reqs.append("paddlepaddle-gpu")
- test_reqs.append("numpy")
+ # test_reqs.extend(["numpy", "praxis"])
+ test_reqs.extend(["numpy"])
return [remove_dups(reqs) for reqs in [setup_reqs, install_reqs, test_reqs]]
@@ -135,7 +133,6 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
extras_require = {
"pytorch": [f"transformer_engine_torch=={__version__}"],
"jax": [f"transformer_engine_jax=={__version__}"],
- "paddle": [f"transformer_engine_paddle=={__version__}"],
}
else:
setup_requires, install_requires, test_requires = setup_requirements()
@@ -169,16 +166,6 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
current_file_path / "transformer_engine",
)
)
- if "paddle" in frameworks:
- from build_tools.paddle import setup_paddle_extension
-
- ext_modules.append(
- setup_paddle_extension(
- "transformer_engine/paddle/csrc",
- current_file_path / "transformer_engine" / "paddle" / "csrc",
- current_file_path / "transformer_engine",
- )
- )
# Configure package
setuptools.setup(
diff --git a/tests/cpp/CMakeLists.txt b/tests/cpp/CMakeLists.txt
index d8c8d99fac..081cd14eb4 100644
--- a/tests/cpp/CMakeLists.txt
+++ b/tests/cpp/CMakeLists.txt
@@ -5,7 +5,11 @@
cmake_minimum_required(VERSION 3.18)
if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
- set(CMAKE_CUDA_ARCHITECTURES 70 80 90)
+ if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.8)
+ set(CMAKE_CUDA_ARCHITECTURES 75 80 89 90 100 120)
+ else ()
+ set(CMAKE_CUDA_ARCHITECTURES 75 80 89 90)
+ endif()
endif()
diff --git a/tests/cpp/operator/CMakeLists.txt b/tests/cpp/operator/CMakeLists.txt
index 178dc5e8dd..ce78fcaae2 100644
--- a/tests/cpp/operator/CMakeLists.txt
+++ b/tests/cpp/operator/CMakeLists.txt
@@ -3,23 +3,33 @@
# See LICENSE for license information.
add_executable(test_operator
+ test_cast.cu
+ test_cast_dbias.cu
+ test_cast_dbias_dgelu.cu
+ test_cast_gated_swiglu.cu
+ test_cast_mxfp8_gated_swiglu.cu
test_qdq.cu
- test_cast_transpose.cu
+ test_cast_mxfp8.cu
+ test_dequantize_mxfp8.cu
test_transpose.cu
+ test_cast_transpose.cu
test_cast_transpose_dbias.cu
test_cast_transpose_dbias_dgelu.cu
test_cast_transpose_dgeglu.cu
test_act.cu
test_normalization.cu
+ test_normalization_mxfp8.cu
test_multi_cast_transpose.cu
test_multi_padding.cu
test_causal_softmax.cu
+ test_swizzle.cu
../test_common.cu)
+find_package(OpenMP REQUIRED)
list(APPEND test_operator_LINKER_LIBS CUDA::cudart GTest::gtest_main ${TE_LIB} CUDA::nvrtc CUDNN::cudnn)
-target_link_libraries(test_operator PUBLIC ${test_operator_LINKER_LIBS})
-target_compile_options(test_operator PRIVATE -O2)
+target_link_libraries(test_operator PUBLIC ${test_operator_LINKER_LIBS} OpenMP::OpenMP_CXX)
+target_compile_options(test_operator PRIVATE -O2 -fopenmp)
include(GoogleTest)
-gtest_discover_tests(test_operator)
+gtest_discover_tests(test_operator DISCOVERY_TIMEOUT 600)
diff --git a/tests/cpp/operator/test_act.cu b/tests/cpp/operator/test_act.cu
index cec997d078..4224f199f4 100644
--- a/tests/cpp/operator/test_act.cu
+++ b/tests/cpp/operator/test_act.cu
@@ -21,58 +21,6 @@
using namespace transformer_engine;
-namespace {
-
-// forward
-
-float gelu(const float x) {
- return 0.5f * x * (1.0f + tanhf(0.79788456F * x * (1.0f + 0.044715f * x * x)));
-}
-
-float silu(const float x) {
- return x / (1 + expf(-x));
-}
-
-float relu(const float x) {
- return x > 0 ? x : 0;
-}
-
-float srelu(const float x) {
- return x > 0 ? x * x : 0;
-}
-
-float qgelu(const float x) {
- return x / (1 + expf(-1.702f * x));
-}
-
-// backward
-
-float dgelu(const float x) {
- const float tanh_out = tanhf(0.79788456f * x * (1.f + 0.044715f * x * x));
- return 0.5f * x * ((1.f - tanh_out * tanh_out) * (0.79788456f + 0.1070322243f * x * x)) +
- 0.5f * (1.f + tanh_out);
-}
-
-float dsilu(const float x) {
- const float sigmoid = 1.f / (1 + expf(-x));
- return x * sigmoid * (1.f - sigmoid) + sigmoid;
-}
-
-float drelu(const float x) {
- return x > 0.f ? 1.f : 0.f;
-}
-
-float dsrelu(const float x) {
- return fmaxf(2.f * x, 0.f);
-}
-
-float dqgelu(const float x) {
- const float sigmoid = 1.f / (1 + expf(-1.702f * x));
- return 1.702f * x * sigmoid * (1.f - sigmoid) + sigmoid;
-}
-
-} // namespace
-
template
void compute_ref_act_cast(const IT *input_h,
OT *output_h,
@@ -82,6 +30,7 @@ void compute_ref_act_cast(const IT *input_h,
const size_t H) {
CT amax = 0.;
+ #pragma omp parallel for schedule(static) reduction(max: amax) proc_bind(spread)
for (size_t i = 0; i < N; i++) {
for (size_t j = 0; j < H; j++) {
CT elt = static_cast(input_h[i * H + j]);
@@ -101,6 +50,7 @@ void compute_ref_dact_cast(const IT *input_h,
const size_t N,
const size_t H) {
using CT = float;
+ #pragma omp parallel for schedule(static) proc_bind(spread)
for (size_t i = 0; i < N; i++) {
for (size_t j = 0; j < H; j++) {
CT elt = static_cast(input_h[i * H + j]);
@@ -118,6 +68,7 @@ void compute_ref_glu_act_cast(const IT *input_h, OT *output_h, const CT scale, C
const int col = H * 2;
+ #pragma omp parallel for schedule(static) reduction(max: amax) proc_bind(spread)
for (size_t i = 0; i < N; i++) {
for (size_t j = 0; j < H; j++) {
CT gelu_elt = static_cast(input_h[i * col + j]);
@@ -139,6 +90,7 @@ void compute_ref_dglu_act_cast(const IT *input_h, const IT *grad_h, OT *output_h
const int col = H * 2;
using CT = float;
+ #pragma omp parallel for schedule(static) proc_bind(spread)
for (size_t i = 0; i < N; i++) {
for (size_t j = 0; j < H; j++) {
CT grad = static_cast(grad_h[i * H + j]);
@@ -164,10 +116,10 @@ void performTest(const size_t N, const size_t H) {
DType itype = TypeInfo::dtype;
DType otype = TypeInfo::dtype;
- Tensor input({ N, H }, itype);
- Tensor output({ N, H }, otype);
- Tensor igrad({ N, H }, itype);
- Tensor ograd({ N, H }, itype);
+ Tensor input("input", { N, H }, itype);
+ Tensor output("output", { N, H }, otype);
+ Tensor igrad("igrad", { N, H }, itype);
+ Tensor ograd("ograd", { N, H }, itype);
fillUniform(&input);
fillUniform(&ograd);
@@ -179,7 +131,7 @@ void performTest(const size_t N, const size_t H) {
nvte_act(input.data(), output.data(), 0);
float ref_amax;
- compute_ref_act_cast(input.cpu_dptr(), ref_output.get(),
+ compute_ref_act_cast(input.rowwise_cpu_dptr(), ref_output.get(),
output.scale(), &ref_amax, N, H);
cudaDeviceSynchronize();
@@ -195,7 +147,7 @@ void performTest(const size_t N, const size_t H) {
nvte_dact(ograd.data(), input.data(), igrad.data(), 0);
- compute_ref_dact_cast(input.cpu_dptr(), ograd.cpu_dptr(),
+ compute_ref_dact_cast(input.rowwise_cpu_dptr(), ograd.rowwise_cpu_dptr(),
ref_igrad.get(), N, H);
cudaDeviceSynchronize();
@@ -219,10 +171,10 @@ void performTestGLU(const size_t N, const size_t H) {
DType itype = TypeInfo::dtype;
DType otype = TypeInfo::dtype;
- Tensor input({N, H * 2}, itype);
- Tensor output({N, H}, otype);
- Tensor igrad({ N, H * 2 }, itype);
- Tensor ograd({ N, H }, itype);
+ Tensor input("input", {N, H * 2}, itype);
+ Tensor output("output", {N, H}, otype);
+ Tensor igrad("igrad", { N, H * 2 }, itype);
+ Tensor ograd("ograd", { N, H }, itype);
fillUniform(&input);
fillUniform(&ograd);
@@ -234,7 +186,7 @@ void performTestGLU(const size_t N, const size_t H) {
nvte_act(input.data(), output.data(), 0);
float ref_amax;
- compute_ref_glu_act_cast(input.cpu_dptr(), ref_output.get(),
+ compute_ref_glu_act_cast(input.rowwise_cpu_dptr(), ref_output.get(),
output.scale(), &ref_amax, N, H);
cudaDeviceSynchronize();
@@ -242,15 +194,19 @@ void performTestGLU(const size_t N, const size_t H) {
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
if (otype == DType::kFloat8E4M3 || otype == DType::kFloat8E5M2) {
- auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
- compareResults("amax", output.amax(), ref_amax, atol_amax, rtol_amax);
+ auto [atol, rtol] = getTolerances(DType::kFloat32);
+ compareResults("amax", output.amax(), ref_amax, atol, rtol);
+ if (output.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) {
+ const float ref_scale = 1.f / output.scale();
+ compareResults("scale_inv", *output.rowwise_cpu_scale_inv_ptr(), ref_scale, atol, rtol);
+ }
}
auto [atol, rtol] = getTolerances(otype);
compareResults("output_gelu", output, ref_output.get(), atol, rtol);
nvte_dact(ograd.data(), input.data(), igrad.data(), 0);
- compute_ref_dglu_act_cast(input.cpu_dptr(), ograd.cpu_dptr(),
+ compute_ref_dglu_act_cast(input.rowwise_cpu_dptr(), ograd.rowwise_cpu_dptr(),
ref_igrad.get(), N, H);
cudaDeviceSynchronize();
diff --git a/tests/cpp/operator/test_cast.cu b/tests/cpp/operator/test_cast.cu
new file mode 100644
index 0000000000..f57d1f035d
--- /dev/null
+++ b/tests/cpp/operator/test_cast.cu
@@ -0,0 +1,130 @@
+/*************************************************************************
+ * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ *
+ * See LICENSE for license information.
+ ************************************************************************/
+
+#include
+#include
+#include
+#include
+#include
+
+#include
+#include
+#include
+
+#include
+#include "../test_common.h"
+
+using namespace transformer_engine;
+
+namespace {
+
+template
+void compute_ref(const InputType *data, OutputType *output_c,
+ const size_t size,
+ float *amax, float scale) {
+ using compute_t = float;
+ compute_t current_max = -1e100;
+ for (size_t i = 0; i < size; ++i) {
+ compute_t current = static_cast(data[i]);
+ current_max = fmaxf(current_max, fabsf(current));
+ output_c[i] = OutputType(scale * current);
+ }
+ *amax = current_max;
+}
+
+template
+void performTest(const std::vector& shape) {
+ using namespace test;
+
+ const size_t full_size = product(shape);
+
+ DType itype = TypeInfo::dtype;
+ DType otype = TypeInfo::dtype;
+
+ Tensor input("input", shape, itype);
+ Tensor output_c("output_c", shape, otype);
+
+ std::unique_ptr ref_output_c = std::make_unique(full_size);
+
+ fillUniform(&input);
+ setRandomScale(&output_c);
+
+ nvte_quantize(input.data(), output_c.data(), 0);
+
+ float ref_amax;
+ compute_ref(input.rowwise_cpu_dptr(), ref_output_c.get(),
+ full_size, &ref_amax, output_c.scale());
+
+ cudaDeviceSynchronize();
+ auto err = cudaGetLastError();
+ ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
+ if (isFp8Type(otype)) {
+ auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
+ compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax);
+ float ref_scale_inv = 1.f / output_c.scale();
+ compareResults("scale_inv", output_c.rowwise_scale_inv(), ref_scale_inv, atol_amax, rtol_amax);
+ }
+ auto [atol, rtol] = getTolerances(otype);
+ compareResults("output_c", output_c, ref_output_c.get(), true, atol, rtol);
+}
+
+std::vector> test_cases = {
+ {16},
+ {16000},
+ {128, 128},
+ {256, 256},
+ {768, 1024},
+ {256, 65536},
+ {2048, 12288},
+ {65536, 128},
+ {65536, 160},
+ {16384, 1616},
+ {1, 128},
+ {1, 1296},
+ {1, 16},
+ {5, 160},
+ {5, 4, 3, 160},
+ {217, 256},
+};
+} // namespace
+
+class CastTestSuite : public ::testing::TestWithParam>> {};
+
+TEST_P(CastTestSuite, TestCast) {
+ using namespace transformer_engine;
+ using namespace test;
+
+ const DType input_type = std::get<0>(GetParam());
+ const DType output_type = std::get<1>(GetParam());
+ const auto size = std::get<2>(GetParam());
+
+ TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType,
+ TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType,
+ performTest(size);
+ );
+ );
+}
+
+
+
+INSTANTIATE_TEST_SUITE_P(
+ OperatorTest,
+ CastTestSuite,
+ ::testing::Combine(
+ ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
+ ::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2),
+ ::testing::ValuesIn(test_cases)),
+ [](const testing::TestParamInfo& info) {
+ std::string name = test::typeName(std::get<0>(info.param)) + "X" +
+ test::typeName(std::get<1>(info.param));
+ const auto& shape = std::get<2>(info.param);
+ for ( const auto& s: shape) {
+ name += "X" + std::to_string(s);
+ }
+ return name;
+ });
diff --git a/tests/cpp/operator/test_cast_dbias.cu b/tests/cpp/operator/test_cast_dbias.cu
new file mode 100644
index 0000000000..1f0a9305d8
--- /dev/null
+++ b/tests/cpp/operator/test_cast_dbias.cu
@@ -0,0 +1,181 @@
+/*************************************************************************
+ * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ *
+ * See LICENSE for license information.
+ ************************************************************************/
+
+#include
+#include
+#include
+#include
+#include
+#include
+
+#include
+#include
+#include
+
+#include
+#include "../test_common.h"
+
+using namespace transformer_engine;
+
+namespace {
+
+template
+void compute_ref_cast_dbias(const IT *input_h,
+ const CT scale,
+ OT *output_c_h,
+ CT *amax_h,
+ IT *dbias_h,
+ const size_t N,
+ const size_t H) {
+ CT amax = 0.;
+
+ std::vector acc_dbias(H, 0.);
+
+ for (size_t i = 0; i < N; i++) {
+ for (size_t j = 0; j < H; j++) {
+ CT elt = static_cast(input_h[i * H + j]);
+
+ // update amax
+ amax = std::abs(elt) > amax ? std::abs(elt) : amax;
+
+ output_c_h[i * H + j] = static_cast(scale * elt);
+
+ // dbias
+ acc_dbias[j] += elt;
+ }
+ }
+
+ *amax_h = amax;
+
+ for (size_t i = 0; i < H; i++) {
+ dbias_h[i] = static_cast(acc_dbias[i]);
+ }
+}
+
+template
+void performTest(const std::vector& shape) {
+ using namespace test;
+ using CType = fp32;
+
+ DType itype = TypeInfo::dtype;
+ DType otype = TypeInfo::dtype;
+
+ const size_t N = first_dimension(shape);
+ const size_t H = last_dimension(shape);
+
+ Tensor input("input", shape, itype);
+
+ Tensor output_c("output_c", shape, otype);
+ // dbias has the same data type with "output grad"
+ Tensor dbias("dbias", {H}, itype);
+
+ fillUniform(&input);
+ setRandomScale(&output_c);
+
+ std::unique_ptr ref_output_c = std::make_unique(N*H);
+ std::unique_ptr ref_output_dbias = std::make_unique(H);
+
+ CType ref_amax;
+ compute_ref_cast_dbias(input.rowwise_cpu_dptr(),
+ output_c.scale(),
+ ref_output_c.get(),
+ &ref_amax,
+ ref_output_dbias.get(),
+ N, H);
+
+ Tensor workspace;
+
+ nvte_quantize_dbias(input.data(),
+ output_c.data(),
+ dbias.data(),
+ workspace.data(),
+ 0);
+
+ workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype());
+
+ nvte_quantize_dbias(input.data(),
+ output_c.data(),
+ dbias.data(),
+ workspace.data(),
+ 0);
+
+ cudaDeviceSynchronize();
+ auto err = cudaGetLastError();
+ ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
+
+ if (isFp8Type(otype)) {
+ auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
+ compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax);
+ float ref_scale_inv = 1.f / output_c.scale();
+ compareResults("scale_inv", output_c.rowwise_scale_inv(), ref_scale_inv, atol_amax, rtol_amax);
+ }
+ auto [atol, rtol] = getTolerances(otype);
+ compareResults("output_c", output_c, ref_output_c.get(), true, atol, rtol);
+
+ auto [atol_dbias, rtol_dbias] = getTolerances(itype);
+ rtol_dbias *= 4;
+ compareResults("output_dbias", dbias, ref_output_dbias.get(), true, atol_dbias, rtol_dbias);
+}
+
+std::vector> test_cases = {
+ {128, 128},
+ {256, 256},
+ {768, 1024},
+ {256, 65536},
+ {2048, 12288},
+ {65536, 128},
+ {65536, 160},
+ {16384, 1616},
+ {1, 128},
+ {1, 1296},
+ {1, 16},
+ {5, 160},
+ {5, 4, 3, 160},
+ {217, 256},
+};
+
+} // namespace;
+
+
+class CastDBiasTestSuite : public ::testing::TestWithParam>> {};
+
+TEST_P(CastDBiasTestSuite, TestCastDBias) {
+ using namespace transformer_engine;
+ using namespace test;
+ // Skip tests for pre-Blackwell architectures
+ if (getDeviceComputeCapability() < blackwellComputeCapability) {
+ GTEST_SKIP();
+ }
+
+ const DType input_type = std::get<0>(GetParam());
+ const DType output_type = std::get<1>(GetParam());
+ const auto size = std::get<2>(GetParam());
+
+ TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType,
+ TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType,
+ performTest(size);
+ );
+ );
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ OperatorTest,
+ CastDBiasTestSuite,
+ ::testing::Combine(
+ ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
+ ::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2),
+ ::testing::ValuesIn(test_cases)),
+ [](const testing::TestParamInfo& info) {
+ std::string name = test::typeName(std::get<0>(info.param)) + "X" +
+ test::typeName(std::get<1>(info.param));
+ const auto& shape = std::get<2>(info.param);
+ for ( const auto& s: shape) {
+ name += "X" + std::to_string(s);
+ }
+ return name;
+ });
diff --git a/tests/cpp/operator/test_cast_dbias_dgelu.cu b/tests/cpp/operator/test_cast_dbias_dgelu.cu
new file mode 100644
index 0000000000..20ea5c31f1
--- /dev/null
+++ b/tests/cpp/operator/test_cast_dbias_dgelu.cu
@@ -0,0 +1,196 @@
+/*************************************************************************
+ * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ *
+ * See LICENSE for license information.
+ ************************************************************************/
+
+#include
+#include
+#include
+#include
+#include
+#include
+
+#include
+#include
+#include
+
+#include
+#include "../test_common.h"
+
+using namespace transformer_engine;
+using namespace test;
+
+namespace {
+
+template
+void compute_ref_cast_dbias_dgelu(const IT *input,
+ const IT *grad,
+ const CT scale,
+ OT *output_c,
+ CT *amax_h,
+ IT *dbias,
+ const size_t N,
+ const size_t H) {
+ CT amax = 0.;
+
+ std::vector acc_dbias(H, 0.);
+
+ for (size_t i = 0; i < N; i++) {
+ for (size_t j = 0; j < H; j++) {
+ CT in_elt = static_cast(input[i * H + j]);
+ const CT in_grad = static_cast(grad[i * H + j]);
+
+ const CT elt = in_grad * static_cast(dgelu(static_cast(in_elt)));
+ const CT elt_abs = std::abs(elt);
+
+ // update amax
+ if (elt_abs > amax) {
+ amax = elt_abs;
+ }
+
+ output_c[i * H + j] = static_cast(scale * elt);
+
+ // dbias
+ acc_dbias[j] += elt;
+ }
+ }
+
+ *amax_h = amax;
+
+ for (size_t i = 0; i < H; i++) {
+ dbias[i] = static_cast(acc_dbias[i]);
+ }
+}
+
+template
+void performTest(const std::vector& shape) {
+ using namespace test;
+ using CType = fp32;
+
+ DType itype = TypeInfo::dtype;
+ DType otype = TypeInfo::dtype;
+
+ const size_t N = first_dimension(shape);
+ const size_t H = last_dimension(shape);
+
+ Tensor input("input", shape, itype);
+ Tensor grad("grad", shape, itype);
+
+ Tensor output_c("output_c", shape, otype);
+ // dbias has the same data type with "output grad"
+ Tensor dbias("dbias", {H}, itype);
+
+ fillUniform(&input);
+ fillUniform(&grad);
+ setRandomScale(&output_c);
+
+ std::unique_ptr ref_output_c = std::make_unique(N*H);
+ std::unique_ptr ref_output_dbias = std::make_unique(H);
+
+ CType ref_amax;
+ compute_ref_cast_dbias_dgelu(input.rowwise_cpu_dptr(),
+ grad.rowwise_cpu_dptr(),
+ output_c.scale(),
+ ref_output_c.get(),
+ &ref_amax,
+ ref_output_dbias.get(),
+ N, H);
+
+ Tensor workspace;
+
+ nvte_quantize_dbias_dgelu(grad.data(),
+ input.data(),
+ output_c.data(),
+ dbias.data(),
+ workspace.data(),
+ 0);
+
+ workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype());
+
+
+ nvte_quantize_dbias_dgelu(grad.data(),
+ input.data(),
+ output_c.data(),
+ dbias.data(),
+ workspace.data(),
+ 0);
+
+ cudaDeviceSynchronize();
+ auto err = cudaGetLastError();
+ ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
+
+ if (isFp8Type(otype)) {
+ auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
+ compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax);
+ float ref_scale_inv = 1.f / output_c.scale();
+ compareResults("scale_inv", output_c.rowwise_scale_inv(), ref_scale_inv, atol_amax, rtol_amax);
+ }
+
+ auto [atol, rtol] = getTolerances(otype);
+ compareResults("output_c", output_c, ref_output_c.get(), true, atol, rtol);
+
+ auto [atol_dbias, rtol_dbias] = getTolerances(itype);
+ rtol_dbias *= 4;
+ compareResults("output_dbias", dbias, ref_output_dbias.get(), true, atol_dbias, rtol_dbias);
+}
+
+std::vector> test_cases = {
+ {128, 128},
+ {256, 256},
+ {768, 1024},
+ {256, 65536},
+ {2048, 12288},
+ {65536, 128},
+ {65536, 160},
+ {16384, 1616},
+ {1, 128},
+ {1, 1296},
+ {1, 16},
+ {5, 160},
+ {5, 4, 3, 160},
+ {217, 256},
+};
+
+} // namespace;
+
+
+class CastDBiasDGeluTestSuite : public ::testing::TestWithParam>> {};
+
+TEST_P(CastDBiasDGeluTestSuite, TestCastDBiasDgelu) {
+ using namespace transformer_engine;
+ using namespace test;
+ // Skip tests for pre-Blackwell architectures
+ if (getDeviceComputeCapability() < blackwellComputeCapability) {
+ GTEST_SKIP();
+ }
+
+ const DType input_type = std::get<0>(GetParam());
+ const DType output_type = std::get<1>(GetParam());
+ const auto size = std::get<2>(GetParam());
+
+ TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType,
+ TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType,
+ performTest(size);
+ );
+ );
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ OperatorTest,
+ CastDBiasDGeluTestSuite,
+ ::testing::Combine(
+ ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
+ ::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2),
+ ::testing::ValuesIn(test_cases)),
+ [](const testing::TestParamInfo& info) {
+ std::string name = test::typeName(std::get<0>(info.param)) + "X" +
+ test::typeName(std::get<1>(info.param));
+ const auto& shape = std::get<2>(info.param);
+ for ( const auto& s: shape) {
+ name += "X" + std::to_string(s);
+ }
+ return name;
+ });
diff --git a/tests/cpp/operator/test_cast_gated_swiglu.cu b/tests/cpp/operator/test_cast_gated_swiglu.cu
new file mode 100644
index 0000000000..35ae462106
--- /dev/null
+++ b/tests/cpp/operator/test_cast_gated_swiglu.cu
@@ -0,0 +1,165 @@
+/*************************************************************************
+ * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ *
+ * See LICENSE for license information.
+ ************************************************************************/
+
+#include
+#include
+#include
+#include
+#include
+#include
+
+#include
+#include
+#include
+#include
+
+#include
+#include
+#include "../test_common.h"
+
+using namespace transformer_engine;
+using namespace test;
+
+namespace {
+
+template
+void compute_ref_cast_dgated_swiglu(const IType * const grad,
+ const IType * const input,
+ const float scale,
+ OType * const output,
+ float * const amax_ptr,
+ const size_t rows,
+ const size_t cols) {
+ float amax = 0;
+ const size_t stride = cols * 2;
+
+ #pragma omp parallel for reduction(max: amax) proc_bind(spread)
+ for (size_t i = 0; i < rows; i++) {
+ for (size_t j = 0; j < cols; j++) {
+ float grad_elt = static_cast(grad[i * cols + j]);
+ float silu_elt = static_cast(input[i * stride + j]);
+ float gate_elt = static_cast(input[i * stride + cols + j]);
+
+ float after_dsilu = dsilu(silu_elt) * grad_elt * gate_elt;
+ float after_dgate = grad_elt * silu(silu_elt);
+
+ if (abs(after_dsilu) > amax) { amax = abs(after_dsilu); }
+ if (abs(after_dgate) > amax) { amax = abs(after_dgate); }
+
+ output[i * stride + j] = static_cast(scale * after_dsilu);
+ output[i * stride + cols + j] = static_cast(scale * after_dgate);
+ }
+ }
+
+ *amax_ptr = amax;
+}
+
+template
+void performTest(const std::vector& shape) {
+ using namespace test;
+
+ DType itype = TypeInfo::dtype;
+ DType otype = TypeInfo::dtype;
+
+ std::vector input_shape = shape;
+ input_shape[input_shape.size() - 1] *= 2;
+
+ const size_t input_size = product(input_shape);
+
+ const size_t rows = first_dimension(shape);
+ const size_t cols = last_dimension(shape);
+
+ Tensor grad("grad", shape, itype);
+ Tensor input("input", input_shape, itype);
+ Tensor output_c("output_c", input_shape, otype);
+
+ fillUniform(&grad);
+ fillUniform(&input);
+ setRandomScale(&output_c);
+
+ std::unique_ptr ref_output_c = std::make_unique(input_size);
+
+ nvte_dswiglu(grad.data(), input.data(), output_c.data(), 0);
+ cudaDeviceSynchronize();
+
+ auto err = cudaGetLastError();
+ ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
+
+ float ref_amax;
+ compute_ref_cast_dgated_swiglu(grad.rowwise_cpu_dptr(),
+ input.rowwise_cpu_dptr(),
+ output_c.scale(),
+ ref_output_c.get(),
+ &ref_amax,
+ rows,
+ cols);
+
+ if (isFp8Type(otype)) {
+ auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
+ compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax);
+ float ref_scale_inv = 1.f / output_c.scale();
+ compareResults("scale_inv", output_c.rowwise_scale_inv(), ref_scale_inv, atol_amax, rtol_amax);
+ }
+
+ auto [atol, rtol] = getTolerances(otype);
+ compareResults("output_c", output_c, ref_output_c.get(), true, atol, rtol);
+}
+
+std::vector> test_cases = {
+ {128, 128},
+ {256, 256},
+ {768, 1024},
+ {256, 65536},
+ {2048, 12288},
+ {65536, 128},
+ {217, 256},
+ {1296},
+ {5, 4, 3, 160},
+};
+
+} // namespace
+
+class CastSwiGLUTestSuite
+ : public ::testing::TestWithParam>> {};
+
+TEST_P(CastSwiGLUTestSuite, TestCastSwiGLU) {
+ using namespace transformer_engine;
+ using namespace test;
+ // Skip tests for pre-Blackwell architectures
+ if (getDeviceComputeCapability() < blackwellComputeCapability) {
+ GTEST_SKIP();
+ }
+
+ const DType input_type = std::get<0>(GetParam());
+ const DType output_type = std::get<1>(GetParam());
+ const auto size = std::get<2>(GetParam());
+
+ if (size.back() % 32 != 0) {
+ GTEST_SKIP();
+ }
+
+ TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(
+ input_type, InputType,
+ TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(
+ output_type, OutputType, performTest(size);););
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ OperatorTest, CastSwiGLUTestSuite,
+ ::testing::Combine(
+ ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
+ ::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2),
+ ::testing::ValuesIn(test_cases)),
+ [](const testing::TestParamInfo &info) {
+ std::string name = test::typeName(std::get<0>(info.param)) + "X" +
+ test::typeName(std::get<1>(info.param));
+ const auto& shape = std::get<2>(info.param);
+ for ( const auto& s: shape) {
+ name += "X" + std::to_string(s);
+ }
+ return name;
+ });
diff --git a/tests/cpp/operator/test_cast_mxfp8.cu b/tests/cpp/operator/test_cast_mxfp8.cu
new file mode 100644
index 0000000000..cb38a5a74a
--- /dev/null
+++ b/tests/cpp/operator/test_cast_mxfp8.cu
@@ -0,0 +1,636 @@
+/*************************************************************************
+ * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ *
+ * See LICENSE for license information.
+ ************************************************************************/
+
+#include
+#include
+#include
+#include
+
+#include
+#include
+#include "../test_common.h"
+#include "transformer_engine/transformer_engine.h"
+
+using namespace transformer_engine;
+using namespace test;
+
+namespace {
+
+enum ProcessingMethod {
+ CAST_ONLY,
+ CAST_DBIAS,
+ CAST_DBIAS_DACT,
+ CAST_DACT,
+ CAST_ACT
+};
+
+enum ActivationType {
+ Identity,
+ GeLU,
+ SiLU,
+ ReLU,
+ QGeLU,
+ SReLU
+};
+
+template
+void scale_block(const ProcessingMethod processing_method,
+ const InputType* input,
+ const InputType* grad,
+ OutputType* output_c,
+ float* dbias,
+ fp8e8m0* output_scales,
+ const size_t scale_idx,
+ const size_t i_min,
+ const size_t i_max,
+ const size_t j_min,
+ const size_t j_max,
+ const size_t cols) {
+ float amax = 0.0f;
+
+ // Find the absolute maximum value in the block
+ for (size_t i = i_min; i < i_max; ++i) {
+ for (size_t j = j_min; j < j_max; ++j) {
+ const size_t idx = i * cols + j;
+ float elt = static_cast(input[idx]);
+ if (processing_method == ProcessingMethod::CAST_DBIAS) {
+ // grad is the input
+ elt = static_cast(grad[idx]);
+ }
+ if (processing_method != ProcessingMethod::CAST_ONLY
+ && processing_method != ProcessingMethod::CAST_DBIAS) {
+ elt = OP(elt);
+ }
+ if (processing_method == ProcessingMethod::CAST_DACT ||
+ processing_method == ProcessingMethod::CAST_DBIAS_DACT) {
+ elt *= static_cast(grad[idx]);
+ }
+ dbias[j] += elt;
+ if (isinf(elt) || isnan(elt)) {
+ continue;
+ }
+ amax = std::max(amax, std::abs(elt));
+ }
+ }
+
+ const fp8e8m0 biased_exponent = float_to_e8m0(amax * Quantized_Limits::max_reciprocal());
+ const float scale_reciprocal = exp2f_rcp(biased_exponent);
+ output_scales[scale_idx] = biased_exponent;
+
+ // Quantize elements in the block
+ for (size_t i = i_min; i < i_max; ++i) {
+ for (size_t j = j_min; j < j_max; ++j) {
+ const size_t idx = i * cols + j;
+ float elt = static_cast(input[idx]);
+ if (processing_method == ProcessingMethod::CAST_DBIAS) {
+ // grad is the input
+ elt = static_cast(grad[idx]);
+ }
+ if (processing_method != ProcessingMethod::CAST_ONLY
+ && processing_method != ProcessingMethod::CAST_DBIAS) {
+ elt = OP(elt);
+ }
+ if (processing_method == ProcessingMethod::CAST_DACT ||
+ processing_method == ProcessingMethod::CAST_DBIAS_DACT) {
+ elt *= static_cast(grad[idx]);
+ }
+ output_c[idx] = static_cast(elt * scale_reciprocal);
+ }
+ }
+}
+
+template
+void compute_ref_x1(const ProcessingMethod processing_method,
+ const InputType* input,
+ const InputType* grad,
+ OutputType* output_c,
+ fp8e8m0* output_scales,
+ InputType* output_dbias,
+ const size_t rows,
+ const size_t cols,
+ const size_t block_size_Y,
+ const size_t block_size_X,
+ const size_t scales_stride)
+{
+ std::vector output_dbias_fp32(cols, 0);
+
+ const size_t blocks_Y = (rows + block_size_Y - 1) / block_size_Y;
+ const size_t blocks_X = (cols + block_size_X - 1) / block_size_X;
+
+ for (size_t ii = 0; ii < blocks_Y; ++ii) {
+ const size_t i_min = ii * block_size_Y;
+ const size_t i_max = std::min((ii + 1) * block_size_Y, rows);
+ for (size_t jj = 0; jj < blocks_X; ++jj) {
+ const size_t j_min = jj * block_size_X;
+ const size_t j_max = std::min((jj + 1) * block_size_X, cols);
+ const size_t scale_idx = ii * scales_stride + jj;
+ scale_block(
+ processing_method, input, grad, output_c, output_dbias_fp32.data(),
+ output_scales, scale_idx, i_min, i_max, j_min, j_max, cols);
+ }
+ }
+ for (size_t j = 0; j < cols; ++j) {
+ output_dbias[j] = static_cast(output_dbias_fp32[j]);
+ }
+}
+
+template
+void compute_ref_x2(const ProcessingMethod processing_method,
+ const InputType* input,
+ const InputType* grad,
+ OutputType* output_rowwise,
+ OutputType* output_colwise,
+ fp8e8m0* scales_rowwise,
+ fp8e8m0* scales_colwise,
+ InputType* output_dbias,
+ const size_t rows,
+ const size_t cols,
+ const size_t block_size_Y,
+ const size_t block_size_X,
+ const size_t scales_stride_rowwise,
+ const size_t scales_stride_colwise) {
+ compute_ref_x1(
+ processing_method, input, grad, output_rowwise, scales_rowwise, output_dbias,
+ rows, cols, 1, block_size_X, scales_stride_rowwise);
+ compute_ref_x1(
+ processing_method, input, grad, output_colwise, scales_colwise, output_dbias,
+ rows, cols, block_size_Y, 1, scales_stride_colwise);
+}
+
+/**
+ * Scaling along single dimension (either rows or columns)
+ * Produces one set of output data and the corresponding data of the fused operation (dbias):
+ * 1) Scaled rows + row-wise scaling factors
+ * OR
+ * 2) Scaled columns + column-wise scaling factors
+ */
+
+template
+void performTest_x1(const ProcessingMethod processing_method,
+ const std::vector& shape,
+ const bool rowwise,
+ const bool colwise,
+ InputsFillCase fill_case) {
+ using namespace test;
+ using EncodingType = fp32;
+ DType itype = TypeInfo::dtype;
+ DType otype = TypeInfo::dtype;
+
+ const size_t rows = first_dimension(shape);
+ const size_t cols = last_dimension(shape);
+
+ if (shape.size() < 2 && colwise) {
+ GTEST_SKIP();
+ }
+
+ const size_t block_size_rows = rowwise ? 1 : 32;
+ const size_t block_size_cols = colwise ? 1 : 32;
+
+ const std::array scale_dims = get_scale_tensor_dims(rows, cols, block_size_rows,
+ block_size_cols);
+
+ const size_t unpadded_blocks_Y = scale_dims[0];
+ const size_t unpadded_blocks_X = scale_dims[1];
+ const size_t blocks_Y = scale_dims[2];
+ const size_t blocks_X = scale_dims[3];
+ const size_t scales_stride = blocks_X;
+
+ Tensor input("input", shape, itype);
+ Tensor grad("grad", shape, itype);
+ Tensor output_c("output_c", shape, otype, rowwise, colwise, NVTE_MXFP8_1D_SCALING);
+ Tensor output_dbias("output_dbias", { cols }, itype);
+
+ std::unique_ptr ref_output_c = std::make_unique(rows * cols);
+ std::unique_ptr ref_output_dbias = std::make_unique(cols);
+ std::unique_ptr ref_output_scales = std::make_unique(blocks_Y * blocks_X);
+
+ fillCase(&input, fill_case);
+ fillUniform(&grad);
+
+ Tensor workspace;
+ switch (processing_method) {
+ case ProcessingMethod::CAST_ONLY: {
+ nvte_quantize(input.data(), output_c.data(), 0);
+ break;
+ }
+ case ProcessingMethod::CAST_DBIAS: {
+ nvte_quantize_dbias(grad.data(),
+ output_c.data(),
+ output_dbias.data(),
+ workspace.data(),
+ 0);
+ workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype());
+
+ nvte_quantize_dbias(grad.data(),
+ output_c.data(),
+ output_dbias.data(),
+ workspace.data(),
+ 0);
+ break;
+ }
+ case ProcessingMethod::CAST_DBIAS_DACT: {
+ nvte_quantize_dbias_dgelu(grad.data(),
+ input.data(),
+ output_c.data(),
+ output_dbias.data(),
+ workspace.data(),
+ 0);
+ workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype());
+
+ nvte_quantize_dbias_dgelu(grad.data(),
+ input.data(),
+ output_c.data(),
+ output_dbias.data(),
+ workspace.data(),
+ 0);
+ break;
+ }
+ case ProcessingMethod::CAST_DACT: {
+ nvte_dgelu(grad.data(), input.data(), output_c.data(), 0);
+ break;
+ }
+ case ProcessingMethod::CAST_ACT: {
+ nvte_gelu(input.data(), output_c.data(), 0);
+ break;
+ }
+ }
+
+ cudaDeviceSynchronize();
+ auto err = cudaGetLastError();
+ ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
+
+ compute_ref_x1(processing_method,
+ input.rowwise_cpu_dptr(),
+ grad.rowwise_cpu_dptr(),
+ ref_output_c.get(),
+ ref_output_scales.get(),
+ ref_output_dbias.get(),
+ rows,
+ cols,
+ block_size_rows,
+ block_size_cols,
+ scales_stride);
+
+ auto [atol, rtol] = getTolerances(otype);
+ compareResults("output_c", output_c, ref_output_c.get(), rowwise, atol, rtol);
+
+ const uint8_t * const gpu_scales_ptr = rowwise
+ ? output_c.rowwise_cpu_scale_inv_ptr()
+ : output_c.columnwise_cpu_scale_inv_ptr();
+
+ compare_e8m0_scaling_factors("scales", gpu_scales_ptr, ref_output_scales.get(),
+ unpadded_blocks_Y, unpadded_blocks_X, scales_stride);
+
+ if (processing_method == ProcessingMethod::CAST_DBIAS || processing_method == ProcessingMethod::CAST_DBIAS_DACT) {
+ auto [atol_dbias, rtol_dbias] = getTolerances(itype);
+ if (itype == DType::kFloat32) {
+ atol_dbias = 1e-4;
+ rtol_dbias *= sqrt(static_cast(rows)) ;
+ } else {
+ rtol_dbias *= 4;
+ }
+ compareResults("output_dbias", output_dbias, ref_output_dbias.get(), true, atol_dbias, rtol_dbias);
+ }
+}
+
+/**
+ * Scaling along both dimensions (rows and columns)
+ * Produces two sets of scaled output data and the corresponding data of the fused operation (dbias):
+ * 1) Scaled rows + row-wise scaling factors
+ * AND
+ * 2) Scaled columns + column-wise scaling factors
+ */
+template
+void performTest_x2(const ProcessingMethod processing_method,
+ const std::vector& shape,
+ const size_t block_size_rows,
+ const size_t block_size_cols,
+ InputsFillCase fill_case) {
+ using namespace test;
+ using EncodingType = fp32;
+ DType itype = TypeInfo::dtype;
+ DType otype = TypeInfo::dtype;
+
+ if (shape.size() < 2) {
+ GTEST_SKIP();
+ }
+
+ const size_t rows = first_dimension(shape);
+ const size_t cols = last_dimension(shape);
+
+ const std::array scale_dims_rowwise = get_scale_tensor_dims(rows, cols, 1, 32);
+ const std::array scale_dims_colwise = get_scale_tensor_dims(rows, cols, 32, 1);
+
+ const size_t unpadded_blocks_Y_rowwise = scale_dims_rowwise[0];
+ const size_t unpadded_blocks_X_rowwise = scale_dims_rowwise[1];
+ const size_t blocks_Y_rowwise = scale_dims_rowwise[2];
+ const size_t blocks_X_rowwise = scale_dims_rowwise[3];
+ const size_t scales_stride_rowwise = blocks_X_rowwise;
+
+ const size_t unpadded_blocks_Y_colwise = scale_dims_colwise[0];
+ const size_t unpadded_blocks_X_colwise = scale_dims_colwise[1];
+ const size_t blocks_Y_colwise = scale_dims_colwise[2];
+ const size_t blocks_X_colwise = scale_dims_colwise[3];
+ const size_t scales_stride_colwise = blocks_X_colwise;
+
+ Tensor input("input", shape, itype);
+ Tensor grad("grad", shape, itype);
+ Tensor output("output", shape, otype, true, true, NVTE_MXFP8_1D_SCALING);
+ Tensor output_dbias("output_dbias", { cols }, itype);
+
+ std::unique_ptr ref_output_c_rowwise = std::make_unique(rows * cols);
+ std::unique_ptr ref_output_c_colwise = std::make_unique(rows * cols);
+ std::unique_ptr ref_scales_rowwise = std::make_unique(blocks_Y_rowwise * blocks_X_rowwise);
+ std::unique_ptr ref_scales_colwise = std::make_unique(blocks_Y_colwise * blocks_X_colwise);
+ std::unique_ptr ref_output_dbias = std::make_unique(cols);
+
+ fillCase(&input, fill_case);
+ fillUniform(&grad);
+
+ Tensor workspace;
+ switch (processing_method) {
+ case ProcessingMethod::CAST_ONLY: {
+ nvte_quantize(input.data(), output.data(), 0);
+ break;
+ }
+ case ProcessingMethod::CAST_DBIAS: {
+ nvte_quantize_dbias(grad.data(),
+ output.data(),
+ output_dbias.data(),
+ workspace.data(),
+ 0);
+ workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype());
+
+ nvte_quantize_dbias(grad.data(),
+ output.data(),
+ output_dbias.data(),
+ workspace.data(),
+ 0);
+ break;
+ }
+ case ProcessingMethod::CAST_DBIAS_DACT: {
+ nvte_quantize_dbias_dgelu(grad.data(),
+ input.data(),
+ output.data(),
+ output_dbias.data(),
+ workspace.data(),
+ 0);
+ workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype());
+
+ nvte_quantize_dbias_dgelu(grad.data(),
+ input.data(),
+ output.data(),
+ output_dbias.data(),
+ workspace.data(),
+ 0);
+ break;
+ }
+ case ProcessingMethod::CAST_DACT: {
+ nvte_dgelu(grad.data(), input.data(), output.data(), 0);
+ break;
+ }
+ case ProcessingMethod::CAST_ACT: {
+ nvte_gelu(input.data(), output.data(), 0);
+ break;
+ }
+ }
+
+ cudaDeviceSynchronize();
+ auto err = cudaGetLastError();
+ ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
+
+ compute_ref_x2(processing_method,
+ input.rowwise_cpu_dptr(),
+ grad.rowwise_cpu_dptr(),
+ ref_output_c_rowwise.get(),
+ ref_output_c_colwise.get(),
+ ref_scales_rowwise.get(),
+ ref_scales_colwise.get(),
+ ref_output_dbias.get(),
+ rows,
+ cols,
+ block_size_rows,
+ block_size_cols,
+ scales_stride_rowwise,
+ scales_stride_colwise);
+
+ auto [atol, rtol] = getTolerances(otype);
+ compareResults("output_c_rowwise", output, ref_output_c_rowwise.get(), true, atol, rtol);
+ compareResults("output_c_colwise", output, ref_output_c_colwise.get(), false, atol, rtol);
+ compare_e8m0_scaling_factors("scales_rowwise", output.rowwise_cpu_scale_inv_ptr(),
+ ref_scales_rowwise.get(), unpadded_blocks_Y_rowwise,
+ unpadded_blocks_X_rowwise, scales_stride_rowwise);
+ compare_e8m0_scaling_factors("scales_colwise", output.columnwise_cpu_scale_inv_ptr(),
+ ref_scales_colwise.get(), unpadded_blocks_Y_colwise,
+ unpadded_blocks_X_colwise, scales_stride_colwise);
+
+ if (processing_method == ProcessingMethod::CAST_DBIAS || processing_method == ProcessingMethod::CAST_DBIAS_DACT) {
+ auto [atol_dbias, rtol_dbias] = getTolerances(itype);
+ if (itype == DType::kFloat32) {
+ atol_dbias = 1e-4;
+ rtol_dbias *= sqrt(static_cast(rows)) ;
+ } else {
+ rtol_dbias *= 4;
+ }
+ compareResults("output_dbias", output_dbias, ref_output_dbias.get(), true, atol_dbias, rtol_dbias);
+ }
+}
+
+std::vector> matrix_sizes = {
+ {1, 16},
+ {16, 48},
+ {65, 96},
+ {128, 128},
+ {256, 256},
+ {993, 512},
+ {256, 65536},
+ {2048, 6144},
+ {16384, 128},
+ {32768, 160},
+ {4096, 1632},
+ {1024},
+ {8, 32, 1024},
+ {16, 8, 4, 512},
+};
+
+std::vector> block_sizes = {
+ {1, 32},
+ {32, 1},
+ {32, 32},
+};
+
+std::vector input_scenarios = {
+ InputsFillCase::uniform,
+ // InputsFillCase::zeros,
+ // InputsFillCase::zero_to_minNorm,
+ // InputsFillCase::minNorm_to_maxNorm,
+ // InputsFillCase::maxNorm_to_inf
+};
+
+std::vector processing_methods = {
+ ProcessingMethod::CAST_ONLY,
+ ProcessingMethod::CAST_DBIAS,
+ ProcessingMethod::CAST_DBIAS_DACT,
+ ProcessingMethod::CAST_DACT,
+ ProcessingMethod::CAST_ACT,
+};
+
+// Only GeLU activation tests are supported
+std::vector Activation_types = {
+ ActivationType::Identity,
+ ActivationType::GeLU,
+ // ActivationType::SiLU,
+ // ActivationType::ReLU,
+ // ActivationType::QGeLU,
+ // ActivationType::SReLU,
+};
+
+} // namespace
+
+class FusedCastMXFP8TestSuite : public ::testing::TestWithParam
+ ,
+ std::pair,
+ transformer_engine::DType,
+ transformer_engine::DType,
+ InputsFillCase>> {};
+
+#define DACT_FUNC_SWITCH(OP_FUNC_TYPE, OP, ...) \
+switch (OP_FUNC_TYPE) { \
+ case ActivationType::Identity: { constexpr auto OP = &identity; { __VA_ARGS__ } } break; \
+ case ActivationType::GeLU: { constexpr auto OP = &dgelu; { __VA_ARGS__ } } break; \
+ case ActivationType::SiLU: { constexpr auto OP = &dsilu; { __VA_ARGS__ } } break; \
+ case ActivationType::ReLU: { constexpr auto OP = &drelu; { __VA_ARGS__ } } break; \
+ case ActivationType::QGeLU: { constexpr auto OP = &dqgelu; { __VA_ARGS__ } } break; \
+ case ActivationType::SReLU: { constexpr auto OP = &dsrelu; { __VA_ARGS__ } } break; \
+}
+
+#define ACT_FUNC_SWITCH(OP_FUNC_TYPE, OP, ...) \
+switch (OP_FUNC_TYPE) { \
+ case ActivationType::Identity: { constexpr auto OP = &identity; { __VA_ARGS__ } } break; \
+ case ActivationType::GeLU: { constexpr auto OP = &gelu; { __VA_ARGS__ } } break; \
+ case ActivationType::SiLU: { constexpr auto OP = &silu; { __VA_ARGS__ } } break; \
+ case ActivationType::ReLU: { constexpr auto OP = &relu; { __VA_ARGS__ } } break; \
+ case ActivationType::QGeLU: { constexpr auto OP = &qgelu; { __VA_ARGS__ } } break; \
+ case ActivationType::SReLU: { constexpr auto OP = &srelu; { __VA_ARGS__ } } break; \
+}
+
+TEST_P(FusedCastMXFP8TestSuite, TestFusedCastMXFP8) {
+ // Skip tests for pre-Blackwell architectures
+ if (getDeviceComputeCapability() < blackwellComputeCapability) {
+ GTEST_SKIP();
+ }
+
+ using namespace transformer_engine;
+ using namespace test;
+
+ const ProcessingMethod processing_method = std::get<0>(GetParam());
+ const ActivationType Act_type = std::get<1>(GetParam());
+ const auto matrix_size = std::get<2>(GetParam());
+ const auto block_size = std::get<3>(GetParam());
+ const DType input_type = std::get<4>(GetParam());
+ const DType output_type = std::get<5>(GetParam());
+ const InputsFillCase fill_case = std::get<6>(GetParam());
+
+ // Skips non Act tests if the Activation type is not an identity
+ if ((processing_method == ProcessingMethod::CAST_ONLY || processing_method == ProcessingMethod::CAST_DBIAS)
+ && Act_type != ActivationType::Identity) {
+ GTEST_SKIP();
+ }
+ // Skips Act tests if the Activation is an identity
+ if ((processing_method == ProcessingMethod::CAST_DBIAS_DACT
+ || processing_method == ProcessingMethod::CAST_DACT
+ || processing_method == ProcessingMethod::CAST_ACT) && (Act_type == ActivationType::Identity)) {
+ GTEST_SKIP();
+ }
+
+ const bool rowwise = block_size.second != 1;
+ const bool colwise = block_size.first != 1;
+ if (processing_method == ProcessingMethod::CAST_ACT) {
+ // Forward activations
+ ACT_FUNC_SWITCH(Act_type, OP,
+ TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType,
+ TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(output_type, OutputType,
+ if (block_size.first == 1 || block_size.second == 1) {
+ performTest_x1(
+ processing_method, matrix_size,
+ rowwise, colwise, fill_case);
+ } else {
+ performTest_x2(
+ processing_method, matrix_size,
+ block_size.first, block_size.second, fill_case);
+ }
+ );
+ );
+ );
+ } else {
+ DACT_FUNC_SWITCH(Act_type, OP,
+ TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType,
+ TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(output_type, OutputType,
+ if (block_size.first == 1 || block_size.second == 1) {
+ performTest_x1(
+ processing_method, matrix_size,
+ rowwise, colwise, fill_case);
+ } else {
+ performTest_x2(
+ processing_method, matrix_size,
+ block_size.first, block_size.second, fill_case);
+ }
+ );
+ );
+ );
+ }
+}
+
+std::string to_string(const ProcessingMethod method) {
+ switch (method) {
+ case ProcessingMethod::CAST_ONLY: return "CAST_ONLY";
+ case ProcessingMethod::CAST_DBIAS: return "CAST_DBIAS";
+ case ProcessingMethod::CAST_DBIAS_DACT: return "CAST_DBIAS_DACT";
+ case ProcessingMethod::CAST_DACT: return "CAST_DACT";
+ case ProcessingMethod::CAST_ACT: return "CAST_ACT";
+ default: return "";
+ }
+}
+
+std::string to_string(const ActivationType Act_type) {
+ switch (Act_type) {
+ case ActivationType::Identity: return "Identity";
+ case ActivationType::GeLU: return "GeLU";
+ case ActivationType::SiLU: return "SiLU";
+ case ActivationType::ReLU: return "ReLU";
+ case ActivationType::QGeLU: return "QGeLU";
+ case ActivationType::SReLU: return "SReLU";
+ default: return "";
+ }
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ OperatorTest,
+ FusedCastMXFP8TestSuite,
+ ::testing::Combine(
+ ::testing::ValuesIn(processing_methods),
+ ::testing::ValuesIn(Activation_types),
+ ::testing::ValuesIn(matrix_sizes),
+ ::testing::ValuesIn(block_sizes),
+ ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
+ ::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2),
+ ::testing::ValuesIn(input_scenarios)),
+ [](const testing::TestParamInfo& info) {
+ std::string name = to_string(std::get<0>(info.param)) + "X" +
+ to_string(std::get<1>(info.param));
+ const auto& shape = std::get<2>(info.param);
+ for ( const auto& s: shape) {
+ name += "X" + std::to_string(s);
+ }
+ name += "X" + std::to_string(std::get<3>(info.param).first) +
+ "X" + std::to_string(std::get<3>(info.param).second) +
+ "X" + test::typeName(std::get<4>(info.param)) +
+ "X" + test::typeName(std::get<5>(info.param)) +
+ "X" + test::caseName(std::get<6>(info.param));
+ return name;
+ });
diff --git a/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu b/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu
new file mode 100644
index 0000000000..6acbdefeab
--- /dev/null
+++ b/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu
@@ -0,0 +1,470 @@
+/*************************************************************************
+ * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ *
+ * See LICENSE for license information.
+ ************************************************************************/
+
+#include
+#include
+#include
+#include
+
+#include
+#include "../test_common.h"
+#include "transformer_engine/transformer_engine.h"
+
+using namespace transformer_engine;
+using namespace test;
+
+namespace {
+
+template
+void scale_block(const IType* grad,
+ const IType* input,
+ OType* output,
+ fp8e8m0* output_scales,
+ const size_t scale_idx,
+ const size_t scale_idx_gate,
+ float& thread_amax,
+ const size_t i_min,
+ const size_t i_max,
+ const size_t j_min,
+ const size_t j_max,
+ const size_t cols) {
+
+ float block_amax = 0.0f;
+ float block_amax_gate = 0.0f;
+ const size_t stride = cols * 2;
+
+ // Find the absolute maximum value in the block
+ for (size_t i = i_min; i < i_max; ++i) {
+ for (size_t j = j_min; j < j_max; ++j) {
+ float silu_elt = static_cast(input[i * stride + j]);
+ float gate_elt = static_cast(input[i * stride + cols + j]);
+ float gated_amax_act = 0;
+ float gated_amax_gate = 0;
+
+ if constexpr (IS_DGATED) {
+ const float grad_elt = static_cast(grad[i * cols + j]);
+ const float after_dsilu = dsilu(silu_elt) * grad_elt * gate_elt;
+ const float after_dgate = silu(silu_elt) * grad_elt;
+ gated_amax_act = abs(after_dsilu);
+ gated_amax_gate = abs(after_dgate);
+ } else {
+ const float after_silu = silu(silu_elt) * gate_elt;
+ gated_amax_act = abs(after_silu);
+ }
+
+ if (gated_amax_act > block_amax) { block_amax = gated_amax_act; }
+ if (gated_amax_gate > block_amax_gate) { block_amax_gate = gated_amax_gate; }
+ }
+ }
+
+ const fp8e8m0 biased_exponent = float_to_e8m0(block_amax *
+ Quantized_Limits::max_reciprocal());
+ const float scale_reciprocal = exp2f_rcp(biased_exponent);
+ output_scales[scale_idx] = biased_exponent;
+ float scale_reciprocal_gate = 1;
+ if constexpr (IS_DGATED) {
+ const fp8e8m0 biased_exponent = float_to_e8m0(block_amax_gate *
+ Quantized_Limits::max_reciprocal());
+ scale_reciprocal_gate = exp2f_rcp(biased_exponent);
+ output_scales[scale_idx_gate] = biased_exponent;
+ }
+
+
+ // Quantize elements in the block
+ for (size_t i = i_min; i < i_max; ++i) {
+ for (size_t j = j_min; j < j_max; ++j) {
+ float silu_elt = static_cast(input[i * stride + j]);
+ float gate_elt = static_cast(input[i * stride + cols + j]);
+
+ if constexpr (IS_DGATED) {
+ const float grad_elt = static_cast(grad[i * cols + j]);
+ const float after_dsilu = dsilu(silu_elt) * grad_elt * gate_elt;
+ const float after_dgate = silu(silu_elt) * grad_elt;
+ output[i * stride + j] = static_cast(after_dsilu * scale_reciprocal);
+ output[i * stride + cols + j] = static_cast(after_dgate *
+ scale_reciprocal_gate);
+ } else {
+ const float after_silu = silu(silu_elt) * gate_elt;
+ output[i * cols + j] = static_cast(after_silu * scale_reciprocal);
+ }
+
+ }
+ }
+ thread_amax = std::max(thread_amax, block_amax);
+ thread_amax = std::max(thread_amax, block_amax_gate);
+}
+
+template
+void compute_ref_x1(const IType* grad,
+ const IType* input,
+ OType* output,
+ fp8e8m0* output_scales,
+ float& ref_amax,
+ const size_t rows,
+ const size_t cols,
+ const size_t block_size_Y,
+ const size_t block_size_X,
+ const size_t scales_stride) {
+ const size_t tile_size_Y = std::max(32lu, block_size_Y);
+ const size_t tile_size_X = std::max(64lu, block_size_X);
+ const size_t tiles_num_Y = (rows + tile_size_Y - 1) / tile_size_Y;
+ const size_t tiles_num_X = (cols + tile_size_X - 1) / tile_size_X;
+ const size_t blocks_per_tile_Y = tile_size_Y / block_size_Y;
+ const size_t blocks_per_tile_X = tile_size_X / block_size_X;
+
+ float amax = 0;
+ #pragma omp parallel reduction(max: amax) proc_bind(spread)
+ {
+ float thread_amax = 0;
+ #pragma omp for schedule(static)
+ for (size_t t = 0; t < tiles_num_Y * tiles_num_X; ++t) {
+ const size_t tile_Y = t / tiles_num_X;
+ const size_t tile_X = t % tiles_num_X;
+ const size_t tile_offset_Y = tile_Y * tile_size_Y;
+ const size_t tile_offset_X = tile_X * tile_size_X;
+
+ for (size_t ii = 0; ii < blocks_per_tile_Y; ++ii) {
+ const size_t block_idx_Y = tile_Y * blocks_per_tile_Y + ii;
+ const size_t block_offset_Y = ii * block_size_Y;
+ const size_t i_min = tile_offset_Y + block_offset_Y;
+ if (i_min >= rows) continue;
+ const size_t i_max = std::min(i_min + block_size_Y, rows);
+
+ for (size_t jj = 0; jj < blocks_per_tile_X; ++jj) {
+ const size_t block_idx_X = tile_X * blocks_per_tile_X + jj;
+ const size_t block_offset_X = jj * block_size_X;
+ const size_t j_min = tile_offset_X + block_offset_X;
+ if (j_min >= cols) continue;
+ const size_t j_max = std::min(j_min + block_size_X, cols);
+
+ const size_t mx_scale_idx = block_idx_Y * scales_stride + block_idx_X;
+ const size_t mx_scale_idx_gate = block_idx_Y * scales_stride + block_idx_X +
+ cols / block_size_X;
+ scale_block(
+ grad, input, output, output_scales, mx_scale_idx, mx_scale_idx_gate,
+ thread_amax, i_min, i_max, j_min, j_max, cols);
+ }
+ }
+ }
+ if (thread_amax > amax) {
+ amax = thread_amax;
+ }
+ }
+ ref_amax = amax;
+}
+
+template
+void compute_ref_x2(const IType* grad,
+ const IType* input,
+ OType* output_rowwise,
+ OType* output_colwise,
+ fp8e8m0* scales_rowwise,
+ fp8e8m0* scales_colwise,
+ float& ref_amax,
+ const size_t rows,
+ const size_t cols,
+ const size_t block_size_Y,
+ const size_t block_size_X,
+ const size_t scales_stride_rowwise,
+ const size_t scales_stride_colwise) {
+ compute_ref_x1(
+ grad, input, output_rowwise, scales_rowwise, ref_amax, rows, cols, 1, block_size_X, scales_stride_rowwise);
+ compute_ref_x1(
+ grad, input, output_colwise, scales_colwise, ref_amax, rows, cols, block_size_Y, 1, scales_stride_colwise);
+}
+
+/**
+ * Scaling along single dimension (either rows or columns)
+ * Produces one set of output data and the corresponding data of the fused operation (dbias):
+ * 1) Scaled rows + row-wise scaling factors
+ * OR
+ * 2) Scaled columns + column-wise scaling factors
+ */
+template
+void performTest_x1(const size_t rows,
+ const size_t cols,
+ const size_t block_size_rows,
+ const size_t block_size_cols,
+ InputsFillCase fill_case) {
+ using namespace test;
+ using EncodingType = fp32;
+ DType itype = TypeInfo::dtype;
+ DType otype = TypeInfo::dtype;
+
+ const bool rowwise = (block_size_rows == 1) && (block_size_cols == 32);
+ const bool colwise = (block_size_rows == 32) && (block_size_cols == 1);
+ NVTE_CHECK(rowwise || colwise);
+
+ // std::cout << "unpadded_blocks_Y: " << unpadded_blocks_Y << std::endl;
+ // std::cout << "unpadded_blocks_X: " << unpadded_blocks_X << std::endl;
+ // std::cout << "blocks_Y: " << blocks_Y << std::endl;
+ // std::cout << "blocks_X: " << blocks_X << std::endl;
+ // std::cout << "scales_stride: " << scales_stride << std::endl;
+
+ Tensor grad("grad", { rows, cols }, itype);
+ Tensor input("input", { rows, cols * 2 }, itype);
+
+ const size_t output_cols = (IS_DGATED ? 2 : 1) * cols;
+
+ const std::array scale_dims = get_scale_tensor_dims(rows, output_cols, block_size_rows,
+ block_size_cols);
+
+ const size_t unpadded_blocks_Y = scale_dims[0];
+ const size_t unpadded_blocks_X = scale_dims[1];
+ const size_t blocks_Y = scale_dims[2];
+ const size_t blocks_X = scale_dims[3];
+ const size_t scales_stride = blocks_X;
+
+ Tensor output("output", std::vector{ rows, output_cols }, otype,
+ rowwise, colwise, NVTE_MXFP8_1D_SCALING);
+
+ std::unique_ptr ref_output = std::make_unique(rows * output_cols);
+ std::unique_ptr ref_output_scales = std::make_unique(blocks_Y * blocks_X);
+
+ for (size_t i = 0; i < blocks_Y * blocks_X; ++i) {
+ ref_output_scales[i] = 0;
+ }
+
+ // fillCase(&grad, fill_case);
+ if constexpr (IS_DGATED) {
+ fillUniform(&grad);
+ }
+ fillUniform(&input);
+
+ if constexpr (IS_DGATED) {
+ nvte_dswiglu(grad.data(), input.data(), output.data(), 0);
+ } else {
+ nvte_swiglu(input.data(), output.data(), 0);
+ }
+ cudaDeviceSynchronize();
+
+ auto err = cudaGetLastError();
+ ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
+
+ float ref_amax = 0;
+ compute_ref_x1