Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Recursive cholesky #710

Open
wants to merge 15 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion library/src/include/ideal_sizes.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,16 @@

/************************** potf2/potrf ***************************************
*******************************************************************************/
/*! \brief Determines the maximum size at which rocSOLVER can use the small-size POTF2
kernel.
\details
POTF2 will attempt to factorize a small symmetric matrix that can fit entirely
within the LDS share memory using compact storage.
The amount of LDS shared memory is assumed to be at least (64 * 1024) bytes. */
#ifndef POTF2_MAX_SMALL_SIZE
#define POTF2_MAX_SMALL_SIZE(T) ((sizeof(T) == 4) ? 180 : (sizeof(T) == 8) ? 127 : 90)
#endif

/*! \brief Determines the size of the leading block that is factorized at each step
when using the blocked algorithm (POTRF). It also applies to the
corresponding batched and strided-batched routines.*/
Expand All @@ -281,11 +291,21 @@

\details POTRF will factorize blocks of POTRF_BLOCKSIZE columns at a time until
the rest of the matrix has no more than POTRF_POTF2_SWITCHSIZE columns; at this point the last block,
if any, will be factorized with the unblocked algorithm (POTF2).*/
if any, will be factorized with the unblocked algorithm (POTF2). */
#ifndef POTRF_POTF2_SWITCHSIZE
#define POTRF_POTF2_SWITCHSIZE(T) POTRF_BLOCKSIZE(T)
#endif

/*! \brief Determines the size at which rocSOLVER switches from
the recursive to the right-looking algorithm when executing POTRF. It also applies to the
corresponding batched and strided-batched routines.
\details POTRF will recursively divide the matrix into submatrices of size n/2 until
n/2 is less than POTRF_RECURSIVE_SWITCHSIZE; at this point the the submatrix will be
factorized with the right-looking algorithm (POTRF or POTF2). */
#ifndef POTRF_RECURSIVE_SWITCHSIZE
#define POTRF_RECURSIVE_SWITCHSIZE(T) ((sizeof(T) == 4) ? 1408 : (sizeof(T) == 8) ? 1024 : 704)
#endif

