From fcef6cd3c3ed4adf1e848077596770e5edf6840a Mon Sep 17 00:00:00 2001 From: Qianruipku Date: Tue, 21 Jan 2025 09:55:05 +0800 Subject: [PATCH 01/17] move kpar into read_input_item --- source/module_base/global_variable.cpp | 1 - source/module_base/global_variable.h | 2 -- source/module_esolver/esolver_ks_lcao.cpp | 6 ++-- source/module_hsolver/hsolver_lcao.cpp | 6 ++-- source/module_io/input_conv.cpp | 32 ++------------------- source/module_io/read_input_item_system.cpp | 24 ++++++++++++++++ source/module_io/read_set_globalv.cpp | 1 + source/module_parameter/system_parameter.h | 1 + 8 files changed, 34 insertions(+), 39 deletions(-) mode change 100644 => 100755 source/module_io/read_set_globalv.cpp mode change 100644 => 100755 source/module_parameter/system_parameter.h diff --git a/source/module_base/global_variable.cpp b/source/module_base/global_variable.cpp index 616e8c3656..6957b9bc97 100644 --- a/source/module_base/global_variable.cpp +++ b/source/module_base/global_variable.cpp @@ -18,7 +18,6 @@ namespace GlobalV int NPROC = 1; ///< global number of process int KPAR = 1; ///< global number of pools -int KPAR_LCAO = 1; ///< global number of pools for LCAO diagonalization only int MY_RANK = 0; ///< global index of process int MY_POOL = 0; ///< global index of pool (count in pool) int MY_STOGROUP = 0; diff --git a/source/module_base/global_variable.h b/source/module_base/global_variable.h index 56629444d3..ae35025fef 100644 --- a/source/module_base/global_variable.h +++ b/source/module_base/global_variable.h @@ -28,7 +28,6 @@ namespace GlobalV // NAME : DCOLOR( color of each group) // NAME : GRANK( index of grid world) // NAME : GSIZE( number of processors in each grid world) -// NAME : KPAR_LCAO ( global number of pools for LCAO diagonalization only) //======================================================================== extern int NPROC; extern int KPAR; @@ -44,7 +43,6 @@ extern int DSIZE; extern int DCOLOR; extern int GRANK; extern int GSIZE; -extern int KPAR_LCAO; //========================================================== // NAME : ofs_running( contain information during runnnig) diff --git a/source/module_esolver/esolver_ks_lcao.cpp b/source/module_esolver/esolver_ks_lcao.cpp index 922d3e19bf..7f2c4ca5ff 100644 --- a/source/module_esolver/esolver_ks_lcao.cpp +++ b/source/module_esolver/esolver_ks_lcao.cpp @@ -241,16 +241,16 @@ void ESolver_KS_LCAO::before_all_runners(UnitCell& ucell, const Input_pa } // 12) if kpar is not divisible by nks, print a warning - if (GlobalV::KPAR_LCAO > 1) + if (PARAM.globalv.kpar_lcao > 1) { - if (this->kv.get_nks() % GlobalV::KPAR_LCAO != 0) + if (this->kv.get_nks() % PARAM.globalv.kpar_lcao != 0) { ModuleBase::WARNING("ESolver_KS_LCAO::before_all_runners", "nks is not divisible by kpar."); std::cout << "\n%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%" "%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%" "%%%%%%%%%%%%%%%%%%%%%%%%%%" << std::endl; - std::cout << " Warning: nks (" << this->kv.get_nks() << ") is not divisible by kpar (" << GlobalV::KPAR_LCAO + std::cout << " Warning: nks (" << this->kv.get_nks() << ") is not divisible by kpar (" << PARAM.globalv.kpar_lcao << ")." << std::endl; std::cout << " This may lead to poor load balance. It is strongly suggested to" << std::endl; std::cout << " set nks to be divisible by kpar, but if this is really what" << std::endl; diff --git a/source/module_hsolver/hsolver_lcao.cpp b/source/module_hsolver/hsolver_lcao.cpp index 44deac1bbd..4c8ff835de 100644 --- a/source/module_hsolver/hsolver_lcao.cpp +++ b/source/module_hsolver/hsolver_lcao.cpp @@ -47,14 +47,14 @@ void HSolverLCAO::solve(hamilt::Hamilt* pHamilt, if (this->method != "pexsi") { - if (GlobalV::KPAR_LCAO > 1 + if (PARAM.globalv.kpar_lcao > 1 && (this->method == "genelpa" || this->method == "elpa" || this->method == "scalapack_gvx")) { #ifdef __MPI - this->parakSolve(pHamilt, psi, pes, GlobalV::KPAR_LCAO); + this->parakSolve(pHamilt, psi, pes, PARAM.globalv.kpar_lcao); #endif } - else if (GlobalV::KPAR_LCAO == 1) + else if (PARAM.globalv.kpar_lcao == 1) { /// Loop over k points for solve Hamiltonian to eigenpairs(eigenvalues and eigenvectors). for (int ik = 0; ik < psi.get_nk(); ++ik) diff --git a/source/module_io/input_conv.cpp b/source/module_io/input_conv.cpp index 65ff27dbc1..3ff770d7d8 100644 --- a/source/module_io/input_conv.cpp +++ b/source/module_io/input_conv.cpp @@ -171,37 +171,9 @@ void Input_Conv::Convert() ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running, "orbital_dir", PARAM.inp.orbital_dir); // GlobalV::global_pseudo_type = PARAM.inp.pseudo_type; - if (PARAM.inp.calculation == "relax" || PARAM.inp.calculation == "cell-relax") - { - } - - if (PARAM.inp.device == "gpu" && PARAM.inp.basis_type == "pw") - { - GlobalV::KPAR = base_device::information::get_device_kpar(PARAM.inp.kpar, PARAM.inp.bndpar); - } -#ifdef __LCAO - else if (PARAM.inp.basis_type == "lcao") { - /// GlobalV::KPAR_LCAO is used in LCAO diagonalization only - GlobalV::KPAR_LCAO = PARAM.inp.kpar; - /// all other parts of the code use GlobalV::KPAR = 1 - GlobalV::KPAR = 1; - } -#endif - else - { - GlobalV::KPAR = PARAM.inp.kpar; - } - if (PARAM.inp.device == "cpu" and PARAM.inp.precision == "single") - { -// cpu single precision is not supported while float_fftw lib is not available -#ifndef __ENABLE_FLOAT_FFTW - ModuleBase::WARNING_QUIT( - "Input_Conv", - "Single precision with cpu is not supported while float_fftw lib is not available; \ - \n Please recompile with cmake flag \"-DENABLE_FLOAT_FFTW=ON\".\n"); -#endif // __ENABLE_FLOAT_FFTW - } + GlobalV::KPAR = PARAM.inp.kpar; + #ifdef __LCAO diff --git a/source/module_io/read_input_item_system.cpp b/source/module_io/read_input_item_system.cpp index c95e785885..ad00b0a514 100644 --- a/source/module_io/read_input_item_system.cpp +++ b/source/module_io/read_input_item_system.cpp @@ -233,6 +233,19 @@ void ReadInput::item_system() item.annotation = "devide all processors into kpar groups and k points " "will be distributed among"; read_sync_int(input.kpar); + item.reset_value = [](const Input_Item& item, Parameter& para) { + if (para.inp.device == "gpu" && para.inp.basis_type == "pw") + { + para.input.kpar = base_device::information::get_device_kpar(para.inp.kpar, para.inp.bndpar); + } +#ifdef __LCAO + else if (para.inp.basis_type == "lcao") + { + para.sys.kpar_lcao = para.inp.kpar; + para.input.kpar = 1; + } +#endif + }; item.check_value = [](const Input_Item& item, const Parameter& para) { if (para.input.basis_type == "lcao" && para.input.kpar > 1) { @@ -796,6 +809,17 @@ void ReadInput::item_system() const std::string warningstr = nofound_str(avail_list, "precision"); ModuleBase::WARNING_QUIT("ReadInput", warningstr); } + + // cpu single precision is not supported while float_fftw lib is not available + if (para.inp.device == "cpu" && para.inp.precision == "single") + { +#ifndef __ENABLE_FLOAT_FFTW + ModuleBase::WARNING_QUIT( + "ReadInput", + "Single precision with cpu is not supported while float_fftw lib is not available; \ + \n Please recompile with cmake flag \"-DENABLE_FLOAT_FFTW=ON\".\n"); +#endif + } }; this->add_item(item); } diff --git a/source/module_io/read_set_globalv.cpp b/source/module_io/read_set_globalv.cpp old mode 100644 new mode 100755 index 1eb115fd9f..e80d1435fb --- a/source/module_io/read_set_globalv.cpp +++ b/source/module_io/read_set_globalv.cpp @@ -127,5 +127,6 @@ void ReadInput::set_globalv_bcast() add_double_bcast(sys.dq); add_int_bcast(sys.nqx); add_int_bcast(sys.nqxq); + add_int_bcast(sys.kpar_lcao); } } // namespace ModuleIO diff --git a/source/module_parameter/system_parameter.h b/source/module_parameter/system_parameter.h old mode 100644 new mode 100755 index 04d2ca870e..1401d015b0 --- a/source/module_parameter/system_parameter.h +++ b/source/module_parameter/system_parameter.h @@ -53,5 +53,6 @@ struct System_para bool double_grid = false; ///< true if "ndx,ndy,ndz" is larger than "nx,ny,nz" double uramping = -10.0 / 13.6; /// U-Ramping method (Ry) std::vector hubbard_u = {}; ///< Hubbard Coulomb interaction parameter U (Ry) + int kpar_lcao = 1; ///< global number of pools for LCAO diagonalization only }; #endif \ No newline at end of file From 23fc25e052f22b7de8252e03dcd5de28ddf3a507 Mon Sep 17 00:00:00 2001 From: Qianruipku Date: Tue, 21 Jan 2025 14:41:08 +0800 Subject: [PATCH 02/17] add para_linear_transform_op --- source/Makefile.Objects | 1 + source/module_hsolver/CMakeLists.txt | 1 + .../module_hsolver/para_linear_transform.cpp | 161 +++++++++++++++ source/module_hsolver/para_linear_transform.h | 55 +++++ source/module_hsolver/test/CMakeLists.txt | 17 +- .../test/test_para_linear_trans.cpp | 194 ++++++++++++++++++ 6 files changed, 426 insertions(+), 3 deletions(-) create mode 100644 source/module_hsolver/para_linear_transform.cpp create mode 100644 source/module_hsolver/para_linear_transform.h create mode 100644 source/module_hsolver/test/test_para_linear_trans.cpp diff --git a/source/Makefile.Objects b/source/Makefile.Objects index ad13d75976..6d416c8e7f 100644 --- a/source/Makefile.Objects +++ b/source/Makefile.Objects @@ -333,6 +333,7 @@ OBJS_HSOLVER=diago_cg.o\ diago_david.o\ diago_dav_subspace.o\ diago_bpcg.o\ + para_linear_transform.o\ hsolver.o\ hsolver_pw.o\ hsolver_lcaopw.o\ diff --git a/source/module_hsolver/CMakeLists.txt b/source/module_hsolver/CMakeLists.txt index 7f6c8ca4c6..6ccfa42c2e 100644 --- a/source/module_hsolver/CMakeLists.txt +++ b/source/module_hsolver/CMakeLists.txt @@ -4,6 +4,7 @@ list(APPEND objects diago_david.cpp diago_dav_subspace.cpp diago_bpcg.cpp + para_linear_transform.cpp hsolver_pw.cpp hsolver_lcaopw.cpp hsolver_pw_sdft.cpp diff --git a/source/module_hsolver/para_linear_transform.cpp b/source/module_hsolver/para_linear_transform.cpp new file mode 100644 index 0000000000..4180d11ff1 --- /dev/null +++ b/source/module_hsolver/para_linear_transform.cpp @@ -0,0 +1,161 @@ +#include "para_linear_transform.h" +#include +#include +namespace hsolver +{ +template +void para_linear_transform_op::operator()(T* A, + const T alpha, + const T beta, + const T* U_global, + const int& nrow, + const int& LDA, + const int& ncol_loc, + const int& ncol_glo, +#ifdef __MPI + MPI_Comm col_world, +#endif + const int rank_col, + const int nproc_col + +) +{ + const Device* ctx = {}; +#ifdef __MPI + if (nproc_col > 1) + { + std::vector colA_loc(nproc_col); + MPI_Allgather(&ncol_loc, 1, MPI_INT, colA_loc.data(), 1, MPI_INT, col_world); + std::vector start_col(nproc_col); + start_col[0] = 0; + for (int ip = 1; ip < nproc_col; ++ip) + { + start_col[ip] = start_col[ip - 1] + colA_loc[ip - 1]; + } + int max_col = *std::max_element(colA_loc.begin(), colA_loc.end()); + std::vector requests(nproc_col); + + std::vector A_tmp(max_col * LDA); + T* A_tmp_device = A_tmp.data(); + if (std::is_same::value) + { + A_tmp_device = nullptr; + resmem_dev_op()(A_tmp_device, max_col * LDA); + } + T* A_tmp2 = nullptr; + resmem_dev_op()(A_tmp2, ncol_loc * LDA); + syncmem_dev_op()(A_tmp2, A, ncol_loc * LDA); + T* A_sum = nullptr; + resmem_dev_op()(A_sum, ncol_loc * LDA); + setmem_dev_op()(A_sum, 0.0, ncol_loc * LDA); + + // Send + for (int ip = 0; ip < nproc_col; ++ip) + { + if (rank_col != ip) + { + int size = LDA * ncol_loc; + Parallel_Common::isend_dev(A, size, ip, 0, col_world, &requests[ip], A_tmp.data()); + } + } + + // Receive + T* U_local = nullptr; + resmem_dev_op()(U_local, max_col * ncol_loc); + const int start = start_col[rank_col]; + for (int ip = 0; ip < nproc_col; ++ip) + { + T real_beta = ip == 0 ? beta : 0; + const int start_row = start_col[ip]; + const int ncol_ip = colA_loc[ip]; + // get U_local + for (int i = 0; i < ncol_loc; ++i) + { + const T* U_glo_tmp = U_global + start_row + (i + start) * ncol_glo; + syncmem_dev_op()(U_local + i * ncol_ip, U_glo_tmp, ncol_ip); + } + + if (ip == rank_col) + { + ModuleBase::gemm_op()(ctx, + 'N', + 'N', + nrow, + ncol_loc, + ncol_ip, + &alpha, + A, + LDA, + U_local, + ncol_ip, + &real_beta, + A_tmp2, + LDA); + } + else + { + int size = LDA * ncol_ip; + MPI_Status status; + Parallel_Common::recv_dev(A_tmp_device, size, ip, 0, col_world, &status, A_tmp.data()); + MPI_Wait(&requests[ip], &status); + ModuleBase::gemm_op()(ctx, + 'N', + 'N', + nrow, + ncol_loc, + ncol_ip, + &alpha, + A_tmp_device, + LDA, + U_local, + ncol_ip, + &real_beta, + A_tmp2, + LDA); + } + // sum all the results + T one = 1.0; + ModuleBase::axpy_op()(ctx, ncol_loc * LDA, &one, A_tmp2, 1, A_sum, 1); + } + syncmem_dev_op()(A, A_sum, ncol_loc * LDA); + delmem_dev_op()(U_local); + delmem_dev_op()(A_tmp2); + delmem_dev_op()(A_sum); + if (std::is_same::value) + { + delmem_dev_op()(A_tmp_device); + } + } + else +#endif + { + T* A_tmp = nullptr; + resmem_dev_op()(A_tmp, LDA * ncol_glo); + syncmem_dev_op()(A_tmp, A, LDA * ncol_loc); + ModuleBase::gemm_op()(ctx, + 'N', + 'N', + nrow, + ncol_glo, + ncol_glo, + &alpha, + A_tmp, + LDA, + U_global, + ncol_glo, + &beta, + A, + LDA); + delmem_dev_op()(A_tmp); + } +}; + +template struct para_linear_transform_op; +template struct para_linear_transform_op, base_device::DEVICE_CPU>; +template struct para_linear_transform_op, base_device::DEVICE_CPU>; +#if ((defined __CUDA) || (defined __ROCM)) +template struct para_linear_transform_op; +template struct para_linear_transform_op, base_device::DEVICE_GPU>; +template struct para_linear_transform_op, base_device::DEVICE_GPU>; +#endif +} // namespace hsolver \ No newline at end of file diff --git a/source/module_hsolver/para_linear_transform.h b/source/module_hsolver/para_linear_transform.h new file mode 100644 index 0000000000..b94dd763ab --- /dev/null +++ b/source/module_hsolver/para_linear_transform.h @@ -0,0 +1,55 @@ +#ifndef __PARA_LINEAR_TRANSFORM_H__ +#define __PARA_LINEAR_TRANSFORM_H__ +#include "module_base/kernels/math_kernel_op.h" +#include "module_base/module_device/device.h" +#include "module_base/module_device/memory_op.h" +#include "module_base/parallel_device.h" +#ifdef __MPI +#include "mpi.h" +#endif +namespace hsolver +{ + +template +struct para_linear_transform_op +{ + using syncmem_dev_op = base_device::memory::synchronize_memory_op; + using resmem_dev_op = base_device::memory::resize_memory_op; + using setmem_dev_op = base_device::memory::set_memory_op; + using delmem_dev_op = base_device::memory::delete_memory_op; + /** + * @brief A_global = alpha * A_global * U_global + beta * A_global + * A is a local matrix with nrow rows and ncol_loc columns + * U_global is a matrix with ncol_glo rows and ncol_glo columns + * @example rotate wave functions: A = A * U + * orthogonalize wave functions: A = A - A * U + * + * @param A : input/output matrix + * @param alpha : alpha + * @param beta : beta + * @param U_global : input matrix + * @param nrow : number of rows of A + * @param LDA : leading dimension of A + * @param ncol_loc : number of columns of A + * @param ncol_glo : number of columns and rows of U_global + * @param col_world : column communicator world + * @param rank_col : rank of col_world + * @param nproc_col : number of processes in col_world + * + */ + void operator()(T* A, + const T alpha, + const T beta, + const T* U_global, + const int& nrow, + const int& LDA, + const int& ncol_loc, + const int& ncol_glo, +#ifdef __MPI + MPI_Comm col_world, +#endif + const int rank_col, + const int nproc_col); +}; +} // namespace hsolver +#endif \ No newline at end of file diff --git a/source/module_hsolver/test/CMakeLists.txt b/source/module_hsolver/test/CMakeLists.txt index e44171912c..7165e895a7 100644 --- a/source/module_hsolver/test/CMakeLists.txt +++ b/source/module_hsolver/test/CMakeLists.txt @@ -12,7 +12,7 @@ if (ENABLE_MPI) AddTest( TARGET HSolver_bpcg LIBS parameter ${math_libs} base psi device container - SOURCES diago_bpcg_test.cpp ../diago_bpcg.cpp ../diago_iter_assist.cpp + SOURCES diago_bpcg_test.cpp ../diago_bpcg.cpp ../para_linear_transform.cpp ../diago_iter_assist.cpp ../../module_basis/module_pw/test/test_tool.cpp ../../module_hamilt_general/operator.cpp ../../module_hamilt_pw/hamilt_pwdft/operator_pw/operator_pw.cpp @@ -77,13 +77,13 @@ if (ENABLE_MPI) AddTest( TARGET HSolver_pw LIBS parameter ${math_libs} psi device base container - SOURCES test_hsolver_pw.cpp ../hsolver_pw.cpp ../hsolver_lcaopw.cpp ../diago_bpcg.cpp ../diago_dav_subspace.cpp ../diag_const_nums.cpp ../diago_iter_assist.cpp + SOURCES test_hsolver_pw.cpp ../hsolver_pw.cpp ../hsolver_lcaopw.cpp ../diago_bpcg.cpp ../diago_dav_subspace.cpp ../diag_const_nums.cpp ../diago_iter_assist.cpp ../para_linear_transform.cpp ) AddTest( TARGET HSolver_sdft LIBS parameter ${math_libs} psi device base container - SOURCES test_hsolver_sdft.cpp ../hsolver_pw_sdft.cpp ../hsolver_pw.cpp ../diago_bpcg.cpp ../diago_dav_subspace.cpp ../diag_const_nums.cpp ../diago_iter_assist.cpp + SOURCES test_hsolver_sdft.cpp ../hsolver_pw_sdft.cpp ../hsolver_pw.cpp ../diago_bpcg.cpp ../diago_dav_subspace.cpp ../diag_const_nums.cpp ../diago_iter_assist.cpp ../para_linear_transform.cpp ) if(ENABLE_LCAO) @@ -159,6 +159,17 @@ AddTest( SOURCES test_diago_hs_para.cpp ../diag_hs_para.cpp ../diago_pxxxgvx.cpp ../diago_elpa.cpp ../diago_scalapack.cpp ) +AddTest( + TARGET hsolver_linear_trans + LIBS parameter ${math_libs} base device MPI::MPI_CXX + SOURCES test_para_linear_trans.cpp ../para_linear_transform.cpp +) + +add_test(NAME hsolver_para_linear_trans + COMMAND mpirun -np 4 ./hsolver_linear_trans + WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR} +) + find_program(BASH bash) if (ENABLE_MPI) add_test(NAME HSolver_cg_parallel diff --git a/source/module_hsolver/test/test_para_linear_trans.cpp b/source/module_hsolver/test/test_para_linear_trans.cpp new file mode 100644 index 0000000000..3acac64435 --- /dev/null +++ b/source/module_hsolver/test/test_para_linear_trans.cpp @@ -0,0 +1,194 @@ +#include "../para_linear_transform.h" + +#include +#ifdef __MPI +#include +#endif + +void random_data(std::vector& A_global, std::vector& U_global, double& alpha, double& beta) +{ + for (auto& val: A_global) + { + val = std::rand() / (RAND_MAX + 1.0); + } + for (auto& val: U_global) + { + val = std::rand() / (RAND_MAX + 1.0); + } + alpha = std::rand() / (RAND_MAX + 1.0); + beta = std::rand() / (RAND_MAX + 1.0); +} +void random_data(std::vector>& A_global, + std::vector>& U_global, + std::complex& alpha, + std::complex& beta) +{ + for (auto& val: A_global) + { + val = std::complex(std::rand() / (RAND_MAX + 1.0), std::rand() / (RAND_MAX + 1.0)); + } + for (auto& val: U_global) + { + val = std::complex(std::rand() / (RAND_MAX + 1.0), std::rand() / (RAND_MAX + 1.0)); + } + alpha = std::complex(std::rand() / (RAND_MAX + 1.0), std::rand() / (RAND_MAX + 1.0)); + beta = std::complex(std::rand() / (RAND_MAX + 1.0), std::rand() / (RAND_MAX + 1.0)); +} +double get_double(std::complex& val) +{ + return val.real() + val.imag(); +} +double get_double(double& val) +{ + return val; +} + +template +class ParaLinearTransformTest : public ::testing::Test +{ + protected: + void SetUp() override + { + } + + void TearDown() override + { + } + void prepare(const int nrow, const int ncol_glo, const int LDA) + { + int rank = 0; + int nproc = 1; + int colA_start = 0; + this->ncol_glo = ncol_glo; + this->ncol_loc = ncol_glo; +#ifdef __MPI + MPI_Comm_rank(MPI_COMM_WORLD, &rank); + MPI_Comm_size(MPI_COMM_WORLD, &nproc); + this->ncol_loc = ncol_glo / nproc; + if (rank < ncol_glo % nproc) + { + ncol_loc++; + } + std::vector ncolA_ip(nproc); + MPI_Allgather(&ncol_loc, 1, MPI_INT, ncolA_ip.data(), 1, MPI_INT, MPI_COMM_WORLD); + for (int i = 0; i < rank; ++i) + { + colA_start += ncolA_ip[i]; + } +#endif + A_global.resize(LDA * ncol_glo); + A_global_ref.resize(LDA * ncol_glo); + U_global.resize(ncol_glo * ncol_glo); + if (rank == 0) + { + random_data(A_global, U_global, alpha, beta); + A_global_ref = A_global; + std::vector A_global_tmp = A_global; + const base_device::DEVICE_CPU* ctx = {}; + ModuleBase::gemm_op()(ctx, + 'N', + 'N', + nrow, + ncol_glo, + ncol_glo, + &alpha, + A_global_tmp.data(), + LDA, + U_global.data(), + ncol_glo, + &beta, + A_global_ref.data(), + LDA); + } + if (std::is_same::value) + { +#ifdef __MPI + MPI_Bcast(A_global.data(), A_global.size(), MPI_DOUBLE, 0, MPI_COMM_WORLD); + MPI_Bcast(U_global.data(), U_global.size(), MPI_DOUBLE, 0, MPI_COMM_WORLD); + MPI_Bcast(A_global_ref.data(), A_global_ref.size(), MPI_DOUBLE, 0, MPI_COMM_WORLD); + MPI_Bcast(&alpha, 1, MPI_DOUBLE, 0, MPI_COMM_WORLD); + MPI_Bcast(&beta, 1, MPI_DOUBLE, 0, MPI_COMM_WORLD); +#endif + } + else if (std::is_same>::value) + { +#ifdef __MPI + MPI_Bcast(A_global.data(), A_global.size(), MPI_DOUBLE_COMPLEX, 0, MPI_COMM_WORLD); + MPI_Bcast(U_global.data(), U_global.size(), MPI_DOUBLE_COMPLEX, 0, MPI_COMM_WORLD); + MPI_Bcast(A_global_ref.data(), A_global_ref.size(), MPI_DOUBLE_COMPLEX, 0, MPI_COMM_WORLD); + MPI_Bcast(&alpha, 1, MPI_DOUBLE_COMPLEX, 0, MPI_COMM_WORLD); + MPI_Bcast(&beta, 1, MPI_DOUBLE_COMPLEX, 0, MPI_COMM_WORLD); +#endif + } + + A.resize(LDA * ncol_loc); + A_ref.resize(LDA * ncol_loc); + for (int i = 0; i < LDA * ncol_loc; ++i) + { + A[i] = A_global[colA_start * LDA + i]; + A_ref[i] = A_global_ref[colA_start * LDA + i]; + } + } + std::vector A; + std::vector A_ref; + std::vector A_global; + std::vector U_global; + std::vector A_global_ref; + int ncol_glo = 1; + int ncol_loc = 1; + T alpha; + T beta; +}; + +typedef ::testing::Types> MyTypes; +TYPED_TEST_SUITE(ParaLinearTransformTest, MyTypes); + +TYPED_TEST(ParaLinearTransformTest, cpucase) +{ + const int nrow = 7; + const int ncol_glo = 13; + const int LDA = 9; + + this->prepare(nrow, ncol_glo, LDA); + int rank_col = 0, nproc_col = 1; +#ifdef __MPI + MPI_Comm col_world = MPI_COMM_WORLD; + MPI_Comm_rank(col_world, &rank_col); + MPI_Comm_size(col_world, &nproc_col); +#endif + + hsolver::para_linear_transform_op()(this->A.data(), + this->alpha, + this->beta, + this->U_global.data(), + nrow, + LDA, + this->ncol_loc, + ncol_glo, +#ifdef __MPI + col_world, +#endif + rank_col, + nproc_col); + + for (int i = 0; i < this->ncol_loc; ++i) + { + for (int j = 0; j < nrow; ++j) + { + EXPECT_NEAR(get_double(this->A[j + i * LDA]), get_double(this->A_ref[j + i * LDA]), 1e-10); + } + } +} + +int main(int argc, char** argv) +{ +#ifdef __MPI + MPI_Init(&argc, &argv); +#endif + ::testing::InitGoogleTest(&argc, argv); + int result = RUN_ALL_TESTS(); +#ifdef __MPI + MPI_Finalize(); +#endif + return result; +} \ No newline at end of file From 7c4410ef6e17d62956c333ab98c14f2140631418 Mon Sep 17 00:00:00 2001 From: Qianruipku Date: Tue, 21 Jan 2025 21:23:44 +0800 Subject: [PATCH 03/17] arrange the order in read_input --- source/driver.cpp | 2 +- source/module_base/global_file.cpp | 18 +- source/module_base/global_variable.cpp | 2 +- source/module_base/global_variable.h | 2 +- source/module_base/parallel_global.cpp | 12 +- source/module_base/parallel_global.h | 4 +- source/module_base/test/global_file_test.cpp | 4 +- source/module_cell/test/klist_test_para.cpp | 4 +- source/module_elecstate/elecstate_print.cpp | 16 +- .../module_charge/charge_init.cpp | 6 +- .../test_mpi/charge_mpi_test.cpp | 6 +- .../hamilt_pwdft/parallel_grid.cpp | 4 +- .../module_hamilt_pw/hamilt_stodft/sto_wf.cpp | 4 +- source/module_io/read_input.cpp | 61 ++--- source/module_io/read_input.h | 7 +- .../module_io/read_input_item_elec_stru.cpp | 1 + source/module_io/read_input_item_exx_dftu.cpp | 1 + source/module_io/read_input_item_output.cpp | 1 + .../module_io/read_input_item_postprocess.cpp | 2 + source/module_io/read_input_item_system.cpp | 4 +- source/module_io/read_set_globalv.cpp | 210 +++++++++--------- .../module_io/test/read_wfc_to_rho_test.cpp | 2 +- .../module_io/test/write_istate_info_test.cpp | 4 +- source/module_parameter/system_parameter.h | 3 +- source/module_psi/psi.cpp | 2 +- source/module_psi/psi.h | 2 +- 26 files changed, 191 insertions(+), 193 deletions(-) diff --git a/source/driver.cpp b/source/driver.cpp index 250ac12707..285afa426e 100644 --- a/source/driver.cpp +++ b/source/driver.cpp @@ -153,7 +153,7 @@ void Driver::reading() PARAM.inp.bndpar, GlobalV::KPAR, GlobalV::NPROC_IN_STOGROUP, - GlobalV::RANK_IN_STOGROUP, + GlobalV::RANK_IN_BPGROUP, GlobalV::MY_STOGROUP, GlobalV::NPROC_IN_POOL, GlobalV::RANK_IN_POOL, diff --git a/source/module_base/global_file.cpp b/source/module_base/global_file.cpp index b912996f4e..4da5e94045 100644 --- a/source/module_base/global_file.cpp +++ b/source/module_base/global_file.cpp @@ -153,36 +153,32 @@ void ModuleBase::Global_File::make_dir_out( #endif } - std::stringstream ss,ss1; - // mohan add 2010-09-12 if(out_alllog) { - ss << "running_" << calculation << "_" << rank + 1; - open_log(GlobalV::ofs_running, ss.str(), calculation, restart); + open_log(GlobalV::ofs_running, PARAM.globalv.log_file, calculation, restart); #if defined(__CUDA) || defined(__ROCM) - open_log(GlobalV::ofs_device, "device" + std::to_string(rank), calculation, restart); + open_log(GlobalV::ofs_device, "device" + std::to_string(rank) + ".log", calculation, restart); #endif } else { if(rank==0) { - ss << "running_" << calculation; - open_log(GlobalV::ofs_running, ss.str(), calculation, restart); + open_log(GlobalV::ofs_running, PARAM.globalv.log_file, calculation, restart); #if defined(__CUDA) || defined(__ROCM) - open_log(GlobalV::ofs_device, "device", calculation, restart); + open_log(GlobalV::ofs_device, "device.log", calculation, restart); #endif } } if(rank==0) { - open_log(GlobalV::ofs_warning, "warning", calculation, restart); + open_log(GlobalV::ofs_warning, "warning.log", calculation, restart); } #ifdef GATHER_INFO - open_log(GlobalV::ofs_info, "math_info_" + std::to_string(rank), calculation, restart); + open_log(GlobalV::ofs_info, "math_info_" + std::to_string(rank) + ".log", calculation, restart); #endif return; @@ -206,7 +202,7 @@ void ModuleBase::Global_File::open_log(std::ofstream &ofs, const std::string &fn // PARAM.globalv.global_out_dir : (default dir to store "*.log" file) //---------------------------------------------------------- std::stringstream ss; - ss << PARAM.globalv.global_out_dir << fn << ".log"; + ss << PARAM.globalv.global_out_dir << fn; if(calculation == "md" && restart) { diff --git a/source/module_base/global_variable.cpp b/source/module_base/global_variable.cpp index 6957b9bc97..6feaafa701 100644 --- a/source/module_base/global_variable.cpp +++ b/source/module_base/global_variable.cpp @@ -24,7 +24,7 @@ int MY_STOGROUP = 0; int NPROC_IN_POOL = 1; ///< local number of process in a pool int NPROC_IN_STOGROUP = 1; int RANK_IN_POOL = 0; ///< global index of pool (count in process), my_rank in each pool -int RANK_IN_STOGROUP = 0; +int RANK_IN_BPGROUP = 0; int DRANK = -1; ///< mohan add 2012-01-13, must be -1, so we can recognize who ///< didn't in DIAG_WORLD int DSIZE = KPAR; diff --git a/source/module_base/global_variable.h b/source/module_base/global_variable.h index ae35025fef..1bb1068384 100644 --- a/source/module_base/global_variable.h +++ b/source/module_base/global_variable.h @@ -37,7 +37,7 @@ extern int MY_STOGROUP; extern int NPROC_IN_POOL; extern int NPROC_IN_STOGROUP; extern int RANK_IN_POOL; -extern int RANK_IN_STOGROUP; +extern int RANK_IN_BPGROUP; extern int DRANK; extern int DSIZE; extern int DCOLOR; diff --git a/source/module_base/parallel_global.cpp b/source/module_base/parallel_global.cpp index 4081fd7207..bf1a28b4c8 100644 --- a/source/module_base/parallel_global.cpp +++ b/source/module_base/parallel_global.cpp @@ -254,7 +254,7 @@ void Parallel_Global::init_pools(const int& NPROC, const int& BNDPAR, const int& KPAR, int& NPROC_IN_STOGROUP, - int& RANK_IN_STOGROUP, + int& RANK_IN_BPGROUP, int& MY_STOGROUP, int& NPROC_IN_POOL, int& RANK_IN_POOL, @@ -269,7 +269,7 @@ void Parallel_Global::init_pools(const int& NPROC, BNDPAR, KPAR, NPROC_IN_STOGROUP, - RANK_IN_STOGROUP, + RANK_IN_BPGROUP, MY_STOGROUP, NPROC_IN_POOL, RANK_IN_POOL, @@ -317,7 +317,7 @@ void Parallel_Global::divide_pools(const int& NPROC, const int& BNDPAR, const int& KPAR, int& NPROC_IN_STOGROUP, - int& RANK_IN_STOGROUP, + int& RANK_IN_BPGROUP, int& MY_STOGROUP, int& NPROC_IN_POOL, int& RANK_IN_POOL, @@ -359,15 +359,15 @@ void Parallel_Global::divide_pools(const int& NPROC, if(BNDPAR > 1) { NPROC_IN_STOGROUP = kpar_group.ngroups * bndpar_group.nprocs_in_group; - RANK_IN_STOGROUP = kpar_group.my_group * bndpar_group.nprocs_in_group + bndpar_group.rank_in_group; + RANK_IN_BPGROUP = kpar_group.my_group * bndpar_group.nprocs_in_group + bndpar_group.rank_in_group; MY_STOGROUP = bndpar_group.my_group; - MPI_Comm_split(MPI_COMM_WORLD, MY_STOGROUP, RANK_IN_STOGROUP, &STO_WORLD); + MPI_Comm_split(MPI_COMM_WORLD, MY_STOGROUP, RANK_IN_BPGROUP, &STO_WORLD); MPI_Comm_dup(bndpar_group.inter_comm, &PARAPW_WORLD); } else { NPROC_IN_STOGROUP = NPROC; - RANK_IN_STOGROUP = MY_RANK; + RANK_IN_BPGROUP = MY_RANK; MY_STOGROUP = 0; MPI_Comm_dup(MPI_COMM_WORLD, &STO_WORLD); MPI_Comm_split(MPI_COMM_WORLD, MY_RANK, 0, &PARAPW_WORLD); diff --git a/source/module_base/parallel_global.h b/source/module_base/parallel_global.h index 1fcf933f7b..59df6aadd3 100644 --- a/source/module_base/parallel_global.h +++ b/source/module_base/parallel_global.h @@ -49,7 +49,7 @@ void init_pools(const int& NPROC, const int& BNDPAR, const int& KPAR, int& NPROC_IN_STOGROUP, - int& RANK_IN_STOGROUP, + int& RANK_IN_BPGROUP, int& MY_STOGROUP, int& NPROC_IN_POOL, int& RANK_IN_POOL, @@ -60,7 +60,7 @@ void divide_pools(const int& NPROC, const int& BNDPAR, const int& KPAR, int& NPROC_IN_STOGROUP, - int& RANK_IN_STOGROUP, + int& RANK_IN_BPGROUP, int& MY_STOGROUP, int& NPROC_IN_POOL, int& RANK_IN_POOL, diff --git a/source/module_base/test/global_file_test.cpp b/source/module_base/test/global_file_test.cpp index adc7f30fc7..365782cd80 100644 --- a/source/module_base/test/global_file_test.cpp +++ b/source/module_base/test/global_file_test.cpp @@ -80,10 +80,10 @@ TEST_F(GlobalFile,mkdiratom) TEST_F(GlobalFile,openlog) { std::ofstream ofs; - ModuleBase::Global_File::open_log(ofs,"Si","md",true); + ModuleBase::Global_File::open_log(ofs,"Si.log","md",true); EXPECT_TRUE(ofs.is_open()); ofs.close(); - ModuleBase::Global_File::open_log(ofs,"Si","md",false); + ModuleBase::Global_File::open_log(ofs,"Si.log","md",false); EXPECT_TRUE(ofs.is_open()); ofs.close(); std::string sss = "Si.log"; diff --git a/source/module_cell/test/klist_test_para.cpp b/source/module_cell/test/klist_test_para.cpp index 3bd41c85a6..163f9b740b 100644 --- a/source/module_cell/test/klist_test_para.cpp +++ b/source/module_cell/test/klist_test_para.cpp @@ -230,7 +230,7 @@ TEST_F(KlistParaTest, Set) PARAM.input.bndpar, GlobalV::KPAR, GlobalV::NPROC_IN_STOGROUP, - GlobalV::RANK_IN_STOGROUP, + GlobalV::RANK_IN_BPGROUP, GlobalV::MY_STOGROUP, GlobalV::NPROC_IN_POOL, GlobalV::RANK_IN_POOL, @@ -287,7 +287,7 @@ TEST_F(KlistParaTest, SetAfterVC) PARAM.input.bndpar, GlobalV::KPAR, GlobalV::NPROC_IN_STOGROUP, - GlobalV::RANK_IN_STOGROUP, + GlobalV::RANK_IN_BPGROUP, GlobalV::MY_STOGROUP, GlobalV::NPROC_IN_POOL, GlobalV::RANK_IN_POOL, diff --git a/source/module_elecstate/elecstate_print.cpp b/source/module_elecstate/elecstate_print.cpp index fc08baea4e..4a0bf22536 100644 --- a/source/module_elecstate/elecstate_print.cpp +++ b/source/module_elecstate/elecstate_print.cpp @@ -174,23 +174,11 @@ void ElecState::print_eigenvalue(std::ofstream& ofs) { ModuleBase::WARNING_QUIT("print_eigenvalue", "Eigenvalues are too large!"); } - std::stringstream ss; - if(PARAM.inp.out_alllog) - { - ss << PARAM.globalv.global_out_dir << "running_" << PARAM.inp.calculation << "_" << GlobalV::MY_RANK + 1 << ".log"; - } - else - { - ss << PARAM.globalv.global_out_dir << "running_" << PARAM.inp.calculation << ".log"; - } - std::string filename = ss.str(); + + std::string filename = PARAM.globalv.global_out_dir + PARAM.globalv.log_file; std::vector ngk_tot = this->klist->ngk; #ifdef __MPI - if(!PARAM.inp.out_alllog) - { - Parallel_Common::bcast_string(filename); - } MPI_Allreduce(MPI_IN_PLACE, ngk_tot.data(), nks, MPI_INT, MPI_SUM, POOL_WORLD); #endif diff --git a/source/module_elecstate/module_charge/charge_init.cpp b/source/module_elecstate/module_charge/charge_init.cpp index f6a0b400d4..bec7d05bf4 100644 --- a/source/module_elecstate/module_charge/charge_init.cpp +++ b/source/module_elecstate/module_charge/charge_init.cpp @@ -61,7 +61,7 @@ void Charge::init_rho(elecstate::efermi& eferm_iout, std::stringstream ssc; ssc << PARAM.globalv.global_readin_dir << "SPIN" << is + 1 << "_CHG.cube"; if (ModuleIO::read_vdata_palgrid(pgrid, - (PARAM.inp.esolver_type == "sdft" ? GlobalV::RANK_IN_STOGROUP : GlobalV::MY_RANK), + (PARAM.inp.esolver_type == "sdft" ? GlobalV::RANK_IN_BPGROUP : GlobalV::MY_RANK), GlobalV::ofs_running, ssc.str(), this->rho[is], @@ -150,7 +150,7 @@ void Charge::init_rho(elecstate::efermi& eferm_iout, // mohan update 2012-02-10, sunliang update 2023-03-09 if (ModuleIO::read_vdata_palgrid( pgrid, - (PARAM.inp.esolver_type == "sdft" ? GlobalV::RANK_IN_STOGROUP : GlobalV::MY_RANK), + (PARAM.inp.esolver_type == "sdft" ? GlobalV::RANK_IN_BPGROUP : GlobalV::MY_RANK), GlobalV::ofs_running, ssc.str(), this->kin_r[is], @@ -223,7 +223,7 @@ void Charge::init_rho(elecstate::efermi& eferm_iout, std::stringstream ssc; ssc << PARAM.globalv.global_readin_dir << "SPIN" << is + 1 << "_CHG.cube"; if (ModuleIO::read_vdata_palgrid(pgrid, - (PARAM.inp.esolver_type == "sdft" ? GlobalV::RANK_IN_STOGROUP : GlobalV::MY_RANK), + (PARAM.inp.esolver_type == "sdft" ? GlobalV::RANK_IN_BPGROUP : GlobalV::MY_RANK), GlobalV::ofs_running, ssc.str(), this->rho[is], diff --git a/source/module_elecstate/test_mpi/charge_mpi_test.cpp b/source/module_elecstate/test_mpi/charge_mpi_test.cpp index 040f5953d6..92e3801f2a 100644 --- a/source/module_elecstate/test_mpi/charge_mpi_test.cpp +++ b/source/module_elecstate/test_mpi/charge_mpi_test.cpp @@ -71,7 +71,7 @@ TEST_F(ChargeMpiTest, reduce_diff_pools1) PARAM.input.bndpar, GlobalV::KPAR, GlobalV::NPROC_IN_STOGROUP, - GlobalV::RANK_IN_STOGROUP, + GlobalV::RANK_IN_BPGROUP, GlobalV::MY_STOGROUP, GlobalV::NPROC_IN_POOL, GlobalV::RANK_IN_POOL, @@ -117,7 +117,7 @@ TEST_F(ChargeMpiTest, reduce_diff_pools2) PARAM.input.bndpar, GlobalV::KPAR, GlobalV::NPROC_IN_STOGROUP, - GlobalV::RANK_IN_STOGROUP, + GlobalV::RANK_IN_BPGROUP, GlobalV::MY_STOGROUP, GlobalV::NPROC_IN_POOL, GlobalV::RANK_IN_POOL, @@ -172,7 +172,7 @@ TEST_F(ChargeMpiTest, rho_mpi) PARAM.input.bndpar, GlobalV::KPAR, GlobalV::NPROC_IN_STOGROUP, - GlobalV::RANK_IN_STOGROUP, + GlobalV::RANK_IN_BPGROUP, GlobalV::MY_STOGROUP, GlobalV::NPROC_IN_POOL, GlobalV::RANK_IN_POOL, diff --git a/source/module_hamilt_pw/hamilt_pwdft/parallel_grid.cpp b/source/module_hamilt_pw/hamilt_pwdft/parallel_grid.cpp index 5d6fe41764..1bb146fd0d 100644 --- a/source/module_hamilt_pw/hamilt_pwdft/parallel_grid.cpp +++ b/source/module_hamilt_pw/hamilt_pwdft/parallel_grid.cpp @@ -293,7 +293,7 @@ void Parallel_Grid::zpiece_to_stogroup(double *zpiece, const int &iz, double *rh { // case 1: the first part of rho in processor 0. // and send zpeice to to other pools. - if(proc == 0 && GlobalV::RANK_IN_STOGROUP ==0) + if(proc == 0 && GlobalV::RANK_IN_BPGROUP ==0) { for(int ir=0; irbcastfuncs) - { - bcastfunc(param); - } } void ReadInput::create_directory(const Parameter& param) @@ -283,22 +305,6 @@ void ReadInput::read_txt_input(Parameter& param, const std::string& filename) resetvalue_item->reset_value(*resetvalue_item, param); } } - this->set_globalv(param); - - // 3) count the number of atom types from STRU file - if (this->check_ntype_flag) - { - check_ntype(param.globalv.global_in_stru, param.input.ntype); - } - - // 4) check the value of the parameters - for (auto& input_item: this->input_lists) - { - Input_Item* checkvalue_item = &(input_item.second); - if (checkvalue_item->check_value != nullptr) { - checkvalue_item->check_value(*checkvalue_item, param); - } - } } void ReadInput::write_txt_input(const Parameter& param, const std::string& filename) @@ -442,9 +448,10 @@ void ReadInput::check_ntype(const std::string& fn, int& param_ntype) { ModuleBase::WARNING_QUIT("ReadInput", "ntype should be greater than 0."); } - else { + else + { GlobalV::ofs_running << " 'ntype' is automatically set to " << param_ntype << std::endl; -} + } } int ReadInput::current_md_step(const std::string& file_dir) diff --git a/source/module_io/read_input.h b/source/module_io/read_input.h index a39d043fe7..6166985a0b 100644 --- a/source/module_io/read_input.h +++ b/source/module_io/read_input.h @@ -83,10 +83,11 @@ class ReadInput * @param item input_item */ void add_item(const Input_Item& item); - //set globalv parameters + // INPUT and STRU need to refer to each other in ABACUS, + // so it is necessary to obtain the file paths related to all inputs + void set_global_dir(Parameter& para); + // set globalv parameters void set_globalv(Parameter& para); - // add bcast functions for global values - void set_globalv_bcast(); // system items void item_system(); // items for electronic structure diff --git a/source/module_io/read_input_item_elec_stru.cpp b/source/module_io/read_input_item_elec_stru.cpp index 089f842de5..79fb24584b 100644 --- a/source/module_io/read_input_item_elec_stru.cpp +++ b/source/module_io/read_input_item_elec_stru.cpp @@ -248,6 +248,7 @@ void ReadInput::item_elec_stru() } }; sync_double(input.nupdown); + add_bool_bcast(sys.two_fermi); this->add_item(item); } { diff --git a/source/module_io/read_input_item_exx_dftu.cpp b/source/module_io/read_input_item_exx_dftu.cpp index dc7c6a6025..b74611bd79 100644 --- a/source/module_io/read_input_item_exx_dftu.cpp +++ b/source/module_io/read_input_item_exx_dftu.cpp @@ -398,6 +398,7 @@ void ReadInput::item_dftu() } }; sync_double(input.uramping_eV); + add_double_bcast(sys.uramping); this->add_item(item); } { diff --git a/source/module_io/read_input_item_output.cpp b/source/module_io/read_input_item_output.cpp index 8e4a59d046..e2c9dfb068 100644 --- a/source/module_io/read_input_item_output.cpp +++ b/source/module_io/read_input_item_output.cpp @@ -187,6 +187,7 @@ void ReadInput::item_output() } }; sync_string(input.out_level); + add_bool_bcast(sys.out_md_control); this->add_item(item); } { diff --git a/source/module_io/read_input_item_postprocess.cpp b/source/module_io/read_input_item_postprocess.cpp index 1087e24d95..46c1e18b4f 100644 --- a/source/module_io/read_input_item_postprocess.cpp +++ b/source/module_io/read_input_item_postprocess.cpp @@ -15,6 +15,7 @@ void ReadInput::item_postprocess() para.sys.dos_setemin = true; }; sync_double(input.dos_emin_ev); + add_bool_bcast(sys.dos_setemin); this->add_item(item); } { @@ -25,6 +26,7 @@ void ReadInput::item_postprocess() para.sys.dos_setemax = true; }; sync_double(input.dos_emax_ev); + add_bool_bcast(sys.dos_setemax); this->add_item(item); } { diff --git a/source/module_io/read_input_item_system.cpp b/source/module_io/read_input_item_system.cpp index ad00b0a514..59170608e0 100644 --- a/source/module_io/read_input_item_system.cpp +++ b/source/module_io/read_input_item_system.cpp @@ -253,6 +253,7 @@ void ReadInput::item_system() } }; this->add_item(item); + add_int_bcast(sys.kpar_lcao); } { Input_Item item("bndpar"); @@ -335,7 +336,6 @@ void ReadInput::item_system() item.annotation = "number of points along x axis for FFT grid"; item.read_value = [](const Input_Item& item, Parameter& para) { para.input.nx = intvalue; - para.sys.ncx = intvalue; }; item.check_value = [](const Input_Item& item, const Parameter& para) { if (para.input.nx * para.input.ny * para.input.nz == 0 && para.input.nx != 0) @@ -351,7 +351,6 @@ void ReadInput::item_system() item.annotation = "number of points along y axis for FFT grid"; item.read_value = [](const Input_Item& item, Parameter& para) { para.input.ny = intvalue; - para.sys.ncy = intvalue; }; item.check_value = [](const Input_Item& item, const Parameter& para) { if (para.input.nx * para.input.ny * para.input.nz == 0 && para.input.ny != 0) @@ -367,7 +366,6 @@ void ReadInput::item_system() item.annotation = "number of points along z axis for FFT grid"; item.read_value = [](const Input_Item& item, Parameter& para) { para.input.nz = intvalue; - para.sys.ncz = intvalue; }; item.check_value = [](const Input_Item& item, const Parameter& para) { if (para.input.nx * para.input.ny * para.input.nz == 0 && para.input.nz != 0) diff --git a/source/module_io/read_set_globalv.cpp b/source/module_io/read_set_globalv.cpp index e80d1435fb..70f7c77a4b 100755 --- a/source/module_io/read_set_globalv.cpp +++ b/source/module_io/read_set_globalv.cpp @@ -1,132 +1,134 @@ #include "module_base/global_variable.h" -#include "module_base/tool_quit.h" #include "module_base/module_device/device.h" +#include "module_base/tool_quit.h" #include "module_parameter/parameter.h" #include "read_input.h" #include "read_input_tool.h" namespace ModuleIO { -void ReadInput::set_globalv(Parameter& para) +/// @note Here para.inp has been synchronized of all ranks. +/// Only para.inp in rank 0 is right. +/// So we need to broadcast the results to all ranks. +void ReadInput::set_global_dir(Parameter& para) { - { - /// caculate the global output directory - const std::string prefix = "OUT."; - para.sys.global_out_dir = prefix + para.inp.suffix + "/"; - para.sys.global_out_dir = to_dir(para.sys.global_out_dir); + /// caculate the global output directory + const std::string prefix = "OUT."; + para.sys.global_out_dir = prefix + para.inp.suffix + "/"; + para.sys.global_out_dir = to_dir(para.sys.global_out_dir); - /// get the global output directory - para.sys.global_stru_dir = para.globalv.global_out_dir + "STRU/"; - para.sys.global_stru_dir = to_dir(para.sys.global_stru_dir); + /// get the global output directory + para.sys.global_stru_dir = para.globalv.global_out_dir + "STRU/"; + para.sys.global_stru_dir = to_dir(para.sys.global_stru_dir); - /// get the global output directory - para.sys.global_matrix_dir = para.globalv.global_out_dir + "matrix/"; - para.sys.global_matrix_dir = to_dir(para.sys.global_matrix_dir); + /// get the global output directory + para.sys.global_matrix_dir = para.globalv.global_out_dir + "matrix/"; + para.sys.global_matrix_dir = to_dir(para.sys.global_matrix_dir); - /// get the global readin directory - para.sys.global_readin_dir = para.inp.read_file_dir + '/'; - para.sys.global_readin_dir = to_dir(para.sys.global_readin_dir); + /// get the global readin directory + para.sys.global_readin_dir = para.inp.read_file_dir + '/'; + para.sys.global_readin_dir = to_dir(para.sys.global_readin_dir); - /// get the stru file for md restart case - if (para.inp.calculation == "md" && para.mdp.md_restart) - { - int istep = current_md_step(para.sys.global_readin_dir); + /// get the stru file for md restart case + if (para.inp.calculation == "md" && para.mdp.md_restart) + { + int istep = current_md_step(para.sys.global_readin_dir); - if (para.inp.read_file_dir == to_dir("OUT." + para.input.suffix)) - { - para.sys.global_in_stru = para.sys.global_stru_dir + "STRU_MD_" + std::to_string(istep); - } - else - { - para.sys.global_in_stru = para.inp.read_file_dir + "STRU_MD_" + std::to_string(istep); - } + if (para.inp.read_file_dir == to_dir("OUT." + para.inp.suffix)) + { + para.sys.global_in_stru = para.sys.global_stru_dir + "STRU_MD_" + std::to_string(istep); } else { - para.sys.global_in_stru = para.inp.stru_file; + para.sys.global_in_stru = para.inp.read_file_dir + "STRU_MD_" + std::to_string(istep); } + } + else + { + para.sys.global_in_stru = para.inp.stru_file; + } - /// caculate the gamma_only_pw and gamma_only_local - if (para.input.gamma_only) - { - para.sys.gamma_only_local = true; - } - if (para.sys.gamma_only_local) + // set the global log file + bool out_alllog = para.inp.out_alllog; +#ifdef __MPI + // because log_file is different for each rank, so we need to bcast the out_alllog + Parallel_Common::bcast_bool(out_alllog); +#endif + if (out_alllog) + { + PARAM.sys.log_file = "running_" + PARAM.inp.calculation + "_" + std::to_string(PARAM.sys.myrank + 1) + ".log"; + } + else + { + PARAM.sys.log_file = "running_" + PARAM.inp.calculation + ".log"; + } +#ifdef __MPI + Parallel_Common::bcast_string(para.sys.global_in_card); + Parallel_Common::bcast_string(para.sys.global_out_dir); + Parallel_Common::bcast_string(para.sys.global_readin_dir); + Parallel_Common::bcast_string(para.sys.global_stru_dir); + Parallel_Common::bcast_string(para.sys.global_matrix_dir); + Parallel_Common::bcast_string(para.sys.global_in_stru); +#endif +} + +/// @note Here para.inp has been synchronized of all ranks. +/// All para.inp have the same value. +void ReadInput::set_globalv(Parameter& para) +{ + /// caculate the gamma_only_pw and gamma_only_local + if (para.inp.gamma_only) + { + para.sys.gamma_only_local = true; + } + if (para.sys.gamma_only_local) + { + if (para.inp.esolver_type == "tddft") { - if (para.inp.esolver_type == "tddft") - { - GlobalV::ofs_running << " WARNING : gamma_only is not applicable for tddft" << std::endl; - para.sys.gamma_only_local = false; - } + GlobalV::ofs_running << " WARNING : gamma_only is not applicable for tddft" << std::endl; + para.sys.gamma_only_local = false; } - /// set deepks_setorb - if (para.input.deepks_scf || para.input.deepks_out_labels) + } + /// set deepks_setorb + if (para.inp.deepks_scf || para.inp.deepks_out_labels) + { + para.sys.deepks_setorb = true; + } + /// set the noncolin and lspinorb from nspin + switch (para.inp.nspin) + { + case 4: + if (para.inp.noncolin) { - para.sys.deepks_setorb = true; + para.sys.domag = true; + para.sys.domag_z = false; } - /// set the noncolin and lspinorb from nspin - switch (para.input.nspin) + else { - case 4: - if (para.input.noncolin) - { - para.sys.domag = true; - para.sys.domag_z = false; - } - else - { - para.sys.domag = false; - para.sys.domag_z = true; - } - para.sys.npol = 2; - break; - case 2: - case 1: para.sys.domag = false; - para.sys.domag_z = false; - para.sys.npol = 1; - default: - break; + para.sys.domag_z = true; } - - para.sys.nqx=static_cast((sqrt(para.inp.ecutwfc) / para.sys.dq + 4.0) * para.inp.cell_factor); - para.sys.nqxq=static_cast((sqrt(para.inp.ecutrho) / para.sys.dq + 4.0) * para.inp.cell_factor); + para.sys.npol = 2; + break; + case 2: + case 1: + para.sys.domag = false; + para.sys.domag_z = false; + para.sys.npol = 1; + default: + break; + } + /// set ncx,ncy,ncz + para.sys.ncx = para.inp.nx; + para.sys.ncy = para.inp.ny; + para.sys.ncz = para.inp.nz; +#ifdef __MPI + Parallel_Common::bcast_bool(para.sys.double_grid); +#endif + // calculate the number of nbands_local + para.sys.nbands_l = para.inp.nbands / para.inp.bndpar; + if (GlobalV::RANK_IN_BPGROUP < para.inp.nbands % para.inp.bndpar) + { + para.sys.nbands_l++; } -} - -void ReadInput::set_globalv_bcast() -{ - // add_int_bcast(sys.myrank); - add_bool_bcast(sys.two_fermi); - add_bool_bcast(sys.use_uspp); - add_bool_bcast(sys.dos_setemin); - add_bool_bcast(sys.dos_setemax); - - add_int_bcast(sys.ncx); - add_int_bcast(sys.ncy); - add_int_bcast(sys.ncz); - add_bool_bcast(sys.out_md_control); - add_bool_bcast(sys.rpa_setorb); - - add_bool_bcast(sys.gamma_only_pw); - add_bool_bcast(sys.gamma_only_local); - - add_string_bcast(sys.global_in_card); - add_string_bcast(sys.global_out_dir); - add_string_bcast(sys.global_readin_dir); - add_string_bcast(sys.global_stru_dir); - add_string_bcast(sys.global_matrix_dir); - - add_bool_bcast(sys.deepks_setorb); - - add_bool_bcast(sys.domag); - add_bool_bcast(sys.domag_z); - add_int_bcast(sys.npol); - - add_bool_bcast(sys.double_grid); - add_double_bcast(sys.uramping); - add_double_bcast(sys.dq); - add_int_bcast(sys.nqx); - add_int_bcast(sys.nqxq); - add_int_bcast(sys.kpar_lcao); } } // namespace ModuleIO diff --git a/source/module_io/test/read_wfc_to_rho_test.cpp b/source/module_io/test/read_wfc_to_rho_test.cpp index 7717b1bbc0..0760c6b319 100644 --- a/source/module_io/test/read_wfc_to_rho_test.cpp +++ b/source/module_io/test/read_wfc_to_rho_test.cpp @@ -285,7 +285,7 @@ int main(int argc, char** argv) PARAM.inp.bndpar, GlobalV::KPAR, GlobalV::NPROC_IN_STOGROUP, - GlobalV::RANK_IN_STOGROUP, + GlobalV::RANK_IN_BPGROUP, GlobalV::MY_STOGROUP, GlobalV::NPROC_IN_POOL, GlobalV::RANK_IN_POOL, diff --git a/source/module_io/test/write_istate_info_test.cpp b/source/module_io/test/write_istate_info_test.cpp index 0afdad3608..435de37722 100644 --- a/source/module_io/test/write_istate_info_test.cpp +++ b/source/module_io/test/write_istate_info_test.cpp @@ -54,7 +54,7 @@ TEST_F(IstateInfoTest, OutIstateInfoS1) PARAM.input.bndpar, GlobalV::KPAR, GlobalV::NPROC_IN_STOGROUP, - GlobalV::RANK_IN_STOGROUP, + GlobalV::RANK_IN_BPGROUP, GlobalV::MY_STOGROUP, GlobalV::NPROC_IN_POOL, GlobalV::RANK_IN_POOL, @@ -104,7 +104,7 @@ TEST_F(IstateInfoTest, OutIstateInfoS2) PARAM.input.bndpar, GlobalV::KPAR, GlobalV::NPROC_IN_STOGROUP, - GlobalV::RANK_IN_STOGROUP, + GlobalV::RANK_IN_BPGROUP, GlobalV::MY_STOGROUP, GlobalV::NPROC_IN_POOL, GlobalV::RANK_IN_POOL, diff --git a/source/module_parameter/system_parameter.h b/source/module_parameter/system_parameter.h index 1401d015b0..d1258ffc6b 100755 --- a/source/module_parameter/system_parameter.h +++ b/source/module_parameter/system_parameter.h @@ -33,7 +33,6 @@ struct System_para int ncx = 0, ncy = 0, ncz = 0; ///< three dimension of FFT charge/grid, same as "nx,ny,nz" bool out_md_control = false; ///< true if "out_level" is set - bool rpa_setorb = false; ///< true if "rpa" is set bool gamma_only_pw = false; ///< true if "gamma_only" is true and "basis_type" is "pw" ///< for plane wave basis. bool gamma_only_local = false; ///< true if "gamma_only" is true and "lcao" @@ -44,6 +43,7 @@ struct System_para std::string global_readin_dir = ""; ///< global readin directory std::string global_stru_dir = ""; ///< global structure directory std::string global_matrix_dir = ""; ///< global matrix directory + std::string log_file = "log"; ///< log file name bool deepks_setorb = false; ///< true if "deepks" is set int npol = 1; ///< number of polarization @@ -54,5 +54,6 @@ struct System_para double uramping = -10.0 / 13.6; /// U-Ramping method (Ry) std::vector hubbard_u = {}; ///< Hubbard Coulomb interaction parameter U (Ry) int kpar_lcao = 1; ///< global number of pools for LCAO diagonalization only + int nbands_l = 0; ///< number of bands of each band parallel calculation }; #endif \ No newline at end of file diff --git a/source/module_psi/psi.cpp b/source/module_psi/psi.cpp index 78eb202766..07452554b1 100644 --- a/source/module_psi/psi.cpp +++ b/source/module_psi/psi.cpp @@ -296,7 +296,7 @@ const int& Psi::get_current_ngk() const } template -const int Psi::get_npol() const +int Psi::get_npol() const { if (PARAM.inp.nspin == 4) { diff --git a/source/module_psi/psi.h b/source/module_psi/psi.h index 75e13433ea..9b427c070d 100644 --- a/source/module_psi/psi.h +++ b/source/module_psi/psi.h @@ -134,7 +134,7 @@ class Psi // solve Range: return(pointer of begin, number of bands or k-points) std::tuple to_range(const Range& range) const; - const int get_npol() const; + int get_npol() const; private: T* psi = nullptr; // avoid using C++ STL From c5dc8cea0fdd33a8a0a580f313ed0fc886b49a0f Mon Sep 17 00:00:00 2001 From: Qianruipku Date: Wed, 22 Jan 2025 09:22:29 +0800 Subject: [PATCH 04/17] change name --- source/driver.cpp | 4 +- source/module_base/global_variable.cpp | 4 +- source/module_base/global_variable.h | 4 +- source/module_base/parallel_global.cpp | 24 ++--- source/module_base/parallel_global.h | 8 +- source/module_cell/cal_atoms_info.h | 10 ++ source/module_cell/test/klist_test_para.cpp | 4 +- source/module_elecstate/elecstate_print.cpp | 2 +- source/module_elecstate/elecstate_pw_sdft.cpp | 2 +- .../test_mpi/charge_mpi_test.cpp | 6 +- source/module_esolver/esolver_ks.cpp | 4 +- source/module_esolver/esolver_sdft_pw.cpp | 2 +- .../hamilt_pwdft/parallel_grid.cpp | 2 +- .../hamilt_stodft/sto_elecond.cpp | 4 +- .../hamilt_stodft/sto_forces.cpp | 2 +- .../hamilt_stodft/sto_stress_pw.cpp | 4 +- .../module_hamilt_pw/hamilt_stodft/sto_wf.cpp | 12 +-- source/module_hsolver/hsolver_pw_sdft.cpp | 4 +- .../module_hsolver/test/diago_lapack_test.cpp | 10 +- source/module_io/read_input.cpp | 4 +- source/module_io/read_input.h | 11 ++- source/module_io/read_set_globalv.cpp | 94 +++++++++---------- .../module_io/test/read_wfc_to_rho_test.cpp | 2 +- .../module_io/test/write_istate_info_test.cpp | 4 +- source/module_io/write_cube.cpp | 2 +- source/module_io/write_istate_info.cpp | 2 +- source/module_parameter/parameter.cpp | 10 -- source/module_parameter/parameter.h | 6 -- source/module_psi/psi_init.cpp | 6 +- 29 files changed, 120 insertions(+), 133 deletions(-) diff --git a/source/driver.cpp b/source/driver.cpp index 285afa426e..efd76ca276 100644 --- a/source/driver.cpp +++ b/source/driver.cpp @@ -152,9 +152,9 @@ void Driver::reading() GlobalV::MY_RANK, PARAM.inp.bndpar, GlobalV::KPAR, - GlobalV::NPROC_IN_STOGROUP, + GlobalV::NPROC_IN_BNDGROUP, GlobalV::RANK_IN_BPGROUP, - GlobalV::MY_STOGROUP, + GlobalV::MY_BNDGROUP, GlobalV::NPROC_IN_POOL, GlobalV::RANK_IN_POOL, GlobalV::MY_POOL); diff --git a/source/module_base/global_variable.cpp b/source/module_base/global_variable.cpp index 6feaafa701..71f41c1b4b 100644 --- a/source/module_base/global_variable.cpp +++ b/source/module_base/global_variable.cpp @@ -20,9 +20,9 @@ int NPROC = 1; ///< global number of process int KPAR = 1; ///< global number of pools int MY_RANK = 0; ///< global index of process int MY_POOL = 0; ///< global index of pool (count in pool) -int MY_STOGROUP = 0; +int MY_BNDGROUP = 0; int NPROC_IN_POOL = 1; ///< local number of process in a pool -int NPROC_IN_STOGROUP = 1; +int NPROC_IN_BNDGROUP = 1; int RANK_IN_POOL = 0; ///< global index of pool (count in process), my_rank in each pool int RANK_IN_BPGROUP = 0; int DRANK = -1; ///< mohan add 2012-01-13, must be -1, so we can recognize who diff --git a/source/module_base/global_variable.h b/source/module_base/global_variable.h index 1bb1068384..7014912457 100644 --- a/source/module_base/global_variable.h +++ b/source/module_base/global_variable.h @@ -33,9 +33,9 @@ extern int NPROC; extern int KPAR; extern int MY_RANK; extern int MY_POOL; -extern int MY_STOGROUP; +extern int MY_BNDGROUP; extern int NPROC_IN_POOL; -extern int NPROC_IN_STOGROUP; +extern int NPROC_IN_BNDGROUP; extern int RANK_IN_POOL; extern int RANK_IN_BPGROUP; extern int DRANK; diff --git a/source/module_base/parallel_global.cpp b/source/module_base/parallel_global.cpp index bf1a28b4c8..b9d87f2700 100644 --- a/source/module_base/parallel_global.cpp +++ b/source/module_base/parallel_global.cpp @@ -253,9 +253,9 @@ void Parallel_Global::init_pools(const int& NPROC, const int& MY_RANK, const int& BNDPAR, const int& KPAR, - int& NPROC_IN_STOGROUP, + int& NPROC_IN_BNDGROUP, int& RANK_IN_BPGROUP, - int& MY_STOGROUP, + int& MY_BNDGROUP, int& NPROC_IN_POOL, int& RANK_IN_POOL, int& MY_POOL) @@ -268,9 +268,9 @@ void Parallel_Global::init_pools(const int& NPROC, MY_RANK, BNDPAR, KPAR, - NPROC_IN_STOGROUP, + NPROC_IN_BNDGROUP, RANK_IN_BPGROUP, - MY_STOGROUP, + MY_BNDGROUP, NPROC_IN_POOL, RANK_IN_POOL, MY_POOL); @@ -316,16 +316,16 @@ void Parallel_Global::divide_pools(const int& NPROC, const int& MY_RANK, const int& BNDPAR, const int& KPAR, - int& NPROC_IN_STOGROUP, + int& NPROC_IN_BNDGROUP, int& RANK_IN_BPGROUP, - int& MY_STOGROUP, + int& MY_BNDGROUP, int& NPROC_IN_POOL, int& RANK_IN_POOL, int& MY_POOL) { // note: the order of k-point parallelization and band parallelization is important // The order will not change the behavior of INTER_POOL or PARAPW_WORLD, and MY_POOL - // and MY_STOGROUP will be the same as well. + // and MY_BNDGROUP will be the same as well. if(BNDPAR > 1 && NPROC %(BNDPAR * KPAR) != 0) { std::cout << "Error: When BNDPAR = " << BNDPAR << " > 1, number of processes (" << NPROC << ") must be divisible by the number of groups (" @@ -358,17 +358,17 @@ void Parallel_Global::divide_pools(const int& NPROC, if(BNDPAR > 1) { - NPROC_IN_STOGROUP = kpar_group.ngroups * bndpar_group.nprocs_in_group; + NPROC_IN_BNDGROUP = kpar_group.ngroups * bndpar_group.nprocs_in_group; RANK_IN_BPGROUP = kpar_group.my_group * bndpar_group.nprocs_in_group + bndpar_group.rank_in_group; - MY_STOGROUP = bndpar_group.my_group; - MPI_Comm_split(MPI_COMM_WORLD, MY_STOGROUP, RANK_IN_BPGROUP, &STO_WORLD); + MY_BNDGROUP = bndpar_group.my_group; + MPI_Comm_split(MPI_COMM_WORLD, MY_BNDGROUP, RANK_IN_BPGROUP, &STO_WORLD); MPI_Comm_dup(bndpar_group.inter_comm, &PARAPW_WORLD); } else { - NPROC_IN_STOGROUP = NPROC; + NPROC_IN_BNDGROUP = NPROC; RANK_IN_BPGROUP = MY_RANK; - MY_STOGROUP = 0; + MY_BNDGROUP = 0; MPI_Comm_dup(MPI_COMM_WORLD, &STO_WORLD); MPI_Comm_split(MPI_COMM_WORLD, MY_RANK, 0, &PARAPW_WORLD); } diff --git a/source/module_base/parallel_global.h b/source/module_base/parallel_global.h index 59df6aadd3..71e933a33e 100644 --- a/source/module_base/parallel_global.h +++ b/source/module_base/parallel_global.h @@ -48,9 +48,9 @@ void init_pools(const int& NPROC, const int& MY_RANK, const int& BNDPAR, const int& KPAR, - int& NPROC_IN_STOGROUP, + int& NPROC_IN_BNDGROUP, int& RANK_IN_BPGROUP, - int& MY_STOGROUP, + int& MY_BNDGROUP, int& NPROC_IN_POOL, int& RANK_IN_POOL, int& MY_POOL); @@ -59,9 +59,9 @@ void divide_pools(const int& NPROC, const int& MY_RANK, const int& BNDPAR, const int& KPAR, - int& NPROC_IN_STOGROUP, + int& NPROC_IN_BNDGROUP, int& RANK_IN_BPGROUP, - int& MY_STOGROUP, + int& MY_BNDGROUP, int& NPROC_IN_POOL, int& RANK_IN_POOL, int& MY_POOL); diff --git a/source/module_cell/cal_atoms_info.h b/source/module_cell/cal_atoms_info.h index 3b59d57540..66b759204d 100644 --- a/source/module_cell/cal_atoms_info.h +++ b/source/module_cell/cal_atoms_info.h @@ -68,6 +68,16 @@ class CalAtomsInfo nelec_spin[1] = (para.inp.nelec - para.inp.nupdown ) / 2.0; } elecstate::cal_nbands(para.inp.nelec, para.sys.nlocal, nelec_spin, para.input.nbands); + // calculate the number of nbands_local + para.sys.nbands_l = para.inp.nbands; + if (inp.ks_solver == "bpcg") // only bpcg support band parallel + { + para.sys.nbands_l = para.inp.nbands / para.inp.bndpar; + if (GlobalV::RANK_IN_BPGROUP < para.inp.nbands % para.inp.bndpar) + { + para.sys.nbands_l++; + } + } return; } }; diff --git a/source/module_cell/test/klist_test_para.cpp b/source/module_cell/test/klist_test_para.cpp index 163f9b740b..531e5eb64b 100644 --- a/source/module_cell/test/klist_test_para.cpp +++ b/source/module_cell/test/klist_test_para.cpp @@ -229,7 +229,7 @@ TEST_F(KlistParaTest, Set) GlobalV::MY_RANK, PARAM.input.bndpar, GlobalV::KPAR, - GlobalV::NPROC_IN_STOGROUP, + GlobalV::NPROC_IN_BNDGROUP, GlobalV::RANK_IN_BPGROUP, GlobalV::MY_STOGROUP, GlobalV::NPROC_IN_POOL, @@ -286,7 +286,7 @@ TEST_F(KlistParaTest, SetAfterVC) GlobalV::MY_RANK, PARAM.input.bndpar, GlobalV::KPAR, - GlobalV::NPROC_IN_STOGROUP, + GlobalV::NPROC_IN_BNDGROUP, GlobalV::RANK_IN_BPGROUP, GlobalV::MY_STOGROUP, GlobalV::NPROC_IN_POOL, diff --git a/source/module_elecstate/elecstate_print.cpp b/source/module_elecstate/elecstate_print.cpp index 4a0bf22536..92c49427cc 100644 --- a/source/module_elecstate/elecstate_print.cpp +++ b/source/module_elecstate/elecstate_print.cpp @@ -205,7 +205,7 @@ void ElecState::print_eigenvalue(std::ofstream& ofs) #ifdef __MPI MPI_Barrier(MPI_COMM_WORLD); #endif - bool ip_flag = PARAM.inp.out_alllog || (GlobalV::RANK_IN_POOL == 0 && GlobalV::MY_STOGROUP == 0); + bool ip_flag = PARAM.inp.out_alllog || (GlobalV::RANK_IN_POOL == 0 && GlobalV::MY_BNDGROUP == 0); if (GlobalV::MY_POOL == ip && ip_flag) { const int start_ik = nks_np * is; diff --git a/source/module_elecstate/elecstate_pw_sdft.cpp b/source/module_elecstate/elecstate_pw_sdft.cpp index bef6277adb..42952a7670 100644 --- a/source/module_elecstate/elecstate_pw_sdft.cpp +++ b/source/module_elecstate/elecstate_pw_sdft.cpp @@ -19,7 +19,7 @@ void ElecStatePW_SDFT::psiToRho(const psi::Psi& psi) setmem_var_op()(this->rho[is], 0, this->charge->nrxx); } - if (GlobalV::MY_STOGROUP == 0) + if (GlobalV::MY_BNDGROUP == 0 || PARAM.inp.ks_solver == "bpcg") { for (int ik = 0; ik < psi.get_nk(); ++ik) { diff --git a/source/module_elecstate/test_mpi/charge_mpi_test.cpp b/source/module_elecstate/test_mpi/charge_mpi_test.cpp index 92e3801f2a..6ef705a93b 100644 --- a/source/module_elecstate/test_mpi/charge_mpi_test.cpp +++ b/source/module_elecstate/test_mpi/charge_mpi_test.cpp @@ -70,7 +70,7 @@ TEST_F(ChargeMpiTest, reduce_diff_pools1) GlobalV::MY_RANK, PARAM.input.bndpar, GlobalV::KPAR, - GlobalV::NPROC_IN_STOGROUP, + GlobalV::NPROC_IN_BNDGROUP, GlobalV::RANK_IN_BPGROUP, GlobalV::MY_STOGROUP, GlobalV::NPROC_IN_POOL, @@ -116,7 +116,7 @@ TEST_F(ChargeMpiTest, reduce_diff_pools2) GlobalV::MY_RANK, PARAM.input.bndpar, GlobalV::KPAR, - GlobalV::NPROC_IN_STOGROUP, + GlobalV::NPROC_IN_BNDGROUP, GlobalV::RANK_IN_BPGROUP, GlobalV::MY_STOGROUP, GlobalV::NPROC_IN_POOL, @@ -171,7 +171,7 @@ TEST_F(ChargeMpiTest, rho_mpi) GlobalV::MY_RANK, PARAM.input.bndpar, GlobalV::KPAR, - GlobalV::NPROC_IN_STOGROUP, + GlobalV::NPROC_IN_BNDGROUP, GlobalV::RANK_IN_BPGROUP, GlobalV::MY_STOGROUP, GlobalV::NPROC_IN_POOL, diff --git a/source/module_esolver/esolver_ks.cpp b/source/module_esolver/esolver_ks.cpp index 8de7f56a05..6466f3be63 100644 --- a/source/module_esolver/esolver_ks.cpp +++ b/source/module_esolver/esolver_ks.cpp @@ -361,7 +361,7 @@ void ESolver_KS::hamilt2density(UnitCell& ucell, const int istep, con // Maybe in the future, density and wavefunctions should use different // parallel algorithms, in which they do not occupy all processors, for // example wavefunctions uses 20 processors while density uses 10. - if (GlobalV::MY_STOGROUP == 0) + if (GlobalV::MY_BNDGROUP == 0) { // double drho = this->estate.caldr2(); // EState should be used after it is constructed. @@ -550,7 +550,7 @@ void ESolver_KS::iter_finish(UnitCell& ucell, const int istep, int& i this->pelec->charge->rho, this->pelec->nelec_spin.data()); - if (GlobalV::MY_STOGROUP == 0) + if (GlobalV::MY_BNDGROUP == 0) { // mixing will restart at this->p_chgmix->mixing_restart steps if (drho <= PARAM.inp.mixing_restart && PARAM.inp.mixing_restart > 0.0 diff --git a/source/module_esolver/esolver_sdft_pw.cpp b/source/module_esolver/esolver_sdft_pw.cpp index f5f9292522..c4dd8d821b 100644 --- a/source/module_esolver/esolver_sdft_pw.cpp +++ b/source/module_esolver/esolver_sdft_pw.cpp @@ -198,7 +198,7 @@ void ESolver_SDFT_PW::hamilt2density_single(UnitCell& ucell, int iste // set_diagethr need it this->esolver_KS_ne = hsolver_pw_sdft_obj.stoiter.KS_ne; - if (GlobalV::MY_STOGROUP == 0) + if (GlobalV::MY_BNDGROUP == 0) { Symmetry_rho srho; for (int is = 0; is < PARAM.inp.nspin; is++) diff --git a/source/module_hamilt_pw/hamilt_pwdft/parallel_grid.cpp b/source/module_hamilt_pw/hamilt_pwdft/parallel_grid.cpp index 1bb146fd0d..b608a11fee 100644 --- a/source/module_hamilt_pw/hamilt_pwdft/parallel_grid.cpp +++ b/source/module_hamilt_pw/hamilt_pwdft/parallel_grid.cpp @@ -84,7 +84,7 @@ void Parallel_Grid::init( this->nproc_in_pool = new int[GlobalV::KPAR]; int nprocgroup; - if(PARAM.inp.esolver_type == "sdft") { nprocgroup = GlobalV::NPROC_IN_STOGROUP; + if(PARAM.inp.esolver_type == "sdft") { nprocgroup = GlobalV::NPROC_IN_BNDGROUP; } else { nprocgroup = GlobalV::NPROC; } diff --git a/source/module_hamilt_pw/hamilt_stodft/sto_elecond.cpp b/source/module_hamilt_pw/hamilt_stodft/sto_elecond.cpp index c0e5d61d3f..f208c2be9b 100644 --- a/source/module_hamilt_pw/hamilt_stodft/sto_elecond.cpp +++ b/source/module_hamilt_pw/hamilt_stodft/sto_elecond.cpp @@ -620,7 +620,7 @@ void Sto_EleCond::sKG(const int& smear_type, } // Parallel for bands int allbands_ks = this->nbands_ks - cutib0; - parallel_distribution paraks(allbands_ks, PARAM.inp.bndpar, GlobalV::MY_STOGROUP); + parallel_distribution paraks(allbands_ks, PARAM.inp.bndpar, GlobalV::MY_BNDGROUP); int perbands_ks = paraks.num_per; int ib0_ks = paraks.start; ib0_ks += this->nbands_ks - allbands_ks; @@ -653,7 +653,7 @@ void Sto_EleCond::sKG(const int& smear_type, //----------------------------------------------------------- // ks conductivity //----------------------------------------------------------- - if (GlobalV::MY_STOGROUP == 0 && allbands_ks > 0) + if (GlobalV::MY_BNDGROUP == 0 && allbands_ks > 0) { jjresponse_ks(ik, nt, dt, dEcut, this->p_elec->wg, velop, ct11.data(), ct12.data(), ct22.data()); } diff --git a/source/module_hamilt_pw/hamilt_stodft/sto_forces.cpp b/source/module_hamilt_pw/hamilt_stodft/sto_forces.cpp index 6684332781..bf717c062e 100644 --- a/source/module_hamilt_pw/hamilt_stodft/sto_forces.cpp +++ b/source/module_hamilt_pw/hamilt_stodft/sto_forces.cpp @@ -210,7 +210,7 @@ void Sto_Forces::cal_sto_force_nl( const int* nchip = stowf.nchip; const int npwx = wfc_basis->npwk_max; int nksbands = psi_in.get_nbands(); - if (GlobalV::MY_STOGROUP != 0) + if (GlobalV::MY_BNDGROUP != 0) { nksbands = 0; } diff --git a/source/module_hamilt_pw/hamilt_stodft/sto_stress_pw.cpp b/source/module_hamilt_pw/hamilt_stodft/sto_stress_pw.cpp index 62a4c16779..45db859c80 100644 --- a/source/module_hamilt_pw/hamilt_stodft/sto_stress_pw.cpp +++ b/source/module_hamilt_pw/hamilt_stodft/sto_stress_pw.cpp @@ -111,7 +111,7 @@ void Sto_Stress_PW::sto_stress_kin(ModuleBase::matrix& sigma, ModuleBase::timer::tick("Sto_Stress_PW", "stress_kin"); int nksbands = psi.get_nbands(); - if (GlobalV::MY_STOGROUP != 0) + if (GlobalV::MY_BNDGROUP != 0) { nksbands = 0; } @@ -160,7 +160,7 @@ void Sto_Stress_PW::sto_stress_nl(ModuleBase::matrix& sigma, int* nchip = stowf.nchip; const int npwx = wfc_basis->npwk_max; int nksbands = psi_in.get_nbands(); - if (GlobalV::MY_STOGROUP != 0) + if (GlobalV::MY_BNDGROUP != 0 && PARAM.inp.ks_solver != "bpcg") { nksbands = 0; } diff --git a/source/module_hamilt_pw/hamilt_stodft/sto_wf.cpp b/source/module_hamilt_pw/hamilt_stodft/sto_wf.cpp index 9ebdc62f04..5fdb9f59b3 100644 --- a/source/module_hamilt_pw/hamilt_stodft/sto_wf.cpp +++ b/source/module_hamilt_pw/hamilt_stodft/sto_wf.cpp @@ -72,7 +72,7 @@ void Stochastic_WF::init_sto_orbitals(const int seed_in) } else { - srand((unsigned)std::abs(seed_in) + (GlobalV::MY_STOGROUP * GlobalV::NPROC_IN_STOGROUP + GlobalV::RANK_IN_BPGROUP) * 10000); + srand((unsigned)std::abs(seed_in) + (GlobalV::MY_BNDGROUP * GlobalV::NPROC_IN_BNDGROUP + GlobalV::RANK_IN_BPGROUP) * 10000); } this->allocate_chi0(); @@ -88,12 +88,12 @@ void Stochastic_WF::allocate_chi0() // former processor calculate more bands if (firstrankmore) { - igroup = GlobalV::MY_STOGROUP; + igroup = GlobalV::MY_BNDGROUP; } // latter processor calculate more bands else { - igroup = PARAM.inp.bndpar - GlobalV::MY_STOGROUP - 1; + igroup = PARAM.inp.bndpar - GlobalV::MY_BNDGROUP - 1; } const int nchi = PARAM.inp.nbands_sto; const int npwx = this->npwx; @@ -172,12 +172,12 @@ void Stochastic_WF::init_com_orbitals() // former processor calculate more bands if (firstrankmore) { - igroup = GlobalV::MY_STOGROUP; + igroup = GlobalV::MY_BNDGROUP; } // latter processor calculate more bands else { - igroup = PARAM.inp.bndpar - GlobalV::MY_STOGROUP - 1; + igroup = PARAM.inp.bndpar - GlobalV::MY_BNDGROUP - 1; } const int ngroup = PARAM.inp.bndpar; const int n_in_pool = GlobalV::NPROC_IN_POOL; @@ -318,7 +318,7 @@ void Stochastic_WF::init_sto_orbitals_Ecut(const int seed_in, MPI_Allgather(&nchiper, 1, MPI_INT, nrecv, 1, MPI_INT, PARAPW_WORLD); #endif int ichi_start = 0; - for (int i = 0; i < GlobalV::MY_STOGROUP; ++i) + for (int i = 0; i < GlobalV::MY_BNDGROUP; ++i) { ichi_start += nrecv[i]; } diff --git a/source/module_hsolver/hsolver_pw_sdft.cpp b/source/module_hsolver/hsolver_pw_sdft.cpp index d03b37b848..0b964a9388 100644 --- a/source/module_hsolver/hsolver_pw_sdft.cpp +++ b/source/module_hsolver/hsolver_pw_sdft.cpp @@ -46,7 +46,7 @@ void HSolverPW_SDFT::solve(const UnitCell& ucell, { ModuleBase::timer::tick("HSolverPW_SDFT", "solve_KS"); pHamilt->updateHk(ik); - if (nbands > 0 && GlobalV::MY_STOGROUP == 0) + if (nbands > 0 && GlobalV::MY_BNDGROUP == 0) { /// update psi pointer for each k point psi.fix_k(ik); @@ -89,7 +89,7 @@ void HSolverPW_SDFT::solve(const UnitCell& ucell, // calculate eband = \sum_{ik,ib} w(ik)f(ik,ib)e_{ikib}, demet = -TS elecstate::ElecStatePW* pes_pw = static_cast*>(pes); - if (GlobalV::MY_STOGROUP == 0) + if (GlobalV::MY_BNDGROUP == 0) { pes_pw->calEBand(); } diff --git a/source/module_hsolver/test/diago_lapack_test.cpp b/source/module_hsolver/test/diago_lapack_test.cpp index cc181c5d28..cb3eb3f403 100644 --- a/source/module_hsolver/test/diago_lapack_test.cpp +++ b/source/module_hsolver/test/diago_lapack_test.cpp @@ -1,8 +1,8 @@ // Author: Zhang Xiaoyang // A modified version of diago_lcao_test.cpp -// #define private public +#define private public #include "module_parameter/parameter.h" -// #undef private +#undef private // Remove some useless functions and dependencies. Serialized the full code // and refactored some function. @@ -187,10 +187,8 @@ class DiagoLapackPrepare void set_env() { - // PARAM.sys.nlocal = nlocal; - PARAM.set_sys_nlocal(nlocal); - // PARAM.input.nbands = nbands; - PARAM.set_input_nbands(nbands); + PARAM.sys.nlocal = nlocal; + PARAM.input.nbands = nbands; } void diago() diff --git a/source/module_io/read_input.cpp b/source/module_io/read_input.cpp index 51ebc25b86..89c3cbc721 100644 --- a/source/module_io/read_input.cpp +++ b/source/module_io/read_input.cpp @@ -129,7 +129,7 @@ void ReadInput::read_parameters(Parameter& param, const std::string& filename_in // 3. check the number of atom types from STRU file // set the global directories - this->set_global_dir(param); + this->set_global_dir(param.inp, param.sys); if (this->check_ntype_flag && this->rank == 0) { check_ntype(param.globalv.global_in_stru, param.input.ntype); @@ -143,7 +143,7 @@ void ReadInput::read_parameters(Parameter& param, const std::string& filename_in } // 5. set the globalv parameters, some parameters in different processes are different. e.g. rank - this->set_globalv(param); + this->set_globalv(param.inp, param.sys); if (this->check_mode) { diff --git a/source/module_io/read_input.h b/source/module_io/read_input.h index 6166985a0b..09f397e474 100644 --- a/source/module_io/read_input.h +++ b/source/module_io/read_input.h @@ -83,11 +83,12 @@ class ReadInput * @param item input_item */ void add_item(const Input_Item& item); - // INPUT and STRU need to refer to each other in ABACUS, - // so it is necessary to obtain the file paths related to all inputs - void set_global_dir(Parameter& para); - // set globalv parameters - void set_globalv(Parameter& para); + /// @brief set System_para according to input parameters + /// INPUT and STRU need to refer to each other in ABACUS, + /// so it is necessary to obtain the file paths related to all inputs + void set_global_dir(const Input_para& inp, System_para& sys); + // set System_para according to input parameters + void set_globalv(const Input_para& inp, System_para& sys); // system items void item_system(); // items for electronic structure diff --git a/source/module_io/read_set_globalv.cpp b/source/module_io/read_set_globalv.cpp index 70f7c77a4b..0c5f542019 100755 --- a/source/module_io/read_set_globalv.cpp +++ b/source/module_io/read_set_globalv.cpp @@ -9,46 +9,46 @@ namespace ModuleIO /// @note Here para.inp has been synchronized of all ranks. /// Only para.inp in rank 0 is right. /// So we need to broadcast the results to all ranks. -void ReadInput::set_global_dir(Parameter& para) +void ReadInput::set_global_dir(const Input_para& inp, System_para& sys) { /// caculate the global output directory const std::string prefix = "OUT."; - para.sys.global_out_dir = prefix + para.inp.suffix + "/"; - para.sys.global_out_dir = to_dir(para.sys.global_out_dir); + sys.global_out_dir = prefix + inp.suffix + "/"; + sys.global_out_dir = to_dir(sys.global_out_dir); /// get the global output directory - para.sys.global_stru_dir = para.globalv.global_out_dir + "STRU/"; - para.sys.global_stru_dir = to_dir(para.sys.global_stru_dir); + sys.global_stru_dir = sys.global_out_dir + "STRU/"; + sys.global_stru_dir = to_dir(sys.global_stru_dir); /// get the global output directory - para.sys.global_matrix_dir = para.globalv.global_out_dir + "matrix/"; - para.sys.global_matrix_dir = to_dir(para.sys.global_matrix_dir); + sys.global_matrix_dir = sys.global_out_dir + "matrix/"; + sys.global_matrix_dir = to_dir(sys.global_matrix_dir); /// get the global readin directory - para.sys.global_readin_dir = para.inp.read_file_dir + '/'; - para.sys.global_readin_dir = to_dir(para.sys.global_readin_dir); + sys.global_readin_dir = inp.read_file_dir + '/'; + sys.global_readin_dir = to_dir(sys.global_readin_dir); /// get the stru file for md restart case - if (para.inp.calculation == "md" && para.mdp.md_restart) + if (inp.calculation == "md" && inp.mdp.md_restart) { - int istep = current_md_step(para.sys.global_readin_dir); + int istep = current_md_step(sys.global_readin_dir); - if (para.inp.read_file_dir == to_dir("OUT." + para.inp.suffix)) + if (inp.read_file_dir == to_dir("OUT." + inp.suffix)) { - para.sys.global_in_stru = para.sys.global_stru_dir + "STRU_MD_" + std::to_string(istep); + sys.global_in_stru = sys.global_stru_dir + "STRU_MD_" + std::to_string(istep); } else { - para.sys.global_in_stru = para.inp.read_file_dir + "STRU_MD_" + std::to_string(istep); + sys.global_in_stru = inp.read_file_dir + "STRU_MD_" + std::to_string(istep); } } else { - para.sys.global_in_stru = para.inp.stru_file; + sys.global_in_stru = inp.stru_file; } // set the global log file - bool out_alllog = para.inp.out_alllog; + bool out_alllog = inp.out_alllog; #ifdef __MPI // because log_file is different for each rank, so we need to bcast the out_alllog Parallel_Common::bcast_bool(out_alllog); @@ -62,73 +62,67 @@ void ReadInput::set_global_dir(Parameter& para) PARAM.sys.log_file = "running_" + PARAM.inp.calculation + ".log"; } #ifdef __MPI - Parallel_Common::bcast_string(para.sys.global_in_card); - Parallel_Common::bcast_string(para.sys.global_out_dir); - Parallel_Common::bcast_string(para.sys.global_readin_dir); - Parallel_Common::bcast_string(para.sys.global_stru_dir); - Parallel_Common::bcast_string(para.sys.global_matrix_dir); - Parallel_Common::bcast_string(para.sys.global_in_stru); + Parallel_Common::bcast_string(sys.global_in_card); + Parallel_Common::bcast_string(sys.global_out_dir); + Parallel_Common::bcast_string(sys.global_readin_dir); + Parallel_Common::bcast_string(sys.global_stru_dir); + Parallel_Common::bcast_string(sys.global_matrix_dir); + Parallel_Common::bcast_string(sys.global_in_stru); #endif } /// @note Here para.inp has been synchronized of all ranks. /// All para.inp have the same value. -void ReadInput::set_globalv(Parameter& para) +void ReadInput::set_globalv(const Input_para& inp, System_para& sys) { /// caculate the gamma_only_pw and gamma_only_local - if (para.inp.gamma_only) + if (inp.gamma_only) { - para.sys.gamma_only_local = true; + sys.gamma_only_local = true; } - if (para.sys.gamma_only_local) + if (sys.gamma_only_local) { - if (para.inp.esolver_type == "tddft") + if (inp.esolver_type == "tddft") { GlobalV::ofs_running << " WARNING : gamma_only is not applicable for tddft" << std::endl; - para.sys.gamma_only_local = false; + sys.gamma_only_local = false; } } /// set deepks_setorb - if (para.inp.deepks_scf || para.inp.deepks_out_labels) + if (inp.deepks_scf || inp.deepks_out_labels) { - para.sys.deepks_setorb = true; + sys.deepks_setorb = true; } /// set the noncolin and lspinorb from nspin - switch (para.inp.nspin) + switch (inp.nspin) { case 4: - if (para.inp.noncolin) + if (inp.noncolin) { - para.sys.domag = true; - para.sys.domag_z = false; + sys.domag = true; + sys.domag_z = false; } else { - para.sys.domag = false; - para.sys.domag_z = true; + sys.domag = false; + sys.domag_z = true; } - para.sys.npol = 2; + sys.npol = 2; break; case 2: case 1: - para.sys.domag = false; - para.sys.domag_z = false; - para.sys.npol = 1; + sys.domag = false; + sys.domag_z = false; + sys.npol = 1; default: break; } /// set ncx,ncy,ncz - para.sys.ncx = para.inp.nx; - para.sys.ncy = para.inp.ny; - para.sys.ncz = para.inp.nz; + sys.ncx = inp.nx; + sys.ncy = inp.ny; + sys.ncz = inp.nz; #ifdef __MPI - Parallel_Common::bcast_bool(para.sys.double_grid); + Parallel_Common::bcast_bool(sys.double_grid); #endif - // calculate the number of nbands_local - para.sys.nbands_l = para.inp.nbands / para.inp.bndpar; - if (GlobalV::RANK_IN_BPGROUP < para.inp.nbands % para.inp.bndpar) - { - para.sys.nbands_l++; - } } } // namespace ModuleIO diff --git a/source/module_io/test/read_wfc_to_rho_test.cpp b/source/module_io/test/read_wfc_to_rho_test.cpp index 0760c6b319..4ed3b21d2c 100644 --- a/source/module_io/test/read_wfc_to_rho_test.cpp +++ b/source/module_io/test/read_wfc_to_rho_test.cpp @@ -284,7 +284,7 @@ int main(int argc, char** argv) GlobalV::MY_RANK, PARAM.inp.bndpar, GlobalV::KPAR, - GlobalV::NPROC_IN_STOGROUP, + GlobalV::NPROC_IN_BNDGROUP, GlobalV::RANK_IN_BPGROUP, GlobalV::MY_STOGROUP, GlobalV::NPROC_IN_POOL, diff --git a/source/module_io/test/write_istate_info_test.cpp b/source/module_io/test/write_istate_info_test.cpp index 435de37722..265c87da51 100644 --- a/source/module_io/test/write_istate_info_test.cpp +++ b/source/module_io/test/write_istate_info_test.cpp @@ -53,7 +53,7 @@ TEST_F(IstateInfoTest, OutIstateInfoS1) GlobalV::MY_RANK, PARAM.input.bndpar, GlobalV::KPAR, - GlobalV::NPROC_IN_STOGROUP, + GlobalV::NPROC_IN_BNDGROUP, GlobalV::RANK_IN_BPGROUP, GlobalV::MY_STOGROUP, GlobalV::NPROC_IN_POOL, @@ -103,7 +103,7 @@ TEST_F(IstateInfoTest, OutIstateInfoS2) GlobalV::MY_RANK, PARAM.input.bndpar, GlobalV::KPAR, - GlobalV::NPROC_IN_STOGROUP, + GlobalV::NPROC_IN_BNDGROUP, GlobalV::RANK_IN_BPGROUP, GlobalV::MY_STOGROUP, GlobalV::NPROC_IN_POOL, diff --git a/source/module_io/write_cube.cpp b/source/module_io/write_cube.cpp index dbb4e469cf..400f454878 100644 --- a/source/module_io/write_cube.cpp +++ b/source/module_io/write_cube.cpp @@ -35,7 +35,7 @@ void ModuleIO::write_vdata_palgrid( // reduce std::vector data_xyz_full(nxyz); // data to be written #ifdef __MPI // reduce to rank 0 - if (my_pool == 0 && GlobalV::MY_STOGROUP == 0) + if (my_pool == 0 && GlobalV::MY_BNDGROUP == 0) { pgrid.reduce(data_xyz_full.data(), data); } diff --git a/source/module_io/write_istate_info.cpp b/source/module_io/write_istate_info.cpp index 0fc52afad2..b6d14711b1 100644 --- a/source/module_io/write_istate_info.cpp +++ b/source/module_io/write_istate_info.cpp @@ -24,7 +24,7 @@ void ModuleIO::write_istate_info(const ModuleBase::matrix &ekb,const ModuleBase: MPI_Barrier(MPI_COMM_WORLD); if (GlobalV::MY_POOL == ip) { - if (GlobalV::RANK_IN_POOL != 0 || GlobalV::MY_STOGROUP != 0 ) { continue; + if (GlobalV::RANK_IN_POOL != 0 || GlobalV::MY_BNDGROUP != 0 ) { continue; } #endif std::ofstream ofsi2(ss.str().c_str(), std::ios::app); diff --git a/source/module_parameter/parameter.cpp b/source/module_parameter/parameter.cpp index 8d5375b3e5..8688b93a07 100644 --- a/source/module_parameter/parameter.cpp +++ b/source/module_parameter/parameter.cpp @@ -14,13 +14,3 @@ void Parameter::set_start_time(const std::time_t& start_time) { sys.start_time = start_time; } - -void Parameter::set_input_nbands(const int& nbands) -{ - input.nbands = nbands; -} - -void Parameter::set_sys_nlocal(const int& nlocal) -{ - sys.nlocal = nlocal; -} diff --git a/source/module_parameter/parameter.h b/source/module_parameter/parameter.h index bdfe2359ea..9a3d4ea9ff 100644 --- a/source/module_parameter/parameter.h +++ b/source/module_parameter/parameter.h @@ -32,12 +32,6 @@ class Parameter void set_pal_param(const int& myrank, const int& nproc, const int& nthread_per_proc); // Set the start time void set_start_time(const std::time_t& start_time); - - // set input.nbands - void set_input_nbands(const int& nbands); - // set sys.nlocal - void set_sys_nlocal(const int& nlocal); - private: // Only ReadInput and CalAtomInfo can modify the value of Parameter. // Do not add extra friend class here!!! diff --git a/source/module_psi/psi_init.cpp b/source/module_psi/psi_init.cpp index 8ef89dcfdc..660808f627 100644 --- a/source/module_psi/psi_init.cpp +++ b/source/module_psi/psi_init.cpp @@ -40,8 +40,8 @@ void PSIInit::prepare_init(const int& random_seed) // use new instead, but will cause asymmetric allocation and deallocation, in literal aspect ModuleBase::timer::tick("PSIInit", "prepare_init"); this->psi_initer.reset(); - if (this->init_wfc == "random") - { + if (this->init_wfc == "random" || (PARAM.ks_solver == "bpcg" && PARAM.inp.bndpar > 1)) + { //temporary solution for band parallel bpcg this->psi_initer = std::unique_ptr>(new psi_initializer_random()); } else if (this->init_wfc == "file") @@ -86,7 +86,7 @@ void PSIInit::initialize_psi(Psi>* psi, hamilt::Hamilt* p_hamilt, std::ofstream& ofs_running) { - if (kspw_psi->get_nbands() == 0 || GlobalV::MY_STOGROUP != 0) + if (kspw_psi->get_nbands() == 0 || (GlobalV::MY_BNDGROUP != 0 && PARAM.inp.ks_solver != "bpcg")) { return; } From 21f7818580f92fcb50de2738f655bd98087ca72b Mon Sep 17 00:00:00 2001 From: Qianruipku Date: Wed, 22 Jan 2025 13:45:22 +0800 Subject: [PATCH 05/17] fix compile --- source/module_base/kernels/math_kernel_op.cpp | 4 +- .../test_parallel/parallel_global_test.cpp | 4 +- source/module_cell/cal_atoms_info.h | 2 +- source/module_cell/test/klist_test_para.cpp | 4 +- .../test/elecstate_print_test.cpp | 27 +++-- .../test_mpi/charge_mpi_test.cpp | 6 +- .../module_hsolver/test/test_hsolver_sdft.cpp | 4 +- source/module_io/read_input.cpp | 44 ++++--- source/module_io/read_input_item_system.cpp | 6 +- source/module_io/read_set_globalv.cpp | 113 +++++++++--------- .../module_io/test/read_wfc_to_rho_test.cpp | 2 +- .../module_io/test/write_istate_info_test.cpp | 4 +- .../test_serial/io_system_variable_test.cpp | 13 +- source/module_psi/psi_init.cpp | 2 +- 14 files changed, 126 insertions(+), 109 deletions(-) diff --git a/source/module_base/kernels/math_kernel_op.cpp b/source/module_base/kernels/math_kernel_op.cpp index 59a3c2ace8..5646a089c2 100644 --- a/source/module_base/kernels/math_kernel_op.cpp +++ b/source/module_base/kernels/math_kernel_op.cpp @@ -382,11 +382,13 @@ template struct line_minimize_with_block_op, base_device::DE template struct scal_op; template struct axpy_op, base_device::DEVICE_CPU>; +template struct axpy_op; template struct gemv_op, base_device::DEVICE_CPU>; template struct gemv_op; template struct gemm_op, base_device::DEVICE_CPU>; template struct gemm_op; template struct dot_real_op, base_device::DEVICE_CPU>; +template struct dot_real_op; template struct vector_div_constant_op, base_device::DEVICE_CPU>; template struct vector_mul_vector_op, base_device::DEVICE_CPU>; template struct vector_div_vector_op, base_device::DEVICE_CPU>; @@ -397,8 +399,6 @@ template struct calc_grad_with_block_op, base_device::DEVIC template struct line_minimize_with_block_op, base_device::DEVICE_CPU>; #ifdef __LCAO -template struct axpy_op; -template struct dot_real_op; template struct vector_mul_vector_op; template struct vector_div_constant_op; template struct vector_div_vector_op; diff --git a/source/module_base/test_parallel/parallel_global_test.cpp b/source/module_base/test_parallel/parallel_global_test.cpp index 9c0972ebe8..8b9e5053a4 100644 --- a/source/module_base/test_parallel/parallel_global_test.cpp +++ b/source/module_base/test_parallel/parallel_global_test.cpp @@ -55,7 +55,7 @@ class MPIContext int rank_in_pool; int nstogroup; - int my_stogroup; + int MY_BNDGROUP; int rank_in_stogroup; int nproc_in_stogroup; @@ -173,7 +173,7 @@ TEST_F(ParaGlobal, InitPools) mpi.kpar, mpi.nproc_in_stogroup, mpi.rank_in_stogroup, - mpi.my_stogroup, + mpi.MY_BNDGROUP, mpi.nproc_in_pool, mpi.rank_in_pool, mpi.my_pool), ::testing::ExitedWithCode(1), ""); diff --git a/source/module_cell/cal_atoms_info.h b/source/module_cell/cal_atoms_info.h index 66b759204d..567140484a 100644 --- a/source/module_cell/cal_atoms_info.h +++ b/source/module_cell/cal_atoms_info.h @@ -70,7 +70,7 @@ class CalAtomsInfo elecstate::cal_nbands(para.inp.nelec, para.sys.nlocal, nelec_spin, para.input.nbands); // calculate the number of nbands_local para.sys.nbands_l = para.inp.nbands; - if (inp.ks_solver == "bpcg") // only bpcg support band parallel + if (para.inp.ks_solver == "bpcg") // only bpcg support band parallel { para.sys.nbands_l = para.inp.nbands / para.inp.bndpar; if (GlobalV::RANK_IN_BPGROUP < para.inp.nbands % para.inp.bndpar) diff --git a/source/module_cell/test/klist_test_para.cpp b/source/module_cell/test/klist_test_para.cpp index 531e5eb64b..744f3e3150 100644 --- a/source/module_cell/test/klist_test_para.cpp +++ b/source/module_cell/test/klist_test_para.cpp @@ -231,7 +231,7 @@ TEST_F(KlistParaTest, Set) GlobalV::KPAR, GlobalV::NPROC_IN_BNDGROUP, GlobalV::RANK_IN_BPGROUP, - GlobalV::MY_STOGROUP, + GlobalV::MY_BNDGROUP, GlobalV::NPROC_IN_POOL, GlobalV::RANK_IN_POOL, GlobalV::MY_POOL); @@ -288,7 +288,7 @@ TEST_F(KlistParaTest, SetAfterVC) GlobalV::KPAR, GlobalV::NPROC_IN_BNDGROUP, GlobalV::RANK_IN_BPGROUP, - GlobalV::MY_STOGROUP, + GlobalV::MY_BNDGROUP, GlobalV::NPROC_IN_POOL, GlobalV::RANK_IN_POOL, GlobalV::MY_POOL); diff --git a/source/module_elecstate/test/elecstate_print_test.cpp b/source/module_elecstate/test/elecstate_print_test.cpp index e05759f6ca..b282d38d6c 100644 --- a/source/module_elecstate/test/elecstate_print_test.cpp +++ b/source/module_elecstate/test/elecstate_print_test.cpp @@ -98,6 +98,7 @@ class ElecStatePrintTest : public ::testing::Test ucell.magnet.tot_magnetization_nc[1] = 4.4; ucell.magnet.tot_magnetization_nc[2] = 5.5; PARAM.input.ks_solver = "dav"; + PARAM.sys.log_file = "test.dat"; } void TearDown() { @@ -120,11 +121,11 @@ TEST_F(ElecStatePrintTest, PrintFormat) TEST_F(ElecStatePrintTest, PrintEigenvalueS2) { PARAM.input.nspin = 2; - GlobalV::ofs_running.open("running_scf.log", std::ios::out); + GlobalV::ofs_running.open("test.dat", std::ios::out); // print eigenvalue elecstate.print_eigenvalue(GlobalV::ofs_running); GlobalV::ofs_running.close(); - ifs.open("running_scf.log", std::ios::in); + ifs.open("test.dat", std::ios::in); std::string str((std::istreambuf_iterator(ifs)), std::istreambuf_iterator()); EXPECT_THAT(str, testing::HasSubstr("STATE ENERGY(eV) AND OCCUPATIONS")); EXPECT_THAT(str, testing::HasSubstr("NSPIN == 2")); @@ -137,17 +138,17 @@ TEST_F(ElecStatePrintTest, PrintEigenvalueS2) EXPECT_THAT(str, testing::HasSubstr("1 40.8171 0.300000")); EXPECT_THAT(str, testing::HasSubstr("2 54.4228 0.400000")); ifs.close(); - std::remove("running_scf.log"); + std::remove("test.dat"); } TEST_F(ElecStatePrintTest, PrintEigenvalueS4) { PARAM.input.nspin = 4; - GlobalV::ofs_running.open("running_scf.log", std::ios::out); + GlobalV::ofs_running.open("test.dat", std::ios::out); // print eigenvalue elecstate.print_eigenvalue(GlobalV::ofs_running); GlobalV::ofs_running.close(); - ifs.open("running_scf.log", std::ios::in); + ifs.open("test.dat", std::ios::in); std::string str((std::istreambuf_iterator(ifs)), std::istreambuf_iterator()); EXPECT_THAT(str, testing::HasSubstr("STATE ENERGY(eV) AND OCCUPATIONS")); EXPECT_THAT(str, testing::HasSubstr("NSPIN == 4")); @@ -158,7 +159,7 @@ TEST_F(ElecStatePrintTest, PrintEigenvalueS4) EXPECT_THAT(str, testing::HasSubstr("1 40.8171 0.300000")); EXPECT_THAT(str, testing::HasSubstr("2 54.4228 0.400000")); ifs.close(); - std::remove("running_scf.log"); + std::remove("test.dat"); } TEST_F(ElecStatePrintTest, PrintBand) @@ -166,43 +167,43 @@ TEST_F(ElecStatePrintTest, PrintBand) PARAM.input.nspin = 1; PARAM.input.nbands = 2; GlobalV::MY_RANK = 0; - GlobalV::ofs_running.open("running_scf.log", std::ios::out); + GlobalV::ofs_running.open("test.dat", std::ios::out); // print eigenvalue elecstate.print_band(0, 1, 0); GlobalV::ofs_running.close(); - ifs.open("running_scf.log", std::ios::in); + ifs.open("test.dat", std::ios::in); std::string str((std::istreambuf_iterator(ifs)), std::istreambuf_iterator()); EXPECT_THAT(str, testing::HasSubstr("Energy (eV) & Occupations for spin=1 K-point=1")); EXPECT_THAT(str, testing::HasSubstr("1 13.6057 0.100000")); EXPECT_THAT(str, testing::HasSubstr("2 27.2114 0.200000")); ifs.close(); - std::remove("running_scf.log"); + std::remove("test.dat"); } TEST_F(ElecStatePrintTest, PrintEigenvalueWarning) { elecstate.ekb(0, 0) = 1.0e11; PARAM.input.nspin = 4; - GlobalV::ofs_running.open("running_scf.log", std::ios::out); + GlobalV::ofs_running.open("test.dat", std::ios::out); testing::internal::CaptureStdout(); EXPECT_EXIT(elecstate.print_eigenvalue(GlobalV::ofs_running), ::testing::ExitedWithCode(1), ""); output = testing::internal::GetCapturedStdout(); EXPECT_THAT(output, testing::HasSubstr("Eigenvalues are too large!")); GlobalV::ofs_running.close(); - std::remove("running_scf.log"); + std::remove("test.dat"); } TEST_F(ElecStatePrintTest, PrintBandWarning) { elecstate.ekb(0, 0) = 1.0e11; PARAM.input.nspin = 4; - GlobalV::ofs_running.open("running_scf.log", std::ios::out); + GlobalV::ofs_running.open("test.dat", std::ios::out); testing::internal::CaptureStdout(); EXPECT_EXIT(elecstate.print_band(0, 1, 0), ::testing::ExitedWithCode(1), ""); output = testing::internal::GetCapturedStdout(); EXPECT_THAT(output, testing::HasSubstr("Eigenvalues are too large!")); GlobalV::ofs_running.close(); - std::remove("running_scf.log"); + std::remove("test.dat"); } TEST_F(ElecStatePrintTest, PrintEtot) diff --git a/source/module_elecstate/test_mpi/charge_mpi_test.cpp b/source/module_elecstate/test_mpi/charge_mpi_test.cpp index 6ef705a93b..9fc13130cc 100644 --- a/source/module_elecstate/test_mpi/charge_mpi_test.cpp +++ b/source/module_elecstate/test_mpi/charge_mpi_test.cpp @@ -72,7 +72,7 @@ TEST_F(ChargeMpiTest, reduce_diff_pools1) GlobalV::KPAR, GlobalV::NPROC_IN_BNDGROUP, GlobalV::RANK_IN_BPGROUP, - GlobalV::MY_STOGROUP, + GlobalV::MY_BNDGROUP, GlobalV::NPROC_IN_POOL, GlobalV::RANK_IN_POOL, GlobalV::MY_POOL); @@ -118,7 +118,7 @@ TEST_F(ChargeMpiTest, reduce_diff_pools2) GlobalV::KPAR, GlobalV::NPROC_IN_BNDGROUP, GlobalV::RANK_IN_BPGROUP, - GlobalV::MY_STOGROUP, + GlobalV::MY_BNDGROUP, GlobalV::NPROC_IN_POOL, GlobalV::RANK_IN_POOL, GlobalV::MY_POOL); @@ -173,7 +173,7 @@ TEST_F(ChargeMpiTest, rho_mpi) GlobalV::KPAR, GlobalV::NPROC_IN_BNDGROUP, GlobalV::RANK_IN_BPGROUP, - GlobalV::MY_STOGROUP, + GlobalV::MY_BNDGROUP, GlobalV::NPROC_IN_POOL, GlobalV::RANK_IN_POOL, GlobalV::MY_POOL); diff --git a/source/module_hsolver/test/test_hsolver_sdft.cpp b/source/module_hsolver/test/test_hsolver_sdft.cpp index 3ed2593d9a..3d69cacf2f 100644 --- a/source/module_hsolver/test/test_hsolver_sdft.cpp +++ b/source/module_hsolver/test/test_hsolver_sdft.cpp @@ -252,7 +252,7 @@ class TestHSolverPW_SDFT : public ::testing::Test // stowf.nchip_max = 0; // psi_test_cd.resize(1, 2, 3); // PARAM.input.nelec = 1.0; -// GlobalV::MY_STOGROUP = 0.0; +// GlobalV::MY_BNDGROUP = 0.0; // int istep = 0; // int iter = 0; @@ -291,7 +291,7 @@ class TestHSolverPW_SDFT : public ::testing::Test // psi_test_no.nbands = 0; // psi_test_no.nbasis = 0; // PARAM.input.nelec = 1.0; -// GlobalV::MY_STOGROUP = 0.0; +// GlobalV::MY_BNDGROUP = 0.0; // PARAM.input.nspin = 1; // elecstate_test.charge = new Charge; // elecstate_test.charge->rho = new double*[1]; diff --git a/source/module_io/read_input.cpp b/source/module_io/read_input.cpp index 89c3cbc721..794588ecaa 100644 --- a/source/module_io/read_input.cpp +++ b/source/module_io/read_input.cpp @@ -13,6 +13,7 @@ #include "module_base/global_function.h" #include "module_base/tool_quit.h" #include "module_base/tool_title.h" +#include "module_base/module_device/device.h" namespace ModuleIO { @@ -112,22 +113,11 @@ void ReadInput::read_parameters(Parameter& param, const std::string& filename_in // 1. only rank 0 read the input file if (this->rank == 0) { - // 1. read the input file // We can also easily add other input file formats here this->read_txt_input(param, filename_in); - - // 2. check the value of the parameters - for (auto& input_item: this->input_lists) - { - Input_Item* checkvalue_item = &(input_item.second); - if (checkvalue_item->check_value != nullptr) - { - checkvalue_item->check_value(*checkvalue_item, param); - } - } } - // 3. check the number of atom types from STRU file + // 2. check the number of atom types from STRU file // set the global directories this->set_global_dir(param.inp, param.sys); if (this->check_ntype_flag && this->rank == 0) @@ -135,16 +125,42 @@ void ReadInput::read_parameters(Parameter& param, const std::string& filename_in check_ntype(param.globalv.global_in_stru, param.input.ntype); } - // 4. broadcast input parameters + // 3. broadcast input parameters // It must be after the check_ntype, because some parameters need to be filled due to ntype for (auto& bcastfunc: this->bcastfuncs) { bcastfunc(param); } - // 5. set the globalv parameters, some parameters in different processes are different. e.g. rank + // 4. set the globalv parameters, some parameters in different processes are different. e.g. rank, log_file this->set_globalv(param.inp, param.sys); + // 5. check the value of the parameters + // It must be after the check_ntype, because some parameters need to be checked according to ntype + // It must be after the set_globalv, because some parameters need to be checked according to param.sys + if (this->rank == 0) + { + for (auto& input_item: this->input_lists) + { + Input_Item* checkvalue_item = &(input_item.second); + if (checkvalue_item->check_value != nullptr) + { + checkvalue_item->check_value(*checkvalue_item, param); + } + } + } + + // 6. check and reset kpar. + // It must be after bcastfunc, and kpar and bndpar are synchronized + // It must be before wirte_txt_input, because kpar is used in write_txt_input + if (param.inp.device == "gpu" && param.inp.basis_type == "pw") + { + param.input.kpar = base_device::information::get_device_kpar(param.inp.kpar, param.inp.bndpar); + } + + + + if (this->check_mode) { std::cout << "----------------------------------------------------------" << std::endl; diff --git a/source/module_io/read_input_item_system.cpp b/source/module_io/read_input_item_system.cpp index 59170608e0..8d21a9f070 100644 --- a/source/module_io/read_input_item_system.cpp +++ b/source/module_io/read_input_item_system.cpp @@ -234,12 +234,8 @@ void ReadInput::item_system() "will be distributed among"; read_sync_int(input.kpar); item.reset_value = [](const Input_Item& item, Parameter& para) { - if (para.inp.device == "gpu" && para.inp.basis_type == "pw") - { - para.input.kpar = base_device::information::get_device_kpar(para.inp.kpar, para.inp.bndpar); - } #ifdef __LCAO - else if (para.inp.basis_type == "lcao") + if (para.inp.basis_type == "lcao") { para.sys.kpar_lcao = para.inp.kpar; para.input.kpar = 1; diff --git a/source/module_io/read_set_globalv.cpp b/source/module_io/read_set_globalv.cpp index 0c5f542019..a26c2a46d2 100755 --- a/source/module_io/read_set_globalv.cpp +++ b/source/module_io/read_set_globalv.cpp @@ -1,11 +1,67 @@ #include "module_base/global_variable.h" -#include "module_base/module_device/device.h" #include "module_base/tool_quit.h" #include "module_parameter/parameter.h" #include "read_input.h" #include "read_input_tool.h" namespace ModuleIO { +/// @note Here para.inp has been synchronized of all ranks. +/// All para.inp have the same value. +void ReadInput::set_globalv(const Input_para& inp, System_para& sys) +{ + /// caculate the gamma_only_pw and gamma_only_local + if (inp.gamma_only) + { + sys.gamma_only_local = true; + } + if (sys.gamma_only_local) + { + if (inp.esolver_type == "tddft") + { + GlobalV::ofs_running << " WARNING : gamma_only is not applicable for tddft" << std::endl; + sys.gamma_only_local = false; + } + } + /// set deepks_setorb + if (inp.deepks_scf || inp.deepks_out_labels) + { + sys.deepks_setorb = true; + } + /// set the noncolin and lspinorb from nspin + switch (inp.nspin) + { + case 4: + if (inp.noncolin) + { + sys.domag = true; + sys.domag_z = false; + } + else + { + sys.domag = false; + sys.domag_z = true; + } + sys.npol = 2; + break; + case 2: + case 1: + sys.domag = false; + sys.domag_z = false; + sys.npol = 1; + default: + break; + } + sys.nqx = static_cast((sqrt(inp.ecutwfc) / sys.dq + 4.0) * inp.cell_factor); + sys.nqxq = static_cast((sqrt(inp.ecutrho) / sys.dq + 4.0) * inp.cell_factor); + /// set ncx,ncy,ncz + sys.ncx = inp.nx; + sys.ncy = inp.ny; + sys.ncz = inp.nz; +#ifdef __MPI + Parallel_Common::bcast_bool(sys.double_grid); +#endif +} + /// @note Here para.inp has been synchronized of all ranks. /// Only para.inp in rank 0 is right. /// So we need to broadcast the results to all ranks. @@ -70,59 +126,4 @@ void ReadInput::set_global_dir(const Input_para& inp, System_para& sys) Parallel_Common::bcast_string(sys.global_in_stru); #endif } - -/// @note Here para.inp has been synchronized of all ranks. -/// All para.inp have the same value. -void ReadInput::set_globalv(const Input_para& inp, System_para& sys) -{ - /// caculate the gamma_only_pw and gamma_only_local - if (inp.gamma_only) - { - sys.gamma_only_local = true; - } - if (sys.gamma_only_local) - { - if (inp.esolver_type == "tddft") - { - GlobalV::ofs_running << " WARNING : gamma_only is not applicable for tddft" << std::endl; - sys.gamma_only_local = false; - } - } - /// set deepks_setorb - if (inp.deepks_scf || inp.deepks_out_labels) - { - sys.deepks_setorb = true; - } - /// set the noncolin and lspinorb from nspin - switch (inp.nspin) - { - case 4: - if (inp.noncolin) - { - sys.domag = true; - sys.domag_z = false; - } - else - { - sys.domag = false; - sys.domag_z = true; - } - sys.npol = 2; - break; - case 2: - case 1: - sys.domag = false; - sys.domag_z = false; - sys.npol = 1; - default: - break; - } - /// set ncx,ncy,ncz - sys.ncx = inp.nx; - sys.ncy = inp.ny; - sys.ncz = inp.nz; -#ifdef __MPI - Parallel_Common::bcast_bool(sys.double_grid); -#endif -} } // namespace ModuleIO diff --git a/source/module_io/test/read_wfc_to_rho_test.cpp b/source/module_io/test/read_wfc_to_rho_test.cpp index 4ed3b21d2c..c941c1e2c5 100644 --- a/source/module_io/test/read_wfc_to_rho_test.cpp +++ b/source/module_io/test/read_wfc_to_rho_test.cpp @@ -286,7 +286,7 @@ int main(int argc, char** argv) GlobalV::KPAR, GlobalV::NPROC_IN_BNDGROUP, GlobalV::RANK_IN_BPGROUP, - GlobalV::MY_STOGROUP, + GlobalV::MY_BNDGROUP, GlobalV::NPROC_IN_POOL, GlobalV::RANK_IN_POOL, GlobalV::MY_POOL); diff --git a/source/module_io/test/write_istate_info_test.cpp b/source/module_io/test/write_istate_info_test.cpp index 265c87da51..ed69f2c5c6 100644 --- a/source/module_io/test/write_istate_info_test.cpp +++ b/source/module_io/test/write_istate_info_test.cpp @@ -55,7 +55,7 @@ TEST_F(IstateInfoTest, OutIstateInfoS1) GlobalV::KPAR, GlobalV::NPROC_IN_BNDGROUP, GlobalV::RANK_IN_BPGROUP, - GlobalV::MY_STOGROUP, + GlobalV::MY_BNDGROUP, GlobalV::NPROC_IN_POOL, GlobalV::RANK_IN_POOL, GlobalV::MY_POOL); @@ -105,7 +105,7 @@ TEST_F(IstateInfoTest, OutIstateInfoS2) GlobalV::KPAR, GlobalV::NPROC_IN_BNDGROUP, GlobalV::RANK_IN_BPGROUP, - GlobalV::MY_STOGROUP, + GlobalV::MY_BNDGROUP, GlobalV::NPROC_IN_POOL, GlobalV::RANK_IN_POOL, GlobalV::MY_POOL); diff --git a/source/module_io/test_serial/io_system_variable_test.cpp b/source/module_io/test_serial/io_system_variable_test.cpp index b02395afc6..499399a69d 100644 --- a/source/module_io/test_serial/io_system_variable_test.cpp +++ b/source/module_io/test_serial/io_system_variable_test.cpp @@ -44,26 +44,29 @@ TEST_F(InputTest, Item_test) { param.input.suffix = "test"; - readinput.set_globalv(param); + readinput.set_global_dir(param.inp, param.sys); + EXPECT_EQ(param.sys.global_out_dir, "OUT.test/"); EXPECT_EQ(param.sys.global_stru_dir, "OUT.test/STRU/"); EXPECT_EQ(param.sys.global_matrix_dir, "OUT.test/matrix/"); + readinput.set_globalv(param.inp, param.sys); + param.input.basis_type = "lcao"; param.input.gamma_only = true; param.input.esolver_type = "tddft"; param.input.nspin = 2; - readinput.set_globalv(param); + readinput.set_globalv(param.inp, param.sys); EXPECT_EQ(param.sys.gamma_only_local, 0); param.input.deepks_scf = true; param.input.deepks_out_labels = true; - readinput.set_globalv(param); + readinput.set_globalv(param.inp, param.sys); EXPECT_EQ(param.sys.deepks_setorb, 1); param.input.nspin = 4; param.input.noncolin = true; - readinput.set_globalv(param); + readinput.set_globalv(param.inp, param.sys); EXPECT_EQ(param.sys.domag, 1); EXPECT_EQ(param.sys.domag_z, 0); EXPECT_EQ(param.sys.npol, 2); @@ -71,7 +74,7 @@ TEST_F(InputTest, Item_test) param.input.nspin = 1; param.input.lspinorb = true; param.input.noncolin = false; - readinput.set_globalv(param); + readinput.set_globalv(param.inp, param.sys); EXPECT_EQ(param.sys.domag, 0); EXPECT_EQ(param.sys.domag_z, 0); EXPECT_EQ(param.sys.npol, 1); diff --git a/source/module_psi/psi_init.cpp b/source/module_psi/psi_init.cpp index 660808f627..085cb381c8 100644 --- a/source/module_psi/psi_init.cpp +++ b/source/module_psi/psi_init.cpp @@ -40,7 +40,7 @@ void PSIInit::prepare_init(const int& random_seed) // use new instead, but will cause asymmetric allocation and deallocation, in literal aspect ModuleBase::timer::tick("PSIInit", "prepare_init"); this->psi_initer.reset(); - if (this->init_wfc == "random" || (PARAM.ks_solver == "bpcg" && PARAM.inp.bndpar > 1)) + if (this->init_wfc == "random" || (PARAM.inp.ks_solver == "bpcg" && PARAM.inp.bndpar > 1)) { //temporary solution for band parallel bpcg this->psi_initer = std::unique_ptr>(new psi_initializer_random()); } From c15c8237b601588a67eb1f833a035f2ddd592080 Mon Sep 17 00:00:00 2001 From: Qianruipku Date: Sun, 26 Jan 2025 09:59:17 +0800 Subject: [PATCH 06/17] make bpcg support bndpar > 1 --- source/module_base/para_gemm.cpp | 385 ++++++++++++------ source/module_base/para_gemm.h | 16 +- source/module_base/parallel_comm.cpp | 8 +- source/module_base/parallel_comm.h | 6 +- source/module_base/parallel_global.cpp | 22 +- source/module_base/parallel_reduce.cpp | 8 +- source/module_base/parallel_reduce.h | 4 +- .../test_parallel/test_para_gemm.cpp | 105 +++-- source/module_elecstate/elecstate.cpp | 19 +- source/module_elecstate/elecstate_pw_sdft.cpp | 2 +- .../module_elecstate/module_charge/charge.h | 3 +- .../module_charge/charge_mpi.cpp | 21 +- source/module_elecstate/occupy.cpp | 5 +- source/module_esolver/esolver_ks.cpp | 10 +- source/module_esolver/esolver_ks_pw.cpp | 6 +- source/module_esolver/esolver_sdft_pw.cpp | 4 +- .../hamilt_pwdft/onsite_projector.cpp | 3 +- .../hamilt_pwdft/parallel_grid.cpp | 8 +- .../hamilt_stodft/sto_dos.cpp | 2 +- .../hamilt_stodft/sto_elecond.cpp | 8 +- .../hamilt_stodft/sto_forces.cpp | 2 +- .../hamilt_stodft/sto_iter.cpp | 115 +++--- .../hamilt_stodft/sto_stress_pw.cpp | 4 +- .../hamilt_stodft/sto_tool.cpp | 2 +- .../module_hamilt_pw/hamilt_stodft/sto_wf.cpp | 2 +- source/module_hsolver/diago_bpcg.cpp | 152 ++----- source/module_hsolver/diago_bpcg.h | 26 +- source/module_hsolver/hsolver_pw.cpp | 4 +- source/module_hsolver/hsolver_pw_sdft.cpp | 21 +- .../module_hsolver/para_linear_transform.cpp | 172 ++++---- source/module_hsolver/para_linear_transform.h | 77 ++-- .../module_hsolver/test/diago_bpcg_test.cpp | 2 +- .../module_hsolver/test/test_hsolver_sdft.cpp | 4 +- .../test/test_para_linear_trans.cpp | 162 +++++--- source/module_io/cal_dos.cpp | 4 +- source/module_io/read_set_globalv.cpp | 10 + source/module_parameter/system_parameter.h | 4 +- source/module_psi/psi_init.cpp | 2 +- 38 files changed, 809 insertions(+), 601 deletions(-) diff --git a/source/module_base/para_gemm.cpp b/source/module_base/para_gemm.cpp index 0908457108..01eb7edc79 100644 --- a/source/module_base/para_gemm.cpp +++ b/source/module_base/para_gemm.cpp @@ -25,7 +25,7 @@ void PGemmCN::set_dimension( const int LDB_in, const int nrow_in, const int LDC_in, - const bool gatherC_in) + const int mode) { #ifdef __MPI MPI_Comm_rank(comm_col, &col_rank); @@ -45,13 +45,45 @@ void PGemmCN::set_dimension( this->ncolB = ncolB_in; this->nrow = nrow_in; #ifdef __MPI - this->gatherC = gatherC_in; + switch (mode) + { + case 1: + gatherC = true; + divideCrow = false; + break; + case 2: + gatherC = false; + divideCrow = false; + break; + case 3: + gatherC = false; + divideCrow = true; + break; + default: + break; + } requests.resize(col_nproc); - colA_loc.resize(col_nproc); - MPI_Allgather(&ncolA, 1, MPI_INT, colA_loc.data(), 1, MPI_INT, col_world); - for (int ip = 0; ip < col_nproc; ip++) + if (this->divideCrow) + { + colB_loc.resize(col_nproc); + MPI_Allgather(&ncolB, 1, MPI_INT, colB_loc.data(), 1, MPI_INT, col_world); + int sum = 0; + for (int ip = 0; ip < col_nproc; ip++) + { + max_colB = std::max(max_colB, colB_loc[ip]); + sum += colB_loc[ip]; + } + size_C_local = sum * LDC; + } + else { - max_colA = std::max(max_colA, colA_loc[ip]); + colA_loc.resize(col_nproc); + MPI_Allgather(&ncolA, 1, MPI_INT, colA_loc.data(), 1, MPI_INT, col_world); + for (int ip = 0; ip < col_nproc; ip++) + { + max_colA = std::max(max_colA, colA_loc[ip]); + } + size_C_local = ncolB * LDC; } if (this->gatherC) @@ -71,160 +103,265 @@ void PGemmCN::set_dimension( } size_C_global = displs[col_nproc - 1] + recv_counts[col_nproc - 1]; } - size_C_local = ncolB * LDC; #endif } template void PGemmCN::multiply(const T alpha, const T* A, const T* B, const T beta, T* C) { - const Device* ctx = {}; #ifdef __MPI - if (col_nproc > 1) + if (this->col_nproc > 1) { - std::vector A_tmp(max_colA * LDA); - for (int ip = 0; ip < col_nproc; ip++) + if (this->divideCrow) { - if (col_rank != ip) - { - int size = ncolA * LDA; - Parallel_Common::isend_dev(A, size, ip, 0, col_world, &requests[ip], A_tmp.data()); - } + multiply_row(alpha, A, B, beta, C); } + else + { + multiply_col(alpha, A, B, beta, C); + } + } + else +#endif + { + multiply_single(alpha, A, B, beta, C); + } +} + +template +void PGemmCN::multiply_single(const T alpha, const T* A, const T* B, const T beta, T* C) +{ + const Device* ctx = {}; +#ifdef __MPI + T real_beta = row_rank == 0 ? beta : 0; +#else + T real_beta = beta; +#endif + ModuleBase::gemm_op()(ctx, 'C', 'N', ncolA, ncolB, nrow, &alpha, A, LDA, B, LDB, &real_beta, C, LDC); +#ifdef __MPI + if (this->row_nproc > 1) + { + Parallel_Common::reduce_dev(C, size_C_local, row_world); + } +#endif +} + +#ifdef __MPI +template +void PGemmCN::multiply_col(const T alpha, const T* A, const T* B, const T beta, T* C) +{ + const Device* ctx = {}; - T* C_local = C; - std::vector C_tmp; - if (this->gatherC) + std::vector B_tmp(max_colA * LDA); + for (int ip = 0; ip < col_nproc; ip++) + { + if (col_rank != ip) { - C_tmp.resize(size_C_local); - if (std::is_same::value) - { - C_local = nullptr; - resmem_dev_op()(C_local, size_C_local); - } - else - { - C_local = C_tmp.data(); - } - syncmem_dev_op()(C_local, C + displs[col_rank], size_C_local); + int size = ncolA * LDA; + Parallel_Common::isend_dev(A, size, ip, 0, col_world, &requests[ip], B_tmp.data()); } + } - T* Atmp_device = nullptr; + T* C_local = C; + std::vector C_tmp; + if (this->gatherC) + { + C_tmp.resize(size_C_local); if (std::is_same::value) { - resmem_dev_op()(Atmp_device, max_colA * LDA); + C_local = nullptr; + resmem_dev_op()(C_local, size_C_local); } else { - Atmp_device = A_tmp.data(); + C_local = C_tmp.data(); } + syncmem_dev_op()(C_local, C + displs[col_rank], size_C_local); + } - int shift = 0; - T real_beta = row_rank == 0 ? beta : 0; - for (int ip = 0; ip < col_nproc; ip++) + T* Atmp_device = nullptr; + if (std::is_same::value) + { + resmem_dev_op()(Atmp_device, max_colA * LDA); + } + else + { + Atmp_device = B_tmp.data(); + } + + int shift = 0; + T real_beta = row_rank == 0 ? beta : 0; + for (int ip = 0; ip < col_nproc; ip++) + { + T* C_start = C_local + shift; + if (col_rank == ip) + { + ModuleBase::gemm_op()(ctx, + 'C', + 'N', + ncolA, + ncolB, + nrow, + &alpha, + A, + LDA, + B, + LDB, + &real_beta, + C_start, + LDC); + shift += ncolA; + } + else + { + int m = colA_loc[ip]; + int size = m * LDA; + MPI_Status status; + Parallel_Common::recv_dev(Atmp_device, size, ip, 0, col_world, &status, B_tmp.data()); + MPI_Wait(&requests[ip], &status); + ModuleBase::gemm_op()(ctx, + 'C', + 'N', + m, + ncolB, + nrow, + &alpha, + Atmp_device, + LDA, + B, + LDB, + &real_beta, + C_start, + LDC); + shift += m; + } + } + + if (this->gatherC) + { + T* Cglobal_cpu = nullptr; + T* Clocal_cpu = C_tmp.data(); + ; + if (std::is_same::value) { - T* C_start = C_local + shift; - if (col_rank == ip) - { - ModuleBase::gemm_op()(ctx, - 'C', - 'N', - ncolA, - ncolB, - nrow, - &alpha, - A, - LDA, - B, - LDB, - &real_beta, - C_start, - LDC); - shift += ncolA; - } - else - { - int m = colA_loc[ip]; - int size = m * LDA; - MPI_Status status; - Parallel_Common::recv_dev(Atmp_device, size, ip, 0, col_world, &status, A_tmp.data()); - MPI_Wait(&requests[ip], &status); - ModuleBase::gemm_op()(ctx, - 'C', - 'N', - m, - ncolB, - nrow, - &alpha, - Atmp_device, - LDA, - B, - LDB, - &real_beta, - C_start, - LDC); - shift += m; - } - } - - if (this->gatherC) - { - T* Cglobal_cpu = nullptr; - T* Clocal_cpu = C_tmp.data();; - if (std::is_same::value) - { - delmem_dev_op()(Atmp_device); - - syncmem_d2h_op()(Clocal_cpu, C_local, size_C_local); - delmem_dev_op()(C_local); - - resmem_dev_op()(Cglobal_cpu, size_C_global); - } - else - { - Cglobal_cpu = C; - } - if (this->row_nproc > 1) - { - Parallel_Common::reduce_data(Clocal_cpu, size_C_local, row_world); - } - Parallel_Common::gatherv_data(Clocal_cpu, - size_C_local, - Cglobal_cpu, - recv_counts.data(), - displs.data(), - col_world); - - if (std::is_same::value) - { - syncmem_h2d_op()(C, Cglobal_cpu, size_C_global); - delmem_dev_op()(Cglobal_cpu); - } + delmem_dev_op()(Atmp_device); + + syncmem_d2h_op()(Clocal_cpu, C_local, size_C_local); + delmem_dev_op()(C_local); + + resmem_dev_op()(Cglobal_cpu, size_C_global); } else { - if (this->row_nproc > 1) - { - Parallel_Common::reduce_dev(C, size_C_local, row_world); - } + Cglobal_cpu = C; + } + if (this->row_nproc > 1) + { + Parallel_Common::reduce_data(Clocal_cpu, size_C_local, row_world); + } + Parallel_Common::gatherv_data(Clocal_cpu, + size_C_local, + Cglobal_cpu, + recv_counts.data(), + displs.data(), + col_world); + + if (std::is_same::value) + { + syncmem_h2d_op()(C, Cglobal_cpu, size_C_global); + delmem_dev_op()(Cglobal_cpu); } } else { - T real_beta = row_rank == 0 ? beta : 0; -#else - T real_beta = beta; -#endif - ModuleBase::gemm_op()(ctx, 'C', 'N', ncolA, ncolB, nrow, &alpha, A, LDA, B, LDB, &real_beta, C, LDC); -#ifdef __MPI if (this->row_nproc > 1) { Parallel_Common::reduce_dev(C, size_C_local, row_world); } } -#endif } +template +void PGemmCN::multiply_row(const T alpha, const T* A, const T* B, const T beta, T* C) +{ + const Device* ctx = {}; + + std::vector B_tmp(max_colB * LDB); + for (int ip = 0; ip < col_nproc; ip++) + { + if (col_rank != ip) + { + int size = ncolB * LDB; + Parallel_Common::isend_dev(B, size, ip, 0, col_world, &requests[ip], B_tmp.data()); + } + } + + std::vector C_tmp; + + T* Btmp_device = nullptr; + if (std::is_same::value) + { + resmem_dev_op()(Btmp_device, max_colB * LDB); + } + else + { + Btmp_device = B_tmp.data(); + } + + int shift = 0; + T real_beta = row_rank == 0 ? beta : 0; + for (int ip = 0; ip < col_nproc; ip++) + { + T* C_start = C + shift; + if (col_rank == ip) + { + ModuleBase::gemm_op()(ctx, + 'C', + 'N', + ncolA, + ncolB, + nrow, + &alpha, + A, + LDA, + B, + LDB, + &real_beta, + C_start, + LDC); + shift += ncolB * LDC; + } + else + { + int m = colB_loc[ip]; + int size = m * LDB; + MPI_Status status; + Parallel_Common::recv_dev(Btmp_device, size, ip, 0, col_world, &status, B_tmp.data()); + MPI_Wait(&requests[ip], &status); + ModuleBase::gemm_op()(ctx, + 'C', + 'N', + ncolA, + m, + nrow, + &alpha, + A, + LDA, + Btmp_device, + LDB, + &real_beta, + C_start, + LDC); + shift += m * LDC; + } + } + if (this->row_nproc > 1) + { + Parallel_Common::reduce_dev(C, size_C_local, row_world); + } +} +#endif + template class PGemmCN; template class PGemmCN; template class PGemmCN, base_device::DEVICE_CPU>; diff --git a/source/module_base/para_gemm.h b/source/module_base/para_gemm.h index 69ffd6d146..f00d3cd975 100644 --- a/source/module_base/para_gemm.h +++ b/source/module_base/para_gemm.h @@ -36,7 +36,7 @@ class PGemmCN * @param LDB leading dimension of B in each proc * @param nrow number of rows of A or B * @param LDC leading dimension of C. C can be C_local or C_global - * @param gatherC whether gather C_local to C_global + * @param mode 1: gather C_local to C_global, 2:C_local(nrow * ncol_loc), 3:C_global(nrow_loc * ncol) */ void set_dimension( #ifdef __MPI @@ -49,7 +49,7 @@ class PGemmCN const int LDB, const int nrow, const int LDC, - const bool gatherC = true); + const int mode = 1); /** * @brief calculate C = alpha * A^H * B + beta * C @@ -67,7 +67,8 @@ class PGemmCN std::vector colA_loc; ///< [col_nproc] number of columns of A matrix in each proc int max_colA = 0; ///< maximum number of columns of A matrix in all procs - std::vector colB_loc; ///<[col_nproc] number of columns of B matrix in each proc + std::vector colB_loc; ///< [col_nproc] number of columns of B matrix in each proc + int max_colB = 0; ///< maximum number of columns of B matrix in all procs std::vector requests; ///< MPI request std::vector recv_counts; ///< receive counts for gathering C_local to C_global @@ -75,6 +76,7 @@ class PGemmCN int size_C_local = 0; ///< size of C_local, which is a local matrix in each proc int size_C_global = 0; ///< size of C_global, which is the global C matrix gathered from all procs bool gatherC = true; ///< whether gather C_local to C_global + bool divideCrow = false; ///< whether divide C_global to C_local #endif int ncolA = 0; ///< number of columns of A, which is a local matrix in each proc int ncolB = 0; ///< number of columns of B, which is a local matrix in each proc @@ -83,6 +85,14 @@ class PGemmCN int LDB = 0; ///< leading dimension of B in each proc int LDC = 0; ///< leading dimension of C, which can be C_local or C_global private: + /// @brief for col_nproc == 1 + void multiply_single(const T alpha, const T* A, const T* B, const T beta, T* C); +#ifdef __MPI + /// @brief for mode = 1 or 2 + void multiply_col(const T alpha, const T* A, const T* B, const T beta, T* C); + /// @brief for mode = 3 + void multiply_row(const T alpha, const T* A, const T* B, const T beta, T* C); +#endif using resmem_dev_op = base_device::memory::resize_memory_op; using delmem_dev_op = base_device::memory::delete_memory_op; using syncmem_dev_op = base_device::memory::synchronize_memory_op; diff --git a/source/module_base/parallel_comm.cpp b/source/module_base/parallel_comm.cpp index 7ede7efe4f..5d03447b5a 100644 --- a/source/module_base/parallel_comm.cpp +++ b/source/module_base/parallel_comm.cpp @@ -3,10 +3,10 @@ #include "mpi.h" #include "parallel_global.h" -MPI_Comm POOL_WORLD; -MPI_Comm INTER_POOL; // communicator among different pools -MPI_Comm STO_WORLD; -MPI_Comm PARAPW_WORLD; +MPI_Comm POOL_WORLD; //groups for different plane waves. In this group, only plane waves are different. K-points and bands are the same. +MPI_Comm KP_WORLD; // groups for differnt k. In this group, only k-points are different. Bands and plane waves are the same. +MPI_Comm BP_WORLD; // groups for differnt bands. In this group, only bands are different. K-points and plane waves are the same. +MPI_Comm INT_BGROUP; // internal comm groups for same bands. In this group, only bands are the same. K-points and plane waves are different. MPI_Comm GRID_WORLD; // mohan add 2012-01-13 MPI_Comm DIAG_WORLD; // mohan add 2012-01-13 diff --git a/source/module_base/parallel_comm.h b/source/module_base/parallel_comm.h index c05772fa92..d76b83e111 100644 --- a/source/module_base/parallel_comm.h +++ b/source/module_base/parallel_comm.h @@ -4,9 +4,9 @@ #ifdef __MPI #include "mpi.h" extern MPI_Comm POOL_WORLD; -extern MPI_Comm INTER_POOL; // communicator among different pools -extern MPI_Comm STO_WORLD; -extern MPI_Comm PARAPW_WORLD; +extern MPI_Comm KP_WORLD; // communicator among different pools +extern MPI_Comm INT_BGROUP; +extern MPI_Comm BP_WORLD; extern MPI_Comm GRID_WORLD; // mohan add 2012-01-13 extern MPI_Comm DIAG_WORLD; // mohan add 2012-01-13 diff --git a/source/module_base/parallel_global.cpp b/source/module_base/parallel_global.cpp index b9d87f2700..720fc66ec1 100644 --- a/source/module_base/parallel_global.cpp +++ b/source/module_base/parallel_global.cpp @@ -237,12 +237,12 @@ void Parallel_Global::read_pal_param(int argc, void Parallel_Global::finalize_mpi() { MPI_Comm_free(&POOL_WORLD); - if (INTER_POOL != MPI_COMM_NULL) + if (KP_WORLD != MPI_COMM_NULL) { - MPI_Comm_free(&INTER_POOL); + MPI_Comm_free(&KP_WORLD); } - MPI_Comm_free(&STO_WORLD); - MPI_Comm_free(&PARAPW_WORLD); + MPI_Comm_free(&INT_BGROUP); + MPI_Comm_free(&BP_WORLD); MPI_Comm_free(&GRID_WORLD); MPI_Comm_free(&DIAG_WORLD); MPI_Finalize(); @@ -324,7 +324,7 @@ void Parallel_Global::divide_pools(const int& NPROC, int& MY_POOL) { // note: the order of k-point parallelization and band parallelization is important - // The order will not change the behavior of INTER_POOL or PARAPW_WORLD, and MY_POOL + // The order will not change the behavior of KP_WORLD or BP_WORLD, and MY_POOL // and MY_BNDGROUP will be the same as well. if(BNDPAR > 1 && NPROC %(BNDPAR * KPAR) != 0) { @@ -349,11 +349,11 @@ void Parallel_Global::divide_pools(const int& NPROC, MPI_Comm_dup(bndpar_group.group_comm, &POOL_WORLD); if(kpar_group.inter_comm != MPI_COMM_NULL) { - MPI_Comm_dup(kpar_group.inter_comm, &INTER_POOL); + MPI_Comm_dup(kpar_group.inter_comm, &KP_WORLD); } else { - INTER_POOL = MPI_COMM_NULL; + KP_WORLD = MPI_COMM_NULL; } if(BNDPAR > 1) @@ -361,16 +361,16 @@ void Parallel_Global::divide_pools(const int& NPROC, NPROC_IN_BNDGROUP = kpar_group.ngroups * bndpar_group.nprocs_in_group; RANK_IN_BPGROUP = kpar_group.my_group * bndpar_group.nprocs_in_group + bndpar_group.rank_in_group; MY_BNDGROUP = bndpar_group.my_group; - MPI_Comm_split(MPI_COMM_WORLD, MY_BNDGROUP, RANK_IN_BPGROUP, &STO_WORLD); - MPI_Comm_dup(bndpar_group.inter_comm, &PARAPW_WORLD); + MPI_Comm_split(MPI_COMM_WORLD, MY_BNDGROUP, RANK_IN_BPGROUP, &INT_BGROUP); + MPI_Comm_dup(bndpar_group.inter_comm, &BP_WORLD); } else { NPROC_IN_BNDGROUP = NPROC; RANK_IN_BPGROUP = MY_RANK; MY_BNDGROUP = 0; - MPI_Comm_dup(MPI_COMM_WORLD, &STO_WORLD); - MPI_Comm_split(MPI_COMM_WORLD, MY_RANK, 0, &PARAPW_WORLD); + MPI_Comm_dup(MPI_COMM_WORLD, &INT_BGROUP); + MPI_Comm_split(MPI_COMM_WORLD, MY_RANK, 0, &BP_WORLD); } return; } diff --git a/source/module_base/parallel_reduce.cpp b/source/module_base/parallel_reduce.cpp index 1a67bbd3bc..c44bd8fb66 100644 --- a/source/module_base/parallel_reduce.cpp +++ b/source/module_base/parallel_reduce.cpp @@ -110,9 +110,9 @@ void Parallel_Reduce::reduce_pool(double* object, const int n) // (1) the value is same in each pool. // (2) we need to reduce the value from different pool. -void Parallel_Reduce::reduce_double_allpool(const int& kpar, const int& nproc_in_pool, double& object) +void Parallel_Reduce::reduce_double_allpool(const int& npool, const int& nproc_in_pool, double& object) { - if (kpar == 1) + if (npool == 1) { return; } @@ -124,9 +124,9 @@ void Parallel_Reduce::reduce_double_allpool(const int& kpar, const int& nproc_in // (1) the value is same in each pool. // (2) we need to reduce the value from different pool. -void Parallel_Reduce::reduce_double_allpool(const int& kpar, const int& nproc_in_pool, double* object, const int n) +void Parallel_Reduce::reduce_double_allpool(const int& npool, const int& nproc_in_pool, double* object, const int n) { - if (kpar == 1) + if (npool == 1) { return; } diff --git a/source/module_base/parallel_reduce.h b/source/module_base/parallel_reduce.h index f120cfd1cd..2d612efc86 100644 --- a/source/module_base/parallel_reduce.h +++ b/source/module_base/parallel_reduce.h @@ -31,8 +31,8 @@ void reduce_int_grid(int* object, const int n); // mohan add 2012-01-12 void reduce_double_grid(double* object, const int n); void reduce_double_diag(double* object, const int n); -void reduce_double_allpool(const int& kpar, const int& nproc_in_pool, double& object); -void reduce_double_allpool(const int& kpar, const int& nproc_in_pool, double* object, const int n); +void reduce_double_allpool(const int& npool, const int& nproc_in_pool, double& object); +void reduce_double_allpool(const int& npool, const int& nproc_in_pool, double* object, const int n); void gather_min_int_all(const int& nproc, int& v); void gather_max_double_all(const int& nproc, double& v); diff --git a/source/module_base/test_parallel/test_para_gemm.cpp b/source/module_base/test_parallel/test_para_gemm.cpp index 4b6445d057..78fdc74b1c 100644 --- a/source/module_base/test_parallel/test_para_gemm.cpp +++ b/source/module_base/test_parallel/test_para_gemm.cpp @@ -367,7 +367,49 @@ TYPED_TEST(PgemmTest, odd_case) this->compare_result(ncolA_global, ncolB_global, LDC_global); } -TYPED_TEST(PgemmTest, odd_case_not_gather) +TYPED_TEST(PgemmTest, row_parallel) +{ + const int ncolA_global = 17, ncolB_global = 7, nrow_global = 13; + const int LDA_global = 17, LDB_global = 18, LDC_global = 19; + + this->decide_ngroup(1, 4); + this->prepare(ncolA_global, ncolB_global, nrow_global, LDA_global, LDB_global, LDC_global); + + this->pgemm.set_dimension(this->col_world, + this->row_world, + this->ncolA, + this->LDA, + this->ncolB, + this->LDB, + this->nrow, + LDC_global); + this->pgemm.multiply(this->alpha, this->A_local.data(), this->B_local.data(), this->beta, this->C_global.data()); + + this->compare_result(ncolA_global, ncolB_global, LDC_global); +} + +TYPED_TEST(PgemmTest, col_parallel) +{ + const int ncolA_global = 17, ncolB_global = 7, nrow_global = 13; + const int LDA_global = 17, LDB_global = 18, LDC_global = 19; + + this->decide_ngroup(4, 1); + this->prepare(ncolA_global, ncolB_global, nrow_global, LDA_global, LDB_global, LDC_global); + + this->pgemm.set_dimension(this->col_world, + this->row_world, + this->ncolA, + this->LDA, + this->ncolB, + this->LDB, + this->nrow, + LDC_global); + this->pgemm.multiply(this->alpha, this->A_local.data(), this->B_local.data(), this->beta, this->C_global.data()); + + this->compare_result(ncolA_global, ncolB_global, LDC_global); +} + +TYPED_TEST(PgemmTest, divide_col) { const int ncolA_global = 17, ncolB_global = 7, nrow_global = 13; const int LDA_global = 17, LDB_global = 18, LDC_global = 19; @@ -392,7 +434,7 @@ TYPED_TEST(PgemmTest, odd_case_not_gather) this->LDB, this->nrow, LDC_global, - false); + 2); this->pgemm.multiply(this->alpha, this->A_local.data(), this->B_local.data(), this->beta, this->C_global.data()+ start); @@ -408,34 +450,32 @@ TYPED_TEST(PgemmTest, odd_case_not_gather) } } -TYPED_TEST(PgemmTest, row_parallel) +TYPED_TEST(PgemmTest, divide_row) { const int ncolA_global = 17, ncolB_global = 7, nrow_global = 13; const int LDA_global = 17, LDB_global = 18, LDC_global = 19; - this->decide_ngroup(1, 4); + this->decide_ngroup(2, 2); this->prepare(ncolA_global, ncolB_global, nrow_global, LDA_global, LDB_global, LDC_global); + std::vector colA_loc(this->nproc_col); + MPI_Allgather(&this->ncolA, 1, MPI_INT, colA_loc.data(), 1, MPI_INT, this->col_world); + std::vector displs(this->nproc_col); + displs[0] = 0; + for (int i = 1; i < this->nproc_col; i++) + { + displs[i] = (displs[i - 1] + colA_loc[i - 1]); + } + int start = displs[this->rank_col]; - this->pgemm.set_dimension(this->col_world, - this->row_world, - this->ncolA, - this->LDA, - this->ncolB, - this->LDB, - this->nrow, - LDC_global); - this->pgemm.multiply(this->alpha, this->A_local.data(), this->B_local.data(), this->beta, this->C_global.data()); - - this->compare_result(ncolA_global, ncolB_global, LDC_global); -} - -TYPED_TEST(PgemmTest, col_parallel) -{ - const int ncolA_global = 17, ncolB_global = 7, nrow_global = 13; - const int LDA_global = 17, LDB_global = 18, LDC_global = 19; - - this->decide_ngroup(4, 1); - this->prepare(ncolA_global, ncolB_global, nrow_global, LDA_global, LDB_global, LDC_global); + int LDC_local = this->ncolA + 2; + std::vector C_loc(LDC_local * ncolB_global, 0.0); + for(int i = 0; i < ncolB_global; i++) + { + for(int j = 0; j < this->ncolA; j++) + { + C_loc[i * LDC_local + j] = this->C_global[i * LDC_global + start + j]; + } + } this->pgemm.set_dimension(this->col_world, this->row_world, @@ -444,10 +484,21 @@ TYPED_TEST(PgemmTest, col_parallel) this->ncolB, this->LDB, this->nrow, - LDC_global); - this->pgemm.multiply(this->alpha, this->A_local.data(), this->B_local.data(), this->beta, this->C_global.data()); + LDC_local, + 3); + this->pgemm.multiply(this->alpha, this->A_local.data(), this->B_local.data(), this->beta, C_loc.data()); - this->compare_result(ncolA_global, ncolB_global, LDC_global); + + + for (int i = 0; i < ncolB_global; i++) + { + for (int j = 0; j < this->ncolA; j++) + { + EXPECT_NEAR(get_double(this->Cref_global[i * LDC_global + start + j]), + get_double(C_loc[i * LDC_local + j]), + 1e-10); + } + } } int main(int argc, char** argv) diff --git a/source/module_elecstate/elecstate.cpp b/source/module_elecstate/elecstate.cpp index 1efcaff554..4c06c99213 100644 --- a/source/module_elecstate/elecstate.cpp +++ b/source/module_elecstate/elecstate.cpp @@ -163,8 +163,8 @@ void ElecState::calculate_weights() this->klist->isk); } #ifdef __MPI - // qianrui fix a bug on 2021-7-21 - Parallel_Reduce::reduce_double_allpool(GlobalV::KPAR, GlobalV::NPROC_IN_POOL, this->f_en.demet); + const int npool = GlobalV::KPAR * PARAM.inp.bndpar; + Parallel_Reduce::reduce_double_allpool(npool, GlobalV::NPROC_IN_POOL, this->f_en.demet); #endif } else if (Occupy::fixed_occupations) @@ -192,16 +192,11 @@ void ElecState::calEBand() } } this->f_en.eband = eband; - if (GlobalV::KPAR != 1 && PARAM.inp.esolver_type != "sdft") - { - //================================== - // Reduce all the Energy in each cpu - //================================== - this->f_en.eband /= GlobalV::NPROC_IN_POOL; + #ifdef __MPI - Parallel_Reduce::reduce_all(this->f_en.eband); + const int npool = GlobalV::KPAR * PARAM.inp.bndpar; + Parallel_Reduce::reduce_double_allpool(npool, GlobalV::NPROC_IN_POOL, this->f_en.eband); #endif - } return; } @@ -253,8 +248,8 @@ void ElecState::init_ks(Charge* chg_in, // pointer for class Charge // init nelec_spin with nelec and nupdown this->init_nelec_spin(); // initialize ekb and wg - this->ekb.create(nk_in, PARAM.inp.nbands); - this->wg.create(nk_in, PARAM.inp.nbands); + this->ekb.create(nk_in, PARAM.globalv.nbands_l); + this->wg.create(nk_in, PARAM.globalv.nbands_l); } } // namespace elecstate diff --git a/source/module_elecstate/elecstate_pw_sdft.cpp b/source/module_elecstate/elecstate_pw_sdft.cpp index 42952a7670..269efc06f2 100644 --- a/source/module_elecstate/elecstate_pw_sdft.cpp +++ b/source/module_elecstate/elecstate_pw_sdft.cpp @@ -19,7 +19,7 @@ void ElecStatePW_SDFT::psiToRho(const psi::Psi& psi) setmem_var_op()(this->rho[is], 0, this->charge->nrxx); } - if (GlobalV::MY_BNDGROUP == 0 || PARAM.inp.ks_solver == "bpcg") + if (PARAM.globalv.ks_run) { for (int ik = 0; ik < psi.get_nk(); ++ik) { diff --git a/source/module_elecstate/module_charge/charge.h b/source/module_elecstate/module_charge/charge.h index 926bffebea..3d18ee7134 100644 --- a/source/module_elecstate/module_charge/charge.h +++ b/source/module_elecstate/module_charge/charge.h @@ -134,7 +134,7 @@ class Charge /** * @brief Reduce among different pools - * If NPROC_IN_POOLs are all the same, use GlobalV::INTER_POOL + * If NPROC_IN_POOLs are all the same, use GlobalV::KP_WORLD * else, gather rho in a POOL, and then reduce among different POOLs * * @param array_rho f(rho): an array [nrxx] @@ -161,7 +161,6 @@ class Charge bool allocate_rho_final_scf; // LiuXh add 20180606 #ifdef __MPI private: - bool use_intel_pool = false; //use INTER_POOL when NPROC_IN_POOLs are all the same int *rec = nullptr; //The number of elements each process should receive into the receive buffer. int *dis = nullptr; //The displacement (relative to recvbuf) for each process in the receive buffer. #endif diff --git a/source/module_elecstate/module_charge/charge_mpi.cpp b/source/module_elecstate/module_charge/charge_mpi.cpp index 61f01ea58b..e9731ee99e 100644 --- a/source/module_elecstate/module_charge/charge_mpi.cpp +++ b/source/module_elecstate/module_charge/charge_mpi.cpp @@ -8,13 +8,8 @@ #ifdef __MPI void Charge::init_chgmpi() { - if (INTER_POOL != MPI_COMM_NULL) + if (KP_WORLD == MPI_COMM_NULL) { - this->use_intel_pool = true; - } - else - { - this->use_intel_pool = false; delete[] rec; rec = new int[GlobalV::NPROC_IN_POOL]; delete[] dis; @@ -33,9 +28,9 @@ void Charge::reduce_diff_pools(double* array_rho) const { ModuleBase::TITLE("Charge", "reduce_diff_pools"); ModuleBase::timer::tick("Charge", "reduce_diff_pools"); - if (this->use_intel_pool) + if (KP_WORLD != MPI_COMM_NULL) { - MPI_Allreduce(MPI_IN_PLACE, array_rho, this->nrxx, MPI_DOUBLE, MPI_SUM, INTER_POOL); + MPI_Allreduce(MPI_IN_PLACE, array_rho, this->nrxx, MPI_DOUBLE, MPI_SUM, KP_WORLD); } else { @@ -97,11 +92,7 @@ void Charge::reduce_diff_pools(double* array_rho) const //================================== // Reduce all the rho in each cpu //================================== - if (PARAM.inp.esolver_type == "sdft") { // qinarui add it temporarily. - MPI_Allreduce(array_tot_aux, array_tot, this->rhopw->nxyz, MPI_DOUBLE, MPI_SUM, STO_WORLD); - } else { - MPI_Allreduce(array_tot_aux, array_tot, this->rhopw->nxyz, MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD); -} + MPI_Allreduce(array_tot_aux, array_tot, this->rhopw->nxyz, MPI_DOUBLE, MPI_SUM, INT_BGROUP); //===================================== // Change the order of rho in each cpu @@ -118,6 +109,10 @@ void Charge::reduce_diff_pools(double* array_rho) const delete[] array_tot; delete[] array_tmp; } + if(PARAM.globalv.all_ks_run && PARAM.inp.bndpar > 1) + { + MPI_Allreduce(MPI_IN_PLACE, array_rho, this->nrxx, MPI_DOUBLE, MPI_SUM, BP_WORLD); + } ModuleBase::timer::tick("Charge", "reduce_diff_pools"); } diff --git a/source/module_elecstate/occupy.cpp b/source/module_elecstate/occupy.cpp index c4cc61213d..4ed3c6a8ac 100644 --- a/source/module_elecstate/occupy.cpp +++ b/source/module_elecstate/occupy.cpp @@ -228,7 +228,7 @@ void Occupy::gweights(const int nks, if (is != -1 && is != isk[ik]) continue; - for (int ib = 0; ib < PARAM.inp.nbands; ib++) + for (int ib = 0; ib < PARAM.globalv.nbands_l; ib++) { //================================ // Calculate the gaussian weights @@ -417,7 +417,8 @@ double Occupy::sumkg(const ModuleBase::matrix& ekb, // GlobalV::ofs_running << "\n sum2 before reduce = " << sum2 << std::endl; #ifdef __MPI - Parallel_Reduce::reduce_double_allpool(GlobalV::KPAR, GlobalV::NPROC_IN_POOL, sum2); + const int npool = GlobalV::KPAR * PARAM.inp.bndpar; + Parallel_Reduce::reduce_double_allpool(npool, GlobalV::NPROC_IN_POOL, sum2); #endif // GlobalV::ofs_running << "\n sum2 after reduce = " << sum2 << std::endl; diff --git a/source/module_esolver/esolver_ks.cpp b/source/module_esolver/esolver_ks.cpp index 6466f3be63..5a0c7920d7 100644 --- a/source/module_esolver/esolver_ks.cpp +++ b/source/module_esolver/esolver_ks.cpp @@ -361,7 +361,7 @@ void ESolver_KS::hamilt2density(UnitCell& ucell, const int istep, con // Maybe in the future, density and wavefunctions should use different // parallel algorithms, in which they do not occupy all processors, for // example wavefunctions uses 20 processors while density uses 10. - if (GlobalV::MY_BNDGROUP == 0) + if (PARAM.globalv.ks_run) { // double drho = this->estate.caldr2(); // EState should be used after it is constructed. @@ -550,7 +550,7 @@ void ESolver_KS::iter_finish(UnitCell& ucell, const int istep, int& i this->pelec->charge->rho, this->pelec->nelec_spin.data()); - if (GlobalV::MY_BNDGROUP == 0) + if (PARAM.globalv.ks_run) { // mixing will restart at this->p_chgmix->mixing_restart steps if (drho <= PARAM.inp.mixing_restart && PARAM.inp.mixing_restart > 0.0 @@ -634,9 +634,9 @@ void ESolver_KS::iter_finish(UnitCell& ucell, const int istep, int& i } #ifdef __MPI - MPI_Bcast(&drho, 1, MPI_DOUBLE, 0, PARAPW_WORLD); - MPI_Bcast(&this->conv_esolver, 1, MPI_DOUBLE, 0, PARAPW_WORLD); - MPI_Bcast(pelec->charge->rho[0], this->pw_rhod->nrxx, MPI_DOUBLE, 0, PARAPW_WORLD); + MPI_Bcast(&drho, 1, MPI_DOUBLE, 0, BP_WORLD); + MPI_Bcast(&this->conv_esolver, 1, MPI_DOUBLE, 0, BP_WORLD); + MPI_Bcast(pelec->charge->rho[0], this->pw_rhod->nrxx, MPI_DOUBLE, 0, BP_WORLD); #endif // update potential diff --git a/source/module_esolver/esolver_ks_pw.cpp b/source/module_esolver/esolver_ks_pw.cpp index 2110cd76fc..0961675029 100644 --- a/source/module_esolver/esolver_ks_pw.cpp +++ b/source/module_esolver/esolver_ks_pw.cpp @@ -212,7 +212,7 @@ void ESolver_KS_PW::before_all_runners(UnitCell& ucell, const Input_p this->kv, this->ppcell, *this->pw_wfc); - allocate_psi(this->psi, this->kv.get_nks(), this->kv.ngk, PARAM.inp.nbands, this->pw_wfc->npwk_max); + allocate_psi(this->psi, this->kv.get_nks(), this->kv.ngk, PARAM.globalv.nbands_l, this->pw_wfc->npwk_max); this->p_psi_init->prepare_init(PARAM.inp.pw_seed); this->kspw_psi = PARAM.inp.device == "gpu" || PARAM.inp.precision == "single" @@ -228,7 +228,7 @@ void ESolver_KS_PW::before_all_runners(UnitCell& ucell, const Input_p //! 9) setup occupations if (PARAM.inp.ocp) { - this->pelec->fixed_weights(PARAM.inp.ocp_kb, PARAM.inp.nbands, PARAM.inp.nelec); + this->pelec->fixed_weights(PARAM.inp.ocp_kb, PARAM.globalv.nbands_l, PARAM.inp.nelec); } } @@ -563,7 +563,7 @@ void ESolver_KS_PW::update_pot(UnitCell& ucell, const int istep, cons this->pelec->pot->update_from_charge(this->pelec->charge, &ucell); this->pelec->f_en.descf = this->pelec->cal_delta_escf(); #ifdef __MPI - MPI_Bcast(&(this->pelec->f_en.descf), 1, MPI_DOUBLE, 0, PARAPW_WORLD); + MPI_Bcast(&(this->pelec->f_en.descf), 1, MPI_DOUBLE, 0, BP_WORLD); #endif } else diff --git a/source/module_esolver/esolver_sdft_pw.cpp b/source/module_esolver/esolver_sdft_pw.cpp index c4dd8d821b..a4c20d37d8 100644 --- a/source/module_esolver/esolver_sdft_pw.cpp +++ b/source/module_esolver/esolver_sdft_pw.cpp @@ -198,7 +198,7 @@ void ESolver_SDFT_PW::hamilt2density_single(UnitCell& ucell, int iste // set_diagethr need it this->esolver_KS_ne = hsolver_pw_sdft_obj.stoiter.KS_ne; - if (GlobalV::MY_BNDGROUP == 0) + if (PARAM.globalv.ks_run) { Symmetry_rho srho; for (int is = 0; is < PARAM.inp.nspin; is++) @@ -217,7 +217,7 @@ void ESolver_SDFT_PW::hamilt2density_single(UnitCell& ucell, int iste #endif } #ifdef __MPI - MPI_Bcast(&(this->pelec->f_en.deband), 1, MPI_DOUBLE, 0, PARAPW_WORLD); + MPI_Bcast(&(this->pelec->f_en.deband), 1, MPI_DOUBLE, 0, BP_WORLD); #endif ModuleBase::timer::tick("ESolver_SDFT_PW", "hamilt2density"); } diff --git a/source/module_hamilt_pw/hamilt_pwdft/onsite_projector.cpp b/source/module_hamilt_pw/hamilt_pwdft/onsite_projector.cpp index 09982d8e06..348ada6633 100644 --- a/source/module_hamilt_pw/hamilt_pwdft/onsite_projector.cpp +++ b/source/module_hamilt_pw/hamilt_pwdft/onsite_projector.cpp @@ -571,7 +571,8 @@ void projectors::OnsiteProjector::cal_occupations(const psi::Psiwhichpro[ipool][iz], iz, STO_WORLD); + MPI_Send(zpiece, ncxy, MPI_DOUBLE, this->whichpro[ipool][iz], iz, INT_BGROUP); } } @@ -309,7 +309,7 @@ void Parallel_Grid::zpiece_to_stogroup(double *zpiece, const int &iz, double *rh // and the receive tag is iz. else if(proc == GlobalV::RANK_IN_POOL ) { - MPI_Recv(zpiece, ncxy, MPI_DOUBLE, 0, iz, STO_WORLD,&ierror); + MPI_Recv(zpiece, ncxy, MPI_DOUBLE, 0, iz, INT_BGROUP,&ierror); for(int ir=0; irwhichpro[ipool][iz], iz, STO_WORLD); + MPI_Send(zpiece, ncxy, MPI_DOUBLE, this->whichpro[ipool][iz], iz, INT_BGROUP); } } }// MY_POOL == 0 @@ -335,7 +335,7 @@ void Parallel_Grid::zpiece_to_stogroup(double *zpiece, const int &iz, double *rh // processor 0. the tag is 'iz' if(proc == GlobalV::RANK_IN_BPGROUP ) { - MPI_Recv(zpiece, ncxy, MPI_DOUBLE, 0, iz, STO_WORLD,&ierror); + MPI_Recv(zpiece, ncxy, MPI_DOUBLE, 0, iz, INT_BGROUP,&ierror); for(int ir=0; ir>& kspsi_all, ks_fact->nrecv, ks_fact->displs, MPI_DOUBLE_COMPLEX, - PARAPW_WORLD); + BP_WORLD); } #endif @@ -629,10 +629,10 @@ void Sto_EleCond::sKG(const int& smear_type, int allbands_sto = perbands_sto; int allbands = perbands; #ifdef __MPI - MPI_Allreduce(&perbands, &allbands, 1, MPI_INT, MPI_SUM, PARAPW_WORLD); + MPI_Allreduce(&perbands, &allbands, 1, MPI_INT, MPI_SUM, BP_WORLD); allbands_sto = allbands - allbands_ks; - info_gatherv ks_fact(perbands_ks, PARAM.inp.bndpar, 1, PARAPW_WORLD); - info_gatherv sto_npwx(perbands_sto, PARAM.inp.bndpar, npwx, PARAPW_WORLD); + info_gatherv ks_fact(perbands_ks, PARAM.inp.bndpar, 1, BP_WORLD); + info_gatherv sto_npwx(perbands_sto, PARAM.inp.bndpar, npwx, BP_WORLD); #endif const int bandsinfo[6]{perbands_ks, perbands_sto, perbands, allbands_ks, allbands_sto, allbands}; double* en_all = nullptr; diff --git a/source/module_hamilt_pw/hamilt_stodft/sto_forces.cpp b/source/module_hamilt_pw/hamilt_stodft/sto_forces.cpp index bf717c062e..9b942d7af2 100644 --- a/source/module_hamilt_pw/hamilt_stodft/sto_forces.cpp +++ b/source/module_hamilt_pw/hamilt_stodft/sto_forces.cpp @@ -210,7 +210,7 @@ void Sto_Forces::cal_sto_force_nl( const int* nchip = stowf.nchip; const int npwx = wfc_basis->npwk_max; int nksbands = psi_in.get_nbands(); - if (GlobalV::MY_BNDGROUP != 0) + if (!PARAM.globalv.ks_run) { nksbands = 0; } diff --git a/source/module_hamilt_pw/hamilt_stodft/sto_iter.cpp b/source/module_hamilt_pw/hamilt_stodft/sto_iter.cpp index bd029a401d..06a535155a 100644 --- a/source/module_hamilt_pw/hamilt_stodft/sto_iter.cpp +++ b/source/module_hamilt_pw/hamilt_stodft/sto_iter.cpp @@ -1,14 +1,16 @@ #include "sto_iter.h" +#include "module_base/kernels/math_kernel_op.h" +#include "module_base/para_gemm.h" #include "module_base/parallel_reduce.h" #include "module_base/timer.h" #include "module_base/tool_quit.h" #include "module_base/tool_title.h" +#include "module_elecstate/kernels/elecstate_op.h" #include "module_elecstate/occupy.h" #include "module_hamilt_pw/hamilt_pwdft/global.h" +#include "module_hsolver/para_linear_transform.h" #include "module_parameter/parameter.h" -#include "module_base/kernels/math_kernel_op.h" -#include "module_elecstate/kernels/elecstate_op.h" template Stochastic_Iter::Stochastic_Iter() @@ -56,8 +58,10 @@ void Stochastic_Iter::orthog(const int& ik, psi::Psi& psi, { ModuleBase::TITLE("Stochastic_Iter", "orthog"); ModuleBase::timer::tick("Stochastic_Iter", "orthog"); + const int nbands_l = psi.get_nbands(); + const int nbands = PARAM.inp.nbands; // orthogonal part - if (PARAM.inp.nbands > 0) + if (nbands > 0) { const int nchipk = stowf.nchip[ik]; const int npw = psi.get_current_ngk(); @@ -66,49 +70,28 @@ void Stochastic_Iter::orthog(const int& ik, psi::Psi& psi, stowf.chiortho->fix_k(ik); T *wfgin = stowf.chi0->get_pointer(), *wfgout = stowf.chiortho->get_pointer(); cpymem_complex_op()(wfgout, wfgin, npwx * nchipk); - // for (int ig = 0; ig < npwx * nchipk; ++ig) - // { - // wfgout[ig] = wfgin[ig]; - // } // orthogonal part T* sum = nullptr; - resmem_complex_op()(sum, PARAM.inp.nbands * nchipk); - char transC = 'C'; - char transN = 'N'; - + resmem_complex_op()(sum, nbands * nchipk); // sum(b - ModuleBase::gemm_op()(ctx, - transC, - transN, - PARAM.inp.nbands, - nchipk, - npw, - &ModuleBase::ONE, - &psi(ik, 0, 0), - npwx, - wfgout, - npwx, - &ModuleBase::ZERO, - sum, - PARAM.inp.nbands); - Parallel_Reduce::reduce_pool(sum, PARAM.inp.nbands * nchipk); - + ModuleBase::PGemmCN pmmcn; +#ifdef __MPI + pmmcn.set_dimension(BP_WORLD, POOL_WORLD, nbands_l, npwx, nchipk, npwx, npw, nbands, 2); +#else + pmmcn.set_dimension(nbands_l, npwx, nchipk, npwx, npw, nbands, 2); +#endif + pmmcn.multiply(1.0, &psi(ik, 0, 0), wfgout, 0.0, sum); + // psi -= psi * sum - ModuleBase::gemm_op()(ctx, - transN, - transN, - npw, - nchipk, - PARAM.inp.nbands, - &ModuleBase::NEG_ONE, - &psi(ik, 0, 0), - npwx, - sum, - PARAM.inp.nbands, - &ModuleBase::ONE, - wfgout, - npwx); + hsolver::PLinearTransform pltrans; +#ifdef __MPI + pltrans.set_dimension(npw, nbands_l, nchipk, npwx, BP_WORLD, true); +#else + pltrans.set_dimension(npw, nbands_l, nchipk, npwx, true); +#endif + pltrans.act(-1.0, &psi(ik, 0, 0), sum, 1.0, wfgout); + delmem_complex_op()(sum); } ModuleBase::timer::tick("Stochastic_Iter", "orthog"); @@ -152,7 +135,7 @@ void Stochastic_Iter::checkemm(const int& ik, for (int ichi = 0; ichi < ntest; ++ichi) { - if (PARAM.inp.nbands > 0) + if (PARAM.globalv.nbands_l > 0) { pchi = &stowf.chiortho->operator()(ik, ichi, 0); } @@ -329,12 +312,12 @@ void Stochastic_Iter::itermu(const int iter, elecstate::ElecState* pe this->check_precision(targetne, 10 * PARAM.inp.scf_thr, "Ne"); // Set wf.wg - if (PARAM.inp.nbands > 0) + if (PARAM.globalv.nbands_l > 0) { for (int ikk = 0; ikk < this->pkv->get_nks(); ++ikk) { double* en = &pes->ekb(ikk, 0); - for (int iksb = 0; iksb < PARAM.inp.nbands; ++iksb) + for (int iksb = 0; iksb < PARAM.globalv.nbands_l; ++iksb) { pes->wg(ikk, iksb) = stofunc.fd(en[iksb]) * this->pkv->wk[ikk]; } @@ -366,7 +349,7 @@ void Stochastic_Iter::calPn(const int& ik, Stochastic_WF& } } T* pchi; - if (PARAM.inp.nbands > 0) + if (PARAM.globalv.nbands_l > 0) { stowf.chiortho->fix_k(ik); pchi = stowf.chiortho->get_pointer(); @@ -435,12 +418,12 @@ double Stochastic_Iter::calne(elecstate::ElecState* pes) p_che->calcoef_real(nroot_fd); sto_ne = vTMv(p_che->coef_real, spolyv, norder); } - if (PARAM.inp.nbands > 0) + if (PARAM.globalv.nbands_l > 0) { for (int ikk = 0; ikk < this->pkv->get_nks(); ++ikk) { double* en = &pes->ekb(ikk, 0); - for (int iksb = 0; iksb < PARAM.inp.nbands; ++iksb) + for (int iksb = 0; iksb < PARAM.globalv.nbands_l; ++iksb) { KS_ne += stofunc.fd(en[iksb]) * this->pkv->wk[ikk]; } @@ -448,7 +431,11 @@ double Stochastic_Iter::calne(elecstate::ElecState* pes) } KS_ne /= GlobalV::NPROC_IN_POOL; #ifdef __MPI - MPI_Allreduce(MPI_IN_PLACE, &KS_ne, 1, MPI_DOUBLE, MPI_SUM, STO_WORLD); + MPI_Allreduce(MPI_IN_PLACE, &KS_ne, 1, MPI_DOUBLE, MPI_SUM, INT_BGROUP); + if(PARAM.globalv.all_ks_run) + { + MPI_Allreduce(MPI_IN_PLACE, &KS_ne, 1, MPI_DOUBLE, MPI_SUM, BP_WORLD); + } MPI_Allreduce(MPI_IN_PLACE, &sto_ne, 1, MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD); #endif @@ -497,13 +484,13 @@ void Stochastic_Iter::sum_stoeband(Stochastic_WF& stowf, stodemet = -vTMv(p_che->coef_real, spolyv, norder); } - if (PARAM.inp.nbands > 0) + if (PARAM.globalv.nbands_l > 0) { for (int ikk = 0; ikk < this->pkv->get_nks(); ++ikk) { double* enb = &pes->ekb(ikk, 0); // number of electrons in KS orbitals - for (int iksb = 0; iksb < PARAM.inp.nbands; ++iksb) + for (int iksb = 0; iksb < PARAM.globalv.nbands_l; ++iksb) { pes->f_en.demet += stofunc.fdlnfd(enb[iksb]) * this->pkv->wk[ikk]; } @@ -511,7 +498,11 @@ void Stochastic_Iter::sum_stoeband(Stochastic_WF& stowf, } pes->f_en.demet /= GlobalV::NPROC_IN_POOL; #ifdef __MPI - MPI_Allreduce(MPI_IN_PLACE, &pes->f_en.demet, 1, MPI_DOUBLE, MPI_SUM, STO_WORLD); + MPI_Allreduce(MPI_IN_PLACE, &pes->f_en.demet, 1, MPI_DOUBLE, MPI_SUM, INT_BGROUP); + if(PARAM.globalv.all_ks_run) + { + MPI_Allreduce(MPI_IN_PLACE, &pes->f_en.demet, 1, MPI_DOUBLE, MPI_SUM, BP_WORLD); + } MPI_Allreduce(MPI_IN_PLACE, &stodemet, 1, MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD); #endif pes->f_en.demet += stodemet; @@ -576,12 +567,8 @@ void Stochastic_Iter::cal_storho(const UnitCell& ucell, resmem_complex_op()(porter, nrxx); std::vector sto_rho(nspin); - for(int is = 0; is < nspin; ++is) - { - sto_rho[is] = pes->charge->rho[is]; - } std::vector _tmprho; - if (PARAM.inp.nbands > 0) + if (PARAM.globalv.nbands_l > 0) { // If there are KS orbitals, we need to allocate another memory for sto_rho _tmprho.resize(nrxx * nspin); @@ -590,6 +577,13 @@ void Stochastic_Iter::cal_storho(const UnitCell& ucell, sto_rho[is] = _tmprho.data() + is * nrxx; } } + else + { + for (int is = 0; is < nspin; ++is) + { + sto_rho[is] = pes->charge->rho[is]; + } + } // pes->rho is a device memory, and when using cpu and double, we donot need to allocate memory for pes->rho if (PARAM.inp.device != "gpu" && PARAM.inp.precision != "single") { @@ -661,11 +655,6 @@ void Stochastic_Iter::cal_storho(const UnitCell& ucell, #ifdef __MPI MPI_Allreduce(MPI_IN_PLACE, &sto_ne, 1, MPI_DOUBLE, MPI_SUM, POOL_WORLD); - MPI_Allreduce(MPI_IN_PLACE, &sto_ne, 1, MPI_DOUBLE, MPI_SUM, PARAPW_WORLD); - for(int is = 0; is < nspin; ++is) - { - MPI_Allreduce(MPI_IN_PLACE, sto_rho[is], nrxx, MPI_DOUBLE, MPI_SUM, PARAPW_WORLD); - } #endif double factor = targetne / (KS_ne + sto_ne); if (std::abs(factor - 1) > 1e-10) @@ -680,7 +669,7 @@ void Stochastic_Iter::cal_storho(const UnitCell& ucell, for (int is = 0; is < 1; ++is) { - if (PARAM.inp.nbands > 0) + if (PARAM.globalv.nbands_l > 0) { #ifdef _OPENMP #pragma omp parallel for @@ -715,7 +704,7 @@ void Stochastic_Iter::calTnchi_ik(const int& ik, Stochastic_WFfix_k(ik); T* out = stowf.shchi->get_pointer(); T* pchi; - if (PARAM.inp.nbands > 0) + if (PARAM.globalv.nbands_l > 0) { stowf.chiortho->fix_k(ik); pchi = stowf.chiortho->get_pointer(); diff --git a/source/module_hamilt_pw/hamilt_stodft/sto_stress_pw.cpp b/source/module_hamilt_pw/hamilt_stodft/sto_stress_pw.cpp index 45db859c80..4d77635139 100644 --- a/source/module_hamilt_pw/hamilt_stodft/sto_stress_pw.cpp +++ b/source/module_hamilt_pw/hamilt_stodft/sto_stress_pw.cpp @@ -111,7 +111,7 @@ void Sto_Stress_PW::sto_stress_kin(ModuleBase::matrix& sigma, ModuleBase::timer::tick("Sto_Stress_PW", "stress_kin"); int nksbands = psi.get_nbands(); - if (GlobalV::MY_BNDGROUP != 0) + if (!PARAM.globalv.ks_run) { nksbands = 0; } @@ -160,7 +160,7 @@ void Sto_Stress_PW::sto_stress_nl(ModuleBase::matrix& sigma, int* nchip = stowf.nchip; const int npwx = wfc_basis->npwk_max; int nksbands = psi_in.get_nbands(); - if (GlobalV::MY_BNDGROUP != 0 && PARAM.inp.ks_solver != "bpcg") + if (!PARAM.globalv.ks_run) { nksbands = 0; } diff --git a/source/module_hamilt_pw/hamilt_stodft/sto_tool.cpp b/source/module_hamilt_pw/hamilt_stodft/sto_tool.cpp index 8b350c7777..98f346f188 100644 --- a/source/module_hamilt_pw/hamilt_stodft/sto_tool.cpp +++ b/source/module_hamilt_pw/hamilt_stodft/sto_tool.cpp @@ -138,7 +138,7 @@ psi::Psi>* gatherchi(psi::Psi>& chi, nrecv_sto, displs_sto, MPI_COMPLEX, - PARAPW_WORLD); + BP_WORLD); ModuleBase::timer::tick("sKG", "bands_gather"); } #endif diff --git a/source/module_hamilt_pw/hamilt_stodft/sto_wf.cpp b/source/module_hamilt_pw/hamilt_stodft/sto_wf.cpp index 5fdb9f59b3..b00564f1ef 100644 --- a/source/module_hamilt_pw/hamilt_stodft/sto_wf.cpp +++ b/source/module_hamilt_pw/hamilt_stodft/sto_wf.cpp @@ -315,7 +315,7 @@ void Stochastic_WF::init_sto_orbitals_Ecut(const int seed_in, int* nrecv = new int[PARAM.inp.bndpar]; const int nchiper = this->nchip[0]; #ifdef __MPI - MPI_Allgather(&nchiper, 1, MPI_INT, nrecv, 1, MPI_INT, PARAPW_WORLD); + MPI_Allgather(&nchiper, 1, MPI_INT, nrecv, 1, MPI_INT, BP_WORLD); #endif int ichi_start = 0; for (int i = 0; i < GlobalV::MY_BNDGROUP; ++i) diff --git a/source/module_hsolver/diago_bpcg.cpp b/source/module_hsolver/diago_bpcg.cpp index 36f77d372d..bd9b9326f2 100644 --- a/source/module_hsolver/diago_bpcg.cpp +++ b/source/module_hsolver/diago_bpcg.cpp @@ -11,6 +11,7 @@ #include "module_base/blas_connector.h" #include "module_base/global_function.h" #include "module_base/kernels/math_kernel_op.h" +#include "para_linear_transform.h" namespace hsolver { @@ -34,9 +35,10 @@ DiagoBPCG::~DiagoBPCG() { } template -void DiagoBPCG::init_iter(const int nband, const int nbasis, const int ndim) { +void DiagoBPCG::init_iter(const int nband, const int nband_l, const int nbasis, const int ndim) { // Specify the problem size n_basis, n_band, while lda is n_basis this->n_band = nband; + this->n_band_l = nband_l; this->n_basis = nbasis; this->n_dim = ndim; @@ -48,30 +50,41 @@ void DiagoBPCG::init_iter(const int nband, const int nbasis, const in this->hsub = std::move(ct::Tensor(t_type, device_type, {this->n_band, this->n_band})); - this->hpsi = std::move(ct::Tensor(t_type, device_type, {this->n_band, this->n_basis})); - this->work = std::move(ct::Tensor(t_type, device_type, {this->n_band, this->n_basis})); - this->hgrad = std::move(ct::Tensor(t_type, device_type, {this->n_band, this->n_basis})); - this->grad_old = std::move(ct::Tensor(t_type, device_type, {this->n_band, this->n_basis})); + this->hpsi = std::move(ct::Tensor(t_type, device_type, {this->n_band_l, this->n_basis})); + this->work = std::move(ct::Tensor(t_type, device_type, {this->n_band_l, this->n_basis})); + this->hgrad = std::move(ct::Tensor(t_type, device_type, {this->n_band_l, this->n_basis})); + this->grad_old = std::move(ct::Tensor(t_type, device_type, {this->n_band_l, this->n_basis})); this->prec = std::move(ct::Tensor(r_type, device_type, {this->n_basis})); - this->grad = std::move(ct::Tensor(t_type, device_type, {this->n_band, this->n_basis})); + this->grad = std::move(ct::Tensor(t_type, device_type, {this->n_band_l, this->n_basis})); +#ifdef __MPI + this->pmmcn.set_dimension(BP_WORLD, POOL_WORLD, n_band_l, n_basis, n_band_l, n_basis, n_dim, n_band); + this->plintrans.set_dimension(n_dim, nband_l, n_band_l, n_basis, BP_WORLD, false); +#else + this->pmmcn.set_dimension(n_band_l, n_basis, n_band_l, n_basis, n_dim, n_band); + this->plintrans.set_dimension(n_dim, nband_l, n_band_l, n_basis, false); +#endif } template bool DiagoBPCG::test_error(const ct::Tensor& err_in, const std::vector& ethr_band) { const Real * _err_st = err_in.data(); + bool not_conv = false; if (err_in.device_type() == ct::DeviceType::GpuDevice) { ct::Tensor h_err_in = err_in.to_device(); _err_st = h_err_in.data(); } - for (int ii = 0; ii < this->n_band; ii++) { + for (int ii = 0; ii < this->n_band_l; ii++) { if (_err_st[ii] > ethr_band[ii]) { - return true; + not_conv = true; } } - return false; +#ifdef __MPI + MPI_Allreduce(MPI_IN_PLACE, ¬_conv, 1, MPI_C_BOOL, MPI_LOR, BP_WORLD); +#endif + return not_conv; } // Finally, the last one! @@ -82,7 +95,7 @@ void DiagoBPCG::line_minimize( ct::Tensor& psi_out, ct::Tensor& hpsi_out) { - line_minimize_with_block_op()(grad_in.data(), hgrad_in.data(), psi_out.data(), hpsi_out.data(), this->n_basis, this->n_basis, this->n_band); + line_minimize_with_block_op()(grad_in.data(), hgrad_in.data(), psi_out.data(), hpsi_out.data(), this->n_basis, this->n_basis, this->n_band_l); } @@ -94,28 +107,8 @@ void DiagoBPCG::orth_cholesky( ct::Tensor& hpsi_out, ct::Tensor& hsub_out) { - // hsub_out = psi_out * transc(psi_out) - ct::EinsumOption option( - /*conj_x=*/false, /*conj_y=*/true, /*alpha=*/1.0, /*beta=*/0.0, /*Tensor out=*/&hsub_out); - // hsub_out = ct::op::einsum("ij,kj->ik", psi_out, psi_out, option); - // gemm: hsub_out(n_band x n_band) = psi_out^T(n_band x n_basis) * psi_out(n_basis x n_band) - gemm_op()(this->ctx, - 'C', - 'N', - this->n_band, //m - this->n_band, //n - this->n_dim, //k - this->one, //1.0 - psi_out.data(), - this->n_basis, //lda - psi_out.data(), - this->n_basis, //ldb - this->zero, //0.0 - hsub_out.data(), - this->n_band); //ldc - - Parallel_Reduce::reduce_pool(hsub_out.data(), this->n_band * this->n_band); + this->pmmcn.multiply(1.0, psi_out.data(), psi_out.data(), 0.0, hsub_out.data()); // set hsub matrix to lower format; ct::kernels::set_matrix()( @@ -150,7 +143,7 @@ void DiagoBPCG::calc_grad_with_block( grad_old_out.data(), this->n_basis, this->n_basis, - this->n_band); + this->n_band_l); } template @@ -165,51 +158,12 @@ void DiagoBPCG::orth_projection( ct::Tensor& hsub_in, ct::Tensor& grad_out) { - ct::EinsumOption option( - /*conj_x=*/false, /*conj_y=*/true, /*alpha=*/1.0, /*beta=*/0.0, /*Tensor out=*/&hsub_in); - // hsub_in = ct::op::einsum("ij,kj->ik", grad_out, psi_in, option); - // gemm: hsub_in(n_band x n_band) = psi_in^T(n_band x n_basis) * grad_out(n_basis x n_band) - gemm_op()(this->ctx, - 'C', - 'N', - this->n_band, //m - this->n_band, //n - this->n_dim, //k - this->one, //1.0 - psi_in.data(), - this->n_basis, //lda - grad_out.data(), - this->n_basis, //ldb - this->zero, //0.0 - hsub_in.data(), - this->n_band); //ldc - - Parallel_Reduce::reduce_pool(hsub_in.data(), this->n_band * this->n_band); - - // set_matrix_op()('L', hsub_in->data(), this->n_band); - option = ct::EinsumOption( - /*conj_x=*/false, /*conj_y=*/false, /*alpha=*/-1.0, /*beta=*/1.0, /*Tensor out=*/&grad_out); - // grad_out = ct::op::einsum("ij,jk->ik", hsub_in, psi_in, option); - - // grad_out(n_basis x n_band) = 1.0 * grad_out(n_basis x n_band) - psi_in(n_basis x n_band) * hsub_in(n_band x n_band) - gemm_op()(this->ctx, - 'N', - 'N', - this->n_dim, //m - this->n_band, //n - this->n_band, //k - this->neg_one, //-1.0 - psi_in.data(), - this->n_basis, //lda - hsub_in.data(), - this->n_band, //ldb - this->one, //1.0 - grad_out.data(), - this->n_basis); //ldc - - // * This type of non inner product like operation does not need reduce! - + this->pmmcn.multiply(1.0, psi_in.data(), grad_out.data(), 0.0, hsub_in.data()); + + // grad_out(n_basis x n_band) = 1.0 * grad_out(n_basis x n_band) - psi_in(n_basis x n_band) * hsub_in(n_band x + // n_band) + this->plintrans.act(-1.0, psi_in.data(), hsub_in.data(), 1.0, grad_out.data()); return; } @@ -219,28 +173,8 @@ void DiagoBPCG::rotate_wf( ct::Tensor& psi_out, ct::Tensor& workspace_in) { - ct::EinsumOption option( - /*conj_x=*/false, /*conj_y=*/false, /*alpha=*/1.0, /*beta=*/0.0, /*Tensor out=*/&workspace_in); - // workspace_in = ct::op::einsum("ij,jk->ik", hsub_in, psi_out, option); - // gemm: workspace_in(n_basis x n_band) = psi_out(n_basis x n_band) * hsub_in(n_band x n_band) - gemm_op()(this->ctx, - 'N', - 'N', - this->n_basis, //m - this->n_band, //n - this->n_band, //k - this->one, //1.0 - psi_out.data(), - this->n_basis, //lda - hsub_in.data(), - this->n_band, //ldb - this->zero, //0.0 - workspace_in.data(), - this->n_basis); //ldc - - // * This type of non inner product like operation does not need reduce! - + this->plintrans.act(1.0, psi_out.data(), hsub_in.data(), 0.0, workspace_in.data()); syncmem_complex_op()(psi_out.template data(), workspace_in.template data(), this->n_band * this->n_basis); return; @@ -263,30 +197,8 @@ void DiagoBPCG::diag_hsub( ct::Tensor& hsub_out, ct::Tensor& eigenvalue_out) { - // calculate all-band hsub - // Note: ctx is nothing but the devices used in this class (Device * ctx = nullptr;), - // it controls the ops to use the corresponding device to calculate results - ct::EinsumOption option( - /*conj_x=*/false, /*conj_y=*/true, /*alpha=*/1.0, /*beta=*/0.0, /*Tensor out=*/&hsub_out); - // hsub_out = ct::op::einsum("ij,kj->ik", psi_in, hpsi_in, option); - // gemm: hsub_out(n_band x n_band) = hpsi_in^T(n_band x n_basis) * psi_in(n_basis x n_band) - gemm_op()(this->ctx, - 'C', - 'N', - this->n_band, //m - this->n_band, //n - this->n_dim, //k - this->one, //1.0 - hpsi_in.data(), - this->n_basis, //lda - psi_in.data(), - this->n_basis, //ldb - this->zero, //0.0 - hsub_out.data(), - this->n_band); //ldc - - Parallel_Reduce::reduce_pool(hsub_out.data(), this->n_band * this->n_band); + this->pmmcn.multiply(1.0, hpsi_in.data(), psi_in.data(), 0.0, hsub_out.data()); ct::kernels::lapack_dnevd()('V', 'U', hsub_out.data(), this->n_band, eigenvalue_out.data()); diff --git a/source/module_hsolver/diago_bpcg.h b/source/module_hsolver/diago_bpcg.h index 90907de5e9..243ac8b44f 100644 --- a/source/module_hsolver/diago_bpcg.h +++ b/source/module_hsolver/diago_bpcg.h @@ -1,18 +1,18 @@ #ifndef DIAGO_BPCG_H_ #define DIAGO_BPCG_H_ +#include "module_base/kernels/math_kernel_op.h" +#include "module_base/module_device/memory_op.h" +#include "module_base/module_device/types.h" +#include "module_base/para_gemm.h" #include "module_hamilt_general/hamilt.h" #include "module_hamilt_pw/hamilt_pwdft/structure_factor.h" - -#include "module_base/module_device/types.h" -#include "module_base/module_device/memory_op.h" - -#include "module_base/kernels/math_kernel_op.h" #include "module_hsolver/kernels/dngvd_op.h" -#include +#include "module_hsolver/para_linear_transform.h" #include #include +#include namespace hsolver { @@ -50,11 +50,12 @@ class DiagoBPCG * This function allocates all the related variables, such as hpsi, hsub, before the diag call. * It is called by the HsolverPW::initDiagh() function. * - * @param nband The number of bands. + * @param nband The number of bands of all processes. + * @param nband_l The number of bands of current process. * @param nbasis The number of basis functions. Leading dimension of psi. * @param ndim The number of valid dimension of psi. */ - void init_iter(const int nband, const int nbasis, const int ndim); + void init_iter(const int nband, const int nband_l, const int nbasis, const int ndim); using HPsiFunc = std::function; @@ -74,14 +75,20 @@ class DiagoBPCG const std::vector& ethr_band); private: - /// the number of rows of the input psi + /// the number of bands of all processes int n_band = 0; + /// the number of bands of current process + int n_band_l = 0; /// the number of cols of the input psi int n_basis = 0; /// valid dimension of psi int n_dim = 0; /// max iter steps for all-band cg loop int nline = 4; + /// parallel matrix multiplication + ModuleBase::PGemmCN pmmcn; + PLinearTransform plintrans; + ct::DataType r_type = ct::DataType::DT_INVALID; ct::DataType t_type = ct::DataType::DT_INVALID; @@ -185,7 +192,6 @@ class DiagoBPCG * psi_out[dim: n_basis x n_band, column major, lda = n_basis_max], * * @param hsub_in Subspace matrix input, dim [n_basis, n_band] with column major. - * @param workspace_in Workspace matrix, dim [n_basis, n_band] with column major.. * @param psi_out output wavefunction matrix with dim [n_basis, n_band], column major. */ void rotate_wf( diff --git a/source/module_hsolver/hsolver_pw.cpp b/source/module_hsolver/hsolver_pw.cpp index 05ccc8acd0..02f44f857b 100644 --- a/source/module_hsolver/hsolver_pw.cpp +++ b/source/module_hsolver/hsolver_pw.cpp @@ -477,7 +477,7 @@ void HSolverPW::hamiltSolvePsiK(hamilt::Hamilt* hm, } else if (this->method == "bpcg") { - const int nband = psi.get_nbands(); + const int nband_l = psi.get_nbands(); const int nbasis = psi.get_nbasis(); const int ndim = psi.get_current_ngk(); // hpsi_func (X, HX, ld, nvec) -> HX = H(X), X and HX blockvectors of size ld x nvec @@ -496,7 +496,7 @@ void HSolverPW::hamiltSolvePsiK(hamilt::Hamilt* hm, ModuleBase::timer::tick("DavSubspace", "hpsi_func"); }; DiagoBPCG bpcg(pre_condition.data()); - bpcg.init_iter(nband, nbasis, ndim); + bpcg.init_iter(PARAM.inp.nbands, nband_l, nbasis, ndim); bpcg.diag(hpsi_func, psi.get_pointer(), eigenvalue, this->ethr_band); } else if (this->method == "dav_subspace") diff --git a/source/module_hsolver/hsolver_pw_sdft.cpp b/source/module_hsolver/hsolver_pw_sdft.cpp index 0b964a9388..b2df935ad6 100644 --- a/source/module_hsolver/hsolver_pw_sdft.cpp +++ b/source/module_hsolver/hsolver_pw_sdft.cpp @@ -46,7 +46,7 @@ void HSolverPW_SDFT::solve(const UnitCell& ucell, { ModuleBase::timer::tick("HSolverPW_SDFT", "solve_KS"); pHamilt->updateHk(ik); - if (nbands > 0 && GlobalV::MY_BNDGROUP == 0) + if (nbands > 0 && PARAM.globalv.ks_run) { /// update psi pointer for each k point psi.fix_k(ik); @@ -58,10 +58,10 @@ void HSolverPW_SDFT::solve(const UnitCell& ucell, } #ifdef __MPI - if (nbands > 0 && PARAM.inp.bndpar > 1) + if (nbands > 0 && !PARAM.globalv.all_ks_run) { - Parallel_Common::bcast_dev(&psi(ik, 0, 0), npwx * nbands, PARAPW_WORLD, &psi_cpu(ik, 0, 0)); - MPI_Bcast(&pes->ekb(ik, 0), nbands, MPI_DOUBLE, 0, PARAPW_WORLD); + Parallel_Common::bcast_dev(&psi(ik, 0, 0), npwx * nbands, BP_WORLD, &psi_cpu(ik, 0, 0)); + MPI_Bcast(&pes->ekb(ik, 0), nbands, MPI_DOUBLE, 0, BP_WORLD); } #endif ModuleBase::timer::tick("HSolverPW_SDFT", "solve_KS"); @@ -89,17 +89,10 @@ void HSolverPW_SDFT::solve(const UnitCell& ucell, // calculate eband = \sum_{ik,ib} w(ik)f(ik,ib)e_{ikib}, demet = -TS elecstate::ElecStatePW* pes_pw = static_cast*>(pes); - if (GlobalV::MY_BNDGROUP == 0) + pes_pw->calEBand(); + if(!PARAM.globalv.all_ks_run) { - pes_pw->calEBand(); - } - if (nbands > 0) - { -#ifdef __MPI - pes->f_en.eband /= GlobalV::NPROC_IN_POOL; - MPI_Allreduce(MPI_IN_PLACE, &pes->f_en.eband, 1, MPI_DOUBLE, MPI_SUM, STO_WORLD); - MPI_Bcast(&pes->f_en.eband, 1, MPI_DOUBLE, 0, PARAPW_WORLD); -#endif + pes->f_en.eband /= PARAM.inp.bndpar; } stoiter.sum_stoeband(stowf, pes_pw, pHamilt, wfc_basis); diff --git a/source/module_hsolver/para_linear_transform.cpp b/source/module_hsolver/para_linear_transform.cpp index 4180d11ff1..f3a0c60d00 100644 --- a/source/module_hsolver/para_linear_transform.cpp +++ b/source/module_hsolver/para_linear_transform.cpp @@ -1,126 +1,148 @@ #include "para_linear_transform.h" -#include + #include +#include namespace hsolver { template -void para_linear_transform_op::operator()(T* A, - const T alpha, - const T beta, - const T* U_global, - const int& nrow, - const int& LDA, - const int& ncol_loc, - const int& ncol_glo, +void PLinearTransform::set_dimension(const int nrowA, + const int ncolA, + const int ncolB, + const int LDA, #ifdef __MPI - MPI_Comm col_world, + MPI_Comm col_world, #endif - const int rank_col, - const int nproc_col - -) + const bool localU) { - const Device* ctx = {}; + this->nrowA = nrowA; + this->ncolA = ncolA; + this->ncolB = ncolB; + this->LDA = LDA; #ifdef __MPI + this->col_world = col_world; + MPI_Comm_rank(col_world, &rank_col); + MPI_Comm_size(col_world, &nproc_col); if (nproc_col > 1) { - std::vector colA_loc(nproc_col); - MPI_Allgather(&ncol_loc, 1, MPI_INT, colA_loc.data(), 1, MPI_INT, col_world); - std::vector start_col(nproc_col); - start_col[0] = 0; + this->localU = localU; + colA_loc.resize(nproc_col); + MPI_Allgather(&ncolA, 1, MPI_INT, colA_loc.data(), 1, MPI_INT, col_world); + start_colA.resize(nproc_col); + start_colA[0] = 0; for (int ip = 1; ip < nproc_col; ++ip) { - start_col[ip] = start_col[ip - 1] + colA_loc[ip - 1]; + start_colA[ip] = start_colA[ip - 1] + colA_loc[ip - 1]; } - int max_col = *std::max_element(colA_loc.begin(), colA_loc.end()); - std::vector requests(nproc_col); + this->ncolA_glo = start_colA[nproc_col - 1] + colA_loc[nproc_col - 1]; + this->max_colA = *std::max_element(colA_loc.begin(), colA_loc.end()); - std::vector A_tmp(max_col * LDA); + std::vector colB_loc(nproc_col); + MPI_Allgather(&ncolB, 1, MPI_INT, colB_loc.data(), 1, MPI_INT, col_world); + start_colB.resize(nproc_col); + start_colB[0] = 0; + for (int ip = 1; ip < nproc_col; ++ip) + { + start_colB[ip] = start_colB[ip - 1] + colB_loc[ip - 1]; + } + this->max_colB = *std::max_element(colB_loc.begin(), colB_loc.end()); + } +#else + nproc_col = 1; + rank_col = 0; +#endif +} +template +void PLinearTransform::act(const T alpha, const T* A, const T* U, const T beta, T* B) +{ + const Device* ctx = {}; +#ifdef __MPI + if (nproc_col > 1) + { + std::vector requests(nproc_col); + std::vector A_tmp(max_colA * LDA); T* A_tmp_device = A_tmp.data(); if (std::is_same::value) { A_tmp_device = nullptr; - resmem_dev_op()(A_tmp_device, max_col * LDA); + resmem_dev_op()(A_tmp_device, max_colA * LDA); } - T* A_tmp2 = nullptr; - resmem_dev_op()(A_tmp2, ncol_loc * LDA); - syncmem_dev_op()(A_tmp2, A, ncol_loc * LDA); - T* A_sum = nullptr; - resmem_dev_op()(A_sum, ncol_loc * LDA); - setmem_dev_op()(A_sum, 0.0, ncol_loc * LDA); + T* B_tmp = nullptr; + resmem_dev_op()(B_tmp, ncolB * LDA); + syncmem_dev_op()(B_tmp, B, ncolB * LDA); + setmem_dev_op()(B, 0.0, ncolB * LDA); + + T* U_tmp = nullptr; + resmem_dev_op()(U_tmp, max_colA * max_colB); // Send for (int ip = 0; ip < nproc_col; ++ip) { if (rank_col != ip) { - int size = LDA * ncol_loc; + int size = LDA * ncolA; Parallel_Common::isend_dev(A, size, ip, 0, col_world, &requests[ip], A_tmp.data()); } } // Receive - T* U_local = nullptr; - resmem_dev_op()(U_local, max_col * ncol_loc); - const int start = start_col[rank_col]; + const int start = this->localU ? 0 : start_colB[rank_col]; for (int ip = 0; ip < nproc_col; ++ip) { T real_beta = ip == 0 ? beta : 0; - const int start_row = start_col[ip]; - const int ncol_ip = colA_loc[ip]; - // get U_local - for (int i = 0; i < ncol_loc; ++i) - { - const T* U_glo_tmp = U_global + start_row + (i + start) * ncol_glo; - syncmem_dev_op()(U_local + i * ncol_ip, U_glo_tmp, ncol_ip); - } + const int ncolA_ip = colA_loc[ip]; + // get U_tmp + + const int start_row = start_colA[ip]; + for (int i = 0; i < ncolB; ++i) + { + const T* U_part = U + start_row + (i + start) * ncolA_glo; + syncmem_dev_op()(U_tmp + i * ncolA_ip, U_part, ncolA_ip); + } if (ip == rank_col) { ModuleBase::gemm_op()(ctx, 'N', 'N', - nrow, - ncol_loc, - ncol_ip, + nrowA, + ncolB, + ncolA_ip, &alpha, A, LDA, - U_local, - ncol_ip, + U_tmp, + ncolA_ip, &real_beta, - A_tmp2, + B_tmp, LDA); } else { - int size = LDA * ncol_ip; + int size = LDA * ncolA_ip; MPI_Status status; Parallel_Common::recv_dev(A_tmp_device, size, ip, 0, col_world, &status, A_tmp.data()); MPI_Wait(&requests[ip], &status); ModuleBase::gemm_op()(ctx, 'N', 'N', - nrow, - ncol_loc, - ncol_ip, + nrowA, + ncolB, + ncolA_ip, &alpha, A_tmp_device, LDA, - U_local, - ncol_ip, + U_tmp, + ncolA_ip, &real_beta, - A_tmp2, + B_tmp, LDA); } // sum all the results T one = 1.0; - ModuleBase::axpy_op()(ctx, ncol_loc * LDA, &one, A_tmp2, 1, A_sum, 1); + ModuleBase::axpy_op()(ctx, ncolB * LDA, &one, B_tmp, 1, B, 1); } - syncmem_dev_op()(A, A_sum, ncol_loc * LDA); - delmem_dev_op()(U_local); - delmem_dev_op()(A_tmp2); - delmem_dev_op()(A_sum); + delmem_dev_op()(U_tmp); + delmem_dev_op()(B_tmp); if (std::is_same::value) { delmem_dev_op()(A_tmp_device); @@ -129,33 +151,29 @@ void para_linear_transform_op::operator()(T* A, else #endif { - T* A_tmp = nullptr; - resmem_dev_op()(A_tmp, LDA * ncol_glo); - syncmem_dev_op()(A_tmp, A, LDA * ncol_loc); ModuleBase::gemm_op()(ctx, 'N', 'N', - nrow, - ncol_glo, - ncol_glo, + nrowA, + ncolB, + ncolA, &alpha, - A_tmp, + A, LDA, - U_global, - ncol_glo, + U, + ncolA, &beta, - A, + B, LDA); - delmem_dev_op()(A_tmp); } }; -template struct para_linear_transform_op; -template struct para_linear_transform_op, base_device::DEVICE_CPU>; -template struct para_linear_transform_op, base_device::DEVICE_CPU>; +template struct PLinearTransform; +template struct PLinearTransform, base_device::DEVICE_CPU>; +template struct PLinearTransform, base_device::DEVICE_CPU>; #if ((defined __CUDA) || (defined __ROCM)) -template struct para_linear_transform_op; -template struct para_linear_transform_op, base_device::DEVICE_GPU>; -template struct para_linear_transform_op, base_device::DEVICE_GPU>; +template struct PLinearTransform; +template struct PLinearTransform, base_device::DEVICE_GPU>; +template struct PLinearTransform, base_device::DEVICE_GPU>; #endif } // namespace hsolver \ No newline at end of file diff --git a/source/module_hsolver/para_linear_transform.h b/source/module_hsolver/para_linear_transform.h index b94dd763ab..42cb02fb47 100644 --- a/source/module_hsolver/para_linear_transform.h +++ b/source/module_hsolver/para_linear_transform.h @@ -4,52 +4,73 @@ #include "module_base/module_device/device.h" #include "module_base/module_device/memory_op.h" #include "module_base/parallel_device.h" +#include #ifdef __MPI #include "mpi.h" #endif namespace hsolver { +/** + * @brief B = alpha * A * U + beta * B + * A and B are local matrice + * U can be a local matrix or a global matrix + */ template -struct para_linear_transform_op +class PLinearTransform { + public: using syncmem_dev_op = base_device::memory::synchronize_memory_op; using resmem_dev_op = base_device::memory::resize_memory_op; using setmem_dev_op = base_device::memory::set_memory_op; using delmem_dev_op = base_device::memory::delete_memory_op; + int nproc_col = 1; + int rank_col = 0; + int nrowA = 0; + int ncolA = 0; + int ncolB = 0; + int LDA = 0; + bool localU = false; +#ifdef __MPI + MPI_Comm col_world = MPI_COMM_NULL; + std::vector colA_loc; + std::vector start_colA; + std::vector start_colB; + int max_colA = 0; + int ncolA_glo = 0; + int max_colB = 0; +#endif + + /** + * @brief set the dimension of A, B, and U + * A: LDA * nrow, U_global: ncolA_global * ncolB_global, U_local: ncolA_global * ncolB + * B: LDA * ncolB + */ + void set_dimension(const int nrowA, + const int ncolA, + const int ncolB, + const int LDA, +#ifdef __MPI + MPI_Comm col_world, +#endif + const bool localU); + /** - * @brief A_global = alpha * A_global * U_global + beta * A_global - * A is a local matrix with nrow rows and ncol_loc columns - * U_global is a matrix with ncol_glo rows and ncol_glo columns - * @example rotate wave functions: A = A * U - * orthogonalize wave functions: A = A - A * U + * @brief B = alpha * A * U + beta * B + * A is a local matrix with nrow rows and ncolA_loc columns + * B is a local matrix with nrow rows and ncolB_loc columns + * U can be a local matrix or a global matrix + * @example rotate wave functions: B = A * U + * orthogonalize wave functions: B = - A * U + B * - * @param A : input/output matrix * @param alpha : alpha - * @param beta : beta + * @param A : input matrix * @param U_global : input matrix - * @param nrow : number of rows of A - * @param LDA : leading dimension of A - * @param ncol_loc : number of columns of A - * @param ncol_glo : number of columns and rows of U_global - * @param col_world : column communicator world - * @param rank_col : rank of col_world - * @param nproc_col : number of processes in col_world + * @param beta : beta + * @param B : input/output matrix * */ - void operator()(T* A, - const T alpha, - const T beta, - const T* U_global, - const int& nrow, - const int& LDA, - const int& ncol_loc, - const int& ncol_glo, -#ifdef __MPI - MPI_Comm col_world, -#endif - const int rank_col, - const int nproc_col); + void act(const T alpha, const T* A, const T* U_global, const T beta, T* B); }; } // namespace hsolver #endif \ No newline at end of file diff --git a/source/module_hsolver/test/diago_bpcg_test.cpp b/source/module_hsolver/test/diago_bpcg_test.cpp index e6af8b5b5e..369fcb4ec6 100644 --- a/source/module_hsolver/test/diago_bpcg_test.cpp +++ b/source/module_hsolver/test/diago_bpcg_test.cpp @@ -154,7 +154,7 @@ class DiagoBPCGPrepare hpsi_out, ld_psi); }; const int ndim = psi_local.get_current_ngk(); - bpcg.init_iter(nband, npw, ndim); + bpcg.init_iter(nband, nband, npw, ndim); std::vector ethr_band(nband, 1e-5); bpcg.diag(hpsi_func, psi_local.get_pointer(), en, ethr_band); bpcg.diag(hpsi_func, psi_local.get_pointer(), en, ethr_band); diff --git a/source/module_hsolver/test/test_hsolver_sdft.cpp b/source/module_hsolver/test/test_hsolver_sdft.cpp index 3d69cacf2f..d1f7bbca13 100644 --- a/source/module_hsolver/test/test_hsolver_sdft.cpp +++ b/source/module_hsolver/test/test_hsolver_sdft.cpp @@ -339,10 +339,10 @@ int main(int argc, char** argv) MPI_Comm_size(MPI_COMM_WORLD, &GlobalV::NPROC); MPI_Comm_rank(MPI_COMM_WORLD, &GlobalV::MY_RANK); - MPI_Comm_split(MPI_COMM_WORLD, 0, 1, &PARAPW_WORLD); + MPI_Comm_split(MPI_COMM_WORLD, 0, 1, &BP_WORLD); int result = RUN_ALL_TESTS(); - MPI_Comm_free(&PARAPW_WORLD); + MPI_Comm_free(&BP_WORLD); MPI_Finalize(); return result; diff --git a/source/module_hsolver/test/test_para_linear_trans.cpp b/source/module_hsolver/test/test_para_linear_trans.cpp index 3acac64435..50b533a9fd 100644 --- a/source/module_hsolver/test/test_para_linear_trans.cpp +++ b/source/module_hsolver/test/test_para_linear_trans.cpp @@ -5,12 +5,20 @@ #include #endif -void random_data(std::vector& A_global, std::vector& U_global, double& alpha, double& beta) +void random_data(std::vector& A_global, + std::vector& B_global, + std::vector& U_global, + double& alpha, + double& beta) { for (auto& val: A_global) { val = std::rand() / (RAND_MAX + 1.0); } + for (auto& val: B_global) + { + val = std::rand() / (RAND_MAX + 1.0); + } for (auto& val: U_global) { val = std::rand() / (RAND_MAX + 1.0); @@ -19,6 +27,7 @@ void random_data(std::vector& A_global, std::vector& U_global, d beta = std::rand() / (RAND_MAX + 1.0); } void random_data(std::vector>& A_global, + std::vector>& B_global, std::vector>& U_global, std::complex& alpha, std::complex& beta) @@ -27,6 +36,10 @@ void random_data(std::vector>& A_global, { val = std::complex(std::rand() / (RAND_MAX + 1.0), std::rand() / (RAND_MAX + 1.0)); } + for (auto& val: B_global) + { + val = std::complex(std::rand() / (RAND_MAX + 1.0), std::rand() / (RAND_MAX + 1.0)); + } for (auto& val: U_global) { val = std::complex(std::rand() / (RAND_MAX + 1.0), std::rand() / (RAND_MAX + 1.0)); @@ -54,58 +67,67 @@ class ParaLinearTransformTest : public ::testing::Test void TearDown() override { } - void prepare(const int nrow, const int ncol_glo, const int LDA) + void prepare(const int nrow, const int ncolA_glo, const int ncolB_glo, const int LDA) { int rank = 0; int nproc = 1; int colA_start = 0; - this->ncol_glo = ncol_glo; - this->ncol_loc = ncol_glo; + int colB_start = 0; + this->ncolA_glo = ncolA_glo; + this->ncolB_glo = ncolB_glo; + this->ncolA_loc = ncolA_glo; + this->ncolB_loc = ncolB_glo; #ifdef __MPI MPI_Comm_rank(MPI_COMM_WORLD, &rank); MPI_Comm_size(MPI_COMM_WORLD, &nproc); - this->ncol_loc = ncol_glo / nproc; - if (rank < ncol_glo % nproc) + this->ncolA_loc = ncolA_glo / nproc; + this->ncolB_loc = ncolB_glo / nproc; + if (rank < ncolA_glo % nproc) { - ncol_loc++; + ncolA_loc++; + ncolB_loc++; } std::vector ncolA_ip(nproc); - MPI_Allgather(&ncol_loc, 1, MPI_INT, ncolA_ip.data(), 1, MPI_INT, MPI_COMM_WORLD); + std::vector ncolB_ip(nproc); + MPI_Allgather(&ncolA_loc, 1, MPI_INT, ncolA_ip.data(), 1, MPI_INT, MPI_COMM_WORLD); + MPI_Allgather(&ncolB_loc, 1, MPI_INT, ncolB_ip.data(), 1, MPI_INT, MPI_COMM_WORLD); for (int i = 0; i < rank; ++i) { colA_start += ncolA_ip[i]; + colB_start += ncolB_ip[i]; } #endif - A_global.resize(LDA * ncol_glo); - A_global_ref.resize(LDA * ncol_glo); - U_global.resize(ncol_glo * ncol_glo); + A_global.resize(LDA * ncolA_glo); + B_global.resize(LDA * ncolB_glo); + B_global_ref.resize(LDA * ncolB_glo); + U_global.resize(ncolA_glo * ncolB_glo); if (rank == 0) { - random_data(A_global, U_global, alpha, beta); - A_global_ref = A_global; - std::vector A_global_tmp = A_global; + random_data(A_global, B_global, U_global, alpha, beta); + B_global_ref = B_global; const base_device::DEVICE_CPU* ctx = {}; ModuleBase::gemm_op()(ctx, 'N', 'N', nrow, - ncol_glo, - ncol_glo, + ncolB_glo, + ncolA_glo, &alpha, - A_global_tmp.data(), + A_global.data(), LDA, U_global.data(), - ncol_glo, + ncolA_glo, &beta, - A_global_ref.data(), + B_global_ref.data(), LDA); } if (std::is_same::value) { #ifdef __MPI MPI_Bcast(A_global.data(), A_global.size(), MPI_DOUBLE, 0, MPI_COMM_WORLD); + MPI_Bcast(B_global.data(), B_global.size(), MPI_DOUBLE, 0, MPI_COMM_WORLD); MPI_Bcast(U_global.data(), U_global.size(), MPI_DOUBLE, 0, MPI_COMM_WORLD); - MPI_Bcast(A_global_ref.data(), A_global_ref.size(), MPI_DOUBLE, 0, MPI_COMM_WORLD); + MPI_Bcast(B_global_ref.data(), B_global_ref.size(), MPI_DOUBLE, 0, MPI_COMM_WORLD); MPI_Bcast(&alpha, 1, MPI_DOUBLE, 0, MPI_COMM_WORLD); MPI_Bcast(&beta, 1, MPI_DOUBLE, 0, MPI_COMM_WORLD); #endif @@ -114,42 +136,53 @@ class ParaLinearTransformTest : public ::testing::Test { #ifdef __MPI MPI_Bcast(A_global.data(), A_global.size(), MPI_DOUBLE_COMPLEX, 0, MPI_COMM_WORLD); + MPI_Bcast(B_global.data(), B_global.size(), MPI_DOUBLE_COMPLEX, 0, MPI_COMM_WORLD); MPI_Bcast(U_global.data(), U_global.size(), MPI_DOUBLE_COMPLEX, 0, MPI_COMM_WORLD); - MPI_Bcast(A_global_ref.data(), A_global_ref.size(), MPI_DOUBLE_COMPLEX, 0, MPI_COMM_WORLD); + MPI_Bcast(B_global_ref.data(), B_global_ref.size(), MPI_DOUBLE_COMPLEX, 0, MPI_COMM_WORLD); MPI_Bcast(&alpha, 1, MPI_DOUBLE_COMPLEX, 0, MPI_COMM_WORLD); MPI_Bcast(&beta, 1, MPI_DOUBLE_COMPLEX, 0, MPI_COMM_WORLD); #endif } - A.resize(LDA * ncol_loc); - A_ref.resize(LDA * ncol_loc); - for (int i = 0; i < LDA * ncol_loc; ++i) + A.resize(LDA * ncolA_loc); + B.resize(LDA * ncolB_loc); + B_ref.resize(LDA * ncolB_loc); + for (int i = 0; i < LDA * ncolA_loc; ++i) { A[i] = A_global[colA_start * LDA + i]; - A_ref[i] = A_global_ref[colA_start * LDA + i]; + } + for (int i = 0; i < LDA * ncolB_loc; ++i) + { + B[i] = B_global[colB_start * LDA + i]; + B_ref[i] = B_global_ref[colB_start * LDA + i]; } } - std::vector A; - std::vector A_ref; + std::vector A, B; + std::vector B_ref; std::vector A_global; + std::vector B_global; std::vector U_global; - std::vector A_global_ref; - int ncol_glo = 1; - int ncol_loc = 1; + std::vector B_global_ref; + int ncolA_glo = 1; + int ncolB_glo = 1; + int ncolA_loc = 1; + int ncolB_loc = 1; T alpha; T beta; + hsolver::PLinearTransform lt; }; typedef ::testing::Types> MyTypes; TYPED_TEST_SUITE(ParaLinearTransformTest, MyTypes); -TYPED_TEST(ParaLinearTransformTest, cpucase) +TYPED_TEST(ParaLinearTransformTest, globalU) { - const int nrow = 7; - const int ncol_glo = 13; + const int nrowA = 7; + const int ncolA_glo = 13; + const int ncolB_glo = 11; const int LDA = 9; - this->prepare(nrow, ncol_glo, LDA); + this->prepare(nrowA, ncolA_glo, ncolB_glo, LDA); int rank_col = 0, nproc_col = 1; #ifdef __MPI MPI_Comm col_world = MPI_COMM_WORLD; @@ -157,28 +190,61 @@ TYPED_TEST(ParaLinearTransformTest, cpucase) MPI_Comm_size(col_world, &nproc_col); #endif - hsolver::para_linear_transform_op()(this->A.data(), - this->alpha, - this->beta, - this->U_global.data(), - nrow, - LDA, - this->ncol_loc, - ncol_glo, + this->lt.set_dimension(nrowA, + this->ncolA_loc, + this->ncolB_loc, + LDA, #ifdef __MPI - col_world, + col_world, #endif - rank_col, - nproc_col); + false); + this->lt.act(this->alpha, this->A.data(), this->U_global.data(), this->beta, this->B.data()); + + for (int i = 0; i < this->ncolB_loc; ++i) + { + for (int j = 0; j < nrowA; ++j) + { + EXPECT_NEAR(get_double(this->B[j + i * LDA]), get_double(this->B_ref[j + i * LDA]), 1e-10); + } + } +} +#ifdef __MPI +TYPED_TEST(ParaLinearTransformTest, localU) +{ + const int nrowA = 7; + const int ncolA_glo = 13; + const int ncolB_glo = 11; + const int LDA = 9; + + this->prepare(nrowA, ncolA_glo, ncolB_glo, LDA); + int rank_col = 0, nproc_col = 1; + + MPI_Comm col_world = MPI_COMM_WORLD; + MPI_Comm_rank(col_world, &rank_col); + MPI_Comm_size(col_world, &nproc_col); + std::vector ncolB_ip(nproc_col); + std::vector start_colB(nproc_col); + MPI_Allgather(&this->ncolB_loc, 1, MPI_INT, ncolB_ip.data(), 1, MPI_INT, col_world); + start_colB[0] = 0; + for (int i = 1; i < nproc_col; ++i) + { + start_colB[i] = start_colB[i - 1] + ncolB_ip[i - 1]; + } + int start = start_colB[rank_col]; - for (int i = 0; i < this->ncol_loc; ++i) + this->lt.set_dimension(nrowA, this->ncolA_loc, this->ncolB_loc, LDA, col_world, true); + + this->lt.act(this->alpha, this->A.data(), this->U_global.data() + start * ncolA_glo, this->beta, this->B.data()); + + for (int i = 0; i < this->ncolB_loc; ++i) { - for (int j = 0; j < nrow; ++j) + for (int j = 0; j < nrowA; ++j) { - EXPECT_NEAR(get_double(this->A[j + i * LDA]), get_double(this->A_ref[j + i * LDA]), 1e-10); + EXPECT_NEAR(get_double(this->B[j + i * LDA]), get_double(this->B_ref[j + i * LDA]), 1e-10); } } } +#endif int main(int argc, char** argv) { diff --git a/source/module_io/cal_dos.cpp b/source/module_io/cal_dos.cpp index db9c374811..8966a0e95c 100644 --- a/source/module_io/cal_dos.cpp +++ b/source/module_io/cal_dos.cpp @@ -4,6 +4,7 @@ #include "module_base/global_function.h" #include "module_base/global_variable.h" #include "module_base/parallel_reduce.h" +#include "module_parameter/parameter.h" bool ModuleIO::calculate_dos(const int& is, const std::string& fa, // file address for DOS @@ -102,7 +103,8 @@ bool ModuleIO::calculate_dos(const int& is, } } #ifdef __MPI - Parallel_Reduce::reduce_double_allpool(GlobalV::KPAR, GlobalV::NPROC_IN_POOL, count); + const int npool = GlobalV::KPAR * PARAM.inp.bndpar; + Parallel_Reduce::reduce_double_allpool(npool, GlobalV::NPROC_IN_POOL, count); #endif count = count / static_cast(nkstot); sum += count; diff --git a/source/module_io/read_set_globalv.cpp b/source/module_io/read_set_globalv.cpp index a26c2a46d2..d127837abf 100755 --- a/source/module_io/read_set_globalv.cpp +++ b/source/module_io/read_set_globalv.cpp @@ -60,6 +60,16 @@ void ReadInput::set_globalv(const Input_para& inp, System_para& sys) #ifdef __MPI Parallel_Common::bcast_bool(sys.double_grid); #endif + /// set ks_run + if (GlobalV::MY_BNDGROUP == 0 || inp.ks_solver == "bpcg") + { + sys.ks_run = true; + } + if (inp.ks_solver != "bpcg" && inp.bndpar > 1) + { + sys.all_ks_run = false; + } + } /// @note Here para.inp has been synchronized of all ranks. diff --git a/source/module_parameter/system_parameter.h b/source/module_parameter/system_parameter.h index d1258ffc6b..e41ed6b823 100755 --- a/source/module_parameter/system_parameter.h +++ b/source/module_parameter/system_parameter.h @@ -54,6 +54,8 @@ struct System_para double uramping = -10.0 / 13.6; /// U-Ramping method (Ry) std::vector hubbard_u = {}; ///< Hubbard Coulomb interaction parameter U (Ry) int kpar_lcao = 1; ///< global number of pools for LCAO diagonalization only - int nbands_l = 0; ///< number of bands of each band parallel calculation + int nbands_l = 0; ///< number of bands of each band parallel calculation, same to nbands when bndpar=1 + bool ks_run = false; ///< true if current process runs KS calculation + bool all_ks_run = true; ///< true if only all processes run KS calculation }; #endif \ No newline at end of file diff --git a/source/module_psi/psi_init.cpp b/source/module_psi/psi_init.cpp index 085cb381c8..fcd07622c2 100644 --- a/source/module_psi/psi_init.cpp +++ b/source/module_psi/psi_init.cpp @@ -86,7 +86,7 @@ void PSIInit::initialize_psi(Psi>* psi, hamilt::Hamilt* p_hamilt, std::ofstream& ofs_running) { - if (kspw_psi->get_nbands() == 0 || (GlobalV::MY_BNDGROUP != 0 && PARAM.inp.ks_solver != "bpcg")) + if (kspw_psi->get_nbands() == 0 || (!PARAM.globalv.ks_run)) { return; } From 7ddd28fa6a5229e890f2f136194de299790fedd5 Mon Sep 17 00:00:00 2001 From: Qianruipku Date: Mon, 27 Jan 2025 15:07:33 +0800 Subject: [PATCH 07/17] fix BPCG --- source/module_cell/cal_atoms_info.h | 7 +- source/module_elecstate/elecstate_print.cpp | 4 +- .../hamilt_stodft/sto_iter.cpp | 65 +++++++++++++++---- source/module_hsolver/diago_bpcg.cpp | 26 +++++--- .../module_hsolver/para_linear_transform.cpp | 24 +++---- source/module_io/read_set_globalv.cpp | 5 -- source/module_io/write_istate_info.cpp | 2 +- tests/integrate/102_PW_BPCG_BP/INPUT | 33 ++++++++++ tests/integrate/102_PW_BPCG_BP/KPT | 4 ++ tests/integrate/102_PW_BPCG_BP/README | 10 +++ tests/integrate/102_PW_BPCG_BP/STRU | 23 +++++++ tests/integrate/102_PW_BPCG_BP/result.ref | 8 +++ tests/integrate/184_PW_BPCG_SDFT_5D10S/INPUT | 38 +++++++++++ tests/integrate/184_PW_BPCG_SDFT_5D10S/KPT | 4 ++ tests/integrate/184_PW_BPCG_SDFT_5D10S/README | 9 +++ tests/integrate/184_PW_BPCG_SDFT_5D10S/STRU | 19 ++++++ tests/integrate/184_PW_BPCG_SDFT_5D10S/jd | 1 + .../184_PW_BPCG_SDFT_5D10S/result.ref | 5 ++ 18 files changed, 244 insertions(+), 43 deletions(-) create mode 100644 tests/integrate/102_PW_BPCG_BP/INPUT create mode 100644 tests/integrate/102_PW_BPCG_BP/KPT create mode 100644 tests/integrate/102_PW_BPCG_BP/README create mode 100644 tests/integrate/102_PW_BPCG_BP/STRU create mode 100644 tests/integrate/102_PW_BPCG_BP/result.ref create mode 100644 tests/integrate/184_PW_BPCG_SDFT_5D10S/INPUT create mode 100644 tests/integrate/184_PW_BPCG_SDFT_5D10S/KPT create mode 100644 tests/integrate/184_PW_BPCG_SDFT_5D10S/README create mode 100644 tests/integrate/184_PW_BPCG_SDFT_5D10S/STRU create mode 100644 tests/integrate/184_PW_BPCG_SDFT_5D10S/jd create mode 100644 tests/integrate/184_PW_BPCG_SDFT_5D10S/result.ref diff --git a/source/module_cell/cal_atoms_info.h b/source/module_cell/cal_atoms_info.h index 567140484a..eeb10823c9 100644 --- a/source/module_cell/cal_atoms_info.h +++ b/source/module_cell/cal_atoms_info.h @@ -73,11 +73,16 @@ class CalAtomsInfo if (para.inp.ks_solver == "bpcg") // only bpcg support band parallel { para.sys.nbands_l = para.inp.nbands / para.inp.bndpar; - if (GlobalV::RANK_IN_BPGROUP < para.inp.nbands % para.inp.bndpar) + if (GlobalV::MY_BNDGROUP < para.inp.nbands % para.inp.bndpar) { para.sys.nbands_l++; } } + // temporary code + if (GlobalV::MY_BNDGROUP == 0 || para.inp.ks_solver == "bpcg") + { + para.sys.ks_run = true; + } return; } }; diff --git a/source/module_elecstate/elecstate_print.cpp b/source/module_elecstate/elecstate_print.cpp index 92c49427cc..8a5fdce8a5 100644 --- a/source/module_elecstate/elecstate_print.cpp +++ b/source/module_elecstate/elecstate_print.cpp @@ -247,7 +247,7 @@ void ElecState::print_band(const int& ik, const int& printe, const int& iter) { // check the band energy. bool wrong = false; - for (int ib = 0; ib < PARAM.inp.nbands; ++ib) + for (int ib = 0; ib < PARAM.globalv.nbands_l; ++ib) { if (std::abs(this->ekb(ik, ib)) > 1.0e10) { @@ -269,7 +269,7 @@ void ElecState::print_band(const int& ik, const int& printe, const int& iter) GlobalV::ofs_running << " Energy (eV) & Occupations for spin=" << this->klist->isk[ik] + 1 << " K-point=" << ik + 1 << std::endl; GlobalV::ofs_running << std::setiosflags(std::ios::showpoint); - for (int ib = 0; ib < PARAM.inp.nbands; ib++) + for (int ib = 0; ib < PARAM.globalv.nbands_l; ib++) { GlobalV::ofs_running << " " << std::setw(6) << ib + 1 << std::setw(15) << this->ekb(ik, ib) * ModuleBase::Ry_to_eV; diff --git a/source/module_hamilt_pw/hamilt_stodft/sto_iter.cpp b/source/module_hamilt_pw/hamilt_stodft/sto_iter.cpp index 06a535155a..4c0aff7469 100644 --- a/source/module_hamilt_pw/hamilt_stodft/sto_iter.cpp +++ b/source/module_hamilt_pw/hamilt_stodft/sto_iter.cpp @@ -58,7 +58,7 @@ void Stochastic_Iter::orthog(const int& ik, psi::Psi& psi, { ModuleBase::TITLE("Stochastic_Iter", "orthog"); ModuleBase::timer::tick("Stochastic_Iter", "orthog"); - const int nbands_l = psi.get_nbands(); + int nbands_l = psi.get_nbands(); const int nbands = PARAM.inp.nbands; // orthogonal part if (nbands > 0) @@ -74,24 +74,63 @@ void Stochastic_Iter::orthog(const int& ik, psi::Psi& psi, // orthogonal part T* sum = nullptr; resmem_complex_op()(sum, nbands * nchipk); - // sum(b - ModuleBase::PGemmCN pmmcn; + + if(PARAM.globalv.all_ks_run) + { + // sum(b + ModuleBase::PGemmCN pmmcn; #ifdef __MPI - pmmcn.set_dimension(BP_WORLD, POOL_WORLD, nbands_l, npwx, nchipk, npwx, npw, nbands, 2); + pmmcn.set_dimension(BP_WORLD, POOL_WORLD, nbands_l, npwx, nchipk, npwx, npw, nbands, 2); #else - pmmcn.set_dimension(nbands_l, npwx, nchipk, npwx, npw, nbands, 2); + pmmcn.set_dimension(nbands_l, npwx, nchipk, npwx, npw, nbands, 2); #endif - pmmcn.multiply(1.0, &psi(ik, 0, 0), wfgout, 0.0, sum); - - // psi -= psi * sum - hsolver::PLinearTransform pltrans; + pmmcn.multiply(1.0, &psi(ik, 0, 0), wfgout, 0.0, sum); + + // psi -= psi * sum + hsolver::PLinearTransform pltrans; #ifdef __MPI - pltrans.set_dimension(npw, nbands_l, nchipk, npwx, BP_WORLD, true); + pltrans.set_dimension(npw, nbands_l, nchipk, npwx, BP_WORLD, true); #else - pltrans.set_dimension(npw, nbands_l, nchipk, npwx, true); + pltrans.set_dimension(npw, nbands_l, nchipk, npwx, true); #endif - pltrans.act(-1.0, &psi(ik, 0, 0), sum, 1.0, wfgout); - + pltrans.act(-1.0, &psi(ik, 0, 0), sum, 1.0, wfgout); + } + else + { + // sum(b + ModuleBase::gemm_op()(ctx, + 'C', + 'N', + nbands, + nchipk, + npw, + &ModuleBase::ONE, + &psi(ik, 0, 0), + npwx, + wfgout, + npwx, + &ModuleBase::ZERO, + sum, + nbands); + Parallel_Reduce::reduce_pool(sum, nbands * nchipk); + + // psi -= psi * sum + ModuleBase::gemm_op()(ctx, + 'N', + 'N', + npw, + nchipk, + nbands, + &ModuleBase::NEG_ONE, + &psi(ik, 0, 0), + npwx, + sum, + nbands, + &ModuleBase::ONE, + wfgout, + npwx); + } + delmem_complex_op()(sum); } ModuleBase::timer::tick("Stochastic_Iter", "orthog"); diff --git a/source/module_hsolver/diago_bpcg.cpp b/source/module_hsolver/diago_bpcg.cpp index bd9b9326f2..743141ff5a 100644 --- a/source/module_hsolver/diago_bpcg.cpp +++ b/source/module_hsolver/diago_bpcg.cpp @@ -12,6 +12,7 @@ #include "module_base/global_function.h" #include "module_base/kernels/math_kernel_op.h" #include "para_linear_transform.h" +#include "module_parameter/parameter.h" namespace hsolver { @@ -44,9 +45,9 @@ void DiagoBPCG::init_iter(const int nband, const int nband_l, const i // All column major tensors - this->beta = std::move(ct::Tensor(r_type, device_type, {this->n_band})); + this->beta = std::move(ct::Tensor(r_type, device_type, {this->n_band_l})); this->eigen = std::move(ct::Tensor(r_type, device_type, {this->n_band})); - this->err_st = std::move(ct::Tensor(r_type, device_type, {this->n_band})); + this->err_st = std::move(ct::Tensor(r_type, device_type, {this->n_band_l})); this->hsub = std::move(ct::Tensor(t_type, device_type, {this->n_band, this->n_band})); @@ -175,7 +176,7 @@ void DiagoBPCG::rotate_wf( { // gemm: workspace_in(n_basis x n_band) = psi_out(n_basis x n_band) * hsub_in(n_band x n_band) this->plintrans.act(1.0, psi_out.data(), hsub_in.data(), 0.0, workspace_in.data()); - syncmem_complex_op()(psi_out.template data(), workspace_in.template data(), this->n_band * this->n_basis); + syncmem_complex_op()(psi_out.template data(), workspace_in.template data(), this->n_band_l * this->n_basis); return; } @@ -187,7 +188,7 @@ void DiagoBPCG::calc_hpsi_with_block( ct::Tensor& hpsi_out) { // calculate all-band hpsi - hpsi_func(psi_in, hpsi_out.data(), this->n_basis, this->n_band); + hpsi_func(psi_in, hpsi_out.data(), this->n_basis, this->n_band_l); } template @@ -256,7 +257,7 @@ void DiagoBPCG::diag(const HPsiFunc& hpsi_func, { const int current_scf_iter = hsolver::DiagoIterAssist::SCF_ITER; // Get the pointer of the input psi - this->psi = std::move(ct::TensorMap(psi_in /*psi_in.get_pointer()*/, t_type, device_type, {this->n_band, this->n_basis})); + this->psi = std::move(ct::TensorMap(psi_in /*psi_in.get_pointer()*/, t_type, device_type, {this->n_band_l, this->n_basis})); // Update the precondition array this->calc_prec(); @@ -264,9 +265,9 @@ void DiagoBPCG::diag(const HPsiFunc& hpsi_func, // Improving the initial guess of the wave function psi through a subspace diagonalization. this->calc_hsub_with_block(hpsi_func, psi_in, this->psi, this->hpsi, this->hsub, this->work, this->eigen); - setmem_complex_op()(this->grad_old.template data(), 0, this->n_basis * this->n_band); + setmem_complex_op()(this->grad_old.template data(), 0, this->n_basis * this->n_band_l); - setmem_var_op()(this->beta.template data(), std::numeric_limits::infinity(), this->n_band); + setmem_var_op()(this->beta.template data(), std::numeric_limits::infinity(), this->n_band_l); int ntry = 0; int max_iter = current_scf_iter > 1 ? @@ -290,7 +291,7 @@ void DiagoBPCG::diag(const HPsiFunc& hpsi_func, this->orth_projection(this->psi, this->hsub, this->grad); // this->grad_old = this->grad; - syncmem_complex_op()(this->grad_old.template data(), this->grad.template data(), n_basis * n_band); + syncmem_complex_op()(this->grad_old.template data(), this->grad.template data(), n_basis * n_band_l); // Calculate H|grad> matrix this->calc_hpsi_with_block(hpsi_func, this->grad.template data(), /*this->grad_wrapper[0],*/ this->hgrad); @@ -311,7 +312,14 @@ void DiagoBPCG::diag(const HPsiFunc& hpsi_func, this->calc_hsub_with_block_exit(this->psi, this->hpsi, this->hsub, this->work, this->eigen); - syncmem_var_d2h_op()(eigenvalue_in, this->eigen.template data(), this->n_band); + int start_nband = 0; +#ifdef __MPI + if (PARAM.inp.bndpar > 1) + { + start_nband = this->plintrans.start_colB[GlobalV::MY_BNDGROUP]; + } +#endif + syncmem_var_d2h_op()(eigenvalue_in, this->eigen.template data() + start_nband, this->n_band_l); return; } diff --git a/source/module_hsolver/para_linear_transform.cpp b/source/module_hsolver/para_linear_transform.cpp index f3a0c60d00..57b6be41a3 100644 --- a/source/module_hsolver/para_linear_transform.cpp +++ b/source/module_hsolver/para_linear_transform.cpp @@ -6,13 +6,13 @@ namespace hsolver { template void PLinearTransform::set_dimension(const int nrowA, - const int ncolA, - const int ncolB, - const int LDA, + const int ncolA, + const int ncolB, + const int LDA, #ifdef __MPI - MPI_Comm col_world, + MPI_Comm col_world, #endif - const bool localU) + const bool localU) { this->nrowA = nrowA; this->ncolA = ncolA; @@ -91,13 +91,13 @@ void PLinearTransform::act(const T alpha, const T* A, const T* U, con T real_beta = ip == 0 ? beta : 0; const int ncolA_ip = colA_loc[ip]; // get U_tmp - - const int start_row = start_colA[ip]; - for (int i = 0; i < ncolB; ++i) - { - const T* U_part = U + start_row + (i + start) * ncolA_glo; - syncmem_dev_op()(U_tmp + i * ncolA_ip, U_part, ncolA_ip); - } + + const int start_row = start_colA[ip]; + for (int i = 0; i < ncolB; ++i) + { + const T* U_part = U + start_row + (i + start) * ncolA_glo; + syncmem_dev_op()(U_tmp + i * ncolA_ip, U_part, ncolA_ip); + } if (ip == rank_col) { diff --git a/source/module_io/read_set_globalv.cpp b/source/module_io/read_set_globalv.cpp index d127837abf..064bfa640e 100755 --- a/source/module_io/read_set_globalv.cpp +++ b/source/module_io/read_set_globalv.cpp @@ -61,15 +61,10 @@ void ReadInput::set_globalv(const Input_para& inp, System_para& sys) Parallel_Common::bcast_bool(sys.double_grid); #endif /// set ks_run - if (GlobalV::MY_BNDGROUP == 0 || inp.ks_solver == "bpcg") - { - sys.ks_run = true; - } if (inp.ks_solver != "bpcg" && inp.bndpar > 1) { sys.all_ks_run = false; } - } /// @note Here para.inp has been synchronized of all ranks. diff --git a/source/module_io/write_istate_info.cpp b/source/module_io/write_istate_info.cpp index b6d14711b1..34a4dba414 100644 --- a/source/module_io/write_istate_info.cpp +++ b/source/module_io/write_istate_info.cpp @@ -41,7 +41,7 @@ void ModuleIO::write_istate_info(const ModuleBase::matrix &ekb,const ModuleBase: << std::setw(25) << "Kpoint = " << ik_global << std::setw(25) << "(" << kv.kvec_d[ik].x << " " << kv.kvec_d[ik].y << " " << kv.kvec_d[ik].z << ")" << std::endl; - for (int ib = 0; ib < PARAM.inp.nbands; ib++) + for (int ib = 0; ib < PARAM.globalv.nbands_l; ib++) { ofsi2.precision(16); ofsi2 << std::setw(6) << ib + 1 << std::setw(25) diff --git a/tests/integrate/102_PW_BPCG_BP/INPUT b/tests/integrate/102_PW_BPCG_BP/INPUT new file mode 100644 index 0000000000..43dd00ab9f --- /dev/null +++ b/tests/integrate/102_PW_BPCG_BP/INPUT @@ -0,0 +1,33 @@ +INPUT_PARAMETERS +#Parameters (General) +suffix autotest +pseudo_dir ../../PP_ORB +pw_seed 1 + +gamma_only 0 +calculation scf +symmetry 1 +out_level ie +smearing_method gaussian +smearing_sigma 0.02 + +#Parameters (3.PW) +ecutwfc 40 +scf_thr 1e-7 +scf_nmax 20 +bndpar 2 + +#Parameters (LCAO) +basis_type pw +ks_solver bpcg +device cpu +chg_extrap second-order +out_dm 0 +pw_diag_thr 0.00001 + +cal_force 1 +cal_stress 1 + +mixing_type broyden +mixing_beta 0.4 +mixing_gg0 1.5 diff --git a/tests/integrate/102_PW_BPCG_BP/KPT b/tests/integrate/102_PW_BPCG_BP/KPT new file mode 100644 index 0000000000..28006d5e2d --- /dev/null +++ b/tests/integrate/102_PW_BPCG_BP/KPT @@ -0,0 +1,4 @@ +K_POINTS +0 +Gamma +2 2 2 0 0 0 diff --git a/tests/integrate/102_PW_BPCG_BP/README b/tests/integrate/102_PW_BPCG_BP/README new file mode 100644 index 0000000000..8b809e0a25 --- /dev/null +++ b/tests/integrate/102_PW_BPCG_BP/README @@ -0,0 +1,10 @@ +This test for: +*GaAs-deformation +*PW +*bndpar 2 +*kpoints 2*2*2 +*sg15 pseudopotential +*smearing_method gauss +*ks_solver bpcg +*mixing_type broyden-kerker +*mixing_beta 0.4 diff --git a/tests/integrate/102_PW_BPCG_BP/STRU b/tests/integrate/102_PW_BPCG_BP/STRU new file mode 100644 index 0000000000..b03baadd25 --- /dev/null +++ b/tests/integrate/102_PW_BPCG_BP/STRU @@ -0,0 +1,23 @@ +ATOMIC_SPECIES +As 1 As_dojo.upf upf201 +Ga 1 Ga_dojo.upf upf201 + +LATTICE_CONSTANT +1 // add lattice constant, 10.58 ang + +LATTICE_VECTORS +5.33 5.33 0.0 +0.0 5.33 5.33 +5.33 0.0 5.33 +ATOMIC_POSITIONS +Direct //Cartesian or Direct coordinate. + +As +0 +1 +0.300000 0.3300000 0.27000000 0 0 0 + +Ga //Element Label +0 +1 //number of atom +0.00000 0.00000 0.000000 0 0 0 diff --git a/tests/integrate/102_PW_BPCG_BP/result.ref b/tests/integrate/102_PW_BPCG_BP/result.ref new file mode 100644 index 0000000000..4815e05cb4 --- /dev/null +++ b/tests/integrate/102_PW_BPCG_BP/result.ref @@ -0,0 +1,8 @@ +etotref -4869.74705201 +etotperatomref -2434.87352600 +totalforceref 5.19522000 +totalstressref 37241.49490600 +pointgroupref C_1 +spacegroupref C_1 +nksibzref 8 +totaltimeref 10.37 diff --git a/tests/integrate/184_PW_BPCG_SDFT_5D10S/INPUT b/tests/integrate/184_PW_BPCG_SDFT_5D10S/INPUT new file mode 100644 index 0000000000..dc7807efb9 --- /dev/null +++ b/tests/integrate/184_PW_BPCG_SDFT_5D10S/INPUT @@ -0,0 +1,38 @@ +INPUT_PARAMETERS +#Parameters (1.General) +suffix autotest +calculation scf +esolver_type sdft +ks_solver bpcg +method_sto 1 + +symmetry 0 +pseudo_dir ../../PP_ORB + +bndpar 2 + +nbands 5 +nbands_sto 11 + +nche_sto 120 +seed_sto 20000 +kpar 1 +cal_force 1 +cal_stress 1 + +#Parameters (2.Iteration) +ecutwfc 20 +scf_thr 1e-4 +scf_nmax 20 + + +#Parameters (3.Basis) +basis_type pw + +#Parameters (4.Smearing) +smearing_method fd +smearing_sigma 0.6 + +#Parameters (5.Mixing) +mixing_type broyden +mixing_beta 0.4 diff --git a/tests/integrate/184_PW_BPCG_SDFT_5D10S/KPT b/tests/integrate/184_PW_BPCG_SDFT_5D10S/KPT new file mode 100644 index 0000000000..c289c0158a --- /dev/null +++ b/tests/integrate/184_PW_BPCG_SDFT_5D10S/KPT @@ -0,0 +1,4 @@ +K_POINTS +0 +Gamma +1 1 1 0 0 0 diff --git a/tests/integrate/184_PW_BPCG_SDFT_5D10S/README b/tests/integrate/184_PW_BPCG_SDFT_5D10S/README new file mode 100644 index 0000000000..b150d66930 --- /dev/null +++ b/tests/integrate/184_PW_BPCG_SDFT_5D10S/README @@ -0,0 +1,9 @@ +This test for: +*SDFT +*Si +*kpoints 1*1*1 +*10 stochastic orbitals + 5 KS orbitals +*mixing_type broyden +*mixing_beta 0.4 +*seed_sto > 0 +*bndpar 2 diff --git a/tests/integrate/184_PW_BPCG_SDFT_5D10S/STRU b/tests/integrate/184_PW_BPCG_SDFT_5D10S/STRU new file mode 100644 index 0000000000..92de2f5eee --- /dev/null +++ b/tests/integrate/184_PW_BPCG_SDFT_5D10S/STRU @@ -0,0 +1,19 @@ +ATOMIC_SPECIES +Si 28 Si.pz-vbc.UPF + +LATTICE_CONSTANT +8 // add lattice constant + +LATTICE_VECTORS +1 0 0 +0 1 0 +0 0 1 + +ATOMIC_POSITIONS +Direct + +Si // Element type +0.0 // magnetism +2 +0.00 0.00 0.00 1 1 1 +0.5 0.5 0.5 1 1 1 diff --git a/tests/integrate/184_PW_BPCG_SDFT_5D10S/jd b/tests/integrate/184_PW_BPCG_SDFT_5D10S/jd new file mode 100644 index 0000000000..e87c7a3930 --- /dev/null +++ b/tests/integrate/184_PW_BPCG_SDFT_5D10S/jd @@ -0,0 +1 @@ +test parallel method bndpar with BPCG. Must be run by MPI with 4 cores, otherwise please ignore this test. diff --git a/tests/integrate/184_PW_BPCG_SDFT_5D10S/result.ref b/tests/integrate/184_PW_BPCG_SDFT_5D10S/result.ref new file mode 100644 index 0000000000..aa4f0a9afa --- /dev/null +++ b/tests/integrate/184_PW_BPCG_SDFT_5D10S/result.ref @@ -0,0 +1,5 @@ +etotref -323.3079867689583580 +etotperatomref -161.6539933845 +totalforceref 5.025088 +totalstressref 1830.665700 +totaltimeref 2.55 From 65e4ba199c539060978ae49266e96a9c3118837a Mon Sep 17 00:00:00 2001 From: Qianruipku Date: Mon, 27 Jan 2025 16:47:40 +0800 Subject: [PATCH 08/17] fix bug in sDFT-BPCG --- source/module_elecstate/module_charge/charge_mpi.cpp | 2 +- source/module_hamilt_pw/hamilt_stodft/sto_iter.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/source/module_elecstate/module_charge/charge_mpi.cpp b/source/module_elecstate/module_charge/charge_mpi.cpp index e9731ee99e..b2a93aab60 100644 --- a/source/module_elecstate/module_charge/charge_mpi.cpp +++ b/source/module_elecstate/module_charge/charge_mpi.cpp @@ -119,7 +119,7 @@ void Charge::reduce_diff_pools(double* array_rho) const void Charge::rho_mpi() { ModuleBase::TITLE("Charge", "rho_mpi"); - if (GlobalV::KPAR <= 1) { + if (GlobalV::KPAR * PARAM.inp.bndpar <= 1) { return; } ModuleBase::timer::tick("Charge", "rho_mpi"); diff --git a/source/module_hamilt_pw/hamilt_stodft/sto_iter.cpp b/source/module_hamilt_pw/hamilt_stodft/sto_iter.cpp index 4c0aff7469..369bd52c41 100644 --- a/source/module_hamilt_pw/hamilt_stodft/sto_iter.cpp +++ b/source/module_hamilt_pw/hamilt_stodft/sto_iter.cpp @@ -668,7 +668,7 @@ void Stochastic_Iter::cal_storho(const UnitCell& ucell, delmem_complex_op()(porter); #ifdef __MPI - if(GlobalV::KPAR > 1) + if(GlobalV::KPAR * PARAM.inp.bndpar > 1) { for (int is = 0; is < nspin; ++is) { From 30bc4ade45ec7beecbaf053947e861f684e7207b Mon Sep 17 00:00:00 2001 From: Qianruipku Date: Mon, 27 Jan 2025 21:04:00 +0800 Subject: [PATCH 09/17] make sdft+bpcg support GPU --- source/module_base/para_gemm.cpp | 7 ++-- tests/integrate/102_PW_BPCG_GPU/INPUT | 1 + tests/integrate/102_PW_BPCG_GPU/result.ref | 10 ++--- .../integrate/187_PW_SDFT_MALL_BPCG_GPU/INPUT | 39 +++++++++++++++++++ tests/integrate/187_PW_SDFT_MALL_BPCG_GPU/KPT | 4 ++ .../187_PW_SDFT_MALL_BPCG_GPU/README | 12 ++++++ .../integrate/187_PW_SDFT_MALL_BPCG_GPU/STRU | 19 +++++++++ tests/integrate/187_PW_SDFT_MALL_BPCG_GPU/jd | 1 + .../187_PW_SDFT_MALL_BPCG_GPU/result.ref | 5 +++ tests/integrate/CASES_CPU.txt | 2 + tests/integrate/CASES_GPU.txt | 1 + 11 files changed, 93 insertions(+), 8 deletions(-) create mode 100644 tests/integrate/187_PW_SDFT_MALL_BPCG_GPU/INPUT create mode 100644 tests/integrate/187_PW_SDFT_MALL_BPCG_GPU/KPT create mode 100644 tests/integrate/187_PW_SDFT_MALL_BPCG_GPU/README create mode 100644 tests/integrate/187_PW_SDFT_MALL_BPCG_GPU/STRU create mode 100644 tests/integrate/187_PW_SDFT_MALL_BPCG_GPU/jd create mode 100644 tests/integrate/187_PW_SDFT_MALL_BPCG_GPU/result.ref diff --git a/source/module_base/para_gemm.cpp b/source/module_base/para_gemm.cpp index 01eb7edc79..a7c6ef7d7d 100644 --- a/source/module_base/para_gemm.cpp +++ b/source/module_base/para_gemm.cpp @@ -241,7 +241,8 @@ void PGemmCN::multiply_col(const T alpha, const T* A, const T* B, con { T* Cglobal_cpu = nullptr; T* Clocal_cpu = C_tmp.data(); - ; + std::vector cpu_tmp; + if (std::is_same::value) { delmem_dev_op()(Atmp_device); @@ -249,7 +250,8 @@ void PGemmCN::multiply_col(const T alpha, const T* A, const T* B, con syncmem_d2h_op()(Clocal_cpu, C_local, size_C_local); delmem_dev_op()(C_local); - resmem_dev_op()(Cglobal_cpu, size_C_global); + cpu_tmp.resize(size_C_global); + Cglobal_cpu = cpu_tmp.data(); } else { @@ -269,7 +271,6 @@ void PGemmCN::multiply_col(const T alpha, const T* A, const T* B, con if (std::is_same::value) { syncmem_h2d_op()(C, Cglobal_cpu, size_C_global); - delmem_dev_op()(Cglobal_cpu); } } else diff --git a/tests/integrate/102_PW_BPCG_GPU/INPUT b/tests/integrate/102_PW_BPCG_GPU/INPUT index 1f24602487..57da09d770 100644 --- a/tests/integrate/102_PW_BPCG_GPU/INPUT +++ b/tests/integrate/102_PW_BPCG_GPU/INPUT @@ -15,6 +15,7 @@ smearing_sigma 0.02 ecutwfc 40 scf_thr 1e-7 scf_nmax 100 +bndpar 2 #Parameters (LCAO) basis_type pw diff --git a/tests/integrate/102_PW_BPCG_GPU/result.ref b/tests/integrate/102_PW_BPCG_GPU/result.ref index 41371e10fd..4c3b56aed5 100644 --- a/tests/integrate/102_PW_BPCG_GPU/result.ref +++ b/tests/integrate/102_PW_BPCG_GPU/result.ref @@ -1,8 +1,8 @@ -etotref -4869.7470518350019120 -etotperatomref -2434.8735259175 -totalforceref 5.207670 -totalstressref 37241.465646 +etotref -4869.7470520365577613 +etotperatomref -2434.8735260183 +totalforceref 5.202524 +totalstressref 37241.827525 pointgroupref C_1 spacegroupref C_1 nksibzref 8 -totaltimeref 4.25 +totaltimeref 4.25 diff --git a/tests/integrate/187_PW_SDFT_MALL_BPCG_GPU/INPUT b/tests/integrate/187_PW_SDFT_MALL_BPCG_GPU/INPUT new file mode 100644 index 0000000000..e25e9f0f49 --- /dev/null +++ b/tests/integrate/187_PW_SDFT_MALL_BPCG_GPU/INPUT @@ -0,0 +1,39 @@ +INPUT_PARAMETERS +#Parameters (1.General) +suffix autotest +calculation scf +esolver_type sdft +method_sto 2 +device gpu +ks_solver bpcg + +symmetry 0 +pseudo_dir ../../PP_ORB + +kpar 1 +bndpar 2 + +nbands 9 +nbands_sto all + +nche_sto 120 +seed_sto 20000 +cal_force 1 +cal_stress 1 + +#Parameters (2.Iteration) +ecutwfc 30 +scf_thr 1e-6 +scf_nmax 20 + + +#Parameters (3.Basis) +basis_type pw + +#Parameters (4.Smearing) +smearing_method fd +smearing_sigma 0.6 + +#Parameters (5.Mixing) +mixing_type plain +mixing_beta 0.7 diff --git a/tests/integrate/187_PW_SDFT_MALL_BPCG_GPU/KPT b/tests/integrate/187_PW_SDFT_MALL_BPCG_GPU/KPT new file mode 100644 index 0000000000..c289c0158a --- /dev/null +++ b/tests/integrate/187_PW_SDFT_MALL_BPCG_GPU/KPT @@ -0,0 +1,4 @@ +K_POINTS +0 +Gamma +1 1 1 0 0 0 diff --git a/tests/integrate/187_PW_SDFT_MALL_BPCG_GPU/README b/tests/integrate/187_PW_SDFT_MALL_BPCG_GPU/README new file mode 100644 index 0000000000..c1f7d02359 --- /dev/null +++ b/tests/integrate/187_PW_SDFT_MALL_BPCG_GPU/README @@ -0,0 +1,12 @@ +This test for: +*SDFT +*Si +*kpoints 1*1*1 +*10 KS + complete orbitals +*ks_solver bpcg +*mixing_type plain +*mixing_beta 0.7 +*sto_method 2 +*seed_sto > 0 +*bndpar 2 +*kpar 1 diff --git a/tests/integrate/187_PW_SDFT_MALL_BPCG_GPU/STRU b/tests/integrate/187_PW_SDFT_MALL_BPCG_GPU/STRU new file mode 100644 index 0000000000..28a94e8c0c --- /dev/null +++ b/tests/integrate/187_PW_SDFT_MALL_BPCG_GPU/STRU @@ -0,0 +1,19 @@ +ATOMIC_SPECIES +Si 28 Si.pz-vbc.UPF + +LATTICE_CONSTANT +5 // add lattice constant + +LATTICE_VECTORS +0.5 0 0.5 +0.5 0.5 0 +0 0.5 0.5 + +ATOMIC_POSITIONS +Direct + +Si // Element type +0.0 // magnetism +2 +0.10 0.00 0.20 1 1 1 +0.5 0.5 0.5 1 1 1 diff --git a/tests/integrate/187_PW_SDFT_MALL_BPCG_GPU/jd b/tests/integrate/187_PW_SDFT_MALL_BPCG_GPU/jd new file mode 100644 index 0000000000..44886f5d13 --- /dev/null +++ b/tests/integrate/187_PW_SDFT_MALL_BPCG_GPU/jd @@ -0,0 +1 @@ +test parallel method bndpar with BPCG and GPU diff --git a/tests/integrate/187_PW_SDFT_MALL_BPCG_GPU/result.ref b/tests/integrate/187_PW_SDFT_MALL_BPCG_GPU/result.ref new file mode 100644 index 0000000000..b4dd3812d2 --- /dev/null +++ b/tests/integrate/187_PW_SDFT_MALL_BPCG_GPU/result.ref @@ -0,0 +1,5 @@ +etotref -96.9361135992121490 +etotperatomref -48.4680567996 +totalforceref 248.977364 +totalstressref 230452.899030 +totaltimeref 7.19 diff --git a/tests/integrate/CASES_CPU.txt b/tests/integrate/CASES_CPU.txt index db834ee8e0..10815ba5e5 100644 --- a/tests/integrate/CASES_CPU.txt +++ b/tests/integrate/CASES_CPU.txt @@ -24,6 +24,7 @@ 101_PW_GTH_CF_CS_Si 102_PW_DA_davidson 102_PW_BPCG +102_PW_BPCG_BP 102_PW_CG 102_PW_DS_davsubspace 102_PW_DS_davsubspace_sca @@ -123,6 +124,7 @@ 184_PW_BNDKPAR_SDFT_MALL 184_PW_BNDPAR_SDFT_10S 184_PW_BNDPAR_SDFT_5D10S +184_PW_BPCG_SDFT_5D10S 184_PW_KPAR_SDFT_ALL 185_PW_SDFT_10D10S_METHD2 185_PW_SDFT_10S_METHD2 diff --git a/tests/integrate/CASES_GPU.txt b/tests/integrate/CASES_GPU.txt index 3490ccc9a0..44fbcccdad 100644 --- a/tests/integrate/CASES_GPU.txt +++ b/tests/integrate/CASES_GPU.txt @@ -3,6 +3,7 @@ 102_PW_BPCG_GPU 187_PW_SDFT_ALL_GPU 187_PW_SDFT_MALL_GPU +187_PW_SDFT_MALL_BPCG_GPU 187_PW_MD_SDFT_ALL_GPU 930_NO_BI2SE2CU2O2_GPU 930_NO_BI2SE2CU2O2_k_GPU From 0664852419b4fd2b9484f99a3fc2ea0a074496bb Mon Sep 17 00:00:00 2001 From: Qianruipku Date: Mon, 27 Jan 2025 21:51:45 +0800 Subject: [PATCH 10/17] update results --- tests/integrate/184_PW_BPCG_SDFT_5D10S/result.ref | 5 ----- .../{184_PW_BPCG_SDFT_5D10S => 184_PW_BPCG_SDFT_5D11S}/INPUT | 0 .../{184_PW_BPCG_SDFT_5D10S => 184_PW_BPCG_SDFT_5D11S}/KPT | 0 .../README | 0 .../{184_PW_BPCG_SDFT_5D10S => 184_PW_BPCG_SDFT_5D11S}/STRU | 0 .../{184_PW_BPCG_SDFT_5D10S => 184_PW_BPCG_SDFT_5D11S}/jd | 0 tests/integrate/184_PW_BPCG_SDFT_5D11S/result.ref | 5 +++++ tests/integrate/CASES_CPU.txt | 2 +- 8 files changed, 6 insertions(+), 6 deletions(-) delete mode 100644 tests/integrate/184_PW_BPCG_SDFT_5D10S/result.ref rename tests/integrate/{184_PW_BPCG_SDFT_5D10S => 184_PW_BPCG_SDFT_5D11S}/INPUT (100%) rename tests/integrate/{184_PW_BPCG_SDFT_5D10S => 184_PW_BPCG_SDFT_5D11S}/KPT (100%) rename tests/integrate/{184_PW_BPCG_SDFT_5D10S => 184_PW_BPCG_SDFT_5D11S}/README (100%) rename tests/integrate/{184_PW_BPCG_SDFT_5D10S => 184_PW_BPCG_SDFT_5D11S}/STRU (100%) rename tests/integrate/{184_PW_BPCG_SDFT_5D10S => 184_PW_BPCG_SDFT_5D11S}/jd (100%) create mode 100644 tests/integrate/184_PW_BPCG_SDFT_5D11S/result.ref diff --git a/tests/integrate/184_PW_BPCG_SDFT_5D10S/result.ref b/tests/integrate/184_PW_BPCG_SDFT_5D10S/result.ref deleted file mode 100644 index aa4f0a9afa..0000000000 --- a/tests/integrate/184_PW_BPCG_SDFT_5D10S/result.ref +++ /dev/null @@ -1,5 +0,0 @@ -etotref -323.3079867689583580 -etotperatomref -161.6539933845 -totalforceref 5.025088 -totalstressref 1830.665700 -totaltimeref 2.55 diff --git a/tests/integrate/184_PW_BPCG_SDFT_5D10S/INPUT b/tests/integrate/184_PW_BPCG_SDFT_5D11S/INPUT similarity index 100% rename from tests/integrate/184_PW_BPCG_SDFT_5D10S/INPUT rename to tests/integrate/184_PW_BPCG_SDFT_5D11S/INPUT diff --git a/tests/integrate/184_PW_BPCG_SDFT_5D10S/KPT b/tests/integrate/184_PW_BPCG_SDFT_5D11S/KPT similarity index 100% rename from tests/integrate/184_PW_BPCG_SDFT_5D10S/KPT rename to tests/integrate/184_PW_BPCG_SDFT_5D11S/KPT diff --git a/tests/integrate/184_PW_BPCG_SDFT_5D10S/README b/tests/integrate/184_PW_BPCG_SDFT_5D11S/README similarity index 100% rename from tests/integrate/184_PW_BPCG_SDFT_5D10S/README rename to tests/integrate/184_PW_BPCG_SDFT_5D11S/README diff --git a/tests/integrate/184_PW_BPCG_SDFT_5D10S/STRU b/tests/integrate/184_PW_BPCG_SDFT_5D11S/STRU similarity index 100% rename from tests/integrate/184_PW_BPCG_SDFT_5D10S/STRU rename to tests/integrate/184_PW_BPCG_SDFT_5D11S/STRU diff --git a/tests/integrate/184_PW_BPCG_SDFT_5D10S/jd b/tests/integrate/184_PW_BPCG_SDFT_5D11S/jd similarity index 100% rename from tests/integrate/184_PW_BPCG_SDFT_5D10S/jd rename to tests/integrate/184_PW_BPCG_SDFT_5D11S/jd diff --git a/tests/integrate/184_PW_BPCG_SDFT_5D11S/result.ref b/tests/integrate/184_PW_BPCG_SDFT_5D11S/result.ref new file mode 100644 index 0000000000..92a930773a --- /dev/null +++ b/tests/integrate/184_PW_BPCG_SDFT_5D11S/result.ref @@ -0,0 +1,5 @@ +etotref -323.2716997598804483 +etotperatomref -161.6358498799 +totalforceref 5.765664 +totalstressref 1782.352362 +totaltimeref 2.80 diff --git a/tests/integrate/CASES_CPU.txt b/tests/integrate/CASES_CPU.txt index 10815ba5e5..09897e4bcd 100644 --- a/tests/integrate/CASES_CPU.txt +++ b/tests/integrate/CASES_CPU.txt @@ -124,7 +124,7 @@ 184_PW_BNDKPAR_SDFT_MALL 184_PW_BNDPAR_SDFT_10S 184_PW_BNDPAR_SDFT_5D10S -184_PW_BPCG_SDFT_5D10S +184_PW_BPCG_SDFT_5D11S 184_PW_KPAR_SDFT_ALL 185_PW_SDFT_10D10S_METHD2 185_PW_SDFT_10S_METHD2 From 32a204aafd2663d747ca85d107e93fa10213e5c7 Mon Sep 17 00:00:00 2001 From: Qianruipku Date: Mon, 27 Jan 2025 21:41:39 +0700 Subject: [PATCH 11/17] fix bug in BPCG --- source/module_hsolver/diago_bpcg.cpp | 4 ++-- tests/integrate/102_PW_BPCG_BP/result.ref | 10 +++++----- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/source/module_hsolver/diago_bpcg.cpp b/source/module_hsolver/diago_bpcg.cpp index 743141ff5a..92fbb16b58 100644 --- a/source/module_hsolver/diago_bpcg.cpp +++ b/source/module_hsolver/diago_bpcg.cpp @@ -96,7 +96,7 @@ void DiagoBPCG::line_minimize( ct::Tensor& psi_out, ct::Tensor& hpsi_out) { - line_minimize_with_block_op()(grad_in.data(), hgrad_in.data(), psi_out.data(), hpsi_out.data(), this->n_basis, this->n_basis, this->n_band_l); + line_minimize_with_block_op()(grad_in.data(), hgrad_in.data(), psi_out.data(), hpsi_out.data(), this->n_dim, this->n_basis, this->n_band_l); } @@ -142,7 +142,7 @@ void DiagoBPCG::calc_grad_with_block( hpsi_in.data(), grad_out.data(), grad_old_out.data(), - this->n_basis, + this->n_dim, this->n_basis, this->n_band_l); } diff --git a/tests/integrate/102_PW_BPCG_BP/result.ref b/tests/integrate/102_PW_BPCG_BP/result.ref index 4815e05cb4..8e4563fef6 100644 --- a/tests/integrate/102_PW_BPCG_BP/result.ref +++ b/tests/integrate/102_PW_BPCG_BP/result.ref @@ -1,8 +1,8 @@ -etotref -4869.74705201 -etotperatomref -2434.87352600 -totalforceref 5.19522000 -totalstressref 37241.49490600 +etotref -4869.7470520063843651 +etotperatomref -2434.8735260032 +totalforceref 5.194830 +totalstressref 37241.448435 pointgroupref C_1 spacegroupref C_1 nksibzref 8 -totaltimeref 10.37 +totaltimeref 5.42 From 6059cf03693abfc835c5734b4a9b7e485ed666f3 Mon Sep 17 00:00:00 2001 From: Qianruipku Date: Tue, 28 Jan 2025 00:22:26 +0800 Subject: [PATCH 12/17] fix tests --- source/module_elecstate/test/elecstate_base_test.cpp | 1 + source/module_elecstate/test/elecstate_print_test.cpp | 1 + source/module_hamilt_pw/hamilt_stodft/sto_iter.cpp | 4 ++++ source/module_hsolver/diago_bpcg.cpp | 3 +-- source/module_hsolver/test/diago_bpcg_test.cpp | 1 + source/module_io/test/write_istate_info_test.cpp | 2 ++ 6 files changed, 10 insertions(+), 2 deletions(-) diff --git a/source/module_elecstate/test/elecstate_base_test.cpp b/source/module_elecstate/test/elecstate_base_test.cpp index 9e393f7323..9a8cd34d66 100644 --- a/source/module_elecstate/test/elecstate_base_test.cpp +++ b/source/module_elecstate/test/elecstate_base_test.cpp @@ -128,6 +128,7 @@ class MockElecState : public ElecState PARAM.input.nupdown = 0.0; PARAM.sys.two_fermi = false; PARAM.input.nbands = 6; + PARAM.sys.nbands_l = 6; PARAM.sys.nlocal = 6; PARAM.input.esolver_type = "ksdft"; PARAM.input.lspinorb = false; diff --git a/source/module_elecstate/test/elecstate_print_test.cpp b/source/module_elecstate/test/elecstate_print_test.cpp index 5754dcdd2a..1ab1337540 100644 --- a/source/module_elecstate/test/elecstate_print_test.cpp +++ b/source/module_elecstate/test/elecstate_print_test.cpp @@ -164,6 +164,7 @@ TEST_F(ElecStatePrintTest, PrintBand) { PARAM.input.nspin = 1; PARAM.input.nbands = 2; + PARAM.sys.nbands_l = 2; GlobalV::MY_RANK = 0; GlobalV::ofs_running.open("test.dat", std::ios::out); // print eigenvalue diff --git a/source/module_hamilt_pw/hamilt_stodft/sto_iter.cpp b/source/module_hamilt_pw/hamilt_stodft/sto_iter.cpp index 369bd52c41..f3b53081db 100644 --- a/source/module_hamilt_pw/hamilt_stodft/sto_iter.cpp +++ b/source/module_hamilt_pw/hamilt_stodft/sto_iter.cpp @@ -673,6 +673,10 @@ void Stochastic_Iter::cal_storho(const UnitCell& ucell, for (int is = 0; is < nspin; ++is) { pes->charge->reduce_diff_pools(sto_rho[is]); + if (!PARAM.globalv.all_ks_run && PARAM.inp.bndpar > 1) + { + MPI_Allreduce(MPI_IN_PLACE, sto_rho[is], nrxx, MPI_DOUBLE, MPI_SUM, BP_WORLD); + } } } #endif diff --git a/source/module_hsolver/diago_bpcg.cpp b/source/module_hsolver/diago_bpcg.cpp index 92fbb16b58..24a1a65468 100644 --- a/source/module_hsolver/diago_bpcg.cpp +++ b/source/module_hsolver/diago_bpcg.cpp @@ -12,7 +12,6 @@ #include "module_base/global_function.h" #include "module_base/kernels/math_kernel_op.h" #include "para_linear_transform.h" -#include "module_parameter/parameter.h" namespace hsolver { @@ -314,7 +313,7 @@ void DiagoBPCG::diag(const HPsiFunc& hpsi_func, int start_nband = 0; #ifdef __MPI - if (PARAM.inp.bndpar > 1) + if (this->plintrans.nproc_col > 1) { start_nband = this->plintrans.start_colB[GlobalV::MY_BNDGROUP]; } diff --git a/source/module_hsolver/test/diago_bpcg_test.cpp b/source/module_hsolver/test/diago_bpcg_test.cpp index 369fcb4ec6..7f2931c507 100644 --- a/source/module_hsolver/test/diago_bpcg_test.cpp +++ b/source/module_hsolver/test/diago_bpcg_test.cpp @@ -283,6 +283,7 @@ int main(int argc, char **argv) int nproc_in_pool, kpar=1, mypool, rank_in_pool; setupmpi(argc,argv,nproc, myrank); divide_pools(nproc, myrank, nproc_in_pool, kpar, mypool, rank_in_pool); + MPI_Comm_split(MPI_COMM_WORLD,myrank,0,&BP_WORLD); GlobalV::NPROC_IN_POOL = nproc; #else MPI_Init(&argc, &argv); diff --git a/source/module_io/test/write_istate_info_test.cpp b/source/module_io/test/write_istate_info_test.cpp index ed69f2c5c6..9e95d17fa4 100644 --- a/source/module_io/test/write_istate_info_test.cpp +++ b/source/module_io/test/write_istate_info_test.cpp @@ -46,6 +46,7 @@ TEST_F(IstateInfoTest, OutIstateInfoS1) // preconditions GlobalV::KPAR = 1; PARAM.input.nbands = 4; + PARAM.sys.nbands_l = 4; PARAM.input.nspin = 1; PARAM.sys.global_out_dir = "./"; // mpi setting @@ -96,6 +97,7 @@ TEST_F(IstateInfoTest, OutIstateInfoS2) // preconditions GlobalV::KPAR = 1; PARAM.input.nbands = 4; + PARAM.sys.nbands_l = 4; PARAM.input.nspin = 2; PARAM.sys.global_out_dir = "./"; // mpi setting From 4941bad99109a3334750417a1c96e27def7ab8d0 Mon Sep 17 00:00:00 2001 From: Qianruipku Date: Tue, 28 Jan 2025 11:36:36 +0800 Subject: [PATCH 13/17] fix test --- source/module_hsolver/diago_bpcg.cpp | 11 ++++++++--- tests/integrate/102_PW_BPCG/result.ref | 10 +++++----- 2 files changed, 13 insertions(+), 8 deletions(-) mode change 100644 => 100755 source/module_hsolver/diago_bpcg.cpp diff --git a/source/module_hsolver/diago_bpcg.cpp b/source/module_hsolver/diago_bpcg.cpp old mode 100644 new mode 100755 index 24a1a65468..2a03fe13b9 --- a/source/module_hsolver/diago_bpcg.cpp +++ b/source/module_hsolver/diago_bpcg.cpp @@ -70,11 +70,16 @@ void DiagoBPCG::init_iter(const int nband, const int nband_l, const i template bool DiagoBPCG::test_error(const ct::Tensor& err_in, const std::vector& ethr_band) { - const Real * _err_st = err_in.data(); + Real* _err_st = err_in.data(); bool not_conv = false; + std::vector tmp_cpu; if (err_in.device_type() == ct::DeviceType::GpuDevice) { - ct::Tensor h_err_in = err_in.to_device(); - _err_st = h_err_in.data(); + // ct::Tensor h_err_in = err_in.to_device(); + // _err_st = h_err_in.data(); + // qianrui change it, because it can not pass the valgrind test + tmp_cpu.resize(this->n_band_l); + _err_st = tmp_cpu.data(); + syncmem_var_d2h_op()(_err_st, err_in.data(), this->n_band_l); } for (int ii = 0; ii < this->n_band_l; ii++) { if (_err_st[ii] > ethr_band[ii]) { diff --git a/tests/integrate/102_PW_BPCG/result.ref b/tests/integrate/102_PW_BPCG/result.ref index 4815e05cb4..6ad0913957 100644 --- a/tests/integrate/102_PW_BPCG/result.ref +++ b/tests/integrate/102_PW_BPCG/result.ref @@ -1,8 +1,8 @@ -etotref -4869.74705201 -etotperatomref -2434.87352600 -totalforceref 5.19522000 -totalstressref 37241.49490600 +etotref -4869.7470520063843651 +etotperatomref -2434.8735260032 +totalforceref 5.194830 +totalstressref 37241.448435 pointgroupref C_1 spacegroupref C_1 nksibzref 8 -totaltimeref 10.37 +totaltimeref 5.53 From 6970f4c14c867c0dffedacb4d1b6cd9fcd026d66 Mon Sep 17 00:00:00 2001 From: Qianruipku Date: Tue, 28 Jan 2025 13:56:58 +0800 Subject: [PATCH 14/17] update results --- tests/integrate/187_PW_SDFT_MALL_BPCG_GPU/INPUT | 2 +- tests/integrate/187_PW_SDFT_MALL_BPCG_GPU/result.ref | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/integrate/187_PW_SDFT_MALL_BPCG_GPU/INPUT b/tests/integrate/187_PW_SDFT_MALL_BPCG_GPU/INPUT index e25e9f0f49..8d35011db8 100644 --- a/tests/integrate/187_PW_SDFT_MALL_BPCG_GPU/INPUT +++ b/tests/integrate/187_PW_SDFT_MALL_BPCG_GPU/INPUT @@ -13,7 +13,7 @@ pseudo_dir ../../PP_ORB kpar 1 bndpar 2 -nbands 9 +nbands 7 nbands_sto all nche_sto 120 diff --git a/tests/integrate/187_PW_SDFT_MALL_BPCG_GPU/result.ref b/tests/integrate/187_PW_SDFT_MALL_BPCG_GPU/result.ref index b4dd3812d2..bdf5186378 100644 --- a/tests/integrate/187_PW_SDFT_MALL_BPCG_GPU/result.ref +++ b/tests/integrate/187_PW_SDFT_MALL_BPCG_GPU/result.ref @@ -1,5 +1,5 @@ -etotref -96.9361135992121490 -etotperatomref -48.4680567996 -totalforceref 248.977364 -totalstressref 230452.899030 -totaltimeref 7.19 +etotref -96.93611634 +etotperatomref -48.46805817 +totalforceref 248.97836000 +totalstressref 230453.08774800 +totaltimeref 6.37 \ No newline at end of file From a1b041fa20f13b37a56fd009f4fbcaca247db1f7 Mon Sep 17 00:00:00 2001 From: Qianruipku Date: Thu, 30 Jan 2025 23:57:42 +0800 Subject: [PATCH 15/17] update results --- source/module_base/para_gemm.cpp | 7 +- source/module_base/parallel_device.cpp | 16 ++ source/module_base/parallel_device.h | 24 ++- .../module_hsolver/para_linear_transform.cpp | 4 +- source/module_io/read_input_item_system.cpp | 2 +- source/module_psi/psi_init.cpp | 149 +++++++++++------- tests/integrate/102_PW_BPCG_GPU/result.ref | 10 +- .../integrate/187_PW_SDFT_MALL_BPCG_GPU/INPUT | 2 +- .../187_PW_SDFT_MALL_BPCG_GPU/result.ref | 10 +- 9 files changed, 154 insertions(+), 70 deletions(-) diff --git a/source/module_base/para_gemm.cpp b/source/module_base/para_gemm.cpp index a7c6ef7d7d..2181a5b84d 100644 --- a/source/module_base/para_gemm.cpp +++ b/source/module_base/para_gemm.cpp @@ -153,12 +153,17 @@ void PGemmCN::multiply_col(const T alpha, const T* A, const T* B, con const Device* ctx = {}; std::vector B_tmp(max_colA * LDA); + std::vector isend_tmp; + if (std::is_same::value) + { + isend_tmp.resize(max_colA * LDA); + } for (int ip = 0; ip < col_nproc; ip++) { if (col_rank != ip) { int size = ncolA * LDA; - Parallel_Common::isend_dev(A, size, ip, 0, col_world, &requests[ip], B_tmp.data()); + Parallel_Common::isend_dev(A, size, ip, 0, col_world, &requests[ip], isend_tmp.data()); } } diff --git a/source/module_base/parallel_device.cpp b/source/module_base/parallel_device.cpp index d7373674d6..b5ade6c56f 100644 --- a/source/module_base/parallel_device.cpp +++ b/source/module_base/parallel_device.cpp @@ -18,6 +18,22 @@ void isend_data(const std::complex* buf, int count, int dest, int tag, MP { MPI_Isend(buf, count, MPI_COMPLEX, dest, tag, comm, request); } +void send_data(const double* buf, int count, int dest, int tag, MPI_Comm& comm) +{ + MPI_Send(buf, count, MPI_DOUBLE, dest, tag, comm); +} +void send_data(const std::complex* buf, int count, int dest, int tag, MPI_Comm& comm) +{ + MPI_Send(buf, count, MPI_DOUBLE_COMPLEX, dest, tag, comm); +} +void send_data(const float* buf, int count, int dest, int tag, MPI_Comm& comm) +{ + MPI_Send(buf, count, MPI_FLOAT, dest, tag, comm); +} +void send_data(const std::complex* buf, int count, int dest, int tag, MPI_Comm& comm) +{ + MPI_Send(buf, count, MPI_COMPLEX, dest, tag, comm); +} void recv_data(double* buf, int count, int source, int tag, MPI_Comm& comm, MPI_Status* status) { MPI_Recv(buf, count, MPI_DOUBLE, source, tag, comm, status); diff --git a/source/module_base/parallel_device.h b/source/module_base/parallel_device.h index 776de4e755..5b225de9dc 100644 --- a/source/module_base/parallel_device.h +++ b/source/module_base/parallel_device.h @@ -11,6 +11,10 @@ void isend_data(const double* buf, int count, int dest, int tag, MPI_Comm& comm, void isend_data(const std::complex* buf, int count, int dest, int tag, MPI_Comm& comm, MPI_Request* request); void isend_data(const float* buf, int count, int dest, int tag, MPI_Comm& comm, MPI_Request* request); void isend_data(const std::complex* buf, int count, int dest, int tag, MPI_Comm& comm, MPI_Request* request); +void send_data(const double* buf, int count, int dest, int tag, MPI_Comm& comm); +void send_data(const std::complex* buf, int count, int dest, int tag, MPI_Comm& comm); +void send_data(const float* buf, int count, int dest, int tag, MPI_Comm& comm); +void send_data(const std::complex* buf, int count, int dest, int tag, MPI_Comm& comm); void recv_data(double* buf, int count, int source, int tag, MPI_Comm& comm, MPI_Status* status); void recv_data(std::complex* buf, int count, int source, int tag, MPI_Comm& comm, MPI_Status* status); void recv_data(float* buf, int count, int source, int tag, MPI_Comm& comm, MPI_Status* status); @@ -39,15 +43,31 @@ struct object_cpu_point }; /** - * @brief isend data in Device + * @brief send data in Device * */ template -void isend_dev(const T* object, int count, int dest, int tag, MPI_Comm& comm, MPI_Request* request, T* tmp_space = nullptr) +void send_dev(const T* object, int count, int dest, int tag, MPI_Comm& comm, T* tmp_space = nullptr) { object_cpu_point o; T* object_cpu = o.get(object, count, tmp_space); o.sync_d2h(object_cpu, object, count); + send_data(object_cpu, count, dest, tag, comm); + o.del(object_cpu); + return; +} + +/** + * @brief isend data in Device + * @note before the date in send_space is recieved, it should not be modified + * + */ +template +void isend_dev(const T* object, int count, int dest, int tag, MPI_Comm& comm, MPI_Request* request, T* send_space) +{ + object_cpu_point o; + T* object_cpu = o.get(object, count, send_space); + o.sync_d2h(object_cpu, object, count); isend_data(object_cpu, count, dest, tag, comm, request); o.del(object_cpu); return; diff --git a/source/module_hsolver/para_linear_transform.cpp b/source/module_hsolver/para_linear_transform.cpp index 57b6be41a3..17e267101f 100644 --- a/source/module_hsolver/para_linear_transform.cpp +++ b/source/module_hsolver/para_linear_transform.cpp @@ -60,10 +60,12 @@ void PLinearTransform::act(const T alpha, const T* A, const T* U, con { std::vector requests(nproc_col); std::vector A_tmp(max_colA * LDA); + std::vector isend_tmp; T* A_tmp_device = A_tmp.data(); if (std::is_same::value) { A_tmp_device = nullptr; + isend_tmp.resize(max_colA * LDA); resmem_dev_op()(A_tmp_device, max_colA * LDA); } T* B_tmp = nullptr; @@ -80,7 +82,7 @@ void PLinearTransform::act(const T alpha, const T* A, const T* U, con if (rank_col != ip) { int size = LDA * ncolA; - Parallel_Common::isend_dev(A, size, ip, 0, col_world, &requests[ip], A_tmp.data()); + Parallel_Common::isend_dev(A, size, ip, 0, col_world, &requests[ip], isend_tmp.data()); } } diff --git a/source/module_io/read_input_item_system.cpp b/source/module_io/read_input_item_system.cpp index 8d21a9f070..bf88ee08b6 100644 --- a/source/module_io/read_input_item_system.cpp +++ b/source/module_io/read_input_item_system.cpp @@ -257,7 +257,7 @@ void ReadInput::item_system() "will be distributed among each group"; read_sync_int(input.bndpar); item.reset_value = [](const Input_Item& item, Parameter& para) { - if (para.input.esolver_type != "sdft") + if (para.input.esolver_type != "sdft" && para.input.ks_solver != "bpcg") { para.input.bndpar = 1; } diff --git a/source/module_psi/psi_init.cpp b/source/module_psi/psi_init.cpp index fcd07622c2..d89b8ab477 100644 --- a/source/module_psi/psi_init.cpp +++ b/source/module_psi/psi_init.cpp @@ -2,6 +2,7 @@ #include "module_base/macros.h" #include "module_base/memory.h" +#include "module_base/parallel_device.h" #include "module_base/timer.h" #include "module_base/tool_quit.h" #include "module_hsolver/diago_iter_assist.h" @@ -40,8 +41,8 @@ void PSIInit::prepare_init(const int& random_seed) // use new instead, but will cause asymmetric allocation and deallocation, in literal aspect ModuleBase::timer::tick("PSIInit", "prepare_init"); this->psi_initer.reset(); - if (this->init_wfc == "random" || (PARAM.inp.ks_solver == "bpcg" && PARAM.inp.bndpar > 1)) - { //temporary solution for band parallel bpcg + if (this->init_wfc == "random") + { this->psi_initer = std::unique_ptr>(new psi_initializer_random()); } else if (this->init_wfc == "file") @@ -97,30 +98,34 @@ void PSIInit::initialize_psi(Psi>* psi, ModuleBase::timer::tick("PSIInit", "initialize_psi"); const int nbands_start = this->psi_initer->nbands_start(); - const int nbands = psi->get_nbands(); + const int nbands_l = psi->get_nbands(); const int nbasis = psi->get_nbasis(); - const bool not_equal = (nbands_start != nbands); + const bool not_equal = (nbands_start != nbands_l); Psi* psi_cpu = reinterpret_cast*>(psi); Psi* psi_device = kspw_psi; - if (not_equal) - { - psi_cpu = new Psi(1, nbands_start, nbasis, nbasis, true); - psi_device = PARAM.inp.device == "gpu" ? new psi::Psi(psi_cpu[0]) - : reinterpret_cast*>(psi_cpu); - } - else if (PARAM.inp.precision == "single") + bool fill = PARAM.inp.ks_solver != "bpcg" || GlobalV::MY_BNDGROUP == 0; + if (fill) { - if (PARAM.inp.device == "cpu") + if (not_equal) { - psi_cpu = reinterpret_cast*>(kspw_psi); - psi_device = kspw_psi; + psi_cpu = new Psi(1, nbands_start, nbasis, nbasis, true); + psi_device = PARAM.inp.device == "gpu" ? new psi::Psi(psi_cpu[0]) + : reinterpret_cast*>(psi_cpu); } - else + else if (PARAM.inp.precision == "single") { - psi_cpu = new Psi(1, nbands_start, nbasis, nbasis, true); - psi_device = kspw_psi; + if (PARAM.inp.device == "cpu") + { + psi_cpu = reinterpret_cast*>(kspw_psi); + psi_device = kspw_psi; + } + else + { + psi_cpu = new Psi(1, nbands_start, nbasis, nbasis, true); + psi_device = kspw_psi; + } } } @@ -134,58 +139,90 @@ void PSIInit::initialize_psi(Psi>* psi, //! Update Hamiltonian from other kpoint to the given one p_hamilt->updateHk(ik); - - //! initialize psi_cpu - this->psi_initer->init_psig(psi_cpu->get_pointer(), ik); - if (psi_device->get_pointer() != psi_cpu->get_pointer()) + if (fill) { - syncmem_h2d_op()(psi_device->get_pointer(), psi_cpu->get_pointer(), nbands_start * nbasis); - } - - std::vector::type> etatom(nbands_start, 0.0); + //! initialize psi_cpu + this->psi_initer->init_psig(psi_cpu->get_pointer(), ik); + if (psi_device->get_pointer() != psi_cpu->get_pointer()) + { + syncmem_h2d_op()(psi_device->get_pointer(), psi_cpu->get_pointer(), nbands_start * nbasis); + } - if (this->ks_solver == "cg") - { - if (not_equal) + if (this->ks_solver == "cg") { - // for diagH_subspace_init, psi_device->get_pointer() and kspw_psi->get_pointer() should be different - hsolver::DiagoIterAssist::diagH_subspace_init(p_hamilt, - psi_device->get_pointer(), - nbands_start, - nbasis, - *(kspw_psi), - etatom.data()); + std::vector::type> etatom(nbands_start, 0.0); + if (not_equal) + { + // for diagH_subspace_init, psi_device->get_pointer() and kspw_psi->get_pointer() should be + // different + hsolver::DiagoIterAssist::diagH_subspace_init(p_hamilt, + psi_device->get_pointer(), + nbands_start, + nbasis, + *(kspw_psi), + etatom.data()); + } + else + { + // for diagH_subspace, psi_device->get_pointer() and kspw_psi->get_pointer() can be the same + hsolver::DiagoIterAssist::diagH_subspace(p_hamilt, + *psi_device, + *kspw_psi, + etatom.data(), + nbands_start); + } } - else + else // dav, bpcg { - // for diagH_subspace, psi_device->get_pointer() and kspw_psi->get_pointer() can be the same - hsolver::DiagoIterAssist::diagH_subspace(p_hamilt, - *psi_device, - *kspw_psi, - etatom.data(), - nbands_start); + if (psi_device->get_pointer() != kspw_psi->get_pointer()) + { + syncmem_complex_op()(kspw_psi->get_pointer(), psi_device->get_pointer(), nbands_l * nbasis); + } } } - else // dav, bpcg +#ifdef __MPI + if (PARAM.inp.ks_solver == "bpcg" && PARAM.inp.bndpar > 1) { - if (psi_device->get_pointer() != kspw_psi->get_pointer()) + std::vector sendcounts(PARAM.inp.bndpar); + std::vector displs(PARAM.inp.bndpar); + MPI_Allgather(&nbands_l, 1, MPI_INT, sendcounts.data(), 1, MPI_INT, BP_WORLD); + displs[0] = 0; + sendcounts[0] *= nbasis; + for (int i = 1; i < PARAM.inp.bndpar; i++) { - syncmem_complex_op()(kspw_psi->get_pointer(), psi_device->get_pointer(), nbands * nbasis); + sendcounts[i] *= nbasis; + displs[i] = displs[i - 1] + sendcounts[i - 1]; } + if (GlobalV::MY_BNDGROUP == 0) + { + for (int ip = 1; ip < PARAM.inp.bndpar; ++ip) + { + Parallel_Common::send_data(psi_cpu->get_pointer() + displs[ip], sendcounts[ip], ip, 0, BP_WORLD); + } + } + else + { + MPI_Status status; + Parallel_Common::recv_dev(kspw_psi->get_pointer(), nbands_l * nbasis, 0, 0, BP_WORLD, &status); + } } +#endif } // end k-point loop - if (not_equal) + if (fill) { - delete psi_cpu; - if(PARAM.inp.device == "gpu") + if (not_equal) { - delete psi_device; + delete psi_cpu; + if (PARAM.inp.device == "gpu") + { + delete psi_device; + } + } + else if (PARAM.inp.precision == "single" && PARAM.inp.device == "gpu") + { + delete psi_cpu; } - } - else if (PARAM.inp.precision == "single" && PARAM.inp.device == "gpu") - { - delete psi_cpu; } ModuleBase::timer::tick("PSIInit", "initialize_psi"); @@ -203,7 +240,11 @@ void PSIInit::initialize_lcao_in_pw(Psi* psi_local, std::ofstream& } } -void allocate_psi(Psi>*& psi, const int& nks, const std::vector& ngk, const int& nbands, const int& npwx) +void allocate_psi(Psi>*& psi, + const int& nks, + const std::vector& ngk, + const int& nbands, + const int& npwx) { assert(npwx > 0); assert(nks > 0); diff --git a/tests/integrate/102_PW_BPCG_GPU/result.ref b/tests/integrate/102_PW_BPCG_GPU/result.ref index 4c3b56aed5..a4c2417f26 100644 --- a/tests/integrate/102_PW_BPCG_GPU/result.ref +++ b/tests/integrate/102_PW_BPCG_GPU/result.ref @@ -1,8 +1,8 @@ -etotref -4869.7470520365577613 -etotperatomref -2434.8735260183 -totalforceref 5.202524 -totalstressref 37241.827525 +etotref -4869.7470518349809936 +etotperatomref -2434.8735259175 +totalforceref 5.207670 +totalstressref 37241.465646 pointgroupref C_1 spacegroupref C_1 nksibzref 8 -totaltimeref 4.25 +totaltimeref 10.28 diff --git a/tests/integrate/187_PW_SDFT_MALL_BPCG_GPU/INPUT b/tests/integrate/187_PW_SDFT_MALL_BPCG_GPU/INPUT index 8d35011db8..4eab8220a6 100644 --- a/tests/integrate/187_PW_SDFT_MALL_BPCG_GPU/INPUT +++ b/tests/integrate/187_PW_SDFT_MALL_BPCG_GPU/INPUT @@ -13,7 +13,7 @@ pseudo_dir ../../PP_ORB kpar 1 bndpar 2 -nbands 7 +nbands 11 nbands_sto all nche_sto 120 diff --git a/tests/integrate/187_PW_SDFT_MALL_BPCG_GPU/result.ref b/tests/integrate/187_PW_SDFT_MALL_BPCG_GPU/result.ref index bdf5186378..077de5b358 100644 --- a/tests/integrate/187_PW_SDFT_MALL_BPCG_GPU/result.ref +++ b/tests/integrate/187_PW_SDFT_MALL_BPCG_GPU/result.ref @@ -1,5 +1,5 @@ -etotref -96.93611634 -etotperatomref -48.46805817 -totalforceref 248.97836000 -totalstressref 230453.08774800 -totaltimeref 6.37 \ No newline at end of file +etotref -96.9361190965003630 +etotperatomref -48.4680595483 +totalforceref 248.979444 +totalstressref 230453.604050 +totaltimeref 6.44 From 1eeb767be041c51f1cb7077ae073385efc60f3df Mon Sep 17 00:00:00 2001 From: "pre-commit-ci-lite[bot]" <117423508+pre-commit-ci-lite[bot]@users.noreply.github.com> Date: Thu, 30 Jan 2025 16:35:06 +0000 Subject: [PATCH 16/17] [pre-commit.ci lite] apply automatic fixes --- source/module_io/cal_dos.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/source/module_io/cal_dos.cpp b/source/module_io/cal_dos.cpp index 8966a0e95c..6f18b12af9 100644 --- a/source/module_io/cal_dos.cpp +++ b/source/module_io/cal_dos.cpp @@ -41,12 +41,12 @@ bool ModuleIO::calculate_dos(const int& is, if (de_ev <= 0) { ModuleBase::WARNING("ModuleIO::calculate_dos", "de <= 0 "); - return 0; + return false; } else if (emax_ev < emin_ev) { ModuleBase::WARNING("ModuleIO::calculate_dos", "emax_ev < emin_ev"); - return 0; + return false; } // mohan fixed bug 2010-1-18 @@ -57,7 +57,7 @@ bool ModuleIO::calculate_dos(const int& is, { ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running, "npoints", npoints); ModuleBase::WARNING("ModuleIO::calculate_dos", "npoints <= 0"); - return 0; + return false; } if (GlobalV::MY_RANK == 0) { @@ -159,5 +159,5 @@ bool ModuleIO::calculate_dos(const int& is, ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running, "sum up the states", sum); delete[] e_mod; - return 1; + return true; } From 8b86c7f5674c1f740e6c21291faf26d28903018b Mon Sep 17 00:00:00 2001 From: Qianruipku Date: Fri, 31 Jan 2025 10:04:44 +0800 Subject: [PATCH 17/17] update --- tests/integrate/102_PW_BPCG_BP/result.ref | 2 +- tests/integrate/184_PW_BPCG_SDFT_5D11S/result.ref | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) mode change 100644 => 100755 tests/integrate/102_PW_BPCG_BP/result.ref diff --git a/tests/integrate/102_PW_BPCG_BP/result.ref b/tests/integrate/102_PW_BPCG_BP/result.ref old mode 100644 new mode 100755 index 8e4563fef6..b97d5ed713 --- a/tests/integrate/102_PW_BPCG_BP/result.ref +++ b/tests/integrate/102_PW_BPCG_BP/result.ref @@ -1,7 +1,7 @@ etotref -4869.7470520063843651 etotperatomref -2434.8735260032 totalforceref 5.194830 -totalstressref 37241.448435 +totalstressref 37241.453346 pointgroupref C_1 spacegroupref C_1 nksibzref 8 diff --git a/tests/integrate/184_PW_BPCG_SDFT_5D11S/result.ref b/tests/integrate/184_PW_BPCG_SDFT_5D11S/result.ref index 92a930773a..460f04c8f6 100644 --- a/tests/integrate/184_PW_BPCG_SDFT_5D11S/result.ref +++ b/tests/integrate/184_PW_BPCG_SDFT_5D11S/result.ref @@ -1,5 +1,5 @@ -etotref -323.2716997598804483 -etotperatomref -161.6358498799 -totalforceref 5.765664 -totalstressref 1782.352362 +etotref -323.28029702 +etotperatomref -161.64014851 +totalforceref 5.605396 +totalstressref 1781.685471 totaltimeref 2.80