Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: Make BPCG support band parallelism #5873

Merged
merged 22 commits into from
Feb 1, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
change name
  • Loading branch information
Qianruipku committed Jan 22, 2025
commit c5dc8cea0fdd33a8a0a580f313ed0fc886b49a0f
4 changes: 2 additions & 2 deletions source/driver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,9 @@ void Driver::reading()
GlobalV::MY_RANK,
PARAM.inp.bndpar,
GlobalV::KPAR,
GlobalV::NPROC_IN_STOGROUP,
GlobalV::NPROC_IN_BNDGROUP,
GlobalV::RANK_IN_BPGROUP,
GlobalV::MY_STOGROUP,
GlobalV::MY_BNDGROUP,
GlobalV::NPROC_IN_POOL,
GlobalV::RANK_IN_POOL,
GlobalV::MY_POOL);
Expand Down
4 changes: 2 additions & 2 deletions source/module_base/global_variable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ int NPROC = 1; ///< global number of process
int KPAR = 1; ///< global number of pools
int MY_RANK = 0; ///< global index of process
int MY_POOL = 0; ///< global index of pool (count in pool)
int MY_STOGROUP = 0;
int MY_BNDGROUP = 0;
int NPROC_IN_POOL = 1; ///< local number of process in a pool
int NPROC_IN_STOGROUP = 1;
int NPROC_IN_BNDGROUP = 1;
int RANK_IN_POOL = 0; ///< global index of pool (count in process), my_rank in each pool
int RANK_IN_BPGROUP = 0;
int DRANK = -1; ///< mohan add 2012-01-13, must be -1, so we can recognize who
Expand Down
4 changes: 2 additions & 2 deletions source/module_base/global_variable.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ extern int NPROC;
extern int KPAR;
extern int MY_RANK;
extern int MY_POOL;
extern int MY_STOGROUP;
extern int MY_BNDGROUP;
extern int NPROC_IN_POOL;
extern int NPROC_IN_STOGROUP;
extern int NPROC_IN_BNDGROUP;
extern int RANK_IN_POOL;
extern int RANK_IN_BPGROUP;
extern int DRANK;
Expand Down
24 changes: 12 additions & 12 deletions source/module_base/parallel_global.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -253,9 +253,9 @@ void Parallel_Global::init_pools(const int& NPROC,
const int& MY_RANK,
const int& BNDPAR,
const int& KPAR,
int& NPROC_IN_STOGROUP,
int& NPROC_IN_BNDGROUP,
int& RANK_IN_BPGROUP,
int& MY_STOGROUP,
int& MY_BNDGROUP,
int& NPROC_IN_POOL,
int& RANK_IN_POOL,
int& MY_POOL)
Expand All @@ -268,9 +268,9 @@ void Parallel_Global::init_pools(const int& NPROC,
MY_RANK,
BNDPAR,
KPAR,
NPROC_IN_STOGROUP,
NPROC_IN_BNDGROUP,
RANK_IN_BPGROUP,
MY_STOGROUP,
MY_BNDGROUP,
NPROC_IN_POOL,
RANK_IN_POOL,
MY_POOL);
Expand Down Expand Up @@ -316,16 +316,16 @@ void Parallel_Global::divide_pools(const int& NPROC,
const int& MY_RANK,
const int& BNDPAR,
const int& KPAR,
int& NPROC_IN_STOGROUP,
int& NPROC_IN_BNDGROUP,
int& RANK_IN_BPGROUP,
int& MY_STOGROUP,
int& MY_BNDGROUP,
int& NPROC_IN_POOL,
int& RANK_IN_POOL,
int& MY_POOL)
{
// note: the order of k-point parallelization and band parallelization is important
// The order will not change the behavior of INTER_POOL or PARAPW_WORLD, and MY_POOL
// and MY_STOGROUP will be the same as well.
// and MY_BNDGROUP will be the same as well.
if(BNDPAR > 1 && NPROC %(BNDPAR * KPAR) != 0)
{
std::cout << "Error: When BNDPAR = " << BNDPAR << " > 1, number of processes (" << NPROC << ") must be divisible by the number of groups ("
Expand Down Expand Up @@ -358,17 +358,17 @@ void Parallel_Global::divide_pools(const int& NPROC,

if(BNDPAR > 1)
{
NPROC_IN_STOGROUP = kpar_group.ngroups * bndpar_group.nprocs_in_group;
NPROC_IN_BNDGROUP = kpar_group.ngroups * bndpar_group.nprocs_in_group;
RANK_IN_BPGROUP = kpar_group.my_group * bndpar_group.nprocs_in_group + bndpar_group.rank_in_group;
MY_STOGROUP = bndpar_group.my_group;
MPI_Comm_split(MPI_COMM_WORLD, MY_STOGROUP, RANK_IN_BPGROUP, &STO_WORLD);
MY_BNDGROUP = bndpar_group.my_group;
MPI_Comm_split(MPI_COMM_WORLD, MY_BNDGROUP, RANK_IN_BPGROUP, &STO_WORLD);
MPI_Comm_dup(bndpar_group.inter_comm, &PARAPW_WORLD);
}
else
{
NPROC_IN_STOGROUP = NPROC;
NPROC_IN_BNDGROUP = NPROC;
RANK_IN_BPGROUP = MY_RANK;
MY_STOGROUP = 0;
MY_BNDGROUP = 0;
MPI_Comm_dup(MPI_COMM_WORLD, &STO_WORLD);
MPI_Comm_split(MPI_COMM_WORLD, MY_RANK, 0, &PARAPW_WORLD);
}
Expand Down
8 changes: 4 additions & 4 deletions source/module_base/parallel_global.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,9 @@ void init_pools(const int& NPROC,
const int& MY_RANK,
const int& BNDPAR,
const int& KPAR,
int& NPROC_IN_STOGROUP,
int& NPROC_IN_BNDGROUP,
int& RANK_IN_BPGROUP,
int& MY_STOGROUP,
int& MY_BNDGROUP,
int& NPROC_IN_POOL,
int& RANK_IN_POOL,
int& MY_POOL);
Expand All @@ -59,9 +59,9 @@ void divide_pools(const int& NPROC,
const int& MY_RANK,
const int& BNDPAR,
const int& KPAR,
int& NPROC_IN_STOGROUP,
int& NPROC_IN_BNDGROUP,
int& RANK_IN_BPGROUP,
int& MY_STOGROUP,
int& MY_BNDGROUP,
int& NPROC_IN_POOL,
int& RANK_IN_POOL,
int& MY_POOL);
Expand Down
10 changes: 10 additions & 0 deletions source/module_cell/cal_atoms_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,16 @@ class CalAtomsInfo
nelec_spin[1] = (para.inp.nelec - para.inp.nupdown ) / 2.0;
}
elecstate::cal_nbands(para.inp.nelec, para.sys.nlocal, nelec_spin, para.input.nbands);
// calculate the number of nbands_local
para.sys.nbands_l = para.inp.nbands;
if (inp.ks_solver == "bpcg") // only bpcg support band parallel
{
para.sys.nbands_l = para.inp.nbands / para.inp.bndpar;
if (GlobalV::RANK_IN_BPGROUP < para.inp.nbands % para.inp.bndpar)
{
para.sys.nbands_l++;
}
}
return;
}
};
Expand Down
4 changes: 2 additions & 2 deletions source/module_cell/test/klist_test_para.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ TEST_F(KlistParaTest, Set)
GlobalV::MY_RANK,
PARAM.input.bndpar,
GlobalV::KPAR,
GlobalV::NPROC_IN_STOGROUP,
GlobalV::NPROC_IN_BNDGROUP,
GlobalV::RANK_IN_BPGROUP,
GlobalV::MY_STOGROUP,
GlobalV::NPROC_IN_POOL,
Expand Down Expand Up @@ -286,7 +286,7 @@ TEST_F(KlistParaTest, SetAfterVC)
GlobalV::MY_RANK,
PARAM.input.bndpar,
GlobalV::KPAR,
GlobalV::NPROC_IN_STOGROUP,
GlobalV::NPROC_IN_BNDGROUP,
GlobalV::RANK_IN_BPGROUP,
GlobalV::MY_STOGROUP,
GlobalV::NPROC_IN_POOL,
Expand Down
2 changes: 1 addition & 1 deletion source/module_elecstate/elecstate_print.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ void ElecState::print_eigenvalue(std::ofstream& ofs)
#ifdef __MPI
MPI_Barrier(MPI_COMM_WORLD);
#endif
bool ip_flag = PARAM.inp.out_alllog || (GlobalV::RANK_IN_POOL == 0 && GlobalV::MY_STOGROUP == 0);
bool ip_flag = PARAM.inp.out_alllog || (GlobalV::RANK_IN_POOL == 0 && GlobalV::MY_BNDGROUP == 0);
if (GlobalV::MY_POOL == ip && ip_flag)
{
const int start_ik = nks_np * is;
Expand Down
2 changes: 1 addition & 1 deletion source/module_elecstate/elecstate_pw_sdft.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ void ElecStatePW_SDFT<T, Device>::psiToRho(const psi::Psi<T, Device>& psi)
setmem_var_op()(this->rho[is], 0, this->charge->nrxx);
}

