Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hybrid execution of sterf #865

Open
wants to merge 10 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ Full documentation for rocSOLVER is available at the [rocSOLVER documentation](h
## (Unreleased) rocSOLVER

### Added

* Hybrid computation support for existing routines:
- STERF

### Changed
### Removed
### Optimized
Expand All @@ -22,7 +26,7 @@ Full documentation for rocSOLVER is available at the [rocSOLVER documentation](h
* Algorithm selection mechanism for hybrid computation
* Hybrid computation support for existing routines:
- BDSQR
- GESVD
- GESVD

### Optimized

Expand Down
13 changes: 13 additions & 0 deletions clients/common/auxiliary/testing_sterf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,19 @@ void testing_sterf(Arguments& argus)

rocblas_int hot_calls = argus.iters;

if(argus.alg_mode)
{
EXPECT_ROCBLAS_STATUS(
rocsolver_set_alg_mode(handle, rocsolver_function_sterf, rocsolver_alg_mode_hybrid),
rocblas_status_success);

rocsolver_alg_mode alg_mode;
EXPECT_ROCBLAS_STATUS(rocsolver_get_alg_mode(handle, rocsolver_function_sterf, &alg_mode),
rocblas_status_success);

EXPECT_EQ(alg_mode, rocsolver_alg_mode_hybrid);
}

// check non-supported values
// N/A

Expand Down
13 changes: 13 additions & 0 deletions clients/common/lapack/testing_syev_heev.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,19 @@ void testing_syev_heev(Arguments& argus)
rocblas_int bc = argus.batch_count;
rocblas_int hot_calls = argus.iters;

if(argus.alg_mode)
{
EXPECT_ROCBLAS_STATUS(
rocsolver_set_alg_mode(handle, rocsolver_function_sterf, rocsolver_alg_mode_hybrid),
rocblas_status_success);

rocsolver_alg_mode alg_mode;
EXPECT_ROCBLAS_STATUS(rocsolver_get_alg_mode(handle, rocsolver_function_sterf, &alg_mode),
rocblas_status_success);

EXPECT_EQ(alg_mode, rocsolver_alg_mode_hybrid);
}

// check non-supported values
if(uplo == rocblas_fill_full || evect == rocblas_evect_tridiagonal)
{
Expand Down
26 changes: 25 additions & 1 deletion clients/gtest/auxiliary/sterf_gtest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ Arguments sterf_setup_arguments(sterf_tuple tup)
return arg;
}

class STERF : public ::TestWithParam<sterf_tuple>
template <rocblas_int MODE>
class STERF_BASE : public ::TestWithParam<sterf_tuple>
{
protected:
void TearDown() override
Expand All @@ -77,6 +78,7 @@ class STERF : public ::TestWithParam<sterf_tuple>
void run_tests()
{
Arguments arg = sterf_setup_arguments(GetParam());
arg.alg_mode = MODE;

if(arg.peek<rocblas_int>("n") == 0)
testing_sterf_bad_arg<T>();
Expand All @@ -85,6 +87,14 @@ class STERF : public ::TestWithParam<sterf_tuple>
}
};

class STERF : public STERF_BASE<0>
{
};

class STERF_HYBRID : public STERF_BASE<1>
{
};

// non-batch tests

TEST_P(STERF, __float)
Expand All @@ -97,6 +107,20 @@ TEST_P(STERF, __double)
run_tests<double>();
}

TEST_P(STERF_HYBRID, __float)
{
run_tests<float>();
}

TEST_P(STERF_HYBRID, __double)
{
run_tests<double>();
}

INSTANTIATE_TEST_SUITE_P(daily_lapack, STERF, ValuesIn(large_matrix_size_range));

INSTANTIATE_TEST_SUITE_P(checkin_lapack, STERF, ValuesIn(matrix_size_range));

INSTANTIATE_TEST_SUITE_P(daily_lapack, STERF_HYBRID, ValuesIn(large_matrix_size_range));

INSTANTIATE_TEST_SUITE_P(checkin_lapack, STERF_HYBRID, ValuesIn(matrix_size_range));
90 changes: 88 additions & 2 deletions clients/gtest/lapack/syev_heev_gtest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ Arguments syev_heev_setup_arguments(syev_heev_tuple tup)
return arg;
}

template <rocblas_int MODE>
class SYEV_HEEV : public ::TestWithParam<syev_heev_tuple>
{
protected:
Expand All @@ -93,6 +94,7 @@ class SYEV_HEEV : public ::TestWithParam<syev_heev_tuple>
void run_tests()
{
Arguments arg = syev_heev_setup_arguments(GetParam());
arg.alg_mode = MODE;

if(arg.peek<rocblas_int>("n") == 0 && arg.peek<char>("evect") == 'N'
&& arg.peek<char>("uplo") == 'L')
Expand All @@ -103,11 +105,19 @@ class SYEV_HEEV : public ::TestWithParam<syev_heev_tuple>
}
};

