Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize LU factorization without pivoting #680

Open
wants to merge 75 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
75 commits
Select commit Hold shift + click to select a range
da9bc3c
development snapshot
EdDAzevedo Jan 31, 2024
d5986e3
snapshot compile ok
EdDAzevedo Feb 6, 2024
103104c
initial checking potf2_simple
EdDAzevedo Feb 6, 2024
2a70104
fix bug related to setting info
EdDAzevedo Feb 6, 2024
3b1da14
use potf2_lds for n <= 64
EdDAzevedo Feb 6, 2024
f3f7c9b
cleanup unused code
EdDAzevedo Feb 6, 2024
302aeff
initial version of packed storage
EdDAzevedo Feb 20, 2024
2819d15
increase block size for POTF2 by using packed storage in LDS
EdDAzevedo Feb 20, 2024
ac035c1
initial version to use rocblasCall_trsm
EdDAzevedo Feb 20, 2024
a3289b8
option to compute as Lower triangular
EdDAzevedo Feb 20, 2024
953e1ce
Merge branch 'develop' into SWDEV-409738_develop_debug
EdDAzevedo Feb 21, 2024
3b0d091
block size for POTRF Cholesky factorization is related to type
EdDAzevedo Feb 21, 2024
48c6feb
work around for problem with memsize related to rocblas trsm
EdDAzevedo Feb 22, 2024
3959b98
Add potf2_lds to specialized kernels folder
tfalders Feb 23, 2024
2122c83
Merge pull request #8 from tfalders/SWDEV-409738-reorg
EdDAzevedo Feb 23, 2024
fecfc19
remove optimization for mixed complex and real arithmetic
EdDAzevedo Feb 24, 2024
2340964
use rocblas_trsm for double and double_complex types for better numer…
EdDAzevedo Feb 24, 2024
f66d2fb
add specialized kernels for getf2_nopiv
EdDAzevedo Feb 26, 2024
90a2264
add block size for GETRF_NOPIV
EdDAzevedo Feb 26, 2024
4024078
option to use getrf_nopiv
EdDAzevedo Feb 26, 2024
6718103
initial checkin of getf2_nopiv
EdDAzevedo Feb 26, 2024
aa9edd5
include getf2_nopiv_run_small
EdDAzevedo Feb 26, 2024
46e43b4
initial checking option to use getrf_nopiv
EdDAzevedo Feb 26, 2024
169437e
initial checking of getf2_nopiv and getrf_nopiv
EdDAzevedo Feb 26, 2024
fe27c46
add option to call getrf_nopiv
EdDAzevedo Feb 26, 2024
49ec7ea
correct minor issue with size_iinfo
EdDAzevedo Feb 27, 2024
feba842
minor change to call in rocblasCall_trsm_mem
EdDAzevedo Feb 27, 2024
738c0b0
debug snapshot
EdDAzevedo Feb 28, 2024
9699078
Merge branch 'rocm_develop' into opt_getf2_nopiv_debug
EdDAzevedo Feb 28, 2024
c7cc063
update blocksize for GETRF_NOPIV
EdDAzevedo Feb 29, 2024
2906564
continue processing even when zero detected to be backward compatible
EdDAzevedo Feb 29, 2024
b7f4483
update calls to rocblasCall_trsm_mem()
EdDAzevedo Feb 29, 2024
11c4129
add test to avoid division by zero
EdDAzevedo Feb 29, 2024
786254b
minor change to be compatible with getf2
EdDAzevedo Feb 29, 2024
04e3700
add check for sfmin
EdDAzevedo Mar 1, 2024
4a0e260
Merge branch 'opt_getf2_nopiv_develop' into opt_getf2_nopiv
EdDAzevedo Mar 1, 2024
04ca4c9
Merge branch 'rocm_develop' into opt_getf2_nopiv
EdDAzevedo Mar 4, 2024
3ee5f91
with clang-format
EdDAzevedo Mar 4, 2024
e351b2c
use std::max,min to replace max,min
EdDAzevedo Mar 5, 2024
d86f8a6
use std::max,min to replace max,min
EdDAzevedo Mar 5, 2024
0f69c17
adjust nb based on n
EdDAzevedo Mar 11, 2024
b707ee5
use exact amount of lds shared memory
EdDAzevedo Mar 12, 2024
7d91451
add check for lds size
EdDAzevedo Mar 12, 2024
78bd2a2
option to use rocsolver wrapper to rocblas
EdDAzevedo Mar 13, 2024
fa9722c
use rocblas directly for non-batched problem
EdDAzevedo Mar 13, 2024
2b1a61e
Merge branch 'rocm_develop' into opt_getf2_nopiv
EdDAzevedo Apr 5, 2024
5b4869a
Merge branch 'rocm_develop' into opt_getf2_nopiv
EdDAzevedo Apr 8, 2024
ae7600a
Merge branch 'rocm_develop' into opt_getf2_nopiv
EdDAzevedo Apr 15, 2024
fef85b5
use rocsolver_trsm_mem instead of rocblasCall_trsm_mem
EdDAzevedo Apr 15, 2024
03068d9
Merge branch 'rocm_develop' into opt_getf2_nopiv
EdDAzevedo Apr 16, 2024
561a2b3
minor update to rocsolver_trsm_mem
EdDAzevedo Apr 16, 2024
d4665fc
minor correction to rocsolver_trsm_mem
EdDAzevedo Apr 16, 2024
53e7411
minor update to rocsolver_trsm
EdDAzevedo Apr 16, 2024
ed4ee87
use rocsolver blas
EdDAzevedo Apr 16, 2024
9e73254
debug snapshot
EdDAzevedo Apr 27, 2024
85dfe37
update getrf_npvt_stopping_nb
EdDAzevedo Apr 28, 2024
db049e1
debug snapshot
EdDAzevedo Apr 28, 2024
31c8387
debug snapshot
EdDAzevedo Apr 29, 2024
33f6f4b
debug snapshot
EdDAzevedo Apr 29, 2024
91e4ca6
debug snapshot
EdDAzevedo Apr 30, 2024
e7dfacc
set stopping nb for getrf_nopvit
EdDAzevedo Apr 30, 2024
cb323cb
minor cleanup
EdDAzevedo Apr 30, 2024
064299c
add more test cases
EdDAzevedo Apr 30, 2024
7f33c39
remove large long running test cases
EdDAzevedo May 1, 2024
889e4c5
query lds to set block size
EdDAzevedo May 1, 2024
ce91838
fix nb block size
EdDAzevedo May 1, 2024
0c788e8
Merge branch 'rocm_develop' into opt_getf2_nopiv
EdDAzevedo May 1, 2024
04fd144
resolve merge conflict
EdDAzevedo May 1, 2024
5a5d747
minor cleanup in blocksize
EdDAzevedo May 2, 2024
7cc63c6
minor cleanup
EdDAzevedo May 2, 2024
a9034d4
clang format
EdDAzevedo May 2, 2024
a64900e
clang-format
EdDAzevedo May 3, 2024
be6224d
Merge branch 'rocm_develop' into opt_getf2_nopiv_recursive
EdDAzevedo May 17, 2024
25ca39f
add namespace
EdDAzevedo May 17, 2024
94a5ebd
add namespace
EdDAzevedo May 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions clients/gtest/lapack/getf2_getrf_gtest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,7 @@ const vector<vector<int>> large_matrix_size_range = {
{1000, 1024, 0},
};