if (GlobalV::MY_STOGROUP == 0)
if (GlobalV::MY_BNDGROUP == 0 || PARAM.inp.ks_solver == "bpcg")
{
for (int ik = 0; ik < psi.get_nk(); ++ik)
{
Expand Down
6 changes: 3 additions & 3 deletions source/module_elecstate/test_mpi/charge_mpi_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ TEST_F(ChargeMpiTest, reduce_diff_pools1)
GlobalV::MY_RANK,
PARAM.input.bndpar,
GlobalV::KPAR,
GlobalV::NPROC_IN_STOGROUP,
GlobalV::NPROC_IN_BNDGROUP,
GlobalV::RANK_IN_BPGROUP,
GlobalV::MY_STOGROUP,
GlobalV::NPROC_IN_POOL,
Expand Down Expand Up @@ -116,7 +116,7 @@ TEST_F(ChargeMpiTest, reduce_diff_pools2)
GlobalV::MY_RANK,
PARAM.input.bndpar,
GlobalV::KPAR,
GlobalV::NPROC_IN_STOGROUP,
GlobalV::NPROC_IN_BNDGROUP,
GlobalV::RANK_IN_BPGROUP,
GlobalV::MY_STOGROUP,
GlobalV::NPROC_IN_POOL,
Expand Down Expand Up @@ -171,7 +171,7 @@ TEST_F(ChargeMpiTest, rho_mpi)
GlobalV::MY_RANK,
PARAM.input.bndpar,
GlobalV::KPAR,
GlobalV::NPROC_IN_STOGROUP,
GlobalV::NPROC_IN_BNDGROUP,
GlobalV::RANK_IN_BPGROUP,
GlobalV::MY_STOGROUP,
GlobalV::NPROC_IN_POOL,
Expand Down
4 changes: 2 additions & 2 deletions source/module_esolver/esolver_ks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ void ESolver_KS<T, Device>::hamilt2density(UnitCell& ucell, const int istep, con
// Maybe in the future, density and wavefunctions should use different
// parallel algorithms, in which they do not occupy all processors, for
// example wavefunctions uses 20 processors while density uses 10.
if (GlobalV::MY_STOGROUP == 0)
if (GlobalV::MY_BNDGROUP == 0)
{
// double drho = this->estate.caldr2();
// EState should be used after it is constructed.
Expand Down Expand Up @@ -550,7 +550,7 @@ void ESolver_KS<T, Device>::iter_finish(UnitCell& ucell, const int istep, int& i
this->pelec->charge->rho,
this->pelec->nelec_spin.data());

if (GlobalV::MY_STOGROUP == 0)
if (GlobalV::MY_BNDGROUP == 0)
{
// mixing will restart at this->p_chgmix->mixing_restart steps
if (drho <= PARAM.inp.mixing_restart && PARAM.inp.mixing_restart > 0.0
Expand Down
2 changes: 1 addition & 1 deletion source/module_esolver/esolver_sdft_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ void ESolver_SDFT_PW<T, Device>::hamilt2density_single(UnitCell& ucell, int iste
// set_diagethr need it
this->esolver_KS_ne = hsolver_pw_sdft_obj.stoiter.KS_ne;

if (GlobalV::MY_STOGROUP == 0)
if (GlobalV::MY_BNDGROUP == 0)
{
Symmetry_rho srho;
for (int is = 0; is < PARAM.inp.nspin; is++)
Expand Down
2 changes: 1 addition & 1 deletion source/module_hamilt_pw/hamilt_pwdft/parallel_grid.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ void Parallel_Grid::init(

this->nproc_in_pool = new int[GlobalV::KPAR];
int nprocgroup;
if(PARAM.inp.esolver_type == "sdft") { nprocgroup = GlobalV::NPROC_IN_STOGROUP;
if(PARAM.inp.esolver_type == "sdft") { nprocgroup = GlobalV::NPROC_IN_BNDGROUP;
} else { nprocgroup = GlobalV::NPROC;
}

Expand Down
4 changes: 2 additions & 2 deletions source/module_hamilt_pw/hamilt_stodft/sto_elecond.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -620,7 +620,7 @@ void Sto_EleCond::sKG(const int& smear_type,
}
// Parallel for bands
int allbands_ks = this->nbands_ks - cutib0;
parallel_distribution paraks(allbands_ks, PARAM.inp.bndpar, GlobalV::MY_STOGROUP);
parallel_distribution paraks(allbands_ks, PARAM.inp.bndpar, GlobalV::MY_BNDGROUP);
int perbands_ks = paraks.num_per;
int ib0_ks = paraks.start;
ib0_ks += this->nbands_ks - allbands_ks;
Expand Down Expand Up @@ -653,7 +653,7 @@ void Sto_EleCond::sKG(const int& smear_type,
//-----------------------------------------------------------
// ks conductivity
//-----------------------------------------------------------
if (GlobalV::MY_STOGROUP == 0 && allbands_ks > 0)
if (GlobalV::MY_BNDGROUP == 0 && allbands_ks > 0)
{
jjresponse_ks(ik, nt, dt, dEcut, this->p_elec->wg, velop, ct11.data(), ct12.data(), ct22.data());
}
Expand Down
2 changes: 1 addition & 1 deletion source/module_hamilt_pw/hamilt_stodft/sto_forces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ void Sto_Forces<FPTYPE, Device>::cal_sto_force_nl(
const int* nchip = stowf.nchip;
const int npwx = wfc_basis->npwk_max;
int nksbands = psi_in.get_nbands();
if (GlobalV::MY_STOGROUP != 0)
if (GlobalV::MY_BNDGROUP != 0)
{
nksbands = 0;
}
Expand Down
4 changes: 2 additions & 2 deletions source/module_hamilt_pw/hamilt_stodft/sto_stress_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ void Sto_Stress_PW<FPTYPE, Device>::sto_stress_kin(ModuleBase::matrix& sigma,
ModuleBase::timer::tick("Sto_Stress_PW", "stress_kin");

int nksbands = psi.get_nbands();
if (GlobalV::MY_STOGROUP != 0)
if (GlobalV::MY_BNDGROUP != 0)
{
nksbands = 0;
}
Expand Down Expand Up @@ -160,7 +160,7 @@ void Sto_Stress_PW<FPTYPE, Device>::sto_stress_nl(ModuleBase::matrix& sigma,
int* nchip = stowf.nchip;
const int npwx = wfc_basis->npwk_max;
int nksbands = psi_in.get_nbands();
if (GlobalV::MY_STOGROUP != 0)
if (GlobalV::MY_BNDGROUP != 0 && PARAM.inp.ks_solver != "bpcg")
{
nksbands = 0;
}
Expand Down
12 changes: 6 additions & 6 deletions source/module_hamilt_pw/hamilt_stodft/sto_wf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ void Stochastic_WF<T, Device>::init_sto_orbitals(const int seed_in)
}
else
{
srand((unsigned)std::abs(seed_in) + (GlobalV::MY_STOGROUP * GlobalV::NPROC_IN_STOGROUP + GlobalV::RANK_IN_BPGROUP) * 10000);
srand((unsigned)std::abs(seed_in) + (GlobalV::MY_BNDGROUP * GlobalV::NPROC_IN_BNDGROUP + GlobalV::RANK_IN_BPGROUP) * 10000);
}

this->allocate_chi0();
Expand All @@ -88,12 +88,12 @@ void Stochastic_WF<T, Device>::allocate_chi0()
// former processor calculate more bands
if (firstrankmore)
{
igroup = GlobalV::MY_STOGROUP;
igroup = GlobalV::MY_BNDGROUP;
}
// latter processor calculate more bands
else
{
igroup = PARAM.inp.bndpar - GlobalV::MY_STOGROUP - 1;
igroup = PARAM.inp.bndpar - GlobalV::MY_BNDGROUP - 1;
}
const int nchi = PARAM.inp.nbands_sto;
const int npwx = this->npwx;
Expand Down Expand Up @@ -172,12 +172,12 @@ void Stochastic_WF<T, Device>::init_com_orbitals()
// former processor calculate more bands
if (firstrankmore)
{
igroup = GlobalV::MY_STOGROUP;
igroup = GlobalV::MY_BNDGROUP;
}
// latter processor calculate more bands
else
{
igroup = PARAM.inp.bndpar - GlobalV::MY_STOGROUP - 1;
igroup = PARAM.inp.bndpar - GlobalV::MY_BNDGROUP - 1;
}
const int ngroup = PARAM.inp.bndpar;
const int n_in_pool = GlobalV::NPROC_IN_POOL;
Expand Down Expand Up @@ -318,7 +318,7 @@ void Stochastic_WF<T, Device>::init_sto_orbitals_Ecut(const int seed_in,
MPI_Allgather(&nchiper, 1, MPI_INT, nrecv, 1, MPI_INT, PARAPW_WORLD);
#endif
int ichi_start = 0;
for (int i = 0; i < GlobalV::MY_STOGROUP; ++i)
for (int i = 0; i < GlobalV::MY_BNDGROUP; ++i)
{
ichi_start += nrecv[i];
}
Expand Down
4 changes: 2 additions & 2 deletions source/module_hsolver/hsolver_pw_sdft.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ void HSolverPW_SDFT<T, Device>::solve(const UnitCell& ucell,
{
ModuleBase::timer::tick("HSolverPW_SDFT", "solve_KS");
pHamilt->updateHk(ik);
if (nbands > 0 && GlobalV::MY_STOGROUP == 0)
if (nbands > 0 && GlobalV::MY_BNDGROUP == 0)
{
/// update psi pointer for each k point
psi.fix_k(ik);
Expand Down Expand Up @@ -89,7 +89,7 @@ void HSolverPW_SDFT<T, Device>::solve(const UnitCell& ucell,

// calculate eband = \sum_{ik,ib} w(ik)f(ik,ib)e_{ikib}, demet = -TS
elecstate::ElecStatePW<T, Device>* pes_pw = static_cast<elecstate::ElecStatePW<T, Device>*>(pes);
if (GlobalV::MY_STOGROUP == 0)
if (GlobalV::MY_BNDGROUP == 0)
{
pes_pw->calEBand();
}
Expand Down
10 changes: 4 additions & 6 deletions source/module_hsolver/test/diago_lapack_test.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
// Author: Zhang Xiaoyang
// A modified version of diago_lcao_test.cpp
// #define private public
#define private public
#include "module_parameter/parameter.h"
// #undef private
#undef private
// Remove some useless functions and dependencies. Serialized the full code
// and refactored some function.

Expand Down Expand Up @@ -187,10 +187,8 @@ class DiagoLapackPrepare

void set_env()
{
// PARAM.sys.nlocal = nlocal;
PARAM.set_sys_nlocal(nlocal);
// PARAM.input.nbands = nbands;
PARAM.set_input_nbands(nbands);
PARAM.sys.nlocal = nlocal;
PARAM.input.nbands = nbands;
}

void diago()
Expand Down
4 changes: 2 additions & 2 deletions source/module_io/read_input.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ void ReadInput::read_parameters(Parameter& param, const std::string& filename_in

// 3. check the number of atom types from STRU file
// set the global directories
this->set_global_dir(param);
this->set_global_dir(param.inp, param.sys);
if (this->check_ntype_flag && this->rank == 0)
{
check_ntype(param.globalv.global_in_stru, param.input.ntype);
Expand All @@ -143,7 +143,7 @@ void ReadInput::read_parameters(Parameter& param, const std::string& filename_in
}

// 5. set the globalv parameters, some parameters in different processes are different. e.g. rank
this->set_globalv(param);
this->set_globalv(param.inp, param.sys);

if (this->check_mode)
{
Expand Down
Loading