Skip to content

Commit

Permalink
Stedc merge optimization (#886)
Browse files Browse the repository at this point in the history
* add padding + external gemm
* add padding + external gemm. Clean code and comments
* update
* refactor padding mechanism
* Clear `tempgemm` to prevent erros due to multiplication with spurious NaNs (#16)
* Apply clang-format
* move tempgemm initialization to template function
* addressed review comments

---------

Co-authored-by: Julio Machado Silva <[email protected]>
  • Loading branch information
jzuniga-amd and jmachado-amd authored Jan 29, 2025
1 parent 0df492a commit c601117
Show file tree
Hide file tree
Showing 5 changed files with 137 additions and 75 deletions.
2 changes: 1 addition & 1 deletion clients/common/auxiliary/testing_stedc.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/* **************************************************************************
* Copyright (C) 2020-2024 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (C) 2020-2025 Advanced Micro Devices, Inc. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions
Expand Down
2 changes: 1 addition & 1 deletion library/src/auxiliary/rocauxiliary_stedc.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/* **************************************************************************
* Copyright (C) 2021-2024 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions
Expand Down
193 changes: 123 additions & 70 deletions library/src/auxiliary/rocauxiliary_stedc.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
* Univ. of Tennessee, Univ. of California Berkeley,
* Univ. of Colorado Denver and NAG Ltd..
* December 2016
* Copyright (C) 2021-2024 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions
Expand Down Expand Up @@ -46,6 +46,12 @@ ROCSOLVER_BEGIN_NAMESPACE
#define STEDC_BDIM 512 // Number of threads per thread-block used in main stedc kernels
#define MAXITERS 50 // Max number of iterations for root finding method

// TODO: using macro STEDC_EXTERNAL_GEMM = true for now. In the future we can pass
// STEDC_EXTERNAL_GEMM at run time to switch between internal vector updates and
// external gemm-based updates.
#define STEDC_EXTERNAL_GEMM true


typedef enum rocsolver_stedc_mode_
{
rocsolver_stedc_mode_qr,
Expand Down Expand Up @@ -1570,8 +1576,9 @@ ROCSOLVER_KERNEL void __launch_bounds__(STEDC_BDIM)
}
}


//--------------------------------------------------------------------------------------//
/** STEDC_MERGEVECTORS_KERNEL merges vectors from the secular equation for
/** STEDC_MERGEVECTORS_KERNEL prepares vectors from the secular equation for
every pair of sub-blocks that need to be merged in a split block. A matrix in the batch
could have many split-blocks, and each split-block could be divided in a maximum of nn sub-blocks.
- Call this kernel with batch_count groups in z, STEDC_NUM_SPLIT_BLKS groups in y,
Expand All @@ -1583,7 +1590,7 @@ ROCSOLVER_KERNEL void __launch_bounds__(STEDC_BDIM)
- An upper bound for the number of sub-blocks (nn) can be estimated from
the size n. If a group has an id larger than the actual number of columns n,
it will do nothing. **/
template <rocsolver_stedc_mode MODE, typename S>
template <rocsolver_stedc_mode MODE, bool USEGEMM, typename S>
ROCSOLVER_KERNEL void __launch_bounds__(STEDC_BDIM)
stedc_mergeVectors_kernel(const rocblas_int k,
const rocblas_int n,
Expand Down Expand Up @@ -1742,74 +1749,106 @@ ROCSOLVER_KERNEL void __launch_bounds__(STEDC_BDIM)
}
__syncthreads();

// 3f. Compute vectors corresponding to non-deflated values
// 3f. Prepare vectors corresponding to non-deflated values
/* ----------------------------------------------------------------- */
S temp, nrm;
rocblas_int j = vidb;
bool go = (j < ns[tid] && idd[p2 + j] == 1);
bool go = (j < ns[tid]);
S* putvec = USEGEMM ? vecs : temps;

if(go)
{
// compute vectors of rank-1 perturbed system and their norms
nrm = 0;
for(int i = tidb; i < dd; i += dim)
{
valf = zz[i] / temps[i + (p2 + j) * n];
nrm += valf * valf;
temps[i + (p2 + j) * n] = valf;
}
inrms[tidb] = nrm;
__syncthreads();

// reduction (for the norms)
for(int r = dim / 2; r > 0; r /= 2)
if(idd[p2 + j] == 1)
{
if(tidb < r)
// compute vectors of rank-1 perturbed system and their norms
nrm = 0;
for(int i = tidb; i < dd; i += dim)
{
nrm += inrms[tidb + r];
inrms[tidb] = nrm;
valf = zz[i] / temps[i + (p2 + j) * n];
nrm += valf * valf;
putvec[i + (p2 + j) * n] = valf;
}
inrms[tidb] = nrm;
__syncthreads();

// reduction (for the norms)
for(int r = dim / 2; r > 0; r /= 2)
{
if(tidb < r)
{
nrm += inrms[tidb + r];
inrms[tidb] = nrm;
}
__syncthreads();
}
nrm = sqrt(inrms[0]);
}
nrm = sqrt(nrm);

// multiply by C (row by row)
for(int ii = 0; ii < tsz; ++ii)
if(USEGEMM)
{
rocblas_int i = in + ii;

// inner products
temp = 0;
if(ii < sz)
// when using external gemms for the update, we need to
// put vectors in padded matrix 'temps'
// (this is to compute 'vecs = C * temps' using external gemm call)
for(int i = tidb; i < in + sz; i += dim)
{
for(int kk = tidb; kk < dd; kk += dim)
temp += C[i + (per[kk] + in) * ldc] * temps[kk + (p2 + j) * n];
if(i >= in && idd[p2 + j] == 1 && idd[i] == 1)
{
dd = 0;
for(int k = in; k < i; ++k)
{
if(idd[k] == 0)
dd++;
}
temps[pers[i - dd] + in + (p2 + j) * n] = vecs[i - dd - in + (p2 + j) * n] / nrm;
}
else
temps[i + (p2 + j) * n] = 0;
}
inrms[tidb] = temp;
__syncthreads();

// reduction
for(int r = dim / 2; r > 0; r /= 2)
}
else
{
// otherwise, use internal gemm-like procedure to
// multiply by C (row by row)
if(idd[p2 + j] == 1)
{
if(ii < sz && tidb < r)
for(int ii = 0; ii < tsz; ++ii)
{
temp += inrms[tidb + r];
rocblas_int i = in + ii;

// inner products
temp = 0;
if(ii < sz)
{
for(int kk = tidb; kk < dd; kk += dim)
temp += C[i + (per[kk] + in) * ldc] * temps[kk + (p2 + j) * n];
}
inrms[tidb] = temp;
__syncthreads();

// reduction
for(int r = dim / 2; r > 0; r /= 2)
{
if(ii < sz && tidb < r)
{
temp += inrms[tidb + r];
inrms[tidb] = temp;
}
__syncthreads();
}

// result
if(ii < sz && tidb == 0)
vecs[i + (p2 + j) * n] = temp / nrm;
__syncthreads();
}
__syncthreads();
}

// result
if(ii < sz && tidb == 0)
vecs[i + (p2 + j) * n] = temp / nrm;
__syncthreads();
}
}
/* ----------------------------------------------------------------- */
}
}
}


