Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Atom virial and stress calculation for virial calculation parallelism of Allegro in Lammps #281

Open
wants to merge 28 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
dcab8be
ParaStressOutput Implemented for
Hongyu-yu Dec 15, 2022
6d29b51
CHANGELOG.md and test tolerance update.
Hongyu-yu Dec 15, 2022
e1c98cd
black format
Hongyu-yu Dec 15, 2022
074bf42
black format
Hongyu-yu Dec 15, 2022
fa60cb7
Update test
Hongyu-yu Dec 16, 2022
4f5e655
Merge branch 'develop' into para_stress
Huni-ML Dec 16, 2022
1b7a103
flake8
Hongyu-yu Dec 16, 2022
60a41ba
Merge branch 'para_stress' of github.com:Hongyu-yu/nequip into para_s…
Hongyu-yu Dec 16, 2022
6c31b19
fix para bug in lammps
Hongyu-yu Dec 16, 2022
e96dd7e
Merge branch 'develop' into para_stress
Hongyu-yu Dec 20, 2022
756f4c6
Merge branch 'mir-group:main' into para_stress
Hongyu-yu Dec 23, 2022
bb10cc6
Add reference in ParaStressOutput
Hongyu-yu Dec 23, 2022
ba3379f
Merge branch 'develop' into para_stress
Hongyu-yu Jan 7, 2023
9719567
Merge branch 'develop' into para_stress
Hongyu-yu Jan 10, 2023
d4a80dd
Merge branch 'develop' into para_stress
Hongyu-yu Jan 21, 2023
6638451
Format and update tests
Hongyu-yu Jan 24, 2023
8d283bf
fix atom_virial sign
Hongyu-yu Feb 6, 2023
fbd560c
Merge branch 'develop' into para_stress
Hongyu-yu Feb 21, 2023
8dd2e03
Merge branch 'develop' into para_stress
Hongyu-yu Feb 22, 2023
b502de8
Update atom_virial in ParaStressForceOutput
Hongyu-yu Mar 27, 2023
22be5e7
Merge branch 'develop' into para_stress
Hongyu-yu Mar 29, 2023
ddba5e4
Update _grad_output.py
Hongyu-yu Mar 29, 2023
56aa691
remove warning
Hongyu-yu Mar 29, 2023
d591b6e
Update _grad_output.py
Hongyu-yu Mar 29, 2023
9701437
add para_stress tests workflows
Hongyu-yu Mar 29, 2023
235bfe0
fix parastress irreps
Hongyu-yu Mar 29, 2023
9c18611
Merge branch 'develop' into para_stress
Hongyu-yu Apr 18, 2023
6ca00ac
Merge branch 'develop' into para_stress
Hongyu-yu May 3, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions .github/workflows/tests_stress.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
name: Run Tests

on:
push:
branches:
- para_stress

pull_request:
branches:
- para_stress

jobs:
build:

runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.9]
torch-version: [1.11.0, 1.13.1]

steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
env:
TORCH: "${{ matrix.torch-version }}"
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
python -m pip install --upgrade pip
pip install setuptools wheel
pip install torch==${TORCH} -f https://download.pytorch.org/whl/cpu/torch_stable.html
pip install h5py
pip install --upgrade-strategy only-if-needed .
- name: Install pytest
run: |
pip install pytest
pip install pytest-xdist[psutil]
- name: Download test data
run: |
mkdir benchmark_data
cd benchmark_data; wget "http://quantum-machine.org/gdml/data/npz/aspirin_ccsd.zip"; cd ..
- name: Test with pytest
run: |
# See https://github.com/pytest-dev/pytest/issues/1075
PYTHONHASHSEED=0 pytest -n auto tests/
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -159,3 +159,5 @@ dmypy.json

# Cython debug symbols
cython_debug/

.history
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ Most recent change on the bottom.

