Skip to content

Commit

Permalink
Add 64-bit versions of lacgv, larf, and larfg (#747)
Browse files Browse the repository at this point in the history
* Added 64-bit larf APIs

* Testing for 64-bit larf API

* Updated documentation

* Add larfg_64

* Add lacgv_64

* add lacgv_64

* simplify template parameters

* modify test coverage, remove unused code

---------

Co-authored-by: Troy Alderson <[email protected]>
  • Loading branch information
qjojo and tfalders authored Jul 17, 2024
1 parent 791dfe4 commit 05343f6
Show file tree
Hide file tree
Showing 20 changed files with 685 additions and 221 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@ Full documentation for rocSOLVER is available at the [rocSOLVER documentation](h

## (Unreleased) rocSOLVER
### Added
- 64-bit APIs for existing functions:
- LACGV_64
- LARF_64
- LARFG_64

### Optimized
### Changed
### Deprecated
Expand Down
2 changes: 1 addition & 1 deletion clients/common/auxiliary/testing_lacgv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,4 @@

#define TESTING_LACGV(...) template void testing_lacgv<__VA_ARGS__>(Arguments&);

INSTANTIATE(TESTING_LACGV, FOREACH_COMPLEX_TYPE, APPLY_STAMP)
INSTANTIATE(TESTING_LACGV, FOREACH_COMPLEX_TYPE, FOREACH_INT_TYPE, APPLY_STAMP)
40 changes: 18 additions & 22 deletions clients/common/auxiliary/testing_lacgv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@
#include "common/misc/rocsolver_arguments.hpp"
#include "common/misc/rocsolver_test.hpp"

template <typename T>
void lacgv_checkBadArgs(const rocblas_handle handle, const rocblas_int n, T dA, const rocblas_int inc)
template <typename T, typename I>
void lacgv_checkBadArgs(const rocblas_handle handle, const I n, T dA, const I inc)
{
// handle
EXPECT_ROCBLAS_STATUS(rocsolver_lacgv(nullptr, n, dA, inc), rocblas_status_invalid_handle);
Expand All @@ -49,16 +49,16 @@ void lacgv_checkBadArgs(const rocblas_handle handle, const rocblas_int n, T dA,
rocblas_status_invalid_pointer);

// quick return with invalid pointers
EXPECT_ROCBLAS_STATUS(rocsolver_lacgv(handle, 0, (T) nullptr, inc), rocblas_status_success);
EXPECT_ROCBLAS_STATUS(rocsolver_lacgv(handle, (I)0, (T) nullptr, inc), rocblas_status_success);
}

template <typename T>
template <typename T, typename I>
void testing_lacgv_bad_arg()
{
// safe arguments
rocblas_local_handle handle;
rocblas_int n = 1;
rocblas_int inc = 1;
I n = 1;
I inc = 1;

// memory allocation
device_strided_batch_vector<T> dA(1, 1, 1, 1);
Expand All @@ -68,12 +68,8 @@ void testing_lacgv_bad_arg()
lacgv_checkBadArgs(handle, n, dA.data(), inc);
}

template <bool CPU, bool GPU, typename T, typename Td, typename Th>
void lacgv_initData(const rocblas_handle handle,
const rocblas_int n,
Td& dA,
const rocblas_int inc,
Th& hA)
template <bool CPU, bool GPU, typename T, typename I, typename Td, typename Th>
void lacgv_initData(const rocblas_handle handle, const I n, Td& dA, const I inc, Th& hA)
{
if(CPU)
{
Expand All @@ -87,11 +83,11 @@ void lacgv_initData(const rocblas_handle handle,
}
}

template <typename T, typename Td, typename Th>
template <typename T, typename Td, typename I, typename Th>
void lacgv_getError(const rocblas_handle handle,
const rocblas_int n,
const I n,
Td& dA,
const rocblas_int inc,
const I inc,
Th& hA,
Th& hAr,
double* max_err)
Expand All @@ -117,11 +113,11 @@ void lacgv_getError(const rocblas_handle handle,
}
}

template <typename T, typename Td, typename Th>
template <typename T, typename I, typename Td, typename Th>
void lacgv_getPerfData(const rocblas_handle handle,
const rocblas_int n,
const I n,
Td& dA,
const rocblas_int inc,
const I inc,
Th& hA,
double* gpu_time_used,
double* cpu_time_used,
Expand Down Expand Up @@ -176,13 +172,13 @@ void lacgv_getPerfData(const rocblas_handle handle,
*gpu_time_used /= hot_calls;
}

template <typename T>
template <typename T, typename I>
void testing_lacgv(Arguments& argus)
{
// get arguments
rocblas_local_handle handle;
rocblas_int n = argus.get<rocblas_int>("n");
rocblas_int inc = argus.get<rocblas_int>("incx");
I n = argus.get<I>("n");
I inc = argus.get<I>("incx");

rocblas_int hot_calls = argus.iters;

Expand Down Expand Up @@ -294,4 +290,4 @@ void testing_lacgv(Arguments& argus)

#define EXTERN_TESTING_LACGV(...) extern template void testing_lacgv<__VA_ARGS__>(Arguments&);

INSTANTIATE(EXTERN_TESTING_LACGV, FOREACH_COMPLEX_TYPE, APPLY_STAMP)
INSTANTIATE(EXTERN_TESTING_LACGV, FOREACH_COMPLEX_TYPE, FOREACH_INT_TYPE, APPLY_STAMP)
2 changes: 1 addition & 1 deletion clients/common/auxiliary/testing_larf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,4 @@

#define TESTING_LARF(...) template void testing_larf<__VA_ARGS__>(Arguments&);

INSTANTIATE(TESTING_LARF, FOREACH_SCALAR_TYPE, APPLY_STAMP)
INSTANTIATE(TESTING_LARF, FOREACH_SCALAR_TYPE, FOREACH_INT_TYPE, APPLY_STAMP)
70 changes: 35 additions & 35 deletions clients/common/auxiliary/testing_larf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,16 @@
#include "common/misc/rocsolver_arguments.hpp"
#include "common/misc/rocsolver_test.hpp"

template <typename T>
template <typename T, typename I>
void larf_checkBadArgs(const rocblas_handle handle,
const rocblas_side side,
const rocblas_int m,
const rocblas_int n,
const I m,
const I n,
T dx,
const rocblas_int inc,
const I inc,
T dt,
T dA,
const rocblas_int lda)
const I lda)
{
// handle
EXPECT_ROCBLAS_STATUS(rocsolver_larf(nullptr, side, m, n, dx, inc, dt, dA, lda),
Expand All @@ -63,24 +63,24 @@ void larf_checkBadArgs(const rocblas_handle handle,
rocblas_status_invalid_pointer);

// quick return with invalid pointers
EXPECT_ROCBLAS_STATUS(rocsolver_larf(handle, rocblas_side_left, 0, n, (T) nullptr, inc,
EXPECT_ROCBLAS_STATUS(rocsolver_larf(handle, rocblas_side_left, (I)0, n, (T) nullptr, inc,
(T) nullptr, (T) nullptr, lda),
rocblas_status_success);
EXPECT_ROCBLAS_STATUS(rocsolver_larf(handle, rocblas_side_right, m, 0, (T) nullptr, inc,
EXPECT_ROCBLAS_STATUS(rocsolver_larf(handle, rocblas_side_right, m, (I)0, (T) nullptr, inc,
(T) nullptr, (T) nullptr, lda),
rocblas_status_success);
}

template <typename T>
template <typename T, typename I>
void testing_larf_bad_arg()
{
// safe arguments
rocblas_local_handle handle;
rocblas_side side = rocblas_side_left;
rocblas_int m = 1;
rocblas_int n = 1;
rocblas_int inc = 1;
rocblas_int lda = 1;
I m = 1;
I n = 1;
I inc = 1;
I lda = 1;

// memory allocation
device_strided_batch_vector<T> dA(1, 1, 1, 1);
Expand All @@ -94,32 +94,32 @@ void testing_larf_bad_arg()
larf_checkBadArgs(handle, side, m, n, dx.data(), inc, dt.data(), dA.data(), lda);
}

template <bool CPU, bool GPU, typename T, typename Td, typename Th>
template <bool CPU, bool GPU, typename T, typename I, typename Td, typename Th>
void larf_initData(const rocblas_handle handle,
const rocblas_side side,
const rocblas_int m,
const rocblas_int n,
const I m,
const I n,
Td& dx,
const rocblas_int inc,
const I inc,
Td& dt,
Td& dA,
const rocblas_int lda,
const I lda,
Th& xx,
Th& hx,
Th& ht,
Th& hA)
{
if(CPU)
{
rocblas_int order = xx.n();
I order = xx.n();

rocblas_init<T>(hA, true);
rocblas_init<T>(xx, true);

// compute householder reflector
cpu_larfg(order, xx[0], xx[0] + abs(inc), abs(inc), ht[0]);
xx[0][0] = 1;
for(rocblas_int i = 0; i < order; i++)
for(I i = 0; i < order; i++)
{
if(inc < 0)
hx[0][i * abs(inc)] = xx[0][(order - 1 - i) * abs(inc)];
Expand All @@ -137,16 +137,16 @@ void larf_initData(const rocblas_handle handle,
}
}

template <typename T, typename Td, typename Th>
template <typename T, typename I, typename Td, typename Th>
void larf_getError(const rocblas_handle handle,
const rocblas_side side,
const rocblas_int m,
const rocblas_int n,
const I m,
const I n,
Td& dx,
const rocblas_int inc,
const I inc,
Td& dt,
Td& dA,
const rocblas_int lda,
const I lda,
Th& xx,
Th& hx,
Th& ht,
Expand Down Expand Up @@ -175,16 +175,16 @@ void larf_getError(const rocblas_handle handle,
*max_err = norm_error('F', m, n, lda, hA[0], hAr[0]);
}

template <typename T, typename Td, typename Th>
template <typename T, typename I, typename Td, typename Th>
void larf_getPerfData(const rocblas_handle handle,
const rocblas_side side,
const rocblas_int m,
const rocblas_int n,
const I m,
const I n,
Td& dx,
const rocblas_int inc,
const I inc,
Td& dt,
Td& dA,
const rocblas_int lda,
const I lda,
Th& xx,
Th& hx,
Th& ht,
Expand Down Expand Up @@ -246,16 +246,16 @@ void larf_getPerfData(const rocblas_handle handle,
*gpu_time_used /= hot_calls;
}

template <typename T>
template <typename T, typename I>
void testing_larf(Arguments& argus)
{
// get arguments
rocblas_local_handle handle;
char sideC = argus.get<char>("side");
rocblas_int m = argus.get<rocblas_int>("m");
rocblas_int n = argus.get<rocblas_int>("n", m);
rocblas_int inc = argus.get<rocblas_int>("incx");
rocblas_int lda = argus.get<rocblas_int>("lda", m);
I m = argus.get<I>("m");
I n = argus.get<I>("n", m);
I inc = argus.get<I>("incx");
I lda = argus.get<I>("lda", m);

rocblas_side side = char2rocblas_side(sideC);
rocblas_int hot_calls = argus.iters;
Expand Down Expand Up @@ -394,4 +394,4 @@ void testing_larf(Arguments& argus)

#define EXTERN_TESTING_LARF(...) extern template void testing_larf<__VA_ARGS__>(Arguments&);

INSTANTIATE(EXTERN_TESTING_LARF, FOREACH_SCALAR_TYPE, APPLY_STAMP)
INSTANTIATE(EXTERN_TESTING_LARF, FOREACH_SCALAR_TYPE, FOREACH_INT_TYPE, APPLY_STAMP)
2 changes: 1 addition & 1 deletion clients/common/auxiliary/testing_larfg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,4 @@

#define TESTING_LARFG(...) template void testing_larfg<__VA_ARGS__>(Arguments&);

INSTANTIATE(TESTING_LARFG, FOREACH_SCALAR_TYPE, APPLY_STAMP)
INSTANTIATE(TESTING_LARFG, FOREACH_SCALAR_TYPE, FOREACH_INT_TYPE, APPLY_STAMP)
Loading

0 comments on commit 05343f6

Please sign in to comment.