const vector<int> large_n_size_range = {
45, 64, 520, 1024, 2000,
};
const vector<int> large_n_size_range = {45, 64, 520, 1024, 2000};

Arguments getrf_setup_arguments(getrf_tuple tup)
{
Expand Down
5 changes: 5 additions & 0 deletions library/src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,11 @@ set(rocsolver_specialized_source
specialized/roclapack_potf2_specialized_kernels_d.cpp
specialized/roclapack_potf2_specialized_kernels_c.cpp
specialized/roclapack_potf2_specialized_kernels_z.cpp
# getf2_nopiv
specialized/roclapack_getf2_nopiv_specialized_kernels_s.cpp
specialized/roclapack_getf2_nopiv_specialized_kernels_d.cpp
specialized/roclapack_getf2_nopiv_specialized_kernels_c.cpp
specialized/roclapack_getf2_nopiv_specialized_kernels_z.cpp
)

if(OPTIMAL)
Expand Down
15 changes: 15 additions & 0 deletions library/src/include/ideal_sizes.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,21 @@
#define GETRF_NPVT_BATCH_BLKSIZES_COMPLEX 0, -16, -32, -48, 64, 128
#endif

// ---------------------------------------------------------------
// size of submatrix that can fit in 64 KBytes of LDS shared memory
// ---------------------------------------------------------------
#ifndef GETRF_NOPIV_BLOCKSIZE
#define GETRF_NOPIV_BLOCKSIZE(T) ((sizeof(T) == 4) ? 128 : (sizeof(T) == 8) ? 90 : 64)
#endif

// --------------------------------------------
// assume there is at least 8 MBytes of last level cache
// terminate recursion if matrix can fit in cache
// --------------------------------------------
#ifndef GETRF_NOPIV_STOPPING_NB
#define GETRF_NOPIV_STOPPING_NB(T) ((sizeof(T) == 4) ? 1408 : (sizeof(T) == 8) ? 1024 : 704)
#endif

/****************************** getri *****************************************
*******************************************************************************/
#ifndef GETRI_MAX_COLS
Expand Down
10 changes: 10 additions & 0 deletions library/src/include/rocsolver_run_specialized_kernels.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,16 @@ rocblas_status potf2_run_small(rocblas_handle handle,
rocblas_int* info,
const rocblas_int batch_count);

template <typename T, typename U>
rocblas_status getf2_nopiv_run_small(rocblas_handle handle,
const rocblas_int n,
U AA,
const rocblas_int shiftA,
const rocblas_int lda,
const rocblas_stride strideA,
rocblas_int* info,
const rocblas_int batch_count);

#ifdef OPTIMAL

template <typename T, typename I, typename INFO, typename U>
Expand Down
120 changes: 120 additions & 0 deletions library/src/lapack/roclapack_getf2_nopiv.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
/************************************************************************
* Derived from the BSD3-licensed
* LAPACK routine (version 3.7.0) --
* Univ. of Tennessee, Univ. of California Berkeley,
* Univ. of Colorado Denver and NAG Ltd..
* December 2016
* Copyright (C) 2019-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions
* are met:
*
* 1. Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
*
* THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
* ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
* OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
* HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
* LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
* OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
* SUCH DAMAGE.
* *************************************************************************/

#pragma once

#include "rocblas.hpp"
#include "rocsolver/rocsolver.h"
#include "rocsolver_run_specialized_kernels.hpp"

ROCSOLVER_BEGIN_NAMESPACE

template <typename T>
rocblas_status rocsolver_getf2_nopiv_potrf_argCheck(rocblas_handle handle,
const rocblas_int n,
const rocblas_int lda,
T A,
rocblas_int* info,
const rocblas_int batch_count = 1)
{
// order is important for unit tests:

// 1. invalid size
if(n < 0 || lda < n || batch_count < 0)
return rocblas_status_invalid_size;

// 2. skip pointer check if querying memory size
if(rocblas_is_device_memory_size_query(handle))
return rocblas_status_continue;

// 3. invalid pointers
if((n && !A) || (batch_count && !info))
return rocblas_status_invalid_pointer;

return rocblas_status_continue;
}

template <typename T, typename U, bool COMPLEX = rocblas_is_complex<T>>
rocblas_status rocsolver_getf2_nopiv_template(rocblas_handle handle,
const rocblas_int n,
U A,
const rocblas_int shiftA,
const rocblas_int lda,
const rocblas_stride strideA,
rocblas_int* info,
const rocblas_int batch_count)
{
ROCSOLVER_ENTER("getf2_nopiv", "n:", n, "shiftA:", shiftA, "lda:", lda, "bc:", batch_count);

// quick return if zero instances in batch
if(batch_count == 0)
return rocblas_status_success;

hipStream_t stream;
rocblas_get_stream(handle, &stream);

rocblas_int blocksReset = (batch_count - 1) / BS1 + 1;
dim3 gridReset(blocksReset, 1, 1);
dim3 threads(BS1, 1, 1);

// info=0 (starting with a positive definite matrix)
ROCSOLVER_LAUNCH_KERNEL(reset_info, gridReset, threads, 0, stream, info, batch_count, 0);

// quick return if no dimensions
if(n == 0)
return rocblas_status_success;

// everything must be executed with scalars on the device
rocblas_pointer_mode old_mode;
rocblas_get_pointer_mode(handle, &old_mode);
rocblas_set_pointer_mode(handle, rocblas_pointer_mode_device);

assert(n <= GETRF_NOPIV_BLOCKSIZE(T));

rocblas_status istat = rocblas_status_success;
if(n <= GETRF_NOPIV_BLOCKSIZE(T))
{
// ----------------------
// use specialized kernel
// ----------------------
getf2_nopiv_run_small<T>(handle, n, A, shiftA, lda, strideA, info, batch_count);
}
else
{
istat = rocblas_status_internal_error;
}

rocblas_set_pointer_mode(handle, old_mode);
return istat;
}

ROCSOLVER_END_NAMESPACE
130 changes: 120 additions & 10 deletions library/src/lapack/roclapack_getrf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
* *************************************************************************/

#include "roclapack_getrf.hpp"
#include "roclapack_getrf_nopiv.hpp"

ROCSOLVER_BEGIN_NAMESPACE

Expand Down Expand Up @@ -109,6 +110,84 @@ rocblas_status rocsolver_getrf_impl(rocblas_handle handle,
optim_mem, pivot);
}

