Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Experimental] Support on arbitrary M & K for siwzzle-A #1521

Open
wants to merge 7 commits into
base: stable3
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 88 additions & 3 deletions clients/include/TensorDataManipulation.hpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,28 @@
/*******************************************************************************
*
* MIT License
*
* Copyright (C) 2024 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#pragma once
#include <cassert>
#include <functional>
Expand Down Expand Up @@ -118,6 +143,24 @@ namespace Tensor
return flattenSize() == newDesc.flattenSize();
}

bool canShapePadTo(const Shape& shape) const
{
if(this->shape.size() != shape.size())
{
return false;
}

for(size_t i = 0; i < this->shape.size(); ++i)
{
if(this->shape.at(i) > shape.at(i))
{
return false;
}
}

return true;
}

private:
Shape shape;
Strides strides;
Expand Down Expand Up @@ -196,7 +239,8 @@ namespace Tensor
return elementSize;
}

size_t getNumBytes() const {
size_t getNumBytes() const
{
return getDesc().flattenSize() * getElementSize();
}

Expand Down Expand Up @@ -280,6 +324,46 @@ namespace Tensor
return permuted;
}

template <typename T>
Tensor pad(const Tensor& src, const Shape& newShape, T padVal)
{
assert(src.getDesc().canShapePadTo(newShape) && "Invalid shape for padding");
Tensor dst(newShape, sizeof(T));
Indices indices(src.getDesc().numDims(), 0);

iterate(dst.getDesc().getShape(), 0, indices, [&dst, &padVal](const Indices& indices) {
dst.setValue<T>(indices, padVal);
});

iterate(src.getDesc().getShape(), 0, indices, [&dst, &src](const Indices& indices) {
auto&& value = src.getValue<T>(indices);
dst.setValue<T>(indices, value);
});
return dst;
}

Tensor pad(const Tensor& tensor,
const Shape& newShape,
const void* padValPtr,
size_t padValSize)
{
switch(padValSize)
{
case 1:
return pad<uint8_t>(tensor, newShape, *static_cast<const uint8_t*>(padValPtr));
case 2:
return pad<uint16_t>(tensor, newShape, *static_cast<const uint16_t*>(padValPtr));
case 4:
return pad<uint32_t>(tensor, newShape, *static_cast<const uint32_t*>(padValPtr));
case 8:
return pad<uint64_t>(tensor, newShape, *static_cast<const uint64_t*>(padValPtr));
default:
assert(false && "Unsupported element size");
}

return Tensor({0}, tensor.getElementSize());
}

Tensor permute(const Tensor& tensor, const Permutation& perm)
{
Shape newShape = permute(tensor.getDesc().getShape(), perm);
Expand All @@ -292,7 +376,6 @@ namespace Tensor
case 2:
permute<uint16_t>(permuted, tensor, perm);
break;
break;
case 4:
permute<uint32_t>(permuted, tensor, perm);
break;
Expand Down Expand Up @@ -331,7 +414,9 @@ namespace Tensor
tensor.getDesc().getShape(),
0,
indices,
[&os, &tensor](const Indices& idx) { os << float(tensor.getValue<T>(idx)) << ", "; },
[&os, &tensor](const Indices& idx) {
os << float(tensor.getValue<T>(idx)) << ", ";
},
[&os](size_t dim) { os << "["; },
[&os, &tensor](size_t dim) {
os << "], ";
Expand Down
29 changes: 22 additions & 7 deletions clients/include/testing_matmul.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,15 @@ void swizzle_tensor(T *dst, const T *src, size_t b, size_t m, size_t k, bool col
tmpTensor = permute(orgTensor, {0, 2, 1});
}

tmpTensor.reshape({b, m / MiM, MiM, k / (MiK * PackK), MiK / MiKv , MiKv * PackK});
Tensor permuted = permute(tmpTensor, {0, 1, 3, 4, 2, 5});
memcpy(dst, permuted. template as<void>(), numElements * sizeof(T));
constexpr auto MultipleM = MiM;
constexpr auto MultipleK = MiK * PackK;
const auto paddedM = (m / MultipleM + !!(m % MultipleM)) * MultipleM;
const auto paddedK = (k / MultipleK + !!(k % MultipleK)) * MultipleK;
::Tensor::Manipulation::Shape paddedShape{b, paddedM, paddedK};
auto paddedTensor = ::Tensor::Manipulation::pad(tmpTensor, paddedShape, T(0));
paddedTensor.reshape({b, paddedM / MiM, MiM, paddedK / (MiK * PackK), MiK / MiKv , MiKv * PackK});
Tensor permuted = permute(paddedTensor, {0, 1, 3, 4, 2, 5});
memcpy(dst, permuted. template as<void>(), b * paddedM * paddedK * sizeof(T));
}

inline void pre_gpu_time(bool use_gpu_timer,
Expand Down Expand Up @@ -1136,8 +1142,17 @@ void testing_matmul_with_bias(const Arguments& arg,
stride_d[i] = do_batched[i] ? arg.stride_c[i] : ldd[i] * N[i];
stride_e[i] = do_batched[i] ? arg.stride_e[i] : lde[i] * N[i];

size_A[i]
= stride_a[i] == 0 ? lda[i] * A_col[i] * num_batches[i] : stride_a[i] * num_batches[i];
if(arg.swizzle_a)
{
//TODO: support different swizzle type
size_A[i] = num_batches[i] * ((M[i] + 15) / 16) * 16 * ((K[i] + 31) / 32) * 32;
}
else
{
size_A[i]
= stride_a[i] == 0 ? lda[i] * A_col[i] * num_batches[i] : stride_a[i] * num_batches[i];
}

size_B[i]
= stride_b[i] == 0 ? ldb[i] * B_col[i] * num_batches[i] : stride_b[i] * num_batches[i];
size_C[i]
Expand Down Expand Up @@ -1219,7 +1234,7 @@ void testing_matmul_with_bias(const Arguments& arg,

if(arg.swizzle_a && TiA == HIP_R_16F)
{
hipblasLtOrder_t orderA = HIPBLASLT_ORDER_ROW16_32C_8;
hipblasLtOrder_t orderA = HIPBLASLT_ORDER_COL16_4R8;
CHECK_HIPBLASLT_ERROR(hipblasLtMatrixLayoutSetAttribute(matA[i], HIPBLASLT_MATRIX_LAYOUT_ORDER, &orderA, sizeof(orderA)));
}

Expand Down Expand Up @@ -1516,7 +1531,7 @@ void testing_matmul_with_bias(const Arguments& arg,

if(arg.swizzle_a && TiA == HIP_R_16F)
{
HipHostBuffer tmp(TiA, num_batches[i] * M[i] * K[i]);
HipHostBuffer tmp(TiA, size_A[i]);
swizzle_tensor(tmp.as<hipblasLtHalf>(), hA[i].as<hipblasLtHalf>(), num_batches[i], M[i], K[i], false);
CHECK_HIP_ERROR(synchronize(dA[i], tmp, block_count));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,15 @@ void swizzleTensor(T *dst, const T *src, size_t m, size_t k, bool colMaj)
memcpy(orgTensor. template as<void>(), src, m * k * sizeof(T));
tmpTensor = permute(orgTensor, {1, 0});
}

tmpTensor.reshape({m / MiM, MiM, k / (MiK * PackK), MiK / MiKv , MiKv * PackK});
Tensor permuted = permute(tmpTensor, {0, 2, 3, 1, 4});
memcpy(dst, permuted. template as<void>(), m * k * sizeof(T));
constexpr auto MultipleM = MiM;
constexpr auto MultipleK = MiK * PackK;
const auto paddedM = (m / MultipleM + !!(m % MultipleM)) * MultipleM;
const auto paddedK = (k / MultipleK + !!(k % MultipleK)) * MultipleK;
::Tensor::Manipulation::Shape paddedShape{paddedM, paddedK};
auto paddedTensor = ::Tensor::Manipulation::pad(tmpTensor, paddedShape, T(0));
paddedTensor.reshape({paddedM / MiM, MiM, paddedK / (MiK * PackK), MiK / MiKv , MiKv * PackK});
Tensor permuted = permute(paddedTensor, {0, 2, 3, 1, 4});
memcpy(dst, permuted. template as<void>(), paddedM * paddedK * sizeof(T));
}

void swizzleGemmEpilogueBiasVecExt(hipblasLtHandle_t handle,
Expand Down Expand Up @@ -135,7 +140,7 @@ void swizzleGemmEpilogueBiasVecExt(hipblasLtHandle_t handle,

if(swizzleA)
{
hipblasLtOrder_t orderA = HIPBLASLT_ORDER_ROW16_32C_8;
hipblasLtOrder_t orderA = HIPBLASLT_ORDER_COL16_4R8;
CHECK_HIPBLASLT_ERROR(hipblasLtMatrixLayoutSetAttribute(matA, HIPBLASLT_MATRIX_LAYOUT_ORDER, &orderA, sizeof(orderA)));
std::vector<hipblasLtHalf> src(m * k, 0);
std::vector<hipblasLtHalf> dst(m * k, 0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ void simpleGemm(hipblasLtHandle_t handle,

if(swizzleA)
{
hipblasLtOrder_t orderA = HIPBLASLT_ORDER_ROW16_32C_8;
hipblasLtOrder_t orderA = HIPBLASLT_ORDER_COL16_4R8;
CHECK_HIPBLASLT_ERROR(hipblasLtMatrixLayoutSetAttribute(matA, HIPBLASLT_MATRIX_LAYOUT_ORDER, &orderA, sizeof(orderA)));
std::vector<hipblasLtHalf> src(m * k, 0);
std::vector<hipblasLtHalf> dst(m * k, 0);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*******************************************************************************
*
* MIT License
*
* Copyright (C) 2024 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#include <array>
#include <cstddef>
#include <iostream>
#include "TensorDataManipulation.hpp"

int main(int argc, char **argv)
{
constexpr size_t m{18};
constexpr size_t k{34};
auto weight = Tensor::Manipulation::Tensor::create<int>({m, k});

for (size_t i = 0; i < m; ++i)
{
for (size_t j = 0; j < k; ++j)
{
weight.setValue<int>({i, j}, i * k + j);
}
}

std::cout << "Original weight:\n";
Tensor::Manipulation::printTensorDataMultiDims<int>(std::cout, weight);
constexpr size_t MiM = 16;
constexpr size_t MiK = 16;
constexpr size_t MiKv = 4;
constexpr size_t PackK = 2;
constexpr auto MultipleM = MiM;
constexpr auto MultipleK = MiK * PackK;
const auto paddedM = (m / MultipleM + !!(m % MultipleM)) * MultipleM;
const auto paddedK = (k / MultipleK + !!(k % MultipleK)) * MultipleK;
Tensor::Manipulation::Shape paddedShape{paddedM, paddedK};
auto paddedWeight = ::Tensor::Manipulation::pad(weight, paddedShape, 0);
std::cout << "Padded weight:\n";
Tensor::Manipulation::printTensorDataMultiDims<int>(std::cout, paddedWeight);
paddedWeight.reshape({paddedM / MiM, MiM, paddedK / (MiK * PackK), MiK / MiKv , MiKv * PackK});
Tensor::Manipulation::Tensor permuted = permute(paddedWeight, {0, 2, 3, 1, 4});
std::cout << "Swizzle weight:\n";
Tensor::Manipulation::printTensorDataMultiDims<int>(std::cout, permuted);
return 0;
}
4 changes: 3 additions & 1 deletion clients/samples/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ add_executable( sample_hipblaslt_ext_op_amax_with_scale 24_amax_with_scale_ext/s
add_executable( sample_hipblaslt_gemm_with_TF32 25_gemm_with_TF32/sample_hipblaslt_gemm_with_TF32.cpp)
add_executable( sample_hipblaslt_gemm_swizzle_a 26_gemm_swizzle_a/sample_hipblaslt_gemm_swizzle_a.cpp)
add_executable( sample_hipblaslt_gemm_bias_swizzle_a_ext 26_gemm_swizzle_a/sample_hipblaslt_gemm_bias_swizzle_a_ext.cpp)
add_executable( sample_hipblaslt_weight_swizzle_padding 26_gemm_swizzle_a/sample_hipblaslt_weight_swizzle_padding.cpp)


set(samples sample_hipblaslt_gemm
Expand Down Expand Up @@ -101,7 +102,8 @@ set(samples sample_hipblaslt_gemm
sample_hipblaslt_ext_op_amax_with_scale
sample_hipblaslt_gemm_with_TF32
sample_hipblaslt_gemm_swizzle_a
sample_hipblaslt_gemm_bias_swizzle_a_ext)
sample_hipblaslt_gemm_bias_swizzle_a_ext
sample_hipblaslt_weight_swizzle_padding)

set( sample_list_all ${samples})

Expand Down
14 changes: 14 additions & 0 deletions docs/api-reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,20 @@ For FP8 type Matmul, hipBLASLt supports the type combinations shown in the follo
| | | BF8 | BF8 | | | | FP32, FP16 | FP16 |
+-------+-------+-------+-------+-------------+----------+----------+------------+-----------+

In order to use FP16-specific data ordering HIPBLASLT_ORDER_COL16_4R8 for gfx94x architecture, here's the valid combinations of transposes and orders of input and output matrices:

+-------+-------+-------+-------+-----------------------------+-----------------------------+---------------------+---------------------+
| Atype | Btype | opA | opB | orderA | orderB | orderC | orderD |
+=======+=======+=======+=======+=============================+=============================+=====================+=====================+
| FP16 | FP16 | T | N | HIPBLASLT_ORDER_COL16_4R8 | HIPBLASLT_ORDER_COL | HIPBLASLT_ORDER_COL | HIPBLASLT_ORDER_COL |
+-------+-------+-------+-------+-----------------------------+-----------------------------+---------------------+---------------------+
| FP16 | FP16 | T | T | HIPBLASLT_ORDER_COL16_4R8 | HIPBLASLT_ORDER_COL | HIPBLASLT_ORDER_COL | HIPBLASLT_ORDER_COL |
+-------+-------+-------+-------+-----------------------------+-----------------------------+---------------------+---------------------+
| FP16 | FP16 | N | N | HIPBLASLT_ORDER_COL | HIPBLASLT_ORDER_COL16_4R8 | HIPBLASLT_ORDER_COL | HIPBLASLT_ORDER_COL |
+-------+-------+-------+-------+-----------------------------+-----------------------------+---------------------+---------------------+
| FP16 | FP16 | T | N | HIPBLASLT_ORDER_COL | HIPBLASLT_ORDER_COL16_4R8 | HIPBLASLT_ORDER_COL | HIPBLASLT_ORDER_COL |
+-------+-------+-------+-------+-----------------------------+-----------------------------+---------------------+---------------------+

hipblasLtMatrixTransformDescCreate()
------------------------------------------
.. doxygenfunction:: hipblasLtMatrixTransformDescCreate
Expand Down
9 changes: 8 additions & 1 deletion library/include/hipblaslt.h
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,14 @@ typedef enum {
* Leading dimension is the stride (in elements) to the beginning of next row in memory.
*/
HIPBLASLT_ORDER_ROW = 1,
HIPBLASLT_ORDER_ROW16_32C_8 = 2
/**
* Data is ordered in column-major ordered tiles of composite tiles with total 16 columns ands 32 rows.
* A tile is composed of 4 inner tiles in column-major with total 8 rows and 16 columns.
* Element offset within the tile is calculated as row%8+8*col+(row/8)*16*8.
* Note that for this order, the number of columns(rows) of the tensor has to be multiple of 16(32) or
* pre-padded to 16(32).
*/
HIPBLASLT_ORDER_COL16_4R8 = 100
} hipblasLtOrder_t;

