-
Notifications
You must be signed in to change notification settings - Fork 143
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Atom virial and stress calculation for virial calculation parallelism of Allegro in Lammps #281
base: develop
Are you sure you want to change the base?
Changes from all commits
dcab8be
6d29b51
e1c98cd
074bf42
fa60cb7
4f5e655
1b7a103
60a41ba
6c31b19
e96dd7e
756f4c6
bb10cc6
ba3379f
9719567
d4a80dd
6638451
8d283bf
fbd560c
8dd2e03
b502de8
22be5e7
ddba5e4
56aa691
d591b6e
9701437
235bfe0
9c18611
6ca00ac
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
name: Run Tests | ||
|
||
on: | ||
push: | ||
branches: | ||
- para_stress | ||
|
||
pull_request: | ||
branches: | ||
- para_stress | ||
|
||
jobs: | ||
build: | ||
|
||
runs-on: ubuntu-latest | ||
strategy: | ||
matrix: | ||
python-version: [3.9] | ||
torch-version: [1.11.0, 1.13.1] | ||
|
||
steps: | ||
- uses: actions/checkout@v2 | ||
- name: Set up Python ${{ matrix.python-version }} | ||
uses: actions/setup-python@v2 | ||
with: | ||
python-version: ${{ matrix.python-version }} | ||
- name: Install dependencies | ||
env: | ||
TORCH: "${{ matrix.torch-version }}" | ||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} | ||
run: | | ||
python -m pip install --upgrade pip | ||
pip install setuptools wheel | ||
pip install torch==${TORCH} -f https://download.pytorch.org/whl/cpu/torch_stable.html | ||
pip install h5py | ||
pip install --upgrade-strategy only-if-needed . | ||
- name: Install pytest | ||
run: | | ||
pip install pytest | ||
pip install pytest-xdist[psutil] | ||
- name: Download test data | ||
run: | | ||
mkdir benchmark_data | ||
cd benchmark_data; wget "http://quantum-machine.org/gdml/data/npz/aspirin_ccsd.zip"; cd .. | ||
- name: Test with pytest | ||
run: | | ||
# See https://github.com/pytest-dev/pytest/issues/1075 | ||
PYTHONHASHSEED=0 pytest -n auto tests/ |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -159,3 +159,5 @@ dmypy.json | |
|
||
# Cython debug symbols | ||
cython_debug/ | ||
|
||
.history |
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
@@ -1,6 +1,7 @@ | ||||
from typing import List, Union, Optional | ||||
|
||||
import torch | ||||
from torch_runstats.scatter import scatter | ||||
|
||||
from e3nn.o3 import Irreps | ||||
from e3nn.util.jit import compile_mode | ||||
|
@@ -356,3 +357,157 @@ def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type: | |||
pos.requires_grad_(False) | ||||
|
||||
return data | ||||
|
||||
|
||||
@compile_mode("script") | ||||
class ParaStressOutput(GraphModuleMixin, torch.nn.Module): | ||||
r"""Compute stress, atomic virial, forces using autograd of an energy model. | ||||
Edge-vector based atomic virial. | ||||
Design for Lammps parallism. | ||||
|
||||
See: | ||||
H. Yu, Y. Zhong, J. Ji, X. Gong, and H. Xiang, Time-Reversal Equivariant Neural Network Potential and Hamiltonian for Magnetic Materials, arXiv:2211.11403. | ||||
https://arxiv.org/abs/2211.11403 | ||||
|
||||
Knuth et. al. Comput. Phys. Commun 190, 33-50, 2015 | ||||
https://pure.mpg.de/rest/items/item_2085135_9/component/file_2156800/content | ||||
|
||||
Args: | ||||
func: the energy model to wrap | ||||
do_forces: whether to compute forces as well | ||||
""" | ||||
do_forces: bool | ||||
|
||||
def __init__( | ||||
self, | ||||
func: GraphModuleMixin, | ||||
do_forces: bool = True, | ||||
): | ||||
super().__init__() | ||||
|
||||
if not do_forces: | ||||
raise NotImplementedError | ||||
self.do_forces = do_forces | ||||
|
||||
self.func = func | ||||
|
||||
# check and init irreps | ||||
self._init_irreps( | ||||
irreps_in=self.func.irreps_in.copy(), | ||||
irreps_out=self.func.irreps_out.copy(), | ||||
) | ||||
self.irreps_out[AtomicDataDict.FORCE_KEY] = "1o" | ||||
self.irreps_out[AtomicDataDict.STRESS_KEY] = "1o" | ||||
self.irreps_out[AtomicDataDict.VIRIAL_KEY] = "1o" | ||||
self.irreps_out[AtomicDataDict.ATOM_VIRIAL_KEY] = "1o" | ||||
|
||||
# for torchscript compat | ||||
self.register_buffer("_empty", torch.Tensor()) | ||||
|
||||
def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type: | ||||
assert AtomicDataDict.EDGE_VECTORS_KEY not in data | ||||
|
||||
if AtomicDataDict.BATCH_KEY in data: | ||||
batch = data[AtomicDataDict.BATCH_KEY] | ||||
num_batch: int = len(data[AtomicDataDict.BATCH_PTR_KEY]) - 1 | ||||
else: | ||||
data = AtomicDataDict.with_batch(data) | ||||
# Special case for efficiency | ||||
batch = data[AtomicDataDict.BATCH_KEY] | ||||
num_batch: int = 1 | ||||
|
||||
pos = data[AtomicDataDict.POSITIONS_KEY] | ||||
|
||||
has_cell: bool = AtomicDataDict.CELL_KEY in data | ||||
|
||||
if has_cell: | ||||
orig_cell = data[AtomicDataDict.CELL_KEY] | ||||
# Make the cell per-batch | ||||
cell = orig_cell.view(-1, 3, 3).expand(num_batch, 3, 3) | ||||
data[AtomicDataDict.CELL_KEY] = cell | ||||
else: | ||||
# torchscript | ||||
orig_cell = self._empty | ||||
cell = self._empty | ||||
|
||||
did_pos_req_grad: bool = pos.requires_grad | ||||
pos.requires_grad_(True) | ||||
data[AtomicDataDict.POSITIONS_KEY] = pos | ||||
data = AtomicDataDict.with_edge_vectors(data, with_lengths=False) | ||||
data[AtomicDataDict.EDGE_VECTORS_KEY].requires_grad_(True) | ||||
Comment on lines
+436
to
+437
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 👍 |
||||
|
||||
# Call model and get gradients | ||||
data = self.func(data) | ||||
|
||||
grads = torch.autograd.grad( | ||||
[data[AtomicDataDict.TOTAL_ENERGY_KEY].sum()], | ||||
[pos, data[AtomicDataDict.EDGE_VECTORS_KEY]], | ||||
create_graph=self.training, # needed to allow gradients of this output during training | ||||
) | ||||
|
||||
# Put negative sign on forces | ||||
forces = grads[0] | ||||
if forces is None: | ||||
# condition needed to unwrap optional for torchscript | ||||
assert False, "failed to compute forces autograd" | ||||
forces = torch.neg(forces) | ||||
data[AtomicDataDict.FORCE_KEY] = forces | ||||
|
||||
# Store virial | ||||
vector_force = grads[1] | ||||
if vector_force is None: | ||||
# condition needed to unwrap optional for torchscript | ||||
assert False, "failed to compute vector_force autograd" | ||||
edge_virial = torch.einsum( | ||||
"zi,zj->zij", vector_force, data[AtomicDataDict.EDGE_VECTORS_KEY] | ||||
) | ||||
atom_virial = scatter( | ||||
edge_virial, | ||||
data[AtomicDataDict.EDGE_INDEX_KEY][0], | ||||
dim=0, | ||||
reduce="sum", | ||||
dim_size=len(pos), | ||||
) | ||||
# edge_virial is distributed into two nodes equally. Not sure and need more tests | ||||
atom_virial = (atom_virial + scatter( | ||||
edge_virial, | ||||
data[AtomicDataDict.EDGE_INDEX_KEY][1], | ||||
dim=0, | ||||
reduce="sum", | ||||
dim_size=len(pos), | ||||
))/2 | ||||
|
||||
virial = scatter(atom_virial, batch, dim=0, reduce="sum") | ||||
virial = (virial + virial.transpose(-1, -2)) / 2 # symmetric | ||||
|
||||
# we only compute the stress (1/V * virial) if we have a cell whose volume we can compute | ||||
if has_cell: | ||||
# ^ can only scale by cell volume if we have one...: | ||||
# Rescale stress tensor | ||||
# See https://github.com/atomistic-machine-learning/schnetpack/blob/master/src/schnetpack/atomistic/output_modules.py#L180 | ||||
# First dim is batch, second is vec, third is xyz | ||||
volume = torch.einsum( | ||||
"zi,zi->z", | ||||
cell[:, 0, :], | ||||
torch.cross(cell[:, 1, :], cell[:, 2, :], dim=1), | ||||
).unsqueeze(-1) | ||||
stress = virial / volume.view(num_batch, 1, 1) | ||||
data[AtomicDataDict.CELL_KEY] = orig_cell | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||
else: | ||||
stress = self._empty # torchscript | ||||
data[AtomicDataDict.STRESS_KEY] = stress | ||||
|
||||
# see discussion in https://github.com/libAtoms/QUIP/issues/227 about sign convention | ||||
# they say the standard convention is virial = -stress x volume | ||||
# looking above this means that we need to pick up another negative sign for the virial | ||||
# to fit this equation with the stress computed above | ||||
virial = torch.neg(virial) | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would need to confirm correctness of sign convention here... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agree. I didn't check it and just inherit it from |
||||
atom_virial = torch.neg(atom_virial) | ||||
data[AtomicDataDict.VIRIAL_KEY] = virial | ||||
data[AtomicDataDict.ATOM_VIRIAL_KEY] = atom_virial | ||||
|
||||
if not did_pos_req_grad: | ||||
# don't give later modules one that does | ||||
pos.requires_grad_(False) | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should also reset |
||||
|
||||
return data |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,7 +6,9 @@ | |
import os | ||
|
||
from ase.atoms import Atoms | ||
from ase.build import molecule, bulk | ||
|
||
from ase.build import molecule, bulk, make_supercell | ||
|
||
from ase.calculators.singlepoint import SinglePointCalculator | ||
from ase.io import write | ||
|
||
|
@@ -162,6 +164,44 @@ def atomic_batch(nequip_dataset): | |
return Batch.from_data_list([nequip_dataset[0], nequip_dataset[1]]) | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def bulks() -> List[Atoms]: | ||
atoms_list = [] | ||
for i in range(8): | ||
atoms = bulk("C") | ||
atoms = make_supercell(atoms, [[2, 0, 0], [0, 2, 0], [0, 0, 2]]) | ||
atoms.rattle() | ||
atoms.calc = SinglePointCalculator( | ||
energy=np.random.random(), | ||
forces=np.random.random((len(atoms), 3)), | ||
stress=np.random.random(6), | ||
magmoms=None, | ||
atoms=atoms, | ||
) | ||
Comment on lines
+174
to
+180
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This fake data is never used, right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Right. This part is just inherited from molecules. Can just delete them. |
||
atoms_list.append(atoms) | ||
return atoms_list | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def nequip_bulk_dataset(bulks, temp_data, float_tolerance): | ||
with tempfile.NamedTemporaryFile(suffix=".xyz") as fp: | ||
for atoms in bulks: | ||
write(fp.name, atoms, format="extxyz", append=True) | ||
a = ASEDataset( | ||
file_name=fp.name, | ||
root=temp_data, | ||
AtomicData_options={"r_max": 5.0}, | ||
ase_args=dict(format="extxyz"), | ||
type_mapper=TypeMapper(chemical_symbol_to_type={"H": 0, "C": 1, "O": 2}), | ||
) | ||
yield a | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def atomic_bulk_batch(nequip_bulk_dataset): | ||
return Batch.from_data_list([nequip_bulk_dataset[0], nequip_bulk_dataset[1]]) | ||
|
||
|
||
@pytest.fixture(scope="function") | ||
def per_species_set(): | ||
dtype = torch.get_default_dtype() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This business is all unnecessary now, right? Since there is no derivative w.r.t. the cell, so the cell doesn't need to be made per-batch-entry
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. Cell here is only used to get the stress and for the training with stress.