class SYEV : public SYEV_HEEV
class SYEV : public SYEV_HEEV<0>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor: is there a convention to use all upper case for compile time constant or #define macros or constants? If so, perhaps consider using a mixed case? Just a thought.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe that changing the capitalization of the test class will also change the capitalization of the test suite output. I like the all caps text in the test output as it makes it very easy to pick out the function name.

{
};

class HEEV : public SYEV_HEEV
class HEEV : public SYEV_HEEV<0>
{
};

class SYEV_HYBRID : public SYEV_HEEV<1>
{
};

class HEEV_HYBRID : public SYEV_HEEV<1>
{
};

Expand All @@ -133,6 +143,26 @@ TEST_P(HEEV, __double_complex)
run_tests<false, false, rocblas_double_complex>();
}

TEST_P(SYEV_HYBRID, __float)
{
run_tests<false, false, float>();
}

TEST_P(SYEV_HYBRID, __double)
{
run_tests<false, false, double>();
}

TEST_P(HEEV_HYBRID, __float_complex)
{
run_tests<false, false, rocblas_float_complex>();
}

TEST_P(HEEV_HYBRID, __double_complex)
{
run_tests<false, false, rocblas_double_complex>();
}

// batched tests

TEST_P(SYEV, batched__float)
Expand All @@ -155,6 +185,26 @@ TEST_P(HEEV, batched__double_complex)
run_tests<true, true, rocblas_double_complex>();
}

TEST_P(SYEV_HYBRID, batched__float)
{
run_tests<true, true, float>();
}

TEST_P(SYEV_HYBRID, batched__double)
{
run_tests<true, true, double>();
}

TEST_P(HEEV_HYBRID, batched__float_complex)
{
run_tests<true, true, rocblas_float_complex>();
}

TEST_P(HEEV_HYBRID, batched__double_complex)
{
run_tests<true, true, rocblas_double_complex>();
}

// strided_batched tests

TEST_P(SYEV, strided_batched__float)
Expand All @@ -177,13 +227,49 @@ TEST_P(HEEV, strided_batched__double_complex)
run_tests<false, true, rocblas_double_complex>();
}

TEST_P(SYEV_HYBRID, strided_batched__float)
{
run_tests<false, true, float>();
}

TEST_P(SYEV_HYBRID, strided_batched__double)
{
run_tests<false, true, double>();
}

TEST_P(HEEV_HYBRID, strided_batched__float_complex)
{
run_tests<false, true, rocblas_float_complex>();
}

TEST_P(HEEV_HYBRID, strided_batched__double_complex)
{
run_tests<false, true, rocblas_double_complex>();
}

// daily_lapack tests normal execution with medium to large sizes
INSTANTIATE_TEST_SUITE_P(daily_lapack, SYEV, Combine(ValuesIn(large_size_range), ValuesIn(op_range)));

INSTANTIATE_TEST_SUITE_P(daily_lapack, HEEV, Combine(ValuesIn(large_size_range), ValuesIn(op_range)));

INSTANTIATE_TEST_SUITE_P(daily_lapack,
SYEV_HYBRID,
Combine(ValuesIn(large_size_range), ValuesIn(op_range)));

INSTANTIATE_TEST_SUITE_P(daily_lapack,
HEEV_HYBRID,
Combine(ValuesIn(large_size_range), ValuesIn(op_range)));

// checkin_lapack tests normal execution with small sizes, invalid sizes,
// quick returns, and corner cases
INSTANTIATE_TEST_SUITE_P(checkin_lapack, SYEV, Combine(ValuesIn(size_range), ValuesIn(op_range)));

INSTANTIATE_TEST_SUITE_P(checkin_lapack, HEEV, Combine(ValuesIn(size_range), ValuesIn(op_range)));

INSTANTIATE_TEST_SUITE_P(checkin_lapack,
SYEV_HYBRID,
Combine(ValuesIn(size_range), ValuesIn(op_range)));

INSTANTIATE_TEST_SUITE_P(checkin_lapack,
HEEV_HYBRID,
Combine(ValuesIn(size_range), ValuesIn(op_range)));
1 change: 1 addition & 0 deletions library/include/rocsolver/rocsolver-extra-types.h
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ typedef enum rocsolver_function_
{
rocsolver_function_bdsqr = 401,
rocsolver_function_gesvd = 402,
rocsolver_function_sterf = 403,
} rocsolver_function;

#endif /* ROCSOLVER_EXTRA_TYPES_H */
4 changes: 4 additions & 0 deletions library/include/rocsolver/rocsolver-functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -4005,6 +4005,10 @@ ROCSOLVER_EXPORT rocblas_status rocsolver_zbdsqr(rocblas_handle handle,
The matrix is not represented explicitly, but rather as the array of
diagonal elements D and the array of symmetric off-diagonal elements E.

\note
A hybrid (CPU+GPU) approach is available for STERF, primarily intended for
homogeneous architectures. Use \ref rocsolver_set_alg_mode to enable it.

@param[in]
handle rocblas_handle.
@param[in]
Expand Down
Loading