diff --git a/source/Makefile.Objects b/source/Makefile.Objects index bf7321d5bc..7f1cdfa58c 100644 --- a/source/Makefile.Objects +++ b/source/Makefile.Objects @@ -333,6 +333,7 @@ OBJS_HSOLVER=diago_cg.o\ diago_david.o\ diago_dav_subspace.o\ diago_bpcg.o\ + para_linear_transform.o\ hsolver.o\ hsolver_pw.o\ hsolver_lcaopw.o\ diff --git a/source/driver.cpp b/source/driver.cpp index 250ac12707..efd76ca276 100644 --- a/source/driver.cpp +++ b/source/driver.cpp @@ -152,9 +152,9 @@ void Driver::reading() GlobalV::MY_RANK, PARAM.inp.bndpar, GlobalV::KPAR, - GlobalV::NPROC_IN_STOGROUP, - GlobalV::RANK_IN_STOGROUP, - GlobalV::MY_STOGROUP, + GlobalV::NPROC_IN_BNDGROUP, + GlobalV::RANK_IN_BPGROUP, + GlobalV::MY_BNDGROUP, GlobalV::NPROC_IN_POOL, GlobalV::RANK_IN_POOL, GlobalV::MY_POOL); diff --git a/source/module_base/global_file.cpp b/source/module_base/global_file.cpp index b912996f4e..4da5e94045 100644 --- a/source/module_base/global_file.cpp +++ b/source/module_base/global_file.cpp @@ -153,36 +153,32 @@ void ModuleBase::Global_File::make_dir_out( #endif } - std::stringstream ss,ss1; - // mohan add 2010-09-12 if(out_alllog) { - ss << "running_" << calculation << "_" << rank + 1; - open_log(GlobalV::ofs_running, ss.str(), calculation, restart); + open_log(GlobalV::ofs_running, PARAM.globalv.log_file, calculation, restart); #if defined(__CUDA) || defined(__ROCM) - open_log(GlobalV::ofs_device, "device" + std::to_string(rank), calculation, restart); + open_log(GlobalV::ofs_device, "device" + std::to_string(rank) + ".log", calculation, restart); #endif } else { if(rank==0) { - ss << "running_" << calculation; - open_log(GlobalV::ofs_running, ss.str(), calculation, restart); + open_log(GlobalV::ofs_running, PARAM.globalv.log_file, calculation, restart); #if defined(__CUDA) || defined(__ROCM) - open_log(GlobalV::ofs_device, "device", calculation, restart); + open_log(GlobalV::ofs_device, "device.log", calculation, restart); #endif } } if(rank==0) { - open_log(GlobalV::ofs_warning, "warning", calculation, restart); + open_log(GlobalV::ofs_warning, "warning.log", calculation, restart); } #ifdef GATHER_INFO - open_log(GlobalV::ofs_info, "math_info_" + std::to_string(rank), calculation, restart); + open_log(GlobalV::ofs_info, "math_info_" + std::to_string(rank) + ".log", calculation, restart); #endif return; @@ -206,7 +202,7 @@ void ModuleBase::Global_File::open_log(std::ofstream &ofs, const std::string &fn // PARAM.globalv.global_out_dir : (default dir to store "*.log" file) //---------------------------------------------------------- std::stringstream ss; - ss << PARAM.globalv.global_out_dir << fn << ".log"; + ss << PARAM.globalv.global_out_dir << fn; if(calculation == "md" && restart) { diff --git a/source/module_base/global_variable.cpp b/source/module_base/global_variable.cpp index 616e8c3656..71f41c1b4b 100644 --- a/source/module_base/global_variable.cpp +++ b/source/module_base/global_variable.cpp @@ -18,14 +18,13 @@ namespace GlobalV int NPROC = 1; ///< global number of process int KPAR = 1; ///< global number of pools -int KPAR_LCAO = 1; ///< global number of pools for LCAO diagonalization only int MY_RANK = 0; ///< global index of process int MY_POOL = 0; ///< global index of pool (count in pool) -int MY_STOGROUP = 0; +int MY_BNDGROUP = 0; int NPROC_IN_POOL = 1; ///< local number of process in a pool -int NPROC_IN_STOGROUP = 1; +int NPROC_IN_BNDGROUP = 1; int RANK_IN_POOL = 0; ///< global index of pool (count in process), my_rank in each pool -int RANK_IN_STOGROUP = 0; +int RANK_IN_BPGROUP = 0; int DRANK = -1; ///< mohan add 2012-01-13, must be -1, so we can recognize who ///< didn't in DIAG_WORLD int DSIZE = KPAR; diff --git a/source/module_base/global_variable.h b/source/module_base/global_variable.h index 56629444d3..7014912457 100644 --- a/source/module_base/global_variable.h +++ b/source/module_base/global_variable.h @@ -28,23 +28,21 @@ namespace GlobalV // NAME : DCOLOR( color of each group) // NAME : GRANK( index of grid world) // NAME : GSIZE( number of processors in each grid world) -// NAME : KPAR_LCAO ( global number of pools for LCAO diagonalization only) //======================================================================== extern int NPROC; extern int KPAR; extern int MY_RANK; extern int MY_POOL; -extern int MY_STOGROUP; +extern int MY_BNDGROUP; extern int NPROC_IN_POOL; -extern int NPROC_IN_STOGROUP; +extern int NPROC_IN_BNDGROUP; extern int RANK_IN_POOL; -extern int RANK_IN_STOGROUP; +extern int RANK_IN_BPGROUP; extern int DRANK; extern int DSIZE; extern int DCOLOR; extern int GRANK; extern int GSIZE; -extern int KPAR_LCAO; //========================================================== // NAME : ofs_running( contain information during runnnig) diff --git a/source/module_base/kernels/math_kernel_op.cpp b/source/module_base/kernels/math_kernel_op.cpp index 59a3c2ace8..5646a089c2 100644 --- a/source/module_base/kernels/math_kernel_op.cpp +++ b/source/module_base/kernels/math_kernel_op.cpp @@ -382,11 +382,13 @@ template struct line_minimize_with_block_op, base_device::DE template struct scal_op; template struct axpy_op, base_device::DEVICE_CPU>; +template struct axpy_op; template struct gemv_op, base_device::DEVICE_CPU>; template struct gemv_op; template struct gemm_op, base_device::DEVICE_CPU>; template struct gemm_op; template struct dot_real_op, base_device::DEVICE_CPU>; +template struct dot_real_op; template struct vector_div_constant_op, base_device::DEVICE_CPU>; template struct vector_mul_vector_op, base_device::DEVICE_CPU>; template struct vector_div_vector_op, base_device::DEVICE_CPU>; @@ -397,8 +399,6 @@ template struct calc_grad_with_block_op, base_device::DEVIC template struct line_minimize_with_block_op, base_device::DEVICE_CPU>; #ifdef __LCAO -template struct axpy_op; -template struct dot_real_op; template struct vector_mul_vector_op; template struct vector_div_constant_op; template struct vector_div_vector_op; diff --git a/source/module_base/para_gemm.cpp b/source/module_base/para_gemm.cpp index 0908457108..2181a5b84d 100644 --- a/source/module_base/para_gemm.cpp +++ b/source/module_base/para_gemm.cpp @@ -25,7 +25,7 @@ void PGemmCN::set_dimension( const int LDB_in, const int nrow_in, const int LDC_in, - const bool gatherC_in) + const int mode) { #ifdef __MPI MPI_Comm_rank(comm_col, &col_rank); @@ -45,13 +45,45 @@ void PGemmCN::set_dimension( this->ncolB = ncolB_in; this->nrow = nrow_in; #ifdef __MPI - this->gatherC = gatherC_in; + switch (mode) + { + case 1: + gatherC = true; + divideCrow = false; + break; + case 2: + gatherC = false; + divideCrow = false; + break; + case 3: + gatherC = false; + divideCrow = true; + break; + default: + break; + } requests.resize(col_nproc); - colA_loc.resize(col_nproc); - MPI_Allgather(&ncolA, 1, MPI_INT, colA_loc.data(), 1, MPI_INT, col_world); - for (int ip = 0; ip < col_nproc; ip++) + if (this->divideCrow) { - max_colA = std::max(max_colA, colA_loc[ip]); + colB_loc.resize(col_nproc); + MPI_Allgather(&ncolB, 1, MPI_INT, colB_loc.data(), 1, MPI_INT, col_world); + int sum = 0; + for (int ip = 0; ip < col_nproc; ip++) + { + max_colB = std::max(max_colB, colB_loc[ip]); + sum += colB_loc[ip]; + } + size_C_local = sum * LDC; + } + else + { + colA_loc.resize(col_nproc); + MPI_Allgather(&ncolA, 1, MPI_INT, colA_loc.data(), 1, MPI_INT, col_world); + for (int ip = 0; ip < col_nproc; ip++) + { + max_colA = std::max(max_colA, colA_loc[ip]); + } + size_C_local = ncolB * LDC; } if (this->gatherC) @@ -71,160 +103,271 @@ void PGemmCN::set_dimension( } size_C_global = displs[col_nproc - 1] + recv_counts[col_nproc - 1]; } - size_C_local = ncolB * LDC; #endif } template void PGemmCN::multiply(const T alpha, const T* A, const T* B, const T beta, T* C) { - const Device* ctx = {}; #ifdef __MPI - if (col_nproc > 1) + if (this->col_nproc > 1) { - std::vector A_tmp(max_colA * LDA); - for (int ip = 0; ip < col_nproc; ip++) + if (this->divideCrow) + { + multiply_row(alpha, A, B, beta, C); + } + else { - if (col_rank != ip) - { - int size = ncolA * LDA; - Parallel_Common::isend_dev(A, size, ip, 0, col_world, &requests[ip], A_tmp.data()); - } + multiply_col(alpha, A, B, beta, C); } + } + else +#endif + { + multiply_single(alpha, A, B, beta, C); + } +} - T* C_local = C; - std::vector C_tmp; - if (this->gatherC) +template +void PGemmCN::multiply_single(const T alpha, const T* A, const T* B, const T beta, T* C) +{ + const Device* ctx = {}; +#ifdef __MPI + T real_beta = row_rank == 0 ? beta : 0; +#else + T real_beta = beta; +#endif + ModuleBase::gemm_op()(ctx, 'C', 'N', ncolA, ncolB, nrow, &alpha, A, LDA, B, LDB, &real_beta, C, LDC); +#ifdef __MPI + if (this->row_nproc > 1) + { + Parallel_Common::reduce_dev(C, size_C_local, row_world); + } +#endif +} + +#ifdef __MPI +template +void PGemmCN::multiply_col(const T alpha, const T* A, const T* B, const T beta, T* C) +{ + const Device* ctx = {}; + + std::vector B_tmp(max_colA * LDA); + std::vector isend_tmp; + if (std::is_same::value) + { + isend_tmp.resize(max_colA * LDA); + } + for (int ip = 0; ip < col_nproc; ip++) + { + if (col_rank != ip) { - C_tmp.resize(size_C_local); - if (std::is_same::value) - { - C_local = nullptr; - resmem_dev_op()(C_local, size_C_local); - } - else - { - C_local = C_tmp.data(); - } - syncmem_dev_op()(C_local, C + displs[col_rank], size_C_local); + int size = ncolA * LDA; + Parallel_Common::isend_dev(A, size, ip, 0, col_world, &requests[ip], isend_tmp.data()); } + } - T* Atmp_device = nullptr; + T* C_local = C; + std::vector C_tmp; + if (this->gatherC) + { + C_tmp.resize(size_C_local); if (std::is_same::value) { - resmem_dev_op()(Atmp_device, max_colA * LDA); + C_local = nullptr; + resmem_dev_op()(C_local, size_C_local); } else { - Atmp_device = A_tmp.data(); + C_local = C_tmp.data(); } + syncmem_dev_op()(C_local, C + displs[col_rank], size_C_local); + } - int shift = 0; - T real_beta = row_rank == 0 ? beta : 0; - for (int ip = 0; ip < col_nproc; ip++) + T* Atmp_device = nullptr; + if (std::is_same::value) + { + resmem_dev_op()(Atmp_device, max_colA * LDA); + } + else + { + Atmp_device = B_tmp.data(); + } + + int shift = 0; + T real_beta = row_rank == 0 ? beta : 0; + for (int ip = 0; ip < col_nproc; ip++) + { + T* C_start = C_local + shift; + if (col_rank == ip) { - T* C_start = C_local + shift; - if (col_rank == ip) - { - ModuleBase::gemm_op()(ctx, - 'C', - 'N', - ncolA, - ncolB, - nrow, - &alpha, - A, - LDA, - B, - LDB, - &real_beta, - C_start, - LDC); - shift += ncolA; - } - else - { - int m = colA_loc[ip]; - int size = m * LDA; - MPI_Status status; - Parallel_Common::recv_dev(Atmp_device, size, ip, 0, col_world, &status, A_tmp.data()); - MPI_Wait(&requests[ip], &status); - ModuleBase::gemm_op()(ctx, - 'C', - 'N', - m, - ncolB, - nrow, - &alpha, - Atmp_device, - LDA, - B, - LDB, - &real_beta, - C_start, - LDC); - shift += m; - } - } - - if (this->gatherC) - { - T* Cglobal_cpu = nullptr; - T* Clocal_cpu = C_tmp.data();; - if (std::is_same::value) - { - delmem_dev_op()(Atmp_device); - - syncmem_d2h_op()(Clocal_cpu, C_local, size_C_local); - delmem_dev_op()(C_local); - - resmem_dev_op()(Cglobal_cpu, size_C_global); - } - else - { - Cglobal_cpu = C; - } - if (this->row_nproc > 1) - { - Parallel_Common::reduce_data(Clocal_cpu, size_C_local, row_world); - } - Parallel_Common::gatherv_data(Clocal_cpu, - size_C_local, - Cglobal_cpu, - recv_counts.data(), - displs.data(), - col_world); - - if (std::is_same::value) - { - syncmem_h2d_op()(C, Cglobal_cpu, size_C_global); - delmem_dev_op()(Cglobal_cpu); - } + ModuleBase::gemm_op()(ctx, + 'C', + 'N', + ncolA, + ncolB, + nrow, + &alpha, + A, + LDA, + B, + LDB, + &real_beta, + C_start, + LDC); + shift += ncolA; } else { - if (this->row_nproc > 1) - { - Parallel_Common::reduce_dev(C, size_C_local, row_world); - } + int m = colA_loc[ip]; + int size = m * LDA; + MPI_Status status; + Parallel_Common::recv_dev(Atmp_device, size, ip, 0, col_world, &status, B_tmp.data()); + MPI_Wait(&requests[ip], &status); + ModuleBase::gemm_op()(ctx, + 'C', + 'N', + m, + ncolB, + nrow, + &alpha, + Atmp_device, + LDA, + B, + LDB, + &real_beta, + C_start, + LDC); + shift += m; + } + } + + if (this->gatherC) + { + T* Cglobal_cpu = nullptr; + T* Clocal_cpu = C_tmp.data(); + std::vector cpu_tmp; + + if (std::is_same::value) + { + delmem_dev_op()(Atmp_device); + + syncmem_d2h_op()(Clocal_cpu, C_local, size_C_local); + delmem_dev_op()(C_local); + + cpu_tmp.resize(size_C_global); + Cglobal_cpu = cpu_tmp.data(); + } + else + { + Cglobal_cpu = C; + } + if (this->row_nproc > 1) + { + Parallel_Common::reduce_data(Clocal_cpu, size_C_local, row_world); + } + Parallel_Common::gatherv_data(Clocal_cpu, + size_C_local, + Cglobal_cpu, + recv_counts.data(), + displs.data(), + col_world); + + if (std::is_same::value) + { + syncmem_h2d_op()(C, Cglobal_cpu, size_C_global); } } else { - T real_beta = row_rank == 0 ? beta : 0; -#else - T real_beta = beta; -#endif - ModuleBase::gemm_op()(ctx, 'C', 'N', ncolA, ncolB, nrow, &alpha, A, LDA, B, LDB, &real_beta, C, LDC); -#ifdef __MPI if (this->row_nproc > 1) { Parallel_Common::reduce_dev(C, size_C_local, row_world); } } -#endif } +template +void PGemmCN::multiply_row(const T alpha, const T* A, const T* B, const T beta, T* C) +{ + const Device* ctx = {}; + + std::vector B_tmp(max_colB * LDB); + for (int ip = 0; ip < col_nproc; ip++) + { + if (col_rank != ip) + { + int size = ncolB * LDB; + Parallel_Common::isend_dev(B, size, ip, 0, col_world, &requests[ip], B_tmp.data()); + } + } + + std::vector C_tmp; + + T* Btmp_device = nullptr; + if (std::is_same::value) + { + resmem_dev_op()(Btmp_device, max_colB * LDB); + } + else + { + Btmp_device = B_tmp.data(); + } + + int shift = 0; + T real_beta = row_rank == 0 ? beta : 0; + for (int ip = 0; ip < col_nproc; ip++) + { + T* C_start = C + shift; + if (col_rank == ip) + { + ModuleBase::gemm_op()(ctx, + 'C', + 'N', + ncolA, + ncolB, + nrow, + &alpha, + A, + LDA, + B, + LDB, + &real_beta, + C_start, + LDC); + shift += ncolB * LDC; + } + else + { + int m = colB_loc[ip]; + int size = m * LDB; + MPI_Status status; + Parallel_Common::recv_dev(Btmp_device, size, ip, 0, col_world, &status, B_tmp.data()); + MPI_Wait(&requests[ip], &status); + ModuleBase::gemm_op()(ctx, + 'C', + 'N', + ncolA, + m, + nrow, + &alpha, + A, + LDA, + Btmp_device, + LDB, + &real_beta, + C_start, + LDC); + shift += m * LDC; + } + } + if (this->row_nproc > 1) + { + Parallel_Common::reduce_dev(C, size_C_local, row_world); + } +} +#endif + template class PGemmCN; template class PGemmCN; template class PGemmCN, base_device::DEVICE_CPU>; diff --git a/source/module_base/para_gemm.h b/source/module_base/para_gemm.h index 69ffd6d146..f00d3cd975 100644 --- a/source/module_base/para_gemm.h +++ b/source/module_base/para_gemm.h @@ -36,7 +36,7 @@ class PGemmCN * @param LDB leading dimension of B in each proc * @param nrow number of rows of A or B * @param LDC leading dimension of C. C can be C_local or C_global - * @param gatherC whether gather C_local to C_global + * @param mode 1: gather C_local to C_global, 2:C_local(nrow * ncol_loc), 3:C_global(nrow_loc * ncol) */ void set_dimension( #ifdef __MPI @@ -49,7 +49,7 @@ class PGemmCN const int LDB, const int nrow, const int LDC, - const bool gatherC = true); + const int mode = 1); /** * @brief calculate C = alpha * A^H * B + beta * C @@ -67,7 +67,8 @@ class PGemmCN std::vector colA_loc; ///< [col_nproc] number of columns of A matrix in each proc int max_colA = 0; ///< maximum number of columns of A matrix in all procs - std::vector colB_loc; ///<[col_nproc] number of columns of B matrix in each proc + std::vector colB_loc; ///< [col_nproc] number of columns of B matrix in each proc + int max_colB = 0; ///< maximum number of columns of B matrix in all procs std::vector requests; ///< MPI request std::vector recv_counts; ///< receive counts for gathering C_local to C_global @@ -75,6 +76,7 @@ class PGemmCN int size_C_local = 0; ///< size of C_local, which is a local matrix in each proc int size_C_global = 0; ///< size of C_global, which is the global C matrix gathered from all procs bool gatherC = true; ///< whether gather C_local to C_global + bool divideCrow = false; ///< whether divide C_global to C_local #endif int ncolA = 0; ///< number of columns of A, which is a local matrix in each proc int ncolB = 0; ///< number of columns of B, which is a local matrix in each proc @@ -83,6 +85,14 @@ class PGemmCN int LDB = 0; ///< leading dimension of B in each proc int LDC = 0; ///< leading dimension of C, which can be C_local or C_global private: + /// @brief for col_nproc == 1 + void multiply_single(const T alpha, const T* A, const T* B, const T beta, T* C); +#ifdef __MPI + /// @brief for mode = 1 or 2 + void multiply_col(const T alpha, const T* A, const T* B, const T beta, T* C); + /// @brief for mode = 3 + void multiply_row(const T alpha, const T* A, const T* B, const T beta, T* C); +#endif using resmem_dev_op = base_device::memory::resize_memory_op; using delmem_dev_op = base_device::memory::delete_memory_op; using syncmem_dev_op = base_device::memory::synchronize_memory_op; diff --git a/source/module_base/parallel_comm.cpp b/source/module_base/parallel_comm.cpp index 7ede7efe4f..5d03447b5a 100644 --- a/source/module_base/parallel_comm.cpp +++ b/source/module_base/parallel_comm.cpp @@ -3,10 +3,10 @@ #include "mpi.h" #include "parallel_global.h" -MPI_Comm POOL_WORLD; -MPI_Comm INTER_POOL; // communicator among different pools -MPI_Comm STO_WORLD; -MPI_Comm PARAPW_WORLD; +MPI_Comm POOL_WORLD; //groups for different plane waves. In this group, only plane waves are different. K-points and bands are the same. +MPI_Comm KP_WORLD; // groups for differnt k. In this group, only k-points are different. Bands and plane waves are the same. +MPI_Comm BP_WORLD; // groups for differnt bands. In this group, only bands are different. K-points and plane waves are the same. +MPI_Comm INT_BGROUP; // internal comm groups for same bands. In this group, only bands are the same. K-points and plane waves are different. MPI_Comm GRID_WORLD; // mohan add 2012-01-13 MPI_Comm DIAG_WORLD; // mohan add 2012-01-13 diff --git a/source/module_base/parallel_comm.h b/source/module_base/parallel_comm.h index c05772fa92..d76b83e111 100644 --- a/source/module_base/parallel_comm.h +++ b/source/module_base/parallel_comm.h @@ -4,9 +4,9 @@ #ifdef __MPI #include "mpi.h" extern MPI_Comm POOL_WORLD; -extern MPI_Comm INTER_POOL; // communicator among different pools -extern MPI_Comm STO_WORLD; -extern MPI_Comm PARAPW_WORLD; +extern MPI_Comm KP_WORLD; // communicator among different pools +extern MPI_Comm INT_BGROUP; +extern MPI_Comm BP_WORLD; extern MPI_Comm GRID_WORLD; // mohan add 2012-01-13 extern MPI_Comm DIAG_WORLD; // mohan add 2012-01-13 diff --git a/source/module_base/parallel_device.cpp b/source/module_base/parallel_device.cpp index d7373674d6..b5ade6c56f 100644 --- a/source/module_base/parallel_device.cpp +++ b/source/module_base/parallel_device.cpp @@ -18,6 +18,22 @@ void isend_data(const std::complex* buf, int count, int dest, int tag, MP { MPI_Isend(buf, count, MPI_COMPLEX, dest, tag, comm, request); } +void send_data(const double* buf, int count, int dest, int tag, MPI_Comm& comm) +{ + MPI_Send(buf, count, MPI_DOUBLE, dest, tag, comm); +} +void send_data(const std::complex* buf, int count, int dest, int tag, MPI_Comm& comm) +{ + MPI_Send(buf, count, MPI_DOUBLE_COMPLEX, dest, tag, comm); +} +void send_data(const float* buf, int count, int dest, int tag, MPI_Comm& comm) +{ + MPI_Send(buf, count, MPI_FLOAT, dest, tag, comm); +} +void send_data(const std::complex* buf, int count, int dest, int tag, MPI_Comm& comm) +{ + MPI_Send(buf, count, MPI_COMPLEX, dest, tag, comm); +} void recv_data(double* buf, int count, int source, int tag, MPI_Comm& comm, MPI_Status* status) { MPI_Recv(buf, count, MPI_DOUBLE, source, tag, comm, status); diff --git a/source/module_base/parallel_device.h b/source/module_base/parallel_device.h index 776de4e755..5b225de9dc 100644 --- a/source/module_base/parallel_device.h +++ b/source/module_base/parallel_device.h @@ -11,6 +11,10 @@ void isend_data(const double* buf, int count, int dest, int tag, MPI_Comm& comm, void isend_data(const std::complex* buf, int count, int dest, int tag, MPI_Comm& comm, MPI_Request* request); void isend_data(const float* buf, int count, int dest, int tag, MPI_Comm& comm, MPI_Request* request); void isend_data(const std::complex* buf, int count, int dest, int tag, MPI_Comm& comm, MPI_Request* request); +void send_data(const double* buf, int count, int dest, int tag, MPI_Comm& comm); +void send_data(const std::complex* buf, int count, int dest, int tag, MPI_Comm& comm); +void send_data(const float* buf, int count, int dest, int tag, MPI_Comm& comm); +void send_data(const std::complex* buf, int count, int dest, int tag, MPI_Comm& comm); void recv_data(double* buf, int count, int source, int tag, MPI_Comm& comm, MPI_Status* status); void recv_data(std::complex* buf, int count, int source, int tag, MPI_Comm& comm, MPI_Status* status); void recv_data(float* buf, int count, int source, int tag, MPI_Comm& comm, MPI_Status* status); @@ -39,15 +43,31 @@ struct object_cpu_point }; /** - * @brief isend data in Device + * @brief send data in Device * */ template -void isend_dev(const T* object, int count, int dest, int tag, MPI_Comm& comm, MPI_Request* request, T* tmp_space = nullptr) +void send_dev(const T* object, int count, int dest, int tag, MPI_Comm& comm, T* tmp_space = nullptr) { object_cpu_point o; T* object_cpu = o.get(object, count, tmp_space); o.sync_d2h(object_cpu, object, count); + send_data(object_cpu, count, dest, tag, comm); + o.del(object_cpu); + return; +} + +/** + * @brief isend data in Device + * @note before the date in send_space is recieved, it should not be modified + * + */ +template +void isend_dev(const T* object, int count, int dest, int tag, MPI_Comm& comm, MPI_Request* request, T* send_space) +{ + object_cpu_point o; + T* object_cpu = o.get(object, count, send_space); + o.sync_d2h(object_cpu, object, count); isend_data(object_cpu, count, dest, tag, comm, request); o.del(object_cpu); return; diff --git a/source/module_base/parallel_global.cpp b/source/module_base/parallel_global.cpp index 4081fd7207..720fc66ec1 100644 --- a/source/module_base/parallel_global.cpp +++ b/source/module_base/parallel_global.cpp @@ -237,12 +237,12 @@ void Parallel_Global::read_pal_param(int argc, void Parallel_Global::finalize_mpi() { MPI_Comm_free(&POOL_WORLD); - if (INTER_POOL != MPI_COMM_NULL) + if (KP_WORLD != MPI_COMM_NULL) { - MPI_Comm_free(&INTER_POOL); + MPI_Comm_free(&KP_WORLD); } - MPI_Comm_free(&STO_WORLD); - MPI_Comm_free(&PARAPW_WORLD); + MPI_Comm_free(&INT_BGROUP); + MPI_Comm_free(&BP_WORLD); MPI_Comm_free(&GRID_WORLD); MPI_Comm_free(&DIAG_WORLD); MPI_Finalize(); @@ -253,9 +253,9 @@ void Parallel_Global::init_pools(const int& NPROC, const int& MY_RANK, const int& BNDPAR, const int& KPAR, - int& NPROC_IN_STOGROUP, - int& RANK_IN_STOGROUP, - int& MY_STOGROUP, + int& NPROC_IN_BNDGROUP, + int& RANK_IN_BPGROUP, + int& MY_BNDGROUP, int& NPROC_IN_POOL, int& RANK_IN_POOL, int& MY_POOL) @@ -268,9 +268,9 @@ void Parallel_Global::init_pools(const int& NPROC, MY_RANK, BNDPAR, KPAR, - NPROC_IN_STOGROUP, - RANK_IN_STOGROUP, - MY_STOGROUP, + NPROC_IN_BNDGROUP, + RANK_IN_BPGROUP, + MY_BNDGROUP, NPROC_IN_POOL, RANK_IN_POOL, MY_POOL); @@ -316,16 +316,16 @@ void Parallel_Global::divide_pools(const int& NPROC, const int& MY_RANK, const int& BNDPAR, const int& KPAR, - int& NPROC_IN_STOGROUP, - int& RANK_IN_STOGROUP, - int& MY_STOGROUP, + int& NPROC_IN_BNDGROUP, + int& RANK_IN_BPGROUP, + int& MY_BNDGROUP, int& NPROC_IN_POOL, int& RANK_IN_POOL, int& MY_POOL) { // note: the order of k-point parallelization and band parallelization is important - // The order will not change the behavior of INTER_POOL or PARAPW_WORLD, and MY_POOL - // and MY_STOGROUP will be the same as well. + // The order will not change the behavior of KP_WORLD or BP_WORLD, and MY_POOL + // and MY_BNDGROUP will be the same as well. if(BNDPAR > 1 && NPROC %(BNDPAR * KPAR) != 0) { std::cout << "Error: When BNDPAR = " << BNDPAR << " > 1, number of processes (" << NPROC << ") must be divisible by the number of groups (" @@ -349,28 +349,28 @@ void Parallel_Global::divide_pools(const int& NPROC, MPI_Comm_dup(bndpar_group.group_comm, &POOL_WORLD); if(kpar_group.inter_comm != MPI_COMM_NULL) { - MPI_Comm_dup(kpar_group.inter_comm, &INTER_POOL); + MPI_Comm_dup(kpar_group.inter_comm, &KP_WORLD); } else { - INTER_POOL = MPI_COMM_NULL; + KP_WORLD = MPI_COMM_NULL; } if(BNDPAR > 1) { - NPROC_IN_STOGROUP = kpar_group.ngroups * bndpar_group.nprocs_in_group; - RANK_IN_STOGROUP = kpar_group.my_group * bndpar_group.nprocs_in_group + bndpar_group.rank_in_group; - MY_STOGROUP = bndpar_group.my_group; - MPI_Comm_split(MPI_COMM_WORLD, MY_STOGROUP, RANK_IN_STOGROUP, &STO_WORLD); - MPI_Comm_dup(bndpar_group.inter_comm, &PARAPW_WORLD); + NPROC_IN_BNDGROUP = kpar_group.ngroups * bndpar_group.nprocs_in_group; + RANK_IN_BPGROUP = kpar_group.my_group * bndpar_group.nprocs_in_group + bndpar_group.rank_in_group; + MY_BNDGROUP = bndpar_group.my_group; + MPI_Comm_split(MPI_COMM_WORLD, MY_BNDGROUP, RANK_IN_BPGROUP, &INT_BGROUP); + MPI_Comm_dup(bndpar_group.inter_comm, &BP_WORLD); } else { - NPROC_IN_STOGROUP = NPROC; - RANK_IN_STOGROUP = MY_RANK; - MY_STOGROUP = 0; - MPI_Comm_dup(MPI_COMM_WORLD, &STO_WORLD); - MPI_Comm_split(MPI_COMM_WORLD, MY_RANK, 0, &PARAPW_WORLD); + NPROC_IN_BNDGROUP = NPROC; + RANK_IN_BPGROUP = MY_RANK; + MY_BNDGROUP = 0; + MPI_Comm_dup(MPI_COMM_WORLD, &INT_BGROUP); + MPI_Comm_split(MPI_COMM_WORLD, MY_RANK, 0, &BP_WORLD); } return; } diff --git a/source/module_base/parallel_global.h b/source/module_base/parallel_global.h index 1fcf933f7b..71e933a33e 100644 --- a/source/module_base/parallel_global.h +++ b/source/module_base/parallel_global.h @@ -48,9 +48,9 @@ void init_pools(const int& NPROC, const int& MY_RANK, const int& BNDPAR, const int& KPAR, - int& NPROC_IN_STOGROUP, - int& RANK_IN_STOGROUP, - int& MY_STOGROUP, + int& NPROC_IN_BNDGROUP, + int& RANK_IN_BPGROUP, + int& MY_BNDGROUP, int& NPROC_IN_POOL, int& RANK_IN_POOL, int& MY_POOL); @@ -59,9 +59,9 @@ void divide_pools(const int& NPROC, const int& MY_RANK, const int& BNDPAR, const int& KPAR, - int& NPROC_IN_STOGROUP, - int& RANK_IN_STOGROUP, - int& MY_STOGROUP, + int& NPROC_IN_BNDGROUP, + int& RANK_IN_BPGROUP, + int& MY_BNDGROUP, int& NPROC_IN_POOL, int& RANK_IN_POOL, int& MY_POOL); diff --git a/source/module_base/parallel_reduce.cpp b/source/module_base/parallel_reduce.cpp index 1a67bbd3bc..c44bd8fb66 100644 --- a/source/module_base/parallel_reduce.cpp +++ b/source/module_base/parallel_reduce.cpp @@ -110,9 +110,9 @@ void Parallel_Reduce::reduce_pool(double* object, const int n) // (1) the value is same in each pool. // (2) we need to reduce the value from different pool. -void Parallel_Reduce::reduce_double_allpool(const int& kpar, const int& nproc_in_pool, double& object) +void Parallel_Reduce::reduce_double_allpool(const int& npool, const int& nproc_in_pool, double& object) { - if (kpar == 1) + if (npool == 1) { return; } @@ -124,9 +124,9 @@ void Parallel_Reduce::reduce_double_allpool(const int& kpar, const int& nproc_in // (1) the value is same in each pool. // (2) we need to reduce the value from different pool. -void Parallel_Reduce::reduce_double_allpool(const int& kpar, const int& nproc_in_pool, double* object, const int n) +void Parallel_Reduce::reduce_double_allpool(const int& npool, const int& nproc_in_pool, double* object, const int n) { - if (kpar == 1) + if (npool == 1) { return; } diff --git a/source/module_base/parallel_reduce.h b/source/module_base/parallel_reduce.h index f120cfd1cd..2d612efc86 100644 --- a/source/module_base/parallel_reduce.h +++ b/source/module_base/parallel_reduce.h @@ -31,8 +31,8 @@ void reduce_int_grid(int* object, const int n); // mohan add 2012-01-12 void reduce_double_grid(double* object, const int n); void reduce_double_diag(double* object, const int n); -void reduce_double_allpool(const int& kpar, const int& nproc_in_pool, double& object); -void reduce_double_allpool(const int& kpar, const int& nproc_in_pool, double* object, const int n); +void reduce_double_allpool(const int& npool, const int& nproc_in_pool, double& object); +void reduce_double_allpool(const int& npool, const int& nproc_in_pool, double* object, const int n); void gather_min_int_all(const int& nproc, int& v); void gather_max_double_all(const int& nproc, double& v); diff --git a/source/module_base/test/global_file_test.cpp b/source/module_base/test/global_file_test.cpp index adc7f30fc7..365782cd80 100644 --- a/source/module_base/test/global_file_test.cpp +++ b/source/module_base/test/global_file_test.cpp @@ -80,10 +80,10 @@ TEST_F(GlobalFile,mkdiratom) TEST_F(GlobalFile,openlog) { std::ofstream ofs; - ModuleBase::Global_File::open_log(ofs,"Si","md",true); + ModuleBase::Global_File::open_log(ofs,"Si.log","md",true); EXPECT_TRUE(ofs.is_open()); ofs.close(); - ModuleBase::Global_File::open_log(ofs,"Si","md",false); + ModuleBase::Global_File::open_log(ofs,"Si.log","md",false); EXPECT_TRUE(ofs.is_open()); ofs.close(); std::string sss = "Si.log"; diff --git a/source/module_base/test_parallel/parallel_global_test.cpp b/source/module_base/test_parallel/parallel_global_test.cpp index 9c0972ebe8..8b9e5053a4 100644 --- a/source/module_base/test_parallel/parallel_global_test.cpp +++ b/source/module_base/test_parallel/parallel_global_test.cpp @@ -55,7 +55,7 @@ class MPIContext int rank_in_pool; int nstogroup; - int my_stogroup; + int MY_BNDGROUP; int rank_in_stogroup; int nproc_in_stogroup; @@ -173,7 +173,7 @@ TEST_F(ParaGlobal, InitPools) mpi.kpar, mpi.nproc_in_stogroup, mpi.rank_in_stogroup, - mpi.my_stogroup, + mpi.MY_BNDGROUP, mpi.nproc_in_pool, mpi.rank_in_pool, mpi.my_pool), ::testing::ExitedWithCode(1), ""); diff --git a/source/module_base/test_parallel/test_para_gemm.cpp b/source/module_base/test_parallel/test_para_gemm.cpp index 4b6445d057..78fdc74b1c 100644 --- a/source/module_base/test_parallel/test_para_gemm.cpp +++ b/source/module_base/test_parallel/test_para_gemm.cpp @@ -367,7 +367,49 @@ TYPED_TEST(PgemmTest, odd_case) this->compare_result(ncolA_global, ncolB_global, LDC_global); } -TYPED_TEST(PgemmTest, odd_case_not_gather) +TYPED_TEST(PgemmTest, row_parallel) +{ + const int ncolA_global = 17, ncolB_global = 7, nrow_global = 13; + const int LDA_global = 17, LDB_global = 18, LDC_global = 19; + + this->decide_ngroup(1, 4); + this->prepare(ncolA_global, ncolB_global, nrow_global, LDA_global, LDB_global, LDC_global); + + this->pgemm.set_dimension(this->col_world, + this->row_world, + this->ncolA, + this->LDA, + this->ncolB, + this->LDB, + this->nrow, + LDC_global); + this->pgemm.multiply(this->alpha, this->A_local.data(), this->B_local.data(), this->beta, this->C_global.data()); + + this->compare_result(ncolA_global, ncolB_global, LDC_global); +} + +TYPED_TEST(PgemmTest, col_parallel) +{ + const int ncolA_global = 17, ncolB_global = 7, nrow_global = 13; + const int LDA_global = 17, LDB_global = 18, LDC_global = 19; + + this->decide_ngroup(4, 1); + this->prepare(ncolA_global, ncolB_global, nrow_global, LDA_global, LDB_global, LDC_global); + + this->pgemm.set_dimension(this->col_world, + this->row_world, + this->ncolA, + this->LDA, + this->ncolB, + this->LDB, + this->nrow, + LDC_global); + this->pgemm.multiply(this->alpha, this->A_local.data(), this->B_local.data(), this->beta, this->C_global.data()); + + this->compare_result(ncolA_global, ncolB_global, LDC_global); +} + +TYPED_TEST(PgemmTest, divide_col) { const int ncolA_global = 17, ncolB_global = 7, nrow_global = 13; const int LDA_global = 17, LDB_global = 18, LDC_global = 19; @@ -392,7 +434,7 @@ TYPED_TEST(PgemmTest, odd_case_not_gather) this->LDB, this->nrow, LDC_global, - false); + 2); this->pgemm.multiply(this->alpha, this->A_local.data(), this->B_local.data(), this->beta, this->C_global.data()+ start); @@ -408,34 +450,32 @@ TYPED_TEST(PgemmTest, odd_case_not_gather) } } -TYPED_TEST(PgemmTest, row_parallel) +TYPED_TEST(PgemmTest, divide_row) { const int ncolA_global = 17, ncolB_global = 7, nrow_global = 13; const int LDA_global = 17, LDB_global = 18, LDC_global = 19; - this->decide_ngroup(1, 4); + this->decide_ngroup(2, 2); this->prepare(ncolA_global, ncolB_global, nrow_global, LDA_global, LDB_global, LDC_global); + std::vector colA_loc(this->nproc_col); + MPI_Allgather(&this->ncolA, 1, MPI_INT, colA_loc.data(), 1, MPI_INT, this->col_world); + std::vector displs(this->nproc_col); + displs[0] = 0; + for (int i = 1; i < this->nproc_col; i++) + { + displs[i] = (displs[i - 1] + colA_loc[i - 1]); + } + int start = displs[this->rank_col]; - this->pgemm.set_dimension(this->col_world, - this->row_world, - this->ncolA, - this->LDA, - this->ncolB, - this->LDB, - this->nrow, - LDC_global); - this->pgemm.multiply(this->alpha, this->A_local.data(), this->B_local.data(), this->beta, this->C_global.data()); - - this->compare_result(ncolA_global, ncolB_global, LDC_global); -} - -TYPED_TEST(PgemmTest, col_parallel) -{ - const int ncolA_global = 17, ncolB_global = 7, nrow_global = 13; - const int LDA_global = 17, LDB_global = 18, LDC_global = 19; - - this->decide_ngroup(4, 1); - this->prepare(ncolA_global, ncolB_global, nrow_global, LDA_global, LDB_global, LDC_global); + int LDC_local = this->ncolA + 2; + std::vector C_loc(LDC_local * ncolB_global, 0.0); + for(int i = 0; i < ncolB_global; i++) + { + for(int j = 0; j < this->ncolA; j++) + { + C_loc[i * LDC_local + j] = this->C_global[i * LDC_global + start + j]; + } + } this->pgemm.set_dimension(this->col_world, this->row_world, @@ -444,10 +484,21 @@ TYPED_TEST(PgemmTest, col_parallel) this->ncolB, this->LDB, this->nrow, - LDC_global); - this->pgemm.multiply(this->alpha, this->A_local.data(), this->B_local.data(), this->beta, this->C_global.data()); + LDC_local, + 3); + this->pgemm.multiply(this->alpha, this->A_local.data(), this->B_local.data(), this->beta, C_loc.data()); - this->compare_result(ncolA_global, ncolB_global, LDC_global); + + + for (int i = 0; i < ncolB_global; i++) + { + for (int j = 0; j < this->ncolA; j++) + { + EXPECT_NEAR(get_double(this->Cref_global[i * LDC_global + start + j]), + get_double(C_loc[i * LDC_local + j]), + 1e-10); + } + } } int main(int argc, char** argv) diff --git a/source/module_cell/cal_atoms_info.h b/source/module_cell/cal_atoms_info.h index 3b59d57540..eeb10823c9 100644 --- a/source/module_cell/cal_atoms_info.h +++ b/source/module_cell/cal_atoms_info.h @@ -68,6 +68,21 @@ class CalAtomsInfo nelec_spin[1] = (para.inp.nelec - para.inp.nupdown ) / 2.0; } elecstate::cal_nbands(para.inp.nelec, para.sys.nlocal, nelec_spin, para.input.nbands); + // calculate the number of nbands_local + para.sys.nbands_l = para.inp.nbands; + if (para.inp.ks_solver == "bpcg") // only bpcg support band parallel + { + para.sys.nbands_l = para.inp.nbands / para.inp.bndpar; + if (GlobalV::MY_BNDGROUP < para.inp.nbands % para.inp.bndpar) + { + para.sys.nbands_l++; + } + } + // temporary code + if (GlobalV::MY_BNDGROUP == 0 || para.inp.ks_solver == "bpcg") + { + para.sys.ks_run = true; + } return; } }; diff --git a/source/module_cell/test/klist_test_para.cpp b/source/module_cell/test/klist_test_para.cpp index 3bd41c85a6..744f3e3150 100644 --- a/source/module_cell/test/klist_test_para.cpp +++ b/source/module_cell/test/klist_test_para.cpp @@ -229,9 +229,9 @@ TEST_F(KlistParaTest, Set) GlobalV::MY_RANK, PARAM.input.bndpar, GlobalV::KPAR, - GlobalV::NPROC_IN_STOGROUP, - GlobalV::RANK_IN_STOGROUP, - GlobalV::MY_STOGROUP, + GlobalV::NPROC_IN_BNDGROUP, + GlobalV::RANK_IN_BPGROUP, + GlobalV::MY_BNDGROUP, GlobalV::NPROC_IN_POOL, GlobalV::RANK_IN_POOL, GlobalV::MY_POOL); @@ -286,9 +286,9 @@ TEST_F(KlistParaTest, SetAfterVC) GlobalV::MY_RANK, PARAM.input.bndpar, GlobalV::KPAR, - GlobalV::NPROC_IN_STOGROUP, - GlobalV::RANK_IN_STOGROUP, - GlobalV::MY_STOGROUP, + GlobalV::NPROC_IN_BNDGROUP, + GlobalV::RANK_IN_BPGROUP, + GlobalV::MY_BNDGROUP, GlobalV::NPROC_IN_POOL, GlobalV::RANK_IN_POOL, GlobalV::MY_POOL); diff --git a/source/module_elecstate/elecstate.cpp b/source/module_elecstate/elecstate.cpp index 1efcaff554..4c06c99213 100644 --- a/source/module_elecstate/elecstate.cpp +++ b/source/module_elecstate/elecstate.cpp @@ -163,8 +163,8 @@ void ElecState::calculate_weights() this->klist->isk); } #ifdef __MPI - // qianrui fix a bug on 2021-7-21 - Parallel_Reduce::reduce_double_allpool(GlobalV::KPAR, GlobalV::NPROC_IN_POOL, this->f_en.demet); + const int npool = GlobalV::KPAR * PARAM.inp.bndpar; + Parallel_Reduce::reduce_double_allpool(npool, GlobalV::NPROC_IN_POOL, this->f_en.demet); #endif } else if (Occupy::fixed_occupations) @@ -192,16 +192,11 @@ void ElecState::calEBand() } } this->f_en.eband = eband; - if (GlobalV::KPAR != 1 && PARAM.inp.esolver_type != "sdft") - { - //================================== - // Reduce all the Energy in each cpu - //================================== - this->f_en.eband /= GlobalV::NPROC_IN_POOL; + #ifdef __MPI - Parallel_Reduce::reduce_all(this->f_en.eband); + const int npool = GlobalV::KPAR * PARAM.inp.bndpar; + Parallel_Reduce::reduce_double_allpool(npool, GlobalV::NPROC_IN_POOL, this->f_en.eband); #endif - } return; } @@ -253,8 +248,8 @@ void ElecState::init_ks(Charge* chg_in, // pointer for class Charge // init nelec_spin with nelec and nupdown this->init_nelec_spin(); // initialize ekb and wg - this->ekb.create(nk_in, PARAM.inp.nbands); - this->wg.create(nk_in, PARAM.inp.nbands); + this->ekb.create(nk_in, PARAM.globalv.nbands_l); + this->wg.create(nk_in, PARAM.globalv.nbands_l); } } // namespace elecstate diff --git a/source/module_elecstate/elecstate_print.cpp b/source/module_elecstate/elecstate_print.cpp index a223b4eac6..508e79ada7 100644 --- a/source/module_elecstate/elecstate_print.cpp +++ b/source/module_elecstate/elecstate_print.cpp @@ -173,23 +173,11 @@ void ElecState::print_eigenvalue(std::ofstream& ofs) { ModuleBase::WARNING_QUIT("print_eigenvalue", "Eigenvalues are too large!"); } - std::stringstream ss; - if(PARAM.inp.out_alllog) - { - ss << PARAM.globalv.global_out_dir << "running_" << PARAM.inp.calculation << "_" << GlobalV::MY_RANK + 1 << ".log"; - } - else - { - ss << PARAM.globalv.global_out_dir << "running_" << PARAM.inp.calculation << ".log"; - } - std::string filename = ss.str(); + + std::string filename = PARAM.globalv.global_out_dir + PARAM.globalv.log_file; std::vector ngk_tot = this->klist->ngk; #ifdef __MPI - if(!PARAM.inp.out_alllog) - { - Parallel_Common::bcast_string(filename); - } MPI_Allreduce(MPI_IN_PLACE, ngk_tot.data(), nks, MPI_INT, MPI_SUM, POOL_WORLD); #endif @@ -216,7 +204,7 @@ void ElecState::print_eigenvalue(std::ofstream& ofs) #ifdef __MPI MPI_Barrier(MPI_COMM_WORLD); #endif - bool ip_flag = PARAM.inp.out_alllog || (GlobalV::RANK_IN_POOL == 0 && GlobalV::MY_STOGROUP == 0); + bool ip_flag = PARAM.inp.out_alllog || (GlobalV::RANK_IN_POOL == 0 && GlobalV::MY_BNDGROUP == 0); if (GlobalV::MY_POOL == ip && ip_flag) { const int start_ik = nks_np * is; @@ -258,7 +246,7 @@ void ElecState::print_band(const int& ik, const int& printe, const int& iter) { // check the band energy. bool wrong = false; - for (int ib = 0; ib < PARAM.inp.nbands; ++ib) + for (int ib = 0; ib < PARAM.globalv.nbands_l; ++ib) { if (std::abs(this->ekb(ik, ib)) > 1.0e10) { @@ -280,7 +268,7 @@ void ElecState::print_band(const int& ik, const int& printe, const int& iter) GlobalV::ofs_running << " Energy (eV) & Occupations for spin=" << this->klist->isk[ik] + 1 << " K-point=" << ik + 1 << std::endl; GlobalV::ofs_running << std::setiosflags(std::ios::showpoint); - for (int ib = 0; ib < PARAM.inp.nbands; ib++) + for (int ib = 0; ib < PARAM.globalv.nbands_l; ib++) { GlobalV::ofs_running << " " << std::setw(6) << ib + 1 << std::setw(15) << this->ekb(ik, ib) * ModuleBase::Ry_to_eV; diff --git a/source/module_elecstate/elecstate_pw_sdft.cpp b/source/module_elecstate/elecstate_pw_sdft.cpp index bef6277adb..269efc06f2 100644 --- a/source/module_elecstate/elecstate_pw_sdft.cpp +++ b/source/module_elecstate/elecstate_pw_sdft.cpp @@ -19,7 +19,7 @@ void ElecStatePW_SDFT::psiToRho(const psi::Psi& psi) setmem_var_op()(this->rho[is], 0, this->charge->nrxx); } - if (GlobalV::MY_STOGROUP == 0) + if (PARAM.globalv.ks_run) { for (int ik = 0; ik < psi.get_nk(); ++ik) { diff --git a/source/module_elecstate/module_charge/charge.h b/source/module_elecstate/module_charge/charge.h index 926bffebea..3d18ee7134 100644 --- a/source/module_elecstate/module_charge/charge.h +++ b/source/module_elecstate/module_charge/charge.h @@ -134,7 +134,7 @@ class Charge /** * @brief Reduce among different pools - * If NPROC_IN_POOLs are all the same, use GlobalV::INTER_POOL + * If NPROC_IN_POOLs are all the same, use GlobalV::KP_WORLD * else, gather rho in a POOL, and then reduce among different POOLs * * @param array_rho f(rho): an array [nrxx] @@ -161,7 +161,6 @@ class Charge bool allocate_rho_final_scf; // LiuXh add 20180606 #ifdef __MPI private: - bool use_intel_pool = false; //use INTER_POOL when NPROC_IN_POOLs are all the same int *rec = nullptr; //The number of elements each process should receive into the receive buffer. int *dis = nullptr; //The displacement (relative to recvbuf) for each process in the receive buffer. #endif diff --git a/source/module_elecstate/module_charge/charge_init.cpp b/source/module_elecstate/module_charge/charge_init.cpp index d162a458ff..930a0c2b7a 100644 --- a/source/module_elecstate/module_charge/charge_init.cpp +++ b/source/module_elecstate/module_charge/charge_init.cpp @@ -61,7 +61,7 @@ void Charge::init_rho(elecstate::efermi& eferm_iout, std::stringstream ssc; ssc << PARAM.globalv.global_readin_dir << "SPIN" << is + 1 << "_CHG.cube"; if (ModuleIO::read_vdata_palgrid(pgrid, - (PARAM.inp.esolver_type == "sdft" ? GlobalV::RANK_IN_STOGROUP : GlobalV::MY_RANK), + (PARAM.inp.esolver_type == "sdft" ? GlobalV::RANK_IN_BPGROUP : GlobalV::MY_RANK), GlobalV::ofs_running, ssc.str(), this->rho[is], @@ -150,7 +150,7 @@ void Charge::init_rho(elecstate::efermi& eferm_iout, // mohan update 2012-02-10, sunliang update 2023-03-09 if (ModuleIO::read_vdata_palgrid( pgrid, - (PARAM.inp.esolver_type == "sdft" ? GlobalV::RANK_IN_STOGROUP : GlobalV::MY_RANK), + (PARAM.inp.esolver_type == "sdft" ? GlobalV::RANK_IN_BPGROUP : GlobalV::MY_RANK), GlobalV::ofs_running, ssc.str(), this->kin_r[is], @@ -223,7 +223,7 @@ void Charge::init_rho(elecstate::efermi& eferm_iout, std::stringstream ssc; ssc << PARAM.globalv.global_readin_dir << "SPIN" << is + 1 << "_CHG.cube"; if (ModuleIO::read_vdata_palgrid(pgrid, - (PARAM.inp.esolver_type == "sdft" ? GlobalV::RANK_IN_STOGROUP : GlobalV::MY_RANK), + (PARAM.inp.esolver_type == "sdft" ? GlobalV::RANK_IN_BPGROUP : GlobalV::MY_RANK), GlobalV::ofs_running, ssc.str(), this->rho[is], diff --git a/source/module_elecstate/module_charge/charge_mpi.cpp b/source/module_elecstate/module_charge/charge_mpi.cpp index f55841be6e..c94a8f5133 100644 --- a/source/module_elecstate/module_charge/charge_mpi.cpp +++ b/source/module_elecstate/module_charge/charge_mpi.cpp @@ -8,13 +8,8 @@ #ifdef __MPI void Charge::init_chgmpi() { - if (INTER_POOL != MPI_COMM_NULL) + if (KP_WORLD == MPI_COMM_NULL) { - this->use_intel_pool = true; - } - else - { - this->use_intel_pool = false; delete[] rec; rec = new int[GlobalV::NPROC_IN_POOL]; delete[] dis; @@ -33,9 +28,9 @@ void Charge::reduce_diff_pools(double* array_rho) const { ModuleBase::TITLE("Charge", "reduce_diff_pools"); ModuleBase::timer::tick("Charge", "reduce_diff_pools"); - if (this->use_intel_pool) + if (KP_WORLD != MPI_COMM_NULL) { - MPI_Allreduce(MPI_IN_PLACE, array_rho, this->nrxx, MPI_DOUBLE, MPI_SUM, INTER_POOL); + MPI_Allreduce(MPI_IN_PLACE, array_rho, this->nrxx, MPI_DOUBLE, MPI_SUM, KP_WORLD); } else { @@ -97,11 +92,7 @@ void Charge::reduce_diff_pools(double* array_rho) const //================================== // Reduce all the rho in each cpu //================================== - if (PARAM.inp.esolver_type == "sdft") { // qinarui add it temporarily. - MPI_Allreduce(array_tot_aux, array_tot, this->rhopw->nxyz, MPI_DOUBLE, MPI_SUM, STO_WORLD); - } else { - MPI_Allreduce(array_tot_aux, array_tot, this->rhopw->nxyz, MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD); -} + MPI_Allreduce(array_tot_aux, array_tot, this->rhopw->nxyz, MPI_DOUBLE, MPI_SUM, INT_BGROUP); //===================================== // Change the order of rho in each cpu @@ -118,13 +109,17 @@ void Charge::reduce_diff_pools(double* array_rho) const delete[] array_tot; delete[] array_tmp; } + if(PARAM.globalv.all_ks_run && PARAM.inp.bndpar > 1) + { + MPI_Allreduce(MPI_IN_PLACE, array_rho, this->nrxx, MPI_DOUBLE, MPI_SUM, BP_WORLD); + } ModuleBase::timer::tick("Charge", "reduce_diff_pools"); } void Charge::rho_mpi() { ModuleBase::TITLE("Charge", "rho_mpi"); - if (GlobalV::KPAR <= 1) { + if (GlobalV::KPAR * PARAM.inp.bndpar <= 1) { return; } ModuleBase::timer::tick("Charge", "rho_mpi"); diff --git a/source/module_elecstate/occupy.cpp b/source/module_elecstate/occupy.cpp index eedf1466ce..580f90de81 100644 --- a/source/module_elecstate/occupy.cpp +++ b/source/module_elecstate/occupy.cpp @@ -230,7 +230,7 @@ void Occupy::gweights(const int nks, continue; } - for (int ib = 0; ib < PARAM.inp.nbands; ib++) + for (int ib = 0; ib < PARAM.globalv.nbands_l; ib++) { //================================ // Calculate the gaussian weights @@ -420,7 +420,8 @@ double Occupy::sumkg(const ModuleBase::matrix& ekb, // GlobalV::ofs_running << "\n sum2 before reduce = " << sum2 << std::endl; #ifdef __MPI - Parallel_Reduce::reduce_double_allpool(GlobalV::KPAR, GlobalV::NPROC_IN_POOL, sum2); + const int npool = GlobalV::KPAR * PARAM.inp.bndpar; + Parallel_Reduce::reduce_double_allpool(npool, GlobalV::NPROC_IN_POOL, sum2); #endif // GlobalV::ofs_running << "\n sum2 after reduce = " << sum2 << std::endl; diff --git a/source/module_elecstate/test/elecstate_base_test.cpp b/source/module_elecstate/test/elecstate_base_test.cpp index 9e393f7323..9a8cd34d66 100644 --- a/source/module_elecstate/test/elecstate_base_test.cpp +++ b/source/module_elecstate/test/elecstate_base_test.cpp @@ -128,6 +128,7 @@ class MockElecState : public ElecState PARAM.input.nupdown = 0.0; PARAM.sys.two_fermi = false; PARAM.input.nbands = 6; + PARAM.sys.nbands_l = 6; PARAM.sys.nlocal = 6; PARAM.input.esolver_type = "ksdft"; PARAM.input.lspinorb = false; diff --git a/source/module_elecstate/test/elecstate_print_test.cpp b/source/module_elecstate/test/elecstate_print_test.cpp index dae3b5904f..1ab1337540 100644 --- a/source/module_elecstate/test/elecstate_print_test.cpp +++ b/source/module_elecstate/test/elecstate_print_test.cpp @@ -96,6 +96,7 @@ class ElecStatePrintTest : public ::testing::Test ucell.magnet.tot_magnetization_nc[1] = 4.4; ucell.magnet.tot_magnetization_nc[2] = 5.5; PARAM.input.ks_solver = "dav"; + PARAM.sys.log_file = "test.dat"; } void TearDown() { @@ -118,11 +119,11 @@ TEST_F(ElecStatePrintTest, PrintFormat) TEST_F(ElecStatePrintTest, PrintEigenvalueS2) { PARAM.input.nspin = 2; - GlobalV::ofs_running.open("running_scf.log", std::ios::out); + GlobalV::ofs_running.open("test.dat", std::ios::out); // print eigenvalue elecstate.print_eigenvalue(GlobalV::ofs_running); GlobalV::ofs_running.close(); - ifs.open("running_scf.log", std::ios::in); + ifs.open("test.dat", std::ios::in); std::string str((std::istreambuf_iterator(ifs)), std::istreambuf_iterator()); EXPECT_THAT(str, testing::HasSubstr("STATE ENERGY(eV) AND OCCUPATIONS")); EXPECT_THAT(str, testing::HasSubstr("NSPIN == 2")); @@ -135,17 +136,17 @@ TEST_F(ElecStatePrintTest, PrintEigenvalueS2) EXPECT_THAT(str, testing::HasSubstr("1 40.8171 0.300000")); EXPECT_THAT(str, testing::HasSubstr("2 54.4228 0.400000")); ifs.close(); - std::remove("running_scf.log"); + std::remove("test.dat"); } TEST_F(ElecStatePrintTest, PrintEigenvalueS4) { PARAM.input.nspin = 4; - GlobalV::ofs_running.open("running_scf.log", std::ios::out); + GlobalV::ofs_running.open("test.dat", std::ios::out); // print eigenvalue elecstate.print_eigenvalue(GlobalV::ofs_running); GlobalV::ofs_running.close(); - ifs.open("running_scf.log", std::ios::in); + ifs.open("test.dat", std::ios::in); std::string str((std::istreambuf_iterator(ifs)), std::istreambuf_iterator()); EXPECT_THAT(str, testing::HasSubstr("STATE ENERGY(eV) AND OCCUPATIONS")); EXPECT_THAT(str, testing::HasSubstr("NSPIN == 4")); @@ -156,51 +157,52 @@ TEST_F(ElecStatePrintTest, PrintEigenvalueS4) EXPECT_THAT(str, testing::HasSubstr("1 40.8171 0.300000")); EXPECT_THAT(str, testing::HasSubstr("2 54.4228 0.400000")); ifs.close(); - std::remove("running_scf.log"); + std::remove("test.dat"); } TEST_F(ElecStatePrintTest, PrintBand) { PARAM.input.nspin = 1; PARAM.input.nbands = 2; + PARAM.sys.nbands_l = 2; GlobalV::MY_RANK = 0; - GlobalV::ofs_running.open("running_scf.log", std::ios::out); + GlobalV::ofs_running.open("test.dat", std::ios::out); // print eigenvalue elecstate.print_band(0, 1, 0); GlobalV::ofs_running.close(); - ifs.open("running_scf.log", std::ios::in); + ifs.open("test.dat", std::ios::in); std::string str((std::istreambuf_iterator(ifs)), std::istreambuf_iterator()); EXPECT_THAT(str, testing::HasSubstr("Energy (eV) & Occupations for spin=1 K-point=1")); EXPECT_THAT(str, testing::HasSubstr("1 13.6057 0.100000")); EXPECT_THAT(str, testing::HasSubstr("2 27.2114 0.200000")); ifs.close(); - std::remove("running_scf.log"); + std::remove("test.dat"); } TEST_F(ElecStatePrintTest, PrintEigenvalueWarning) { elecstate.ekb(0, 0) = 1.0e11; PARAM.input.nspin = 4; - GlobalV::ofs_running.open("running_scf.log", std::ios::out); + GlobalV::ofs_running.open("test.dat", std::ios::out); testing::internal::CaptureStdout(); EXPECT_EXIT(elecstate.print_eigenvalue(GlobalV::ofs_running), ::testing::ExitedWithCode(1), ""); output = testing::internal::GetCapturedStdout(); EXPECT_THAT(output, testing::HasSubstr("Eigenvalues are too large!")); GlobalV::ofs_running.close(); - std::remove("running_scf.log"); + std::remove("test.dat"); } TEST_F(ElecStatePrintTest, PrintBandWarning) { elecstate.ekb(0, 0) = 1.0e11; PARAM.input.nspin = 4; - GlobalV::ofs_running.open("running_scf.log", std::ios::out); + GlobalV::ofs_running.open("test.dat", std::ios::out); testing::internal::CaptureStdout(); EXPECT_EXIT(elecstate.print_band(0, 1, 0), ::testing::ExitedWithCode(1), ""); output = testing::internal::GetCapturedStdout(); EXPECT_THAT(output, testing::HasSubstr("Eigenvalues are too large!")); GlobalV::ofs_running.close(); - std::remove("running_scf.log"); + std::remove("test.dat"); } TEST_F(ElecStatePrintTest, PrintEtot) diff --git a/source/module_elecstate/test_mpi/charge_mpi_test.cpp b/source/module_elecstate/test_mpi/charge_mpi_test.cpp index fda923cb86..81d0fff751 100644 --- a/source/module_elecstate/test_mpi/charge_mpi_test.cpp +++ b/source/module_elecstate/test_mpi/charge_mpi_test.cpp @@ -63,9 +63,9 @@ TEST_F(ChargeMpiTest, reduce_diff_pools1) GlobalV::MY_RANK, PARAM.input.bndpar, GlobalV::KPAR, - GlobalV::NPROC_IN_STOGROUP, - GlobalV::RANK_IN_STOGROUP, - GlobalV::MY_STOGROUP, + GlobalV::NPROC_IN_BNDGROUP, + GlobalV::RANK_IN_BPGROUP, + GlobalV::MY_BNDGROUP, GlobalV::NPROC_IN_POOL, GlobalV::RANK_IN_POOL, GlobalV::MY_POOL); @@ -109,9 +109,9 @@ TEST_F(ChargeMpiTest, reduce_diff_pools2) GlobalV::MY_RANK, PARAM.input.bndpar, GlobalV::KPAR, - GlobalV::NPROC_IN_STOGROUP, - GlobalV::RANK_IN_STOGROUP, - GlobalV::MY_STOGROUP, + GlobalV::NPROC_IN_BNDGROUP, + GlobalV::RANK_IN_BPGROUP, + GlobalV::MY_BNDGROUP, GlobalV::NPROC_IN_POOL, GlobalV::RANK_IN_POOL, GlobalV::MY_POOL); @@ -164,9 +164,9 @@ TEST_F(ChargeMpiTest, rho_mpi) GlobalV::MY_RANK, PARAM.input.bndpar, GlobalV::KPAR, - GlobalV::NPROC_IN_STOGROUP, - GlobalV::RANK_IN_STOGROUP, - GlobalV::MY_STOGROUP, + GlobalV::NPROC_IN_BNDGROUP, + GlobalV::RANK_IN_BPGROUP, + GlobalV::MY_BNDGROUP, GlobalV::NPROC_IN_POOL, GlobalV::RANK_IN_POOL, GlobalV::MY_POOL); diff --git a/source/module_esolver/esolver_ks.cpp b/source/module_esolver/esolver_ks.cpp index 53b55cc67d..9ee009193c 100644 --- a/source/module_esolver/esolver_ks.cpp +++ b/source/module_esolver/esolver_ks.cpp @@ -361,7 +361,7 @@ void ESolver_KS::hamilt2density(UnitCell& ucell, const int istep, con // Maybe in the future, density and wavefunctions should use different // parallel algorithms, in which they do not occupy all processors, for // example wavefunctions uses 20 processors while density uses 10. - if (GlobalV::MY_STOGROUP == 0) + if (PARAM.globalv.ks_run) { // double drho = this->estate.caldr2(); // EState should be used after it is constructed. @@ -550,7 +550,7 @@ void ESolver_KS::iter_finish(UnitCell& ucell, const int istep, int& i this->pelec->charge->rho, this->pelec->nelec_spin.data()); - if (GlobalV::MY_STOGROUP == 0) + if (PARAM.globalv.ks_run) { // mixing will restart at this->p_chgmix->mixing_restart steps if (drho <= PARAM.inp.mixing_restart && PARAM.inp.mixing_restart > 0.0 @@ -634,9 +634,9 @@ void ESolver_KS::iter_finish(UnitCell& ucell, const int istep, int& i } #ifdef __MPI - MPI_Bcast(&drho, 1, MPI_DOUBLE, 0, PARAPW_WORLD); - MPI_Bcast(&this->conv_esolver, 1, MPI_DOUBLE, 0, PARAPW_WORLD); - MPI_Bcast(pelec->charge->rho[0], this->pw_rhod->nrxx, MPI_DOUBLE, 0, PARAPW_WORLD); + MPI_Bcast(&drho, 1, MPI_DOUBLE, 0, BP_WORLD); + MPI_Bcast(&this->conv_esolver, 1, MPI_DOUBLE, 0, BP_WORLD); + MPI_Bcast(pelec->charge->rho[0], this->pw_rhod->nrxx, MPI_DOUBLE, 0, BP_WORLD); #endif // update potential diff --git a/source/module_esolver/esolver_ks_lcao.cpp b/source/module_esolver/esolver_ks_lcao.cpp index 922d3e19bf..7f2c4ca5ff 100644 --- a/source/module_esolver/esolver_ks_lcao.cpp +++ b/source/module_esolver/esolver_ks_lcao.cpp @@ -241,16 +241,16 @@ void ESolver_KS_LCAO::before_all_runners(UnitCell& ucell, const Input_pa } // 12) if kpar is not divisible by nks, print a warning - if (GlobalV::KPAR_LCAO > 1) + if (PARAM.globalv.kpar_lcao > 1) { - if (this->kv.get_nks() % GlobalV::KPAR_LCAO != 0) + if (this->kv.get_nks() % PARAM.globalv.kpar_lcao != 0) { ModuleBase::WARNING("ESolver_KS_LCAO::before_all_runners", "nks is not divisible by kpar."); std::cout << "\n%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%" "%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%" "%%%%%%%%%%%%%%%%%%%%%%%%%%" << std::endl; - std::cout << " Warning: nks (" << this->kv.get_nks() << ") is not divisible by kpar (" << GlobalV::KPAR_LCAO + std::cout << " Warning: nks (" << this->kv.get_nks() << ") is not divisible by kpar (" << PARAM.globalv.kpar_lcao << ")." << std::endl; std::cout << " This may lead to poor load balance. It is strongly suggested to" << std::endl; std::cout << " set nks to be divisible by kpar, but if this is really what" << std::endl; diff --git a/source/module_esolver/esolver_ks_pw.cpp b/source/module_esolver/esolver_ks_pw.cpp index 2110cd76fc..0961675029 100644 --- a/source/module_esolver/esolver_ks_pw.cpp +++ b/source/module_esolver/esolver_ks_pw.cpp @@ -212,7 +212,7 @@ void ESolver_KS_PW::before_all_runners(UnitCell& ucell, const Input_p this->kv, this->ppcell, *this->pw_wfc); - allocate_psi(this->psi, this->kv.get_nks(), this->kv.ngk, PARAM.inp.nbands, this->pw_wfc->npwk_max); + allocate_psi(this->psi, this->kv.get_nks(), this->kv.ngk, PARAM.globalv.nbands_l, this->pw_wfc->npwk_max); this->p_psi_init->prepare_init(PARAM.inp.pw_seed); this->kspw_psi = PARAM.inp.device == "gpu" || PARAM.inp.precision == "single" @@ -228,7 +228,7 @@ void ESolver_KS_PW::before_all_runners(UnitCell& ucell, const Input_p //! 9) setup occupations if (PARAM.inp.ocp) { - this->pelec->fixed_weights(PARAM.inp.ocp_kb, PARAM.inp.nbands, PARAM.inp.nelec); + this->pelec->fixed_weights(PARAM.inp.ocp_kb, PARAM.globalv.nbands_l, PARAM.inp.nelec); } } @@ -563,7 +563,7 @@ void ESolver_KS_PW::update_pot(UnitCell& ucell, const int istep, cons this->pelec->pot->update_from_charge(this->pelec->charge, &ucell); this->pelec->f_en.descf = this->pelec->cal_delta_escf(); #ifdef __MPI - MPI_Bcast(&(this->pelec->f_en.descf), 1, MPI_DOUBLE, 0, PARAPW_WORLD); + MPI_Bcast(&(this->pelec->f_en.descf), 1, MPI_DOUBLE, 0, BP_WORLD); #endif } else diff --git a/source/module_esolver/esolver_sdft_pw.cpp b/source/module_esolver/esolver_sdft_pw.cpp index f5f9292522..a4c20d37d8 100644 --- a/source/module_esolver/esolver_sdft_pw.cpp +++ b/source/module_esolver/esolver_sdft_pw.cpp @@ -198,7 +198,7 @@ void ESolver_SDFT_PW::hamilt2density_single(UnitCell& ucell, int iste // set_diagethr need it this->esolver_KS_ne = hsolver_pw_sdft_obj.stoiter.KS_ne; - if (GlobalV::MY_STOGROUP == 0) + if (PARAM.globalv.ks_run) { Symmetry_rho srho; for (int is = 0; is < PARAM.inp.nspin; is++) @@ -217,7 +217,7 @@ void ESolver_SDFT_PW::hamilt2density_single(UnitCell& ucell, int iste #endif } #ifdef __MPI - MPI_Bcast(&(this->pelec->f_en.deband), 1, MPI_DOUBLE, 0, PARAPW_WORLD); + MPI_Bcast(&(this->pelec->f_en.deband), 1, MPI_DOUBLE, 0, BP_WORLD); #endif ModuleBase::timer::tick("ESolver_SDFT_PW", "hamilt2density"); } diff --git a/source/module_hamilt_pw/hamilt_pwdft/onsite_projector.cpp b/source/module_hamilt_pw/hamilt_pwdft/onsite_projector.cpp index 09982d8e06..348ada6633 100644 --- a/source/module_hamilt_pw/hamilt_pwdft/onsite_projector.cpp +++ b/source/module_hamilt_pw/hamilt_pwdft/onsite_projector.cpp @@ -571,7 +571,8 @@ void projectors::OnsiteProjector::cal_occupations(const psi::Psinproc_in_pool = new int[GlobalV::KPAR]; int nprocgroup; - if(PARAM.inp.esolver_type == "sdft") { nprocgroup = GlobalV::NPROC_IN_STOGROUP; + if(PARAM.inp.esolver_type == "sdft") { nprocgroup = GlobalV::NPROC_IN_BNDGROUP; } else { nprocgroup = GlobalV::NPROC; } @@ -293,7 +293,7 @@ void Parallel_Grid::zpiece_to_stogroup(double *zpiece, const int &iz, double *rh { // case 1: the first part of rho in processor 0. // and send zpeice to to other pools. - if(proc == 0 && GlobalV::RANK_IN_STOGROUP ==0) + if(proc == 0 && GlobalV::RANK_IN_BPGROUP ==0) { for(int ir=0; irwhichpro[ipool][iz], iz, STO_WORLD); + MPI_Send(zpiece, ncxy, MPI_DOUBLE, this->whichpro[ipool][iz], iz, INT_BGROUP); } } @@ -309,7 +309,7 @@ void Parallel_Grid::zpiece_to_stogroup(double *zpiece, const int &iz, double *rh // and the receive tag is iz. else if(proc == GlobalV::RANK_IN_POOL ) { - MPI_Recv(zpiece, ncxy, MPI_DOUBLE, 0, iz, STO_WORLD,&ierror); + MPI_Recv(zpiece, ncxy, MPI_DOUBLE, 0, iz, INT_BGROUP,&ierror); for(int ir=0; irwhichpro[ipool][iz], iz, STO_WORLD); + MPI_Send(zpiece, ncxy, MPI_DOUBLE, this->whichpro[ipool][iz], iz, INT_BGROUP); } } }// MY_POOL == 0 @@ -333,9 +333,9 @@ void Parallel_Grid::zpiece_to_stogroup(double *zpiece, const int &iz, double *rh //ofs_running << "\n Receive charge density iz=" << iz << endl; // the processors in other pools always receive rho from // processor 0. the tag is 'iz' - if(proc == GlobalV::RANK_IN_STOGROUP ) + if(proc == GlobalV::RANK_IN_BPGROUP ) { - MPI_Recv(zpiece, ncxy, MPI_DOUBLE, 0, iz, STO_WORLD,&ierror); + MPI_Recv(zpiece, ncxy, MPI_DOUBLE, 0, iz, INT_BGROUP,&ierror); for(int ir=0; ir>& kspsi_all, ks_fact->nrecv, ks_fact->displs, MPI_DOUBLE_COMPLEX, - PARAPW_WORLD); + BP_WORLD); } #endif @@ -620,7 +620,7 @@ void Sto_EleCond::sKG(const int& smear_type, } // Parallel for bands int allbands_ks = this->nbands_ks - cutib0; - parallel_distribution paraks(allbands_ks, PARAM.inp.bndpar, GlobalV::MY_STOGROUP); + parallel_distribution paraks(allbands_ks, PARAM.inp.bndpar, GlobalV::MY_BNDGROUP); int perbands_ks = paraks.num_per; int ib0_ks = paraks.start; ib0_ks += this->nbands_ks - allbands_ks; @@ -629,10 +629,10 @@ void Sto_EleCond::sKG(const int& smear_type, int allbands_sto = perbands_sto; int allbands = perbands; #ifdef __MPI - MPI_Allreduce(&perbands, &allbands, 1, MPI_INT, MPI_SUM, PARAPW_WORLD); + MPI_Allreduce(&perbands, &allbands, 1, MPI_INT, MPI_SUM, BP_WORLD); allbands_sto = allbands - allbands_ks; - info_gatherv ks_fact(perbands_ks, PARAM.inp.bndpar, 1, PARAPW_WORLD); - info_gatherv sto_npwx(perbands_sto, PARAM.inp.bndpar, npwx, PARAPW_WORLD); + info_gatherv ks_fact(perbands_ks, PARAM.inp.bndpar, 1, BP_WORLD); + info_gatherv sto_npwx(perbands_sto, PARAM.inp.bndpar, npwx, BP_WORLD); #endif const int bandsinfo[6]{perbands_ks, perbands_sto, perbands, allbands_ks, allbands_sto, allbands}; double* en_all = nullptr; @@ -653,7 +653,7 @@ void Sto_EleCond::sKG(const int& smear_type, //----------------------------------------------------------- // ks conductivity //----------------------------------------------------------- - if (GlobalV::MY_STOGROUP == 0 && allbands_ks > 0) + if (GlobalV::MY_BNDGROUP == 0 && allbands_ks > 0) { jjresponse_ks(ik, nt, dt, dEcut, this->p_elec->wg, velop, ct11.data(), ct12.data(), ct22.data()); } diff --git a/source/module_hamilt_pw/hamilt_stodft/sto_forces.cpp b/source/module_hamilt_pw/hamilt_stodft/sto_forces.cpp index 6684332781..9b942d7af2 100644 --- a/source/module_hamilt_pw/hamilt_stodft/sto_forces.cpp +++ b/source/module_hamilt_pw/hamilt_stodft/sto_forces.cpp @@ -210,7 +210,7 @@ void Sto_Forces::cal_sto_force_nl( const int* nchip = stowf.nchip; const int npwx = wfc_basis->npwk_max; int nksbands = psi_in.get_nbands(); - if (GlobalV::MY_STOGROUP != 0) + if (!PARAM.globalv.ks_run) { nksbands = 0; } diff --git a/source/module_hamilt_pw/hamilt_stodft/sto_iter.cpp b/source/module_hamilt_pw/hamilt_stodft/sto_iter.cpp index bd029a401d..f3b53081db 100644 --- a/source/module_hamilt_pw/hamilt_stodft/sto_iter.cpp +++ b/source/module_hamilt_pw/hamilt_stodft/sto_iter.cpp @@ -1,14 +1,16 @@ #include "sto_iter.h" +#include "module_base/kernels/math_kernel_op.h" +#include "module_base/para_gemm.h" #include "module_base/parallel_reduce.h" #include "module_base/timer.h" #include "module_base/tool_quit.h" #include "module_base/tool_title.h" +#include "module_elecstate/kernels/elecstate_op.h" #include "module_elecstate/occupy.h" #include "module_hamilt_pw/hamilt_pwdft/global.h" +#include "module_hsolver/para_linear_transform.h" #include "module_parameter/parameter.h" -#include "module_base/kernels/math_kernel_op.h" -#include "module_elecstate/kernels/elecstate_op.h" template Stochastic_Iter::Stochastic_Iter() @@ -56,8 +58,10 @@ void Stochastic_Iter::orthog(const int& ik, psi::Psi& psi, { ModuleBase::TITLE("Stochastic_Iter", "orthog"); ModuleBase::timer::tick("Stochastic_Iter", "orthog"); + int nbands_l = psi.get_nbands(); + const int nbands = PARAM.inp.nbands; // orthogonal part - if (PARAM.inp.nbands > 0) + if (nbands > 0) { const int nchipk = stowf.nchip[ik]; const int npw = psi.get_current_ngk(); @@ -66,49 +70,67 @@ void Stochastic_Iter::orthog(const int& ik, psi::Psi& psi, stowf.chiortho->fix_k(ik); T *wfgin = stowf.chi0->get_pointer(), *wfgout = stowf.chiortho->get_pointer(); cpymem_complex_op()(wfgout, wfgin, npwx * nchipk); - // for (int ig = 0; ig < npwx * nchipk; ++ig) - // { - // wfgout[ig] = wfgin[ig]; - // } // orthogonal part T* sum = nullptr; - resmem_complex_op()(sum, PARAM.inp.nbands * nchipk); - char transC = 'C'; - char transN = 'N'; - - // sum(b - ModuleBase::gemm_op()(ctx, - transC, - transN, - PARAM.inp.nbands, - nchipk, - npw, - &ModuleBase::ONE, - &psi(ik, 0, 0), - npwx, - wfgout, - npwx, - &ModuleBase::ZERO, - sum, - PARAM.inp.nbands); - Parallel_Reduce::reduce_pool(sum, PARAM.inp.nbands * nchipk); - - // psi -= psi * sum - ModuleBase::gemm_op()(ctx, - transN, - transN, - npw, - nchipk, - PARAM.inp.nbands, - &ModuleBase::NEG_ONE, - &psi(ik, 0, 0), - npwx, - sum, - PARAM.inp.nbands, - &ModuleBase::ONE, - wfgout, - npwx); + resmem_complex_op()(sum, nbands * nchipk); + + if(PARAM.globalv.all_ks_run) + { + // sum(b + ModuleBase::PGemmCN pmmcn; +#ifdef __MPI + pmmcn.set_dimension(BP_WORLD, POOL_WORLD, nbands_l, npwx, nchipk, npwx, npw, nbands, 2); +#else + pmmcn.set_dimension(nbands_l, npwx, nchipk, npwx, npw, nbands, 2); +#endif + pmmcn.multiply(1.0, &psi(ik, 0, 0), wfgout, 0.0, sum); + + // psi -= psi * sum + hsolver::PLinearTransform pltrans; +#ifdef __MPI + pltrans.set_dimension(npw, nbands_l, nchipk, npwx, BP_WORLD, true); +#else + pltrans.set_dimension(npw, nbands_l, nchipk, npwx, true); +#endif + pltrans.act(-1.0, &psi(ik, 0, 0), sum, 1.0, wfgout); + } + else + { + // sum(b + ModuleBase::gemm_op()(ctx, + 'C', + 'N', + nbands, + nchipk, + npw, + &ModuleBase::ONE, + &psi(ik, 0, 0), + npwx, + wfgout, + npwx, + &ModuleBase::ZERO, + sum, + nbands); + Parallel_Reduce::reduce_pool(sum, nbands * nchipk); + + // psi -= psi * sum + ModuleBase::gemm_op()(ctx, + 'N', + 'N', + npw, + nchipk, + nbands, + &ModuleBase::NEG_ONE, + &psi(ik, 0, 0), + npwx, + sum, + nbands, + &ModuleBase::ONE, + wfgout, + npwx); + } + delmem_complex_op()(sum); } ModuleBase::timer::tick("Stochastic_Iter", "orthog"); @@ -152,7 +174,7 @@ void Stochastic_Iter::checkemm(const int& ik, for (int ichi = 0; ichi < ntest; ++ichi) { - if (PARAM.inp.nbands > 0) + if (PARAM.globalv.nbands_l > 0) { pchi = &stowf.chiortho->operator()(ik, ichi, 0); } @@ -329,12 +351,12 @@ void Stochastic_Iter::itermu(const int iter, elecstate::ElecState* pe this->check_precision(targetne, 10 * PARAM.inp.scf_thr, "Ne"); // Set wf.wg - if (PARAM.inp.nbands > 0) + if (PARAM.globalv.nbands_l > 0) { for (int ikk = 0; ikk < this->pkv->get_nks(); ++ikk) { double* en = &pes->ekb(ikk, 0); - for (int iksb = 0; iksb < PARAM.inp.nbands; ++iksb) + for (int iksb = 0; iksb < PARAM.globalv.nbands_l; ++iksb) { pes->wg(ikk, iksb) = stofunc.fd(en[iksb]) * this->pkv->wk[ikk]; } @@ -366,7 +388,7 @@ void Stochastic_Iter::calPn(const int& ik, Stochastic_WF& } } T* pchi; - if (PARAM.inp.nbands > 0) + if (PARAM.globalv.nbands_l > 0) { stowf.chiortho->fix_k(ik); pchi = stowf.chiortho->get_pointer(); @@ -435,12 +457,12 @@ double Stochastic_Iter::calne(elecstate::ElecState* pes) p_che->calcoef_real(nroot_fd); sto_ne = vTMv(p_che->coef_real, spolyv, norder); } - if (PARAM.inp.nbands > 0) + if (PARAM.globalv.nbands_l > 0) { for (int ikk = 0; ikk < this->pkv->get_nks(); ++ikk) { double* en = &pes->ekb(ikk, 0); - for (int iksb = 0; iksb < PARAM.inp.nbands; ++iksb) + for (int iksb = 0; iksb < PARAM.globalv.nbands_l; ++iksb) { KS_ne += stofunc.fd(en[iksb]) * this->pkv->wk[ikk]; } @@ -448,7 +470,11 @@ double Stochastic_Iter::calne(elecstate::ElecState* pes) } KS_ne /= GlobalV::NPROC_IN_POOL; #ifdef __MPI - MPI_Allreduce(MPI_IN_PLACE, &KS_ne, 1, MPI_DOUBLE, MPI_SUM, STO_WORLD); + MPI_Allreduce(MPI_IN_PLACE, &KS_ne, 1, MPI_DOUBLE, MPI_SUM, INT_BGROUP); + if(PARAM.globalv.all_ks_run) + { + MPI_Allreduce(MPI_IN_PLACE, &KS_ne, 1, MPI_DOUBLE, MPI_SUM, BP_WORLD); + } MPI_Allreduce(MPI_IN_PLACE, &sto_ne, 1, MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD); #endif @@ -497,13 +523,13 @@ void Stochastic_Iter::sum_stoeband(Stochastic_WF& stowf, stodemet = -vTMv(p_che->coef_real, spolyv, norder); } - if (PARAM.inp.nbands > 0) + if (PARAM.globalv.nbands_l > 0) { for (int ikk = 0; ikk < this->pkv->get_nks(); ++ikk) { double* enb = &pes->ekb(ikk, 0); // number of electrons in KS orbitals - for (int iksb = 0; iksb < PARAM.inp.nbands; ++iksb) + for (int iksb = 0; iksb < PARAM.globalv.nbands_l; ++iksb) { pes->f_en.demet += stofunc.fdlnfd(enb[iksb]) * this->pkv->wk[ikk]; } @@ -511,7 +537,11 @@ void Stochastic_Iter::sum_stoeband(Stochastic_WF& stowf, } pes->f_en.demet /= GlobalV::NPROC_IN_POOL; #ifdef __MPI - MPI_Allreduce(MPI_IN_PLACE, &pes->f_en.demet, 1, MPI_DOUBLE, MPI_SUM, STO_WORLD); + MPI_Allreduce(MPI_IN_PLACE, &pes->f_en.demet, 1, MPI_DOUBLE, MPI_SUM, INT_BGROUP); + if(PARAM.globalv.all_ks_run) + { + MPI_Allreduce(MPI_IN_PLACE, &pes->f_en.demet, 1, MPI_DOUBLE, MPI_SUM, BP_WORLD); + } MPI_Allreduce(MPI_IN_PLACE, &stodemet, 1, MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD); #endif pes->f_en.demet += stodemet; @@ -576,12 +606,8 @@ void Stochastic_Iter::cal_storho(const UnitCell& ucell, resmem_complex_op()(porter, nrxx); std::vector sto_rho(nspin); - for(int is = 0; is < nspin; ++is) - { - sto_rho[is] = pes->charge->rho[is]; - } std::vector _tmprho; - if (PARAM.inp.nbands > 0) + if (PARAM.globalv.nbands_l > 0) { // If there are KS orbitals, we need to allocate another memory for sto_rho _tmprho.resize(nrxx * nspin); @@ -590,6 +616,13 @@ void Stochastic_Iter::cal_storho(const UnitCell& ucell, sto_rho[is] = _tmprho.data() + is * nrxx; } } + else + { + for (int is = 0; is < nspin; ++is) + { + sto_rho[is] = pes->charge->rho[is]; + } + } // pes->rho is a device memory, and when using cpu and double, we donot need to allocate memory for pes->rho if (PARAM.inp.device != "gpu" && PARAM.inp.precision != "single") { @@ -635,11 +668,15 @@ void Stochastic_Iter::cal_storho(const UnitCell& ucell, delmem_complex_op()(porter); #ifdef __MPI - if(GlobalV::KPAR > 1) + if(GlobalV::KPAR * PARAM.inp.bndpar > 1) { for (int is = 0; is < nspin; ++is) { pes->charge->reduce_diff_pools(sto_rho[is]); + if (!PARAM.globalv.all_ks_run && PARAM.inp.bndpar > 1) + { + MPI_Allreduce(MPI_IN_PLACE, sto_rho[is], nrxx, MPI_DOUBLE, MPI_SUM, BP_WORLD); + } } } #endif @@ -661,11 +698,6 @@ void Stochastic_Iter::cal_storho(const UnitCell& ucell, #ifdef __MPI MPI_Allreduce(MPI_IN_PLACE, &sto_ne, 1, MPI_DOUBLE, MPI_SUM, POOL_WORLD); - MPI_Allreduce(MPI_IN_PLACE, &sto_ne, 1, MPI_DOUBLE, MPI_SUM, PARAPW_WORLD); - for(int is = 0; is < nspin; ++is) - { - MPI_Allreduce(MPI_IN_PLACE, sto_rho[is], nrxx, MPI_DOUBLE, MPI_SUM, PARAPW_WORLD); - } #endif double factor = targetne / (KS_ne + sto_ne); if (std::abs(factor - 1) > 1e-10) @@ -680,7 +712,7 @@ void Stochastic_Iter::cal_storho(const UnitCell& ucell, for (int is = 0; is < 1; ++is) { - if (PARAM.inp.nbands > 0) + if (PARAM.globalv.nbands_l > 0) { #ifdef _OPENMP #pragma omp parallel for @@ -715,7 +747,7 @@ void Stochastic_Iter::calTnchi_ik(const int& ik, Stochastic_WFfix_k(ik); T* out = stowf.shchi->get_pointer(); T* pchi; - if (PARAM.inp.nbands > 0) + if (PARAM.globalv.nbands_l > 0) { stowf.chiortho->fix_k(ik); pchi = stowf.chiortho->get_pointer(); diff --git a/source/module_hamilt_pw/hamilt_stodft/sto_stress_pw.cpp b/source/module_hamilt_pw/hamilt_stodft/sto_stress_pw.cpp index 62a4c16779..4d77635139 100644 --- a/source/module_hamilt_pw/hamilt_stodft/sto_stress_pw.cpp +++ b/source/module_hamilt_pw/hamilt_stodft/sto_stress_pw.cpp @@ -111,7 +111,7 @@ void Sto_Stress_PW::sto_stress_kin(ModuleBase::matrix& sigma, ModuleBase::timer::tick("Sto_Stress_PW", "stress_kin"); int nksbands = psi.get_nbands(); - if (GlobalV::MY_STOGROUP != 0) + if (!PARAM.globalv.ks_run) { nksbands = 0; } @@ -160,7 +160,7 @@ void Sto_Stress_PW::sto_stress_nl(ModuleBase::matrix& sigma, int* nchip = stowf.nchip; const int npwx = wfc_basis->npwk_max; int nksbands = psi_in.get_nbands(); - if (GlobalV::MY_STOGROUP != 0) + if (!PARAM.globalv.ks_run) { nksbands = 0; } diff --git a/source/module_hamilt_pw/hamilt_stodft/sto_tool.cpp b/source/module_hamilt_pw/hamilt_stodft/sto_tool.cpp index 8b350c7777..98f346f188 100644 --- a/source/module_hamilt_pw/hamilt_stodft/sto_tool.cpp +++ b/source/module_hamilt_pw/hamilt_stodft/sto_tool.cpp @@ -138,7 +138,7 @@ psi::Psi>* gatherchi(psi::Psi>& chi, nrecv_sto, displs_sto, MPI_COMPLEX, - PARAPW_WORLD); + BP_WORLD); ModuleBase::timer::tick("sKG", "bands_gather"); } #endif diff --git a/source/module_hamilt_pw/hamilt_stodft/sto_wf.cpp b/source/module_hamilt_pw/hamilt_stodft/sto_wf.cpp index 2dd983e23e..b00564f1ef 100644 --- a/source/module_hamilt_pw/hamilt_stodft/sto_wf.cpp +++ b/source/module_hamilt_pw/hamilt_stodft/sto_wf.cpp @@ -72,7 +72,7 @@ void Stochastic_WF::init_sto_orbitals(const int seed_in) } else { - srand((unsigned)std::abs(seed_in) + (GlobalV::MY_STOGROUP * GlobalV::NPROC_IN_STOGROUP + GlobalV::RANK_IN_STOGROUP) * 10000); + srand((unsigned)std::abs(seed_in) + (GlobalV::MY_BNDGROUP * GlobalV::NPROC_IN_BNDGROUP + GlobalV::RANK_IN_BPGROUP) * 10000); } this->allocate_chi0(); @@ -88,12 +88,12 @@ void Stochastic_WF::allocate_chi0() // former processor calculate more bands if (firstrankmore) { - igroup = GlobalV::MY_STOGROUP; + igroup = GlobalV::MY_BNDGROUP; } // latter processor calculate more bands else { - igroup = PARAM.inp.bndpar - GlobalV::MY_STOGROUP - 1; + igroup = PARAM.inp.bndpar - GlobalV::MY_BNDGROUP - 1; } const int nchi = PARAM.inp.nbands_sto; const int npwx = this->npwx; @@ -172,16 +172,16 @@ void Stochastic_WF::init_com_orbitals() // former processor calculate more bands if (firstrankmore) { - igroup = GlobalV::MY_STOGROUP; + igroup = GlobalV::MY_BNDGROUP; } // latter processor calculate more bands else { - igroup = PARAM.inp.bndpar - GlobalV::MY_STOGROUP - 1; + igroup = PARAM.inp.bndpar - GlobalV::MY_BNDGROUP - 1; } const int ngroup = PARAM.inp.bndpar; const int n_in_pool = GlobalV::NPROC_IN_POOL; - const int i_in_group = GlobalV::RANK_IN_STOGROUP; + const int i_in_group = GlobalV::RANK_IN_BPGROUP; const int i_in_pool = GlobalV::RANK_IN_POOL; int* totnpw = new int[nks]; @@ -315,10 +315,10 @@ void Stochastic_WF::init_sto_orbitals_Ecut(const int seed_in, int* nrecv = new int[PARAM.inp.bndpar]; const int nchiper = this->nchip[0]; #ifdef __MPI - MPI_Allgather(&nchiper, 1, MPI_INT, nrecv, 1, MPI_INT, PARAPW_WORLD); + MPI_Allgather(&nchiper, 1, MPI_INT, nrecv, 1, MPI_INT, BP_WORLD); #endif int ichi_start = 0; - for (int i = 0; i < GlobalV::MY_STOGROUP; ++i) + for (int i = 0; i < GlobalV::MY_BNDGROUP; ++i) { ichi_start += nrecv[i]; } diff --git a/source/module_hsolver/CMakeLists.txt b/source/module_hsolver/CMakeLists.txt index 7f6c8ca4c6..6ccfa42c2e 100644 --- a/source/module_hsolver/CMakeLists.txt +++ b/source/module_hsolver/CMakeLists.txt @@ -4,6 +4,7 @@ list(APPEND objects diago_david.cpp diago_dav_subspace.cpp diago_bpcg.cpp + para_linear_transform.cpp hsolver_pw.cpp hsolver_lcaopw.cpp hsolver_pw_sdft.cpp diff --git a/source/module_hsolver/diago_bpcg.cpp b/source/module_hsolver/diago_bpcg.cpp old mode 100644 new mode 100755 index 36f77d372d..2a03fe13b9 --- a/source/module_hsolver/diago_bpcg.cpp +++ b/source/module_hsolver/diago_bpcg.cpp @@ -11,6 +11,7 @@ #include "module_base/blas_connector.h" #include "module_base/global_function.h" #include "module_base/kernels/math_kernel_op.h" +#include "para_linear_transform.h" namespace hsolver { @@ -34,44 +35,61 @@ DiagoBPCG::~DiagoBPCG() { } template -void DiagoBPCG::init_iter(const int nband, const int nbasis, const int ndim) { +void DiagoBPCG::init_iter(const int nband, const int nband_l, const int nbasis, const int ndim) { // Specify the problem size n_basis, n_band, while lda is n_basis this->n_band = nband; + this->n_band_l = nband_l; this->n_basis = nbasis; this->n_dim = ndim; // All column major tensors - this->beta = std::move(ct::Tensor(r_type, device_type, {this->n_band})); + this->beta = std::move(ct::Tensor(r_type, device_type, {this->n_band_l})); this->eigen = std::move(ct::Tensor(r_type, device_type, {this->n_band})); - this->err_st = std::move(ct::Tensor(r_type, device_type, {this->n_band})); + this->err_st = std::move(ct::Tensor(r_type, device_type, {this->n_band_l})); this->hsub = std::move(ct::Tensor(t_type, device_type, {this->n_band, this->n_band})); - this->hpsi = std::move(ct::Tensor(t_type, device_type, {this->n_band, this->n_basis})); - this->work = std::move(ct::Tensor(t_type, device_type, {this->n_band, this->n_basis})); - this->hgrad = std::move(ct::Tensor(t_type, device_type, {this->n_band, this->n_basis})); - this->grad_old = std::move(ct::Tensor(t_type, device_type, {this->n_band, this->n_basis})); + this->hpsi = std::move(ct::Tensor(t_type, device_type, {this->n_band_l, this->n_basis})); + this->work = std::move(ct::Tensor(t_type, device_type, {this->n_band_l, this->n_basis})); + this->hgrad = std::move(ct::Tensor(t_type, device_type, {this->n_band_l, this->n_basis})); + this->grad_old = std::move(ct::Tensor(t_type, device_type, {this->n_band_l, this->n_basis})); this->prec = std::move(ct::Tensor(r_type, device_type, {this->n_basis})); - this->grad = std::move(ct::Tensor(t_type, device_type, {this->n_band, this->n_basis})); + this->grad = std::move(ct::Tensor(t_type, device_type, {this->n_band_l, this->n_basis})); +#ifdef __MPI + this->pmmcn.set_dimension(BP_WORLD, POOL_WORLD, n_band_l, n_basis, n_band_l, n_basis, n_dim, n_band); + this->plintrans.set_dimension(n_dim, nband_l, n_band_l, n_basis, BP_WORLD, false); +#else + this->pmmcn.set_dimension(n_band_l, n_basis, n_band_l, n_basis, n_dim, n_band); + this->plintrans.set_dimension(n_dim, nband_l, n_band_l, n_basis, false); +#endif } template bool DiagoBPCG::test_error(const ct::Tensor& err_in, const std::vector& ethr_band) { - const Real * _err_st = err_in.data(); + Real* _err_st = err_in.data(); + bool not_conv = false; + std::vector tmp_cpu; if (err_in.device_type() == ct::DeviceType::GpuDevice) { - ct::Tensor h_err_in = err_in.to_device(); - _err_st = h_err_in.data(); + // ct::Tensor h_err_in = err_in.to_device(); + // _err_st = h_err_in.data(); + // qianrui change it, because it can not pass the valgrind test + tmp_cpu.resize(this->n_band_l); + _err_st = tmp_cpu.data(); + syncmem_var_d2h_op()(_err_st, err_in.data(), this->n_band_l); } - for (int ii = 0; ii < this->n_band; ii++) { + for (int ii = 0; ii < this->n_band_l; ii++) { if (_err_st[ii] > ethr_band[ii]) { - return true; + not_conv = true; } } - return false; +#ifdef __MPI + MPI_Allreduce(MPI_IN_PLACE, ¬_conv, 1, MPI_C_BOOL, MPI_LOR, BP_WORLD); +#endif + return not_conv; } // Finally, the last one! @@ -82,7 +100,7 @@ void DiagoBPCG::line_minimize( ct::Tensor& psi_out, ct::Tensor& hpsi_out) { - line_minimize_with_block_op()(grad_in.data(), hgrad_in.data(), psi_out.data(), hpsi_out.data(), this->n_basis, this->n_basis, this->n_band); + line_minimize_with_block_op()(grad_in.data(), hgrad_in.data(), psi_out.data(), hpsi_out.data(), this->n_dim, this->n_basis, this->n_band_l); } @@ -94,28 +112,8 @@ void DiagoBPCG::orth_cholesky( ct::Tensor& hpsi_out, ct::Tensor& hsub_out) { - // hsub_out = psi_out * transc(psi_out) - ct::EinsumOption option( - /*conj_x=*/false, /*conj_y=*/true, /*alpha=*/1.0, /*beta=*/0.0, /*Tensor out=*/&hsub_out); - // hsub_out = ct::op::einsum("ij,kj->ik", psi_out, psi_out, option); - // gemm: hsub_out(n_band x n_band) = psi_out^T(n_band x n_basis) * psi_out(n_basis x n_band) - gemm_op()(this->ctx, - 'C', - 'N', - this->n_band, //m - this->n_band, //n - this->n_dim, //k - this->one, //1.0 - psi_out.data(), - this->n_basis, //lda - psi_out.data(), - this->n_basis, //ldb - this->zero, //0.0 - hsub_out.data(), - this->n_band); //ldc - - Parallel_Reduce::reduce_pool(hsub_out.data(), this->n_band * this->n_band); + this->pmmcn.multiply(1.0, psi_out.data(), psi_out.data(), 0.0, hsub_out.data()); // set hsub matrix to lower format; ct::kernels::set_matrix()( @@ -148,9 +146,9 @@ void DiagoBPCG::calc_grad_with_block( hpsi_in.data(), grad_out.data(), grad_old_out.data(), + this->n_dim, this->n_basis, - this->n_basis, - this->n_band); + this->n_band_l); } template @@ -165,51 +163,12 @@ void DiagoBPCG::orth_projection( ct::Tensor& hsub_in, ct::Tensor& grad_out) { - ct::EinsumOption option( - /*conj_x=*/false, /*conj_y=*/true, /*alpha=*/1.0, /*beta=*/0.0, /*Tensor out=*/&hsub_in); - // hsub_in = ct::op::einsum("ij,kj->ik", grad_out, psi_in, option); - // gemm: hsub_in(n_band x n_band) = psi_in^T(n_band x n_basis) * grad_out(n_basis x n_band) - gemm_op()(this->ctx, - 'C', - 'N', - this->n_band, //m - this->n_band, //n - this->n_dim, //k - this->one, //1.0 - psi_in.data(), - this->n_basis, //lda - grad_out.data(), - this->n_basis, //ldb - this->zero, //0.0 - hsub_in.data(), - this->n_band); //ldc - - Parallel_Reduce::reduce_pool(hsub_in.data(), this->n_band * this->n_band); - - // set_matrix_op()('L', hsub_in->data(), this->n_band); - option = ct::EinsumOption( - /*conj_x=*/false, /*conj_y=*/false, /*alpha=*/-1.0, /*beta=*/1.0, /*Tensor out=*/&grad_out); - // grad_out = ct::op::einsum("ij,jk->ik", hsub_in, psi_in, option); - - // grad_out(n_basis x n_band) = 1.0 * grad_out(n_basis x n_band) - psi_in(n_basis x n_band) * hsub_in(n_band x n_band) - gemm_op()(this->ctx, - 'N', - 'N', - this->n_dim, //m - this->n_band, //n - this->n_band, //k - this->neg_one, //-1.0 - psi_in.data(), - this->n_basis, //lda - hsub_in.data(), - this->n_band, //ldb - this->one, //1.0 - grad_out.data(), - this->n_basis); //ldc - - // * This type of non inner product like operation does not need reduce! - + this->pmmcn.multiply(1.0, psi_in.data(), grad_out.data(), 0.0, hsub_in.data()); + + // grad_out(n_basis x n_band) = 1.0 * grad_out(n_basis x n_band) - psi_in(n_basis x n_band) * hsub_in(n_band x + // n_band) + this->plintrans.act(-1.0, psi_in.data(), hsub_in.data(), 1.0, grad_out.data()); return; } @@ -219,29 +178,9 @@ void DiagoBPCG::rotate_wf( ct::Tensor& psi_out, ct::Tensor& workspace_in) { - ct::EinsumOption option( - /*conj_x=*/false, /*conj_y=*/false, /*alpha=*/1.0, /*beta=*/0.0, /*Tensor out=*/&workspace_in); - // workspace_in = ct::op::einsum("ij,jk->ik", hsub_in, psi_out, option); - // gemm: workspace_in(n_basis x n_band) = psi_out(n_basis x n_band) * hsub_in(n_band x n_band) - gemm_op()(this->ctx, - 'N', - 'N', - this->n_basis, //m - this->n_band, //n - this->n_band, //k - this->one, //1.0 - psi_out.data(), - this->n_basis, //lda - hsub_in.data(), - this->n_band, //ldb - this->zero, //0.0 - workspace_in.data(), - this->n_basis); //ldc - - // * This type of non inner product like operation does not need reduce! - - syncmem_complex_op()(psi_out.template data(), workspace_in.template data(), this->n_band * this->n_basis); + this->plintrans.act(1.0, psi_out.data(), hsub_in.data(), 0.0, workspace_in.data()); + syncmem_complex_op()(psi_out.template data(), workspace_in.template data(), this->n_band_l * this->n_basis); return; } @@ -253,7 +192,7 @@ void DiagoBPCG::calc_hpsi_with_block( ct::Tensor& hpsi_out) { // calculate all-band hpsi - hpsi_func(psi_in, hpsi_out.data(), this->n_basis, this->n_band); + hpsi_func(psi_in, hpsi_out.data(), this->n_basis, this->n_band_l); } template @@ -263,30 +202,8 @@ void DiagoBPCG::diag_hsub( ct::Tensor& hsub_out, ct::Tensor& eigenvalue_out) { - // calculate all-band hsub - // Note: ctx is nothing but the devices used in this class (Device * ctx = nullptr;), - // it controls the ops to use the corresponding device to calculate results - ct::EinsumOption option( - /*conj_x=*/false, /*conj_y=*/true, /*alpha=*/1.0, /*beta=*/0.0, /*Tensor out=*/&hsub_out); - // hsub_out = ct::op::einsum("ij,kj->ik", psi_in, hpsi_in, option); - // gemm: hsub_out(n_band x n_band) = hpsi_in^T(n_band x n_basis) * psi_in(n_basis x n_band) - gemm_op()(this->ctx, - 'C', - 'N', - this->n_band, //m - this->n_band, //n - this->n_dim, //k - this->one, //1.0 - hpsi_in.data(), - this->n_basis, //lda - psi_in.data(), - this->n_basis, //ldb - this->zero, //0.0 - hsub_out.data(), - this->n_band); //ldc - - Parallel_Reduce::reduce_pool(hsub_out.data(), this->n_band * this->n_band); + this->pmmcn.multiply(1.0, hpsi_in.data(), psi_in.data(), 0.0, hsub_out.data()); ct::kernels::lapack_dnevd()('V', 'U', hsub_out.data(), this->n_band, eigenvalue_out.data()); @@ -344,7 +261,7 @@ void DiagoBPCG::diag(const HPsiFunc& hpsi_func, { const int current_scf_iter = hsolver::DiagoIterAssist::SCF_ITER; // Get the pointer of the input psi - this->psi = std::move(ct::TensorMap(psi_in /*psi_in.get_pointer()*/, t_type, device_type, {this->n_band, this->n_basis})); + this->psi = std::move(ct::TensorMap(psi_in /*psi_in.get_pointer()*/, t_type, device_type, {this->n_band_l, this->n_basis})); // Update the precondition array this->calc_prec(); @@ -352,9 +269,9 @@ void DiagoBPCG::diag(const HPsiFunc& hpsi_func, // Improving the initial guess of the wave function psi through a subspace diagonalization. this->calc_hsub_with_block(hpsi_func, psi_in, this->psi, this->hpsi, this->hsub, this->work, this->eigen); - setmem_complex_op()(this->grad_old.template data(), 0, this->n_basis * this->n_band); + setmem_complex_op()(this->grad_old.template data(), 0, this->n_basis * this->n_band_l); - setmem_var_op()(this->beta.template data(), std::numeric_limits::infinity(), this->n_band); + setmem_var_op()(this->beta.template data(), std::numeric_limits::infinity(), this->n_band_l); int ntry = 0; int max_iter = current_scf_iter > 1 ? @@ -378,7 +295,7 @@ void DiagoBPCG::diag(const HPsiFunc& hpsi_func, this->orth_projection(this->psi, this->hsub, this->grad); // this->grad_old = this->grad; - syncmem_complex_op()(this->grad_old.template data(), this->grad.template data(), n_basis * n_band); + syncmem_complex_op()(this->grad_old.template data(), this->grad.template data(), n_basis * n_band_l); // Calculate H|grad> matrix this->calc_hpsi_with_block(hpsi_func, this->grad.template data(), /*this->grad_wrapper[0],*/ this->hgrad); @@ -399,7 +316,14 @@ void DiagoBPCG::diag(const HPsiFunc& hpsi_func, this->calc_hsub_with_block_exit(this->psi, this->hpsi, this->hsub, this->work, this->eigen); - syncmem_var_d2h_op()(eigenvalue_in, this->eigen.template data(), this->n_band); + int start_nband = 0; +#ifdef __MPI + if (this->plintrans.nproc_col > 1) + { + start_nband = this->plintrans.start_colB[GlobalV::MY_BNDGROUP]; + } +#endif + syncmem_var_d2h_op()(eigenvalue_in, this->eigen.template data() + start_nband, this->n_band_l); return; } diff --git a/source/module_hsolver/diago_bpcg.h b/source/module_hsolver/diago_bpcg.h index 90907de5e9..243ac8b44f 100644 --- a/source/module_hsolver/diago_bpcg.h +++ b/source/module_hsolver/diago_bpcg.h @@ -1,18 +1,18 @@ #ifndef DIAGO_BPCG_H_ #define DIAGO_BPCG_H_ +#include "module_base/kernels/math_kernel_op.h" +#include "module_base/module_device/memory_op.h" +#include "module_base/module_device/types.h" +#include "module_base/para_gemm.h" #include "module_hamilt_general/hamilt.h" #include "module_hamilt_pw/hamilt_pwdft/structure_factor.h" - -#include "module_base/module_device/types.h" -#include "module_base/module_device/memory_op.h" - -#include "module_base/kernels/math_kernel_op.h" #include "module_hsolver/kernels/dngvd_op.h" -#include +#include "module_hsolver/para_linear_transform.h" #include #include +#include namespace hsolver { @@ -50,11 +50,12 @@ class DiagoBPCG * This function allocates all the related variables, such as hpsi, hsub, before the diag call. * It is called by the HsolverPW::initDiagh() function. * - * @param nband The number of bands. + * @param nband The number of bands of all processes. + * @param nband_l The number of bands of current process. * @param nbasis The number of basis functions. Leading dimension of psi. * @param ndim The number of valid dimension of psi. */ - void init_iter(const int nband, const int nbasis, const int ndim); + void init_iter(const int nband, const int nband_l, const int nbasis, const int ndim); using HPsiFunc = std::function; @@ -74,14 +75,20 @@ class DiagoBPCG const std::vector& ethr_band); private: - /// the number of rows of the input psi + /// the number of bands of all processes int n_band = 0; + /// the number of bands of current process + int n_band_l = 0; /// the number of cols of the input psi int n_basis = 0; /// valid dimension of psi int n_dim = 0; /// max iter steps for all-band cg loop int nline = 4; + /// parallel matrix multiplication + ModuleBase::PGemmCN pmmcn; + PLinearTransform plintrans; + ct::DataType r_type = ct::DataType::DT_INVALID; ct::DataType t_type = ct::DataType::DT_INVALID; @@ -185,7 +192,6 @@ class DiagoBPCG * psi_out[dim: n_basis x n_band, column major, lda = n_basis_max], * * @param hsub_in Subspace matrix input, dim [n_basis, n_band] with column major. - * @param workspace_in Workspace matrix, dim [n_basis, n_band] with column major.. * @param psi_out output wavefunction matrix with dim [n_basis, n_band], column major. */ void rotate_wf( diff --git a/source/module_hsolver/hsolver_lcao.cpp b/source/module_hsolver/hsolver_lcao.cpp index 44deac1bbd..4c8ff835de 100644 --- a/source/module_hsolver/hsolver_lcao.cpp +++ b/source/module_hsolver/hsolver_lcao.cpp @@ -47,14 +47,14 @@ void HSolverLCAO::solve(hamilt::Hamilt* pHamilt, if (this->method != "pexsi") { - if (GlobalV::KPAR_LCAO > 1 + if (PARAM.globalv.kpar_lcao > 1 && (this->method == "genelpa" || this->method == "elpa" || this->method == "scalapack_gvx")) { #ifdef __MPI - this->parakSolve(pHamilt, psi, pes, GlobalV::KPAR_LCAO); + this->parakSolve(pHamilt, psi, pes, PARAM.globalv.kpar_lcao); #endif } - else if (GlobalV::KPAR_LCAO == 1) + else if (PARAM.globalv.kpar_lcao == 1) { /// Loop over k points for solve Hamiltonian to eigenpairs(eigenvalues and eigenvectors). for (int ik = 0; ik < psi.get_nk(); ++ik) diff --git a/source/module_hsolver/hsolver_pw.cpp b/source/module_hsolver/hsolver_pw.cpp index 05ccc8acd0..02f44f857b 100644 --- a/source/module_hsolver/hsolver_pw.cpp +++ b/source/module_hsolver/hsolver_pw.cpp @@ -477,7 +477,7 @@ void HSolverPW::hamiltSolvePsiK(hamilt::Hamilt* hm, } else if (this->method == "bpcg") { - const int nband = psi.get_nbands(); + const int nband_l = psi.get_nbands(); const int nbasis = psi.get_nbasis(); const int ndim = psi.get_current_ngk(); // hpsi_func (X, HX, ld, nvec) -> HX = H(X), X and HX blockvectors of size ld x nvec @@ -496,7 +496,7 @@ void HSolverPW::hamiltSolvePsiK(hamilt::Hamilt* hm, ModuleBase::timer::tick("DavSubspace", "hpsi_func"); }; DiagoBPCG bpcg(pre_condition.data()); - bpcg.init_iter(nband, nbasis, ndim); + bpcg.init_iter(PARAM.inp.nbands, nband_l, nbasis, ndim); bpcg.diag(hpsi_func, psi.get_pointer(), eigenvalue, this->ethr_band); } else if (this->method == "dav_subspace") diff --git a/source/module_hsolver/hsolver_pw_sdft.cpp b/source/module_hsolver/hsolver_pw_sdft.cpp index d03b37b848..b2df935ad6 100644 --- a/source/module_hsolver/hsolver_pw_sdft.cpp +++ b/source/module_hsolver/hsolver_pw_sdft.cpp @@ -46,7 +46,7 @@ void HSolverPW_SDFT::solve(const UnitCell& ucell, { ModuleBase::timer::tick("HSolverPW_SDFT", "solve_KS"); pHamilt->updateHk(ik); - if (nbands > 0 && GlobalV::MY_STOGROUP == 0) + if (nbands > 0 && PARAM.globalv.ks_run) { /// update psi pointer for each k point psi.fix_k(ik); @@ -58,10 +58,10 @@ void HSolverPW_SDFT::solve(const UnitCell& ucell, } #ifdef __MPI - if (nbands > 0 && PARAM.inp.bndpar > 1) + if (nbands > 0 && !PARAM.globalv.all_ks_run) { - Parallel_Common::bcast_dev(&psi(ik, 0, 0), npwx * nbands, PARAPW_WORLD, &psi_cpu(ik, 0, 0)); - MPI_Bcast(&pes->ekb(ik, 0), nbands, MPI_DOUBLE, 0, PARAPW_WORLD); + Parallel_Common::bcast_dev(&psi(ik, 0, 0), npwx * nbands, BP_WORLD, &psi_cpu(ik, 0, 0)); + MPI_Bcast(&pes->ekb(ik, 0), nbands, MPI_DOUBLE, 0, BP_WORLD); } #endif ModuleBase::timer::tick("HSolverPW_SDFT", "solve_KS"); @@ -89,17 +89,10 @@ void HSolverPW_SDFT::solve(const UnitCell& ucell, // calculate eband = \sum_{ik,ib} w(ik)f(ik,ib)e_{ikib}, demet = -TS elecstate::ElecStatePW* pes_pw = static_cast*>(pes); - if (GlobalV::MY_STOGROUP == 0) + pes_pw->calEBand(); + if(!PARAM.globalv.all_ks_run) { - pes_pw->calEBand(); - } - if (nbands > 0) - { -#ifdef __MPI - pes->f_en.eband /= GlobalV::NPROC_IN_POOL; - MPI_Allreduce(MPI_IN_PLACE, &pes->f_en.eband, 1, MPI_DOUBLE, MPI_SUM, STO_WORLD); - MPI_Bcast(&pes->f_en.eband, 1, MPI_DOUBLE, 0, PARAPW_WORLD); -#endif + pes->f_en.eband /= PARAM.inp.bndpar; } stoiter.sum_stoeband(stowf, pes_pw, pHamilt, wfc_basis); diff --git a/source/module_hsolver/para_linear_transform.cpp b/source/module_hsolver/para_linear_transform.cpp new file mode 100644 index 0000000000..17e267101f --- /dev/null +++ b/source/module_hsolver/para_linear_transform.cpp @@ -0,0 +1,181 @@ +#include "para_linear_transform.h" + +#include +#include +namespace hsolver +{ +template +void PLinearTransform::set_dimension(const int nrowA, + const int ncolA, + const int ncolB, + const int LDA, +#ifdef __MPI + MPI_Comm col_world, +#endif + const bool localU) +{ + this->nrowA = nrowA; + this->ncolA = ncolA; + this->ncolB = ncolB; + this->LDA = LDA; +#ifdef __MPI + this->col_world = col_world; + MPI_Comm_rank(col_world, &rank_col); + MPI_Comm_size(col_world, &nproc_col); + if (nproc_col > 1) + { + this->localU = localU; + colA_loc.resize(nproc_col); + MPI_Allgather(&ncolA, 1, MPI_INT, colA_loc.data(), 1, MPI_INT, col_world); + start_colA.resize(nproc_col); + start_colA[0] = 0; + for (int ip = 1; ip < nproc_col; ++ip) + { + start_colA[ip] = start_colA[ip - 1] + colA_loc[ip - 1]; + } + this->ncolA_glo = start_colA[nproc_col - 1] + colA_loc[nproc_col - 1]; + this->max_colA = *std::max_element(colA_loc.begin(), colA_loc.end()); + + std::vector colB_loc(nproc_col); + MPI_Allgather(&ncolB, 1, MPI_INT, colB_loc.data(), 1, MPI_INT, col_world); + start_colB.resize(nproc_col); + start_colB[0] = 0; + for (int ip = 1; ip < nproc_col; ++ip) + { + start_colB[ip] = start_colB[ip - 1] + colB_loc[ip - 1]; + } + this->max_colB = *std::max_element(colB_loc.begin(), colB_loc.end()); + } +#else + nproc_col = 1; + rank_col = 0; +#endif +} +template +void PLinearTransform::act(const T alpha, const T* A, const T* U, const T beta, T* B) +{ + const Device* ctx = {}; +#ifdef __MPI + if (nproc_col > 1) + { + std::vector requests(nproc_col); + std::vector A_tmp(max_colA * LDA); + std::vector isend_tmp; + T* A_tmp_device = A_tmp.data(); + if (std::is_same::value) + { + A_tmp_device = nullptr; + isend_tmp.resize(max_colA * LDA); + resmem_dev_op()(A_tmp_device, max_colA * LDA); + } + T* B_tmp = nullptr; + resmem_dev_op()(B_tmp, ncolB * LDA); + syncmem_dev_op()(B_tmp, B, ncolB * LDA); + setmem_dev_op()(B, 0.0, ncolB * LDA); + + T* U_tmp = nullptr; + resmem_dev_op()(U_tmp, max_colA * max_colB); + + // Send + for (int ip = 0; ip < nproc_col; ++ip) + { + if (rank_col != ip) + { + int size = LDA * ncolA; + Parallel_Common::isend_dev(A, size, ip, 0, col_world, &requests[ip], isend_tmp.data()); + } + } + + // Receive + const int start = this->localU ? 0 : start_colB[rank_col]; + for (int ip = 0; ip < nproc_col; ++ip) + { + T real_beta = ip == 0 ? beta : 0; + const int ncolA_ip = colA_loc[ip]; + // get U_tmp + + const int start_row = start_colA[ip]; + for (int i = 0; i < ncolB; ++i) + { + const T* U_part = U + start_row + (i + start) * ncolA_glo; + syncmem_dev_op()(U_tmp + i * ncolA_ip, U_part, ncolA_ip); + } + + if (ip == rank_col) + { + ModuleBase::gemm_op()(ctx, + 'N', + 'N', + nrowA, + ncolB, + ncolA_ip, + &alpha, + A, + LDA, + U_tmp, + ncolA_ip, + &real_beta, + B_tmp, + LDA); + } + else + { + int size = LDA * ncolA_ip; + MPI_Status status; + Parallel_Common::recv_dev(A_tmp_device, size, ip, 0, col_world, &status, A_tmp.data()); + MPI_Wait(&requests[ip], &status); + ModuleBase::gemm_op()(ctx, + 'N', + 'N', + nrowA, + ncolB, + ncolA_ip, + &alpha, + A_tmp_device, + LDA, + U_tmp, + ncolA_ip, + &real_beta, + B_tmp, + LDA); + } + // sum all the results + T one = 1.0; + ModuleBase::axpy_op()(ctx, ncolB * LDA, &one, B_tmp, 1, B, 1); + } + delmem_dev_op()(U_tmp); + delmem_dev_op()(B_tmp); + if (std::is_same::value) + { + delmem_dev_op()(A_tmp_device); + } + } + else +#endif + { + ModuleBase::gemm_op()(ctx, + 'N', + 'N', + nrowA, + ncolB, + ncolA, + &alpha, + A, + LDA, + U, + ncolA, + &beta, + B, + LDA); + } +}; + +template struct PLinearTransform; +template struct PLinearTransform, base_device::DEVICE_CPU>; +template struct PLinearTransform, base_device::DEVICE_CPU>; +#if ((defined __CUDA) || (defined __ROCM)) +template struct PLinearTransform; +template struct PLinearTransform, base_device::DEVICE_GPU>; +template struct PLinearTransform, base_device::DEVICE_GPU>; +#endif +} // namespace hsolver \ No newline at end of file diff --git a/source/module_hsolver/para_linear_transform.h b/source/module_hsolver/para_linear_transform.h new file mode 100644 index 0000000000..42cb02fb47 --- /dev/null +++ b/source/module_hsolver/para_linear_transform.h @@ -0,0 +1,76 @@ +#ifndef __PARA_LINEAR_TRANSFORM_H__ +#define __PARA_LINEAR_TRANSFORM_H__ +#include "module_base/kernels/math_kernel_op.h" +#include "module_base/module_device/device.h" +#include "module_base/module_device/memory_op.h" +#include "module_base/parallel_device.h" +#include +#ifdef __MPI +#include "mpi.h" +#endif +namespace hsolver +{ + +/** + * @brief B = alpha * A * U + beta * B + * A and B are local matrice + * U can be a local matrix or a global matrix + */ +template +class PLinearTransform +{ + public: + using syncmem_dev_op = base_device::memory::synchronize_memory_op; + using resmem_dev_op = base_device::memory::resize_memory_op; + using setmem_dev_op = base_device::memory::set_memory_op; + using delmem_dev_op = base_device::memory::delete_memory_op; + int nproc_col = 1; + int rank_col = 0; + int nrowA = 0; + int ncolA = 0; + int ncolB = 0; + int LDA = 0; + bool localU = false; +#ifdef __MPI + MPI_Comm col_world = MPI_COMM_NULL; + std::vector colA_loc; + std::vector start_colA; + std::vector start_colB; + int max_colA = 0; + int ncolA_glo = 0; + int max_colB = 0; +#endif + + /** + * @brief set the dimension of A, B, and U + * A: LDA * nrow, U_global: ncolA_global * ncolB_global, U_local: ncolA_global * ncolB + * B: LDA * ncolB + */ + void set_dimension(const int nrowA, + const int ncolA, + const int ncolB, + const int LDA, +#ifdef __MPI + MPI_Comm col_world, +#endif + const bool localU); + + /** + * @brief B = alpha * A * U + beta * B + * A is a local matrix with nrow rows and ncolA_loc columns + * B is a local matrix with nrow rows and ncolB_loc columns + * U can be a local matrix or a global matrix + * @example rotate wave functions: B = A * U + * orthogonalize wave functions: B = - A * U + B + * + * @param alpha : alpha + * @param A : input matrix + * @param U_global : input matrix + * @param beta : beta + * @param B : input/output matrix + * + */ + void act(const T alpha, const T* A, const T* U_global, const T beta, T* B); +}; +} // namespace hsolver +#endif \ No newline at end of file diff --git a/source/module_hsolver/test/CMakeLists.txt b/source/module_hsolver/test/CMakeLists.txt index e44171912c..7165e895a7 100644 --- a/source/module_hsolver/test/CMakeLists.txt +++ b/source/module_hsolver/test/CMakeLists.txt @@ -12,7 +12,7 @@ if (ENABLE_MPI) AddTest( TARGET HSolver_bpcg LIBS parameter ${math_libs} base psi device container - SOURCES diago_bpcg_test.cpp ../diago_bpcg.cpp ../diago_iter_assist.cpp + SOURCES diago_bpcg_test.cpp ../diago_bpcg.cpp ../para_linear_transform.cpp ../diago_iter_assist.cpp ../../module_basis/module_pw/test/test_tool.cpp ../../module_hamilt_general/operator.cpp ../../module_hamilt_pw/hamilt_pwdft/operator_pw/operator_pw.cpp @@ -77,13 +77,13 @@ if (ENABLE_MPI) AddTest( TARGET HSolver_pw LIBS parameter ${math_libs} psi device base container - SOURCES test_hsolver_pw.cpp ../hsolver_pw.cpp ../hsolver_lcaopw.cpp ../diago_bpcg.cpp ../diago_dav_subspace.cpp ../diag_const_nums.cpp ../diago_iter_assist.cpp + SOURCES test_hsolver_pw.cpp ../hsolver_pw.cpp ../hsolver_lcaopw.cpp ../diago_bpcg.cpp ../diago_dav_subspace.cpp ../diag_const_nums.cpp ../diago_iter_assist.cpp ../para_linear_transform.cpp ) AddTest( TARGET HSolver_sdft LIBS parameter ${math_libs} psi device base container - SOURCES test_hsolver_sdft.cpp ../hsolver_pw_sdft.cpp ../hsolver_pw.cpp ../diago_bpcg.cpp ../diago_dav_subspace.cpp ../diag_const_nums.cpp ../diago_iter_assist.cpp + SOURCES test_hsolver_sdft.cpp ../hsolver_pw_sdft.cpp ../hsolver_pw.cpp ../diago_bpcg.cpp ../diago_dav_subspace.cpp ../diag_const_nums.cpp ../diago_iter_assist.cpp ../para_linear_transform.cpp ) if(ENABLE_LCAO) @@ -159,6 +159,17 @@ AddTest( SOURCES test_diago_hs_para.cpp ../diag_hs_para.cpp ../diago_pxxxgvx.cpp ../diago_elpa.cpp ../diago_scalapack.cpp ) +AddTest( + TARGET hsolver_linear_trans + LIBS parameter ${math_libs} base device MPI::MPI_CXX + SOURCES test_para_linear_trans.cpp ../para_linear_transform.cpp +) + +add_test(NAME hsolver_para_linear_trans + COMMAND mpirun -np 4 ./hsolver_linear_trans + WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR} +) + find_program(BASH bash) if (ENABLE_MPI) add_test(NAME HSolver_cg_parallel diff --git a/source/module_hsolver/test/diago_bpcg_test.cpp b/source/module_hsolver/test/diago_bpcg_test.cpp index e6af8b5b5e..7f2931c507 100644 --- a/source/module_hsolver/test/diago_bpcg_test.cpp +++ b/source/module_hsolver/test/diago_bpcg_test.cpp @@ -154,7 +154,7 @@ class DiagoBPCGPrepare hpsi_out, ld_psi); }; const int ndim = psi_local.get_current_ngk(); - bpcg.init_iter(nband, npw, ndim); + bpcg.init_iter(nband, nband, npw, ndim); std::vector ethr_band(nband, 1e-5); bpcg.diag(hpsi_func, psi_local.get_pointer(), en, ethr_band); bpcg.diag(hpsi_func, psi_local.get_pointer(), en, ethr_band); @@ -283,6 +283,7 @@ int main(int argc, char **argv) int nproc_in_pool, kpar=1, mypool, rank_in_pool; setupmpi(argc,argv,nproc, myrank); divide_pools(nproc, myrank, nproc_in_pool, kpar, mypool, rank_in_pool); + MPI_Comm_split(MPI_COMM_WORLD,myrank,0,&BP_WORLD); GlobalV::NPROC_IN_POOL = nproc; #else MPI_Init(&argc, &argv); diff --git a/source/module_hsolver/test/diago_lapack_test.cpp b/source/module_hsolver/test/diago_lapack_test.cpp index cc181c5d28..cb3eb3f403 100644 --- a/source/module_hsolver/test/diago_lapack_test.cpp +++ b/source/module_hsolver/test/diago_lapack_test.cpp @@ -1,8 +1,8 @@ // Author: Zhang Xiaoyang // A modified version of diago_lcao_test.cpp -// #define private public +#define private public #include "module_parameter/parameter.h" -// #undef private +#undef private // Remove some useless functions and dependencies. Serialized the full code // and refactored some function. @@ -187,10 +187,8 @@ class DiagoLapackPrepare void set_env() { - // PARAM.sys.nlocal = nlocal; - PARAM.set_sys_nlocal(nlocal); - // PARAM.input.nbands = nbands; - PARAM.set_input_nbands(nbands); + PARAM.sys.nlocal = nlocal; + PARAM.input.nbands = nbands; } void diago() diff --git a/source/module_hsolver/test/test_hsolver_sdft.cpp b/source/module_hsolver/test/test_hsolver_sdft.cpp index 3ed2593d9a..d1f7bbca13 100644 --- a/source/module_hsolver/test/test_hsolver_sdft.cpp +++ b/source/module_hsolver/test/test_hsolver_sdft.cpp @@ -252,7 +252,7 @@ class TestHSolverPW_SDFT : public ::testing::Test // stowf.nchip_max = 0; // psi_test_cd.resize(1, 2, 3); // PARAM.input.nelec = 1.0; -// GlobalV::MY_STOGROUP = 0.0; +// GlobalV::MY_BNDGROUP = 0.0; // int istep = 0; // int iter = 0; @@ -291,7 +291,7 @@ class TestHSolverPW_SDFT : public ::testing::Test // psi_test_no.nbands = 0; // psi_test_no.nbasis = 0; // PARAM.input.nelec = 1.0; -// GlobalV::MY_STOGROUP = 0.0; +// GlobalV::MY_BNDGROUP = 0.0; // PARAM.input.nspin = 1; // elecstate_test.charge = new Charge; // elecstate_test.charge->rho = new double*[1]; @@ -339,10 +339,10 @@ int main(int argc, char** argv) MPI_Comm_size(MPI_COMM_WORLD, &GlobalV::NPROC); MPI_Comm_rank(MPI_COMM_WORLD, &GlobalV::MY_RANK); - MPI_Comm_split(MPI_COMM_WORLD, 0, 1, &PARAPW_WORLD); + MPI_Comm_split(MPI_COMM_WORLD, 0, 1, &BP_WORLD); int result = RUN_ALL_TESTS(); - MPI_Comm_free(&PARAPW_WORLD); + MPI_Comm_free(&BP_WORLD); MPI_Finalize(); return result; diff --git a/source/module_hsolver/test/test_para_linear_trans.cpp b/source/module_hsolver/test/test_para_linear_trans.cpp new file mode 100644 index 0000000000..50b533a9fd --- /dev/null +++ b/source/module_hsolver/test/test_para_linear_trans.cpp @@ -0,0 +1,260 @@ +#include "../para_linear_transform.h" + +#include +#ifdef __MPI +#include +#endif + +void random_data(std::vector& A_global, + std::vector& B_global, + std::vector& U_global, + double& alpha, + double& beta) +{ + for (auto& val: A_global) + { + val = std::rand() / (RAND_MAX + 1.0); + } + for (auto& val: B_global) + { + val = std::rand() / (RAND_MAX + 1.0); + } + for (auto& val: U_global) + { + val = std::rand() / (RAND_MAX + 1.0); + } + alpha = std::rand() / (RAND_MAX + 1.0); + beta = std::rand() / (RAND_MAX + 1.0); +} +void random_data(std::vector>& A_global, + std::vector>& B_global, + std::vector>& U_global, + std::complex& alpha, + std::complex& beta) +{ + for (auto& val: A_global) + { + val = std::complex(std::rand() / (RAND_MAX + 1.0), std::rand() / (RAND_MAX + 1.0)); + } + for (auto& val: B_global) + { + val = std::complex(std::rand() / (RAND_MAX + 1.0), std::rand() / (RAND_MAX + 1.0)); + } + for (auto& val: U_global) + { + val = std::complex(std::rand() / (RAND_MAX + 1.0), std::rand() / (RAND_MAX + 1.0)); + } + alpha = std::complex(std::rand() / (RAND_MAX + 1.0), std::rand() / (RAND_MAX + 1.0)); + beta = std::complex(std::rand() / (RAND_MAX + 1.0), std::rand() / (RAND_MAX + 1.0)); +} +double get_double(std::complex& val) +{ + return val.real() + val.imag(); +} +double get_double(double& val) +{ + return val; +} + +template +class ParaLinearTransformTest : public ::testing::Test +{ + protected: + void SetUp() override + { + } + + void TearDown() override + { + } + void prepare(const int nrow, const int ncolA_glo, const int ncolB_glo, const int LDA) + { + int rank = 0; + int nproc = 1; + int colA_start = 0; + int colB_start = 0; + this->ncolA_glo = ncolA_glo; + this->ncolB_glo = ncolB_glo; + this->ncolA_loc = ncolA_glo; + this->ncolB_loc = ncolB_glo; +#ifdef __MPI + MPI_Comm_rank(MPI_COMM_WORLD, &rank); + MPI_Comm_size(MPI_COMM_WORLD, &nproc); + this->ncolA_loc = ncolA_glo / nproc; + this->ncolB_loc = ncolB_glo / nproc; + if (rank < ncolA_glo % nproc) + { + ncolA_loc++; + ncolB_loc++; + } + std::vector ncolA_ip(nproc); + std::vector ncolB_ip(nproc); + MPI_Allgather(&ncolA_loc, 1, MPI_INT, ncolA_ip.data(), 1, MPI_INT, MPI_COMM_WORLD); + MPI_Allgather(&ncolB_loc, 1, MPI_INT, ncolB_ip.data(), 1, MPI_INT, MPI_COMM_WORLD); + for (int i = 0; i < rank; ++i) + { + colA_start += ncolA_ip[i]; + colB_start += ncolB_ip[i]; + } +#endif + A_global.resize(LDA * ncolA_glo); + B_global.resize(LDA * ncolB_glo); + B_global_ref.resize(LDA * ncolB_glo); + U_global.resize(ncolA_glo * ncolB_glo); + if (rank == 0) + { + random_data(A_global, B_global, U_global, alpha, beta); + B_global_ref = B_global; + const base_device::DEVICE_CPU* ctx = {}; + ModuleBase::gemm_op()(ctx, + 'N', + 'N', + nrow, + ncolB_glo, + ncolA_glo, + &alpha, + A_global.data(), + LDA, + U_global.data(), + ncolA_glo, + &beta, + B_global_ref.data(), + LDA); + } + if (std::is_same::value) + { +#ifdef __MPI + MPI_Bcast(A_global.data(), A_global.size(), MPI_DOUBLE, 0, MPI_COMM_WORLD); + MPI_Bcast(B_global.data(), B_global.size(), MPI_DOUBLE, 0, MPI_COMM_WORLD); + MPI_Bcast(U_global.data(), U_global.size(), MPI_DOUBLE, 0, MPI_COMM_WORLD); + MPI_Bcast(B_global_ref.data(), B_global_ref.size(), MPI_DOUBLE, 0, MPI_COMM_WORLD); + MPI_Bcast(&alpha, 1, MPI_DOUBLE, 0, MPI_COMM_WORLD); + MPI_Bcast(&beta, 1, MPI_DOUBLE, 0, MPI_COMM_WORLD); +#endif + } + else if (std::is_same>::value) + { +#ifdef __MPI + MPI_Bcast(A_global.data(), A_global.size(), MPI_DOUBLE_COMPLEX, 0, MPI_COMM_WORLD); + MPI_Bcast(B_global.data(), B_global.size(), MPI_DOUBLE_COMPLEX, 0, MPI_COMM_WORLD); + MPI_Bcast(U_global.data(), U_global.size(), MPI_DOUBLE_COMPLEX, 0, MPI_COMM_WORLD); + MPI_Bcast(B_global_ref.data(), B_global_ref.size(), MPI_DOUBLE_COMPLEX, 0, MPI_COMM_WORLD); + MPI_Bcast(&alpha, 1, MPI_DOUBLE_COMPLEX, 0, MPI_COMM_WORLD); + MPI_Bcast(&beta, 1, MPI_DOUBLE_COMPLEX, 0, MPI_COMM_WORLD); +#endif + } + + A.resize(LDA * ncolA_loc); + B.resize(LDA * ncolB_loc); + B_ref.resize(LDA * ncolB_loc); + for (int i = 0; i < LDA * ncolA_loc; ++i) + { + A[i] = A_global[colA_start * LDA + i]; + } + for (int i = 0; i < LDA * ncolB_loc; ++i) + { + B[i] = B_global[colB_start * LDA + i]; + B_ref[i] = B_global_ref[colB_start * LDA + i]; + } + } + std::vector A, B; + std::vector B_ref; + std::vector A_global; + std::vector B_global; + std::vector U_global; + std::vector B_global_ref; + int ncolA_glo = 1; + int ncolB_glo = 1; + int ncolA_loc = 1; + int ncolB_loc = 1; + T alpha; + T beta; + hsolver::PLinearTransform lt; +}; + +typedef ::testing::Types> MyTypes; +TYPED_TEST_SUITE(ParaLinearTransformTest, MyTypes); + +TYPED_TEST(ParaLinearTransformTest, globalU) +{ + const int nrowA = 7; + const int ncolA_glo = 13; + const int ncolB_glo = 11; + const int LDA = 9; + + this->prepare(nrowA, ncolA_glo, ncolB_glo, LDA); + int rank_col = 0, nproc_col = 1; +#ifdef __MPI + MPI_Comm col_world = MPI_COMM_WORLD; + MPI_Comm_rank(col_world, &rank_col); + MPI_Comm_size(col_world, &nproc_col); +#endif + + this->lt.set_dimension(nrowA, + this->ncolA_loc, + this->ncolB_loc, + LDA, +#ifdef __MPI + col_world, +#endif + false); + this->lt.act(this->alpha, this->A.data(), this->U_global.data(), this->beta, this->B.data()); + + for (int i = 0; i < this->ncolB_loc; ++i) + { + for (int j = 0; j < nrowA; ++j) + { + EXPECT_NEAR(get_double(this->B[j + i * LDA]), get_double(this->B_ref[j + i * LDA]), 1e-10); + } + } +} +#ifdef __MPI +TYPED_TEST(ParaLinearTransformTest, localU) +{ + const int nrowA = 7; + const int ncolA_glo = 13; + const int ncolB_glo = 11; + const int LDA = 9; + + this->prepare(nrowA, ncolA_glo, ncolB_glo, LDA); + int rank_col = 0, nproc_col = 1; + + MPI_Comm col_world = MPI_COMM_WORLD; + MPI_Comm_rank(col_world, &rank_col); + MPI_Comm_size(col_world, &nproc_col); + std::vector ncolB_ip(nproc_col); + std::vector start_colB(nproc_col); + MPI_Allgather(&this->ncolB_loc, 1, MPI_INT, ncolB_ip.data(), 1, MPI_INT, col_world); + start_colB[0] = 0; + for (int i = 1; i < nproc_col; ++i) + { + start_colB[i] = start_colB[i - 1] + ncolB_ip[i - 1]; + } + int start = start_colB[rank_col]; + + this->lt.set_dimension(nrowA, this->ncolA_loc, this->ncolB_loc, LDA, col_world, true); + + this->lt.act(this->alpha, this->A.data(), this->U_global.data() + start * ncolA_glo, this->beta, this->B.data()); + + for (int i = 0; i < this->ncolB_loc; ++i) + { + for (int j = 0; j < nrowA; ++j) + { + EXPECT_NEAR(get_double(this->B[j + i * LDA]), get_double(this->B_ref[j + i * LDA]), 1e-10); + } + } +} +#endif + +int main(int argc, char** argv) +{ +#ifdef __MPI + MPI_Init(&argc, &argv); +#endif + ::testing::InitGoogleTest(&argc, argv); + int result = RUN_ALL_TESTS(); +#ifdef __MPI + MPI_Finalize(); +#endif + return result; +} \ No newline at end of file diff --git a/source/module_io/cal_dos.cpp b/source/module_io/cal_dos.cpp index db9c374811..6f18b12af9 100644 --- a/source/module_io/cal_dos.cpp +++ b/source/module_io/cal_dos.cpp @@ -4,6 +4,7 @@ #include "module_base/global_function.h" #include "module_base/global_variable.h" #include "module_base/parallel_reduce.h" +#include "module_parameter/parameter.h" bool ModuleIO::calculate_dos(const int& is, const std::string& fa, // file address for DOS @@ -40,12 +41,12 @@ bool ModuleIO::calculate_dos(const int& is, if (de_ev <= 0) { ModuleBase::WARNING("ModuleIO::calculate_dos", "de <= 0 "); - return 0; + return false; } else if (emax_ev < emin_ev) { ModuleBase::WARNING("ModuleIO::calculate_dos", "emax_ev < emin_ev"); - return 0; + return false; } // mohan fixed bug 2010-1-18 @@ -56,7 +57,7 @@ bool ModuleIO::calculate_dos(const int& is, { ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running, "npoints", npoints); ModuleBase::WARNING("ModuleIO::calculate_dos", "npoints <= 0"); - return 0; + return false; } if (GlobalV::MY_RANK == 0) { @@ -102,7 +103,8 @@ bool ModuleIO::calculate_dos(const int& is, } } #ifdef __MPI - Parallel_Reduce::reduce_double_allpool(GlobalV::KPAR, GlobalV::NPROC_IN_POOL, count); + const int npool = GlobalV::KPAR * PARAM.inp.bndpar; + Parallel_Reduce::reduce_double_allpool(npool, GlobalV::NPROC_IN_POOL, count); #endif count = count / static_cast(nkstot); sum += count; @@ -157,5 +159,5 @@ bool ModuleIO::calculate_dos(const int& is, ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running, "sum up the states", sum); delete[] e_mod; - return 1; + return true; } diff --git a/source/module_io/input_conv.cpp b/source/module_io/input_conv.cpp index 54e549366d..a2240f8b8a 100644 --- a/source/module_io/input_conv.cpp +++ b/source/module_io/input_conv.cpp @@ -171,37 +171,9 @@ void Input_Conv::Convert() ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running, "orbital_dir", PARAM.inp.orbital_dir); // GlobalV::global_pseudo_type = PARAM.inp.pseudo_type; - if (PARAM.inp.calculation == "relax" || PARAM.inp.calculation == "cell-relax") - { - } - - if (PARAM.inp.device == "gpu" && PARAM.inp.basis_type == "pw") - { - GlobalV::KPAR = base_device::information::get_device_kpar(PARAM.inp.kpar, PARAM.inp.bndpar); - } -#ifdef __LCAO - else if (PARAM.inp.basis_type == "lcao") { - /// GlobalV::KPAR_LCAO is used in LCAO diagonalization only - GlobalV::KPAR_LCAO = PARAM.inp.kpar; - /// all other parts of the code use GlobalV::KPAR = 1 - GlobalV::KPAR = 1; - } -#endif - else - { - GlobalV::KPAR = PARAM.inp.kpar; - } - if (PARAM.inp.device == "cpu" and PARAM.inp.precision == "single") - { -// cpu single precision is not supported while float_fftw lib is not available -#ifndef __ENABLE_FLOAT_FFTW - ModuleBase::WARNING_QUIT( - "Input_Conv", - "Single precision with cpu is not supported while float_fftw lib is not available; \ - \n Please recompile with cmake flag \"-DENABLE_FLOAT_FFTW=ON\".\n"); -#endif // __ENABLE_FLOAT_FFTW - } + GlobalV::KPAR = PARAM.inp.kpar; + #ifdef __LCAO diff --git a/source/module_io/read_input.cpp b/source/module_io/read_input.cpp index 50bf47f4d1..794588ecaa 100644 --- a/source/module_io/read_input.cpp +++ b/source/module_io/read_input.cpp @@ -13,6 +13,7 @@ #include "module_base/global_function.h" #include "module_base/tool_quit.h" #include "module_base/tool_title.h" +#include "module_base/module_device/device.h" namespace ModuleIO { @@ -105,9 +106,6 @@ ReadInput::ReadInput(const int& rank) this->item_exx(); this->item_dftu(); this->item_others(); - - // set globalv functions - this->set_globalv_bcast(); } void ReadInput::read_parameters(Parameter& param, const std::string& filename_in) @@ -115,8 +113,54 @@ void ReadInput::read_parameters(Parameter& param, const std::string& filename_in // 1. only rank 0 read the input file if (this->rank == 0) { + // We can also easily add other input file formats here this->read_txt_input(param, filename_in); } + + // 2. check the number of atom types from STRU file + // set the global directories + this->set_global_dir(param.inp, param.sys); + if (this->check_ntype_flag && this->rank == 0) + { + check_ntype(param.globalv.global_in_stru, param.input.ntype); + } + + // 3. broadcast input parameters + // It must be after the check_ntype, because some parameters need to be filled due to ntype + for (auto& bcastfunc: this->bcastfuncs) + { + bcastfunc(param); + } + + // 4. set the globalv parameters, some parameters in different processes are different. e.g. rank, log_file + this->set_globalv(param.inp, param.sys); + + // 5. check the value of the parameters + // It must be after the check_ntype, because some parameters need to be checked according to ntype + // It must be after the set_globalv, because some parameters need to be checked according to param.sys + if (this->rank == 0) + { + for (auto& input_item: this->input_lists) + { + Input_Item* checkvalue_item = &(input_item.second); + if (checkvalue_item->check_value != nullptr) + { + checkvalue_item->check_value(*checkvalue_item, param); + } + } + } + + // 6. check and reset kpar. + // It must be after bcastfunc, and kpar and bndpar are synchronized + // It must be before wirte_txt_input, because kpar is used in write_txt_input + if (param.inp.device == "gpu" && param.inp.basis_type == "pw") + { + param.input.kpar = base_device::information::get_device_kpar(param.inp.kpar, param.inp.bndpar); + } + + + + if (this->check_mode) { std::cout << "----------------------------------------------------------" << std::endl; @@ -125,12 +169,6 @@ void ReadInput::read_parameters(Parameter& param, const std::string& filename_in exit(0); return; } - - // 2. broadcast the parameters - for (auto& bcastfunc: this->bcastfuncs) - { - bcastfunc(param); - } } void ReadInput::create_directory(const Parameter& param) @@ -283,22 +321,6 @@ void ReadInput::read_txt_input(Parameter& param, const std::string& filename) resetvalue_item->reset_value(*resetvalue_item, param); } } - this->set_globalv(param); - - // 3) count the number of atom types from STRU file - if (this->check_ntype_flag) - { - check_ntype(param.globalv.global_in_stru, param.input.ntype); - } - - // 4) check the value of the parameters - for (auto& input_item: this->input_lists) - { - Input_Item* checkvalue_item = &(input_item.second); - if (checkvalue_item->check_value != nullptr) { - checkvalue_item->check_value(*checkvalue_item, param); - } - } } void ReadInput::write_txt_input(const Parameter& param, const std::string& filename) @@ -442,9 +464,10 @@ void ReadInput::check_ntype(const std::string& fn, int& param_ntype) { ModuleBase::WARNING_QUIT("ReadInput", "ntype should be greater than 0."); } - else { + else + { GlobalV::ofs_running << " 'ntype' is automatically set to " << param_ntype << std::endl; -} + } } int ReadInput::current_md_step(const std::string& file_dir) diff --git a/source/module_io/read_input.h b/source/module_io/read_input.h index a39d043fe7..09f397e474 100644 --- a/source/module_io/read_input.h +++ b/source/module_io/read_input.h @@ -83,10 +83,12 @@ class ReadInput * @param item input_item */ void add_item(const Input_Item& item); - //set globalv parameters - void set_globalv(Parameter& para); - // add bcast functions for global values - void set_globalv_bcast(); + /// @brief set System_para according to input parameters + /// INPUT and STRU need to refer to each other in ABACUS, + /// so it is necessary to obtain the file paths related to all inputs + void set_global_dir(const Input_para& inp, System_para& sys); + // set System_para according to input parameters + void set_globalv(const Input_para& inp, System_para& sys); // system items void item_system(); // items for electronic structure diff --git a/source/module_io/read_input_item_elec_stru.cpp b/source/module_io/read_input_item_elec_stru.cpp index 089f842de5..79fb24584b 100644 --- a/source/module_io/read_input_item_elec_stru.cpp +++ b/source/module_io/read_input_item_elec_stru.cpp @@ -248,6 +248,7 @@ void ReadInput::item_elec_stru() } }; sync_double(input.nupdown); + add_bool_bcast(sys.two_fermi); this->add_item(item); } { diff --git a/source/module_io/read_input_item_exx_dftu.cpp b/source/module_io/read_input_item_exx_dftu.cpp index dc7c6a6025..b74611bd79 100644 --- a/source/module_io/read_input_item_exx_dftu.cpp +++ b/source/module_io/read_input_item_exx_dftu.cpp @@ -398,6 +398,7 @@ void ReadInput::item_dftu() } }; sync_double(input.uramping_eV); + add_double_bcast(sys.uramping); this->add_item(item); } { diff --git a/source/module_io/read_input_item_output.cpp b/source/module_io/read_input_item_output.cpp index 8e4a59d046..e2c9dfb068 100644 --- a/source/module_io/read_input_item_output.cpp +++ b/source/module_io/read_input_item_output.cpp @@ -187,6 +187,7 @@ void ReadInput::item_output() } }; sync_string(input.out_level); + add_bool_bcast(sys.out_md_control); this->add_item(item); } { diff --git a/source/module_io/read_input_item_postprocess.cpp b/source/module_io/read_input_item_postprocess.cpp index 1087e24d95..46c1e18b4f 100644 --- a/source/module_io/read_input_item_postprocess.cpp +++ b/source/module_io/read_input_item_postprocess.cpp @@ -15,6 +15,7 @@ void ReadInput::item_postprocess() para.sys.dos_setemin = true; }; sync_double(input.dos_emin_ev); + add_bool_bcast(sys.dos_setemin); this->add_item(item); } { @@ -25,6 +26,7 @@ void ReadInput::item_postprocess() para.sys.dos_setemax = true; }; sync_double(input.dos_emax_ev); + add_bool_bcast(sys.dos_setemax); this->add_item(item); } { diff --git a/source/module_io/read_input_item_system.cpp b/source/module_io/read_input_item_system.cpp index c95e785885..bf88ee08b6 100644 --- a/source/module_io/read_input_item_system.cpp +++ b/source/module_io/read_input_item_system.cpp @@ -233,6 +233,15 @@ void ReadInput::item_system() item.annotation = "devide all processors into kpar groups and k points " "will be distributed among"; read_sync_int(input.kpar); + item.reset_value = [](const Input_Item& item, Parameter& para) { +#ifdef __LCAO + if (para.inp.basis_type == "lcao") + { + para.sys.kpar_lcao = para.inp.kpar; + para.input.kpar = 1; + } +#endif + }; item.check_value = [](const Input_Item& item, const Parameter& para) { if (para.input.basis_type == "lcao" && para.input.kpar > 1) { @@ -240,6 +249,7 @@ void ReadInput::item_system() } }; this->add_item(item); + add_int_bcast(sys.kpar_lcao); } { Input_Item item("bndpar"); @@ -247,7 +257,7 @@ void ReadInput::item_system() "will be distributed among each group"; read_sync_int(input.bndpar); item.reset_value = [](const Input_Item& item, Parameter& para) { - if (para.input.esolver_type != "sdft") + if (para.input.esolver_type != "sdft" && para.input.ks_solver != "bpcg") { para.input.bndpar = 1; } @@ -322,7 +332,6 @@ void ReadInput::item_system() item.annotation = "number of points along x axis for FFT grid"; item.read_value = [](const Input_Item& item, Parameter& para) { para.input.nx = intvalue; - para.sys.ncx = intvalue; }; item.check_value = [](const Input_Item& item, const Parameter& para) { if (para.input.nx * para.input.ny * para.input.nz == 0 && para.input.nx != 0) @@ -338,7 +347,6 @@ void ReadInput::item_system() item.annotation = "number of points along y axis for FFT grid"; item.read_value = [](const Input_Item& item, Parameter& para) { para.input.ny = intvalue; - para.sys.ncy = intvalue; }; item.check_value = [](const Input_Item& item, const Parameter& para) { if (para.input.nx * para.input.ny * para.input.nz == 0 && para.input.ny != 0) @@ -354,7 +362,6 @@ void ReadInput::item_system() item.annotation = "number of points along z axis for FFT grid"; item.read_value = [](const Input_Item& item, Parameter& para) { para.input.nz = intvalue; - para.sys.ncz = intvalue; }; item.check_value = [](const Input_Item& item, const Parameter& para) { if (para.input.nx * para.input.ny * para.input.nz == 0 && para.input.nz != 0) @@ -796,6 +803,17 @@ void ReadInput::item_system() const std::string warningstr = nofound_str(avail_list, "precision"); ModuleBase::WARNING_QUIT("ReadInput", warningstr); } + + // cpu single precision is not supported while float_fftw lib is not available + if (para.inp.device == "cpu" && para.inp.precision == "single") + { +#ifndef __ENABLE_FLOAT_FFTW + ModuleBase::WARNING_QUIT( + "ReadInput", + "Single precision with cpu is not supported while float_fftw lib is not available; \ + \n Please recompile with cmake flag \"-DENABLE_FLOAT_FFTW=ON\".\n"); +#endif + } }; this->add_item(item); } diff --git a/source/module_io/read_set_globalv.cpp b/source/module_io/read_set_globalv.cpp old mode 100644 new mode 100755 index 0b3434f170..960d3ad48f --- a/source/module_io/read_set_globalv.cpp +++ b/source/module_io/read_set_globalv.cpp @@ -1,131 +1,134 @@ #include "module_base/global_variable.h" #include "module_base/tool_quit.h" -#include "module_base/module_device/device.h" #include "module_parameter/parameter.h" #include "read_input.h" #include "read_input_tool.h" namespace ModuleIO { -void ReadInput::set_globalv(Parameter& para) +/// @note Here para.inp has been synchronized of all ranks. +/// All para.inp have the same value. +void ReadInput::set_globalv(const Input_para& inp, System_para& sys) { + /// caculate the gamma_only_pw and gamma_only_local + if (inp.gamma_only) { - /// caculate the global output directory - const std::string prefix = "OUT."; - para.sys.global_out_dir = prefix + para.inp.suffix + "/"; - para.sys.global_out_dir = to_dir(para.sys.global_out_dir); - - /// get the global output directory - para.sys.global_stru_dir = para.globalv.global_out_dir + "STRU/"; - para.sys.global_stru_dir = to_dir(para.sys.global_stru_dir); - - /// get the global output directory - para.sys.global_matrix_dir = para.globalv.global_out_dir + "matrix/"; - para.sys.global_matrix_dir = to_dir(para.sys.global_matrix_dir); - - /// get the global readin directory - para.sys.global_readin_dir = para.inp.read_file_dir; - para.sys.global_readin_dir = to_dir(para.sys.global_readin_dir); - - /// get the stru file for md restart case - if (para.inp.calculation == "md" && para.mdp.md_restart) - { - int istep = current_md_step(para.sys.global_readin_dir); - - if (para.inp.read_file_dir == to_dir("OUT." + para.input.suffix)) - { - para.sys.global_in_stru = para.sys.global_stru_dir + "STRU_MD_" + std::to_string(istep); - } - else - { - para.sys.global_in_stru = para.inp.read_file_dir + "STRU_MD_" + std::to_string(istep); - } - } - else - { - para.sys.global_in_stru = para.inp.stru_file; - } - - /// caculate the gamma_only_pw and gamma_only_local - if (para.input.gamma_only) - { - para.sys.gamma_only_local = true; - } - if (para.sys.gamma_only_local) + sys.gamma_only_local = true; + } + if (sys.gamma_only_local) + { + if (inp.esolver_type == "tddft") { - if (para.inp.esolver_type == "tddft") - { - GlobalV::ofs_running << " WARNING : gamma_only is not applicable for tddft" << std::endl; - para.sys.gamma_only_local = false; - } + GlobalV::ofs_running << " WARNING : gamma_only is not applicable for tddft" << std::endl; + sys.gamma_only_local = false; } - /// set deepks_setorb - if (para.input.deepks_scf || para.input.deepks_out_labels) + } + /// set deepks_setorb + if (inp.deepks_scf || inp.deepks_out_labels) + { + sys.deepks_setorb = true; + } + /// set the noncolin and lspinorb from nspin + switch (inp.nspin) + { + case 4: + if (inp.noncolin) { - para.sys.deepks_setorb = true; + sys.domag = true; + sys.domag_z = false; } - /// set the noncolin and lspinorb from nspin - switch (para.input.nspin) + else { - case 4: - if (para.input.noncolin) - { - para.sys.domag = true; - para.sys.domag_z = false; - } - else - { - para.sys.domag = false; - para.sys.domag_z = true; - } - para.sys.npol = 2; - break; - case 2: - case 1: - para.sys.domag = false; - para.sys.domag_z = false; - para.sys.npol = 1; - default: - break; + sys.domag = false; + sys.domag_z = true; } - - para.sys.nqx=static_cast((sqrt(para.inp.ecutwfc) / para.sys.dq + 4.0) * para.inp.cell_factor); - para.sys.nqxq=static_cast((sqrt(para.inp.ecutrho) / para.sys.dq + 4.0) * para.inp.cell_factor); + sys.npol = 2; + break; + case 2: + case 1: + sys.domag = false; + sys.domag_z = false; + sys.npol = 1; + default: + break; + } + sys.nqx = static_cast((sqrt(inp.ecutwfc) / sys.dq + 4.0) * inp.cell_factor); + sys.nqxq = static_cast((sqrt(inp.ecutrho) / sys.dq + 4.0) * inp.cell_factor); + /// set ncx,ncy,ncz + sys.ncx = inp.nx; + sys.ncy = inp.ny; + sys.ncz = inp.nz; +#ifdef __MPI + Parallel_Common::bcast_bool(sys.double_grid); +#endif + /// set ks_run + if (inp.ks_solver != "bpcg" && inp.bndpar > 1) + { + sys.all_ks_run = false; } } -void ReadInput::set_globalv_bcast() +/// @note Here para.inp has been synchronized of all ranks. +/// Only para.inp in rank 0 is right. +/// So we need to broadcast the results to all ranks. +void ReadInput::set_global_dir(const Input_para& inp, System_para& sys) { - // add_int_bcast(sys.myrank); - add_bool_bcast(sys.two_fermi); - add_bool_bcast(sys.use_uspp); - add_bool_bcast(sys.dos_setemin); - add_bool_bcast(sys.dos_setemax); + /// caculate the global output directory + const std::string prefix = "OUT."; + sys.global_out_dir = prefix + inp.suffix + "/"; + sys.global_out_dir = to_dir(sys.global_out_dir); + + /// get the global output directory + sys.global_stru_dir = sys.global_out_dir + "STRU/"; + sys.global_stru_dir = to_dir(sys.global_stru_dir); + + /// get the global output directory + sys.global_matrix_dir = sys.global_out_dir + "matrix/"; + sys.global_matrix_dir = to_dir(sys.global_matrix_dir); - add_int_bcast(sys.ncx); - add_int_bcast(sys.ncy); - add_int_bcast(sys.ncz); - add_bool_bcast(sys.out_md_control); - add_bool_bcast(sys.rpa_setorb); + /// get the global readin directory + sys.global_readin_dir = inp.read_file_dir; + sys.global_readin_dir = to_dir(sys.global_readin_dir); - add_bool_bcast(sys.gamma_only_pw); - add_bool_bcast(sys.gamma_only_local); - - add_string_bcast(sys.global_in_card); - add_string_bcast(sys.global_out_dir); - add_string_bcast(sys.global_readin_dir); - add_string_bcast(sys.global_stru_dir); - add_string_bcast(sys.global_matrix_dir); + /// get the stru file for md restart case + if (inp.calculation == "md" && inp.mdp.md_restart) + { + int istep = current_md_step(sys.global_readin_dir); - add_bool_bcast(sys.deepks_setorb); - - add_bool_bcast(sys.domag); - add_bool_bcast(sys.domag_z); - add_int_bcast(sys.npol); - - add_bool_bcast(sys.double_grid); - add_double_bcast(sys.uramping); - add_double_bcast(sys.dq); - add_int_bcast(sys.nqx); - add_int_bcast(sys.nqxq); + if (inp.read_file_dir == to_dir("OUT." + inp.suffix)) + { + sys.global_in_stru = sys.global_stru_dir + "STRU_MD_" + std::to_string(istep); + } + else + { + sys.global_in_stru = inp.read_file_dir + "STRU_MD_" + std::to_string(istep); + } + } + else + { + sys.global_in_stru = inp.stru_file; + } + + // set the global log file + bool out_alllog = inp.out_alllog; +#ifdef __MPI + // because log_file is different for each rank, so we need to bcast the out_alllog + Parallel_Common::bcast_bool(out_alllog); +#endif + if (out_alllog) + { + PARAM.sys.log_file = "running_" + PARAM.inp.calculation + "_" + std::to_string(PARAM.sys.myrank + 1) + ".log"; + } + else + { + PARAM.sys.log_file = "running_" + PARAM.inp.calculation + ".log"; + } +#ifdef __MPI + Parallel_Common::bcast_string(sys.global_in_card); + Parallel_Common::bcast_string(sys.global_out_dir); + Parallel_Common::bcast_string(sys.global_readin_dir); + Parallel_Common::bcast_string(sys.global_stru_dir); + Parallel_Common::bcast_string(sys.global_matrix_dir); + Parallel_Common::bcast_string(sys.global_in_stru); +#endif } } // namespace ModuleIO diff --git a/source/module_io/test/read_wfc_to_rho_test.cpp b/source/module_io/test/read_wfc_to_rho_test.cpp index 9272a7c0a9..0290ae1d38 100644 --- a/source/module_io/test/read_wfc_to_rho_test.cpp +++ b/source/module_io/test/read_wfc_to_rho_test.cpp @@ -277,9 +277,9 @@ int main(int argc, char** argv) GlobalV::MY_RANK, PARAM.inp.bndpar, GlobalV::KPAR, - GlobalV::NPROC_IN_STOGROUP, - GlobalV::RANK_IN_STOGROUP, - GlobalV::MY_STOGROUP, + GlobalV::NPROC_IN_BNDGROUP, + GlobalV::RANK_IN_BPGROUP, + GlobalV::MY_BNDGROUP, GlobalV::NPROC_IN_POOL, GlobalV::RANK_IN_POOL, GlobalV::MY_POOL); diff --git a/source/module_io/test/write_istate_info_test.cpp b/source/module_io/test/write_istate_info_test.cpp index 0afdad3608..9e95d17fa4 100644 --- a/source/module_io/test/write_istate_info_test.cpp +++ b/source/module_io/test/write_istate_info_test.cpp @@ -46,6 +46,7 @@ TEST_F(IstateInfoTest, OutIstateInfoS1) // preconditions GlobalV::KPAR = 1; PARAM.input.nbands = 4; + PARAM.sys.nbands_l = 4; PARAM.input.nspin = 1; PARAM.sys.global_out_dir = "./"; // mpi setting @@ -53,9 +54,9 @@ TEST_F(IstateInfoTest, OutIstateInfoS1) GlobalV::MY_RANK, PARAM.input.bndpar, GlobalV::KPAR, - GlobalV::NPROC_IN_STOGROUP, - GlobalV::RANK_IN_STOGROUP, - GlobalV::MY_STOGROUP, + GlobalV::NPROC_IN_BNDGROUP, + GlobalV::RANK_IN_BPGROUP, + GlobalV::MY_BNDGROUP, GlobalV::NPROC_IN_POOL, GlobalV::RANK_IN_POOL, GlobalV::MY_POOL); @@ -96,6 +97,7 @@ TEST_F(IstateInfoTest, OutIstateInfoS2) // preconditions GlobalV::KPAR = 1; PARAM.input.nbands = 4; + PARAM.sys.nbands_l = 4; PARAM.input.nspin = 2; PARAM.sys.global_out_dir = "./"; // mpi setting @@ -103,9 +105,9 @@ TEST_F(IstateInfoTest, OutIstateInfoS2) GlobalV::MY_RANK, PARAM.input.bndpar, GlobalV::KPAR, - GlobalV::NPROC_IN_STOGROUP, - GlobalV::RANK_IN_STOGROUP, - GlobalV::MY_STOGROUP, + GlobalV::NPROC_IN_BNDGROUP, + GlobalV::RANK_IN_BPGROUP, + GlobalV::MY_BNDGROUP, GlobalV::NPROC_IN_POOL, GlobalV::RANK_IN_POOL, GlobalV::MY_POOL); diff --git a/source/module_io/test_serial/io_system_variable_test.cpp b/source/module_io/test_serial/io_system_variable_test.cpp index b02395afc6..499399a69d 100644 --- a/source/module_io/test_serial/io_system_variable_test.cpp +++ b/source/module_io/test_serial/io_system_variable_test.cpp @@ -44,26 +44,29 @@ TEST_F(InputTest, Item_test) { param.input.suffix = "test"; - readinput.set_globalv(param); + readinput.set_global_dir(param.inp, param.sys); + EXPECT_EQ(param.sys.global_out_dir, "OUT.test/"); EXPECT_EQ(param.sys.global_stru_dir, "OUT.test/STRU/"); EXPECT_EQ(param.sys.global_matrix_dir, "OUT.test/matrix/"); + readinput.set_globalv(param.inp, param.sys); + param.input.basis_type = "lcao"; param.input.gamma_only = true; param.input.esolver_type = "tddft"; param.input.nspin = 2; - readinput.set_globalv(param); + readinput.set_globalv(param.inp, param.sys); EXPECT_EQ(param.sys.gamma_only_local, 0); param.input.deepks_scf = true; param.input.deepks_out_labels = true; - readinput.set_globalv(param); + readinput.set_globalv(param.inp, param.sys); EXPECT_EQ(param.sys.deepks_setorb, 1); param.input.nspin = 4; param.input.noncolin = true; - readinput.set_globalv(param); + readinput.set_globalv(param.inp, param.sys); EXPECT_EQ(param.sys.domag, 1); EXPECT_EQ(param.sys.domag_z, 0); EXPECT_EQ(param.sys.npol, 2); @@ -71,7 +74,7 @@ TEST_F(InputTest, Item_test) param.input.nspin = 1; param.input.lspinorb = true; param.input.noncolin = false; - readinput.set_globalv(param); + readinput.set_globalv(param.inp, param.sys); EXPECT_EQ(param.sys.domag, 0); EXPECT_EQ(param.sys.domag_z, 0); EXPECT_EQ(param.sys.npol, 1); diff --git a/source/module_io/write_cube.cpp b/source/module_io/write_cube.cpp index dbb4e469cf..400f454878 100644 --- a/source/module_io/write_cube.cpp +++ b/source/module_io/write_cube.cpp @@ -35,7 +35,7 @@ void ModuleIO::write_vdata_palgrid( // reduce std::vector data_xyz_full(nxyz); // data to be written #ifdef __MPI // reduce to rank 0 - if (my_pool == 0 && GlobalV::MY_STOGROUP == 0) + if (my_pool == 0 && GlobalV::MY_BNDGROUP == 0) { pgrid.reduce(data_xyz_full.data(), data); } diff --git a/source/module_io/write_istate_info.cpp b/source/module_io/write_istate_info.cpp index 0fc52afad2..34a4dba414 100644 --- a/source/module_io/write_istate_info.cpp +++ b/source/module_io/write_istate_info.cpp @@ -24,7 +24,7 @@ void ModuleIO::write_istate_info(const ModuleBase::matrix &ekb,const ModuleBase: MPI_Barrier(MPI_COMM_WORLD); if (GlobalV::MY_POOL == ip) { - if (GlobalV::RANK_IN_POOL != 0 || GlobalV::MY_STOGROUP != 0 ) { continue; + if (GlobalV::RANK_IN_POOL != 0 || GlobalV::MY_BNDGROUP != 0 ) { continue; } #endif std::ofstream ofsi2(ss.str().c_str(), std::ios::app); @@ -41,7 +41,7 @@ void ModuleIO::write_istate_info(const ModuleBase::matrix &ekb,const ModuleBase: << std::setw(25) << "Kpoint = " << ik_global << std::setw(25) << "(" << kv.kvec_d[ik].x << " " << kv.kvec_d[ik].y << " " << kv.kvec_d[ik].z << ")" << std::endl; - for (int ib = 0; ib < PARAM.inp.nbands; ib++) + for (int ib = 0; ib < PARAM.globalv.nbands_l; ib++) { ofsi2.precision(16); ofsi2 << std::setw(6) << ib + 1 << std::setw(25) diff --git a/source/module_parameter/parameter.cpp b/source/module_parameter/parameter.cpp index 8d5375b3e5..8688b93a07 100644 --- a/source/module_parameter/parameter.cpp +++ b/source/module_parameter/parameter.cpp @@ -14,13 +14,3 @@ void Parameter::set_start_time(const std::time_t& start_time) { sys.start_time = start_time; } - -void Parameter::set_input_nbands(const int& nbands) -{ - input.nbands = nbands; -} - -void Parameter::set_sys_nlocal(const int& nlocal) -{ - sys.nlocal = nlocal; -} diff --git a/source/module_parameter/parameter.h b/source/module_parameter/parameter.h index bdfe2359ea..9a3d4ea9ff 100644 --- a/source/module_parameter/parameter.h +++ b/source/module_parameter/parameter.h @@ -32,12 +32,6 @@ class Parameter void set_pal_param(const int& myrank, const int& nproc, const int& nthread_per_proc); // Set the start time void set_start_time(const std::time_t& start_time); - - // set input.nbands - void set_input_nbands(const int& nbands); - // set sys.nlocal - void set_sys_nlocal(const int& nlocal); - private: // Only ReadInput and CalAtomInfo can modify the value of Parameter. // Do not add extra friend class here!!! diff --git a/source/module_parameter/system_parameter.h b/source/module_parameter/system_parameter.h old mode 100644 new mode 100755 index 04d2ca870e..e41ed6b823 --- a/source/module_parameter/system_parameter.h +++ b/source/module_parameter/system_parameter.h @@ -33,7 +33,6 @@ struct System_para int ncx = 0, ncy = 0, ncz = 0; ///< three dimension of FFT charge/grid, same as "nx,ny,nz" bool out_md_control = false; ///< true if "out_level" is set - bool rpa_setorb = false; ///< true if "rpa" is set bool gamma_only_pw = false; ///< true if "gamma_only" is true and "basis_type" is "pw" ///< for plane wave basis. bool gamma_only_local = false; ///< true if "gamma_only" is true and "lcao" @@ -44,6 +43,7 @@ struct System_para std::string global_readin_dir = ""; ///< global readin directory std::string global_stru_dir = ""; ///< global structure directory std::string global_matrix_dir = ""; ///< global matrix directory + std::string log_file = "log"; ///< log file name bool deepks_setorb = false; ///< true if "deepks" is set int npol = 1; ///< number of polarization @@ -53,5 +53,9 @@ struct System_para bool double_grid = false; ///< true if "ndx,ndy,ndz" is larger than "nx,ny,nz" double uramping = -10.0 / 13.6; /// U-Ramping method (Ry) std::vector hubbard_u = {}; ///< Hubbard Coulomb interaction parameter U (Ry) + int kpar_lcao = 1; ///< global number of pools for LCAO diagonalization only + int nbands_l = 0; ///< number of bands of each band parallel calculation, same to nbands when bndpar=1 + bool ks_run = false; ///< true if current process runs KS calculation + bool all_ks_run = true; ///< true if only all processes run KS calculation }; #endif \ No newline at end of file diff --git a/source/module_psi/psi.cpp b/source/module_psi/psi.cpp index 78eb202766..07452554b1 100644 --- a/source/module_psi/psi.cpp +++ b/source/module_psi/psi.cpp @@ -296,7 +296,7 @@ const int& Psi::get_current_ngk() const } template -const int Psi::get_npol() const +int Psi::get_npol() const { if (PARAM.inp.nspin == 4) { diff --git a/source/module_psi/psi.h b/source/module_psi/psi.h index 75e13433ea..9b427c070d 100644 --- a/source/module_psi/psi.h +++ b/source/module_psi/psi.h @@ -134,7 +134,7 @@ class Psi // solve Range: return(pointer of begin, number of bands or k-points) std::tuple to_range(const Range& range) const; - const int get_npol() const; + int get_npol() const; private: T* psi = nullptr; // avoid using C++ STL diff --git a/source/module_psi/psi_init.cpp b/source/module_psi/psi_init.cpp index 8ef89dcfdc..d89b8ab477 100644 --- a/source/module_psi/psi_init.cpp +++ b/source/module_psi/psi_init.cpp @@ -2,6 +2,7 @@ #include "module_base/macros.h" #include "module_base/memory.h" +#include "module_base/parallel_device.h" #include "module_base/timer.h" #include "module_base/tool_quit.h" #include "module_hsolver/diago_iter_assist.h" @@ -86,7 +87,7 @@ void PSIInit::initialize_psi(Psi>* psi, hamilt::Hamilt* p_hamilt, std::ofstream& ofs_running) { - if (kspw_psi->get_nbands() == 0 || GlobalV::MY_STOGROUP != 0) + if (kspw_psi->get_nbands() == 0 || (!PARAM.globalv.ks_run)) { return; } @@ -97,30 +98,34 @@ void PSIInit::initialize_psi(Psi>* psi, ModuleBase::timer::tick("PSIInit", "initialize_psi"); const int nbands_start = this->psi_initer->nbands_start(); - const int nbands = psi->get_nbands(); + const int nbands_l = psi->get_nbands(); const int nbasis = psi->get_nbasis(); - const bool not_equal = (nbands_start != nbands); + const bool not_equal = (nbands_start != nbands_l); Psi* psi_cpu = reinterpret_cast*>(psi); Psi* psi_device = kspw_psi; - if (not_equal) + bool fill = PARAM.inp.ks_solver != "bpcg" || GlobalV::MY_BNDGROUP == 0; + if (fill) { - psi_cpu = new Psi(1, nbands_start, nbasis, nbasis, true); - psi_device = PARAM.inp.device == "gpu" ? new psi::Psi(psi_cpu[0]) - : reinterpret_cast*>(psi_cpu); - } - else if (PARAM.inp.precision == "single") - { - if (PARAM.inp.device == "cpu") + if (not_equal) { - psi_cpu = reinterpret_cast*>(kspw_psi); - psi_device = kspw_psi; + psi_cpu = new Psi(1, nbands_start, nbasis, nbasis, true); + psi_device = PARAM.inp.device == "gpu" ? new psi::Psi(psi_cpu[0]) + : reinterpret_cast*>(psi_cpu); } - else + else if (PARAM.inp.precision == "single") { - psi_cpu = new Psi(1, nbands_start, nbasis, nbasis, true); - psi_device = kspw_psi; + if (PARAM.inp.device == "cpu") + { + psi_cpu = reinterpret_cast*>(kspw_psi); + psi_device = kspw_psi; + } + else + { + psi_cpu = new Psi(1, nbands_start, nbasis, nbasis, true); + psi_device = kspw_psi; + } } } @@ -134,58 +139,90 @@ void PSIInit::initialize_psi(Psi>* psi, //! Update Hamiltonian from other kpoint to the given one p_hamilt->updateHk(ik); - - //! initialize psi_cpu - this->psi_initer->init_psig(psi_cpu->get_pointer(), ik); - if (psi_device->get_pointer() != psi_cpu->get_pointer()) + if (fill) { - syncmem_h2d_op()(psi_device->get_pointer(), psi_cpu->get_pointer(), nbands_start * nbasis); - } - - std::vector::type> etatom(nbands_start, 0.0); + //! initialize psi_cpu + this->psi_initer->init_psig(psi_cpu->get_pointer(), ik); + if (psi_device->get_pointer() != psi_cpu->get_pointer()) + { + syncmem_h2d_op()(psi_device->get_pointer(), psi_cpu->get_pointer(), nbands_start * nbasis); + } - if (this->ks_solver == "cg") - { - if (not_equal) + if (this->ks_solver == "cg") { - // for diagH_subspace_init, psi_device->get_pointer() and kspw_psi->get_pointer() should be different - hsolver::DiagoIterAssist::diagH_subspace_init(p_hamilt, - psi_device->get_pointer(), - nbands_start, - nbasis, - *(kspw_psi), - etatom.data()); + std::vector::type> etatom(nbands_start, 0.0); + if (not_equal) + { + // for diagH_subspace_init, psi_device->get_pointer() and kspw_psi->get_pointer() should be + // different + hsolver::DiagoIterAssist::diagH_subspace_init(p_hamilt, + psi_device->get_pointer(), + nbands_start, + nbasis, + *(kspw_psi), + etatom.data()); + } + else + { + // for diagH_subspace, psi_device->get_pointer() and kspw_psi->get_pointer() can be the same + hsolver::DiagoIterAssist::diagH_subspace(p_hamilt, + *psi_device, + *kspw_psi, + etatom.data(), + nbands_start); + } } - else + else // dav, bpcg { - // for diagH_subspace, psi_device->get_pointer() and kspw_psi->get_pointer() can be the same - hsolver::DiagoIterAssist::diagH_subspace(p_hamilt, - *psi_device, - *kspw_psi, - etatom.data(), - nbands_start); + if (psi_device->get_pointer() != kspw_psi->get_pointer()) + { + syncmem_complex_op()(kspw_psi->get_pointer(), psi_device->get_pointer(), nbands_l * nbasis); + } } } - else // dav, bpcg +#ifdef __MPI + if (PARAM.inp.ks_solver == "bpcg" && PARAM.inp.bndpar > 1) { - if (psi_device->get_pointer() != kspw_psi->get_pointer()) + std::vector sendcounts(PARAM.inp.bndpar); + std::vector displs(PARAM.inp.bndpar); + MPI_Allgather(&nbands_l, 1, MPI_INT, sendcounts.data(), 1, MPI_INT, BP_WORLD); + displs[0] = 0; + sendcounts[0] *= nbasis; + for (int i = 1; i < PARAM.inp.bndpar; i++) + { + sendcounts[i] *= nbasis; + displs[i] = displs[i - 1] + sendcounts[i - 1]; + } + if (GlobalV::MY_BNDGROUP == 0) { - syncmem_complex_op()(kspw_psi->get_pointer(), psi_device->get_pointer(), nbands * nbasis); + for (int ip = 1; ip < PARAM.inp.bndpar; ++ip) + { + Parallel_Common::send_data(psi_cpu->get_pointer() + displs[ip], sendcounts[ip], ip, 0, BP_WORLD); + } } + else + { + MPI_Status status; + Parallel_Common::recv_dev(kspw_psi->get_pointer(), nbands_l * nbasis, 0, 0, BP_WORLD, &status); + } } +#endif } // end k-point loop - if (not_equal) + if (fill) { - delete psi_cpu; - if(PARAM.inp.device == "gpu") + if (not_equal) { - delete psi_device; + delete psi_cpu; + if (PARAM.inp.device == "gpu") + { + delete psi_device; + } + } + else if (PARAM.inp.precision == "single" && PARAM.inp.device == "gpu") + { + delete psi_cpu; } - } - else if (PARAM.inp.precision == "single" && PARAM.inp.device == "gpu") - { - delete psi_cpu; } ModuleBase::timer::tick("PSIInit", "initialize_psi"); @@ -203,7 +240,11 @@ void PSIInit::initialize_lcao_in_pw(Psi* psi_local, std::ofstream& } } -void allocate_psi(Psi>*& psi, const int& nks, const std::vector& ngk, const int& nbands, const int& npwx) +void allocate_psi(Psi>*& psi, + const int& nks, + const std::vector& ngk, + const int& nbands, + const int& npwx) { assert(npwx > 0); assert(nks > 0); diff --git a/tests/integrate/102_PW_BPCG/result.ref b/tests/integrate/102_PW_BPCG/result.ref index 4815e05cb4..6ad0913957 100644 --- a/tests/integrate/102_PW_BPCG/result.ref +++ b/tests/integrate/102_PW_BPCG/result.ref @@ -1,8 +1,8 @@ -etotref -4869.74705201 -etotperatomref -2434.87352600 -totalforceref 5.19522000 -totalstressref 37241.49490600 +etotref -4869.7470520063843651 +etotperatomref -2434.8735260032 +totalforceref 5.194830 +totalstressref 37241.448435 pointgroupref C_1 spacegroupref C_1 nksibzref 8 -totaltimeref 10.37 +totaltimeref 5.53 diff --git a/tests/integrate/102_PW_BPCG_BP/INPUT b/tests/integrate/102_PW_BPCG_BP/INPUT new file mode 100644 index 0000000000..43dd00ab9f --- /dev/null +++ b/tests/integrate/102_PW_BPCG_BP/INPUT @@ -0,0 +1,33 @@ +INPUT_PARAMETERS +#Parameters (General) +suffix autotest +pseudo_dir ../../PP_ORB +pw_seed 1 + +gamma_only 0 +calculation scf +symmetry 1 +out_level ie +smearing_method gaussian +smearing_sigma 0.02 + +#Parameters (3.PW) +ecutwfc 40 +scf_thr 1e-7 +scf_nmax 20 +bndpar 2 + +#Parameters (LCAO) +basis_type pw +ks_solver bpcg +device cpu +chg_extrap second-order +out_dm 0 +pw_diag_thr 0.00001 + +cal_force 1 +cal_stress 1 + +mixing_type broyden +mixing_beta 0.4 +mixing_gg0 1.5 diff --git a/tests/integrate/102_PW_BPCG_BP/KPT b/tests/integrate/102_PW_BPCG_BP/KPT new file mode 100644 index 0000000000..28006d5e2d --- /dev/null +++ b/tests/integrate/102_PW_BPCG_BP/KPT @@ -0,0 +1,4 @@ +K_POINTS +0 +Gamma +2 2 2 0 0 0 diff --git a/tests/integrate/102_PW_BPCG_BP/README b/tests/integrate/102_PW_BPCG_BP/README new file mode 100644 index 0000000000..8b809e0a25 --- /dev/null +++ b/tests/integrate/102_PW_BPCG_BP/README @@ -0,0 +1,10 @@ +This test for: +*GaAs-deformation +*PW +*bndpar 2 +*kpoints 2*2*2 +*sg15 pseudopotential +*smearing_method gauss +*ks_solver bpcg +*mixing_type broyden-kerker +*mixing_beta 0.4 diff --git a/tests/integrate/102_PW_BPCG_BP/STRU b/tests/integrate/102_PW_BPCG_BP/STRU new file mode 100644 index 0000000000..b03baadd25 --- /dev/null +++ b/tests/integrate/102_PW_BPCG_BP/STRU @@ -0,0 +1,23 @@ +ATOMIC_SPECIES +As 1 As_dojo.upf upf201 +Ga 1 Ga_dojo.upf upf201 + +LATTICE_CONSTANT +1 // add lattice constant, 10.58 ang + +LATTICE_VECTORS +5.33 5.33 0.0 +0.0 5.33 5.33 +5.33 0.0 5.33 +ATOMIC_POSITIONS +Direct //Cartesian or Direct coordinate. + +As +0 +1 +0.300000 0.3300000 0.27000000 0 0 0 + +Ga //Element Label +0 +1 //number of atom +0.00000 0.00000 0.000000 0 0 0 diff --git a/tests/integrate/102_PW_BPCG_BP/result.ref b/tests/integrate/102_PW_BPCG_BP/result.ref new file mode 100755 index 0000000000..b97d5ed713 --- /dev/null +++ b/tests/integrate/102_PW_BPCG_BP/result.ref @@ -0,0 +1,8 @@ +etotref -4869.7470520063843651 +etotperatomref -2434.8735260032 +totalforceref 5.194830 +totalstressref 37241.453346 +pointgroupref C_1 +spacegroupref C_1 +nksibzref 8 +totaltimeref 5.42 diff --git a/tests/integrate/102_PW_BPCG_GPU/INPUT b/tests/integrate/102_PW_BPCG_GPU/INPUT index 1f24602487..57da09d770 100644 --- a/tests/integrate/102_PW_BPCG_GPU/INPUT +++ b/tests/integrate/102_PW_BPCG_GPU/INPUT @@ -15,6 +15,7 @@ smearing_sigma 0.02 ecutwfc 40 scf_thr 1e-7 scf_nmax 100 +bndpar 2 #Parameters (LCAO) basis_type pw diff --git a/tests/integrate/102_PW_BPCG_GPU/result.ref b/tests/integrate/102_PW_BPCG_GPU/result.ref index 41371e10fd..a4c2417f26 100644 --- a/tests/integrate/102_PW_BPCG_GPU/result.ref +++ b/tests/integrate/102_PW_BPCG_GPU/result.ref @@ -1,8 +1,8 @@ -etotref -4869.7470518350019120 +etotref -4869.7470518349809936 etotperatomref -2434.8735259175 totalforceref 5.207670 totalstressref 37241.465646 pointgroupref C_1 spacegroupref C_1 nksibzref 8 -totaltimeref 4.25 +totaltimeref 10.28 diff --git a/tests/integrate/184_PW_BPCG_SDFT_5D11S/INPUT b/tests/integrate/184_PW_BPCG_SDFT_5D11S/INPUT new file mode 100644 index 0000000000..dc7807efb9 --- /dev/null +++ b/tests/integrate/184_PW_BPCG_SDFT_5D11S/INPUT @@ -0,0 +1,38 @@ +INPUT_PARAMETERS +#Parameters (1.General) +suffix autotest +calculation scf +esolver_type sdft +ks_solver bpcg +method_sto 1 + +symmetry 0 +pseudo_dir ../../PP_ORB + +bndpar 2 + +nbands 5 +nbands_sto 11 + +nche_sto 120 +seed_sto 20000 +kpar 1 +cal_force 1 +cal_stress 1 + +#Parameters (2.Iteration) +ecutwfc 20 +scf_thr 1e-4 +scf_nmax 20 + + +#Parameters (3.Basis) +basis_type pw + +#Parameters (4.Smearing) +smearing_method fd +smearing_sigma 0.6 + +#Parameters (5.Mixing) +mixing_type broyden +mixing_beta 0.4 diff --git a/tests/integrate/184_PW_BPCG_SDFT_5D11S/KPT b/tests/integrate/184_PW_BPCG_SDFT_5D11S/KPT new file mode 100644 index 0000000000..c289c0158a --- /dev/null +++ b/tests/integrate/184_PW_BPCG_SDFT_5D11S/KPT @@ -0,0 +1,4 @@ +K_POINTS +0 +Gamma +1 1 1 0 0 0 diff --git a/tests/integrate/184_PW_BPCG_SDFT_5D11S/README b/tests/integrate/184_PW_BPCG_SDFT_5D11S/README new file mode 100644 index 0000000000..b150d66930 --- /dev/null +++ b/tests/integrate/184_PW_BPCG_SDFT_5D11S/README @@ -0,0 +1,9 @@ +This test for: +*SDFT +*Si +*kpoints 1*1*1 +*10 stochastic orbitals + 5 KS orbitals +*mixing_type broyden +*mixing_beta 0.4 +*seed_sto > 0 +*bndpar 2 diff --git a/tests/integrate/184_PW_BPCG_SDFT_5D11S/STRU b/tests/integrate/184_PW_BPCG_SDFT_5D11S/STRU new file mode 100644 index 0000000000..92de2f5eee --- /dev/null +++ b/tests/integrate/184_PW_BPCG_SDFT_5D11S/STRU @@ -0,0 +1,19 @@ +ATOMIC_SPECIES +Si 28 Si.pz-vbc.UPF + +LATTICE_CONSTANT +8 // add lattice constant + +LATTICE_VECTORS +1 0 0 +0 1 0 +0 0 1 + +ATOMIC_POSITIONS +Direct + +Si // Element type +0.0 // magnetism +2 +0.00 0.00 0.00 1 1 1 +0.5 0.5 0.5 1 1 1 diff --git a/tests/integrate/184_PW_BPCG_SDFT_5D11S/jd b/tests/integrate/184_PW_BPCG_SDFT_5D11S/jd new file mode 100644 index 0000000000..e87c7a3930 --- /dev/null +++ b/tests/integrate/184_PW_BPCG_SDFT_5D11S/jd @@ -0,0 +1 @@ +test parallel method bndpar with BPCG. Must be run by MPI with 4 cores, otherwise please ignore this test. diff --git a/tests/integrate/184_PW_BPCG_SDFT_5D11S/result.ref b/tests/integrate/184_PW_BPCG_SDFT_5D11S/result.ref new file mode 100644 index 0000000000..460f04c8f6 --- /dev/null +++ b/tests/integrate/184_PW_BPCG_SDFT_5D11S/result.ref @@ -0,0 +1,5 @@ +etotref -323.28029702 +etotperatomref -161.64014851 +totalforceref 5.605396 +totalstressref 1781.685471 +totaltimeref 2.80 diff --git a/tests/integrate/187_PW_SDFT_MALL_BPCG_GPU/INPUT b/tests/integrate/187_PW_SDFT_MALL_BPCG_GPU/INPUT new file mode 100644 index 0000000000..4eab8220a6 --- /dev/null +++ b/tests/integrate/187_PW_SDFT_MALL_BPCG_GPU/INPUT @@ -0,0 +1,39 @@ +INPUT_PARAMETERS +#Parameters (1.General) +suffix autotest +calculation scf +esolver_type sdft +method_sto 2 +device gpu +ks_solver bpcg + +symmetry 0 +pseudo_dir ../../PP_ORB + +kpar 1 +bndpar 2 + +nbands 11 +nbands_sto all + +nche_sto 120 +seed_sto 20000 +cal_force 1 +cal_stress 1 + +#Parameters (2.Iteration) +ecutwfc 30 +scf_thr 1e-6 +scf_nmax 20 + + +#Parameters (3.Basis) +basis_type pw + +#Parameters (4.Smearing) +smearing_method fd +smearing_sigma 0.6 + +#Parameters (5.Mixing) +mixing_type plain +mixing_beta 0.7 diff --git a/tests/integrate/187_PW_SDFT_MALL_BPCG_GPU/KPT b/tests/integrate/187_PW_SDFT_MALL_BPCG_GPU/KPT new file mode 100644 index 0000000000..c289c0158a --- /dev/null +++ b/tests/integrate/187_PW_SDFT_MALL_BPCG_GPU/KPT @@ -0,0 +1,4 @@ +K_POINTS +0 +Gamma +1 1 1 0 0 0 diff --git a/tests/integrate/187_PW_SDFT_MALL_BPCG_GPU/README b/tests/integrate/187_PW_SDFT_MALL_BPCG_GPU/README new file mode 100644 index 0000000000..c1f7d02359 --- /dev/null +++ b/tests/integrate/187_PW_SDFT_MALL_BPCG_GPU/README @@ -0,0 +1,12 @@ +This test for: +*SDFT +*Si +*kpoints 1*1*1 +*10 KS + complete orbitals +*ks_solver bpcg +*mixing_type plain +*mixing_beta 0.7 +*sto_method 2 +*seed_sto > 0 +*bndpar 2 +*kpar 1 diff --git a/tests/integrate/187_PW_SDFT_MALL_BPCG_GPU/STRU b/tests/integrate/187_PW_SDFT_MALL_BPCG_GPU/STRU new file mode 100644 index 0000000000..28a94e8c0c --- /dev/null +++ b/tests/integrate/187_PW_SDFT_MALL_BPCG_GPU/STRU @@ -0,0 +1,19 @@ +ATOMIC_SPECIES +Si 28 Si.pz-vbc.UPF + +LATTICE_CONSTANT +5 // add lattice constant + +LATTICE_VECTORS +0.5 0 0.5 +0.5 0.5 0 +0 0.5 0.5 + +ATOMIC_POSITIONS +Direct + +Si // Element type +0.0 // magnetism +2 +0.10 0.00 0.20 1 1 1 +0.5 0.5 0.5 1 1 1 diff --git a/tests/integrate/187_PW_SDFT_MALL_BPCG_GPU/jd b/tests/integrate/187_PW_SDFT_MALL_BPCG_GPU/jd new file mode 100644 index 0000000000..44886f5d13 --- /dev/null +++ b/tests/integrate/187_PW_SDFT_MALL_BPCG_GPU/jd @@ -0,0 +1 @@ +test parallel method bndpar with BPCG and GPU diff --git a/tests/integrate/187_PW_SDFT_MALL_BPCG_GPU/result.ref b/tests/integrate/187_PW_SDFT_MALL_BPCG_GPU/result.ref new file mode 100644 index 0000000000..077de5b358 --- /dev/null +++ b/tests/integrate/187_PW_SDFT_MALL_BPCG_GPU/result.ref @@ -0,0 +1,5 @@ +etotref -96.9361190965003630 +etotperatomref -48.4680595483 +totalforceref 248.979444 +totalstressref 230453.604050 +totaltimeref 6.44 diff --git a/tests/integrate/CASES_CPU.txt b/tests/integrate/CASES_CPU.txt index db834ee8e0..09897e4bcd 100644 --- a/tests/integrate/CASES_CPU.txt +++ b/tests/integrate/CASES_CPU.txt @@ -24,6 +24,7 @@ 101_PW_GTH_CF_CS_Si 102_PW_DA_davidson 102_PW_BPCG +102_PW_BPCG_BP 102_PW_CG 102_PW_DS_davsubspace 102_PW_DS_davsubspace_sca @@ -123,6 +124,7 @@ 184_PW_BNDKPAR_SDFT_MALL 184_PW_BNDPAR_SDFT_10S 184_PW_BNDPAR_SDFT_5D10S +184_PW_BPCG_SDFT_5D11S 184_PW_KPAR_SDFT_ALL 185_PW_SDFT_10D10S_METHD2 185_PW_SDFT_10S_METHD2 diff --git a/tests/integrate/CASES_GPU.txt b/tests/integrate/CASES_GPU.txt index 3490ccc9a0..44fbcccdad 100644 --- a/tests/integrate/CASES_GPU.txt +++ b/tests/integrate/CASES_GPU.txt @@ -3,6 +3,7 @@ 102_PW_BPCG_GPU 187_PW_SDFT_ALL_GPU 187_PW_SDFT_MALL_GPU +187_PW_SDFT_MALL_BPCG_GPU 187_PW_MD_SDFT_ALL_GPU 930_NO_BI2SE2CU2O2_GPU 930_NO_BI2SE2CU2O2_k_GPU