Skip to content

Commit

Permalink
Add 64-bit api for potf2, potrf, potrs (#776)
Browse files Browse the repository at this point in the history
* add 64-bit geqrf

* update changelog and docs

* make new 64-bit APIs optional

* first pass

* address feedback, fix exports

* fix docs

* add potrf_64, potf2_64

* add potrs_64

* update docs and changelog

* fix info type in potf2 impl

* Added potrf_info32

* Updated test

* Add 32-bit info variant for compatibility

---------

Co-authored-by: Troy Alderson <[email protected]>
  • Loading branch information
qjojo and tfalders authored Aug 7, 2024
1 parent 9dd1c22 commit 9f253f2
Show file tree
Hide file tree
Showing 39 changed files with 2,066 additions and 479 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ Full documentation for rocSOLVER is available at the [rocSOLVER documentation](h
- LARFG_64
- GEQR2_64 (with batched and strided\_batched versions)
- GEQRF_64 (with batched and strided\_batched versions)
- POTF2_64 (with batched and strided\_batched versions)
- POTRF_64 (with batched and strided\_batched versions)
- POTRS_64 (with batched and strided\_batched versions)

### Optimized
### Changed
Expand Down
1 change: 1 addition & 0 deletions clients/common/lapack/testing_potf2_potrf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,5 @@ INSTANTIATE(TESTING_POTF2_POTRF,
FOREACH_MATRIX_DATA_LAYOUT,
FOREACH_BLOCKED_VARIANT,
FOREACH_SCALAR_TYPE,
FOREACH_INT_TYPE,
APPLY_STAMP)
102 changes: 50 additions & 52 deletions clients/common/lapack/testing_potf2_potrf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,15 @@
#include "common/misc/rocsolver_arguments.hpp"
#include "common/misc/rocsolver_test.hpp"

template <bool STRIDED, bool POTRF, typename T, typename U>
template <bool STRIDED, bool POTRF, typename T, typename I, typename U>
void potf2_potrf_checkBadArgs(const rocblas_handle handle,
const rocblas_fill uplo,
const rocblas_int n,
const I n,
T dA,
const rocblas_int lda,
const I lda,
const rocblas_stride stA,
U dinfo,
const rocblas_int bc)
const I bc)
{
// handle
EXPECT_ROCBLAS_STATUS(
Expand Down Expand Up @@ -85,22 +85,22 @@ void potf2_potrf_checkBadArgs(const rocblas_handle handle,
rocblas_status_success);
}

template <bool BATCHED, bool STRIDED, bool POTRF, typename T>
template <bool BATCHED, bool STRIDED, bool POTRF, typename T, typename I>
void testing_potf2_potrf_bad_arg()
{
// safe arguments
rocblas_local_handle handle;
rocblas_fill uplo = rocblas_fill_upper;
rocblas_int n = 1;
rocblas_int lda = 1;
I n = 1;
I lda = 1;
rocblas_stride stA = 1;
rocblas_int bc = 1;
I bc = 1;

if(BATCHED)
{
// memory allocations
device_batch_vector<T> dA(1, 1, 1);
device_strided_batch_vector<rocblas_int> dinfo(1, 1, 1, 1);
device_strided_batch_vector<I> dinfo(1, 1, 1, 1);
CHECK_HIP_ERROR(dA.memcheck());
CHECK_HIP_ERROR(dinfo.memcheck());

Expand All @@ -112,7 +112,7 @@ void testing_potf2_potrf_bad_arg()
{
// memory allocations
device_strided_batch_vector<T> dA(1, 1, 1, 1);
device_strided_batch_vector<rocblas_int> dinfo(1, 1, 1, 1);
device_strided_batch_vector<I> dinfo(1, 1, 1, 1);
CHECK_HIP_ERROR(dA.memcheck());
CHECK_HIP_ERROR(dinfo.memcheck());

Expand All @@ -122,15 +122,15 @@ void testing_potf2_potrf_bad_arg()
}
}

template <bool CPU, bool GPU, typename T, typename Td, typename Ud, typename Th, typename Uh>
template <bool CPU, bool GPU, typename T, typename I, typename Td, typename Ud, typename Th, typename Uh>
void potf2_potrf_initData(const rocblas_handle handle,
const rocblas_fill uplo,
const rocblas_int n,
const I n,
Td& dA,
const rocblas_int lda,
const I lda,
const rocblas_stride stA,
Ud& dInfo,
const rocblas_int bc,
const I bc,
Th& hA,
Uh& hInfo,
const bool singular)
Expand All @@ -139,10 +139,10 @@ void potf2_potrf_initData(const rocblas_handle handle,
{
rocblas_init<T>(hA, true);

for(rocblas_int b = 0; b < bc; ++b)
for(I b = 0; b < bc; ++b)
{
// scale to ensure positive definiteness
for(rocblas_int i = 0; i < n; i++)
for(I i = 0; i < n; i++)
hA[b][i + i * lda] = hA[b][i + i * lda] * sconj(hA[b][i + i * lda]) * 400;

if(singular && (b == bc / 4 || b == bc / 2 || b == bc - 1))
Expand All @@ -151,7 +151,7 @@ void potf2_potrf_initData(const rocblas_handle handle,
// always the same elements for debugging purposes
// the algorithm must detect the lower order of the principal minors <= 0
// in those matrices in the batch that are non positive definite
rocblas_int i = n / 4 + b;
I i = n / 4 + b;
i -= (i / n) * n;
hA[b][i + i * lda] = 0;
i = n / 2 + b;
Expand All @@ -171,19 +171,19 @@ void potf2_potrf_initData(const rocblas_handle handle,
}
}

template <bool STRIDED, bool POTRF, typename T, typename Td, typename Ud, typename Th, typename Uh>
template <bool STRIDED, bool POTRF, typename T, typename I, typename Td, typename Id, typename Th, typename Ih, typename Uh>
void potf2_potrf_getError(const rocblas_handle handle,
const rocblas_fill uplo,
const rocblas_int n,
const I n,
Td& dA,
const rocblas_int lda,
const I lda,
const rocblas_stride stA,
Ud& dInfo,
const rocblas_int bc,
Id& dInfo,
const I bc,
Th& hA,
Th& hARes,
Uh& hInfo,
Uh& hInfoRes,
Ih& hInfoRes,
double* max_err,
const bool singular)
{
Expand All @@ -199,7 +199,7 @@ void potf2_potrf_getError(const rocblas_handle handle,
CHECK_HIP_ERROR(hInfoRes.transfer_from(dInfo));

// CPU lapack
for(rocblas_int b = 0; b < bc; ++b)
for(I b = 0; b < bc; ++b)
{
POTRF ? cpu_potrf(uplo, n, hA[b], lda, hInfo[b]) : cpu_potf2(uplo, n, hA[b], lda, hInfo[b]);
}
Expand All @@ -209,9 +209,9 @@ void potf2_potrf_getError(const rocblas_handle handle,
// IT MIGHT BE REVISITED IN THE FUTURE)
// using frobenius norm
double err;
rocblas_int nn;
I nn;
*max_err = 0;
for(rocblas_int b = 0; b < bc; ++b)
for(I b = 0; b < bc; ++b)
{
nn = hInfoRes[b][0] == 0 ? n : hInfoRes[b][0];
// (TODO: For now, the algorithm is modifying the whole input matrix even when
Expand All @@ -224,7 +224,7 @@ void potf2_potrf_getError(const rocblas_handle handle,

// also check info for non positive definite cases
err = 0;
for(rocblas_int b = 0; b < bc; ++b)
for(I b = 0; b < bc; ++b)
{
EXPECT_EQ(hInfo[b][0], hInfoRes[b][0]) << "where b = " << b;
if(hInfo[b][0] != hInfoRes[b][0])
Expand All @@ -233,15 +233,15 @@ void potf2_potrf_getError(const rocblas_handle handle,
*max_err += err;
}

template <bool STRIDED, bool POTRF, typename T, typename Td, typename Ud, typename Th, typename Uh>
template <bool STRIDED, bool POTRF, typename T, typename I, typename Td, typename Id, typename Th, typename Uh>
void potf2_potrf_getPerfData(const rocblas_handle handle,
const rocblas_fill uplo,
const rocblas_int n,
const I n,
Td& dA,
const rocblas_int lda,
const I lda,
const rocblas_stride stA,
Ud& dInfo,
const rocblas_int bc,
Id& dInfo,
const I bc,
Th& hA,
Uh& hInfo,
double* gpu_time_used,
Expand All @@ -259,7 +259,7 @@ void potf2_potrf_getPerfData(const rocblas_handle handle,

// cpu-lapack performance (only if not in perf mode)
*cpu_time_used = get_time_us_no_sync();
for(rocblas_int b = 0; b < bc; ++b)
for(I b = 0; b < bc; ++b)
{
POTRF ? cpu_potrf(uplo, n, hA[b], lda, hInfo[b])
: cpu_potf2(uplo, n, hA[b], lda, hInfo[b]);
Expand Down Expand Up @@ -307,18 +307,18 @@ void potf2_potrf_getPerfData(const rocblas_handle handle,
*gpu_time_used /= hot_calls;
}

template <bool BATCHED, bool STRIDED, bool POTRF, typename T>
template <bool BATCHED, bool STRIDED, bool POTRF, typename T, typename I>
void testing_potf2_potrf(Arguments& argus)
{
// get arguments
rocblas_local_handle handle;
char uploC = argus.get<char>("uplo");
rocblas_int n = argus.get<rocblas_int>("n");
rocblas_int lda = argus.get<rocblas_int>("lda", n);
I n = argus.get<I>("n");
I lda = argus.get<I>("lda", n);
rocblas_stride stA = argus.get<rocblas_stride>("strideA", lda * n);

rocblas_fill uplo = char2rocblas_fill(uploC);
rocblas_int bc = argus.batch_count;
I bc = argus.batch_count;
rocblas_int hot_calls = argus.iters;

rocblas_stride stARes = (argus.unit_check || argus.norm_check) ? stA : 0;
Expand All @@ -328,12 +328,11 @@ void testing_potf2_potrf(Arguments& argus)
{
if(BATCHED)
EXPECT_ROCBLAS_STATUS(rocsolver_potf2_potrf(STRIDED, POTRF, handle, uplo, n,
(T* const*)nullptr, lda, stA,
(rocblas_int*)nullptr, bc),
(T* const*)nullptr, lda, stA, (I*)nullptr, bc),
rocblas_status_invalid_value);
else
EXPECT_ROCBLAS_STATUS(rocsolver_potf2_potrf(STRIDED, POTRF, handle, uplo, n, (T*)nullptr,
lda, stA, (rocblas_int*)nullptr, bc),
EXPECT_ROCBLAS_STATUS(rocsolver_potf2_potrf(STRIDED, POTRF, handle, uplo, n,
(T*)nullptr, lda, stA, (I*)nullptr, bc),
rocblas_status_invalid_value);

if(argus.timing)
Expand All @@ -354,12 +353,11 @@ void testing_potf2_potrf(Arguments& argus)
{
if(BATCHED)
EXPECT_ROCBLAS_STATUS(rocsolver_potf2_potrf(STRIDED, POTRF, handle, uplo, n,
(T* const*)nullptr, lda, stA,
(rocblas_int*)nullptr, bc),
(T* const*)nullptr, lda, stA, (I*)nullptr, bc),
rocblas_status_invalid_size);
else
EXPECT_ROCBLAS_STATUS(rocsolver_potf2_potrf(STRIDED, POTRF, handle, uplo, n, (T*)nullptr,
lda, stA, (rocblas_int*)nullptr, bc),
EXPECT_ROCBLAS_STATUS(rocsolver_potf2_potrf(STRIDED, POTRF, handle, uplo, n,
(T*)nullptr, lda, stA, (I*)nullptr, bc),
rocblas_status_invalid_size);

if(argus.timing)
Expand All @@ -374,11 +372,10 @@ void testing_potf2_potrf(Arguments& argus)
CHECK_ROCBLAS_ERROR(rocblas_start_device_memory_size_query(handle));
if(BATCHED)
CHECK_ALLOC_QUERY(rocsolver_potf2_potrf(STRIDED, POTRF, handle, uplo, n,
(T* const*)nullptr, lda, stA,
(rocblas_int*)nullptr, bc));
(T* const*)nullptr, lda, stA, (I*)nullptr, bc));
else
CHECK_ALLOC_QUERY(rocsolver_potf2_potrf(STRIDED, POTRF, handle, uplo, n, (T*)nullptr,
lda, stA, (rocblas_int*)nullptr, bc));
lda, stA, (I*)nullptr, bc));

size_t size;
CHECK_ROCBLAS_ERROR(rocblas_stop_device_memory_size_query(handle, &size));
Expand All @@ -397,9 +394,9 @@ void testing_potf2_potrf(Arguments& argus)
host_batch_vector<T> hA(size_A, 1, bc);
host_batch_vector<T> hARes(size_ARes, 1, bc);
host_strided_batch_vector<rocblas_int> hInfo(1, 1, 1, bc);
host_strided_batch_vector<rocblas_int> hInfoRes(1, 1, 1, bc);
host_strided_batch_vector<I> hInfoRes(1, 1, 1, bc);
device_batch_vector<T> dA(size_A, 1, bc);
device_strided_batch_vector<rocblas_int> dInfo(1, 1, 1, bc);
device_strided_batch_vector<I> dInfo(1, 1, 1, bc);
if(size_A)
CHECK_HIP_ERROR(dA.memcheck());
CHECK_HIP_ERROR(dInfo.memcheck());
Expand Down Expand Up @@ -435,9 +432,9 @@ void testing_potf2_potrf(Arguments& argus)
host_strided_batch_vector<T> hA(size_A, 1, stA, bc);
host_strided_batch_vector<T> hARes(size_ARes, 1, stARes, bc);
host_strided_batch_vector<rocblas_int> hInfo(1, 1, 1, bc);
host_strided_batch_vector<rocblas_int> hInfoRes(1, 1, 1, bc);
host_strided_batch_vector<I> hInfoRes(1, 1, 1, bc);
device_strided_batch_vector<T> dA(size_A, 1, stA, bc);
device_strided_batch_vector<rocblas_int> dInfo(1, 1, 1, bc);
device_strided_batch_vector<I> dInfo(1, 1, 1, bc);
if(size_A)
CHECK_HIP_ERROR(dA.memcheck());
CHECK_HIP_ERROR(dInfo.memcheck());
Expand Down Expand Up @@ -526,4 +523,5 @@ INSTANTIATE(EXTERN_TESTING_POTF2_POTRF,
FOREACH_MATRIX_DATA_LAYOUT,
FOREACH_BLOCKED_VARIANT,
FOREACH_SCALAR_TYPE,
FOREACH_INT_TYPE,
APPLY_STAMP)
2 changes: 1 addition & 1 deletion clients/common/lapack/testing_potrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,4 @@

#define TESTING_POTRS(...) template void testing_potrs<__VA_ARGS__>(Arguments&);

INSTANTIATE(TESTING_POTRS, FOREACH_MATRIX_DATA_LAYOUT, FOREACH_SCALAR_TYPE, APPLY_STAMP)
INSTANTIATE(TESTING_POTRS, FOREACH_MATRIX_DATA_LAYOUT, FOREACH_SCALAR_TYPE, FOREACH_INT_TYPE, APPLY_STAMP)
Loading

0 comments on commit 9f253f2

Please sign in to comment.