From a4baff03d72eea071740dfc9c0c7e4cacf54ac49 Mon Sep 17 00:00:00 2001 From: AGonzales-amd Date: Wed, 8 Jan 2025 23:39:03 +0000 Subject: [PATCH 1/7] add mfma internal gemm kernel and device functions --- library/src/include/lib_device_helpers.hpp | 17 ++ .../roclapack_gemm_device_functions.hpp | 258 +++++++++++++++++ .../roclapack_gemm_specialized_kernels.hpp | 263 ++++++++++++------ 3 files changed, 446 insertions(+), 92 deletions(-) create mode 100644 library/src/specialized/roclapack_gemm_device_functions.hpp diff --git a/library/src/include/lib_device_helpers.hpp b/library/src/include/lib_device_helpers.hpp index dc011f03a..5a18ae26c 100644 --- a/library/src/include/lib_device_helpers.hpp +++ b/library/src/include/lib_device_helpers.hpp @@ -119,6 +119,23 @@ __device__ void scale_tridiag(const rocblas_int start, const rocblas_int end, T* } } +template , int> = 0> +__device__ T shfl(T val, I src) +{ + return __shfl(val, src); +} + +template , int> = 0> +__device__ T shfl(T val, I src) +{ + using S = decltype(std::real(T{})); + + auto r = __shfl(val.real(), src); + auto i = __shfl(val.imag(), src); + + return rocblas_complex_num(r, i); +} + // ********************************************************** // GPU kernels that are used by many rocsolver functions // ********************************************************** diff --git a/library/src/specialized/roclapack_gemm_device_functions.hpp b/library/src/specialized/roclapack_gemm_device_functions.hpp new file mode 100644 index 000000000..e8d567557 --- /dev/null +++ b/library/src/specialized/roclapack_gemm_device_functions.hpp @@ -0,0 +1,258 @@ +/* ************************************************************************** + * Copyright (C) 2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS + * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) + * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT + * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY + * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF + * SUCH DAMAGE. + * *************************************************************************/ + +#pragma once + +#include "rocblas.hpp" +#include "rocsolver/rocsolver.h" + +#if defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#define ROCSOLVER_MFMA_ENABLED 1 +#else +#define ROCSOLVER_MFMA_ENABLED 0 +#endif // defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) + +ROCSOLVER_BEGIN_NAMESPACE + +#if ROCSOLVER_MFMA_ENABLED + +template +struct mfma_16x16x4_base +{ + using RegT = T; + using AccT = __attribute__( (__vector_size__(4 * sizeof(T)) )) T; +}; + +template +struct mfma_16x16x4; + +// float specialization +template <> +struct mfma_16x16x4: public mfma_16x16x4_base +{ + __device__ inline auto operator()(const RegT& a, const RegT& b, const AccT& c) const + { + return __builtin_amdgcn_mfma_f32_16x16x4f32(a, b, c, 0, 0, 0); + } +}; + +// double specialization +template <> +struct mfma_16x16x4: public mfma_16x16x4_base +{ + __device__ inline auto operator()(const RegT& a, const RegT& b, const AccT& c) const + { + return __builtin_amdgcn_mfma_f64_16x16x4f64(a, b, c, 0, 0, 0); + } +}; + +// complex specialization +template +struct mfma_16x16x4 +{ + using RegT = T; + using AccT = std::array; + + using S = decltype(std::real(T{})); + using RegS = typename mfma_16x16x4_base::RegT; + using AccS = typename mfma_16x16x4_base::AccT; + + __device__ inline auto operator()(const RegT& a, const RegT& b, const AccT& c) const + { + RegS ar = a.real(); + RegS ai = a.imag(); + RegS br = b.real(); + RegS bi = b.imag(); + AccS cr = {c[0].real(), c[1].real(), c[2].real(), c[3].real()}; + AccS ci = {c[0].imag(), c[1].imag(), c[2].imag(), c[3].imag()}; + AccS zero = {0}; + + const auto mfma_S = mfma_16x16x4(); + + // real x real + auto arbr = mfma_S(ar, br, zero); + + // real x imag + auto arbi = mfma_S(ar, bi, zero); + + // imag x real + auto aibr = mfma_S(ai, br, zero); + + // imag x imag + auto aibi = mfma_S(ai, bi, zero); + + // cr += r x r - i x i + cr += arbr - aibi; + // ci += r x i + i x r + ci += arbi + aibr; + + return AccT{rocblas_complex_num(cr[0], ci[0]), + rocblas_complex_num(cr[1], ci[1]), + rocblas_complex_num(cr[2], ci[2]), + rocblas_complex_num(cr[3], ci[3])}; + } +}; + +template || std::is_same_v, int> = 0> +__device__ inline I get_c_col(I li, I lj, I gpri, I inc_C, I ldc) +{ + return lj; +} + +template || std::is_same_v, int> = 0> +__device__ inline I get_c_col(I li, I lj, I gpri, I inc_C, I ldc) +{ + return lj; +} + +template || std::is_same_v, int> = 0> +__device__ inline I get_c_row(I li, I lj, I gpri, I inc_C, I ldc) +{ + return gpri + li * 4; +} + +template || std::is_same_v, int> = 0> +__device__ inline I get_c_row(I li, I lj, I gpri, I inc_C, I ldc) +{ + return gpri * 4 + li; +} + +/** GEMM device function to compute C = alpha * A * B + beta * C. + + Where C is an m x n matrix, A is an m x p matrix, and B is an + p x n matrix. This is a wave function, every lane of the wave + must perform call this function. + - m: 0 < m <= 16 + - n: 0 < n <= 16 + - p: 0 < p +**/ +// Run with warpSize sized block +template +__device__ void gemm_16x16xp(rocblas_operation transA, + rocblas_operation transB, + I m, + I n, + I p, + T alpha, + const T *A, + I inc_A, + I lda, + const T *B, + I inc_B, + I ldb, + T beta, + T *C, + I inc_C, + I ldc) +{ + using T4 = typename mfma_16x16x4::AccT; + + const I lid = threadIdx.x % warpSize; + + const I cmajor_i_16x4 = lid % 16; + const I cmajor_j_16x4 = lid / 16; + + const I cmajor_i_4x16 = lid % 4; + const I cmajor_j_4x16 = lid / 4; + + const I rmajor_i_4x16 = cmajor_j_16x4; + const I rmajor_j_4x16 = cmajor_i_16x4; + + // addresses to transpose B from col-major to row-major + // and transpose C from row-major to col-major + const auto c2r_src = rmajor_j_4x16 * 4 + rmajor_i_4x16; + const auto r2c_src = cmajor_i_4x16 * 16 + cmajor_j_4x16; + + T4 dmn = {0}; + + for(I kb = 0; kb < p; kb += 4) + { + // read A and B in col-major + T amk = 0; + T bkn = 0; + + // load A + if(transA == rocblas_operation_none) + { + // read col major 16x4 A + if(cmajor_i_16x4 < m && (kb + cmajor_j_16x4) < p) + amk = A[(kb + cmajor_j_16x4) * lda + cmajor_i_16x4 * inc_A]; + } + else + { + // read col major 4x16 op(A) + if(cmajor_j_4x16 < m && (kb + cmajor_i_4x16) < p) + amk = A[cmajor_j_4x16 * lda + (kb + cmajor_i_4x16) * inc_A]; + + // transpose op(A) to 16x4 + amk = shfl(amk, c2r_src); + } + + // load B + if(transB == rocblas_operation_none) + { + // read col major 4x16 B + if(cmajor_j_4x16 < n && (kb + cmajor_i_4x16) < p) + bkn = B[cmajor_j_4x16 * ldb + (kb + cmajor_i_4x16) * inc_B]; + + // transpose B to row major + bkn = shfl(bkn, c2r_src); + } + else + { + // read col major 16x4 op(B) + if(cmajor_i_16x4 < n && (kb + cmajor_j_16x4) < p) + bkn = B[(kb + cmajor_j_16x4) * ldb + cmajor_i_16x4 * inc_B]; + } + + if constexpr (rocblas_is_complex) + if(transA == rocblas_operation_conjugate_transpose) + amk = conj(amk); + if(transB == rocblas_operation_conjugate_transpose) + bkn = conj(bkn); + + dmn = mfma_16x16x4()(amk, bkn, dmn); + } + +#pragma unroll + for (I i = 0; i < 4; ++i) + { + const I c_col = get_c_col(cmajor_i_4x16, cmajor_j_4x16, i, inc_C, ldc); + const I c_row = get_c_row(cmajor_i_4x16, cmajor_j_4x16, i, inc_C, ldc); + const I idx = (c_col * ldc) + (c_row * inc_C); + + // transpose C to col major + dmn[i] = shfl(dmn[i], r2c_src); + + if(c_col < n && c_row < m) + C[idx] = alpha * dmn[i] + beta * C[idx]; + } +} + +#endif // ROCSOLVER_MFMA_ENABLED + +ROCSOLVER_END_NAMESPACE \ No newline at end of file diff --git a/library/src/specialized/roclapack_gemm_specialized_kernels.hpp b/library/src/specialized/roclapack_gemm_specialized_kernels.hpp index 255374de6..3305c492e 100644 --- a/library/src/specialized/roclapack_gemm_specialized_kernels.hpp +++ b/library/src/specialized/roclapack_gemm_specialized_kernels.hpp @@ -28,6 +28,7 @@ #pragma once #include "rocsolver_run_specialized_kernels.hpp" +#include "roclapack_gemm_device_functions.hpp" #include @@ -73,7 +74,7 @@ ROCSOLVER_KERNEL void gemm_kernel(const I m, T* B = load_ptr_batch(BB, bid, shiftB, strideB); T* C = load_ptr_batch(CC, bid, shiftC, strideC); - // gemm function assuming no conjugation + // gemm function T temp = 0; if(i < m && j < n) { @@ -87,68 +88,108 @@ ROCSOLVER_KERNEL void gemm_kernel(const I m, } } -// /** Optimized kernel that executes a simple gemm A = BC -// where A, B and C are sub blocks of the same matrix MM with -// leading dimension ldim and stride. A, B and C are -// located in MM by their respective shifts. - -// Call this kernel with 'batch_count' groups in z, and enough -// groups in x and y to cover all the 'm' rows and 'n' columns of C. -// Size of shared memory per group should be: -// lmemsize = k * (hipBlockDim_x + hipBlockDim_y) * sizeof(T); **/ -// template -// ROCSOLVER_KERNEL void gemm_kernel(const rocblas_int m, -// const rocblas_int n, -// const rocblas_int k, -// U MM, -// const rocblas_int shiftA, -// const rocblas_int shiftB, -// const rocblas_int shiftC, -// const rocblas_int ldim, -// const rocblas_stride stride) -// { -// // indices -// int id = hipBlockIdx_z; -// int tx = hipThreadIdx_x; -// int ty = hipThreadIdx_y; -// int bdx = hipBlockDim_x; -// int bdy = hipBlockDim_y; -// int i = hipBlockIdx_x * bdx + tx; -// int j = hipBlockIdx_y * bdy + ty; - -// // batch instance -// T* A = load_ptr_batch(MM, id, shiftA, stride); -// T* B = load_ptr_batch(MM, id, shiftB, stride); -// T* C = load_ptr_batch(MM, id, shiftC, stride); - -// // shared mem setup -// extern __shared__ double lmem[]; -// T* a = reinterpret_cast(lmem); -// T* b = a + k * bdx; -// T c; - -// // local row and column of the shared arrays -// a += tx * k; -// b += ty * k; - -// // read A and B into shared mem -// for(int kk = ty; kk < k; kk += bdy) -// a[kk] = i < m ? A[i + kk * ldim] : 0; -// for(int kk = tx; kk < k; kk += bdx) -// b[kk] = j < n ? B[kk + j * ldim] : 0; -// __syncthreads(); - -// if(i < m && j < n) -// { -// // update c -// c = C[i + j * ldim]; -// for(int kk = 0; kk < k; ++kk) -// c -= a[kk] * b[kk]; - -// // write back to global memory -// C[i + j * ldim] = c; -// } -// } +#if ROCSOLVER_MFMA_ENABLED + +/** GEMM device function to compute C = alpha * A * B + beta * C. + + Call this kernel with 'batch_count' groups in z. Each wave in x + and y computes 16 of 'm' rows and 16 of the 'n' columns of C. **/ +template +ROCSOLVER_KERNEL void mfma_gemm_kernel(rocblas_operation transA, + rocblas_operation transB, + const I m, + const I n, + const I p, + V alpha, + U1 AA, + rocblas_stride shiftA, + I inca, + I lda, + rocblas_stride strideA, + U2 BB, + rocblas_stride shiftB, + I incb, + I ldb, + rocblas_stride strideB, + V beta, + U3 CC, + rocblas_stride shiftC, + I incc, + I ldc, + rocblas_stride strideC) +{ + const I bid_x = blockIdx.x; + const I bid_y = blockIdx.y; + + const I numWaves_x = blockDim.x / warpSize; + const I numWaves_y = blockDim.y; + + const I wid_x = threadIdx.x / warpSize; + const I wid_y = threadIdx.y; + + const I block_row = 16 * (numWaves_x * bid_x + wid_x); + const I block_col = 16 * (numWaves_y * bid_y + wid_y); + + if(block_row >= m || block_col >= n) + { + return; + } + + const I m_bar = (block_row + 16) <= m ? 16 : m % 16; + const I n_bar = (block_col + 16) <= n ? 16 : n % 16; + + I batch_id = hipBlockIdx_z; + + // batch instance + T a = load_scalar(alpha, batch_id, 0); + T b = load_scalar(beta, batch_id, 0); + T* A = load_ptr_batch(AA, batch_id, shiftA, strideA); + T* B = load_ptr_batch(BB, batch_id, shiftB, strideB); + T* C = load_ptr_batch(CC, batch_id, shiftC, strideC); + + A += block_row * (transA == rocblas_operation_none ? inca : lda); + B += block_col * (transB == rocblas_operation_none ? ldb : incb); + + // C(bid_x,bid_y) += A(bid_x,:) * B(:,bid_y) + gemm_16x16xp(transA, transB, m_bar, n_bar, p, + a, + A, + inca, + lda, + B, + incb, + ldb, + b, + C + (block_col * ldc + block_row * incc), + incc, + ldc); +} + +#else // ROCSOLVER_MFMA_ENABLED +template +ROCSOLVER_KERNEL void mfma_gemm_kernel(rocblas_operation transA, + rocblas_operation transB, + const I m, + const I n, + const I p, + V alpha, + U1 AA, + rocblas_stride shiftA, + I inca, + I lda, + rocblas_stride strideA, + U2 BB, + rocblas_stride shiftB, + I incb, + I ldb, + rocblas_stride strideB, + V beta, + U3 CC, + rocblas_stride shiftC, + I incc, + I ldc, + rocblas_stride strideC){} +#endif // ROCSOLVER_MFMA_ENABLED /************************************************************* Launchers of specialized kernels @@ -201,41 +242,79 @@ rocblas_status rocsolver_gemm(rocblas_handle handle, rocblas_pointer_mode pmode; rocblas_get_pointer_mode(handle, &pmode); - // matrices can be transposed by swapping inc and ld - I lda1 = inca; - I lda2 = lda; - I ldb1 = incb; - I ldb2 = ldb; - if(transA != rocblas_operation_none) - { - lda1 = lda; - lda2 = inca; - } - if(transB != rocblas_operation_none) - { - ldb1 = ldb; - ldb2 = incb; - } + // get warp size + int device; + HIP_CHECK(hipGetDevice(&device)); + hipDeviceProp_t deviceProperties; + HIP_CHECK(hipGetDeviceProperties(&deviceProperties, device)); - const bool conjA = transA == rocblas_operation_conjugate_transpose; - const bool conjB = transB == rocblas_operation_conjugate_transpose; + std::string deviceArch(deviceProperties.gcnArchName); - // launch specialized kernel - I blocksx = (m - 1) / BS2 + 1; - I blocksy = (n - 1) / BS2 + 1; - dim3 grid(blocksx, blocksy, batch_count); - dim3 threads(BS2, BS2, 1); - if(pmode == rocblas_pointer_mode_device) + if((deviceArch.find("gfx90a") != std::string::npos) + || (deviceArch.find("gfx940") != std::string::npos) + || (deviceArch.find("gfx941") != std::string::npos) + || (deviceArch.find("gfx942") != std::string::npos)) { - ROCSOLVER_LAUNCH_KERNEL((gemm_kernel), grid, threads, 0, stream, m, n, k, alpha, conjA, - A, shiftA, lda1, lda2, strideA, conjB, B, shiftB, ldb1, ldb2, - strideB, beta, C, shiftC, incc, ldc, strideC); + const auto warpSize = deviceProperties.warpSize; + + // launch specialized kernel + const I numWarpsX = 4; + const I numWarpsY = 4; + const I blocksx = (m + (numWarpsX * 16 - 1)) / (numWarpsX * 16); + const I blocksy = (n + (numWarpsY * 16 - 1)) / (numWarpsY * 16); + dim3 grid(blocksx, blocksy, batch_count); + dim3 threads(numWarpsX * warpSize, numWarpsY, 1); + if(pmode == rocblas_pointer_mode_device) + { + ROCSOLVER_LAUNCH_KERNEL((mfma_gemm_kernel), grid, threads, 0, stream, transA, transB, m, n, k, alpha, + A, shiftA, inca, lda, strideA, B, shiftB, incb, ldb, + strideB, beta, C, shiftC, incc, ldc, strideC); + } + else + { + ROCSOLVER_LAUNCH_KERNEL((mfma_gemm_kernel), grid, threads, 0, stream, transA, transB, m, n, k, *alpha, + A, shiftA, inca, lda, strideA, B, shiftB, incb, ldb, + strideB, *beta, C, shiftC, incc, ldc, strideC); + } } else { - ROCSOLVER_LAUNCH_KERNEL((gemm_kernel), grid, threads, 0, stream, m, n, k, *alpha, conjA, - A, shiftA, lda1, lda2, strideA, conjB, B, shiftB, ldb1, ldb2, - strideB, *beta, C, shiftC, incc, ldc, strideC); + // matrices can be transposed by swapping inc and ld + I lda1 = inca; + I lda2 = lda; + I ldb1 = incb; + I ldb2 = ldb; + if(transA != rocblas_operation_none) + { + lda1 = lda; + lda2 = inca; + } + if(transB != rocblas_operation_none) + { + ldb1 = ldb; + ldb2 = incb; + } + + const bool conjA = transA == rocblas_operation_conjugate_transpose; + const bool conjB = transB == rocblas_operation_conjugate_transpose; + + // launch specialized kernel + I blocksx = (m - 1) / BS2 + 1; + I blocksy = (n - 1) / BS2 + 1; + dim3 grid(blocksx, blocksy, batch_count); + dim3 threads(BS2, BS2, 1); + if(pmode == rocblas_pointer_mode_device) + { + ROCSOLVER_LAUNCH_KERNEL((gemm_kernel), grid, threads, 0, stream, m, n, k, alpha, conjA, + A, shiftA, lda1, lda2, strideA, conjB, B, shiftB, ldb1, ldb2, + strideB, beta, C, shiftC, incc, ldc, strideC); + } + else + { + ROCSOLVER_LAUNCH_KERNEL((gemm_kernel), grid, threads, 0, stream, m, n, k, *alpha, conjA, + A, shiftA, lda1, lda2, strideA, conjB, B, shiftB, ldb1, ldb2, + strideB, *beta, C, shiftC, incc, ldc, strideC); + } } return rocblas_status_success; From 64d9825f0af4eacbc1e2fad6cae31f48b663dc65 Mon Sep 17 00:00:00 2001 From: AGonzales-amd Date: Tue, 14 Jan 2025 19:41:44 +0000 Subject: [PATCH 2/7] temp change --- CMakeLists.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index fdd906004..e061b3087 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -57,7 +57,8 @@ set(CMAKE_CXX_STANDARD_REQUIRED ON) option(ROCSOLVER_EMBED_FMT "Hide libfmt symbols" ON) option(OPTIMAL "Build specialized kernels for small matrix sizes" ON) option(ROCSOLVER_FIND_PACKAGE_LAPACK_CONFIG "Skip module mode search for LAPACK" ON) -option(ROCSOLVER_USE_INTERNAL_BLAS "Use internal implementation of GEMM and TRSM for debugging." OFF) +# Temp Change default for CI +option(ROCSOLVER_USE_INTERNAL_BLAS "Use internal implementation of GEMM and TRSM for debugging." ON) # Add our CMake helper files to the lookup path list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake") From 262fb78ef5cd74d3564c945db218dba652872cab Mon Sep 17 00:00:00 2001 From: AGonzales-amd Date: Wed, 15 Jan 2025 22:01:21 +0000 Subject: [PATCH 3/7] revert temp change --- CMakeLists.txt | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index e061b3087..fdd906004 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -57,8 +57,7 @@ set(CMAKE_CXX_STANDARD_REQUIRED ON) option(ROCSOLVER_EMBED_FMT "Hide libfmt symbols" ON) option(OPTIMAL "Build specialized kernels for small matrix sizes" ON) option(ROCSOLVER_FIND_PACKAGE_LAPACK_CONFIG "Skip module mode search for LAPACK" ON) -# Temp Change default for CI -option(ROCSOLVER_USE_INTERNAL_BLAS "Use internal implementation of GEMM and TRSM for debugging." ON) +option(ROCSOLVER_USE_INTERNAL_BLAS "Use internal implementation of GEMM and TRSM for debugging." OFF) # Add our CMake helper files to the lookup path list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake") From 24e73642ec0a5f87ac3dd319c0e5dab2d450bd35 Mon Sep 17 00:00:00 2001 From: AGonzales-amd Date: Thu, 16 Jan 2025 20:16:53 +0000 Subject: [PATCH 4/7] documentation and formatting --- .../roclapack_gemm_device_functions.hpp | 95 ++++++++----- .../roclapack_gemm_specialized_kernels.hpp | 126 ++++++++---------- 2 files changed, 118 insertions(+), 103 deletions(-) diff --git a/library/src/specialized/roclapack_gemm_device_functions.hpp b/library/src/specialized/roclapack_gemm_device_functions.hpp index e8d567557..bc860d42e 100644 --- a/library/src/specialized/roclapack_gemm_device_functions.hpp +++ b/library/src/specialized/roclapack_gemm_device_functions.hpp @@ -44,7 +44,7 @@ template struct mfma_16x16x4_base { using RegT = T; - using AccT = __attribute__( (__vector_size__(4 * sizeof(T)) )) T; + using AccT = __attribute__((__vector_size__(4 * sizeof(T)))) T; }; template @@ -52,7 +52,7 @@ struct mfma_16x16x4; // float specialization template <> -struct mfma_16x16x4: public mfma_16x16x4_base +struct mfma_16x16x4 : public mfma_16x16x4_base { __device__ inline auto operator()(const RegT& a, const RegT& b, const AccT& c) const { @@ -62,7 +62,7 @@ struct mfma_16x16x4: public mfma_16x16x4_base // double specialization template <> -struct mfma_16x16x4: public mfma_16x16x4_base +struct mfma_16x16x4 : public mfma_16x16x4_base { __device__ inline auto operator()(const RegT& a, const RegT& b, const AccT& c) const { @@ -110,32 +110,38 @@ struct mfma_16x16x4 // ci += r x i + i x r ci += arbi + aibr; - return AccT{rocblas_complex_num(cr[0], ci[0]), - rocblas_complex_num(cr[1], ci[1]), - rocblas_complex_num(cr[2], ci[2]), - rocblas_complex_num(cr[3], ci[3])}; + return AccT{rocblas_complex_num(cr[0], ci[0]), rocblas_complex_num(cr[1], ci[1]), + rocblas_complex_num(cr[2], ci[2]), rocblas_complex_num(cr[3], ci[3])}; } }; -template || std::is_same_v, int> = 0> +template || std::is_same_v, int> = 0> __device__ inline I get_c_col(I li, I lj, I gpri, I inc_C, I ldc) { return lj; } -template || std::is_same_v, int> = 0> +template || std::is_same_v, int> = 0> __device__ inline I get_c_col(I li, I lj, I gpri, I inc_C, I ldc) { return lj; } -template || std::is_same_v, int> = 0> +template || std::is_same_v, int> = 0> __device__ inline I get_c_row(I li, I lj, I gpri, I inc_C, I ldc) { return gpri + li * 4; } -template || std::is_same_v, int> = 0> +template || std::is_same_v, int> = 0> __device__ inline I get_c_row(I li, I lj, I gpri, I inc_C, I ldc) { return gpri * 4 + li; @@ -146,28 +152,45 @@ __device__ inline I get_c_row(I li, I lj, I gpri, I inc_C, I ldc) Where C is an m x n matrix, A is an m x p matrix, and B is an p x n matrix. This is a wave function, every lane of the wave must perform call this function. - - m: 0 < m <= 16 - - n: 0 < n <= 16 - - p: 0 < p + + transA form of op(A). + transB form of op(B). + m number of rows of matrix C. + 0 < m <= 16 + n number of columns of matrix C. + 0 < n <= 16 + p number of rows of matrices op(A) and number of rows of matrix op(B). + 0 < p + alpha scalar alpha. + A pointer to matrix A. + inc_A stride from the start of one row to the next of matrix A. + lda leading dimension of A. + B pointer to matrix B. + inc_B stride from the start of one row to the next of matrix B. + ldb leading dimension of B. + C pointer to matrix C. + inc_C stride from the start of one row to the next of matrix C. + ldc leading dimension of C. + **/ // Run with warpSize sized block template __device__ void gemm_16x16xp(rocblas_operation transA, - rocblas_operation transB, - I m, - I n, - I p, - T alpha, - const T *A, - I inc_A, - I lda, - const T *B, - I inc_B, - I ldb, - T beta, - T *C, - I inc_C, - I ldc) + rocblas_operation transB, + I m, + I n, + I p, + T alpha, + const T* A, + I inc_A, + I lda, + const T* B, + I inc_B, + I ldb, + T beta, + T* C, + I inc_C, + I ldc) { using T4 = typename mfma_16x16x4::AccT; @@ -182,9 +205,9 @@ __device__ void gemm_16x16xp(rocblas_operation transA, const I rmajor_i_4x16 = cmajor_j_16x4; const I rmajor_j_4x16 = cmajor_i_16x4; - // addresses to transpose B from col-major to row-major + // addresses to transpose B from col-major to row-major // and transpose C from row-major to col-major - const auto c2r_src = rmajor_j_4x16 * 4 + rmajor_i_4x16; + const auto c2r_src = rmajor_j_4x16 * 4 + rmajor_i_4x16; const auto r2c_src = cmajor_i_4x16 * 16 + cmajor_j_4x16; T4 dmn = {0}; @@ -203,7 +226,7 @@ __device__ void gemm_16x16xp(rocblas_operation transA, amk = A[(kb + cmajor_j_16x4) * lda + cmajor_i_16x4 * inc_A]; } else - { + { // read col major 4x16 op(A) if(cmajor_j_4x16 < m && (kb + cmajor_i_4x16) < p) amk = A[cmajor_j_4x16 * lda + (kb + cmajor_i_4x16) * inc_A]; @@ -229,17 +252,17 @@ __device__ void gemm_16x16xp(rocblas_operation transA, bkn = B[(kb + cmajor_j_16x4) * ldb + cmajor_i_16x4 * inc_B]; } - if constexpr (rocblas_is_complex) + if constexpr(rocblas_is_complex) if(transA == rocblas_operation_conjugate_transpose) amk = conj(amk); - if(transB == rocblas_operation_conjugate_transpose) - bkn = conj(bkn); + if(transB == rocblas_operation_conjugate_transpose) + bkn = conj(bkn); dmn = mfma_16x16x4()(amk, bkn, dmn); } #pragma unroll - for (I i = 0; i < 4; ++i) + for(I i = 0; i < 4; ++i) { const I c_col = get_c_col(cmajor_i_4x16, cmajor_j_4x16, i, inc_C, ldc); const I c_row = get_c_row(cmajor_i_4x16, cmajor_j_4x16, i, inc_C, ldc); diff --git a/library/src/specialized/roclapack_gemm_specialized_kernels.hpp b/library/src/specialized/roclapack_gemm_specialized_kernels.hpp index 3305c492e..a94174c10 100644 --- a/library/src/specialized/roclapack_gemm_specialized_kernels.hpp +++ b/library/src/specialized/roclapack_gemm_specialized_kernels.hpp @@ -27,8 +27,8 @@ #pragma once -#include "rocsolver_run_specialized_kernels.hpp" #include "roclapack_gemm_device_functions.hpp" +#include "rocsolver_run_specialized_kernels.hpp" #include @@ -96,27 +96,27 @@ ROCSOLVER_KERNEL void gemm_kernel(const I m, and y computes 16 of 'm' rows and 16 of the 'n' columns of C. **/ template ROCSOLVER_KERNEL void mfma_gemm_kernel(rocblas_operation transA, - rocblas_operation transB, - const I m, - const I n, - const I p, - V alpha, - U1 AA, - rocblas_stride shiftA, - I inca, - I lda, - rocblas_stride strideA, - U2 BB, - rocblas_stride shiftB, - I incb, - I ldb, - rocblas_stride strideB, - V beta, - U3 CC, - rocblas_stride shiftC, - I incc, - I ldc, - rocblas_stride strideC) + rocblas_operation transB, + const I m, + const I n, + const I p, + V alpha, + U1 AA, + rocblas_stride shiftA, + I inca, + I lda, + rocblas_stride strideA, + U2 BB, + rocblas_stride shiftB, + I incb, + I ldb, + rocblas_stride strideB, + V beta, + U3 CC, + rocblas_stride shiftC, + I incc, + I ldc, + rocblas_stride strideC) { const I bid_x = blockIdx.x; const I bid_y = blockIdx.y; @@ -151,44 +151,36 @@ ROCSOLVER_KERNEL void mfma_gemm_kernel(rocblas_operation transA, B += block_col * (transB == rocblas_operation_none ? ldb : incb); // C(bid_x,bid_y) += A(bid_x,:) * B(:,bid_y) - gemm_16x16xp(transA, transB, m_bar, n_bar, p, - a, - A, - inca, - lda, - B, - incb, - ldb, - b, - C + (block_col * ldc + block_row * incc), - incc, - ldc); + gemm_16x16xp(transA, transB, m_bar, n_bar, p, a, A, inca, lda, B, incb, ldb, b, + C + (block_col * ldc + block_row * incc), incc, ldc); } #else // ROCSOLVER_MFMA_ENABLED template ROCSOLVER_KERNEL void mfma_gemm_kernel(rocblas_operation transA, - rocblas_operation transB, - const I m, - const I n, - const I p, - V alpha, - U1 AA, - rocblas_stride shiftA, - I inca, - I lda, - rocblas_stride strideA, - U2 BB, - rocblas_stride shiftB, - I incb, - I ldb, - rocblas_stride strideB, - V beta, - U3 CC, - rocblas_stride shiftC, - I incc, - I ldc, - rocblas_stride strideC){} + rocblas_operation transB, + const I m, + const I n, + const I p, + V alpha, + U1 AA, + rocblas_stride shiftA, + I inca, + I lda, + rocblas_stride strideA, + U2 BB, + rocblas_stride shiftB, + I incb, + I ldb, + rocblas_stride strideB, + V beta, + U3 CC, + rocblas_stride shiftC, + I incc, + I ldc, + rocblas_stride strideC) +{ +} #endif // ROCSOLVER_MFMA_ENABLED /************************************************************* @@ -266,15 +258,15 @@ rocblas_status rocsolver_gemm(rocblas_handle handle, dim3 threads(numWarpsX * warpSize, numWarpsY, 1); if(pmode == rocblas_pointer_mode_device) { - ROCSOLVER_LAUNCH_KERNEL((mfma_gemm_kernel), grid, threads, 0, stream, transA, transB, m, n, k, alpha, - A, shiftA, inca, lda, strideA, B, shiftB, incb, ldb, - strideB, beta, C, shiftC, incc, ldc, strideC); + ROCSOLVER_LAUNCH_KERNEL((mfma_gemm_kernel), grid, threads, 0, stream, transA, transB, + m, n, k, alpha, A, shiftA, inca, lda, strideA, B, shiftB, incb, + ldb, strideB, beta, C, shiftC, incc, ldc, strideC); } else { - ROCSOLVER_LAUNCH_KERNEL((mfma_gemm_kernel), grid, threads, 0, stream, transA, transB, m, n, k, *alpha, - A, shiftA, inca, lda, strideA, B, shiftB, incb, ldb, - strideB, *beta, C, shiftC, incc, ldc, strideC); + ROCSOLVER_LAUNCH_KERNEL((mfma_gemm_kernel), grid, threads, 0, stream, transA, transB, + m, n, k, *alpha, A, shiftA, inca, lda, strideA, B, shiftB, incb, + ldb, strideB, *beta, C, shiftC, incc, ldc, strideC); } } else @@ -305,15 +297,15 @@ rocblas_status rocsolver_gemm(rocblas_handle handle, dim3 threads(BS2, BS2, 1); if(pmode == rocblas_pointer_mode_device) { - ROCSOLVER_LAUNCH_KERNEL((gemm_kernel), grid, threads, 0, stream, m, n, k, alpha, conjA, - A, shiftA, lda1, lda2, strideA, conjB, B, shiftB, ldb1, ldb2, - strideB, beta, C, shiftC, incc, ldc, strideC); + ROCSOLVER_LAUNCH_KERNEL((gemm_kernel), grid, threads, 0, stream, m, n, k, alpha, + conjA, A, shiftA, lda1, lda2, strideA, conjB, B, shiftB, ldb1, + ldb2, strideB, beta, C, shiftC, incc, ldc, strideC); } else { - ROCSOLVER_LAUNCH_KERNEL((gemm_kernel), grid, threads, 0, stream, m, n, k, *alpha, conjA, - A, shiftA, lda1, lda2, strideA, conjB, B, shiftB, ldb1, ldb2, - strideB, *beta, C, shiftC, incc, ldc, strideC); + ROCSOLVER_LAUNCH_KERNEL((gemm_kernel), grid, threads, 0, stream, m, n, k, *alpha, + conjA, A, shiftA, lda1, lda2, strideA, conjB, B, shiftB, ldb1, + ldb2, strideB, *beta, C, shiftC, incc, ldc, strideC); } } From eefef23d8aa8f771f24fc547476a020e7c58179c Mon Sep 17 00:00:00 2001 From: AGonzales-amd Date: Fri, 17 Jan 2025 16:54:29 +0000 Subject: [PATCH 5/7] fixes --- library/src/specialized/roclapack_gemm_device_functions.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/src/specialized/roclapack_gemm_device_functions.hpp b/library/src/specialized/roclapack_gemm_device_functions.hpp index bc860d42e..11a09e74d 100644 --- a/library/src/specialized/roclapack_gemm_device_functions.hpp +++ b/library/src/specialized/roclapack_gemm_device_functions.hpp @@ -159,7 +159,7 @@ __device__ inline I get_c_row(I li, I lj, I gpri, I inc_C, I ldc) 0 < m <= 16 n number of columns of matrix C. 0 < n <= 16 - p number of rows of matrices op(A) and number of rows of matrix op(B). + p number of columns of matrix op(A) and number of rows of matrix op(B). 0 < p alpha scalar alpha. A pointer to matrix A. From e937d3f317d9eeec7278053730c99106626b97ed Mon Sep 17 00:00:00 2001 From: AGonzales-amd Date: Wed, 22 Jan 2025 20:36:10 +0000 Subject: [PATCH 6/7] rocsolver_gemm test --- clients/CMakeLists.txt | 8 +- clients/benchmarks/client.cpp | 14 +- clients/common/misc/rocsolver_dispatcher.hpp | 12 +- clients/common/unit/testing_gemm.cpp | 32 ++ clients/common/unit/testing_gemm.hpp | 424 ++++++++++++++++++ clients/gtest/CMakeLists.txt | 8 +- clients/gtest/unit/gemm_gtest.cpp | 327 ++++++++++++++ common/include/rocblas_utility.hpp | 20 +- library/src/include/lib_device_helpers.hpp | 19 +- .../roclapack_gemm_device_functions.hpp | 17 +- .../rocsolver_run_specialized_kernels.hpp | 98 ++-- .../rocauxiliary_larf_specialized_kernels.hpp | 3 +- ...rocauxiliary_larfg_specialized_kernels.hpp | 3 +- .../roclapack_gemm_specialized_kernels.hpp | 99 ++-- .../roclapack_ger_specialized_kernels.hpp | 3 +- .../roclapack_getf2_specialized_kernels.hpp | 3 +- .../roclapack_getri_specialized_kernels.hpp | 3 +- .../roclapack_potf2_specialized_kernels.hpp | 3 +- .../roclapack_trsm_specialized_kernels.hpp | 3 +- .../roclapack_trtri_specialized_kernels.hpp | 3 +- 20 files changed, 965 insertions(+), 137 deletions(-) create mode 100644 clients/common/unit/testing_gemm.cpp create mode 100644 clients/common/unit/testing_gemm.hpp create mode 100644 clients/gtest/unit/gemm_gtest.cpp rename library/src/{specialized => include}/roclapack_gemm_device_functions.hpp (97%) diff --git a/clients/CMakeLists.txt b/clients/CMakeLists.txt index 386330bf4..6af117ba1 100755 --- a/clients/CMakeLists.txt +++ b/clients/CMakeLists.txt @@ -1,5 +1,5 @@ # ########################################################################## -# Copyright (C) 2019-2024 Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2019-2025 Advanced Micro Devices, Inc. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions @@ -58,6 +58,7 @@ if(BUILD_CLIENTS_BENCHMARKS OR BUILD_CLIENTS_TESTS) add_library(clients-common INTERFACE) target_include_directories(clients-common INTERFACE ${CMAKE_CURRENT_SOURCE_DIR} + $ ) target_link_libraries(clients-common INTERFACE ${LAPACK_LIBRARIES} @@ -154,6 +155,10 @@ if(BUILD_CLIENTS_BENCHMARKS OR BUILD_CLIENTS_TESTS) common/refact/testing_csrrf_solve.cpp ) + set(rocunit_inst_files + common/unit/testing_gemm.cpp + ) + set(common_source_files common/misc/lapack_host_reference.cpp common/misc/rocsolver_test.cpp @@ -164,6 +169,7 @@ if(BUILD_CLIENTS_BENCHMARKS OR BUILD_CLIENTS_TESTS) ${rocauxiliary_inst_files} ${roclapack_inst_files} ${rocrefact_inst_files} + ${rocunit_inst_files} ) prepend_path("${CMAKE_CURRENT_SOURCE_DIR}/" common_source_files common_source_paths) diff --git a/clients/benchmarks/client.cpp b/clients/benchmarks/client.cpp index 75882fc2a..4f543fd13 100644 --- a/clients/benchmarks/client.cpp +++ b/clients/benchmarks/client.cpp @@ -1,5 +1,5 @@ /* ************************************************************************** - * Copyright (C) 2016-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2016-2025 Advanced Micro Devices, Inc. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions @@ -621,6 +621,18 @@ try " Indicates if a matrix should be transposed.\n" " ") + ("transA", + value()->default_value('N'), + "N = no transpose, T = transpose, C = conjugate transpose.\n" + " Indicates if matrix A should be transposed.\n" + " ") + + ("transB", + value()->default_value('N'), + "N = no transpose, T = transpose, C = conjugate transpose.\n" + " Indicates if matrix B should be transposed.\n" + " ") + ("uplo", value()->default_value('U'), "U = upper, L = lower.\n" diff --git a/clients/common/misc/rocsolver_dispatcher.hpp b/clients/common/misc/rocsolver_dispatcher.hpp index 6e2947303..789a83dca 100644 --- a/clients/common/misc/rocsolver_dispatcher.hpp +++ b/clients/common/misc/rocsolver_dispatcher.hpp @@ -1,5 +1,5 @@ /* ************************************************************************** - * Copyright (C) 2021-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions @@ -113,6 +113,9 @@ #include "common/refact/testing_csrrf_splitlu.hpp" #include "common/refact/testing_csrrf_sumlu.hpp" +// unit +#include "common/unit/testing_gemm.hpp" + struct str_less { bool operator()(const char* a, const char* b) const @@ -304,6 +307,13 @@ class rocsolver_dispatcher {"geblttrs_npvt", testing_geblttrs_npvt}, {"geblttrs_npvt_batched", testing_geblttrs_npvt}, {"geblttrs_npvt_strided_batched", testing_geblttrs_npvt}, + // unit + {"gemm", testing_gemm}, + {"gemm_batched", testing_gemm}, + {"gemm_strided_batched", testing_gemm}, + {"gemm_64", testing_gemm}, + {"gemm_batched_64", testing_gemm}, + {"gemm_strided_batched_64", testing_gemm}, }; // Grab function from the map and execute diff --git a/clients/common/unit/testing_gemm.cpp b/clients/common/unit/testing_gemm.cpp new file mode 100644 index 000000000..230ceece7 --- /dev/null +++ b/clients/common/unit/testing_gemm.cpp @@ -0,0 +1,32 @@ +/* ************************************************************************** + * Copyright (C) 2024-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS + * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) + * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT + * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY + * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF + * SUCH DAMAGE. + * *************************************************************************/ + +#include "testing_gemm.hpp" + +#define TESTING_GEMM(...) template void testing_gemm<__VA_ARGS__>(Arguments&); + +INSTANTIATE(TESTING_GEMM, FOREACH_MATRIX_DATA_LAYOUT, FOREACH_SCALAR_TYPE, FOREACH_INT_TYPE, APPLY_STAMP) diff --git a/clients/common/unit/testing_gemm.hpp b/clients/common/unit/testing_gemm.hpp new file mode 100644 index 000000000..51bbd5ae6 --- /dev/null +++ b/clients/common/unit/testing_gemm.hpp @@ -0,0 +1,424 @@ +/* ************************************************************************** + * Copyright (C) 2024-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS + * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) + * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT + * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY + * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF + * SUCH DAMAGE. + * *************************************************************************/ + +#pragma once + +#include "common/misc/client_util.hpp" +#include "common/misc/clientcommon.hpp" +#include "common/misc/lapack_host_reference.hpp" +#include "common/misc/norm.hpp" +#include "common/misc/rocsolver.hpp" +#include "common/misc/rocsolver_arguments.hpp" +#include "common/misc/rocsolver_test.hpp" + +#include "rocblas_utility.hpp" +#include "rocsolver_run_specialized_kernels.hpp" + +template +void gemm_initData(const rocblas_handle handle, Td& dA, Th& hA, Td& dB, Th& hB, Td& dC, Th& hC) +{ + if(CPU) + { + rocblas_init(hA, true); + rocblas_init(hB, true); + rocblas_init(hC, true); + } + + if(GPU) + { + // now copy to the GPU + CHECK_HIP_ERROR(dA.transfer_from(hA)); + CHECK_HIP_ERROR(dB.transfer_from(hB)); + CHECK_HIP_ERROR(dC.transfer_from(hC)); + } +} + +template +void gemm_getError(const rocblas_handle handle, + const rocblas_operation transA, + const rocblas_operation transB, + const I m, + const I n, + const I k, + Ud& dalpha, + Td& dA, + const I inca, + const I lda, + const rocblas_stride stA, + Td& dB, + const I incb, + const I ldb, + const rocblas_stride stB, + Ud& dbeta, + Td& dC, + const I incc, + const I ldc, + const rocblas_stride stC, + const I bc, + Uh& halpha, + Uh& hbeta, + Th& hA, + Th& hB, + Th& hC, + Th& hCRes, + double* max_err) +{ + // input data initialization + gemm_initData(handle, dA, hA, dB, hB, dC, hC); + + // execute computations + // GPU + CHECK_ROCBLAS_ERROR( + (rocsolver::rocsolver_gemm)(handle, transA, transB, m, n, k, dalpha.data(), dA.data(), + 0, inca, lda, stA, dB.data(), 0, incb, ldb, stB, + dbeta.data(), dC.data(), 0, incc, ldc, stC, bc, nullptr)); + CHECK_HIP_ERROR(hCRes.transfer_from(dC)); + + // CPU lapack + for(I b = 0; b < bc; ++b) + { + cpu_gemm(transA, transB, m, n, k, halpha[0][0], hA[b], lda, hB[b], ldb, hbeta[0][0], hC[b], + ldc); + } + + // error is ||hC - hCRes|| / ||hC|| + // (THIS DOES NOT ACCOUNT FOR NUMERICAL REPRODUCIBILITY ISSUES. + // IT MIGHT BE REVISITED IN THE FUTURE) + // using frobenius norm + double err; + *max_err = 0; + for(I b = 0; b < bc; ++b) + { + err = norm_error('F', m, n, ldc, hC[b], hCRes[b]); + *max_err = err > *max_err ? err : *max_err; + } +} + +template +void gemm_getPerfData(const rocblas_handle handle, + const rocblas_operation transA, + const rocblas_operation transB, + const I m, + const I n, + const I k, + Ud& dalpha, + Td& dA, + const I inca, + const I lda, + const rocblas_stride stA, + Td& dB, + const I incb, + const I ldb, + const rocblas_stride stB, + Ud& dbeta, + Td& dC, + const I incc, + const I ldc, + const rocblas_stride stC, + const I bc, + Uh& halpha, + Uh& hbeta, + Th& hA, + Th& hB, + Th& hC, + double* gpu_time_used, + double* cpu_time_used, + const rocblas_int hot_calls, + const int profile, + const bool profile_kernels, + const bool perf) +{ + if(!perf) + { + gemm_initData(handle, dA, hA, dB, hB, dC, hC); + + // cpu-lapack performance (only if not in perf mode) + *cpu_time_used = get_time_us_no_sync(); + for(I b = 0; b < bc; ++b) + { + cpu_gemm(transA, transB, m, n, k, halpha[0][0], hA[b], lda, hB[b], ldb, hbeta[0][0], + hC[b], ldc); + } + *cpu_time_used = get_time_us_no_sync() - *cpu_time_used; + } + + gemm_initData(handle, dA, hA, dB, hB, dC, hC); + + // cold calls + for(int iter = 0; iter < 2; iter++) + { + gemm_initData(handle, dA, hA, dB, hB, dC, hC); + + CHECK_ROCBLAS_ERROR((rocsolver::rocsolver_gemm)(handle, transA, transB, m, n, k, + dalpha.data(), dA.data(), 0, inca, + lda, stA, dB.data(), 0, incb, ldb, + stB, dbeta.data(), dC.data(), 0, incc, + ldc, stC, bc, nullptr)); + } + + // gpu-lapack performance + hipStream_t stream; + CHECK_ROCBLAS_ERROR(rocblas_get_stream(handle, &stream)); + double start; + + if(profile > 0) + { + if(profile_kernels) + rocsolver_log_set_layer_mode(rocblas_layer_mode_log_profile + | rocblas_layer_mode_ex_log_kernel); + else + rocsolver_log_set_layer_mode(rocblas_layer_mode_log_profile); + rocsolver_log_set_max_levels(profile); + } + + for(rocblas_int iter = 0; iter < hot_calls; iter++) + { + gemm_initData(handle, dA, hA, dB, hB, dC, hC); + + start = get_time_us_sync(stream); + CHECK_ROCBLAS_ERROR((rocsolver::rocsolver_gemm)(handle, transA, transB, m, n, k, + dalpha.data(), dA.data(), 0, inca, + lda, stA, dB.data(), 0, incb, ldb, + stB, dbeta.data(), dC.data(), 0, incc, + ldc, stC, bc, nullptr)); + *gpu_time_used += get_time_us_sync(stream) - start; + } + *gpu_time_used /= hot_calls; +} + +template +void testing_gemm(Arguments& argus) +{ + // get arguments + rocblas_local_handle handle; + I m = argus.get("m"); + I n = argus.get("n"); + I k = argus.get("k"); + I inca = argus.get("inca", 1); + I incb = argus.get("incb", 1); + I incc = argus.get("incc", 1); + + char tA = argus.get("transA"); + char tB = argus.get("transB"); + rocblas_operation transA = char2rocblas_operation(tA); + rocblas_operation transB = char2rocblas_operation(tB); + I mk = transA == rocblas_operation_none ? m : k; + I km = transA == rocblas_operation_none ? k : m; + I kn = transB == rocblas_operation_none ? k : n; + I nk = transB == rocblas_operation_none ? n : k; + + I lda = argus.get("lda", mk); + I ldb = argus.get("ldb", kn); + I ldc = argus.get("ldc", m); + rocblas_stride stA = argus.get("strideA", lda * km); + rocblas_stride stB = argus.get("strideB", ldb * nk); + rocblas_stride stC = argus.get("strideC", ldc * n); + + T alpha = argus.get("alpha", 1); + T beta = argus.get("beta", 1); + I bc = argus.batch_count; + I hot_calls = argus.iters; + + rocblas_stride stCRes = (argus.unit_check || argus.norm_check) ? stC : 0; + + size_t size_A = size_t(lda) * km; + size_t size_B = size_t(ldb) * nk; + size_t size_C = size_t(ldc) * n; + size_t size_CRes = (argus.unit_check || argus.norm_check) ? size_C : 0; + + double max_error = 0, gpu_time_used = 0, cpu_time_used = 0; + + // check invalid sizes + bool invalid_size = (n < 0 || m < 0 || k < 0 || ldc < m || inca < 1 || incb < 1 || incc < 1 + || bc < 0 || lda < mk || ldb < kn); + + if(invalid_size) + { + if(argus.timing) + rocsolver_bench_inform(inform_invalid_size); + + return; + } + + // memory size query is necessary + if(argus.mem_query) + { + rocsolver_bench_inform(inform_mem_query, 0); + return; + } + + // check quick return + if(n == 0 || m == 0 || k == 0 || bc == 0) + { + if(argus.timing) + rocsolver_bench_inform(inform_quick_return); + + return; + } + + // memory allocations + host_strided_batch_vector halpha(1, 1, 1, 1); + host_strided_batch_vector hbeta(1, 1, 1, 1); + device_strided_batch_vector dalpha(1, 1, 1, 1); + device_strided_batch_vector dbeta(1, 1, 1, 1); + + halpha[0][0] = alpha; + hbeta[0][0] = beta; + + CHECK_HIP_ERROR(dalpha.transfer_from(halpha)); + CHECK_HIP_ERROR(dbeta.transfer_from(hbeta)); + + if(BATCHED) + { + // memory allocations + host_batch_vector hA(size_A, inca, bc); + host_batch_vector hB(size_B, incb, bc); + host_batch_vector hC(size_C, incc, bc); + host_batch_vector hCRes(size_CRes, incc, bc); + device_batch_vector dA(size_A, inca, bc); + device_batch_vector dB(size_B, incb, bc); + device_batch_vector dC(size_C, incc, bc); + if(size_A) + CHECK_HIP_ERROR(dA.memcheck()); + if(size_B) + CHECK_HIP_ERROR(dB.memcheck()); + if(size_C) + CHECK_HIP_ERROR(dC.memcheck()); + + // check computations + if(argus.unit_check || argus.norm_check) + { + gemm_getError(handle, transA, transB, m, n, k, dalpha, dA, inca, lda, stA, dB, + incb, ldb, stB, dbeta, dC, incc, ldc, stC, bc, halpha, hbeta, hA, + hB, hC, hCRes, &max_error); + } + + // collect performance data + if(argus.timing) + { + gemm_getPerfData(handle, transA, transB, m, n, k, dalpha, dA, inca, lda, stA, dB, + incb, ldb, stB, dbeta, dC, incc, ldc, stC, bc, halpha, hbeta, hA, + hB, hC, &gpu_time_used, &cpu_time_used, hot_calls, argus.profile, + argus.profile_kernels, argus.perf); + } + } + + else + { + // memory allocations + host_strided_batch_vector hA(size_A, inca, stA, bc); + host_strided_batch_vector hB(size_B, incb, stB, bc); + host_strided_batch_vector hC(size_C, incc, stC, bc); + host_strided_batch_vector hCRes(size_CRes, incc, stCRes, bc); + device_strided_batch_vector dA(size_A, inca, stA, bc); + device_strided_batch_vector dB(size_B, incb, stB, bc); + device_strided_batch_vector dC(size_C, incc, stC, bc); + if(size_A) + CHECK_HIP_ERROR(dA.memcheck()); + if(size_B) + CHECK_HIP_ERROR(dB.memcheck()); + if(size_C) + CHECK_HIP_ERROR(dC.memcheck()); + + // check computations + if(argus.unit_check || argus.norm_check) + { + gemm_getError(handle, transA, transB, m, n, k, dalpha, dA, inca, lda, stA, dB, + incb, ldb, stB, dbeta, dC, incc, ldc, stC, bc, halpha, hbeta, hA, + hB, hC, hCRes, &max_error); + } + + // collect performance data + if(argus.timing) + { + gemm_getPerfData(handle, transA, transB, m, n, k, dalpha, dA, inca, lda, stA, dB, + incb, ldb, stB, dbeta, dC, incc, ldc, stC, bc, halpha, hbeta, hA, + hB, hC, &gpu_time_used, &cpu_time_used, hot_calls, argus.profile, + argus.profile_kernels, argus.perf); + } + } + + // validate results for rocsolver-test + if(argus.unit_check) + ROCSOLVER_TEST_CHECK(T, max_error, m); + + // output results for rocsolver-bench + if(argus.timing) + { + if(!argus.perf) + { + rocsolver_bench_header("Arguments:"); + if(BATCHED) + { + rocsolver_bench_output("m", "n", "k", "lda", "ldb", "ldc", "transA", "transB", + "batch_count"); + rocsolver_bench_output(m, n, k, lda, ldb, ldc, tA, tB, bc); + } + else if(STRIDED) + { + rocsolver_bench_output("m", "n", "k", "lda", "ldb", "ldc", "strideA", "strideB", + "strideC", "transA", "transB", "batch_count"); + rocsolver_bench_output(m, n, k, lda, ldb, ldc, stA, stB, stC, tA, tB, bc); + } + else + { + rocsolver_bench_output("m", "n", "k", "lda", "ldb", "ldc", "transA", "transB"); + rocsolver_bench_output(m, n, k, lda, ldb, ldc, tA, tB); + } + rocsolver_bench_header("Results:"); + if(argus.norm_check) + { + rocsolver_bench_output("cpu_time_us", "gpu_time_us", "error"); + rocsolver_bench_output(cpu_time_used, gpu_time_used, max_error); + } + else + { + rocsolver_bench_output("cpu_time_us", "gpu_time_us"); + rocsolver_bench_output(cpu_time_used, gpu_time_used); + } + rocsolver_bench_endl(); + } + else + { + if(argus.norm_check) + rocsolver_bench_output(gpu_time_used, max_error); + else + rocsolver_bench_output(gpu_time_used); + } + } + + // ensure all arguments were consumed + argus.validate_consumed(); +} + +#define EXTERN_TESTING_GEMM(...) extern template void testing_gemm<__VA_ARGS__>(Arguments&); + +INSTANTIATE(EXTERN_TESTING_GEMM, + FOREACH_MATRIX_DATA_LAYOUT, + FOREACH_SCALAR_TYPE, + FOREACH_INT_TYPE, + APPLY_STAMP) diff --git a/clients/gtest/CMakeLists.txt b/clients/gtest/CMakeLists.txt index 8c8c4ca5c..2abd3b67c 100755 --- a/clients/gtest/CMakeLists.txt +++ b/clients/gtest/CMakeLists.txt @@ -1,5 +1,5 @@ # ########################################################################## -# Copyright (C) 2019-2024 Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2019-2025 Advanced Micro Devices, Inc. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions @@ -127,6 +127,11 @@ set(rocrefact_test_source refact/csrrf_solve_gtest.cpp ) +set(rocunit_test_source + # gemm + unit/gemm_gtest.cpp +) + set(others_test_source # unified memory model managed_malloc_gtest.cpp @@ -146,6 +151,7 @@ add_executable(rocsolver-test ${roclapack_test_source} ${rocauxiliary_test_source} ${rocrefact_test_source} + ${rocunit_test_source} ${others_test_source} ${rocsolver_test_source} ) diff --git a/clients/gtest/unit/gemm_gtest.cpp b/clients/gtest/unit/gemm_gtest.cpp new file mode 100644 index 000000000..452df2ea7 --- /dev/null +++ b/clients/gtest/unit/gemm_gtest.cpp @@ -0,0 +1,327 @@ +/* ************************************************************************** + * Copyright (C) 2020-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS + * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) + * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT + * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY + * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF + * SUCH DAMAGE. + * *************************************************************************/ + +#include "common/unit/testing_gemm.hpp" + +using ::testing::Combine; +using ::testing::TestWithParam; +using ::testing::Values; +using ::testing::ValuesIn; +using namespace std; + +template +using gemm_tuple = tuple, I, vector, char, char, double, double>; + +// each m_size_range is a {m, lda, ldc} +// each k_size_range is a {k, ldb} + +// for checkin_lapack tests +const vector> m_size_range = { + // normal (valid) samples + {50, 50, 50}, + {150, 200, 180}}; + +const vector> k_size_range = { + // normal (valid) samples + {50, 50}, + {150, 200}}; + +const vector n_size_range = {16, 150}; + +const vector> m_size_range_64 = { + // normal (valid) samples + {50, 50, 50}, + {150, 200, 180}}; + +const vector> k_size_range_64 = { + // normal (valid) samples + {50, 50}, + {150, 200}}; + +const vector n_size_range_64 = {16, 150}; + +// for daily_lapack tests +const vector> large_m_size_range = { + {1000, 1024, 1024}, +}; + +const vector> large_k_size_range = { + {1000, 1024}, +}; + +const vector large_n_size_range = {1000}; + +const vector> large_m_size_range_64 = { + {1000, 1024, 1024}, +}; + +const vector> large_k_size_range_64 = { + {1000, 1024}, +}; + +const vector large_n_size_range_64 = {1000}; + +const vector all_operations = { + 'N', + 'T', + 'C', +}; + +const vector alphas = {1.5}; +const vector betas = {-2.0}; + +template +Arguments gemm_setup_arguments(gemm_tuple tup) +{ + vector m_size = std::get<0>(tup); + I n_size = std::get<1>(tup); + vector k_size = std::get<2>(tup); + char transA = std::get<3>(tup); + char transB = std::get<4>(tup); + double alpha = std::get<5>(tup); + double beta = std::get<6>(tup); + + Arguments arg; + + arg.set("m", m_size[0]); + arg.set("n", n_size); + arg.set("k", k_size[0]); + arg.set("lda", m_size[1]); + arg.set("ldb", k_size[1]); + arg.set("ldc", m_size[2]); + arg.set("transA", transA); + arg.set("transB", transB); + arg.set("alpha", alpha); + arg.set("beta", beta); + + // only testing standard use case/defaults for strides + + arg.timing = 0; + + return arg; +} + +template +class GEMM_BASE : public ::TestWithParam> +{ +protected: + void TearDown() override + { + EXPECT_EQ(hipGetLastError(), hipSuccess); + } + + template + void run_tests() + { + Arguments arg = gemm_setup_arguments(this->GetParam()); + + arg.batch_count = (BATCHED || STRIDED ? 3 : 1); + testing_gemm(arg); + } +}; + +class GEMM : public GEMM_BASE +{ +}; + +class GEMM_64 : public GEMM_BASE +{ +}; + +// non-batch tests + +TEST_P(GEMM, __float) +{ + run_tests(); +} + +TEST_P(GEMM, __double) +{ + run_tests(); +} + +TEST_P(GEMM, __float_complex) +{ + run_tests(); +} + +TEST_P(GEMM, __double_complex) +{ + run_tests(); +} + +// batched tests + +TEST_P(GEMM, batched__float) +{ + run_tests(); +} + +TEST_P(GEMM, batched__double) +{ + run_tests(); +} + +TEST_P(GEMM, batched__float_complex) +{ + run_tests(); +} + +TEST_P(GEMM, batched__double_complex) +{ + run_tests(); +} + +// strided_batched cases + +TEST_P(GEMM, strided_batched__float) +{ + run_tests(); +} + +TEST_P(GEMM, strided_batched__double) +{ + run_tests(); +} + +TEST_P(GEMM, strided_batched__float_complex) +{ + run_tests(); +} + +TEST_P(GEMM, strided_batched__double_complex) +{ + run_tests(); +} + +// 64-bit API + +// non-batch tests + +TEST_P(GEMM_64, __float) +{ + run_tests(); +} + +TEST_P(GEMM_64, __double) +{ + run_tests(); +} + +TEST_P(GEMM_64, __float_complex) +{ + run_tests(); +} + +TEST_P(GEMM_64, __double_complex) +{ + run_tests(); +} + +// batched tests + +TEST_P(GEMM_64, batched__float) +{ + run_tests(); +} + +TEST_P(GEMM_64, batched__double) +{ + run_tests(); +} + +TEST_P(GEMM_64, batched__float_complex) +{ + run_tests(); +} + +TEST_P(GEMM_64, batched__double_complex) +{ + run_tests(); +} + +// strided_batched cases + +TEST_P(GEMM_64, strided_batched__float) +{ + run_tests(); +} + +TEST_P(GEMM_64, strided_batched__double) +{ + run_tests(); +} + +TEST_P(GEMM_64, strided_batched__float_complex) +{ + run_tests(); +} + +TEST_P(GEMM_64, strided_batched__double_complex) +{ + run_tests(); +} + +INSTANTIATE_TEST_SUITE_P(daily_lapack, + GEMM, + Combine(ValuesIn(large_m_size_range), + ValuesIn(large_n_size_range), + ValuesIn(large_k_size_range), + ValuesIn(all_operations), + ValuesIn(all_operations), + ValuesIn(alphas), + ValuesIn(betas))); + +INSTANTIATE_TEST_SUITE_P(checkin_lapack, + GEMM, + Combine(ValuesIn(m_size_range), + ValuesIn(n_size_range), + ValuesIn(k_size_range), + ValuesIn(all_operations), + ValuesIn(all_operations), + ValuesIn(alphas), + ValuesIn(betas))); + +INSTANTIATE_TEST_SUITE_P(daily_lapack, + GEMM_64, + Combine(ValuesIn(large_m_size_range_64), + ValuesIn(large_n_size_range_64), + ValuesIn(large_k_size_range_64), + ValuesIn(all_operations), + ValuesIn(all_operations), + ValuesIn(alphas), + ValuesIn(betas))); + +INSTANTIATE_TEST_SUITE_P(checkin_lapack, + GEMM_64, + Combine(ValuesIn(m_size_range_64), + ValuesIn(n_size_range_64), + ValuesIn(k_size_range_64), + ValuesIn(all_operations), + ValuesIn(all_operations), + ValuesIn(alphas), + ValuesIn(betas))); diff --git a/common/include/rocblas_utility.hpp b/common/include/rocblas_utility.hpp index c5b96615c..e1c421c9e 100644 --- a/common/include/rocblas_utility.hpp +++ b/common/include/rocblas_utility.hpp @@ -1,5 +1,5 @@ /* ************************************************************************** - * Copyright (C) 2016-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2016-2025 Advanced Micro Devices, Inc. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions @@ -108,6 +108,24 @@ __device__ __host__ inline T conj(const T& z) return std::conj(z); } +// Exchange values between threads in a warp +template , int> = 0> +__device__ T shfl(T val, I src) +{ + return __shfl(val, src); +} + +template , int> = 0> +__device__ T shfl(T val, I src) +{ + using S = decltype(std::real(T{})); + + auto r = __shfl(val.real(), src); + auto i = __shfl(val.imag(), src); + + return rocblas_complex_num(r, i); +} + // Load a scalar. If the argument is a pointer, dereference it; otherwise copy // it. Allows the same kernels to be used for host and device scalars. diff --git a/library/src/include/lib_device_helpers.hpp b/library/src/include/lib_device_helpers.hpp index 5a18ae26c..b6af4775f 100644 --- a/library/src/include/lib_device_helpers.hpp +++ b/library/src/include/lib_device_helpers.hpp @@ -1,5 +1,5 @@ /* ************************************************************************** - * Copyright (C) 2019-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2019-2025 Advanced Micro Devices, Inc. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions @@ -119,23 +119,6 @@ __device__ void scale_tridiag(const rocblas_int start, const rocblas_int end, T* } } -template , int> = 0> -__device__ T shfl(T val, I src) -{ - return __shfl(val, src); -} - -template , int> = 0> -__device__ T shfl(T val, I src) -{ - using S = decltype(std::real(T{})); - - auto r = __shfl(val.real(), src); - auto i = __shfl(val.imag(), src); - - return rocblas_complex_num(r, i); -} - // ********************************************************** // GPU kernels that are used by many rocsolver functions // ********************************************************** diff --git a/library/src/specialized/roclapack_gemm_device_functions.hpp b/library/src/include/roclapack_gemm_device_functions.hpp similarity index 97% rename from library/src/specialized/roclapack_gemm_device_functions.hpp rename to library/src/include/roclapack_gemm_device_functions.hpp index 11a09e74d..3b29bd13f 100644 --- a/library/src/specialized/roclapack_gemm_device_functions.hpp +++ b/library/src/include/roclapack_gemm_device_functions.hpp @@ -1,5 +1,5 @@ /* ************************************************************************** - * Copyright (C) 2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2024-2025 Advanced Micro Devices, Inc. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions @@ -27,8 +27,7 @@ #pragma once -#include "rocblas.hpp" -#include "rocsolver/rocsolver.h" +#include "rocblas_utility.hpp" #if defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) #define ROCSOLVER_MFMA_ENABLED 1 @@ -152,7 +151,7 @@ __device__ inline I get_c_row(I li, I lj, I gpri, I inc_C, I ldc) Where C is an m x n matrix, A is an m x p matrix, and B is an p x n matrix. This is a wave function, every lane of the wave must perform call this function. - + transA form of op(A). transB form of op(B). m number of rows of matrix C. @@ -171,7 +170,7 @@ __device__ inline I get_c_row(I li, I lj, I gpri, I inc_C, I ldc) C pointer to matrix C. inc_C stride from the start of one row to the next of matrix C. ldc leading dimension of C. - + **/ // Run with warpSize sized block template @@ -253,10 +252,12 @@ __device__ void gemm_16x16xp(rocblas_operation transA, } if constexpr(rocblas_is_complex) + { if(transA == rocblas_operation_conjugate_transpose) amk = conj(amk); - if(transB == rocblas_operation_conjugate_transpose) - bkn = conj(bkn); + if(transB == rocblas_operation_conjugate_transpose) + bkn = conj(bkn); + } dmn = mfma_16x16x4()(amk, bkn, dmn); } @@ -278,4 +279,4 @@ __device__ void gemm_16x16xp(rocblas_operation transA, #endif // ROCSOLVER_MFMA_ENABLED -ROCSOLVER_END_NAMESPACE \ No newline at end of file +ROCSOLVER_END_NAMESPACE diff --git a/library/src/include/rocsolver_run_specialized_kernels.hpp b/library/src/include/rocsolver_run_specialized_kernels.hpp index 43ede3892..88844332d 100644 --- a/library/src/include/rocsolver_run_specialized_kernels.hpp +++ b/library/src/include/rocsolver_run_specialized_kernels.hpp @@ -1,5 +1,5 @@ /* ************************************************************************** - * Copyright (C) 2021-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions @@ -27,7 +27,7 @@ #pragma once -#include "rocblas.hpp" +#include "rocblas/rocblas.h" #include "rocsolver/rocsolver.h" ROCSOLVER_BEGIN_NAMESPACE @@ -188,55 +188,55 @@ rocblas_status rocsolver_trsm_upper(rocblas_handle handle, // gemm template -rocblas_status rocsolver_gemm(rocblas_handle handle, - rocblas_operation transA, - rocblas_operation transB, - I m, - I n, - I k, - const T* alpha, - U1 A, - rocblas_stride shiftA, - I lda, - rocblas_stride strideA, - U2 B, - rocblas_stride shiftB, - I ldb, - rocblas_stride strideB, - const T* beta, - U3 C, - rocblas_stride shiftC, - I ldc, - rocblas_stride strideC, - I batch_count, - T** work); +ROCSOLVER_EXPORT rocblas_status rocsolver_gemm(rocblas_handle handle, + rocblas_operation transA, + rocblas_operation transB, + I m, + I n, + I k, + const T* alpha, + U1 A, + rocblas_stride shiftA, + I lda, + rocblas_stride strideA, + U2 B, + rocblas_stride shiftB, + I ldb, + rocblas_stride strideB, + const T* beta, + U3 C, + rocblas_stride shiftC, + I ldc, + rocblas_stride strideC, + I batch_count, + T** work); template -rocblas_status rocsolver_gemm(rocblas_handle handle, - rocblas_operation transA, - rocblas_operation transB, - I m, - I n, - I k, - const T* alpha, - U1 A, - rocblas_stride shiftA, - I inca, - I lda, - rocblas_stride strideA, - U2 B, - rocblas_stride shiftB, - I incb, - I ldb, - rocblas_stride strideB, - const T* beta, - U3 C, - rocblas_stride shiftC, - I incc, - I ldc, - rocblas_stride strideC, - I batch_count, - T** work); +ROCSOLVER_EXPORT rocblas_status rocsolver_gemm(rocblas_handle handle, + rocblas_operation transA, + rocblas_operation transB, + I m, + I n, + I k, + const T* alpha, + U1 A, + rocblas_stride shiftA, + I inca, + I lda, + rocblas_stride strideA, + U2 B, + rocblas_stride shiftB, + I incb, + I ldb, + rocblas_stride strideB, + const T* beta, + U3 C, + rocblas_stride shiftC, + I incc, + I ldc, + rocblas_stride strideC, + I batch_count, + T** work); // ger template diff --git a/library/src/specialized/rocauxiliary_larf_specialized_kernels.hpp b/library/src/specialized/rocauxiliary_larf_specialized_kernels.hpp index f07c2cbcc..e76ef2d36 100644 --- a/library/src/specialized/rocauxiliary_larf_specialized_kernels.hpp +++ b/library/src/specialized/rocauxiliary_larf_specialized_kernels.hpp @@ -4,7 +4,7 @@ * Univ. of Tennessee, Univ. of California Berkeley, * Univ. of Colorado Denver and NAG Ltd.. * December 2016 - * Copyright (C) 2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2024-2025 Advanced Micro Devices, Inc. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions @@ -33,6 +33,7 @@ #pragma once #include "lapack_device_functions.hpp" +#include "rocblas.hpp" #include "rocsolver_run_specialized_kernels.hpp" ROCSOLVER_BEGIN_NAMESPACE diff --git a/library/src/specialized/rocauxiliary_larfg_specialized_kernels.hpp b/library/src/specialized/rocauxiliary_larfg_specialized_kernels.hpp index 6f1f98744..781911878 100644 --- a/library/src/specialized/rocauxiliary_larfg_specialized_kernels.hpp +++ b/library/src/specialized/rocauxiliary_larfg_specialized_kernels.hpp @@ -4,7 +4,7 @@ * Univ. of Tennessee, Univ. of California Berkeley, * Univ. of Colorado Denver and NAG Ltd.. * December 2016 - * Copyright (C) 2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2024-2025 Advanced Micro Devices, Inc. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions @@ -34,6 +34,7 @@ #include "../auxiliary/rocauxiliary_larfg.hpp" #include "lapack_device_functions.hpp" +#include "rocblas.hpp" #include "rocsolver_run_specialized_kernels.hpp" ROCSOLVER_BEGIN_NAMESPACE diff --git a/library/src/specialized/roclapack_gemm_specialized_kernels.hpp b/library/src/specialized/roclapack_gemm_specialized_kernels.hpp index a94174c10..996cf7bd1 100644 --- a/library/src/specialized/roclapack_gemm_specialized_kernels.hpp +++ b/library/src/specialized/roclapack_gemm_specialized_kernels.hpp @@ -1,5 +1,5 @@ /* ************************************************************************** - * Copyright (C) 2019-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2019-2025 Advanced Micro Devices, Inc. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions @@ -27,6 +27,7 @@ #pragma once +#include "rocblas.hpp" #include "roclapack_gemm_device_functions.hpp" #include "rocsolver_run_specialized_kernels.hpp" @@ -188,31 +189,31 @@ ROCSOLVER_KERNEL void mfma_gemm_kernel(rocblas_operation transA, *************************************************************/ template -rocblas_status rocsolver_gemm(rocblas_handle handle, - rocblas_operation transA, - rocblas_operation transB, - I m, - I n, - I k, - const T* alpha, - U1 A, - rocblas_stride shiftA, - I inca, - I lda, - rocblas_stride strideA, - U2 B, - rocblas_stride shiftB, - I incb, - I ldb, - rocblas_stride strideB, - const T* beta, - U3 C, - rocblas_stride shiftC, - I incc, - I ldc, - rocblas_stride strideC, - I batch_count, - T** work) +ROCSOLVER_EXPORT rocblas_status rocsolver_gemm(rocblas_handle handle, + rocblas_operation transA, + rocblas_operation transB, + I m, + I n, + I k, + const T* alpha, + U1 A, + rocblas_stride shiftA, + I inca, + I lda, + rocblas_stride strideA, + U2 B, + rocblas_stride shiftB, + I incb, + I ldb, + rocblas_stride strideB, + const T* beta, + U3 C, + rocblas_stride shiftC, + I incc, + I ldc, + rocblas_stride strideC, + I batch_count, + T** work) { ROCSOLVER_ENTER("gemm", "transA:", transA, "transB:", transB, "m:", m, "n:", n, "k:", k, "shiftA:", shiftA, "inca:", inca, "lda:", lda, "shiftB:", shiftB, "incb:", incb, @@ -317,28 +318,28 @@ rocblas_status rocsolver_gemm(rocblas_handle handle, *************************************************************/ template -inline rocblas_status rocsolver_gemm(rocblas_handle handle, - rocblas_operation transA, - rocblas_operation transB, - I m, - I n, - I k, - const T* alpha, - U1 A, - rocblas_stride shiftA, - I lda, - rocblas_stride strideA, - U2 B, - rocblas_stride shiftB, - I ldb, - rocblas_stride strideB, - const T* beta, - U3 C, - rocblas_stride shiftC, - I ldc, - rocblas_stride strideC, - I batch_count, - T** work) +ROCSOLVER_EXPORT inline rocblas_status rocsolver_gemm(rocblas_handle handle, + rocblas_operation transA, + rocblas_operation transB, + I m, + I n, + I k, + const T* alpha, + U1 A, + rocblas_stride shiftA, + I lda, + rocblas_stride strideA, + U2 B, + rocblas_stride shiftB, + I ldb, + rocblas_stride strideB, + const T* beta, + U3 C, + rocblas_stride shiftC, + I ldc, + rocblas_stride strideC, + I batch_count, + T** work) { return rocsolver_gemm(handle, transA, transB, m, n, k, alpha, A, shiftA, 1, lda, strideA, B, shiftB, 1, ldb, strideB, beta, C, shiftC, 1, ldc, strideC, @@ -350,7 +351,7 @@ inline rocblas_status rocsolver_gemm(rocblas_handle handle, *************************************************************/ #define INSTANTIATE_GEMM(T, I, U1, U2, U3) \ - template rocblas_status rocsolver_gemm( \ + template ROCSOLVER_EXPORT rocblas_status rocsolver_gemm( \ rocblas_handle handle, rocblas_operation transA, rocblas_operation transB, I m, I n, I k, \ const T* alpha, U1 A, rocblas_stride shiftA, I lda, rocblas_stride strideA, U2 B, \ rocblas_stride shiftB, I ldb, rocblas_stride strideB, const T* beta, U3 C, \ diff --git a/library/src/specialized/roclapack_ger_specialized_kernels.hpp b/library/src/specialized/roclapack_ger_specialized_kernels.hpp index 38b68b88e..1f2cad223 100644 --- a/library/src/specialized/roclapack_ger_specialized_kernels.hpp +++ b/library/src/specialized/roclapack_ger_specialized_kernels.hpp @@ -1,5 +1,5 @@ /* ************************************************************************** - * Copyright (C) 2019-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2019-2025 Advanced Micro Devices, Inc. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions @@ -27,6 +27,7 @@ #pragma once +#include "rocblas.hpp" #include "rocsolver_run_specialized_kernels.hpp" ROCSOLVER_BEGIN_NAMESPACE diff --git a/library/src/specialized/roclapack_getf2_specialized_kernels.hpp b/library/src/specialized/roclapack_getf2_specialized_kernels.hpp index d27caadf7..463ea3297 100644 --- a/library/src/specialized/roclapack_getf2_specialized_kernels.hpp +++ b/library/src/specialized/roclapack_getf2_specialized_kernels.hpp @@ -4,11 +4,12 @@ * Factorization and inversion of a million matrices using GPUs: Challenges * and countermeasures. Procedia Computer Science, 108, 606-615. * - * Copyright (C) 2019-2024 Advanced Micro Devices, Inc. + * Copyright (C) 2019-2025 Advanced Micro Devices, Inc. * ***********************************************************************/ #pragma once +#include "rocblas.hpp" #include "rocsolver_run_specialized_kernels.hpp" ROCSOLVER_BEGIN_NAMESPACE diff --git a/library/src/specialized/roclapack_getri_specialized_kernels.hpp b/library/src/specialized/roclapack_getri_specialized_kernels.hpp index e443871ca..26e230abd 100644 --- a/library/src/specialized/roclapack_getri_specialized_kernels.hpp +++ b/library/src/specialized/roclapack_getri_specialized_kernels.hpp @@ -4,7 +4,7 @@ * Univ. of Tennessee, Univ. of California Berkeley, * Univ. of Colorado Denver and NAG Ltd.. * December 2016 - * Copyright (C) 2019-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2019-2025 Advanced Micro Devices, Inc. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions @@ -32,6 +32,7 @@ #pragma once +#include "rocblas.hpp" #include "rocsolver_run_specialized_kernels.hpp" ROCSOLVER_BEGIN_NAMESPACE diff --git a/library/src/specialized/roclapack_potf2_specialized_kernels.hpp b/library/src/specialized/roclapack_potf2_specialized_kernels.hpp index 7387f7acc..92a7c9fff 100644 --- a/library/src/specialized/roclapack_potf2_specialized_kernels.hpp +++ b/library/src/specialized/roclapack_potf2_specialized_kernels.hpp @@ -4,7 +4,7 @@ * Univ. of Tennessee, Univ. of California Berkeley, * Univ. of Colorado Denver and NAG Ltd.. * December 2016 - * Copyright (C) 2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2024-2025 Advanced Micro Devices, Inc. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions @@ -32,6 +32,7 @@ #pragma once +#include "rocblas.hpp" #include "rocsolver_run_specialized_kernels.hpp" #include #include diff --git a/library/src/specialized/roclapack_trsm_specialized_kernels.hpp b/library/src/specialized/roclapack_trsm_specialized_kernels.hpp index 01ab387a7..323df390b 100644 --- a/library/src/specialized/roclapack_trsm_specialized_kernels.hpp +++ b/library/src/specialized/roclapack_trsm_specialized_kernels.hpp @@ -1,5 +1,5 @@ /* ************************************************************************** - * Copyright (C) 2019-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2019-2025 Advanced Micro Devices, Inc. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions @@ -27,6 +27,7 @@ #pragma once +#include "rocblas.hpp" #include "rocsolver_run_specialized_kernels.hpp" ROCSOLVER_BEGIN_NAMESPACE diff --git a/library/src/specialized/roclapack_trtri_specialized_kernels.hpp b/library/src/specialized/roclapack_trtri_specialized_kernels.hpp index 35fdc6847..09fff4ccb 100644 --- a/library/src/specialized/roclapack_trtri_specialized_kernels.hpp +++ b/library/src/specialized/roclapack_trtri_specialized_kernels.hpp @@ -4,7 +4,7 @@ * Univ. of Tennessee, Univ. of California Berkeley, * Univ. of Colorado Denver and NAG Ltd.. * December 2016 - * Copyright (C) 2021-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions @@ -32,6 +32,7 @@ #pragma once +#include "rocblas.hpp" #include "rocsolver_run_specialized_kernels.hpp" ROCSOLVER_BEGIN_NAMESPACE From 288ac4dbffd8e2c2312033c1a3f9a0523351829c Mon Sep 17 00:00:00 2001 From: AGonzales-amd Date: Thu, 30 Jan 2025 01:07:18 +0000 Subject: [PATCH 7/7] windows ci --- clients/common/unit/testing_gemm.hpp | 1 + 1 file changed, 1 insertion(+) diff --git a/clients/common/unit/testing_gemm.hpp b/clients/common/unit/testing_gemm.hpp index 51bbd5ae6..a9446c439 100644 --- a/clients/common/unit/testing_gemm.hpp +++ b/clients/common/unit/testing_gemm.hpp @@ -215,6 +215,7 @@ void testing_gemm(Arguments& argus) { // get arguments rocblas_local_handle handle; + rocblas_set_pointer_mode(handle, rocblas_pointer_mode_device); I m = argus.get("m"); I n = argus.get("n"); I k = argus.get("k");