Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update bandwidth and latency calculations, add multi work group support #30

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 88 additions & 42 deletions tests/functional_tests/alltoall_tester.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@
* IN THE SOFTWARE.
*****************************************************************************/

using namespace rocshmem;

/* Declare the template with a generic implementation */
template <typename T>
__device__ void wg_alltoall(rocshmem_ctx_t ctx, rocshmem_team_t team, T *dest,
Expand Down Expand Up @@ -52,30 +50,32 @@ ALLTOALL_DEF_GEN(unsigned int, uint)
ALLTOALL_DEF_GEN(unsigned long, ulong)
ALLTOALL_DEF_GEN(unsigned long long, ulonglong)

rocshmem_team_t team_alltoall_world_dup;

/******************************************************************************
* DEVICE TEST KERNEL
*****************************************************************************/
template <typename T1>
__global__ void AlltoallTest(int loop, int skip, uint64_t *timer,
T1 *source_buf, T1 *dest_buf, int size,
ShmemContextType ctx_type, rocshmem_team_t team) {
__global__ void AlltoallTest(int loop, int skip, uint64_t *start_time,
uint64_t *end_time, T1 *source_buf, T1 *dest_buf,
int size, ShmemContextType ctx_type,
rocshmem_team_t *teams) {
__shared__ rocshmem_ctx_t ctx;
int wg_id = get_flat_grid_id();

rocshmem_wg_init();
rocshmem_wg_ctx_create(ctx_type, &ctx);
rocshmem_wg_team_create_ctx(teams[wg_id], ctx_type, &ctx);

int n_pes = rocshmem_ctx_n_pes(ctx);

source_buf += wg_id * n_pes * size;
dest_buf += wg_id * n_pes * size;

__syncthreads();

uint64_t start;
for (int i = 0; i < loop + skip; i++) {
if (i == skip && hipThreadIdx_x == 0) {
start = rocshmem_timer();
start_time[wg_id] = wall_clock64();
}
wg_alltoall<T1>(ctx, team,
wg_alltoall<T1>(ctx, teams[wg_id],
dest_buf, // T* dest
source_buf, // const T* source
size); // int nelement
Expand All @@ -84,7 +84,7 @@ __global__ void AlltoallTest(int loop, int skip, uint64_t *timer,
__syncthreads();

if (hipThreadIdx_x == 0) {
timer[hipBlockIdx_x] = rocshmem_timer() - start;
end_time[wg_id] = wall_clock64();
}

rocshmem_wg_ctx_destroy(&ctx);
Expand All @@ -95,69 +95,115 @@ __global__ void AlltoallTest(int loop, int skip, uint64_t *timer,
* HOST TESTER CLASS METHODS
*****************************************************************************/
template <typename T1>
AlltoallTester<T1>::AlltoallTester(
TesterArguments args, std::function<void(T1 &, T1 &, T1)> f1,
std::function<std::pair<bool, std::string>(const T1 &, T1)> f2)
: Tester(args), init_buf{f1}, verify_buf{f2} {
int n_pes = rocshmem_team_n_pes(ROCSHMEM_TEAM_WORLD);
source_buf = (T1 *)rocshmem_malloc(args.max_msg_size * sizeof(T1) * n_pes);
dest_buf = (T1 *)rocshmem_malloc(args.max_msg_size * sizeof(T1) * n_pes);
AlltoallTester<T1>::AlltoallTester(TesterArguments args)
: Tester(args){
my_pe = rocshmem_team_my_pe(ROCSHMEM_TEAM_WORLD);
n_pes = rocshmem_team_n_pes(ROCSHMEM_TEAM_WORLD);

int num_elems = (args.max_msg_size / sizeof(T1)) * args.num_wgs * n_pes;
int buff_size = num_elems * sizeof(T1);

source_buf = (T1 *)rocshmem_malloc(buff_size);
dest_buf = (T1 *)rocshmem_malloc(buff_size);

char* value{nullptr};
if ((value = getenv("ROCSHMEM_MAX_NUM_TEAMS"))) {
num_teams = atoi(value);
}

CHECK_HIP(hipMalloc(&team_alltoall_world_dup,
sizeof(rocshmem_team_t) * num_teams));
}

template <typename T1>
AlltoallTester<T1>::~AlltoallTester() {
rocshmem_free(source_buf);
rocshmem_free(dest_buf);
CHECK_HIP(hipFree(team_alltoall_world_dup));
}

template <typename T1>
void AlltoallTester<T1>::preLaunchKernel() {
int n_pes = rocshmem_team_n_pes(ROCSHMEM_TEAM_WORLD);
bw_factor = sizeof(T1) * n_pes;

team_alltoall_world_dup = ROCSHMEM_TEAM_INVALID;
rocshmem_team_split_strided(ROCSHMEM_TEAM_WORLD, 0, 1, n_pes, nullptr, 0,
&team_alltoall_world_dup);
bw_factor = n_pes;

for (int team_i = 0; team_i < num_teams; team_i++) {
team_alltoall_world_dup[team_i] = ROCSHMEM_TEAM_INVALID;
rocshmem_team_split_strided(ROCSHMEM_TEAM_WORLD, 0, 1, n_pes, nullptr, 0,
&team_alltoall_world_dup[team_i]);
if (team_alltoall_world_dup[team_i] == ROCSHMEM_TEAM_INVALID) {
std::cout << "Team " << team_i << " is invalid!" << std::endl;
abort();
}
}
}

template <typename T1>
void AlltoallTester<T1>::launchKernel(dim3 gridSize, dim3 blockSize, int loop,
uint64_t size) {
size_t shared_bytes = 0;

int num_elems = size / sizeof(T1);

hipLaunchKernelGGL(AlltoallTest<T1>, gridSize, blockSize, shared_bytes,
stream, loop, args.skip, timer, source_buf, dest_buf, size,
_shmem_context, team_alltoall_world_dup);
stream, loop, args.skip, start_time, end_time,
source_buf, dest_buf, num_elems, _shmem_context,
team_alltoall_world_dup);

num_msgs = loop + args.skip;
num_timed_msgs = loop;
num_msgs = (loop + args.skip) * gridSize.x;
num_timed_msgs = loop * gridSize.x;
}

template <typename T1>
void AlltoallTester<T1>::postLaunchKernel() {
rocshmem_team_destroy(team_alltoall_world_dup);
for (int team_i = 0; team_i < num_teams; team_i++) {
rocshmem_team_destroy(team_alltoall_world_dup[team_i]);
}
}

template <typename T1>
void AlltoallTester<T1>::resetBuffers(uint64_t size) {
int n_pes = rocshmem_team_n_pes(ROCSHMEM_TEAM_WORLD);
for (int i = 0; i < n_pes; i++) {
for (uint64_t j = 0; j < size; j++) {
init_buf(source_buf[i * size + j], dest_buf[i * size + j], (T1)i);

int num_elems = size / sizeof(T1);
int buff_size = num_elems * sizeof(T1) * args.num_wgs * n_pes;
int idx = 0;

for(int wg_id = 0; wg_id < args.num_wgs; wg_id++) {
for(int pe = 0; pe < n_pes; pe++) {
for(int i = 0; i < num_elems; i++) {
idx = (wg_id * n_pes + pe) * num_elems + i;
if constexpr (std::is_same<T1, char>::value ||
std::is_same<T1, signed char>::value ||
std::is_same<T1, unsigned char>::value) {
source_buf[idx] = static_cast<T1>('a' + my_pe + pe + wg_id);
}
else if constexpr (std::is_floating_point<T1>::value) {
source_buf[idx] = static_cast<T1>(3.14 + my_pe + pe + wg_id);
}
else if constexpr (std::is_integral<T1>::value) {
source_buf[idx] = static_cast<T1>(my_pe + pe + wg_id);
}
}
}
}

memset(dest_buf, -1, buff_size);
}

template <typename T1>
void AlltoallTester<T1>::verifyResults(uint64_t size) {
int n_pes = rocshmem_team_n_pes(ROCSHMEM_TEAM_WORLD);
for (int i = 0; i < n_pes; i++) {
for (uint64_t j = 0; j < size; j++) {
auto r = verify_buf(dest_buf[i * size + j], i);
if (r.first == false) {
fprintf(stderr, "Data validation error at idx %lu\n", j);
fprintf(stderr, "%s.\n", r.second.c_str());
exit(-1);
int num_elems = size / sizeof(T1);
int idx = 0;

for(int wg_id = 0; wg_id < args.num_wgs; wg_id++) {
for(int pe = 0; pe < n_pes; pe++) {
for(int i = 0; i < num_elems; i++) {
idx = (wg_id * n_pes + pe) * num_elems + i;
if (dest_buf[idx] != source_buf[idx]) {
std::cerr << "Data validation error at idx " << idx << std::endl;
std::cerr << "PE " << my_pe << " Got " << dest_buf[idx]
<< ", Expected " << source_buf[idx] << std::endl;
exit(-1);
}
}
}
}
Expand Down
23 changes: 15 additions & 8 deletions tests/functional_tests/alltoall_tester.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,15 @@

#include "tester.hpp"

using namespace rocshmem;

/************* *****************************************************************
* HOST TESTER CLASS
*****************************************************************************/
template <typename T1>
class AlltoallTester : public Tester {
public:
explicit AlltoallTester(
TesterArguments args, std::function<void(T1 &, T1 &, T1)> f1,
std::function<std::pair<bool, std::string>(const T1 &, T1)> f2);
explicit AlltoallTester(TesterArguments args);
virtual ~AlltoallTester();

protected:
Expand All @@ -51,12 +51,19 @@ class AlltoallTester : public Tester {

virtual void verifyResults(uint64_t size) override;

T1 *source_buf;
T1 *dest_buf;
T1 *source_buf = nullptr;
T1 *dest_buf = nullptr;

private:
int my_pe = 0;
int n_pes = 0;

private:
std::function<void(T1 &, T1 &, T1)> init_buf;
std::function<std::pair<bool, std::string>(const T1 &, T1)> verify_buf;
/**
* This constant should equal ROCSHMEM_MAX_NUM_TEAMS - 1.
* The default value for the maximum number of teams is 40.
*/
int num_teams = 39;
rocshmem_team_t *team_alltoall_world_dup;
};

#include "alltoall_tester.cpp"
Expand Down
21 changes: 12 additions & 9 deletions tests/functional_tests/amo_bitwise_tester.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@ using namespace rocshmem;

/* Declare the global kernel template with a generic implementation */
template <typename T>
__global__ void AMOBitwiseTest(int loop, int skip, uint64_t *timer, char *r_buf,
T *s_buf, T *ret_val, TestType type,
__global__ void AMOBitwiseTest(int loop, int skip, uint64_t *start_time,
uint64_t *end_time, char *r_buf, T *s_buf,
T *ret_val, TestType type,
ShmemContextType ctx_type) {
return;
}
Expand Down Expand Up @@ -64,8 +65,8 @@ void AMOBitwiseTester<T>::launchKernel(dim3 gridsize, dim3 blocksize, int loop,
size_t shared_bytes = 0;

hipLaunchKernelGGL(AMOBitwiseTest, gridsize, blocksize, shared_bytes, stream,
loop, args.skip, timer, _r_buf, _s_buf, _ret_val, _type,
_shmem_context);
loop, args.skip, start_time, end_time, _r_buf, _s_buf,
_ret_val, _type, _shmem_context);

_gridSize = gridsize;
num_msgs = (loop + args.skip) * gridsize.x;
Expand Down Expand Up @@ -123,17 +124,19 @@ void AMOBitwiseTester<T>::verifyResults(uint64_t size) {
#define AMO_BITWISE_DEF_GEN(T, TNAME) \
template <> \
__global__ void AMOBitwiseTest<T>( \
int loop, int skip, uint64_t *timer, char *r_buf, T *s_buf, T *ret_val, \
TestType type, ShmemContextType ctx_type) { \
int loop, int skip, uint64_t *start_time, uint64_t *end_time, \
char *r_buf, T *s_buf, T *ret_val, TestType type, \
ShmemContextType ctx_type) { \
__shared__ rocshmem_ctx_t ctx; \
rocshmem_wg_init(); \
rocshmem_wg_ctx_create(ctx_type, &ctx); \
if (hipThreadIdx_x == 0) { \
uint64_t start; \
T ret = 0; \
T cond = 0; \
for (int i = 0; i < loop + skip; i++) { \
if (i == skip) start = rocshmem_timer(); \
if (i == skip) { \
start_time[hipBlockIdx_x] = wall_clock64(); \
} \
switch (type) { \
case AMO_FetchAndTestType: \
ret = rocshmem_ctx_##TNAME##_atomic_fetch_and(ctx, (T *)r_buf, \
Expand Down Expand Up @@ -161,7 +164,7 @@ void AMOBitwiseTester<T>::verifyResults(uint64_t size) {
} \
} \
rocshmem_ctx_quiet(ctx); \
timer[hipBlockIdx_x] = rocshmem_timer() - start; \
end_time[hipBlockIdx_x] = wall_clock64(); \
ret_val[hipBlockIdx_x] = ret; \
rocshmem_ctx_getmem(ctx, &s_buf[hipBlockIdx_x], r_buf, sizeof(T), 1); \
} \
Expand Down
Loading