## [0.5.6] - 2022-12-19
### Added
- `ParaStressForceOutput` Support for atom virial and virial parallism in Lammps.
- sklearn dependency removed
- `nequip-benchmark` and `nequip-train` report number of weights and number of trainable weights
- `nequip-benchmark --no-compile` and `--verbose` and `--memory-summary`
Expand Down
1 change: 1 addition & 0 deletions nequip/data/AtomicData.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
AtomicDataDict.FORCE_KEY,
AtomicDataDict.PER_ATOM_ENERGY_KEY,
AtomicDataDict.BATCH_KEY,
AtomicDataDict.ATOM_VIRIAL_KEY,
}
_DEFAULT_EDGE_FIELDS: Set[str] = {
AtomicDataDict.EDGE_CELL_SHIFT_KEY,
Expand Down
2 changes: 2 additions & 0 deletions nequip/data/_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
PARTIAL_FORCE_KEY: Final[str] = "partial_forces"
STRESS_KEY: Final[str] = "stress"
VIRIAL_KEY: Final[str] = "virial"
ATOM_VIRIAL_KEY: Final[str] = "atom_virial"

ALL_ENERGY_KEYS: Final[List[str]] = [
PER_ATOM_ENERGY_KEY,
Expand All @@ -65,6 +66,7 @@
PARTIAL_FORCE_KEY,
STRESS_KEY,
VIRIAL_KEY,
ATOM_VIRIAL_KEY,
]

BATCH_KEY: Final[str] = "batch"
Expand Down
8 changes: 7 additions & 1 deletion nequip/model/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
from ._eng import EnergyModel, SimpleIrrepsConfig
from ._grads import ForceOutput, PartialForceOutput, StressForceOutput
from ._grads import (
ForceOutput,
PartialForceOutput,
StressForceOutput,
ParaStressForceOutput,
)
from ._scaling import RescaleEnergyEtc, PerSpeciesRescale
from ._weight_init import (
uniform_initialize_FCs,
Expand All @@ -18,6 +23,7 @@
ForceOutput,
PartialForceOutput,
StressForceOutput,
ParaStressForceOutput,
RescaleEnergyEtc,
PerSpeciesRescale,
uniform_initialize_FCs,
Expand Down
18 changes: 18 additions & 0 deletions nequip/model/_grads.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from nequip.nn import GraphModuleMixin, GradientOutput
from nequip.nn import PartialForceOutput as PartialForceOutputModule
from nequip.nn import StressOutput as StressOutputModule
from nequip.nn import ParaStressOutput as ParaStressOutputModule
from nequip.data import AtomicDataDict


Expand Down Expand Up @@ -56,3 +57,20 @@ def StressForceOutput(model: GraphModuleMixin) -> GradientOutput:
):
raise ValueError("This model already has force or stress outputs.")
return StressOutputModule(func=model)


def ParaStressForceOutput(model: GraphModuleMixin) -> GradientOutput:
r"""Add forces and stresses to a model that predicts energy.

Args:
model: the model to wrap. Must have ``AtomicDataDict.TOTAL_ENERGY_KEY`` as an output.

Returns:
A ``StressOutput`` wrapping ``model``.
"""
if (
AtomicDataDict.FORCE_KEY in model.irreps_out
or AtomicDataDict.STRESS_KEY in model.irreps_out
):
raise ValueError("This model already has force or stress outputs.")
return ParaStressOutputModule(func=model)
7 changes: 6 additions & 1 deletion nequip/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,12 @@
PerSpeciesScaleShift,
) # noqa: F401
from ._interaction_block import InteractionBlock # noqa: F401
from ._grad_output import GradientOutput, PartialForceOutput, StressOutput # noqa: F401
from ._grad_output import ( # noqa: F401
GradientOutput,
PartialForceOutput,
StressOutput,
ParaStressOutput,
) # noqa: F401
from ._rescale import RescaleOutput # noqa: F401
from ._convnetlayer import ConvNetLayer # noqa: F401
from ._util import SaveForOutput # noqa: F401
Expand Down
155 changes: 155 additions & 0 deletions nequip/nn/_grad_output.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import List, Union, Optional

import torch
from torch_runstats.scatter import scatter

from e3nn.o3 import Irreps
from e3nn.util.jit import compile_mode
Expand Down Expand Up @@ -356,3 +357,157 @@ def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type:
pos.requires_grad_(False)

return data


@compile_mode("script")
class ParaStressOutput(GraphModuleMixin, torch.nn.Module):
r"""Compute stress, atomic virial, forces using autograd of an energy model.
Edge-vector based atomic virial.
Design for Lammps parallism.

See:
H. Yu, Y. Zhong, J. Ji, X. Gong, and H. Xiang, Time-Reversal Equivariant Neural Network Potential and Hamiltonian for Magnetic Materials, arXiv:2211.11403.
https://arxiv.org/abs/2211.11403

Knuth et. al. Comput. Phys. Commun 190, 33-50, 2015
https://pure.mpg.de/rest/items/item_2085135_9/component/file_2156800/content

Args:
func: the energy model to wrap
do_forces: whether to compute forces as well
"""
do_forces: bool

