Skip to content

Commit

Permalink
Added memory alignment check to cast_fp8_1D (#1507)
Browse files Browse the repository at this point in the history
* Added TMA alignment check to cast_fp8_1D

Signed-off-by: Oleg Goncharov <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Use tensor const-ref instead of tensor const-ptr

Signed-off-by: Tim Moon <[email protected]>

---------

Signed-off-by: Oleg Goncharov <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Tim Moon <[email protected]>
Co-authored-by: Tim Moon <[email protected]>
  • Loading branch information
4 people authored Feb 26, 2025
1 parent 8ca2caf commit 5d85857
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 11 deletions.
8 changes: 1 addition & 7 deletions transformer_engine/common/common.cu
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,6 @@ CUtensorMapDataType get_CUtensorMapDataType(DType dtype) {
return dtypeMapping.at(dtype);
}

inline bool isPointerAligned(const void *const ptr, const int alignment) {
const uint64_t ptr_as_uint = reinterpret_cast<uint64_t>(ptr);
return ptr_as_uint % alignment == 0;
}

// Set up parameters to create TMA descriptor.
void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor,
const uint64_t globalY, const uint64_t globalX, const uint32_t shmemY,
Expand Down Expand Up @@ -100,8 +95,7 @@ void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor,
void *dataPtr =
reinterpret_cast<void *>(reinterpret_cast<uint8_t *>(tensor.dptr) + offset_elems * type_size);

constexpr int TMA_gmem_alignment = 16; // Alignment of the global memory address
NVTE_CHECK(isPointerAligned(dataPtr, TMA_gmem_alignment),
NVTE_CHECK(is_aligned_ptr(dataPtr, TMA_gmem_alignment),
"Tensor data pointer must be 16B aligned");

const int TMA_needed_size = TMA_gmem_alignment / type_size;
Expand Down
14 changes: 12 additions & 2 deletions transformer_engine/common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include <cuda_runtime_api.h>
#include <transformer_engine/transformer_engine.h>

#include <cstdint>
#include <functional>
#include <stdexcept>
#include <string>
Expand Down Expand Up @@ -426,6 +427,17 @@ constexpr size_t scale_tensor_alignment_Y_rowwise = 128;
constexpr size_t scale_tensor_alignment_X_colwise = 128;
constexpr size_t scale_tensor_alignment_Y_colwise = 4;

// Alignment requirements for the Tensor Memory Accelerator (TMA)
constexpr int TMA_gmem_alignment = 16; // global memory address alignment

inline bool is_aligned_ptr(const void *ptr, size_t alignment) {
return reinterpret_cast<uintptr_t>(ptr) % alignment == 0;
}

inline bool is_aligned_tensor_data(const Tensor &t, size_t alignment) {
return is_aligned_ptr(static_cast<const void *>(t.data.dptr), alignment);
}

size_t typeToSize(const DType type);

void CheckNoopTensor(const Tensor &t, const std::string &name);
Expand Down Expand Up @@ -465,8 +477,6 @@ void checkCuDriverContext(CUstream stream);

CUtensorMapDataType get_CUtensorMapDataType(DType dtype);

inline bool isPointerAligned(const void *const ptr, const int alignment);

// Set up parameters to create TMA descriptor.
void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor,
const uint64_t globalY, const uint64_t globalX, const uint32_t shmemY,
Expand Down
9 changes: 7 additions & 2 deletions transformer_engine/common/util/cast_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1110,15 +1110,20 @@ void fp8_quantize_arch_ge_100(const Tensor &input, const Tensor *act_input, cons
switch (output->scaling_mode) {
case NVTE_DELAYED_TENSOR_SCALING: {
if (!IS_DBIAS && !IS_DACT) {
if (is_full_tile_1D_tensor(output) && is_fp8_dtype(output->dtype())) {
if (is_full_tile_1D_tensor(output) && is_fp8_dtype(output->dtype()) &&
is_aligned_tensor_data(input, TMA_gmem_alignment) &&
is_aligned_tensor_data(*output, TMA_gmem_alignment)) {
// Aligned AND FP8
cast_fp8_1D<IS_ACT, ParamOP, OP>(input, output, stream);
} else {
// Unaligned
CastVectorizedUnaryKernelLauncher<ParamOP, OP>(input, noop, output, stream);
}
} else if (!IS_DBIAS && IS_DACT) {
if (dimensions_supported_by_TMA(output) && is_fp8_dtype(output->dtype())) {
if (dimensions_supported_by_TMA(output) && is_fp8_dtype(output->dtype()) &&
is_aligned_tensor_data(input, TMA_gmem_alignment) &&
is_aligned_tensor_data(*output, TMA_gmem_alignment) &&
is_aligned_tensor_data(*act_input, TMA_gmem_alignment)) {
// Aligned AND FP8 (+dAct)
cast_fp8_2D<IS_DBIAS, IS_DACT, ParamOP, OP>(input, act_input, output, dbias, workspace,
stream);
Expand Down

0 comments on commit 5d85857

Please sign in to comment.