//--------------------------------------------------------------------------------------//
/** STEDC_MERGEUPDATE_KERNEL updates vectors after merges are done. A matrix in the batch
could have many split-blocks, and each split-block could be divided in a maximum of nn sub-blocks.
Expand Down Expand Up @@ -2050,15 +2089,10 @@ void local_gemm(rocblas_handle handle,
const rocblas_int batch_count,
S** workArr)
{
// Execute A*B -> temp -> A

// everything must be executed with scalars on the host
rocblas_pointer_mode old_mode;
rocblas_get_pointer_mode(handle, &old_mode);
rocblas_set_pointer_mode(handle, rocblas_pointer_mode_host);
S one = 1.0;
S zero = 0.0;

// Execute A*B -> temp -> A
// temp = A*B
rocsolver_gemm(handle, rocblas_operation_none, rocblas_operation_none, n, n, n, &one, A, shiftA,
lda, strideA, B, shiftV, ldv, strideV, &zero, temp, shiftV, ldv, strideV,
Expand All @@ -2070,8 +2104,6 @@ void local_gemm(rocblas_handle handle,
rocblas_int blocks = (n - 1) / BS2 + 1;
ROCSOLVER_LAUNCH_KERNEL(copy_mat<T>, dim3(blocks, blocks, batch_count), dim3(BS2, BS2), 0,
stream, copymat_from_buffer, n, n, A, shiftA, lda, strideA, temp);

rocblas_set_pointer_mode(handle, old_mode);
}

template <bool BATCHED,
Expand All @@ -2095,15 +2127,11 @@ void local_gemm(rocblas_handle handle,
const rocblas_int batch_count,
S** workArr)
{
// Execute A -> work; work*B -> temp -> A

// everything must be executed with scalars on the host
rocblas_pointer_mode old_mode;
rocblas_get_pointer_mode(handle, &old_mode);
rocblas_set_pointer_mode(handle, rocblas_pointer_mode_host);
S one = 1.0;
S zero = 0.0;

// Execute A -> work; work*B -> temp -> A

// work = real(A)
hipStream_t stream;
rocblas_get_stream(handle, &stream);
Expand Down Expand Up @@ -2134,8 +2162,6 @@ void local_gemm(rocblas_handle handle,
ROCSOLVER_LAUNCH_KERNEL((copy_mat<T, S, false>), dim3(blocks, blocks, batch_count),
dim3(BS2, BS2), 0, stream, copymat_from_buffer, n, n, A, shiftA, lda,
strideA, temp);

rocblas_set_pointer_mode(handle, old_mode);
}

//--------------------------------------------------------------------------------------//
Expand Down Expand Up @@ -2316,6 +2342,17 @@ rocblas_status rocsolver_stedc_template(rocblas_handle handle,
// otherwise use divide and conquer algorithm:
else
{
// initialize temporary array for vector updates
size_t size_tempgemm = sizeof(S) * 2 * n * n * batch_count;
HIP_CHECK(hipMemsetAsync((void*)tempgemm, 0, size_tempgemm, stream));

// everything must be executed with scalars on the host
rocblas_pointer_mode old_mode;
rocblas_get_pointer_mode(handle, &old_mode);
rocblas_set_pointer_mode(handle, rocblas_pointer_mode_host);
S one = 1.0;
S zero = 0.0;

// constants
S eps = get_epsilon<S>();
S ssfmin = get_safemin<S>();
Expand All @@ -2324,6 +2361,10 @@ rocblas_status rocsolver_stedc_template(rocblas_handle handle,
ssfmax = sqrt(ssfmax) / S(3.0);
rocblas_int blocksn = (n - 1) / BS2 + 1;

// find max number of sub-blocks to consider during the divide phase
rocblas_int maxlevs = stedc_num_levels<rocsolver_stedc_mode_qr>(n);
rocblas_int maxblks = 1 << maxlevs;

// initialize identity matrix in V
// if evect is tridiagonal we can store V directly in C
// otherwise, they must be kept separate to compute C*V
Expand All @@ -2339,10 +2380,6 @@ rocblas_status rocsolver_stedc_template(rocblas_handle handle,
ROCSOLVER_LAUNCH_KERNEL(init_ident<S>, dim3(blocksn, blocksn, batch_count), dim3(BS2, BS2),
0, stream, n, n, V, 0, ldv, strideV);

// find max number of sub-blocks to consider during the divide phase
rocblas_int maxlevs = stedc_num_levels<rocsolver_stedc_mode_qr>(n);
rocblas_int maxblks = 1 << maxlevs;

// find independent split blocks in matrix
ROCSOLVER_LAUNCH_KERNEL(stedc_split, dim3(batch_count), dim3(1), 0, stream, n, D + shiftD,
strideD, E + shiftE, strideE, splits_map, eps);
Expand All @@ -2367,9 +2404,11 @@ rocblas_status rocsolver_stedc_template(rocblas_handle handle,
rocblas_int numgrps3 = ((n - 1) / maxblks + 1) * maxblks;

// launch merge for level k
/** TODO: using max number of levels for now. Kernels return immediately when passing
the actual number of levels in the split block. We should explore if synchronizing
to copy back the actual number of levels makes any difference **/
// TODO: using max number of levels for now. Kernels return immediately when surpassing
// the actual number of levels in the split block. We should explore if synchronizing
// to copy back the actual number of levels makes any difference.
// TODO: the code computing the context of the level (first part of each kernel) could be
// reused.
for(rocblas_int k = 0; k < maxlevs; ++k)
{
// a. prepare secular equations
Expand All @@ -2387,12 +2426,24 @@ rocblas_status rocsolver_stedc_template(rocblas_handle handle,
E + shiftE, strideE, tmpz, tempgemm, splits, eps, ssfmin, ssfmax);

// c. find merged eigen vectors
ROCSOLVER_LAUNCH_KERNEL((stedc_mergeVectors_kernel<rocsolver_stedc_mode_qr, S>),
ROCSOLVER_LAUNCH_KERNEL((stedc_mergeVectors_kernel<rocsolver_stedc_mode_qr, STEDC_EXTERNAL_GEMM, S>),
dim3(numgrps3, STEDC_NUM_SPLIT_BLKS, batch_count),
dim3(STEDC_BDIM), lmemsize3, stream, k, n, D + shiftD, strideD,
E + shiftE, strideE, V, 0, ldv, strideV, tmpz, tempgemm, splits);

// c. update level
if(STEDC_EXTERNAL_GEMM)
{
// using external gemms with padded matrices to do the vector update
// One single full gemm of size n x n x n merges all the blocks in the level
// TODO: using macro STEDC_EXTERNAL_GEMM = true for now. In the future we can pass
// STEDC_EXTERNAL_GEMM at run time to switch between internal vector updates and
// external gemm based updates.
rocsolver_gemm(handle, rocblas_operation_none, rocblas_operation_none, n, n, n,
&one, V, 0, ldv, strideV, tempgemm, n*n, n, 2*n*n, &zero, tempgemm, 0, n, 2*n*n,
batch_count, workArr);
}

// d. update level
ROCSOLVER_LAUNCH_KERNEL((stedc_mergeUpdate_kernel<rocsolver_stedc_mode_qr, S>),
dim3(numgrps3, STEDC_NUM_SPLIT_BLKS, batch_count),
dim3(STEDC_BDIM), lmemsize3, stream, k, n, D + shiftD, strideD,
Expand Down Expand Up @@ -2430,6 +2481,8 @@ rocblas_status rocsolver_stedc_template(rocblas_handle handle,
ROCSOLVER_LAUNCH_KERNEL((stedc_sort<T>), dim3(1, 1, nblocks), dim3(BS1), 0, stream, n,
D + shiftD, strideD, C, shiftC, ldc, strideC, batch_count,
splits_map);

rocblas_set_pointer_mode(handle, old_mode);
}

return rocblas_status_success;
Expand Down
6 changes: 5 additions & 1 deletion library/src/auxiliary/rocauxiliary_stedcj.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ ROCSOLVER_BEGIN_NAMESPACE

#define MAXSWEEPS 20 // Max number of sweeps for Jacobi solver (when used)

// TODO: using macro STEDCJ_EXTERNAL_GEMM = false for now. We can enable the use of
// external gemm updates once the development is completed for stedc.
#define STEDCJ_EXTERNAL_GEMM false

/***************** Device auxiliary functions *****************************************/
/**************************************************************************************/

Expand Down Expand Up @@ -450,7 +454,7 @@ rocblas_status rocsolver_stedcj_template(rocblas_handle handle,
eps, ssfmin, ssfmax);

// c. find merged eigen vectors
ROCSOLVER_LAUNCH_KERNEL((stedc_mergeVectors_kernel<rocsolver_stedc_mode_jacobi, S>),
ROCSOLVER_LAUNCH_KERNEL((stedc_mergeVectors_kernel<rocsolver_stedc_mode_jacobi, STEDCJ_EXTERNAL_GEMM, S>),
dim3(numgrps3, STEDC_NUM_SPLIT_BLKS, batch_count), dim3(STEDC_BDIM),
lmemsize3, stream, k, n, D, strideD, E, strideE, tempvect, 0, ldt,
strideT, tmpz, tempgemm, splits_map);
Expand Down
9 changes: 7 additions & 2 deletions library/src/auxiliary/rocauxiliary_stedcx.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@

ROCSOLVER_BEGIN_NAMESPACE

// TODO: using macro STEDCX_EXTERNAL_GEMM = false for now. We can enable the use of
// external gemm updates once the development is completed for stedc.
#define STEDCX_EXTERNAL_GEMM false


/***************** Device auxiliary functions *****************************************/
/**************************************************************************************/

Expand Down Expand Up @@ -578,11 +583,11 @@ rocblas_status rocsolver_stedcx_template(rocblas_handle handle,
eps, ssfmin, ssfmax);

// c. find merged eigen vectors
ROCSOLVER_LAUNCH_KERNEL((stedc_mergeVectors_kernel<rocsolver_stedc_mode_bisection, S>),
ROCSOLVER_LAUNCH_KERNEL((stedc_mergeVectors_kernel<rocsolver_stedc_mode_bisection, STEDCX_EXTERNAL_GEMM, S>),
dim3(numgrps3, STEDC_NUM_SPLIT_BLKS, batch_count), dim3(STEDC_BDIM),
lmemsize3, stream, k, n, D, strideD, E, strideE, tempvect, 0, ldt,
strideT, tmpz, tempgemm, splits);

// d. update level
ROCSOLVER_LAUNCH_KERNEL((stedc_mergeUpdate_kernel<rocsolver_stedc_mode_bisection, S>),
dim3(numgrps3, STEDC_NUM_SPLIT_BLKS, batch_count), dim3(STEDC_BDIM),
Expand Down

0 comments on commit c601117

Please sign in to comment.