def __init__(
self,
func: GraphModuleMixin,
do_forces: bool = True,
):
super().__init__()

if not do_forces:
raise NotImplementedError
self.do_forces = do_forces

self.func = func

# check and init irreps
self._init_irreps(
irreps_in=self.func.irreps_in.copy(),
irreps_out=self.func.irreps_out.copy(),
)
self.irreps_out[AtomicDataDict.FORCE_KEY] = "1o"
self.irreps_out[AtomicDataDict.STRESS_KEY] = "1o"
self.irreps_out[AtomicDataDict.VIRIAL_KEY] = "1o"
self.irreps_out[AtomicDataDict.ATOM_VIRIAL_KEY] = "1o"

# for torchscript compat
self.register_buffer("_empty", torch.Tensor())

def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type:
assert AtomicDataDict.EDGE_VECTORS_KEY not in data

if AtomicDataDict.BATCH_KEY in data:
batch = data[AtomicDataDict.BATCH_KEY]
num_batch: int = len(data[AtomicDataDict.BATCH_PTR_KEY]) - 1
else:
data = AtomicDataDict.with_batch(data)
# Special case for efficiency
batch = data[AtomicDataDict.BATCH_KEY]
num_batch: int = 1

pos = data[AtomicDataDict.POSITIONS_KEY]

has_cell: bool = AtomicDataDict.CELL_KEY in data

if has_cell:
orig_cell = data[AtomicDataDict.CELL_KEY]
# Make the cell per-batch
cell = orig_cell.view(-1, 3, 3).expand(num_batch, 3, 3)
data[AtomicDataDict.CELL_KEY] = cell
else:
# torchscript
orig_cell = self._empty
cell = self._empty
Comment on lines +423 to +431
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This business is all unnecessary now, right? Since there is no derivative w.r.t. the cell, so the cell doesn't need to be made per-batch-entry

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Cell here is only used to get the stress and for the training with stress.


did_pos_req_grad: bool = pos.requires_grad
pos.requires_grad_(True)
data[AtomicDataDict.POSITIONS_KEY] = pos
data = AtomicDataDict.with_edge_vectors(data, with_lengths=False)
data[AtomicDataDict.EDGE_VECTORS_KEY].requires_grad_(True)
Comment on lines +436 to +437
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍


# Call model and get gradients
data = self.func(data)

grads = torch.autograd.grad(
[data[AtomicDataDict.TOTAL_ENERGY_KEY].sum()],
[pos, data[AtomicDataDict.EDGE_VECTORS_KEY]],
create_graph=self.training, # needed to allow gradients of this output during training
)

# Put negative sign on forces
forces = grads[0]
if forces is None:
# condition needed to unwrap optional for torchscript
assert False, "failed to compute forces autograd"
forces = torch.neg(forces)
data[AtomicDataDict.FORCE_KEY] = forces

# Store virial
vector_force = grads[1]
if vector_force is None:
# condition needed to unwrap optional for torchscript
assert False, "failed to compute vector_force autograd"
edge_virial = torch.einsum(
"zi,zj->zij", vector_force, data[AtomicDataDict.EDGE_VECTORS_KEY]
)
atom_virial = scatter(
edge_virial,
data[AtomicDataDict.EDGE_INDEX_KEY][0],
dim=0,
reduce="sum",
dim_size=len(pos),
)
# edge_virial is distributed into two nodes equally. Not sure and need more tests
atom_virial = (atom_virial + scatter(
edge_virial,
data[AtomicDataDict.EDGE_INDEX_KEY][1],
dim=0,
reduce="sum",
dim_size=len(pos),
))/2

virial = scatter(atom_virial, batch, dim=0, reduce="sum")
virial = (virial + virial.transpose(-1, -2)) / 2 # symmetric

