Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: Make BPCG support band parallelism #5873

Open
wants to merge 22 commits into
base: develop
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
change name
  • Loading branch information
Qianruipku committed Jan 22, 2025
commit c5dc8cea0fdd33a8a0a580f313ed0fc886b49a0f
4 changes: 2 additions & 2 deletions source/driver.cpp
Original file line number Diff line number Diff line change
@@ -152,9 +152,9 @@ void Driver::reading()
GlobalV::MY_RANK,
PARAM.inp.bndpar,
GlobalV::KPAR,
GlobalV::NPROC_IN_STOGROUP,
GlobalV::NPROC_IN_BNDGROUP,
GlobalV::RANK_IN_BPGROUP,
GlobalV::MY_STOGROUP,
GlobalV::MY_BNDGROUP,
GlobalV::NPROC_IN_POOL,
GlobalV::RANK_IN_POOL,
GlobalV::MY_POOL);
4 changes: 2 additions & 2 deletions source/module_base/global_variable.cpp
Original file line number Diff line number Diff line change
@@ -20,9 +20,9 @@ int NPROC = 1; ///< global number of process
int KPAR = 1; ///< global number of pools
int MY_RANK = 0; ///< global index of process
int MY_POOL = 0; ///< global index of pool (count in pool)
int MY_STOGROUP = 0;
int MY_BNDGROUP = 0;
int NPROC_IN_POOL = 1; ///< local number of process in a pool
int NPROC_IN_STOGROUP = 1;
int NPROC_IN_BNDGROUP = 1;
int RANK_IN_POOL = 0; ///< global index of pool (count in process), my_rank in each pool
int RANK_IN_BPGROUP = 0;
int DRANK = -1; ///< mohan add 2012-01-13, must be -1, so we can recognize who
4 changes: 2 additions & 2 deletions source/module_base/global_variable.h
Original file line number Diff line number Diff line change
@@ -33,9 +33,9 @@ extern int NPROC;
extern int KPAR;
extern int MY_RANK;
extern int MY_POOL;
extern int MY_STOGROUP;
extern int MY_BNDGROUP;
extern int NPROC_IN_POOL;
extern int NPROC_IN_STOGROUP;
extern int NPROC_IN_BNDGROUP;
extern int RANK_IN_POOL;
extern int RANK_IN_BPGROUP;
extern int DRANK;
24 changes: 12 additions & 12 deletions source/module_base/parallel_global.cpp
Original file line number Diff line number Diff line change
@@ -253,9 +253,9 @@ void Parallel_Global::init_pools(const int& NPROC,
const int& MY_RANK,
const int& BNDPAR,
const int& KPAR,
int& NPROC_IN_STOGROUP,
int& NPROC_IN_BNDGROUP,
int& RANK_IN_BPGROUP,
int& MY_STOGROUP,
int& MY_BNDGROUP,
int& NPROC_IN_POOL,
int& RANK_IN_POOL,
int& MY_POOL)
@@ -268,9 +268,9 @@ void Parallel_Global::init_pools(const int& NPROC,
MY_RANK,
BNDPAR,
KPAR,
NPROC_IN_STOGROUP,
NPROC_IN_BNDGROUP,
RANK_IN_BPGROUP,
MY_STOGROUP,
MY_BNDGROUP,
NPROC_IN_POOL,
RANK_IN_POOL,
MY_POOL);
@@ -316,16 +316,16 @@ void Parallel_Global::divide_pools(const int& NPROC,
const int& MY_RANK,
const int& BNDPAR,
const int& KPAR,
int& NPROC_IN_STOGROUP,
int& NPROC_IN_BNDGROUP,
int& RANK_IN_BPGROUP,
int& MY_STOGROUP,
int& MY_BNDGROUP,
int& NPROC_IN_POOL,
int& RANK_IN_POOL,
int& MY_POOL)
{
// note: the order of k-point parallelization and band parallelization is important
// The order will not change the behavior of INTER_POOL or PARAPW_WORLD, and MY_POOL
// and MY_STOGROUP will be the same as well.
// and MY_BNDGROUP will be the same as well.
if(BNDPAR > 1 && NPROC %(BNDPAR * KPAR) != 0)
{
std::cout << "Error: When BNDPAR = " << BNDPAR << " > 1, number of processes (" << NPROC << ") must be divisible by the number of groups ("
@@ -358,17 +358,17 @@ void Parallel_Global::divide_pools(const int& NPROC,

if(BNDPAR > 1)
{
NPROC_IN_STOGROUP = kpar_group.ngroups * bndpar_group.nprocs_in_group;
NPROC_IN_BNDGROUP = kpar_group.ngroups * bndpar_group.nprocs_in_group;
RANK_IN_BPGROUP = kpar_group.my_group * bndpar_group.nprocs_in_group + bndpar_group.rank_in_group;
MY_STOGROUP = bndpar_group.my_group;
MPI_Comm_split(MPI_COMM_WORLD, MY_STOGROUP, RANK_IN_BPGROUP, &STO_WORLD);
MY_BNDGROUP = bndpar_group.my_group;
MPI_Comm_split(MPI_COMM_WORLD, MY_BNDGROUP, RANK_IN_BPGROUP, &STO_WORLD);
MPI_Comm_dup(bndpar_group.inter_comm, &PARAPW_WORLD);
}
else
{
NPROC_IN_STOGROUP = NPROC;
NPROC_IN_BNDGROUP = NPROC;
RANK_IN_BPGROUP = MY_RANK;
MY_STOGROUP = 0;
MY_BNDGROUP = 0;
MPI_Comm_dup(MPI_COMM_WORLD, &STO_WORLD);
MPI_Comm_split(MPI_COMM_WORLD, MY_RANK, 0, &PARAPW_WORLD);
}
8 changes: 4 additions & 4 deletions source/module_base/parallel_global.h
Original file line number Diff line number Diff line change
@@ -48,9 +48,9 @@ void init_pools(const int& NPROC,
const int& MY_RANK,
const int& BNDPAR,
const int& KPAR,
int& NPROC_IN_STOGROUP,
int& NPROC_IN_BNDGROUP,
int& RANK_IN_BPGROUP,
int& MY_STOGROUP,
int& MY_BNDGROUP,
int& NPROC_IN_POOL,
int& RANK_IN_POOL,
int& MY_POOL);
@@ -59,9 +59,9 @@ void divide_pools(const int& NPROC,
const int& MY_RANK,
const int& BNDPAR,
const int& KPAR,
int& NPROC_IN_STOGROUP,
int& NPROC_IN_BNDGROUP,
int& RANK_IN_BPGROUP,
int& MY_STOGROUP,
int& MY_BNDGROUP,
int& NPROC_IN_POOL,
int& RANK_IN_POOL,
int& MY_POOL);
10 changes: 10 additions & 0 deletions source/module_cell/cal_atoms_info.h
Original file line number Diff line number Diff line change
@@ -68,6 +68,16 @@ class CalAtomsInfo
nelec_spin[1] = (para.inp.nelec - para.inp.nupdown ) / 2.0;
}
elecstate::cal_nbands(para.inp.nelec, para.sys.nlocal, nelec_spin, para.input.nbands);
// calculate the number of nbands_local
para.sys.nbands_l = para.inp.nbands;
if (inp.ks_solver == "bpcg") // only bpcg support band parallel
{
para.sys.nbands_l = para.inp.nbands / para.inp.bndpar;
if (GlobalV::RANK_IN_BPGROUP < para.inp.nbands % para.inp.bndpar)
{
para.sys.nbands_l++;
}
}
return;
}
};
4 changes: 2 additions & 2 deletions source/module_cell/test/klist_test_para.cpp
Original file line number Diff line number Diff line change
@@ -229,7 +229,7 @@ TEST_F(KlistParaTest, Set)
GlobalV::MY_RANK,
PARAM.input.bndpar,
GlobalV::KPAR,
GlobalV::NPROC_IN_STOGROUP,
GlobalV::NPROC_IN_BNDGROUP,
GlobalV::RANK_IN_BPGROUP,
GlobalV::MY_STOGROUP,
GlobalV::NPROC_IN_POOL,
@@ -286,7 +286,7 @@ TEST_F(KlistParaTest, SetAfterVC)
GlobalV::MY_RANK,
PARAM.input.bndpar,
GlobalV::KPAR,
GlobalV::NPROC_IN_STOGROUP,
GlobalV::NPROC_IN_BNDGROUP,
GlobalV::RANK_IN_BPGROUP,
GlobalV::MY_STOGROUP,
GlobalV::NPROC_IN_POOL,
2 changes: 1 addition & 1 deletion source/module_elecstate/elecstate_print.cpp
Original file line number Diff line number Diff line change
@@ -205,7 +205,7 @@ void ElecState::print_eigenvalue(std::ofstream& ofs)
#ifdef __MPI
MPI_Barrier(MPI_COMM_WORLD);
#endif
bool ip_flag = PARAM.inp.out_alllog || (GlobalV::RANK_IN_POOL == 0 && GlobalV::MY_STOGROUP == 0);
bool ip_flag = PARAM.inp.out_alllog || (GlobalV::RANK_IN_POOL == 0 && GlobalV::MY_BNDGROUP == 0);
if (GlobalV::MY_POOL == ip && ip_flag)
{
const int start_ik = nks_np * is;
2 changes: 1 addition & 1 deletion source/module_elecstate/elecstate_pw_sdft.cpp
Original file line number Diff line number Diff line change
@@ -19,7 +19,7 @@ void ElecStatePW_SDFT<T, Device>::psiToRho(const psi::Psi<T, Device>& psi)
setmem_var_op()(this->rho[is], 0, this->charge->nrxx);
}

