Skip to content

Commit

Permalink
PetscDiffSolver support for restricted solves
Browse files Browse the repository at this point in the history
  • Loading branch information
roystgnr committed Dec 14, 2022
1 parent 4d8dedb commit dcb2c7e
Show file tree
Hide file tree
Showing 2 changed files with 167 additions and 32 deletions.
29 changes: 25 additions & 4 deletions include/solvers/petsc_diff_solver.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,11 +91,9 @@ class PetscDiffSolver : public DiffSolver
* will be restricted to unconstrained dofs, with constraints
* applied separately as necessary.
*/
virtual void restrict_solves_to_unconstrained(bool restricting) override
{ if (restricting) libmesh_not_implemented(); }
virtual void restrict_solves_to_unconstrained(bool restricting) override;

virtual bool get_restrict_solves_to_unconstrained() override
{ return false; }
virtual bool get_restrict_solves_to_unconstrained() override;

/**
* This method performs a solve. What occurs in
Expand All @@ -104,6 +102,24 @@ class PetscDiffSolver : public DiffSolver
*/
virtual unsigned int solve () override;

/*
* Scatter connecting full vectors with unconstrained subvectors
*/
WrappedPetsc<VecScatter> scatter;

/*
* Submatrix of Jacobian with unconstrained-unconstrained coupling
*/
WrappedPetsc<Mat> submat;

bool submat_created;

/**
* PETSc index set containing unconstrained dofs on which to solve
* (\p nullptr means solve on all dofs).
*/
WrappedPetsc<IS> _unconstrained_dofs_is;

protected:

/**
Expand All @@ -120,6 +136,11 @@ class PetscDiffSolver : public DiffSolver
#endif
#endif

/**
* Are we restricting solves to unconstrained DoFs?
*/
bool _restrict_to_unconstrained;

private:

