diff --git a/include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp index 62f4482c7a..2a70c84856 100644 --- a/include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp @@ -972,14 +972,20 @@ compute_row_broadcast_stages() { template< int Stages, class CtaTileShapeMNK, - class ElementInput, - class ElementCompute = ElementInput, + class ElementInput_, + class ElementCompute = cute::remove_pointer_t, class StrideMNL_ = Stride<_0,_1,_0>, - int Alignment = 128 / sizeof_bits_v, + int Alignment = 128 / sizeof_bits_v>, bool EnableNullptr = true // Fallback scalar broadcast for nullptr params > struct Sm90RowBroadcast { using StrideMNL = StrideMNL_; + // Get base element input type. + using ElementInput = cute::remove_pointer_t; + // Check if input is an array of pointers. + static constexpr bool IsArrayOfPointers = is_same_v; + using PtrRowType = cute::conditional_t; + static_assert(Stages == 0, "Row broadcast doesn't support smem pipelining"); static constexpr bool IsDynamicBroadcast = is_same_v(StrideMNL{}))>, bool>; // row vector or scalar broadcast @@ -991,7 +997,7 @@ struct Sm90RowBroadcast { }; struct Arguments { - ElementInput const* ptr_row = nullptr; + PtrRowType ptr_row = nullptr; ElementInput null_default = ElementInput(0); StrideMNL dRow = {}; }; @@ -1036,7 +1042,7 @@ struct Sm90RowBroadcast { is_zero_ = params.null_default == ElementCompute(0); } // Dynamic non-batched scalar broadcast - else if (IsDynamicBroadcast && stride_N == bool(0) && stride_L == repeat_like(stride_L, 0)) { + else if (IsDynamicBroadcast && stride_N == bool(0) && stride_L == repeat_like(stride_L, 0) && !IsArrayOfPointers) { is_zero_ = params.ptr_row[0] == ElementInput(0); } } @@ -1183,7 +1189,13 @@ struct Sm90RowBroadcast { auto layout_M = make_layout(M, repeat_like(M, _0{})); auto layout_L = make_layout(L, get<2>(params.dRow)); - Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row), make_layout(layout_M,layout_N,layout_L)); + ElementInput const* ptr_row; + if constexpr(IsArrayOfPointers) { + ptr_row = params.ptr_row[l]; + } else { + ptr_row = params.ptr_row; + } + Tensor mRow = make_tensor(make_gmem_ptr(ptr_row), make_layout(layout_M,layout_N,layout_L)); Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N) Tensor sRow = make_tensor(make_smem_ptr(smem), make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})), make_shape(_0{}, _1{})); // (CTA_M, CTA_N) @@ -1220,14 +1232,20 @@ struct Sm90RowBroadcast { template< int Stages, class CtaTileShapeMNK, - class ElementInput, - class ElementCompute = ElementInput, + class ElementInput_, + class ElementCompute = cute::remove_pointer_t, class StrideMNL_ = Stride<_1,_0,_0>, - int Alignment = 128 / sizeof_bits_v, + int Alignment = 128 / sizeof_bits_v>, bool EnableNullptr = true // Fallback scalar broadcast for nullptr params > struct Sm90ColBroadcast { using StrideMNL = StrideMNL_; + // Get base element input type. + using ElementInput = cute::remove_pointer_t; + // Check if input is an array of pointers. + static constexpr bool IsArrayOfPointers = is_same_v; + using PtrColType = cute::conditional_t; + static_assert(Stages == 0, "Column broadcast doesn't support smem pipelining"); static constexpr bool IsDynamicBroadcast = is_same_v(StrideMNL{}))>, bool>; // Column vector or scalar broadcast @@ -1238,13 +1256,13 @@ struct Sm90ColBroadcast { struct SharedStorage { }; struct Arguments { - ElementInput const* ptr_col = nullptr; + PtrColType ptr_col = nullptr; ElementInput null_default = ElementInput(0); StrideMNL dCol = {}; }; struct Params { - ElementInput const* ptr_col = nullptr; + PtrColType ptr_col = nullptr; ElementCompute null_default = ElementCompute(0); StrideMNL dCol = {}; }; @@ -1301,7 +1319,7 @@ struct Sm90ColBroadcast { is_zero_ = params.null_default == ElementCompute(0); } // Dynamic non-batched scalar broadcast - else if (IsDynamicBroadcast && stride_M == bool(0) && stride_L == repeat_like(stride_L, 0)) { + else if (IsDynamicBroadcast && stride_M == bool(0) && stride_L == repeat_like(stride_L, 0) && !IsArrayOfPointers) { is_zero_ = params.ptr_col[0] == ElementInput(0); } } @@ -1398,6 +1416,7 @@ struct Sm90ColBroadcast { get_consumer_store_callbacks(ConsumerStoreArgs const& args) { auto [M, N, K, L] = args.problem_shape_mnkl; + auto [m, n, k, l] = args.tile_coord_mnkl; auto layout_M = [&] () CUTLASS_LAMBDA_FUNC_INLINE { auto shape_M = get<0>(args.problem_shape_mnkl); if constexpr (IsDynamicBroadcast) { @@ -1416,11 +1435,17 @@ struct Sm90ColBroadcast { auto layout_N = make_layout(N, repeat_like(N, _0{})); auto layout_L = make_layout(L, get<2>(params.dCol)); - Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col), make_layout(layout_M,layout_N,layout_L)); + ElementInput const* ptr_col; + if constexpr(IsArrayOfPointers) { + ptr_col = params.ptr_col[l]; + } else { + ptr_col = params.ptr_col; + } + Tensor mCol = make_tensor(make_gmem_ptr(ptr_col), make_layout(layout_M,layout_N,layout_L)); Tensor tCgCol = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); - Tensor mCol_static = make_tensor(make_gmem_ptr(params.ptr_col), make_layout(make_layout(M),layout_N,layout_L)); + Tensor mCol_static = make_tensor(make_gmem_ptr(ptr_col), make_layout(make_layout(M),layout_N,layout_L)); Tensor tCgCol_static = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) mCol_static, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); Tensor tCrCol = make_tensor_like(tCgCol_static); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)