if (GlobalV::MY_STOGROUP == 0)
if (GlobalV::MY_BNDGROUP == 0 || PARAM.inp.ks_solver == "bpcg")
{
for (int ik = 0; ik < psi.get_nk(); ++ik)
{
6 changes: 3 additions & 3 deletions source/module_elecstate/test_mpi/charge_mpi_test.cpp
Original file line number Diff line number Diff line change
@@ -70,7 +70,7 @@ TEST_F(ChargeMpiTest, reduce_diff_pools1)
GlobalV::MY_RANK,
PARAM.input.bndpar,
GlobalV::KPAR,
GlobalV::NPROC_IN_STOGROUP,
GlobalV::NPROC_IN_BNDGROUP,
GlobalV::RANK_IN_BPGROUP,
GlobalV::MY_STOGROUP,
GlobalV::NPROC_IN_POOL,
@@ -116,7 +116,7 @@ TEST_F(ChargeMpiTest, reduce_diff_pools2)
GlobalV::MY_RANK,
PARAM.input.bndpar,
GlobalV::KPAR,
GlobalV::NPROC_IN_STOGROUP,
GlobalV::NPROC_IN_BNDGROUP,
GlobalV::RANK_IN_BPGROUP,
GlobalV::MY_STOGROUP,
GlobalV::NPROC_IN_POOL,
@@ -171,7 +171,7 @@ TEST_F(ChargeMpiTest, rho_mpi)
GlobalV::MY_RANK,
PARAM.input.bndpar,
GlobalV::KPAR,
GlobalV::NPROC_IN_STOGROUP,
GlobalV::NPROC_IN_BNDGROUP,
GlobalV::RANK_IN_BPGROUP,
GlobalV::MY_STOGROUP,
GlobalV::NPROC_IN_POOL,
4 changes: 2 additions & 2 deletions source/module_esolver/esolver_ks.cpp
Original file line number Diff line number Diff line change
@@ -361,7 +361,7 @@ void ESolver_KS<T, Device>::hamilt2density(UnitCell& ucell, const int istep, con
// Maybe in the future, density and wavefunctions should use different
// parallel algorithms, in which they do not occupy all processors, for
// example wavefunctions uses 20 processors while density uses 10.
if (GlobalV::MY_STOGROUP == 0)
if (GlobalV::MY_BNDGROUP == 0)
{
// double drho = this->estate.caldr2();
// EState should be used after it is constructed.
@@ -550,7 +550,7 @@ void ESolver_KS<T, Device>::iter_finish(UnitCell& ucell, const int istep, int& i
this->pelec->charge->rho,
this->pelec->nelec_spin.data());

if (GlobalV::MY_STOGROUP == 0)
if (GlobalV::MY_BNDGROUP == 0)
{
// mixing will restart at this->p_chgmix->mixing_restart steps
if (drho <= PARAM.inp.mixing_restart && PARAM.inp.mixing_restart > 0.0
2 changes: 1 addition & 1 deletion source/module_esolver/esolver_sdft_pw.cpp
Original file line number Diff line number Diff line change
@@ -198,7 +198,7 @@ void ESolver_SDFT_PW<T, Device>::hamilt2density_single(UnitCell& ucell, int iste
// set_diagethr need it
this->esolver_KS_ne = hsolver_pw_sdft_obj.stoiter.KS_ne;

if (GlobalV::MY_STOGROUP == 0)
if (GlobalV::MY_BNDGROUP == 0)
{
Symmetry_rho srho;
for (int is = 0; is < PARAM.inp.nspin; is++)
2 changes: 1 addition & 1 deletion source/module_hamilt_pw/hamilt_pwdft/parallel_grid.cpp
Original file line number Diff line number Diff line change
@@ -84,7 +84,7 @@ void Parallel_Grid::init(

this->nproc_in_pool = new int[GlobalV::KPAR];
int nprocgroup;
if(PARAM.inp.esolver_type == "sdft") { nprocgroup = GlobalV::NPROC_IN_STOGROUP;
if(PARAM.inp.esolver_type == "sdft") { nprocgroup = GlobalV::NPROC_IN_BNDGROUP;
} else { nprocgroup = GlobalV::NPROC;
}

4 changes: 2 additions & 2 deletions source/module_hamilt_pw/hamilt_stodft/sto_elecond.cpp
Original file line number Diff line number Diff line change
@@ -620,7 +620,7 @@ void Sto_EleCond::sKG(const int& smear_type,
}
// Parallel for bands
int allbands_ks = this->nbands_ks - cutib0;
parallel_distribution paraks(allbands_ks, PARAM.inp.bndpar, GlobalV::MY_STOGROUP);
parallel_distribution paraks(allbands_ks, PARAM.inp.bndpar, GlobalV::MY_BNDGROUP);
int perbands_ks = paraks.num_per;
int ib0_ks = paraks.start;
ib0_ks += this->nbands_ks - allbands_ks;
@@ -653,7 +653,7 @@ void Sto_EleCond::sKG(const int& smear_type,
//-----------------------------------------------------------
// ks conductivity
//-----------------------------------------------------------
if (GlobalV::MY_STOGROUP == 0 && allbands_ks > 0)
if (GlobalV::MY_BNDGROUP == 0 && allbands_ks > 0)
{
jjresponse_ks(ik, nt, dt, dEcut, this->p_elec->wg, velop, ct11.data(), ct12.data(), ct22.data());
}
2 changes: 1 addition & 1 deletion source/module_hamilt_pw/hamilt_stodft/sto_forces.cpp
Original file line number Diff line number Diff line change
@@ -210,7 +210,7 @@ void Sto_Forces<FPTYPE, Device>::cal_sto_force_nl(
const int* nchip = stowf.nchip;
const int npwx = wfc_basis->npwk_max;
int nksbands = psi_in.get_nbands();
if (GlobalV::MY_STOGROUP != 0)
if (GlobalV::MY_BNDGROUP != 0)
{
nksbands = 0;
}
4 changes: 2 additions & 2 deletions source/module_hamilt_pw/hamilt_stodft/sto_stress_pw.cpp
Original file line number Diff line number Diff line change
@@ -111,7 +111,7 @@ void Sto_Stress_PW<FPTYPE, Device>::sto_stress_kin(ModuleBase::matrix& sigma,
ModuleBase::timer::tick("Sto_Stress_PW", "stress_kin");

int nksbands = psi.get_nbands();
if (GlobalV::MY_STOGROUP != 0)
if (GlobalV::MY_BNDGROUP != 0)
{
nksbands = 0;
}
@@ -160,7 +160,7 @@ void Sto_Stress_PW<FPTYPE, Device>::sto_stress_nl(ModuleBase::matrix& sigma,
int* nchip = stowf.nchip;
const int npwx = wfc_basis->npwk_max;
int nksbands = psi_in.get_nbands();
if (GlobalV::MY_STOGROUP != 0)
if (GlobalV::MY_BNDGROUP != 0 && PARAM.inp.ks_solver != "bpcg")
{
nksbands = 0;
}
12 changes: 6 additions & 6 deletions source/module_hamilt_pw/hamilt_stodft/sto_wf.cpp
Original file line number Diff line number Diff line change
@@ -72,7 +72,7 @@ void Stochastic_WF<T, Device>::init_sto_orbitals(const int seed_in)
}
else
{
srand((unsigned)std::abs(seed_in) + (GlobalV::MY_STOGROUP * GlobalV::NPROC_IN_STOGROUP + GlobalV::RANK_IN_BPGROUP) * 10000);
srand((unsigned)std::abs(seed_in) + (GlobalV::MY_BNDGROUP * GlobalV::NPROC_IN_BNDGROUP + GlobalV::RANK_IN_BPGROUP) * 10000);
}

this->allocate_chi0();
@@ -88,12 +88,12 @@ void Stochastic_WF<T, Device>::allocate_chi0()
// former processor calculate more bands
if (firstrankmore)
{
igroup = GlobalV::MY_STOGROUP;
igroup = GlobalV::MY_BNDGROUP;
}
// latter processor calculate more bands
else
{
igroup = PARAM.inp.bndpar - GlobalV::MY_STOGROUP - 1;
igroup = PARAM.inp.bndpar - GlobalV::MY_BNDGROUP - 1;
}
const int nchi = PARAM.inp.nbands_sto;
const int npwx = this->npwx;
@@ -172,12 +172,12 @@ void Stochastic_WF<T, Device>::init_com_orbitals()
// former processor calculate more bands
if (firstrankmore)
{
igroup = GlobalV::MY_STOGROUP;
igroup = GlobalV::MY_BNDGROUP;
}
// latter processor calculate more bands
else
{
igroup = PARAM.inp.bndpar - GlobalV::MY_STOGROUP - 1;
igroup = PARAM.inp.bndpar - GlobalV::MY_BNDGROUP - 1;
}
const int ngroup = PARAM.inp.bndpar;
const int n_in_pool = GlobalV::NPROC_IN_POOL;
@@ -318,7 +318,7 @@ void Stochastic_WF<T, Device>::init_sto_orbitals_Ecut(const int seed_in,
MPI_Allgather(&nchiper, 1, MPI_INT, nrecv, 1, MPI_INT, PARAPW_WORLD);
#endif
int ichi_start = 0;
for (int i = 0; i < GlobalV::MY_STOGROUP; ++i)
for (int i = 0; i < GlobalV::MY_BNDGROUP; ++i)
{
ichi_start += nrecv[i];
}
4 changes: 2 additions & 2 deletions source/module_hsolver/hsolver_pw_sdft.cpp
Original file line number Diff line number Diff line change
@@ -46,7 +46,7 @@ void HSolverPW_SDFT<T, Device>::solve(const UnitCell& ucell,
{
ModuleBase::timer::tick("HSolverPW_SDFT", "solve_KS");
pHamilt->updateHk(ik);
if (nbands > 0 && GlobalV::MY_STOGROUP == 0)
if (nbands > 0 && GlobalV::MY_BNDGROUP == 0)
{
/// update psi pointer for each k point
psi.fix_k(ik);
@@ -89,7 +89,7 @@ void HSolverPW_SDFT<T, Device>::solve(const UnitCell& ucell,

// calculate eband = \sum_{ik,ib} w(ik)f(ik,ib)e_{ikib}, demet = -TS
elecstate::ElecStatePW<T, Device>* pes_pw = static_cast<elecstate::ElecStatePW<T, Device>*>(pes);
if (GlobalV::MY_STOGROUP == 0)
if (GlobalV::MY_BNDGROUP == 0)
{
pes_pw->calEBand();
}
10 changes: 4 additions & 6 deletions source/module_hsolver/test/diago_lapack_test.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
// Author: Zhang Xiaoyang
// A modified version of diago_lcao_test.cpp
// #define private public
#define private public
#include "module_parameter/parameter.h"
// #undef private
#undef private
// Remove some useless functions and dependencies. Serialized the full code
// and refactored some function.

@@ -187,10 +187,8 @@ class DiagoLapackPrepare

void set_env()
{
// PARAM.sys.nlocal = nlocal;
PARAM.set_sys_nlocal(nlocal);
// PARAM.input.nbands = nbands;
PARAM.set_input_nbands(nbands);
PARAM.sys.nlocal = nlocal;
PARAM.input.nbands = nbands;
}

void diago()
4 changes: 2 additions & 2 deletions source/module_io/read_input.cpp
Original file line number Diff line number Diff line change
@@ -129,7 +129,7 @@ void ReadInput::read_parameters(Parameter& param, const std::string& filename_in

// 3. check the number of atom types from STRU file
// set the global directories
this->set_global_dir(param);
this->set_global_dir(param.inp, param.sys);
if (this->check_ntype_flag && this->rank == 0)
{
check_ntype(param.globalv.global_in_stru, param.input.ntype);
@@ -143,7 +143,7 @@ void ReadInput::read_parameters(Parameter& param, const std::string& filename_in
}

// 5. set the globalv parameters, some parameters in different processes are different. e.g. rank
this->set_globalv(param);
this->set_globalv(param.inp, param.sys);

if (this->check_mode)
{
Loading