From 05343f6f2debccffcd6d53adb688f39d821cb0e6 Mon Sep 17 00:00:00 2001 From: Jonah Quist Date: Wed, 17 Jul 2024 11:26:17 -0600 Subject: [PATCH] Add 64-bit versions of lacgv, larf, and larfg (#747) * Added 64-bit larf APIs * Testing for 64-bit larf API * Updated documentation * Add larfg_64 * Add lacgv_64 * add lacgv_64 * simplify template parameters * modify test coverage, remove unused code --------- Co-authored-by: Troy Alderson <58866654+tfalders@users.noreply.github.com> --- CHANGELOG.md | 5 + clients/common/auxiliary/testing_lacgv.cpp | 2 +- clients/common/auxiliary/testing_lacgv.hpp | 40 ++++---- clients/common/auxiliary/testing_larf.cpp | 2 +- clients/common/auxiliary/testing_larf.hpp | 70 +++++++------- clients/common/auxiliary/testing_larfg.cpp | 2 +- clients/common/auxiliary/testing_larfg.hpp | 43 ++++----- clients/common/misc/rocsolver.hpp | 96 +++++++++++++++++++ clients/common/misc/rocsolver_dispatcher.hpp | 9 +- clients/gtest/auxiliary/lacgv_gtest.cpp | 58 +++++++++-- clients/gtest/auxiliary/larf_gtest.cpp | 91 +++++++++++++++--- clients/gtest/auxiliary/larfg_gtest.cpp | 90 ++++++++++++++--- docs/reference/auxiliary.rst | 28 +++++- .../include/rocsolver/rocsolver-functions.h | 78 +++++++++++++++ library/src/auxiliary/rocauxiliary_lacgv.cpp | 25 ++++- library/src/auxiliary/rocauxiliary_lacgv.hpp | 50 +++++----- library/src/auxiliary/rocauxiliary_larf.cpp | 70 ++++++++++++-- library/src/auxiliary/rocauxiliary_larf.hpp | 44 ++++----- library/src/auxiliary/rocauxiliary_larfg.cpp | 56 +++++++++-- library/src/auxiliary/rocauxiliary_larfg.hpp | 47 ++++----- 20 files changed, 685 insertions(+), 221 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b7bfa5e70..525c4af70 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,11 @@ Full documentation for rocSOLVER is available at the [rocSOLVER documentation](h ## (Unreleased) rocSOLVER ### Added +- 64-bit APIs for existing functions: + - LACGV_64 + - LARF_64 + - LARFG_64 + ### Optimized ### Changed ### Deprecated diff --git a/clients/common/auxiliary/testing_lacgv.cpp b/clients/common/auxiliary/testing_lacgv.cpp index 0ca300410..f8af6b5d4 100644 --- a/clients/common/auxiliary/testing_lacgv.cpp +++ b/clients/common/auxiliary/testing_lacgv.cpp @@ -29,4 +29,4 @@ #define TESTING_LACGV(...) template void testing_lacgv<__VA_ARGS__>(Arguments&); -INSTANTIATE(TESTING_LACGV, FOREACH_COMPLEX_TYPE, APPLY_STAMP) +INSTANTIATE(TESTING_LACGV, FOREACH_COMPLEX_TYPE, FOREACH_INT_TYPE, APPLY_STAMP) diff --git a/clients/common/auxiliary/testing_lacgv.hpp b/clients/common/auxiliary/testing_lacgv.hpp index dde40e168..281b957ba 100644 --- a/clients/common/auxiliary/testing_lacgv.hpp +++ b/clients/common/auxiliary/testing_lacgv.hpp @@ -35,8 +35,8 @@ #include "common/misc/rocsolver_arguments.hpp" #include "common/misc/rocsolver_test.hpp" -template -void lacgv_checkBadArgs(const rocblas_handle handle, const rocblas_int n, T dA, const rocblas_int inc) +template +void lacgv_checkBadArgs(const rocblas_handle handle, const I n, T dA, const I inc) { // handle EXPECT_ROCBLAS_STATUS(rocsolver_lacgv(nullptr, n, dA, inc), rocblas_status_invalid_handle); @@ -49,16 +49,16 @@ void lacgv_checkBadArgs(const rocblas_handle handle, const rocblas_int n, T dA, rocblas_status_invalid_pointer); // quick return with invalid pointers - EXPECT_ROCBLAS_STATUS(rocsolver_lacgv(handle, 0, (T) nullptr, inc), rocblas_status_success); + EXPECT_ROCBLAS_STATUS(rocsolver_lacgv(handle, (I)0, (T) nullptr, inc), rocblas_status_success); } -template +template void testing_lacgv_bad_arg() { // safe arguments rocblas_local_handle handle; - rocblas_int n = 1; - rocblas_int inc = 1; + I n = 1; + I inc = 1; // memory allocation device_strided_batch_vector dA(1, 1, 1, 1); @@ -68,12 +68,8 @@ void testing_lacgv_bad_arg() lacgv_checkBadArgs(handle, n, dA.data(), inc); } -template -void lacgv_initData(const rocblas_handle handle, - const rocblas_int n, - Td& dA, - const rocblas_int inc, - Th& hA) +template +void lacgv_initData(const rocblas_handle handle, const I n, Td& dA, const I inc, Th& hA) { if(CPU) { @@ -87,11 +83,11 @@ void lacgv_initData(const rocblas_handle handle, } } -template +template void lacgv_getError(const rocblas_handle handle, - const rocblas_int n, + const I n, Td& dA, - const rocblas_int inc, + const I inc, Th& hA, Th& hAr, double* max_err) @@ -117,11 +113,11 @@ void lacgv_getError(const rocblas_handle handle, } } -template +template void lacgv_getPerfData(const rocblas_handle handle, - const rocblas_int n, + const I n, Td& dA, - const rocblas_int inc, + const I inc, Th& hA, double* gpu_time_used, double* cpu_time_used, @@ -176,13 +172,13 @@ void lacgv_getPerfData(const rocblas_handle handle, *gpu_time_used /= hot_calls; } -template +template void testing_lacgv(Arguments& argus) { // get arguments rocblas_local_handle handle; - rocblas_int n = argus.get("n"); - rocblas_int inc = argus.get("incx"); + I n = argus.get("n"); + I inc = argus.get("incx"); rocblas_int hot_calls = argus.iters; @@ -294,4 +290,4 @@ void testing_lacgv(Arguments& argus) #define EXTERN_TESTING_LACGV(...) extern template void testing_lacgv<__VA_ARGS__>(Arguments&); -INSTANTIATE(EXTERN_TESTING_LACGV, FOREACH_COMPLEX_TYPE, APPLY_STAMP) +INSTANTIATE(EXTERN_TESTING_LACGV, FOREACH_COMPLEX_TYPE, FOREACH_INT_TYPE, APPLY_STAMP) diff --git a/clients/common/auxiliary/testing_larf.cpp b/clients/common/auxiliary/testing_larf.cpp index 8e4eda0b2..1e67a982a 100644 --- a/clients/common/auxiliary/testing_larf.cpp +++ b/clients/common/auxiliary/testing_larf.cpp @@ -29,4 +29,4 @@ #define TESTING_LARF(...) template void testing_larf<__VA_ARGS__>(Arguments&); -INSTANTIATE(TESTING_LARF, FOREACH_SCALAR_TYPE, APPLY_STAMP) +INSTANTIATE(TESTING_LARF, FOREACH_SCALAR_TYPE, FOREACH_INT_TYPE, APPLY_STAMP) diff --git a/clients/common/auxiliary/testing_larf.hpp b/clients/common/auxiliary/testing_larf.hpp index 608314441..ad76e8ef6 100644 --- a/clients/common/auxiliary/testing_larf.hpp +++ b/clients/common/auxiliary/testing_larf.hpp @@ -35,16 +35,16 @@ #include "common/misc/rocsolver_arguments.hpp" #include "common/misc/rocsolver_test.hpp" -template +template void larf_checkBadArgs(const rocblas_handle handle, const rocblas_side side, - const rocblas_int m, - const rocblas_int n, + const I m, + const I n, T dx, - const rocblas_int inc, + const I inc, T dt, T dA, - const rocblas_int lda) + const I lda) { // handle EXPECT_ROCBLAS_STATUS(rocsolver_larf(nullptr, side, m, n, dx, inc, dt, dA, lda), @@ -63,24 +63,24 @@ void larf_checkBadArgs(const rocblas_handle handle, rocblas_status_invalid_pointer); // quick return with invalid pointers - EXPECT_ROCBLAS_STATUS(rocsolver_larf(handle, rocblas_side_left, 0, n, (T) nullptr, inc, + EXPECT_ROCBLAS_STATUS(rocsolver_larf(handle, rocblas_side_left, (I)0, n, (T) nullptr, inc, (T) nullptr, (T) nullptr, lda), rocblas_status_success); - EXPECT_ROCBLAS_STATUS(rocsolver_larf(handle, rocblas_side_right, m, 0, (T) nullptr, inc, + EXPECT_ROCBLAS_STATUS(rocsolver_larf(handle, rocblas_side_right, m, (I)0, (T) nullptr, inc, (T) nullptr, (T) nullptr, lda), rocblas_status_success); } -template +template void testing_larf_bad_arg() { // safe arguments rocblas_local_handle handle; rocblas_side side = rocblas_side_left; - rocblas_int m = 1; - rocblas_int n = 1; - rocblas_int inc = 1; - rocblas_int lda = 1; + I m = 1; + I n = 1; + I inc = 1; + I lda = 1; // memory allocation device_strided_batch_vector dA(1, 1, 1, 1); @@ -94,16 +94,16 @@ void testing_larf_bad_arg() larf_checkBadArgs(handle, side, m, n, dx.data(), inc, dt.data(), dA.data(), lda); } -template +template void larf_initData(const rocblas_handle handle, const rocblas_side side, - const rocblas_int m, - const rocblas_int n, + const I m, + const I n, Td& dx, - const rocblas_int inc, + const I inc, Td& dt, Td& dA, - const rocblas_int lda, + const I lda, Th& xx, Th& hx, Th& ht, @@ -111,7 +111,7 @@ void larf_initData(const rocblas_handle handle, { if(CPU) { - rocblas_int order = xx.n(); + I order = xx.n(); rocblas_init(hA, true); rocblas_init(xx, true); @@ -119,7 +119,7 @@ void larf_initData(const rocblas_handle handle, // compute householder reflector cpu_larfg(order, xx[0], xx[0] + abs(inc), abs(inc), ht[0]); xx[0][0] = 1; - for(rocblas_int i = 0; i < order; i++) + for(I i = 0; i < order; i++) { if(inc < 0) hx[0][i * abs(inc)] = xx[0][(order - 1 - i) * abs(inc)]; @@ -137,16 +137,16 @@ void larf_initData(const rocblas_handle handle, } } -template +template void larf_getError(const rocblas_handle handle, const rocblas_side side, - const rocblas_int m, - const rocblas_int n, + const I m, + const I n, Td& dx, - const rocblas_int inc, + const I inc, Td& dt, Td& dA, - const rocblas_int lda, + const I lda, Th& xx, Th& hx, Th& ht, @@ -175,16 +175,16 @@ void larf_getError(const rocblas_handle handle, *max_err = norm_error('F', m, n, lda, hA[0], hAr[0]); } -template +template void larf_getPerfData(const rocblas_handle handle, const rocblas_side side, - const rocblas_int m, - const rocblas_int n, + const I m, + const I n, Td& dx, - const rocblas_int inc, + const I inc, Td& dt, Td& dA, - const rocblas_int lda, + const I lda, Th& xx, Th& hx, Th& ht, @@ -246,16 +246,16 @@ void larf_getPerfData(const rocblas_handle handle, *gpu_time_used /= hot_calls; } -template +template void testing_larf(Arguments& argus) { // get arguments rocblas_local_handle handle; char sideC = argus.get("side"); - rocblas_int m = argus.get("m"); - rocblas_int n = argus.get("n", m); - rocblas_int inc = argus.get("incx"); - rocblas_int lda = argus.get("lda", m); + I m = argus.get("m"); + I n = argus.get("n", m); + I inc = argus.get("incx"); + I lda = argus.get("lda", m); rocblas_side side = char2rocblas_side(sideC); rocblas_int hot_calls = argus.iters; @@ -394,4 +394,4 @@ void testing_larf(Arguments& argus) #define EXTERN_TESTING_LARF(...) extern template void testing_larf<__VA_ARGS__>(Arguments&); -INSTANTIATE(EXTERN_TESTING_LARF, FOREACH_SCALAR_TYPE, APPLY_STAMP) +INSTANTIATE(EXTERN_TESTING_LARF, FOREACH_SCALAR_TYPE, FOREACH_INT_TYPE, APPLY_STAMP) diff --git a/clients/common/auxiliary/testing_larfg.cpp b/clients/common/auxiliary/testing_larfg.cpp index 641ba2d9e..662786fbc 100644 --- a/clients/common/auxiliary/testing_larfg.cpp +++ b/clients/common/auxiliary/testing_larfg.cpp @@ -29,4 +29,4 @@ #define TESTING_LARFG(...) template void testing_larfg<__VA_ARGS__>(Arguments&); -INSTANTIATE(TESTING_LARFG, FOREACH_SCALAR_TYPE, APPLY_STAMP) +INSTANTIATE(TESTING_LARFG, FOREACH_SCALAR_TYPE, FOREACH_INT_TYPE, APPLY_STAMP) diff --git a/clients/common/auxiliary/testing_larfg.hpp b/clients/common/auxiliary/testing_larfg.hpp index 831e1814b..0791aefcf 100644 --- a/clients/common/auxiliary/testing_larfg.hpp +++ b/clients/common/auxiliary/testing_larfg.hpp @@ -35,13 +35,8 @@ #include "common/misc/rocsolver_arguments.hpp" #include "common/misc/rocsolver_test.hpp" -template -void larfg_checkBadArgs(const rocblas_handle handle, - const rocblas_int n, - T da, - T dx, - const rocblas_int inc, - T dt) +template +void larfg_checkBadArgs(const rocblas_handle handle, const I n, T da, T dx, const I inc, T dt) { // handle EXPECT_ROCBLAS_STATUS(rocsolver_larfg(nullptr, n, da, dx, inc, dt), @@ -59,17 +54,17 @@ void larfg_checkBadArgs(const rocblas_handle handle, rocblas_status_invalid_pointer); // quick return with invalid pointers - EXPECT_ROCBLAS_STATUS(rocsolver_larfg(handle, 0, (T) nullptr, (T) nullptr, inc, (T) nullptr), + EXPECT_ROCBLAS_STATUS(rocsolver_larfg(handle, (I)0, (T) nullptr, (T) nullptr, inc, (T) nullptr), rocblas_status_success); } -template +template void testing_larfg_bad_arg() { // safe arguments rocblas_local_handle handle; - rocblas_int n = 2; - rocblas_int inc = 1; + I n = 2; + I inc = 1; // memory allocation device_strided_batch_vector da(1, 1, 1, 1); @@ -83,12 +78,12 @@ void testing_larfg_bad_arg() larfg_checkBadArgs(handle, n, da.data(), dx.data(), inc, dt.data()); } -template +template void larfg_initData(const rocblas_handle handle, - const rocblas_int n, + const I n, Td& da, Td& dx, - const rocblas_int inc, + const I inc, Td& dt, Th& ha, Th& hx, @@ -108,12 +103,12 @@ void larfg_initData(const rocblas_handle handle, } } -template +template void larfg_getError(const rocblas_handle handle, - const rocblas_int n, + const I n, Td& da, Td& dx, - const rocblas_int inc, + const I inc, Td& dt, Th& ha, Th& hx, @@ -139,12 +134,12 @@ void larfg_getError(const rocblas_handle handle, *max_err = norm_error('O', 1, n - 1, inc, hx[0], hxr[0]); } -template +template void larfg_getPerfData(const rocblas_handle handle, - const rocblas_int n, + const I n, Td& da, Td& dx, - const rocblas_int inc, + const I inc, Td& dt, Th& ha, Th& hx, @@ -202,13 +197,13 @@ void larfg_getPerfData(const rocblas_handle handle, *gpu_time_used /= hot_calls; } -template +template void testing_larfg(Arguments& argus) { // get arguments rocblas_local_handle handle; - rocblas_int n = argus.get("n"); - rocblas_int inc = argus.get("incx"); + I n = argus.get("n"); + I inc = argus.get("incx"); rocblas_int hot_calls = argus.iters; @@ -332,4 +327,4 @@ void testing_larfg(Arguments& argus) #define EXTERN_TESTING_LARFG(...) extern template void testing_larfg<__VA_ARGS__>(Arguments&); -INSTANTIATE(EXTERN_TESTING_LARFG, FOREACH_SCALAR_TYPE, APPLY_STAMP) +INSTANTIATE(EXTERN_TESTING_LARFG, FOREACH_SCALAR_TYPE, FOREACH_INT_TYPE, APPLY_STAMP) diff --git a/clients/common/misc/rocsolver.hpp b/clients/common/misc/rocsolver.hpp index b36170a9b..d5624b214 100644 --- a/clients/common/misc/rocsolver.hpp +++ b/clients/common/misc/rocsolver.hpp @@ -915,6 +915,18 @@ inline rocblas_status { return rocsolver_zlacgv(handle, n, x, incx); } + +inline rocblas_status + rocsolver_lacgv(rocblas_handle handle, int64_t n, rocblas_float_complex* x, int64_t incx) +{ + return rocsolver_clacgv_64(handle, n, x, incx); +} + +inline rocblas_status + rocsolver_lacgv(rocblas_handle handle, int64_t n, rocblas_double_complex* x, int64_t incx) +{ + return rocsolver_zlacgv_64(handle, n, x, incx); +} /*****************************************************/ /******************** LASWP ********************/ @@ -1007,6 +1019,38 @@ inline rocblas_status rocsolver_larfg(rocblas_handle handle, { return rocsolver_zlarfg(handle, n, alpha, x, incx, tau); } + +inline rocblas_status + rocsolver_larfg(rocblas_handle handle, int64_t n, float* alpha, float* x, int64_t incx, float* tau) +{ + return rocsolver_slarfg_64(handle, n, alpha, x, incx, tau); +} + +inline rocblas_status + rocsolver_larfg(rocblas_handle handle, int64_t n, double* alpha, double* x, int64_t incx, double* tau) +{ + return rocsolver_dlarfg_64(handle, n, alpha, x, incx, tau); +} + +inline rocblas_status rocsolver_larfg(rocblas_handle handle, + int64_t n, + rocblas_float_complex* alpha, + rocblas_float_complex* x, + int64_t incx, + rocblas_float_complex* tau) +{ + return rocsolver_clarfg_64(handle, n, alpha, x, incx, tau); +} + +inline rocblas_status rocsolver_larfg(rocblas_handle handle, + int64_t n, + rocblas_double_complex* alpha, + rocblas_double_complex* x, + int64_t incx, + rocblas_double_complex* tau) +{ + return rocsolver_zlarfg_64(handle, n, alpha, x, incx, tau); +} /*****************************************************/ /******************** LARF ********************/ @@ -1061,6 +1105,58 @@ inline rocblas_status rocsolver_larf(rocblas_handle handle, { return rocsolver_zlarf(handle, side, m, n, x, incx, alpha, A, lda); } + +inline rocblas_status rocsolver_larf(rocblas_handle handle, + rocblas_side side, + int64_t m, + int64_t n, + float* x, + int64_t incx, + float* alpha, + float* A, + int64_t lda) +{ + return rocsolver_slarf_64(handle, side, m, n, x, incx, alpha, A, lda); +} + +inline rocblas_status rocsolver_larf(rocblas_handle handle, + rocblas_side side, + int64_t m, + int64_t n, + double* x, + int64_t incx, + double* alpha, + double* A, + int64_t lda) +{ + return rocsolver_dlarf_64(handle, side, m, n, x, incx, alpha, A, lda); +} + +inline rocblas_status rocsolver_larf(rocblas_handle handle, + rocblas_side side, + int64_t m, + int64_t n, + rocblas_float_complex* x, + int64_t incx, + rocblas_float_complex* alpha, + rocblas_float_complex* A, + int64_t lda) +{ + return rocsolver_clarf_64(handle, side, m, n, x, incx, alpha, A, lda); +} + +inline rocblas_status rocsolver_larf(rocblas_handle handle, + rocblas_side side, + int64_t m, + int64_t n, + rocblas_double_complex* x, + int64_t incx, + rocblas_double_complex* alpha, + rocblas_double_complex* A, + int64_t lda) +{ + return rocsolver_zlarf_64(handle, side, m, n, x, incx, alpha, A, lda); +} /*****************************************************/ /******************** LARFT ********************/ diff --git a/clients/common/misc/rocsolver_dispatcher.hpp b/clients/common/misc/rocsolver_dispatcher.hpp index 7bd09f66a..48a417058 100644 --- a/clients/common/misc/rocsolver_dispatcher.hpp +++ b/clients/common/misc/rocsolver_dispatcher.hpp @@ -132,8 +132,10 @@ class rocsolver_dispatcher // Map for functions that support all precisions static const func_map map = { {"laswp", testing_laswp}, - {"larfg", testing_larfg}, - {"larf", testing_larf}, + {"larfg", testing_larfg}, + {"larfg_64", testing_larfg}, + {"larf", testing_larf}, + {"larf_64", testing_larf}, {"larft", testing_larft}, {"larfb", testing_larfb}, {"latrd", testing_latrd}, @@ -409,7 +411,8 @@ class rocsolver_dispatcher { // Map for functions that support only single-complex and double-complex precisions static const func_map map_complex = { - {"lacgv", testing_lacgv}, + {"lacgv", testing_lacgv}, + {"lacgv_64", testing_lacgv}, // ungxx {"ung2r", testing_orgxr_ungxr}, {"ungqr", testing_orgxr_ungxr}, diff --git a/clients/gtest/auxiliary/lacgv_gtest.cpp b/clients/gtest/auxiliary/lacgv_gtest.cpp index f5d3c790d..f00730559 100644 --- a/clients/gtest/auxiliary/lacgv_gtest.cpp +++ b/clients/gtest/auxiliary/lacgv_gtest.cpp @@ -33,7 +33,8 @@ using ::testing::Values; using ::testing::ValuesIn; using namespace std; -typedef vector lacgv_tuple; +template +using lacgv_tuple = vector; // each range is a {n,inc} @@ -54,21 +55,39 @@ const vector> range = { {30, 3}, {30, -3}}; +const vector> range_64 = { + // quick return + {0, 1}, + // invalid + {-1, 1}, + {1, 0}, + // normal (valid) samples + {10, 1}, + {10, -1}, + {20, 2}, + {30, 3}, + {30, -3}}; + // for daily_lapack tests const vector> large_range = {{192, 10}, {192, -10}, {250, 20}, {500, 30}, {1500, 40}, {1500, -40}}; -Arguments lacgv_setup_arguments(lacgv_tuple tup) +const vector> large_range_64 + = {{192, 10}, {192, -10}, {250, 20}, {500, 30}, {1500, 40}, {1500, -40}}; + +template +Arguments lacgv_setup_arguments(lacgv_tuple tup) { Arguments arg; - arg.set("n", tup[0]); - arg.set("incx", tup[1]); + arg.set("n", tup[0]); + arg.set("incx", tup[1]); return arg; } -class LACGV : public ::TestWithParam +template +class LACGV_BASE : public ::TestWithParam> { protected: void TearDown() override @@ -79,15 +98,22 @@ class LACGV : public ::TestWithParam template void run_tests() { - Arguments arg = lacgv_setup_arguments(GetParam()); + Arguments arg = lacgv_setup_arguments(this->GetParam()); - if(arg.peek("n") == 0) - testing_lacgv_bad_arg(); + if(arg.peek("n") == 0) + testing_lacgv_bad_arg(); - testing_lacgv(arg); + testing_lacgv(arg); } }; +class LACGV : public LACGV_BASE +{ +}; +class LACGV_64 : public LACGV_BASE +{ +}; + // non-batch tests TEST_P(LACGV, __float_complex) @@ -100,6 +126,20 @@ TEST_P(LACGV, __double_complex) run_tests(); } +TEST_P(LACGV_64, __float_complex) +{ + run_tests(); +} + +TEST_P(LACGV_64, __double_complex) +{ + run_tests(); +} + INSTANTIATE_TEST_SUITE_P(daily_lapack, LACGV, ValuesIn(large_range)); INSTANTIATE_TEST_SUITE_P(checkin_lapack, LACGV, ValuesIn(range)); + +INSTANTIATE_TEST_SUITE_P(daily_lapack, LACGV_64, ValuesIn(large_range_64)); + +INSTANTIATE_TEST_SUITE_P(checkin_lapack, LACGV_64, ValuesIn(range_64)); diff --git a/clients/gtest/auxiliary/larf_gtest.cpp b/clients/gtest/auxiliary/larf_gtest.cpp index 1fa290817..d54cf1266 100644 --- a/clients/gtest/auxiliary/larf_gtest.cpp +++ b/clients/gtest/auxiliary/larf_gtest.cpp @@ -33,7 +33,8 @@ using ::testing::Values; using ::testing::ValuesIn; using namespace std; -typedef std::tuple, vector> larf_tuple; +template +using larf_tuple = std::tuple, vector>; // each size_range vector is a {M,N,lda} @@ -69,22 +70,49 @@ const vector> matrix_size_range = { {20, 15, 20}, {35, 35, 50}}; +const vector> matrix_size_range_64 = { + // quick return + {0, 10, 1}, + {10, 0, 10}, + // invalid + {-1, 10, 1}, + {10, -1, 10}, + {10, 10, 5}, + // normal (valid) samples + {12, 20, 12}, + {20, 15, 20}, + {35, 35, 50}}; +const vector> incx_range_64 = { + // invalid + {0, 0}, + // normal (valid) samples + {-10, 0}, + {-5, 1}, + {-1, 0}, + {1, 1}, + {5, 0}, + {10, 1}}; + // for daily_lapack tests const vector> large_matrix_size_range = {{192, 192, 192}, {640, 300, 700}, {1024, 2000, 1024}, {2547, 2547, 2550}}; -Arguments larf_setup_arguments(larf_tuple tup) +const vector> large_matrix_size_range_64 + = {{192, 192, 192}, {640, 300, 700}, {1024, 2000, 1024}, {2547, 2547, 2550}}; + +template +Arguments larf_setup_arguments(larf_tuple tup) { - vector matrix_size = std::get<0>(tup); - vector inc = std::get<1>(tup); + vector matrix_size = std::get<0>(tup); + vector inc = std::get<1>(tup); Arguments arg; - arg.set("m", matrix_size[0]); - arg.set("n", matrix_size[1]); - arg.set("lda", matrix_size[2]); + arg.set("m", matrix_size[0]); + arg.set("n", matrix_size[1]); + arg.set("lda", matrix_size[2]); - arg.set("incx", inc[0]); + arg.set("incx", inc[0]); arg.set("side", inc[1] == 1 ? 'R' : 'L'); arg.timing = 0; @@ -92,7 +120,8 @@ Arguments larf_setup_arguments(larf_tuple tup) return arg; } -class LARF : public ::TestWithParam +template +class LARF_BASE : public ::TestWithParam> { protected: void TearDown() override @@ -103,15 +132,23 @@ class LARF : public ::TestWithParam template void run_tests() { - Arguments arg = larf_setup_arguments(GetParam()); + Arguments arg = larf_setup_arguments(this->GetParam()); - if(arg.peek("m") == 0 && arg.peek("incx") == 0) - testing_larf_bad_arg(); + if(arg.peek("m") == 0 && arg.peek("incx") == 0) + testing_larf_bad_arg(); - testing_larf(arg); + testing_larf(arg); } }; +class LARF : public LARF_BASE +{ +}; + +class LARF_64 : public LARF_BASE +{ +}; + // non-batch tests TEST_P(LARF, __float) @@ -134,6 +171,26 @@ TEST_P(LARF, __double_complex) run_tests(); } +TEST_P(LARF_64, __float) +{ + run_tests(); +} + +TEST_P(LARF_64, __double) +{ + run_tests(); +} + +TEST_P(LARF_64, __float_complex) +{ + run_tests(); +} + +TEST_P(LARF_64, __double_complex) +{ + run_tests(); +} + INSTANTIATE_TEST_SUITE_P(daily_lapack, LARF, Combine(ValuesIn(large_matrix_size_range), ValuesIn(incx_range))); @@ -141,3 +198,11 @@ INSTANTIATE_TEST_SUITE_P(daily_lapack, INSTANTIATE_TEST_SUITE_P(checkin_lapack, LARF, Combine(ValuesIn(matrix_size_range), ValuesIn(incx_range))); + +INSTANTIATE_TEST_SUITE_P(daily_lapack, + LARF_64, + Combine(ValuesIn(large_matrix_size_range_64), ValuesIn(incx_range_64))); + +INSTANTIATE_TEST_SUITE_P(checkin_lapack, + LARF_64, + Combine(ValuesIn(matrix_size_range_64), ValuesIn(incx_range_64))); diff --git a/clients/gtest/auxiliary/larfg_gtest.cpp b/clients/gtest/auxiliary/larfg_gtest.cpp index d484e2f58..eec5cbf5f 100644 --- a/clients/gtest/auxiliary/larfg_gtest.cpp +++ b/clients/gtest/auxiliary/larfg_gtest.cpp @@ -33,7 +33,8 @@ using ::testing::Values; using ::testing::ValuesIn; using namespace std; -typedef std::tuple larfg_tuple; +template +using larfg_tuple = std::tuple; // case when n = 0 and incx = 0 also execute the bad arguments test // (null handle, null pointers and invalid values) @@ -49,6 +50,17 @@ const vector incx_range = { 10, }; +const vector incx_range_64 = { + // invalid + -1, + 0, + // normal (valid) samples + 1, + 5, + 8, + 10, +}; + // for checkin_lapack tests const vector n_size_range = { // quick return @@ -62,6 +74,18 @@ const vector n_size_range = { 35, }; +const vector n_size_range_64 = { + // quick return + 0, + // invalid + -1, + // normal (valid) samples + 1, + 12, + 20, + 35, +}; + // for daily_lapack tests const vector large_n_size_range = { 192, @@ -70,22 +94,31 @@ const vector large_n_size_range = { 2547, }; -Arguments larfg_setup_arguments(larfg_tuple tup) +const vector large_n_size_range_64 = { + 192, + 640, + 1024, + 2547, +}; + +template +Arguments larfg_setup_arguments(larfg_tuple tup) { - int n_size = std::get<0>(tup); - int inc = std::get<1>(tup); + I n_size = std::get<0>(tup); + I inc = std::get<1>(tup); Arguments arg; - arg.set("n", n_size); - arg.set("incx", inc); + arg.set("n", n_size); + arg.set("incx", inc); arg.timing = 0; return arg; } -class LARFG : public ::TestWithParam +template +class LARFG_BASE : public ::TestWithParam> { protected: void TearDown() override @@ -96,15 +129,22 @@ class LARFG : public ::TestWithParam template void run_tests() { - Arguments arg = larfg_setup_arguments(GetParam()); + Arguments arg = larfg_setup_arguments(this->GetParam()); - if(arg.peek("n") == 0 && arg.peek("incx") == 0) - testing_larfg_bad_arg(); + if(arg.peek("n") == 0 && arg.peek("incx") == 0) + testing_larfg_bad_arg(); - testing_larfg(arg); + testing_larfg(arg); } }; +class LARFG : public LARFG_BASE +{ +}; +class LARFG_64 : public LARFG_BASE +{ +}; + // non-batch tests TEST_P(LARFG, __float) @@ -127,8 +167,36 @@ TEST_P(LARFG, __double_complex) run_tests(); } +TEST_P(LARFG_64, __float) +{ + run_tests(); +} + +TEST_P(LARFG_64, __double) +{ + run_tests(); +} + +TEST_P(LARFG_64, __float_complex) +{ + run_tests(); +} + +TEST_P(LARFG_64, __double_complex) +{ + run_tests(); +} + INSTANTIATE_TEST_SUITE_P(daily_lapack, LARFG, Combine(ValuesIn(large_n_size_range), ValuesIn(incx_range))); INSTANTIATE_TEST_SUITE_P(checkin_lapack, LARFG, Combine(ValuesIn(n_size_range), ValuesIn(incx_range))); + +INSTANTIATE_TEST_SUITE_P(daily_lapack, + LARFG_64, + Combine(ValuesIn(large_n_size_range_64), ValuesIn(incx_range_64))); + +INSTANTIATE_TEST_SUITE_P(checkin_lapack, + LARFG_64, + Combine(ValuesIn(n_size_range_64), ValuesIn(incx_range_64))); diff --git a/docs/reference/auxiliary.rst b/docs/reference/auxiliary.rst index 4b1497f82..18f05a86c 100644 --- a/docs/reference/auxiliary.rst +++ b/docs/reference/auxiliary.rst @@ -24,13 +24,13 @@ The auxiliary functions are divided into the following categories: * i, j, and k are used as general purpose indices. In some legacy LAPACK APIs, k could be a parameter indicating some problem/matrix dimension. - * Depending on the context, when it is necessary to index rows, columns and blocks or submatrices, - i is assigned to rows, j to columns and k to blocks. l is always used to index - matrices/problems in a batch. + * Depending on the context, when it is necessary to index rows, columns and blocks or submatrices, + i is assigned to rows, j to columns and k to blocks. l is always used to index + matrices/problems in a batch. * x[i] stands for the i-th element of vector x, while A[i,j] represents the element in the i-th row and j-th column of matrix A. Indices are 1-based, i.e. x[1] is the first element of x. - * To identify a block in a matrix or a matrix in the batch, k and l are used as sub-indices + * To identify a block in a matrix or a matrix in the batch, k and l are used as sub-indices * x_i :math:`=x_i`; we sometimes use both notations, :math:`x_i` when displaying mathematical equations, and x_i in the text describing the function parameters. * If X is a real vector or matrix, :math:`X^T` indicates its transpose; if X is complex, then @@ -54,6 +54,10 @@ Vector and Matrix manipulations rocsolver_lacgv() --------------------------------------- +.. doxygenfunction:: rocsolver_zlacgv_64 + :outline: +.. doxygenfunction:: rocsolver_clacgv_64 + :outline .. doxygenfunction:: rocsolver_zlacgv :outline: .. doxygenfunction:: rocsolver_clacgv @@ -97,6 +101,14 @@ Householder reflections rocsolver_larfg() --------------------------------------- +.. doxygenfunction:: rocsolver_zlarfg_64 + :outline: +.. doxygenfunction:: rocsolver_clarfg_64 + :outline: +.. doxygenfunction:: rocsolver_dlarfg_64 + :outline: +.. doxygenfunction:: rocsolver_slarfg_64 + :outline: .. doxygenfunction:: rocsolver_zlarfg :outline: .. doxygenfunction:: rocsolver_clarfg @@ -121,6 +133,14 @@ rocsolver_larft() rocsolver_larf() --------------------------------------- +.. doxygenfunction:: rocsolver_zlarf_64 + :outline: +.. doxygenfunction:: rocsolver_clarf_64 + :outline: +.. doxygenfunction:: rocsolver_dlarf_64 + :outline: +.. doxygenfunction:: rocsolver_slarf_64 + :outline: .. doxygenfunction:: rocsolver_zlarf :outline: .. doxygenfunction:: rocsolver_clarf diff --git a/library/include/rocsolver/rocsolver-functions.h b/library/include/rocsolver/rocsolver-functions.h index d9f573e49..32f5a77d5 100644 --- a/library/include/rocsolver/rocsolver-functions.h +++ b/library/include/rocsolver/rocsolver-functions.h @@ -178,6 +178,16 @@ ROCSOLVER_EXPORT rocblas_status rocsolver_zlacgv(rocblas_handle handle, const rocblas_int n, rocblas_double_complex* x, const rocblas_int incx); + +ROCSOLVER_EXPORT rocblas_status rocsolver_clacgv_64(rocblas_handle handle, + const int64_t n, + rocblas_float_complex* x, + const int64_t incx); + +ROCSOLVER_EXPORT rocblas_status rocsolver_zlacgv_64(rocblas_handle handle, + const int64_t n, + rocblas_double_complex* x, + const int64_t incx); //! @} /*! @{ @@ -346,6 +356,34 @@ ROCSOLVER_EXPORT rocblas_status rocsolver_zlarfg(rocblas_handle handle, rocblas_double_complex* x, const rocblas_int incx, rocblas_double_complex* tau); + +ROCSOLVER_EXPORT rocblas_status rocsolver_slarfg_64(rocblas_handle handle, + const int64_t n, + float* alpha, + float* x, + const int64_t incx, + float* tau); + +ROCSOLVER_EXPORT rocblas_status rocsolver_dlarfg_64(rocblas_handle handle, + const int64_t n, + double* alpha, + double* x, + const int64_t incx, + double* tau); + +ROCSOLVER_EXPORT rocblas_status rocsolver_clarfg_64(rocblas_handle handle, + const int64_t n, + rocblas_float_complex* alpha, + rocblas_float_complex* x, + const int64_t incx, + rocblas_float_complex* tau); + +ROCSOLVER_EXPORT rocblas_status rocsolver_zlarfg_64(rocblas_handle handle, + const int64_t n, + rocblas_double_complex* alpha, + rocblas_double_complex* x, + const int64_t incx, + rocblas_double_complex* tau); //! @} /*! @{ @@ -537,6 +575,46 @@ ROCSOLVER_EXPORT rocblas_status rocsolver_zlarf(rocblas_handle handle, const rocblas_double_complex* alpha, rocblas_double_complex* A, const rocblas_int lda); + +ROCSOLVER_EXPORT rocblas_status rocsolver_slarf_64(rocblas_handle handle, + const rocblas_side side, + const int64_t m, + const int64_t n, + float* x, + const int64_t incx, + const float* alpha, + float* A, + const int64_t lda); + +ROCSOLVER_EXPORT rocblas_status rocsolver_dlarf_64(rocblas_handle handle, + const rocblas_side side, + const int64_t m, + const int64_t n, + double* x, + const int64_t incx, + const double* alpha, + double* A, + const int64_t lda); + +ROCSOLVER_EXPORT rocblas_status rocsolver_clarf_64(rocblas_handle handle, + const rocblas_side side, + const int64_t m, + const int64_t n, + rocblas_float_complex* x, + const int64_t incx, + const rocblas_float_complex* alpha, + rocblas_float_complex* A, + const int64_t lda); + +ROCSOLVER_EXPORT rocblas_status rocsolver_zlarf_64(rocblas_handle handle, + const rocblas_side side, + const int64_t m, + const int64_t n, + rocblas_double_complex* x, + const int64_t incx, + const rocblas_double_complex* alpha, + rocblas_double_complex* A, + const int64_t lda); //! @} /*! @{ diff --git a/library/src/auxiliary/rocauxiliary_lacgv.cpp b/library/src/auxiliary/rocauxiliary_lacgv.cpp index 38fea7e9c..7c848b704 100644 --- a/library/src/auxiliary/rocauxiliary_lacgv.cpp +++ b/library/src/auxiliary/rocauxiliary_lacgv.cpp @@ -29,9 +29,8 @@ ROCSOLVER_BEGIN_NAMESPACE -template -rocblas_status - rocsolver_lacgv_impl(rocblas_handle handle, const rocblas_int n, T* x, const rocblas_int incx) +template +rocblas_status rocsolver_lacgv_impl(rocblas_handle handle, const I n, T* x, const I incx) { ROCSOLVER_ENTER_TOP("lacgv", "-n", n, "--incx", incx); @@ -44,11 +43,11 @@ rocblas_status return st; // working with unshifted arrays - rocblas_int shiftx = 0; + rocblas_stride shiftx = 0; // normal (non-batched non-strided) execution rocblas_stride stridex = 0; - rocblas_int batch_count = 1; + I batch_count = 1; // this function does not require memory work space if(rocblas_is_device_memory_size_query(handle)) @@ -84,4 +83,20 @@ rocblas_status rocsolver_zlacgv(rocblas_handle handle, return rocsolver::rocsolver_lacgv_impl(handle, n, x, incx); } +rocblas_status rocsolver_clacgv_64(rocblas_handle handle, + const int64_t n, + rocblas_float_complex* x, + const int64_t incx) +{ + return rocsolver::rocsolver_lacgv_impl(handle, n, x, incx); +} + +rocblas_status rocsolver_zlacgv_64(rocblas_handle handle, + const int64_t n, + rocblas_double_complex* x, + const int64_t incx) +{ + return rocsolver::rocsolver_lacgv_impl(handle, n, x, incx); +} + } // extern C diff --git a/library/src/auxiliary/rocauxiliary_lacgv.hpp b/library/src/auxiliary/rocauxiliary_lacgv.hpp index a0e212f78..037d04ed7 100644 --- a/library/src/auxiliary/rocauxiliary_lacgv.hpp +++ b/library/src/auxiliary/rocauxiliary_lacgv.hpp @@ -37,38 +37,37 @@ ROCSOLVER_BEGIN_NAMESPACE -template , int> = 0> -ROCSOLVER_KERNEL void conj_in_place(const rocblas_int m, - const rocblas_int n, +template , int> = 0> +ROCSOLVER_KERNEL void conj_in_place(const I m, + const I n, U A, - const rocblas_int shifta, - const rocblas_int lda, + const rocblas_stride shifta, + const I lda, const rocblas_stride stridea) { // do nothing } -template , int> = 0> -ROCSOLVER_KERNEL void conj_in_place(const rocblas_int m, - const rocblas_int n, +template , int> = 0> +ROCSOLVER_KERNEL void conj_in_place(const I m, + const I n, U A, - const rocblas_int shifta, - const rocblas_int lda, + const rocblas_stride shifta, + const I lda, const rocblas_stride stridea) { - int i = hipBlockIdx_x * hipBlockDim_x + hipThreadIdx_x; - int j = hipBlockIdx_y * hipBlockDim_y + hipThreadIdx_y; - int b = hipBlockIdx_z; + I i = hipBlockIdx_x * hipBlockDim_x + hipThreadIdx_x; + I j = hipBlockIdx_y * hipBlockDim_y + hipThreadIdx_y; + I b = hipBlockIdx_z; T* Ap = load_ptr_batch(A, b, shifta, stridea); if(i < m && j < n) - Ap[i + j * lda] = conj(Ap[i + j * lda]); + Ap[i + j * lda] = conj(Ap[i + j * (int64_t)lda]); } -template -rocblas_status - rocsolver_lacgv_argCheck(rocblas_handle handle, const rocblas_int n, const rocblas_int incx, T x) +template +rocblas_status rocsolver_lacgv_argCheck(rocblas_handle handle, const I n, const I incx, T x) { // order is important for unit tests: @@ -90,14 +89,14 @@ rocblas_status return rocblas_status_continue; } -template > +template > rocblas_status rocsolver_lacgv_template(rocblas_handle handle, - const rocblas_int n, + const I n, U x, - const rocblas_int shiftx, - const rocblas_int incx, + const rocblas_stride shiftx, + const I incx, const rocblas_stride stridex, - const rocblas_int batch_count) + const I batch_count) { ROCSOLVER_ENTER("lacgv", "n:", n, "shiftX:", shiftx, "incx:", incx, "bc:", batch_count); @@ -109,12 +108,13 @@ rocblas_status rocsolver_lacgv_template(rocblas_handle handle, rocblas_get_stream(handle, &stream); // handle negative increments - rocblas_int offset = incx < 0 ? shiftx - (n - 1) * incx : shiftx; + rocblas_stride offset = incx < 0 ? shiftx - (n - 1) * incx : shiftx; // conjugate x - rocblas_int blocks = (n - 1) / 64 + 1; + constexpr int LACGV_NTHREADS = 64; + I blocks = (n - 1) / LACGV_NTHREADS + 1; ROCSOLVER_LAUNCH_KERNEL(conj_in_place, dim3(1, blocks, batch_count), dim3(1, 64, 1), 0, - stream, 1, n, x, offset, incx, stridex); + stream, (I)1, n, x, offset, incx, stridex); return rocblas_status_success; } diff --git a/library/src/auxiliary/rocauxiliary_larf.cpp b/library/src/auxiliary/rocauxiliary_larf.cpp index 296bdfb68..154ab109b 100644 --- a/library/src/auxiliary/rocauxiliary_larf.cpp +++ b/library/src/auxiliary/rocauxiliary_larf.cpp @@ -29,16 +29,16 @@ ROCSOLVER_BEGIN_NAMESPACE -template +template rocblas_status rocsolver_larf_impl(rocblas_handle handle, const rocblas_side side, - const rocblas_int m, - const rocblas_int n, + const I m, + const I n, T* x, - const rocblas_int incx, + const I incx, const T* alpha, T* A, - const rocblas_int lda) + const I lda) { ROCSOLVER_ENTER_TOP("larf", "--side", side, "-m", m, "-n", n, "--incx", incx, "--lda", lda); @@ -51,14 +51,14 @@ rocblas_status rocsolver_larf_impl(rocblas_handle handle, return st; // working with unshifted arrays - rocblas_int shiftA = 0; - rocblas_int shiftx = 0; + rocblas_stride shiftA = 0; + rocblas_stride shiftx = 0; // normal (non-batched non-strided) execution rocblas_stride stridex = 0; rocblas_stride stridea = 0; rocblas_stride stridep = 0; - rocblas_int batch_count = 1; + I batch_count = 1; // memory workspace sizes: // size for constants in rocblas calls @@ -155,4 +155,58 @@ rocblas_status rocsolver_zlarf(rocblas_handle handle, alpha, A, lda); } +rocblas_status rocsolver_slarf_64(rocblas_handle handle, + const rocblas_side side, + const int64_t m, + const int64_t n, + float* x, + const int64_t incx, + const float* alpha, + float* A, + const int64_t lda) +{ + return rocsolver::rocsolver_larf_impl(handle, side, m, n, x, incx, alpha, A, lda); +} + +rocblas_status rocsolver_dlarf_64(rocblas_handle handle, + const rocblas_side side, + const int64_t m, + const int64_t n, + double* x, + const int64_t incx, + const double* alpha, + double* A, + const int64_t lda) +{ + return rocsolver::rocsolver_larf_impl(handle, side, m, n, x, incx, alpha, A, lda); +} + +rocblas_status rocsolver_clarf_64(rocblas_handle handle, + const rocblas_side side, + const int64_t m, + const int64_t n, + rocblas_float_complex* x, + const int64_t incx, + const rocblas_float_complex* alpha, + rocblas_float_complex* A, + const int64_t lda) +{ + return rocsolver::rocsolver_larf_impl(handle, side, m, n, x, incx, alpha, + A, lda); +} + +rocblas_status rocsolver_zlarf_64(rocblas_handle handle, + const rocblas_side side, + const int64_t m, + const int64_t n, + rocblas_double_complex* x, + const int64_t incx, + const rocblas_double_complex* alpha, + rocblas_double_complex* A, + const int64_t lda) +{ + return rocsolver::rocsolver_larf_impl(handle, side, m, n, x, incx, + alpha, A, lda); +} + } // extern C diff --git a/library/src/auxiliary/rocauxiliary_larf.hpp b/library/src/auxiliary/rocauxiliary_larf.hpp index bbef04016..7da4881c6 100644 --- a/library/src/auxiliary/rocauxiliary_larf.hpp +++ b/library/src/auxiliary/rocauxiliary_larf.hpp @@ -37,11 +37,11 @@ ROCSOLVER_BEGIN_NAMESPACE -template +template void rocsolver_larf_getMemorySize(const rocblas_side side, - const rocblas_int m, - const rocblas_int n, - const rocblas_int batch_count, + const I m, + const I n, + const I batch_count, size_t* size_scalars, size_t* size_Abyx, size_t* size_workArr) @@ -74,13 +74,13 @@ void rocsolver_larf_getMemorySize(const rocblas_side side, *size_workArr = 0; } -template +template rocblas_status rocsolver_larf_argCheck(rocblas_handle handle, const rocblas_side side, - const rocblas_int m, - const rocblas_int n, - const rocblas_int lda, - const rocblas_int incx, + const I m, + const I n, + const I lda, + const I incx, T x, T A, U alpha) @@ -107,22 +107,22 @@ rocblas_status rocsolver_larf_argCheck(rocblas_handle handle, return rocblas_status_continue; } -template > +template > rocblas_status rocsolver_larf_template(rocblas_handle handle, const rocblas_side side, - const rocblas_int m, - const rocblas_int n, + const I m, + const I n, U x, - const rocblas_int shiftx, - const rocblas_int incx, + const rocblas_stride shiftx, + const I incx, const rocblas_stride stridex, const T* alpha, const rocblas_stride stridep, U A, - const rocblas_int shiftA, - const rocblas_int lda, + const rocblas_stride shiftA, + const I lda, const rocblas_stride stridea, - const rocblas_int batch_count, + const I batch_count, T* scalars, T* Abyx, T** workArr) @@ -144,7 +144,7 @@ rocblas_status rocsolver_larf_template(rocblas_handle handle, // determine side and order of H bool leftside = (side == rocblas_side_left); - rocblas_int order = m; + I order = m; rocblas_operation trans = rocblas_operation_none; if(leftside) { @@ -165,13 +165,13 @@ rocblas_status rocsolver_larf_template(rocblas_handle handle, // compute the rank-1 update (A + tau*X*W' or A + tau*W*X') if(leftside) { - rocblasCall_ger(handle, m, n, alpha, stridep, x, shiftx, incx, stridex, Abyx, 0, - 1, order, A, shiftA, lda, stridea, batch_count, workArr); + rocblasCall_ger(handle, m, n, alpha, stridep, x, shiftx, incx, stridex, Abyx, + 0, 1, order, A, shiftA, lda, stridea, batch_count, workArr); } else { - rocblasCall_ger(handle, m, n, alpha, stridep, Abyx, 0, 1, order, x, shiftx, - incx, stridex, A, shiftA, lda, stridea, batch_count, workArr); + rocblasCall_ger(handle, m, n, alpha, stridep, Abyx, 0, 1, order, x, shiftx, + incx, stridex, A, shiftA, lda, stridea, batch_count, workArr); } rocblas_set_pointer_mode(handle, old_mode); diff --git a/library/src/auxiliary/rocauxiliary_larfg.cpp b/library/src/auxiliary/rocauxiliary_larfg.cpp index a439ecee6..3489e4751 100644 --- a/library/src/auxiliary/rocauxiliary_larfg.cpp +++ b/library/src/auxiliary/rocauxiliary_larfg.cpp @@ -29,13 +29,9 @@ ROCSOLVER_BEGIN_NAMESPACE -template -rocblas_status rocsolver_larfg_impl(rocblas_handle handle, - const rocblas_int n, - T* alpha, - T* x, - const rocblas_int incx, - T* tau) +template +rocblas_status + rocsolver_larfg_impl(rocblas_handle handle, const I n, T* alpha, T* x, const I incx, T* tau) { // TODO: How to get alpha for bench logging ROCSOLVER_ENTER_TOP("larfg", "-n", n, "--incx", incx); @@ -49,13 +45,13 @@ rocblas_status rocsolver_larfg_impl(rocblas_handle handle, return st; // working with unshifted arrays - rocblas_int shifta = 0; - rocblas_int shiftx = 0; + rocblas_stride shifta = 0; + rocblas_stride shiftx = 0; // normal (non-batched non-strided) execution rocblas_stride stridex = 0; rocblas_stride strideP = 0; - rocblas_int batch_count = 1; + I batch_count = 1; // memory workspace sizes: // size of re-usable workspace @@ -131,4 +127,44 @@ rocblas_status rocsolver_zlarfg(rocblas_handle handle, return rocsolver::rocsolver_larfg_impl(handle, n, alpha, x, incx, tau); } +rocblas_status rocsolver_slarfg_64(rocblas_handle handle, + const int64_t n, + float* alpha, + float* x, + const int64_t incx, + float* tau) +{ + return rocsolver::rocsolver_larfg_impl(handle, n, alpha, x, incx, tau); +} + +rocblas_status rocsolver_dlarfg_64(rocblas_handle handle, + const int64_t n, + double* alpha, + double* x, + const int64_t incx, + double* tau) +{ + return rocsolver::rocsolver_larfg_impl(handle, n, alpha, x, incx, tau); +} + +rocblas_status rocsolver_clarfg_64(rocblas_handle handle, + const int64_t n, + rocblas_float_complex* alpha, + rocblas_float_complex* x, + const int64_t incx, + rocblas_float_complex* tau) +{ + return rocsolver::rocsolver_larfg_impl(handle, n, alpha, x, incx, tau); +} + +rocblas_status rocsolver_zlarfg_64(rocblas_handle handle, + const int64_t n, + rocblas_double_complex* alpha, + rocblas_double_complex* x, + const int64_t incx, + rocblas_double_complex* tau) +{ + return rocsolver::rocsolver_larfg_impl(handle, n, alpha, x, incx, tau); +} + } // extern C diff --git a/library/src/auxiliary/rocauxiliary_larfg.hpp b/library/src/auxiliary/rocauxiliary_larfg.hpp index 0d0a36ba4..a849130f9 100644 --- a/library/src/auxiliary/rocauxiliary_larfg.hpp +++ b/library/src/auxiliary/rocauxiliary_larfg.hpp @@ -37,15 +37,15 @@ ROCSOLVER_BEGIN_NAMESPACE -template , int> = 0> +template , int> = 0> ROCSOLVER_KERNEL void set_taubeta(T* tau, const rocblas_stride strideP, T* norms, U alpha, - const rocblas_int shifta, + const rocblas_stride shifta, const rocblas_stride stride) { - int b = hipBlockIdx_x; + I b = hipBlockIdx_x; T* a = load_ptr_batch(alpha, b, shifta, stride); T* t = tau + b * strideP; @@ -71,16 +71,16 @@ ROCSOLVER_KERNEL void set_taubeta(T* tau, } } -template , int> = 0> +template , int> = 0> ROCSOLVER_KERNEL void set_taubeta(T* tau, const rocblas_stride strideP, T* norms, U alpha, - const rocblas_int shifta, + const rocblas_stride shifta, const rocblas_stride stride) { using S = decltype(std::real(T{})); - int b = hipBlockIdx_x; + I b = hipBlockIdx_x; S r, rr, ri, ar, ai; T* a = load_ptr_batch(alpha, b, shifta, stride); @@ -119,11 +119,8 @@ ROCSOLVER_KERNEL void set_taubeta(T* tau, } } -template -void rocsolver_larfg_getMemorySize(const rocblas_int n, - const rocblas_int batch_count, - size_t* size_work, - size_t* size_norms) +template +void rocsolver_larfg_getMemorySize(const I n, const I batch_count, size_t* size_work, size_t* size_norms) { // if quick return no workspace needed if(n == 0 || batch_count == 0) @@ -138,18 +135,14 @@ void rocsolver_larfg_getMemorySize(const rocblas_int n, // size of re-usable workspace // TODO: replace with rocBLAS call - constexpr int ROCBLAS_DOT_NB = 512; + constexpr I ROCBLAS_DOT_NB = 512; *size_work = n > 2 ? (n - 2) / ROCBLAS_DOT_NB + 2 : 1; *size_work *= sizeof(T) * batch_count; } -template -rocblas_status rocsolver_larfg_argCheck(rocblas_handle handle, - const rocblas_int n, - const rocblas_int incx, - T alpha, - T x, - U tau) +template +rocblas_status + rocsolver_larfg_argCheck(rocblas_handle handle, const I n, const I incx, T alpha, T x, U tau) { // order is important for unit tests: @@ -171,18 +164,18 @@ rocblas_status rocsolver_larfg_argCheck(rocblas_handle handle, return rocblas_status_continue; } -template > +template > rocblas_status rocsolver_larfg_template(rocblas_handle handle, - const rocblas_int n, + const I n, U alpha, - const rocblas_int shifta, + const rocblas_stride shifta, U x, - const rocblas_int shiftx, - const rocblas_int incx, + const rocblas_stride shiftx, + const I incx, const rocblas_stride stridex, T* tau, const rocblas_stride strideP, - const rocblas_int batch_count, + const I batch_count, T* work, T* norms) { @@ -218,8 +211,8 @@ rocblas_status rocsolver_larfg_template(rocblas_handle handle, // set value of tau and beta and scalling factor for vector x // alpha <- beta, norms <- scaling - ROCSOLVER_LAUNCH_KERNEL(set_taubeta, dim3(batch_count), dim3(1), 0, stream, tau, strideP, - norms, alpha, shifta, stridex); + ROCSOLVER_LAUNCH_KERNEL((set_taubeta), dim3(batch_count), dim3(1), 0, stream, tau, + strideP, norms, alpha, shifta, stridex); // compute vector v=x*norms rocblasCall_scal(handle, n - 1, norms, 1, x, shiftx, incx, stridex, batch_count);