/** Matrix transform descriptor attributes to define details of the operation.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,8 @@ RocblasltContractionProblem construct_rocblaslt_problem(rocblaslt_handle
rocblaslt_compute_type compute_type;
void * bias = nullptr, *scaleAlphaVec = nullptr, *e = nullptr;
bool gradient = false;
bool swizzleA = matA->order == HIPBLASLT_ORDER_ROW16_32C_8;
bool swizzleA = matA->order == HIPBLASLT_ORDER_COL16_4R8;
bool swizzleB = matB->order == HIPBLASLT_ORDER_COL16_4R8;
rocblaslt_status isValid = rocblaslt_matmul_valid_args(matmul_descr,
dummy_ptr,
dummy_ptr,
Expand Down Expand Up @@ -302,8 +303,7 @@ RocblasltContractionProblem construct_rocblaslt_problem(rocblaslt_handle
gradient,
compute_type,
swizzleA,
/*TODO: Currently we don't support swizzle B */
false);
swizzleB);
if(isValid != rocblaslt_status_continue)
{
m = 0;
Expand Down Expand Up @@ -386,8 +386,7 @@ RocblasltContractionProblem construct_rocblaslt_problem(rocblaslt_handle
nullptr,
handle->Synchronizer,
swizzleA,
/*TODO: Currently we don't support swizzle B */
false};
swizzleB};

return problem;
}
Expand Down
Loading