Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

reinterpreted structs to match #1596

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 30 additions & 1 deletion library/src/amd_detail/rocblaslt/include/rocblaslt-types.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@
#include <stdint.h>
#include <vector>

#include <hipblaslt-ext.hpp>

#define ROCBLASLT_KERNEL __global__
#define ROCBLASLT_DEVICE_ILF __device__

Expand Down Expand Up @@ -369,13 +371,34 @@ typedef enum rocblaslt_matmul_preference_attributes_
/********************************************************************************
* \brief rocblaslt_matmul_algo holds the description of the matrix
* multiplication algorithm.
*******************************************************************************/
typedef struct __attribute__((packed, aligned(8))) _rocblaslt_matmul_algo
{
uint8_t data[8] = {0};
bool fallback = false;
size_t max_workspace_bytes = 0;
} rocblaslt_matmul_algo;
*******************************************************************************/

/********************************************************************************
* \brief rocblaslt_matmul_algo holds the description of the matrix
* multiplication algorithm.
*******************************************************************************/
typedef struct _rocblaslt_matmul_algo{
#ifdef __cplusplus
uint8_t data[8] = {0}; // must match hipblasLtMatmulAlgo_t layout
bool fallback = false; //
uint8_t data_pad[7] = {0}; // has uint8_t data[16]
size_t max_workspace_bytes = 0;
#else
uint8_t data[8];
bool fallback;
uint8_t data_pad[7];
size_t max_workspace_bytes;
#endif
} rocblaslt_matmul_algo;

static_assert(sizeof(rocblaslt_matmul_algo) == sizeof(hipblasLtMatmulAlgo_t),
"rocblaslt_matmul_algo struct does not match size of hipblasLtMatmulAlgo_t");

/********************************************************************************
* \brief rocblaslt_matmul_heuristic holds the configured matrix
Expand Down Expand Up @@ -448,6 +471,9 @@ namespace rocblaslt
int aux_stride = 0;
};

static_assert(sizeof(RocGemmEpilogue) == sizeof(hipblaslt_ext::GemmEpilogue),
"RocGemmEpilogue struct does not match size of hipblaslt_ext::GemmEpilogue");

class RocGemmEpilogueV2
{
public:
Expand Down Expand Up @@ -491,6 +517,9 @@ namespace rocblaslt
void* aux = nullptr;
};

static_assert(sizeof(RocGemmInputs) == sizeof(hipblaslt_ext::GemmInputs),
"RocGemmInputs struct does not match size of hipblaslt_ext::GemmInputs");

struct RocGemmInputsV2
{
void* a = nullptr;
Expand Down