template <typename T, typename U>
rocblas_status rocsolver_getrf_nopiv_impl(rocblas_handle handle,
const rocblas_int m,
const rocblas_int n,
U A,
const rocblas_int lda,
rocblas_int* info)
{
const char* name = "getrf_nopiv";
ROCSOLVER_ENTER_TOP(name, "-m", m, "-n", n, "--lda", lda);

using S = decltype(std::real(T{}));

if(!handle)
return rocblas_status_invalid_handle;

// argument checking
{
bool const pivot = false;
rocblas_int* const ipiv = nullptr;

rocblas_status st = rocsolver_getf2_getrf_argCheck(handle, m, n, lda, A, ipiv, info, pivot);
if(st != rocblas_status_continue)
return st;
}

// working with unshifted arrays
rocblas_int shiftA = 0;

// normal (non-batched non-strided) execution
rocblas_stride strideA = 0;
rocblas_int batch_count = 1;

// memory workspace sizes:
//
// size of reusable workspace (and for calling TRSM)
bool optim_mem = false;
size_t size_work1 = 0;
size_t size_work2 = 0;
size_t size_work3 = 0;
size_t size_work4 = 0;

// size to store info about singularity of each subblock
size_t size_iinfo = 0;

constexpr bool is_batched = false;
constexpr bool is_strided = false;
rocsolver_getrf_nopiv_getMemorySize<is_batched, is_strided, T>(
m, n, batch_count, &size_work1, &size_work2, &size_work3, &size_work4, &size_iinfo,
&optim_mem);

if(rocblas_is_device_memory_size_query(handle))
return rocblas_set_optimal_device_memory_size(handle, size_work1, size_work2, size_work3,
size_work4, size_iinfo);

// memory workspace allocation

void* work1 = nullptr;
void* work2 = nullptr;
void* work3 = nullptr;
void* work4 = nullptr;
void* iinfo = nullptr;
rocblas_device_malloc mem(handle, size_work1, size_work2, size_work3, size_work4, size_iinfo);

if(!mem)
return rocblas_status_memory_error;

work1 = mem[0];
work2 = mem[1];
work3 = mem[2];
work4 = mem[3];
iinfo = mem[4];

// execution
return rocsolver_getrf_nopiv_template<false, false, T>(handle, m, n, A, shiftA, lda, strideA,
info, batch_count, work1, work2, work3,
work4, (rocblas_int*)iinfo, optim_mem);
}
ROCSOLVER_END_NAMESPACE

