Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[EVT] Add support for Row/Col broadcast PtrArray #2033

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -973,26 +973,32 @@ template<
int Stages,
class CtaTileShapeMNK,
class ElementInput,
class ElementCompute = ElementInput,
class ElementCompute = std::remove_pointer_t<ElementInput>,
class StrideMNL_ = Stride<_0,_1,_0>,
int Alignment = 128 / sizeof_bits_v<ElementInput>,
int Alignment = 128 / sizeof_bits_v<std::remove_pointer_t<ElementInput>>,
bool EnableNullptr = true // Fallback scalar broadcast for nullptr params
>
struct Sm90RowBroadcast {
using StrideMNL = StrideMNL_;
// Get base element input type.
using ElementInput_ = std::remove_pointer_t<ElementInput>;
// Check if input is an array of pointers.
static constexpr bool is_ptr_array_ = is_same_v<ElementInput_*, ElementInput>;
using PtrRowType = std::conditional_t<is_ptr_array_, ElementInput_ const* const*, ElementInput_ const*>;

static_assert(Stages == 0, "Row broadcast doesn't support smem pipelining");

static constexpr bool IsDynamicBroadcast = is_same_v<remove_cvref_t<decltype(get<1>(StrideMNL{}))>, bool>; // row vector or scalar broadcast
static_assert(is_static_v<decltype(take<0,2>(StrideMNL{}))> || IsDynamicBroadcast); // batch stride can be dynamic or static
static_assert(take<0,2>(StrideMNL{}) == Stride<_0,_1>{} || IsDynamicBroadcast);

struct SharedStorage {
array_aligned<ElementInput, size<1>(CtaTileShapeMNK{})> smem;
array_aligned<ElementInput_, size<1>(CtaTileShapeMNK{})> smem;
};

struct Arguments {
ElementInput const* ptr_row = nullptr;
ElementInput null_default = ElementInput(0);
PtrRowType ptr_row = nullptr;
ElementInput_ null_default = ElementInput_(0);
StrideMNL dRow = {};
};

Expand Down Expand Up @@ -1029,21 +1035,23 @@ struct Sm90RowBroadcast {
CUTLASS_HOST_DEVICE
Sm90RowBroadcast(Params const& params, SharedStorage const& shared_storage)
: params(params), is_zero_(false),
smem(const_cast<ElementInput*>(shared_storage.smem.data())) {
smem(const_cast<ElementInput_*>(shared_storage.smem.data())) {
auto const& [stride_M, stride_N, stride_L] = params.dRow;
// Nullptr default
if (EnableNullptr && params.ptr_row == nullptr) {
is_zero_ = params.null_default == ElementCompute(0);
}
// Dynamic non-batched scalar broadcast
else if (IsDynamicBroadcast && stride_N == bool(0) && stride_L == repeat_like(stride_L, 0)) {
is_zero_ = params.ptr_row[0] == ElementInput(0);
if constexpr(!is_ptr_array_) {
if (IsDynamicBroadcast && stride_N == bool(0) && stride_L == repeat_like(stride_L, 0)) {
is_zero_ = params.ptr_row[0] == ElementInput_(0);
}
}
}

Params params;
bool is_zero_ = false;
ElementInput *smem = nullptr;
ElementInput_ *smem = nullptr;

CUTLASS_DEVICE bool
is_producer_load_needed() const {
Expand Down Expand Up @@ -1125,13 +1133,13 @@ struct Sm90RowBroadcast {
begin_loop(int epi_m, int epi_n) {
if (epi_m == 0) { // Assumes M-major subtile loop
Tensor tSR_sRow_flt = filter_zeros(tSR_sRow(_,_,_,epi_m,epi_n));
Tensor tSR_rRow_flt = make_tensor_like<ElementInput>(tSR_sRow_flt);
Tensor tSR_rRow_flt = make_tensor_like<ElementInput_>(tSR_sRow_flt);
copy_aligned(tSR_sRow_flt, tSR_rRow_flt);

constexpr int FrgSize = size(tSR_rRow_flt);
using FrgInput = Array<ElementInput, FrgSize>;
using FrgInput = Array<ElementInput_, FrgSize>;
using FrgCompute = Array<ElementCompute, FrgSize>;
using ConvertInput = NumericArrayConverter<ElementCompute, ElementInput, FrgSize>;
using ConvertInput = NumericArrayConverter<ElementCompute, ElementInput_, FrgSize>;

Tensor tSR_rRow_input_frg = recast<FrgInput>(coalesce(tSR_rRow_flt));
Tensor tSR_rRow_compute_frg = recast<FrgCompute>(filter(tSR_rRow));
Expand Down Expand Up @@ -1183,12 +1191,18 @@ struct Sm90RowBroadcast {

auto layout_M = make_layout(M, repeat_like(M, _0{}));
auto layout_L = make_layout(L, get<2>(params.dRow));
Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row), make_layout(layout_M,layout_N,layout_L));
ElementInput_ const* ptr_row;
if constexpr(is_ptr_array_) {
ptr_row = params.ptr_row[l];
} else {
ptr_row = params.ptr_row;
}
Tensor mRow = make_tensor(make_gmem_ptr(ptr_row), make_layout(layout_M,layout_N,layout_L));
Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N)
Tensor sRow = make_tensor(make_smem_ptr(smem),
make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})), make_shape(_0{}, _1{})); // (CTA_M, CTA_N)
//// G2S: Gmem to Smem
auto tiled_g2s = make_tiled_copy(Copy_Atom<DefaultCopy, ElementInput>{},
auto tiled_g2s = make_tiled_copy(Copy_Atom<DefaultCopy, ElementInput_>{},
Layout< Shape<_1, ThreadCount>,
Stride<_0, _1>>{},
Layout<_1>{});
Expand Down Expand Up @@ -1221,13 +1235,19 @@ template<
int Stages,
class CtaTileShapeMNK,
class ElementInput,
class ElementCompute = ElementInput,
class ElementCompute = std::remove_pointer_t<ElementInput>,
class StrideMNL_ = Stride<_1,_0,_0>,
int Alignment = 128 / sizeof_bits_v<ElementInput>,
int Alignment = 128 / sizeof_bits_v<std::remove_pointer_t<ElementInput>>,
bool EnableNullptr = true // Fallback scalar broadcast for nullptr params
>
struct Sm90ColBroadcast {
using StrideMNL = StrideMNL_;
// Get base element input type.
using ElementInput_ = std::remove_pointer_t<ElementInput>;
// Check if input is an array of pointers.
static constexpr bool is_ptr_array_ = is_same_v<ElementInput_*, ElementInput>;
using PtrColType = std::conditional_t<is_ptr_array_, ElementInput_ const* const*, ElementInput_ const*>;

static_assert(Stages == 0, "Column broadcast doesn't support smem pipelining");

static constexpr bool IsDynamicBroadcast = is_same_v<remove_cvref_t<decltype(get<0>(StrideMNL{}))>, bool>; // Column vector or scalar broadcast
Expand All @@ -1238,13 +1258,13 @@ struct Sm90ColBroadcast {
struct SharedStorage { };

struct Arguments {
ElementInput const* ptr_col = nullptr;
ElementInput null_default = ElementInput(0);
PtrColType ptr_col = nullptr;
ElementInput_ null_default = ElementInput_(0);
StrideMNL dCol = {};
};

struct Params {
ElementInput const* ptr_col = nullptr;
PtrColType ptr_col = nullptr;
ElementCompute null_default = ElementCompute(0);
StrideMNL dCol = {};
};
Expand Down Expand Up @@ -1301,8 +1321,10 @@ struct Sm90ColBroadcast {
is_zero_ = params.null_default == ElementCompute(0);
}
// Dynamic non-batched scalar broadcast
else if (IsDynamicBroadcast && stride_M == bool(0) && stride_L == repeat_like(stride_L, 0)) {
is_zero_ = params.ptr_col[0] == ElementInput(0);
if constexpr(!is_ptr_array_) {
if (IsDynamicBroadcast && stride_M == bool(0) && stride_L == repeat_like(stride_L, 0)) {
is_zero_ = params.ptr_col[0] == ElementInput_(0);
}
}
}

Expand Down Expand Up @@ -1344,13 +1366,13 @@ struct Sm90ColBroadcast {
// Filter so we don't issue redundant copies over stride-0 modes
// (only works if 0-strides are in same location, which is by construction)
Tensor tCgCol_flt = filter_zeros(tCgCol);
Tensor tCrCol_flt = make_tensor_like<ElementInput>(filter_zeros(tCrCol));
Tensor tCrCol_flt = make_tensor_like<ElementInput_>(filter_zeros(tCrCol));
Tensor tCcCol_flt = filter_zeros(tCcCol, tCgCol.stride());

constexpr auto MCL = decltype(max_common_layout(tCgCol_flt, tCrCol_flt)){};
constexpr int V = cute::min(Alignment, size(MCL));
if constexpr (V > 1) {
using VecType = uint_bit_t<V * sizeof_bits_v<ElementInput>>;
using VecType = uint_bit_t<V * sizeof_bits_v<ElementInput_>>;
Tensor tCgCol_vec = recast<VecType>(coalesce(tCgCol_flt));
Tensor tCrCol_vec = recast<VecType>(coalesce(tCrCol_flt));
Tensor tCcCol_vec = tensor<1>(zipped_divide(tCcCol_flt, MCL.compose(Int<V>{})));
Expand All @@ -1363,9 +1385,9 @@ struct Sm90ColBroadcast {
}

constexpr int FrgSize = size(tCrCol_flt);
using FrgInput = Array<ElementInput, FrgSize>;
using FrgInput = Array<ElementInput_, FrgSize>;
using FrgCompute = Array<ElementCompute, FrgSize>;
using ConvertInput = NumericArrayConverter<ElementCompute, ElementInput, FrgSize>;
using ConvertInput = NumericArrayConverter<ElementCompute, ElementInput_, FrgSize>;

Tensor tCrCol_input_frg = recast<FrgInput>(coalesce(tCrCol_flt));
Tensor tCrCol_compute_frg = recast<FrgCompute>(filter(tCrCol));
Expand Down Expand Up @@ -1398,6 +1420,7 @@ struct Sm90ColBroadcast {
get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {

auto [M, N, K, L] = args.problem_shape_mnkl;
auto [m, n, k, l] = args.tile_coord_mnkl;
auto layout_M = [&] () CUTLASS_LAMBDA_FUNC_INLINE {
auto shape_M = get<0>(args.problem_shape_mnkl);
if constexpr (IsDynamicBroadcast) {
Expand All @@ -1416,11 +1439,17 @@ struct Sm90ColBroadcast {

auto layout_N = make_layout(N, repeat_like(N, _0{}));
auto layout_L = make_layout(L, get<2>(params.dCol));
Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col), make_layout(layout_M,layout_N,layout_L));
ElementInput_ const* ptr_col;
if constexpr(is_ptr_array_) {
ptr_col = params.ptr_col[l];
} else {
ptr_col = params.ptr_col;
}
Tensor mCol = make_tensor(make_gmem_ptr(ptr_col), make_layout(layout_M,layout_N,layout_L));
Tensor tCgCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);

Tensor mCol_static = make_tensor(make_gmem_ptr(params.ptr_col), make_layout(make_layout(M),layout_N,layout_L));
Tensor mCol_static = make_tensor(make_gmem_ptr(ptr_col), make_layout(make_layout(M),layout_N,layout_L));
Tensor tCgCol_static = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
mCol_static, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
Tensor tCrCol = make_tensor_like<ElementCompute>(tCgCol_static); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
Expand Down