Skip to content

Commit

Permalink
rocsolver_gemm test
Browse files Browse the repository at this point in the history
  • Loading branch information
AGonzales-amd committed Jan 22, 2025
1 parent eefef23 commit e937d3f
Show file tree
Hide file tree
Showing 20 changed files with 965 additions and 137 deletions.
8 changes: 7 additions & 1 deletion clients/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# ##########################################################################
# Copyright (C) 2019-2024 Advanced Micro Devices, Inc. All rights reserved.
# Copyright (C) 2019-2025 Advanced Micro Devices, Inc. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
Expand Down Expand Up @@ -58,6 +58,7 @@ if(BUILD_CLIENTS_BENCHMARKS OR BUILD_CLIENTS_TESTS)
add_library(clients-common INTERFACE)
target_include_directories(clients-common INTERFACE
${CMAKE_CURRENT_SOURCE_DIR}
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/../library/src/include>
)
target_link_libraries(clients-common INTERFACE
${LAPACK_LIBRARIES}
Expand Down Expand Up @@ -154,6 +155,10 @@ if(BUILD_CLIENTS_BENCHMARKS OR BUILD_CLIENTS_TESTS)
common/refact/testing_csrrf_solve.cpp
)

set(rocunit_inst_files
common/unit/testing_gemm.cpp
)

set(common_source_files
common/misc/lapack_host_reference.cpp
common/misc/rocsolver_test.cpp
Expand All @@ -164,6 +169,7 @@ if(BUILD_CLIENTS_BENCHMARKS OR BUILD_CLIENTS_TESTS)
${rocauxiliary_inst_files}
${roclapack_inst_files}
${rocrefact_inst_files}
${rocunit_inst_files}
)

prepend_path("${CMAKE_CURRENT_SOURCE_DIR}/" common_source_files common_source_paths)
Expand Down
14 changes: 13 additions & 1 deletion clients/benchmarks/client.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/* **************************************************************************
* Copyright (C) 2016-2024 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (C) 2016-2025 Advanced Micro Devices, Inc. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions
Expand Down Expand Up @@ -621,6 +621,18 @@ try
" Indicates if a matrix should be transposed.\n"
" ")

("transA",
value<char>()->default_value('N'),
"N = no transpose, T = transpose, C = conjugate transpose.\n"
" Indicates if matrix A should be transposed.\n"
" ")

("transB",
value<char>()->default_value('N'),
"N = no transpose, T = transpose, C = conjugate transpose.\n"
" Indicates if matrix B should be transposed.\n"
" ")

("uplo",
value<char>()->default_value('U'),
"U = upper, L = lower.\n"
Expand Down
12 changes: 11 additions & 1 deletion clients/common/misc/rocsolver_dispatcher.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/* **************************************************************************
* Copyright (C) 2021-2024 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions
Expand Down Expand Up @@ -113,6 +113,9 @@
#include "common/refact/testing_csrrf_splitlu.hpp"
#include "common/refact/testing_csrrf_sumlu.hpp"

// unit
#include "common/unit/testing_gemm.hpp"

struct str_less
{
bool operator()(const char* a, const char* b) const
Expand Down Expand Up @@ -304,6 +307,13 @@ class rocsolver_dispatcher
{"geblttrs_npvt", testing_geblttrs_npvt<false, false, T>},
{"geblttrs_npvt_batched", testing_geblttrs_npvt<true, true, T>},
{"geblttrs_npvt_strided_batched", testing_geblttrs_npvt<false, true, T>},
// unit
{"gemm", testing_gemm<false, false, T, rocblas_int>},
{"gemm_batched", testing_gemm<true, true, T, rocblas_int>},
{"gemm_strided_batched", testing_gemm<false, true, T, rocblas_int>},
{"gemm_64", testing_gemm<false, false, T, int64_t>},
{"gemm_batched_64", testing_gemm<true, true, T, int64_t>},
{"gemm_strided_batched_64", testing_gemm<false, true, T, int64_t>},
};

// Grab function from the map and execute
Expand Down
32 changes: 32 additions & 0 deletions clients/common/unit/testing_gemm.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/* **************************************************************************
* Copyright (C) 2024-2025 Advanced Micro Devices, Inc. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions
* are met:
*
* 1. Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
*
* THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
* ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
* OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
* HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
* LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
* OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
* SUCH DAMAGE.
* *************************************************************************/

#include "testing_gemm.hpp"

#define TESTING_GEMM(...) template void testing_gemm<__VA_ARGS__>(Arguments&);

INSTANTIATE(TESTING_GEMM, FOREACH_MATRIX_DATA_LAYOUT, FOREACH_SCALAR_TYPE, FOREACH_INT_TYPE, APPLY_STAMP)
Loading

0 comments on commit e937d3f

Please sign in to comment.