/*
Expand Down Expand Up @@ -165,6 +244,7 @@ rocblas_status rocsolver_zgetrf(rocblas_handle handle,
true);
}

static constexpr bool use_getrf_nopiv = true;
rocblas_status rocsolver_sgetrf_64(rocblas_handle handle,
const int64_t m,
const int64_t n,
Expand Down Expand Up @@ -218,8 +298,15 @@ rocblas_status rocsolver_sgetrf_npvt(rocblas_handle handle,
const rocblas_int lda,
rocblas_int* info)
{
rocblas_int* ipiv = nullptr;
return rocsolver::rocsolver_getrf_impl<float>(handle, m, n, A, lda, ipiv, info, false);
if(use_getrf_nopiv)
{
return rocsolver::rocsolver_getrf_nopiv_impl<float>(handle, m, n, A, lda, info);
}
else
{
rocblas_int* ipiv = nullptr;
return rocsolver::rocsolver_getrf_impl<float>(handle, m, n, A, lda, ipiv, info, false);
}
}

rocblas_status rocsolver_dgetrf_npvt(rocblas_handle handle,
Expand All @@ -229,8 +316,15 @@ rocblas_status rocsolver_dgetrf_npvt(rocblas_handle handle,
const rocblas_int lda,
rocblas_int* info)
{
rocblas_int* ipiv = nullptr;
return rocsolver::rocsolver_getrf_impl<double>(handle, m, n, A, lda, ipiv, info, false);
if(use_getrf_nopiv)
{
return rocsolver::rocsolver_getrf_nopiv_impl<double>(handle, m, n, A, lda, info);
}
else
{
rocblas_int* ipiv = nullptr;
return rocsolver::rocsolver_getrf_impl<double>(handle, m, n, A, lda, ipiv, info, false);
}
}

rocblas_status rocsolver_cgetrf_npvt(rocblas_handle handle,
Expand All @@ -240,9 +334,17 @@ rocblas_status rocsolver_cgetrf_npvt(rocblas_handle handle,
const rocblas_int lda,
rocblas_int* info)
{
rocblas_int* ipiv = nullptr;
return rocsolver::rocsolver_getrf_impl<rocblas_float_complex>(handle, m, n, A, lda, ipiv, info,
false);
if(use_getrf_nopiv)
{
return rocsolver::rocsolver_getrf_nopiv_impl<rocblas_float_complex>(handle, m, n, A, lda,
info);
}
else
{
rocblas_int* ipiv = nullptr;
return rocsolver::rocsolver_getrf_impl<rocblas_float_complex>(handle, m, n, A, lda, ipiv,
info, false);
}
}

rocblas_status rocsolver_zgetrf_npvt(rocblas_handle handle,
Expand All @@ -252,9 +354,17 @@ rocblas_status rocsolver_zgetrf_npvt(rocblas_handle handle,
const rocblas_int lda,
rocblas_int* info)
{
rocblas_int* ipiv = nullptr;
return rocsolver::rocsolver_getrf_impl<rocblas_double_complex>(handle, m, n, A, lda, ipiv, info,
false);
if(use_getrf_nopiv)
{
return rocsolver::rocsolver_getrf_nopiv_impl<rocblas_double_complex>(handle, m, n, A, lda,
info);
}
else
{
rocblas_int* ipiv = nullptr;
return rocsolver::rocsolver_getrf_impl<rocblas_double_complex>(handle, m, n, A, lda, ipiv,
info, false);
}
}

rocblas_status rocsolver_sgetrf_npvt_64(rocblas_handle handle,
Expand Down
Loading