# we only compute the stress (1/V * virial) if we have a cell whose volume we can compute
if has_cell:
# ^ can only scale by cell volume if we have one...:
# Rescale stress tensor
# See https://github.com/atomistic-machine-learning/schnetpack/blob/master/src/schnetpack/atomistic/output_modules.py#L180
# First dim is batch, second is vec, third is xyz
volume = torch.einsum(
"zi,zi->z",
cell[:, 0, :],
torch.cross(cell[:, 1, :], cell[:, 2, :], dim=1),
).unsqueeze(-1)
stress = virial / volume.view(num_batch, 1, 1)
data[AtomicDataDict.CELL_KEY] = orig_cell
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
data[AtomicDataDict.CELL_KEY] = orig_cell

else:
stress = self._empty # torchscript
data[AtomicDataDict.STRESS_KEY] = stress

# see discussion in https://github.com/libAtoms/QUIP/issues/227 about sign convention
# they say the standard convention is virial = -stress x volume
# looking above this means that we need to pick up another negative sign for the virial
# to fit this equation with the stress computed above
virial = torch.neg(virial)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would need to confirm correctness of sign convention here...

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree. I didn't check it and just inherit it from StressForceOutput .

atom_virial = torch.neg(atom_virial)
data[AtomicDataDict.VIRIAL_KEY] = virial
data[AtomicDataDict.ATOM_VIRIAL_KEY] = atom_virial

if not did_pos_req_grad:
# don't give later modules one that does
pos.requires_grad_(False)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should also reset requires_grad opn the edge vectors


return data
9 changes: 7 additions & 2 deletions nequip/utils/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,11 @@ def assert_AtomicData_equivariant(
# must be this to actually rotate it when flattened
irps[AtomicDataDict.CELL_KEY] = "3x1o"

stress_keys = (AtomicDataDict.STRESS_KEY, AtomicDataDict.VIRIAL_KEY)
stress_keys = (
AtomicDataDict.STRESS_KEY,
AtomicDataDict.VIRIAL_KEY,
AtomicDataDict.ATOM_VIRIAL_KEY,
)
for k in stress_keys:
irreps_in.pop(k, None)
if any(k in irreps_out for k in stress_keys):
Expand All @@ -219,7 +223,8 @@ def assert_AtomicData_equivariant(
stress_rtp = stress_cart_tensor.reduced_tensor_products().to(device, dtype)
# symmetric 3x3 cartesian tensor as irreps
for k in stress_keys:
irreps_out[k] = stress_cart_tensor
if k in irreps_out:
irreps_out[k] = stress_cart_tensor

def wrapper(*args):
arg_dict = {k: v for k, v in zip(irreps_in, args)}
Expand Down
42 changes: 41 additions & 1 deletion nequip/utils/unittests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
import os

from ase.atoms import Atoms
from ase.build import molecule, bulk

from ase.build import molecule, bulk, make_supercell

from ase.calculators.singlepoint import SinglePointCalculator
from ase.io import write

Expand Down Expand Up @@ -162,6 +164,44 @@ def atomic_batch(nequip_dataset):
return Batch.from_data_list([nequip_dataset[0], nequip_dataset[1]])


@pytest.fixture(scope="session")
def bulks() -> List[Atoms]:
atoms_list = []
for i in range(8):
atoms = bulk("C")
atoms = make_supercell(atoms, [[2, 0, 0], [0, 2, 0], [0, 0, 2]])
atoms.rattle()
atoms.calc = SinglePointCalculator(
energy=np.random.random(),
forces=np.random.random((len(atoms), 3)),
stress=np.random.random(6),
magmoms=None,
atoms=atoms,
)
Comment on lines +174 to +180
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fake data is never used, right?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right. This part is just inherited from molecules. Can just delete them.

atoms_list.append(atoms)
return atoms_list


@pytest.fixture(scope="session")
def nequip_bulk_dataset(bulks, temp_data, float_tolerance):
with tempfile.NamedTemporaryFile(suffix=".xyz") as fp:
for atoms in bulks:
write(fp.name, atoms, format="extxyz", append=True)
a = ASEDataset(
file_name=fp.name,
root=temp_data,
AtomicData_options={"r_max": 5.0},
ase_args=dict(format="extxyz"),
type_mapper=TypeMapper(chemical_symbol_to_type={"H": 0, "C": 1, "O": 2}),
)
yield a


@pytest.fixture(scope="session")
def atomic_bulk_batch(nequip_bulk_dataset):
return Batch.from_data_list([nequip_bulk_dataset[0], nequip_bulk_dataset[1]])


@pytest.fixture(scope="function")
def per_species_set():
dtype = torch.get_default_dtype()
Expand Down
Loading