/**
Expand Down
170 changes: 142 additions & 28 deletions src/solvers/petsc_diff_solver.C
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,10 @@ extern "C"
// Function to hand to PETSc's SNES,
// which monitors convergence at X
PetscErrorCode
__libmesh_petsc_diff_solver_monitor (SNES snes,
PetscInt its,
PetscReal fnorm,
void * ctx)
_libmesh_petsc_diff_solver_monitor (SNES snes,
PetscInt its,
PetscReal fnorm,
void * ctx)
{
PetscDiffSolver & solver =
*(static_cast<PetscDiffSolver *> (ctx));
Expand Down Expand Up @@ -84,7 +84,7 @@ extern "C"
// Functions to hand to PETSc's SNES,
// which compute the residual or jacobian at X
PetscErrorCode
__libmesh_petsc_diff_solver_residual (SNES, Vec x, Vec r, void * ctx)
_libmesh_petsc_diff_solver_residual (SNES, Vec x, Vec r, void * ctx)
{
libmesh_assert(x);
libmesh_assert(r);
Expand All @@ -106,9 +106,20 @@ extern "C"
// DiffSystem assembles from the solution and into the rhs, so swap
// those with our input vectors before assembling. They'll probably
// already be references to the same vectors, but PETSc might do
// something tricky.
X_input.swap(X_system);
R_input.swap(R_system);
// something tricky. ... or we might do something tricky. If
// we're solving only on constrained DoFs, we'll need to do some
// scatters here.
if (solver.get_restrict_solves_to_unconstrained())
{
VecScatterBeginEnd(solver.comm(), solver.scatter,
X_input.vec(), X_system.vec(),
INSERT_VALUES, SCATTER_REVERSE);
}
else
{
X_input.swap(X_system);
R_input.swap(R_system);
}

// We may need to localize a parallel solution
sys.update();
Expand All @@ -121,20 +132,29 @@ extern "C"
R_system.close();

// Swap back
X_input.swap(X_system);
R_input.swap(R_system);
if (solver.get_restrict_solves_to_unconstrained())
{
VecScatterBeginEnd(solver.comm(), solver.scatter,
R_system.vec(), R_input.vec(),
INSERT_VALUES, SCATTER_FORWARD);
}
else
{
X_input.swap(X_system);
R_input.swap(R_system);
}

// No errors, we hope
return 0;
}


PetscErrorCode
__libmesh_petsc_diff_solver_jacobian (SNES,
Vec x,
Mat libmesh_dbg_var(j),
Mat pc,
void * ctx)
_libmesh_petsc_diff_solver_jacobian (SNES,
Vec x,
Mat libmesh_dbg_var(j),
Mat pc,
void * ctx)
{
libmesh_assert(x);
libmesh_assert(j);
Expand All @@ -159,9 +179,20 @@ extern "C"
// DiffSystem assembles from the solution and into the jacobian, so
// swap those with our input vectors before assembling. They'll
// probably already be references to the same vectors, but PETSc
// might do something tricky.
X_input.swap(X_system);
J_input.swap(J_system);
// might do something tricky. ... or we might do something
// tricky. If we're solving only on constrained DoFs, we'll need
// to do some scatters here.
if (solver.get_restrict_solves_to_unconstrained())
{
VecScatterBeginEnd(solver.comm(), solver.scatter,
X_input.vec(), X_system.vec(),
INSERT_VALUES, SCATTER_REVERSE);
}
else
{
X_input.swap(X_system);
J_input.swap(J_system);
}

// We may need to localize a parallel solution
sys.update();
Expand All @@ -174,8 +205,24 @@ extern "C"
J_system.close();

// Swap back
X_input.swap(X_system);
J_input.swap(J_system);
if (solver.get_restrict_solves_to_unconstrained())
{
PetscInt ierr =
LibMeshCreateSubMatrix(J_system.mat(),
solver._unconstrained_dofs_is,
solver._unconstrained_dofs_is,
solver.submat_created ?
MAT_REUSE_MATRIX :
MAT_INITIAL_MATRIX,
solver.submat.get());
LIBMESH_CHKERR(ierr);
solver.submat_created = true;
}
else
{
X_input.swap(X_system);
J_input.swap(J_system);
}

// No errors, we hope
return 0;
Expand All @@ -185,7 +232,7 @@ extern "C"


PetscDiffSolver::PetscDiffSolver (sys_type & s)
: Parent(s)
: Parent(s), submat_created(false)
{
}

Expand Down Expand Up @@ -280,6 +327,17 @@ DiffSolver::SolveResult convert_solve_result(SNESConvergedReason r)
}


void PetscDiffSolver::restrict_solves_to_unconstrained(bool restricting)
{
_restrict_to_unconstrained = restricting;
}


bool PetscDiffSolver::get_restrict_solves_to_unconstrained()
{
return _restrict_to_unconstrained;
}


unsigned int PetscDiffSolver::solve()
{
Expand All @@ -292,20 +350,76 @@ unsigned int PetscDiffSolver::solve()
PetscVector<Number> & r =
*(cast_ptr<PetscVector<Number> *>(_system.rhs));

WrappedPetsc<Vec> subrhs, subsolution;

PetscErrorCode ierr = 0;

ierr = SNESSetFunction (_snes, r.vec(),
__libmesh_petsc_diff_solver_residual, this);
LIBMESH_CHKERR(ierr);
if (_restrict_to_unconstrained)
{
std::vector<PetscInt> unconstrained_dofs;
const DofMap & dofmap = this->system().get_dof_map();
for (dof_id_type i = dofmap.first_dof(),
end_i = dofmap.end_dof();
i != end_i; ++i)
if (!dofmap.is_constrained_dof(i))
unconstrained_dofs.push_back(cast_int<PetscInt>(i));

const PetscInt is_local_size =
cast_int<PetscInt>(unconstrained_dofs.size());

ierr = ISCreateGeneral(this->comm().get(),
cast_int<PetscInt>(unconstrained_dofs.size()),
unconstrained_dofs.data(), PETSC_COPY_VALUES,
_unconstrained_dofs_is.get());

ierr = VecCreate(this->comm().get(), subrhs.get());
LIBMESH_CHKERR(ierr);
ierr = VecSetSizes(subrhs, is_local_size, PETSC_DECIDE);
LIBMESH_CHKERR(ierr);
ierr = VecSetFromOptions(subrhs);
LIBMESH_CHKERR(ierr);

ierr = SNESSetJacobian (_snes, jac.mat(), jac.mat(),
__libmesh_petsc_diff_solver_jacobian, this);
ierr = VecCreate(this->comm().get(), subsolution.get());
LIBMESH_CHKERR(ierr);
ierr = VecSetSizes(subsolution, is_local_size, PETSC_DECIDE);
LIBMESH_CHKERR(ierr);
ierr = VecSetFromOptions(subsolution);
LIBMESH_CHKERR(ierr);

ierr = VecScatterCreate(r.vec(), _unconstrained_dofs_is,
subrhs, nullptr, scatter.get());
LIBMESH_CHKERR(ierr);

VecScatterBeginEnd(this->comm(), scatter, r.vec(), subrhs, INSERT_VALUES, SCATTER_FORWARD);
VecScatterBeginEnd(this->comm(), scatter, x.vec(), subsolution, INSERT_VALUES, SCATTER_FORWARD);

LIBMESH_CHKERR(ierr);

ierr = SNESSetFunction (_snes, subrhs,
_libmesh_petsc_diff_solver_residual, this);
LIBMESH_CHKERR(ierr);

ierr = SNESSetJacobian (_snes, submat, submat,
_libmesh_petsc_diff_solver_jacobian, this);
}
else
{
ierr = SNESSetFunction (_snes, r.vec(),
_libmesh_petsc_diff_solver_residual, this);
LIBMESH_CHKERR(ierr);

ierr = SNESSetJacobian (_snes, jac.mat(), jac.mat(),
_libmesh_petsc_diff_solver_jacobian, this);
}
LIBMESH_CHKERR(ierr);

ierr = SNESSetFromOptions(_snes);
LIBMESH_CHKERR(ierr);

ierr = SNESSolve (_snes, PETSC_NULL, x.vec());
if (_restrict_to_unconstrained)
ierr = SNESSolve (_snes, PETSC_NULL, x.vec());
else
ierr = SNESSolve (_snes, PETSC_NULL, x.vec());
LIBMESH_CHKERR(ierr);

#ifdef LIBMESH_ENABLE_CONSTRAINTS
Expand Down Expand Up @@ -334,7 +448,7 @@ void PetscDiffSolver::setup_petsc_data()
ierr = SNESCreate(this->comm().get(), _snes.get());
LIBMESH_CHKERR(ierr);

ierr = SNESMonitorSet (_snes, __libmesh_petsc_diff_solver_monitor,
ierr = SNESMonitorSet (_snes, _libmesh_petsc_diff_solver_monitor,
this, PETSC_NULL);
LIBMESH_CHKERR(ierr);

Expand Down

0 comments on commit dcb2c7e

Please sign in to comment.