/*! \brief Determines the maximum size at which rocSOLVER can use POTF2
\details
POTF2 will attempt to factorize a small symmetric matrix that can fit entirely
Expand Down
272 changes: 271 additions & 1 deletion library/src/lapack/roclapack_potrf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ void rocsolver_potrf_getMemorySize(const I n,
*size_iinfo = 0;
*optim_mem = true;
}
else
else if(n <= POTRF_RECURSIVE_SWITCHSIZE(T))
{
I jb = nb;
size_t s1, s2;
Expand All @@ -135,6 +135,276 @@ void rocsolver_potrf_getMemorySize(const I n,

*size_work1 = std::max(s1, s2);
}
else
{
// requirements for recursive POTRF
auto const n2 = n / 2;
auto const n1 = n - n2;

size_t w11 = 0, w12 = 0, w13 = 0;
size_t w21 = 0, w22 = 0, w23 = 0;
size_t w31 = 0, w32 = 0, w33 = 0;
size_t w41 = 0, w42 = 0, w43 = 0;
size_t p1 = 0, p2 = 0;
bool opt1 = false, opt2 = false, opt3 = false;
size_t unused;

// size to store info about positiveness of each subblock
*size_iinfo = sizeof(rocblas_int) * batch_count;

// requirements for calling POTRF recursively on submatrices
rocsolver_potrf_getMemorySize<BATCHED, STRIDED, T>(n1, uplo, batch_count, size_scalars, &w11,
&w21, &w31, &w41, &p1, &unused, &opt1);

rocsolver_potrf_getMemorySize<BATCHED, STRIDED, T>(n2, uplo, batch_count, &unused, &w12,
&w22, &w32, &w42, &p2, &unused, &opt2);

// extra requirements for calling TRSM
if(uplo == rocblas_fill_upper)
{
rocsolver_trsm_mem<BATCHED, STRIDED, T>(rocblas_side_left,
rocblas_operation_conjugate_transpose, n1, n2,
batch_count, &w13, &w23, &w33, &w43, &opt3);
}
else
{
rocsolver_trsm_mem<BATCHED, STRIDED, T>(rocblas_side_right,
rocblas_operation_conjugate_transpose, n2, n1,
batch_count, &w13, &w23, &w33, &w43, &opt3);
}

*size_work1 = std::max({w11, w12, w13});
*size_work2 = std::max({w21, w22, w23});
*size_work3 = std::max({w31, w32, w33});
*size_work4 = std::max({w41, w42, w43});
*size_pivots = std::max(p1, p2);
*optim_mem = opt1 && opt2 && opt3;
}
}

template <bool BATCHED, bool STRIDED, typename T, typename I, typename U>
rocblas_status rocsolver_potrf_recursive_template(rocblas_handle handle,
const rocblas_fill uplo,
const I n,
U A,
const rocblas_stride shiftA,
const I lda,
const rocblas_stride strideA,
I* info,
const I batch_count,
T* scalars,
void* work1,
void* work2,
void* work3,
void* work4,
T* pivots,
I* iinfo,
bool optim_mem,
const I row_offset = 0)
{
ROCSOLVER_ENTER("potrf_recursive", "uplo:", uplo, "n:", n, "shiftA:", shiftA, "lda:", lda,
"bc:", batch_count, "row_offset:", row_offset);
using S = decltype(std::real(T{}));

// quick return
if(n == 0)
return rocblas_status_success;

// -------------------------------------------------
// UNBLOCKED ALGORITHM FOR SMALL MATRICES
// -------------------------------------------------
I nb = POTRF_BLOCKSIZE(T);
if(n <= POTRF_POTF2_SWITCHSIZE(T) && row_offset == 0)
{
// only the first potf2 (when row_offset = 0) may modify info directly,
// others must go through iinfo and the chk_positive kernel
return rocsolver_potf2_template<T>(handle, uplo, n, A, shiftA, lda, strideA, info,
batch_count, scalars, (T*)work1, pivots);
}

hipStream_t stream;
rocblas_get_stream(handle, &stream);

rocblas_int blocksReset = (batch_count - 1) / BS1 + 1;
dim3 gridReset(blocksReset, 1, 1);
dim3 threads(BS1, 1, 1);

// constants for rocblas functions calls
T t_one = 1;
S s_one = 1;
S s_minone = -1;

// (TODO: When the matrix is detected to be non positive definite, we need to
// prevent TRSM and HERK to modify further the input matrix; ideally with no
// synchronizations.)

// -------------------------------------------------
// RIGHT-LOOKING ALGORITHM FOR MEDIUM MATRICES
// -------------------------------------------------
if(n <= POTRF_RECURSIVE_SWITCHSIZE(T))
{
I jb, j = 0;

if(uplo == rocblas_fill_upper)
{
// Compute the Cholesky factorization A = U'*U.
while(j < n - POTRF_POTF2_SWITCHSIZE(T))
{
// Factor diagonal and subdiagonal blocks
jb = std::min(n - j, nb); // number of columns in the block
ROCSOLVER_LAUNCH_KERNEL(reset_info, gridReset, threads, 0, stream, iinfo,
batch_count, 0);
ROCBLAS_CHECK(rocsolver_potf2_template<T>(
handle, uplo, jb, A, shiftA + idx2D(j, j, lda), lda, strideA, iinfo,
batch_count, scalars, (T*)work1, pivots));

// test for non-positive-definiteness.
ROCSOLVER_LAUNCH_KERNEL(chk_positive<U>, gridReset, threads, 0, stream, iinfo, info,
j + row_offset, batch_count);

if(j + jb < n)
{
// update trailing submatrix
ROCBLAS_CHECK(rocsolver_trsm_upper<BATCHED, STRIDED, T>(
handle, rocblas_side_left, rocblas_operation_conjugate_transpose,
rocblas_diagonal_non_unit, jb, (n - j - jb), A, shiftA + idx2D(j, j, lda),
lda, strideA, A, shiftA + idx2D(j, j + jb, lda), lda, strideA, batch_count,
optim_mem, work1, work2, work3, work4));

ROCBLAS_CHECK(rocblasCall_syrk_herk<BATCHED, T>(
handle, uplo, rocblas_operation_conjugate_transpose, n - j - jb, jb,
&s_minone, A, shiftA + idx2D(j, j + jb, lda), lda, strideA, &s_one, A,
shiftA + idx2D(j + jb, j + jb, lda), lda, strideA, batch_count));
}
j += nb;
}
}
else
{
// Compute the Cholesky factorization A = L*L'.
while(j < n - POTRF_POTF2_SWITCHSIZE(T))
{
// Factor diagonal and subdiagonal blocks
jb = std::min(n - j, nb); // number of columns in the block
ROCSOLVER_LAUNCH_KERNEL(reset_info, gridReset, threads, 0, stream, iinfo,
batch_count, 0);
ROCBLAS_CHECK(rocsolver_potf2_template<T>(
handle, uplo, jb, A, shiftA + idx2D(j, j, lda), lda, strideA, iinfo,
batch_count, scalars, (T*)work1, pivots));

// test for non-positive-definiteness.
ROCSOLVER_LAUNCH_KERNEL(chk_positive<U>, gridReset, threads, 0, stream, iinfo, info,
j + row_offset, batch_count);

if(j + jb < n)
{
// update trailing submatrix
ROCBLAS_CHECK(rocsolver_trsm_lower<BATCHED, STRIDED, T>(
handle, rocblas_side_right, rocblas_operation_conjugate_transpose,
rocblas_diagonal_non_unit, (n - j - jb), jb, A, shiftA + idx2D(j, j, lda),
lda, strideA, A, shiftA + idx2D(j + jb, j, lda), lda, strideA, batch_count,
optim_mem, work1, work2, work3, work4));

ROCBLAS_CHECK(rocblasCall_syrk_herk<BATCHED, T>(
handle, uplo, rocblas_operation_none, n - j - jb, jb, &s_minone, A,
shiftA + idx2D(j + jb, j, lda), lda, strideA, &s_one, A,
shiftA + idx2D(j + jb, j + jb, lda), lda, strideA, batch_count));
}
j += nb;
}
}

// factor last block
if(j < n)
{
ROCBLAS_CHECK(rocsolver_potf2_template<T>(handle, uplo, n - j, A,
shiftA + idx2D(j, j, lda), lda, strideA, iinfo,
batch_count, scalars, (T*)work1, pivots));
ROCSOLVER_LAUNCH_KERNEL(chk_positive<U>, gridReset, threads, 0, stream, iinfo, info,
j + row_offset, batch_count);
}

return rocblas_status_success;
}

// -------------------------------------------------
// RECURSIVE ALGORITHM FOR LARGE MATRICES
// -------------------------------------------------
else
{
auto const n2 = n / 2;
auto const n1 = n - n2;

if(uplo == rocblas_fill_upper)
{
// -------------------------------------------------
// A = U' * U
// [A11 A12] = [ U11' 0 ] * [U11 U12]
// [A12' A22] [ U12' U22'] [0 U22]
//
// where A11 is n1 by n1, A22 is n2 by n2, n == (n1 + n2)
// -------------------------------------------------

// find U11 given A11 = U11' * U11
ROCBLAS_CHECK(rocsolver_potrf_recursive_template<BATCHED, STRIDED, T>(
handle, uplo, n1, A, shiftA, lda, strideA, info, batch_count, scalars, work1, work2,
work3, work4, pivots, iinfo, optim_mem));

// find U12 given A12 = U11' * U12
auto const A12_offset = idx2D(0, n1, lda);
ROCBLAS_CHECK(rocsolver_trsm_upper<BATCHED, STRIDED, T>(
handle, rocblas_side_left, rocblas_operation_conjugate_transpose,
rocblas_diagonal_non_unit, n1, n2, A, shiftA, lda, strideA, A, shiftA + A12_offset,
lda, strideA, batch_count, optim_mem, work1, work2, work3, work4));

// update A22 as A22 - U12' * U12
auto const A22_offset = idx2D(n1, n1, lda);
ROCBLAS_CHECK(rocblasCall_syrk_herk<BATCHED, T>(
handle, uplo, rocblas_operation_conjugate_transpose, n2, n1, &s_minone, A,
shiftA + A12_offset, lda, strideA, &s_one, A, shiftA + A22_offset, lda, strideA,
batch_count));

// find U22 given A22 = U22' * U22
ROCBLAS_CHECK(rocsolver_potrf_recursive_template<BATCHED, STRIDED, T>(
handle, uplo, n2, A, shiftA + A22_offset, lda, strideA, info, batch_count, scalars,
work1, work2, work3, work4, pivots, iinfo, optim_mem, n1));
}
else
{
// ------------------------------------------------
// A = L * L'
// [A11 A21'] = [L11 0 ] * [L11' L21']
// [A21 A22 ] [L21 L22] [0 L22']
//
// where A11 is n1 by n1, A22 is n2 by n2, n == (n1 + n2)
// ------------------------------------------------

// find L11 given A11 = L11 * L11'
ROCBLAS_CHECK(rocsolver_potrf_recursive_template<BATCHED, STRIDED, T>(
handle, uplo, n1, A, shiftA, lda, strideA, info, batch_count, scalars, work1, work2,
work3, work4, pivots, iinfo, optim_mem));

// find L21 given A21 = L21 * L11'
auto const A21_offset = idx2D(n1, 0, lda);
ROCBLAS_CHECK(rocsolver_trsm_lower<BATCHED, STRIDED, T>(
handle, rocblas_side_right, rocblas_operation_conjugate_transpose,
rocblas_diagonal_non_unit, n2, n1, A, shiftA, lda, strideA, A, shiftA + A21_offset,
lda, strideA, batch_count, optim_mem, work1, work2, work3, work4));

// update A22 as A22 - L21 * L21'
auto const A22_offset = idx2D(n1, n1, lda);
ROCBLAS_CHECK(rocblasCall_syrk_herk<BATCHED, T>(
handle, uplo, rocblas_operation_none, n2, n1, &s_minone, A, shiftA + A21_offset,
lda, strideA, &s_one, A, shiftA + A22_offset, lda, strideA, batch_count));

// find L22 given A22 = L22 * L22'
ROCBLAS_CHECK(rocsolver_potrf_recursive_template<BATCHED, STRIDED, T>(
handle, uplo, n2, A, shiftA + A22_offset, lda, strideA, info, batch_count, scalars,
work1, work2, work3, work4, pivots, iinfo, optim_mem, n1));
}

return rocblas_status_success;
}
}

template <bool BATCHED, bool STRIDED, typename T, typename I, typename INFO, typename S, typename U>
Expand Down
18 changes: 10 additions & 8 deletions library/src/specialized/roclapack_potf2_specialized_kernels.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,11 +154,13 @@ __device__ static void potf2_simple(bool const is_upper, I const n, T* const A,
// ------------------------------------------------------------

auto const conj_lkk = conj(lkk);
auto const inv_conj_lkk = 1.0 / conj_lkk;
for(I j0 = (kcol + 1) + j0_start; j0 < n; j0 += j0_inc)
{
auto const j0k = idx_lower(j0, kcol, lda);

A[j0k] = (A[j0k] / conj_lkk);
// A[j0k] = (A[j0k] / conj_lkk);
A[j0k] = (A[j0k] * inv_conj_lkk);
}

__syncthreads();
Expand All @@ -172,10 +174,9 @@ __device__ static void potf2_simple(bool const is_upper, I const n, T* const A,
for(I j = (kcol + 1) + j_start; j < n; j += j_inc)
{
auto const vj = A[idx_lower(j, kcol, lda)];
for(I i = (kcol + 1) + i_start; i < n; i += i_inc)
for(I i = j + i_start; i < n; i += i_inc)
{
bool const lower_part = (i >= j);
if(lower_part)
assert(i >= j);
{
auto const vi = A[idx_lower(i, kcol, lda)];
auto const ij = idx_lower(i, j, lda);
Expand Down Expand Up @@ -232,11 +233,13 @@ __device__ static void potf2_simple(bool const is_upper, I const n, T* const A,
// ----------------------------------------------
// (2) vU12' * u11 = vA12', or u11' * vU12 = vA12
// ----------------------------------------------
auto const inv_ukk = 1.0 / ukk;
for(I j0 = (kcol + 1) + j0_start; j0 < n; j0 += j0_inc)
{
auto const kj0 = idx_upper(kcol, j0, lda);

A[kj0] = A[kj0] / ukk;
// A[kj0] = A[kj0] / ukk;
A[kj0] = A[kj0] * inv_ukk;
}

__syncthreads();
Expand All @@ -249,10 +252,9 @@ __device__ static void potf2_simple(bool const is_upper, I const n, T* const A,
for(I j = (kcol + 1) + j_start; j < n; j += j_inc)
{
auto const vj = A[idx_upper(kcol, j, lda)];
for(I i = (kcol + 1) + i_start; i < n; i += i_inc)
for(I i = (kcol + 1) + i_start; i <= j; i += i_inc)
{
bool const upper_part = (i <= j);
if(upper_part)
assert(i <= j);
{
auto const vi = A[idx_upper(kcol, i, lda)];
auto const ij = idx_upper(i, j, lda);
Expand Down