Skip to content

Commit

Permalink
Add 64-bit version of geqrf, geqr2 (#767)
Browse files Browse the repository at this point in the history
* add 64-bit geqrf

* update changelog and docs

* make new 64-bit APIs optional

* address feedback, fix exports

* fix docs
  • Loading branch information
qjojo authored Jul 31, 2024
1 parent e26176b commit e169341
Show file tree
Hide file tree
Showing 29 changed files with 1,293 additions and 237 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ Full documentation for rocSOLVER is available at the [rocSOLVER documentation](h
- LACGV_64
- LARF_64
- LARFG_64
- GEQR2_64 (with batched and strided\_batched versions)
- GEQRF_64 (with batched and strided\_batched versions)

### Optimized
### Changed
Expand Down
1 change: 1 addition & 0 deletions clients/common/lapack/testing_geqr2_geqrf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,5 @@ INSTANTIATE(TESTING_GEQR2_GEQRF,
FOREACH_MATRIX_DATA_LAYOUT,
FOREACH_BLOCKED_VARIANT,
FOREACH_SCALAR_TYPE,
FOREACH_INT_TYPE,
APPLY_STAMP)
85 changes: 43 additions & 42 deletions clients/common/lapack/testing_geqr2_geqrf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,16 @@
#include "common/misc/rocsolver_arguments.hpp"
#include "common/misc/rocsolver_test.hpp"

template <bool STRIDED, bool GEQRF, typename T, typename U>
template <bool STRIDED, bool GEQRF, typename T, typename I, typename U>
void geqr2_geqrf_checkBadArgs(const rocblas_handle handle,
const rocblas_int m,
const rocblas_int n,
const I m,
const I n,
T dA,
const rocblas_int lda,
const I lda,
const rocblas_stride stA,
U dIpiv,
const rocblas_stride stP,
const rocblas_int bc)
const I bc)
{
// handle
EXPECT_ROCBLAS_STATUS(
Expand All @@ -57,7 +57,7 @@ void geqr2_geqrf_checkBadArgs(const rocblas_handle handle,
// sizes (only check batch_count if applicable)
if(STRIDED)
EXPECT_ROCBLAS_STATUS(
rocsolver_geqr2_geqrf(STRIDED, GEQRF, handle, m, n, dA, lda, stA, dIpiv, stP, -1),
rocsolver_geqr2_geqrf(STRIDED, GEQRF, handle, m, n, dA, lda, stA, dIpiv, stP, (I)-1),
rocblas_status_invalid_size);

// pointers
Expand All @@ -69,31 +69,31 @@ void geqr2_geqrf_checkBadArgs(const rocblas_handle handle,
rocblas_status_invalid_pointer);

// quick return with invalid pointers
EXPECT_ROCBLAS_STATUS(rocsolver_geqr2_geqrf(STRIDED, GEQRF, handle, 0, n, (T) nullptr, lda, stA,
(U) nullptr, stP, bc),
EXPECT_ROCBLAS_STATUS(rocsolver_geqr2_geqrf(STRIDED, GEQRF, handle, (I)0, n, (T) nullptr, lda,
stA, (U) nullptr, stP, bc),
rocblas_status_success);
EXPECT_ROCBLAS_STATUS(rocsolver_geqr2_geqrf(STRIDED, GEQRF, handle, m, 0, (T) nullptr, lda, stA,
(U) nullptr, stP, bc),
EXPECT_ROCBLAS_STATUS(rocsolver_geqr2_geqrf(STRIDED, GEQRF, handle, m, (I)0, (T) nullptr, lda,
stA, (U) nullptr, stP, bc),
rocblas_status_success);

// quick return with zero batch_count if applicable
if(STRIDED)
EXPECT_ROCBLAS_STATUS(
rocsolver_geqr2_geqrf(STRIDED, GEQRF, handle, m, n, dA, lda, stA, dIpiv, stP, 0),
rocsolver_geqr2_geqrf(STRIDED, GEQRF, handle, m, n, dA, lda, stA, dIpiv, stP, (I)0),
rocblas_status_success);
}

template <bool BATCHED, bool STRIDED, bool GEQRF, typename T>
template <bool BATCHED, bool STRIDED, bool GEQRF, typename T, typename I>
void testing_geqr2_geqrf_bad_arg()
{
// safe arguments
rocblas_local_handle handle;
rocblas_int m = 1;
rocblas_int n = 1;
rocblas_int lda = 1;
I m = 1;
I n = 1;
I lda = 1;
rocblas_stride stA = 1;
rocblas_stride stP = 1;
rocblas_int bc = 1;
I bc = 1;

if(BATCHED)
{
Expand Down Expand Up @@ -121,16 +121,16 @@ void testing_geqr2_geqrf_bad_arg()
}
}

template <bool CPU, bool GPU, typename T, typename Td, typename Ud, typename Th, typename Uh>
template <bool CPU, bool GPU, typename T, typename I, typename Td, typename Ud, typename Th, typename Uh>
void geqr2_geqrf_initData(const rocblas_handle handle,
const rocblas_int m,
const rocblas_int n,
const I m,
const I n,
Td& dA,
const rocblas_int lda,
const I lda,
const rocblas_stride stA,
Ud& dIpiv,
const rocblas_stride stP,
const rocblas_int bc,
const I bc,
Th& hA,
Uh& hIpiv)
{
Expand All @@ -139,11 +139,11 @@ void geqr2_geqrf_initData(const rocblas_handle handle,
rocblas_init<T>(hA, true);

// scale A to avoid singularities
for(rocblas_int b = 0; b < bc; ++b)
for(I b = 0; b < bc; ++b)
{
for(rocblas_int i = 0; i < m; i++)
for(I i = 0; i < m; i++)
{
for(rocblas_int j = 0; j < n; j++)
for(I j = 0; j < n; j++)
{
if(i == j)
hA[b][i + j * lda] += 400;
Expand All @@ -161,16 +161,16 @@ void geqr2_geqrf_initData(const rocblas_handle handle,
}
}

template <bool STRIDED, bool GEQRF, typename T, typename Td, typename Ud, typename Th, typename Uh>
template <bool STRIDED, bool GEQRF, typename T, typename I, typename Td, typename Ud, typename Th, typename Uh>
void geqr2_geqrf_getError(const rocblas_handle handle,
const rocblas_int m,
const rocblas_int n,
const I m,
const I n,
Td& dA,
const rocblas_int lda,
const I lda,
const rocblas_stride stA,
Ud& dIpiv,
const rocblas_stride stP,
const rocblas_int bc,
const I bc,
Th& hA,
Th& hARes,
Uh& hIpiv,
Expand All @@ -188,7 +188,7 @@ void geqr2_geqrf_getError(const rocblas_handle handle,
CHECK_HIP_ERROR(hARes.transfer_from(dA));

// CPU lapack
for(rocblas_int b = 0; b < bc; ++b)
for(I b = 0; b < bc; ++b)
{
GEQRF ? cpu_geqrf(m, n, hA[b], lda, hIpiv[b], hW.data(), n)
: cpu_geqr2(m, n, hA[b], lda, hIpiv[b], hW.data());
Expand All @@ -200,23 +200,23 @@ void geqr2_geqrf_getError(const rocblas_handle handle,
// using frobenius norm
double err;
*max_err = 0;
for(rocblas_int b = 0; b < bc; ++b)
for(I b = 0; b < bc; ++b)
{
err = norm_error('F', m, n, lda, hA[b], hARes[b]);
*max_err = err > *max_err ? err : *max_err;
}
}

template <bool STRIDED, bool GEQRF, typename T, typename Td, typename Ud, typename Th, typename Uh>
template <bool STRIDED, bool GEQRF, typename T, typename I, typename Td, typename Ud, typename Th, typename Uh>
void geqr2_geqrf_getPerfData(const rocblas_handle handle,
const rocblas_int m,
const rocblas_int n,
const I m,
const I n,
Td& dA,
const rocblas_int lda,
const I lda,
const rocblas_stride stA,
Ud& dIpiv,
const rocblas_stride stP,
const rocblas_int bc,
const I bc,
Th& hA,
Uh& hIpiv,
double* gpu_time_used,
Expand All @@ -234,7 +234,7 @@ void geqr2_geqrf_getPerfData(const rocblas_handle handle,

// cpu-lapack performance (only if not in perf mode)
*cpu_time_used = get_time_us_no_sync();
for(rocblas_int b = 0; b < bc; ++b)
for(I b = 0; b < bc; ++b)
{
GEQRF ? cpu_geqrf(m, n, hA[b], lda, hIpiv[b], hW.data(), n)
: cpu_geqr2(m, n, hA[b], lda, hIpiv[b], hW.data());
Expand Down Expand Up @@ -280,18 +280,18 @@ void geqr2_geqrf_getPerfData(const rocblas_handle handle,
*gpu_time_used /= hot_calls;
}

template <bool BATCHED, bool STRIDED, bool GEQRF, typename T>
template <bool BATCHED, bool STRIDED, bool GEQRF, typename T, typename I>
void testing_geqr2_geqrf(Arguments& argus)
{
// get arguments
rocblas_local_handle handle;
rocblas_int m = argus.get<rocblas_int>("m");
rocblas_int n = argus.get<rocblas_int>("n", m);
rocblas_int lda = argus.get<rocblas_int>("lda", m);
I m = argus.get<I>("m");
I n = argus.get<I>("n", m);
I lda = argus.get<I>("lda", m);
rocblas_stride stA = argus.get<rocblas_stride>("strideA", lda * n);
rocblas_stride stP = argus.get<rocblas_stride>("strideP", min(m, n));

rocblas_int bc = argus.batch_count;
I bc = argus.batch_count;
rocblas_int hot_calls = argus.iters;

rocblas_stride stARes = (argus.unit_check || argus.norm_check) ? stA : 0;
Expand Down Expand Up @@ -519,4 +519,5 @@ INSTANTIATE(EXTERN_TESTING_GEQR2_GEQRF,
FOREACH_MATRIX_DATA_LAYOUT,
FOREACH_BLOCKED_VARIANT,
FOREACH_SCALAR_TYPE,
FOREACH_INT_TYPE,
APPLY_STAMP)
146 changes: 146 additions & 0 deletions clients/common/misc/rocsolver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6488,6 +6488,152 @@ inline rocblas_status rocsolver_geqr2_geqrf(bool STRIDED,
return GEQRF ? rocsolver_zgeqrf_ptr_batched(handle, m, n, A, lda, ipiv, bc)
: rocblas_status_not_implemented;
}

// normal and strided_batched
inline rocblas_status rocsolver_geqr2_geqrf(bool STRIDED,
bool GEQRF,
rocblas_handle handle,
int64_t m,
int64_t n,
float* A,
int64_t lda,
rocblas_stride stA,
float* ipiv,
rocblas_stride stP,
int64_t bc)
{
if(STRIDED)
return GEQRF ? rocsolver_sgeqrf_strided_batched_64(handle, m, n, A, lda, stA, ipiv, stP, bc)
: rocsolver_sgeqr2_strided_batched_64(handle, m, n, A, lda, stA, ipiv, stP, bc);
else
return GEQRF ? rocsolver_sgeqrf_64(handle, m, n, A, lda, ipiv)
: rocsolver_sgeqr2_64(handle, m, n, A, lda, ipiv);
}

inline rocblas_status rocsolver_geqr2_geqrf(bool STRIDED,
bool GEQRF,
rocblas_handle handle,
int64_t m,
int64_t n,
double* A,
int64_t lda,
rocblas_stride stA,
double* ipiv,
rocblas_stride stP,
int64_t bc)
{
if(STRIDED)
return GEQRF ? rocsolver_dgeqrf_strided_batched_64(handle, m, n, A, lda, stA, ipiv, stP, bc)
: rocsolver_dgeqr2_strided_batched_64(handle, m, n, A, lda, stA, ipiv, stP, bc);
else
return GEQRF ? rocsolver_dgeqrf_64(handle, m, n, A, lda, ipiv)
: rocsolver_dgeqr2_64(handle, m, n, A, lda, ipiv);
}

inline rocblas_status rocsolver_geqr2_geqrf(bool STRIDED,
bool GEQRF,
rocblas_handle handle,
int64_t m,
int64_t n,
rocblas_float_complex* A,
int64_t lda,
rocblas_stride stA,
rocblas_float_complex* ipiv,
rocblas_stride stP,
int64_t bc)
{
if(STRIDED)
return GEQRF ? rocsolver_cgeqrf_strided_batched_64(handle, m, n, A, lda, stA, ipiv, stP, bc)
: rocsolver_cgeqr2_strided_batched_64(handle, m, n, A, lda, stA, ipiv, stP, bc);
else
return GEQRF ? rocsolver_cgeqrf_64(handle, m, n, A, lda, ipiv)
: rocsolver_cgeqr2_64(handle, m, n, A, lda, ipiv);
}

inline rocblas_status rocsolver_geqr2_geqrf(bool STRIDED,
bool GEQRF,
rocblas_handle handle,
int64_t m,
int64_t n,
rocblas_double_complex* A,
int64_t lda,
rocblas_stride stA,
rocblas_double_complex* ipiv,
rocblas_stride stP,
int64_t bc)
{
if(STRIDED)
return GEQRF ? rocsolver_zgeqrf_strided_batched_64(handle, m, n, A, lda, stA, ipiv, stP, bc)
: rocsolver_zgeqr2_strided_batched_64(handle, m, n, A, lda, stA, ipiv, stP, bc);
else
return GEQRF ? rocsolver_zgeqrf_64(handle, m, n, A, lda, ipiv)
: rocsolver_zgeqr2_64(handle, m, n, A, lda, ipiv);
}

// batched
inline rocblas_status rocsolver_geqr2_geqrf(bool STRIDED,
bool GEQRF,
rocblas_handle handle,
int64_t m,
int64_t n,
float* const A[],
int64_t lda,
rocblas_stride stA,
float* ipiv,
rocblas_stride stP,
int64_t bc)
{
return GEQRF ? rocsolver_sgeqrf_batched_64(handle, m, n, A, lda, ipiv, stP, bc)
: rocsolver_sgeqr2_batched_64(handle, m, n, A, lda, ipiv, stP, bc);
}

inline rocblas_status rocsolver_geqr2_geqrf(bool STRIDED,
bool GEQRF,
rocblas_handle handle,
int64_t m,
int64_t n,
double* const A[],
int64_t lda,
rocblas_stride stA,
double* ipiv,
rocblas_stride stP,
int64_t bc)
{
return GEQRF ? rocsolver_dgeqrf_batched_64(handle, m, n, A, lda, ipiv, stP, bc)
: rocsolver_dgeqr2_batched_64(handle, m, n, A, lda, ipiv, stP, bc);
}

inline rocblas_status rocsolver_geqr2_geqrf(bool STRIDED,
bool GEQRF,
rocblas_handle handle,
int64_t m,
int64_t n,
rocblas_float_complex* const A[],
int64_t lda,
rocblas_stride stA,
rocblas_float_complex* ipiv,
rocblas_stride stP,
int64_t bc)
{
return GEQRF ? rocsolver_cgeqrf_batched_64(handle, m, n, A, lda, ipiv, stP, bc)
: rocsolver_cgeqr2_batched_64(handle, m, n, A, lda, ipiv, stP, bc);
}

inline rocblas_status rocsolver_geqr2_geqrf(bool STRIDED,
bool GEQRF,
rocblas_handle handle,
int64_t m,
int64_t n,
rocblas_double_complex* const A[],
int64_t lda,
rocblas_stride stA,
rocblas_double_complex* ipiv,
rocblas_stride stP,
int64_t bc)
{
return GEQRF ? rocsolver_zgeqrf_batched_64(handle, m, n, A, lda, ipiv, stP, bc)
: rocsolver_zgeqr2_batched_64(handle, m, n, A, lda, ipiv, stP, bc);
}
/********************************************************/

/******************** GERQ2_GERQF ********************/
Expand Down
Loading

0 comments on commit e169341

Please sign in to comment.