From dcab8beb2a94a9bf3430117029b0bda018ca2beb Mon Sep 17 00:00:00 2001 From: Hongyu-yu Date: Fri, 16 Dec 2022 01:42:43 +0800 Subject: [PATCH 01/16] ParaStressOutput Implemented for parallelism of virial calculation in Lammps. --- nequip/data/AtomicData.py | 1 + nequip/data/_keys.py | 2 + nequip/model/__init__.py | 3 +- nequip/model/_grads.py | 17 ++++ nequip/nn/__init__.py | 2 +- nequip/nn/_grad_output.py | 133 ++++++++++++++++++++++++++ nequip/utils/test.py | 2 +- nequip/utils/unittests/conftest.py | 40 +++++++- tests/unit/model/test_nequip_model.py | 37 ++++++- 9 files changed, 232 insertions(+), 5 deletions(-) diff --git a/nequip/data/AtomicData.py b/nequip/data/AtomicData.py index 728c260b..97d076b1 100644 --- a/nequip/data/AtomicData.py +++ b/nequip/data/AtomicData.py @@ -41,6 +41,7 @@ AtomicDataDict.FORCE_KEY, AtomicDataDict.PER_ATOM_ENERGY_KEY, AtomicDataDict.BATCH_KEY, + AtomicDataDict.ATOM_VIRIAL_KEY, } _DEFAULT_EDGE_FIELDS: Set[str] = { AtomicDataDict.EDGE_CELL_SHIFT_KEY, diff --git a/nequip/data/_keys.py b/nequip/data/_keys.py index 54b66ce3..85e5f344 100644 --- a/nequip/data/_keys.py +++ b/nequip/data/_keys.py @@ -55,6 +55,7 @@ PARTIAL_FORCE_KEY: Final[str] = "partial_forces" STRESS_KEY: Final[str] = "stress" VIRIAL_KEY: Final[str] = "virial" +ATOM_VIRIAL_KEY: Final[str] = "atom_virial" ALL_ENERGY_KEYS: Final[List[str]] = [ PER_ATOM_ENERGY_KEY, @@ -63,6 +64,7 @@ PARTIAL_FORCE_KEY, STRESS_KEY, VIRIAL_KEY, + ATOM_VIRIAL_KEY, ] BATCH_KEY: Final[str] = "batch" diff --git a/nequip/model/__init__.py b/nequip/model/__init__.py index b79a820c..e076f201 100644 --- a/nequip/model/__init__.py +++ b/nequip/model/__init__.py @@ -1,5 +1,5 @@ from ._eng import EnergyModel, SimpleIrrepsConfig -from ._grads import ForceOutput, PartialForceOutput, StressForceOutput +from ._grads import ForceOutput, PartialForceOutput, StressForceOutput, ParaStressForceOutput from ._scaling import RescaleEnergyEtc, PerSpeciesRescale from ._weight_init import ( uniform_initialize_FCs, @@ -17,6 +17,7 @@ ForceOutput, PartialForceOutput, StressForceOutput, + ParaStressForceOutput, RescaleEnergyEtc, PerSpeciesRescale, uniform_initialize_FCs, diff --git a/nequip/model/_grads.py b/nequip/model/_grads.py index b382d0b4..fa583224 100644 --- a/nequip/model/_grads.py +++ b/nequip/model/_grads.py @@ -1,6 +1,7 @@ from nequip.nn import GraphModuleMixin, GradientOutput from nequip.nn import PartialForceOutput as PartialForceOutputModule from nequip.nn import StressOutput as StressOutputModule +from nequip.nn import ParaStressOutput as ParaStressOutputModule from nequip.data import AtomicDataDict @@ -56,3 +57,19 @@ def StressForceOutput(model: GraphModuleMixin) -> GradientOutput: ): raise ValueError("This model already has force or stress outputs.") return StressOutputModule(func=model) + +def ParaStressForceOutput(model: GraphModuleMixin) -> GradientOutput: + r"""Add forces and stresses to a model that predicts energy. + + Args: + model: the model to wrap. Must have ``AtomicDataDict.TOTAL_ENERGY_KEY`` as an output. + + Returns: + A ``StressOutput`` wrapping ``model``. + """ + if ( + AtomicDataDict.FORCE_KEY in model.irreps_out + or AtomicDataDict.STRESS_KEY in model.irreps_out + ): + raise ValueError("This model already has force or stress outputs.") + return ParaStressOutputModule(func=model) \ No newline at end of file diff --git a/nequip/nn/__init__.py b/nequip/nn/__init__.py index 10cebee6..a9633544 100644 --- a/nequip/nn/__init__.py +++ b/nequip/nn/__init__.py @@ -6,7 +6,7 @@ PerSpeciesScaleShift, ) # noqa: F401 from ._interaction_block import InteractionBlock # noqa: F401 -from ._grad_output import GradientOutput, PartialForceOutput, StressOutput # noqa: F401 +from ._grad_output import GradientOutput, PartialForceOutput, StressOutput, ParaStressOutput # noqa: F401 from ._rescale import RescaleOutput # noqa: F401 from ._convnetlayer import ConvNetLayer # noqa: F401 from ._util import SaveForOutput # noqa: F401 diff --git a/nequip/nn/_grad_output.py b/nequip/nn/_grad_output.py index 673f8ff0..a690ea08 100644 --- a/nequip/nn/_grad_output.py +++ b/nequip/nn/_grad_output.py @@ -2,6 +2,7 @@ import warnings import torch +from torch_runstats.scatter import scatter from e3nn.o3 import Irreps from e3nn.util.jit import compile_mode @@ -334,3 +335,135 @@ def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type: pos.requires_grad_(False) return data + + +@compile_mode("script") +class ParaStressOutput(GraphModuleMixin, torch.nn.Module): + r"""Compute stress, atomic virial, forces using autograd of an energy model. + Design for Lammps parallism. + + See: + Knuth et. al. Comput. Phys. Commun 190, 33-50, 2015 + https://pure.mpg.de/rest/items/item_2085135_9/component/file_2156800/content + + Args: + func: the energy model to wrap + do_forces: whether to compute forces as well + """ + do_forces: bool + + def __init__( + self, + func: GraphModuleMixin, + do_forces: bool = True, + ): + super().__init__() + + warnings.warn( + "!! Stresses in NequIP are in BETA and UNDER DEVELOPMENT: _please_ carefully check the sanity of your results and report any (potential) issues on the GitHub" + ) + + if not do_forces: + raise NotImplementedError + self.do_forces = do_forces + + self.func = func + + # check and init irreps + self._init_irreps( + irreps_in=self.func.irreps_in.copy(), + irreps_out=self.func.irreps_out.copy(), + ) + self.irreps_out[AtomicDataDict.FORCE_KEY] = "1o" + self.irreps_out[AtomicDataDict.STRESS_KEY] = "3x1o" + self.irreps_out[AtomicDataDict.VIRIAL_KEY] = "3x1o" + self.irreps_out[AtomicDataDict.ATOM_VIRIAL_KEY] = "3x1o" + + # for torchscript compat + self.register_buffer("_empty", torch.Tensor()) + + def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type: + data = AtomicDataDict.with_batch(data) + + batch = data[AtomicDataDict.BATCH_KEY] + num_batch: int = int(batch.max().cpu().item()) + 1 + pos = data[AtomicDataDict.POSITIONS_KEY] + + has_cell: bool = AtomicDataDict.CELL_KEY in data + + if has_cell: + orig_cell = data[AtomicDataDict.CELL_KEY] + # Make the cell per-batch + cell = orig_cell.view(-1, 3, 3).expand(num_batch, 3, 3) + data[AtomicDataDict.CELL_KEY] = cell + else: + # torchscript + orig_cell = self._empty + cell = self._empty + + did_pos_req_grad: bool = pos.requires_grad + pos.requires_grad_(True) + data[AtomicDataDict.POSITIONS_KEY] = pos + data = AtomicDataDict.with_edge_vectors(data, with_lengths=False) + data[AtomicDataDict.EDGE_VECTORS_KEY].requires_grad_(True) + + # Call model and get gradients + data = self.func(data) + + grads = torch.autograd.grad( + [data[AtomicDataDict.TOTAL_ENERGY_KEY].sum()], + [pos, data[AtomicDataDict.EDGE_VECTORS_KEY]], + create_graph=self.training, # needed to allow gradients of this output during training + ) + + # Put negative sign on forces + forces = grads[0] + if forces is None: + # condition needed to unwrap optional for torchscript + assert False, "failed to compute forces autograd" + forces = torch.neg(forces) + data[AtomicDataDict.FORCE_KEY] = forces + + # Store virial + vector_force = grads[1] + edge_virial_1 = torch.einsum("zi,zj->zij", vector_force, data[AtomicDataDict.EDGE_VECTORS_KEY]) + edge_virial_2 = torch.einsum("zi,zj->zji", vector_force, data[AtomicDataDict.EDGE_VECTORS_KEY]) + edge_virial = (edge_virial_1 + edge_virial_2)/2 + atom_virial = scatter(edge_virial, data[AtomicDataDict.EDGE_INDEX_KEY][0], dim=0, reduce="sum") + virial = scatter(atom_virial, batch, dim=0, reduce="sum") + + # virial = grads[1] + if virial is None: + # condition needed to unwrap optional for torchscript + assert False, "failed to compute virial autograd" + + # we only compute the stress (1/V * virial) if we have a cell whose volume we can compute + if has_cell: + # ^ can only scale by cell volume if we have one...: + # Rescale stress tensor + # See https://github.com/atomistic-machine-learning/schnetpack/blob/master/src/schnetpack/atomistic/output_modules.py#L180 + # First dim is batch, second is vec, third is xyz + volume = torch.einsum( + "zi,zi->z", + cell[:, 0, :], + torch.cross(cell[:, 1, :], cell[:, 2, :], dim=1), + ).unsqueeze(-1) + stress = virial / volume.view(-1, 1, 1) + data[AtomicDataDict.CELL_KEY] = orig_cell + else: + stress = self._empty # torchscript + data[AtomicDataDict.STRESS_KEY] = stress + + # see discussion in https://github.com/libAtoms/QUIP/issues/227 about sign convention + # they say the standard convention is virial = -stress x volume + # looking above this means that we need to pick up another negative sign for the virial + # to fit this equation with the stress computed above + virial = torch.neg(virial) + data[AtomicDataDict.VIRIAL_KEY] = virial + data[AtomicDataDict.ATOM_VIRIAL_KEY] = atom_virial + + if not did_pos_req_grad: + # don't give later modules one that does + pos.requires_grad_(False) + + return data diff --git a/nequip/utils/test.py b/nequip/utils/test.py index edf2c1e8..fff68ca8 100644 --- a/nequip/utils/test.py +++ b/nequip/utils/test.py @@ -194,7 +194,7 @@ def assert_AtomicData_equivariant( # must be this to actually rotate it irps[AtomicDataDict.CELL_KEY] = "3x1o" - stress_keys = (AtomicDataDict.STRESS_KEY, AtomicDataDict.VIRIAL_KEY) + stress_keys = (AtomicDataDict.STRESS_KEY, AtomicDataDict.VIRIAL_KEY, AtomicDataDict.ATOM_VIRIAL_KEY) for k in stress_keys: irreps_in.pop(k, None) if any(k in irreps_out for k in stress_keys): diff --git a/nequip/utils/unittests/conftest.py b/nequip/utils/unittests/conftest.py index 77a91930..b3a4036f 100644 --- a/nequip/utils/unittests/conftest.py +++ b/nequip/utils/unittests/conftest.py @@ -6,7 +6,7 @@ import os from ase.atoms import Atoms -from ase.build import molecule +from ase.build import molecule, bulk, make_supercell from ase.calculators.singlepoint import SinglePointCalculator from ase.io import write @@ -133,6 +133,44 @@ def atomic_batch(nequip_dataset): return Batch.from_data_list([nequip_dataset[0], nequip_dataset[1]]) +@pytest.fixture(scope="session") +def bulks() -> List[Atoms]: + atoms_list = [] + for i in range(8): + atoms = bulk("C") + atoms = make_supercell(atoms,[[2,0,0],[0,2,0],[0,0,2]]) + atoms.rattle() + atoms.calc = SinglePointCalculator( + energy=np.random.random(), + forces=np.random.random((len(atoms), 3)), + stress=np.random.random(6), + magmoms=None, + atoms=atoms, + ) + atoms_list.append(atoms) + return atoms_list + + +@pytest.fixture(scope="session") +def nequip_bulk_dataset(bulks, temp_data, float_tolerance): + with tempfile.NamedTemporaryFile(suffix=".xyz") as fp: + for atoms in bulks: + write(fp.name, atoms, format="extxyz", append=True) + a = ASEDataset( + file_name=fp.name, + root=temp_data, + extra_fixed_fields={"r_max": 5.0}, + ase_args=dict(format="extxyz"), + type_mapper=TypeMapper(chemical_symbol_to_type={"H": 0, "C": 1, "O": 2}), + ) + yield a + +@pytest.fixture(scope="session") +def atomic_bulk_batch(nequip_bulk_dataset): + return Batch.from_data_list([nequip_bulk_dataset[0], nequip_bulk_dataset[1]]) + + + @pytest.fixture(scope="function") def per_species_set(): dtype = torch.get_default_dtype() diff --git a/tests/unit/model/test_nequip_model.py b/tests/unit/model/test_nequip_model.py index 2aa82e15..4a1df22e 100644 --- a/tests/unit/model/test_nequip_model.py +++ b/tests/unit/model/test_nequip_model.py @@ -1,8 +1,9 @@ import pytest +import torch from e3nn import o3 -from nequip.data import AtomicDataDict +from nequip.data import AtomicDataDict, AtomicData from nequip.model import model_from_config from nequip.nn import AtomwiseLinear from nequip.utils.unittests.model_tests import BaseEnergyModelTests @@ -95,6 +96,17 @@ def base_config(self, request): AtomicDataDict.VIRIAL_KEY, ], ), + ( + ["EnergyModel", "ParaStressForceOutput"], + [ + AtomicDataDict.TOTAL_ENERGY_KEY, + AtomicDataDict.PER_ATOM_ENERGY_KEY, + AtomicDataDict.FORCE_KEY, + AtomicDataDict.STRESS_KEY, + AtomicDataDict.VIRIAL_KEY, + AtomicDataDict.ATOM_VIRIAL_KEY, + ], + ), ], scope="class", ) @@ -120,3 +132,26 @@ def test_submods(self): model.layer0_convnet.irreps_in[model.chemical_embedding.out_field] == true_irreps ) + + def test_stress(self, atomic_bulk_batch, device): + config_og = minimal_config2.copy() + config_og["model_builders"] = ["EnergyModel", "StressForceOutput"] + model_og = model_from_config(config=config_og, initialize=True) + nn_state = model_og.state_dict() + + config_para = minimal_config2.copy() + config_para["model_builders"] = ["EnergyModel", "ParaStressForceOutput"] + model_para = model_from_config(config=config_para, initialize=True) + + model_para.load_state_dict(nn_state, strict = True) + + model_og.to(device) + model_para.to(device) + data = atomic_bulk_batch.to(device) + + output_og = model_og(AtomicData.to_AtomicDataDict(data)) + output_para = model_para(AtomicData.to_AtomicDataDict(data)) + + assert torch.allclose(output_og[AtomicDataDict.STRESS_KEY], output_para[AtomicDataDict.STRESS_KEY], atol=5e-6) + assert torch.allclose(output_og[AtomicDataDict.VIRIAL_KEY], output_para[AtomicDataDict.VIRIAL_KEY], atol=5e-6) + From 6d29b51a4c34559900cb67440400234d25b33119 Mon Sep 17 00:00:00 2001 From: Hongyu-yu Date: Fri, 16 Dec 2022 01:54:47 +0800 Subject: [PATCH 02/16] CHANGELOG.md and test tolerance update. --- CHANGELOG.md | 1 + tests/unit/model/test_nequip_model.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1a7bb63c..67dabd56 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ Most recent change on the bottom. ## [Unreleased] - 0.5.6 ### Added +- `ParaStressForceOutput` Support for atom virial and virial parallism in Lammps. - sklearn dependency removed - `nequip-benchmark` and `nequip-train` report number of weights and number of trainable weights - `nequip-benchmark --no-compile` and `--verbose` and `--memory-summary` diff --git a/tests/unit/model/test_nequip_model.py b/tests/unit/model/test_nequip_model.py index 4a1df22e..036aa386 100644 --- a/tests/unit/model/test_nequip_model.py +++ b/tests/unit/model/test_nequip_model.py @@ -152,6 +152,6 @@ def test_stress(self, atomic_bulk_batch, device): output_og = model_og(AtomicData.to_AtomicDataDict(data)) output_para = model_para(AtomicData.to_AtomicDataDict(data)) - assert torch.allclose(output_og[AtomicDataDict.STRESS_KEY], output_para[AtomicDataDict.STRESS_KEY], atol=5e-6) - assert torch.allclose(output_og[AtomicDataDict.VIRIAL_KEY], output_para[AtomicDataDict.VIRIAL_KEY], atol=5e-6) + assert torch.allclose(output_og[AtomicDataDict.STRESS_KEY], output_para[AtomicDataDict.STRESS_KEY], atol=1e-6) + assert torch.allclose(output_og[AtomicDataDict.VIRIAL_KEY], output_para[AtomicDataDict.VIRIAL_KEY], atol=1e-5) # Little big here caused by the summation over edges. From e1c98cd60ff8b02ff998e55fcf8f99d9ee0dbd5c Mon Sep 17 00:00:00 2001 From: Hongyu-yu Date: Fri, 16 Dec 2022 02:00:42 +0800 Subject: [PATCH 03/16] black format --- nequip/model/__init__.py | 7 ++++++- nequip/model/_grads.py | 3 ++- nequip/nn/__init__.py | 7 ++++++- nequip/nn/_grad_output.py | 18 ++++++++++++------ nequip/utils/test.py | 6 +++++- nequip/utils/unittests/conftest.py | 4 ++-- 6 files changed, 33 insertions(+), 12 deletions(-) diff --git a/nequip/model/__init__.py b/nequip/model/__init__.py index e076f201..f9c95b78 100644 --- a/nequip/model/__init__.py +++ b/nequip/model/__init__.py @@ -1,5 +1,10 @@ from ._eng import EnergyModel, SimpleIrrepsConfig -from ._grads import ForceOutput, PartialForceOutput, StressForceOutput, ParaStressForceOutput +from ._grads import ( + ForceOutput, + PartialForceOutput, + StressForceOutput, + ParaStressForceOutput, +) from ._scaling import RescaleEnergyEtc, PerSpeciesRescale from ._weight_init import ( uniform_initialize_FCs, diff --git a/nequip/model/_grads.py b/nequip/model/_grads.py index fa583224..5d860bee 100644 --- a/nequip/model/_grads.py +++ b/nequip/model/_grads.py @@ -58,6 +58,7 @@ def StressForceOutput(model: GraphModuleMixin) -> GradientOutput: raise ValueError("This model already has force or stress outputs.") return StressOutputModule(func=model) + def ParaStressForceOutput(model: GraphModuleMixin) -> GradientOutput: r"""Add forces and stresses to a model that predicts energy. @@ -72,4 +73,4 @@ def ParaStressForceOutput(model: GraphModuleMixin) -> GradientOutput: or AtomicDataDict.STRESS_KEY in model.irreps_out ): raise ValueError("This model already has force or stress outputs.") - return ParaStressOutputModule(func=model) \ No newline at end of file + return ParaStressOutputModule(func=model) diff --git a/nequip/nn/__init__.py b/nequip/nn/__init__.py index a9633544..4a2aa43f 100644 --- a/nequip/nn/__init__.py +++ b/nequip/nn/__init__.py @@ -6,7 +6,12 @@ PerSpeciesScaleShift, ) # noqa: F401 from ._interaction_block import InteractionBlock # noqa: F401 -from ._grad_output import GradientOutput, PartialForceOutput, StressOutput, ParaStressOutput # noqa: F401 +from ._grad_output import ( + GradientOutput, + PartialForceOutput, + StressOutput, + ParaStressOutput, +) # noqa: F401 from ._rescale import RescaleOutput # noqa: F401 from ._convnetlayer import ConvNetLayer # noqa: F401 from ._util import SaveForOutput # noqa: F401 diff --git a/nequip/nn/_grad_output.py b/nequip/nn/_grad_output.py index a690ea08..19839525 100644 --- a/nequip/nn/_grad_output.py +++ b/nequip/nn/_grad_output.py @@ -403,7 +403,7 @@ def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type: did_pos_req_grad: bool = pos.requires_grad pos.requires_grad_(True) - data[AtomicDataDict.POSITIONS_KEY] = pos + data[AtomicDataDict.POSITIONS_KEY] = pos data = AtomicDataDict.with_edge_vectors(data, with_lengths=False) data[AtomicDataDict.EDGE_VECTORS_KEY].requires_grad_(True) @@ -423,13 +423,19 @@ def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type: assert False, "failed to compute forces autograd" forces = torch.neg(forces) data[AtomicDataDict.FORCE_KEY] = forces - + # Store virial vector_force = grads[1] - edge_virial_1 = torch.einsum("zi,zj->zij", vector_force, data[AtomicDataDict.EDGE_VECTORS_KEY]) - edge_virial_2 = torch.einsum("zi,zj->zji", vector_force, data[AtomicDataDict.EDGE_VECTORS_KEY]) - edge_virial = (edge_virial_1 + edge_virial_2)/2 - atom_virial = scatter(edge_virial, data[AtomicDataDict.EDGE_INDEX_KEY][0], dim=0, reduce="sum") + edge_virial_1 = torch.einsum( + "zi,zj->zij", vector_force, data[AtomicDataDict.EDGE_VECTORS_KEY] + ) + edge_virial_2 = torch.einsum( + "zi,zj->zji", vector_force, data[AtomicDataDict.EDGE_VECTORS_KEY] + ) + edge_virial = (edge_virial_1 + edge_virial_2) / 2 + atom_virial = scatter( + edge_virial, data[AtomicDataDict.EDGE_INDEX_KEY][0], dim=0, reduce="sum" + ) virial = scatter(atom_virial, batch, dim=0, reduce="sum") # virial = grads[1] diff --git a/nequip/utils/test.py b/nequip/utils/test.py index fff68ca8..b92526f9 100644 --- a/nequip/utils/test.py +++ b/nequip/utils/test.py @@ -194,7 +194,11 @@ def assert_AtomicData_equivariant( # must be this to actually rotate it irps[AtomicDataDict.CELL_KEY] = "3x1o" - stress_keys = (AtomicDataDict.STRESS_KEY, AtomicDataDict.VIRIAL_KEY, AtomicDataDict.ATOM_VIRIAL_KEY) + stress_keys = ( + AtomicDataDict.STRESS_KEY, + AtomicDataDict.VIRIAL_KEY, + AtomicDataDict.ATOM_VIRIAL_KEY, + ) for k in stress_keys: irreps_in.pop(k, None) if any(k in irreps_out for k in stress_keys): diff --git a/nequip/utils/unittests/conftest.py b/nequip/utils/unittests/conftest.py index b3a4036f..a2928cee 100644 --- a/nequip/utils/unittests/conftest.py +++ b/nequip/utils/unittests/conftest.py @@ -138,7 +138,7 @@ def bulks() -> List[Atoms]: atoms_list = [] for i in range(8): atoms = bulk("C") - atoms = make_supercell(atoms,[[2,0,0],[0,2,0],[0,0,2]]) + atoms = make_supercell(atoms, [[2, 0, 0], [0, 2, 0], [0, 0, 2]]) atoms.rattle() atoms.calc = SinglePointCalculator( energy=np.random.random(), @@ -165,12 +165,12 @@ def nequip_bulk_dataset(bulks, temp_data, float_tolerance): ) yield a + @pytest.fixture(scope="session") def atomic_bulk_batch(nequip_bulk_dataset): return Batch.from_data_list([nequip_bulk_dataset[0], nequip_bulk_dataset[1]]) - @pytest.fixture(scope="function") def per_species_set(): dtype = torch.get_default_dtype() From 074bf422625b6700ec3df9cccbf645cfee717c18 Mon Sep 17 00:00:00 2001 From: Hongyu-yu Date: Fri, 16 Dec 2022 02:01:05 +0800 Subject: [PATCH 04/16] black format --- tests/unit/model/test_nequip_model.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/tests/unit/model/test_nequip_model.py b/tests/unit/model/test_nequip_model.py index 036aa386..511fdad4 100644 --- a/tests/unit/model/test_nequip_model.py +++ b/tests/unit/model/test_nequip_model.py @@ -138,20 +138,27 @@ def test_stress(self, atomic_bulk_batch, device): config_og["model_builders"] = ["EnergyModel", "StressForceOutput"] model_og = model_from_config(config=config_og, initialize=True) nn_state = model_og.state_dict() - + config_para = minimal_config2.copy() config_para["model_builders"] = ["EnergyModel", "ParaStressForceOutput"] model_para = model_from_config(config=config_para, initialize=True) - - model_para.load_state_dict(nn_state, strict = True) - + + model_para.load_state_dict(nn_state, strict=True) + model_og.to(device) model_para.to(device) data = atomic_bulk_batch.to(device) - + output_og = model_og(AtomicData.to_AtomicDataDict(data)) output_para = model_para(AtomicData.to_AtomicDataDict(data)) - - assert torch.allclose(output_og[AtomicDataDict.STRESS_KEY], output_para[AtomicDataDict.STRESS_KEY], atol=1e-6) - assert torch.allclose(output_og[AtomicDataDict.VIRIAL_KEY], output_para[AtomicDataDict.VIRIAL_KEY], atol=1e-5) # Little big here caused by the summation over edges. + assert torch.allclose( + output_og[AtomicDataDict.STRESS_KEY], + output_para[AtomicDataDict.STRESS_KEY], + atol=1e-6, + ) + assert torch.allclose( + output_og[AtomicDataDict.VIRIAL_KEY], + output_para[AtomicDataDict.VIRIAL_KEY], + atol=1e-5, + ) # Little big here caused by the summation over edges. From fa60cb7e3f16eb3f6fc046bab7ef0484e3a7746e Mon Sep 17 00:00:00 2001 From: Hongyu-yu Date: Fri, 16 Dec 2022 10:52:54 +0800 Subject: [PATCH 05/16] Update test --- .gitignore | 2 ++ nequip/nn/_grad_output.py | 12 +++++------- nequip/utils/test.py | 3 ++- tests/unit/model/test_nequip_model.py | 1 + 4 files changed, 10 insertions(+), 8 deletions(-) diff --git a/.gitignore b/.gitignore index c183223e..c64954e1 100644 --- a/.gitignore +++ b/.gitignore @@ -159,3 +159,5 @@ dmypy.json # Cython debug symbols cython_debug/ + +.history \ No newline at end of file diff --git a/nequip/nn/_grad_output.py b/nequip/nn/_grad_output.py index 19839525..84b11040 100644 --- a/nequip/nn/_grad_output.py +++ b/nequip/nn/_grad_output.py @@ -426,19 +426,18 @@ def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type: # Store virial vector_force = grads[1] - edge_virial_1 = torch.einsum( + if vector_force is None: + # condition needed to unwrap optional for torchscript + assert False, "failed to compute vector_force autograd" + edge_virial = torch.einsum( "zi,zj->zij", vector_force, data[AtomicDataDict.EDGE_VECTORS_KEY] ) - edge_virial_2 = torch.einsum( - "zi,zj->zji", vector_force, data[AtomicDataDict.EDGE_VECTORS_KEY] - ) - edge_virial = (edge_virial_1 + edge_virial_2) / 2 + edge_virial = (edge_virial + edge_virial.transpose(-1, -2)) / 2 # symmetric atom_virial = scatter( edge_virial, data[AtomicDataDict.EDGE_INDEX_KEY][0], dim=0, reduce="sum" ) virial = scatter(atom_virial, batch, dim=0, reduce="sum") - # virial = grads[1] if virial is None: # condition needed to unwrap optional for torchscript assert False, "failed to compute virial autograd" @@ -455,7 +454,6 @@ def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type: torch.cross(cell[:, 1, :], cell[:, 2, :], dim=1), ).unsqueeze(-1) stress = virial / volume.view(-1, 1, 1) - data[AtomicDataDict.CELL_KEY] = orig_cell else: stress = self._empty # torchscript data[AtomicDataDict.STRESS_KEY] = stress diff --git a/nequip/utils/test.py b/nequip/utils/test.py index b92526f9..bd6507ad 100644 --- a/nequip/utils/test.py +++ b/nequip/utils/test.py @@ -208,7 +208,8 @@ def assert_AtomicData_equivariant( stress_rtp = stress_cart_tensor.reduced_tensor_products().to(device, dtype) # symmetric 3x3 cartesian tensor as irreps for k in stress_keys: - irreps_out[k] = stress_cart_tensor + if k in irreps_out: + irreps_out[k] = stress_cart_tensor def wrapper(*args): arg_dict = {k: v for k, v in zip(irreps_in, args)} diff --git a/tests/unit/model/test_nequip_model.py b/tests/unit/model/test_nequip_model.py index 511fdad4..218e9eca 100644 --- a/tests/unit/model/test_nequip_model.py +++ b/tests/unit/model/test_nequip_model.py @@ -7,6 +7,7 @@ from nequip.model import model_from_config from nequip.nn import AtomwiseLinear from nequip.utils.unittests.model_tests import BaseEnergyModelTests +from nequip.utils.test import assert_AtomicData_equivariant COMMON_CONFIG = { "avg_num_neighbors": None, From 1b7a10300f9f7706558c79e6ac918b0b178be3a4 Mon Sep 17 00:00:00 2001 From: Hongyu-yu Date: Fri, 16 Dec 2022 10:57:16 +0800 Subject: [PATCH 06/16] flake8 --- nequip/nn/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nequip/nn/__init__.py b/nequip/nn/__init__.py index 4a2aa43f..f13ed060 100644 --- a/nequip/nn/__init__.py +++ b/nequip/nn/__init__.py @@ -6,7 +6,7 @@ PerSpeciesScaleShift, ) # noqa: F401 from ._interaction_block import InteractionBlock # noqa: F401 -from ._grad_output import ( +from ._grad_output import ( # noqa: F401 GradientOutput, PartialForceOutput, StressOutput, From 6c31b19d5a978735d7a9f4b716c10e1bdc7e656f Mon Sep 17 00:00:00 2001 From: Hongyu-yu Date: Fri, 16 Dec 2022 12:53:58 +0800 Subject: [PATCH 07/16] fix para bug in lammps --- nequip/nn/_grad_output.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nequip/nn/_grad_output.py b/nequip/nn/_grad_output.py index 84b11040..a51cee06 100644 --- a/nequip/nn/_grad_output.py +++ b/nequip/nn/_grad_output.py @@ -434,7 +434,7 @@ def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type: ) edge_virial = (edge_virial + edge_virial.transpose(-1, -2)) / 2 # symmetric atom_virial = scatter( - edge_virial, data[AtomicDataDict.EDGE_INDEX_KEY][0], dim=0, reduce="sum" + edge_virial, data[AtomicDataDict.EDGE_INDEX_KEY][0], dim=0, reduce="sum", dim_size=len(pos) ) virial = scatter(atom_virial, batch, dim=0, reduce="sum") From bb10cc6f91ca42fda452d7b22cc314d33c0e3d95 Mon Sep 17 00:00:00 2001 From: Hongyu-yu Date: Sat, 24 Dec 2022 01:53:26 +0800 Subject: [PATCH 08/16] Add reference in ParaStressOutput --- nequip/nn/_grad_output.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/nequip/nn/_grad_output.py b/nequip/nn/_grad_output.py index a51cee06..5986ef3c 100644 --- a/nequip/nn/_grad_output.py +++ b/nequip/nn/_grad_output.py @@ -340,9 +340,13 @@ def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type: @compile_mode("script") class ParaStressOutput(GraphModuleMixin, torch.nn.Module): r"""Compute stress, atomic virial, forces using autograd of an energy model. + Edge-vector based atomic virial. Design for Lammps parallism. See: + H. Yu, Y. Zhong, J. Ji, X. Gong, and H. Xiang, Time-Reversal Equivariant Neural Network Potential and Hamiltonian for Magnetic Materials, arXiv:2211.11403. + https://arxiv.org/abs/2211.11403 + Knuth et. al. Comput. Phys. Commun 190, 33-50, 2015 https://pure.mpg.de/rest/items/item_2085135_9/component/file_2156800/content From 66384516a609fee515d894edeab5d944d2a6ab67 Mon Sep 17 00:00:00 2001 From: Hongyu Yu <74477906+Hongyu-yu@users.noreply.github.com> Date: Tue, 24 Jan 2023 08:39:44 +0000 Subject: [PATCH 09/16] Format and update tests --- nequip/nn/__init__.py | 2 +- nequip/nn/_grad_output.py | 8 ++++++-- nequip/utils/unittests/conftest.py | 2 +- tests/unit/model/test_nequip_model.py | 1 - 4 files changed, 8 insertions(+), 5 deletions(-) diff --git a/nequip/nn/__init__.py b/nequip/nn/__init__.py index f13ed060..5069a117 100644 --- a/nequip/nn/__init__.py +++ b/nequip/nn/__init__.py @@ -6,7 +6,7 @@ PerSpeciesScaleShift, ) # noqa: F401 from ._interaction_block import InteractionBlock # noqa: F401 -from ._grad_output import ( # noqa: F401 +from ._grad_output import ( # noqa: F401 GradientOutput, PartialForceOutput, StressOutput, diff --git a/nequip/nn/_grad_output.py b/nequip/nn/_grad_output.py index 5986ef3c..ef9b2069 100644 --- a/nequip/nn/_grad_output.py +++ b/nequip/nn/_grad_output.py @@ -436,9 +436,13 @@ def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type: edge_virial = torch.einsum( "zi,zj->zij", vector_force, data[AtomicDataDict.EDGE_VECTORS_KEY] ) - edge_virial = (edge_virial + edge_virial.transpose(-1, -2)) / 2 # symmetric + edge_virial = (edge_virial + edge_virial.transpose(-1, -2)) / 2 # symmetric atom_virial = scatter( - edge_virial, data[AtomicDataDict.EDGE_INDEX_KEY][0], dim=0, reduce="sum", dim_size=len(pos) + edge_virial, + data[AtomicDataDict.EDGE_INDEX_KEY][0], + dim=0, + reduce="sum", + dim_size=len(pos), ) virial = scatter(atom_virial, batch, dim=0, reduce="sum") diff --git a/nequip/utils/unittests/conftest.py b/nequip/utils/unittests/conftest.py index f589608c..eddc1c44 100644 --- a/nequip/utils/unittests/conftest.py +++ b/nequip/utils/unittests/conftest.py @@ -159,7 +159,7 @@ def nequip_bulk_dataset(bulks, temp_data, float_tolerance): a = ASEDataset( file_name=fp.name, root=temp_data, - extra_fixed_fields={"r_max": 5.0}, + AtomicData_options={"r_max": 5.0}, ase_args=dict(format="extxyz"), type_mapper=TypeMapper(chemical_symbol_to_type={"H": 0, "C": 1, "O": 2}), ) diff --git a/tests/unit/model/test_nequip_model.py b/tests/unit/model/test_nequip_model.py index 218e9eca..511fdad4 100644 --- a/tests/unit/model/test_nequip_model.py +++ b/tests/unit/model/test_nequip_model.py @@ -7,7 +7,6 @@ from nequip.model import model_from_config from nequip.nn import AtomwiseLinear from nequip.utils.unittests.model_tests import BaseEnergyModelTests -from nequip.utils.test import assert_AtomicData_equivariant COMMON_CONFIG = { "avg_num_neighbors": None, From 8d283bfc3beec17b8de67fb0b190416068f88759 Mon Sep 17 00:00:00 2001 From: Hongyu-yu Date: Mon, 6 Feb 2023 13:08:07 +0800 Subject: [PATCH 10/16] fix atom_virial sign --- nequip/nn/_grad_output.py | 1 + 1 file changed, 1 insertion(+) diff --git a/nequip/nn/_grad_output.py b/nequip/nn/_grad_output.py index ef9b2069..7c47235f 100644 --- a/nequip/nn/_grad_output.py +++ b/nequip/nn/_grad_output.py @@ -471,6 +471,7 @@ def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type: # looking above this means that we need to pick up another negative sign for the virial # to fit this equation with the stress computed above virial = torch.neg(virial) + atom_virial = torch.neg(atom_virial) data[AtomicDataDict.VIRIAL_KEY] = virial data[AtomicDataDict.ATOM_VIRIAL_KEY] = atom_virial From b502de899cf0e6005a28cca4bf1cc0075570c346 Mon Sep 17 00:00:00 2001 From: Hongyu Yu <74477906+Hongyu-yu@users.noreply.github.com> Date: Tue, 28 Mar 2023 03:15:26 +0800 Subject: [PATCH 11/16] Update atom_virial in ParaStressForceOutput Mentioned by @bruncefan1983, per-atom virial is not symmetric for many-body potentials. Symmetrize virial at last. --- nequip/nn/_grad_output.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/nequip/nn/_grad_output.py b/nequip/nn/_grad_output.py index 9c2ac87f..b0b9055f 100644 --- a/nequip/nn/_grad_output.py +++ b/nequip/nn/_grad_output.py @@ -463,7 +463,6 @@ def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type: edge_virial = torch.einsum( "zi,zj->zij", vector_force, data[AtomicDataDict.EDGE_VECTORS_KEY] ) - edge_virial = (edge_virial + edge_virial.transpose(-1, -2)) / 2 # symmetric atom_virial = scatter( edge_virial, data[AtomicDataDict.EDGE_INDEX_KEY][0], @@ -471,7 +470,17 @@ def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type: reduce="sum", dim_size=len(pos), ) + # edge_virial is distributed into two nodes equally. Not sure and need more tests + atom_virial = (atom_virial + scatter( + edge_virial, + data[AtomicDataDict.EDGE_INDEX_KEY][1], + dim=0, + reduce="sum", + dim_size=len(pos), + ))/2 + virial = scatter(atom_virial, batch, dim=0, reduce="sum") + virial = (virial + virial.transpose(-1, -2)) / 2 # symmetric if virial is None: # condition needed to unwrap optional for torchscript From ddba5e4948f77400c739086731d97e0351a8240a Mon Sep 17 00:00:00 2001 From: Hongyu Yu <74477906+Hongyu-yu@users.noreply.github.com> Date: Wed, 29 Mar 2023 14:13:04 +0800 Subject: [PATCH 12/16] Update _grad_output.py --- nequip/nn/_grad_output.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/nequip/nn/_grad_output.py b/nequip/nn/_grad_output.py index bfb1fc79..382bbbbc 100644 --- a/nequip/nn/_grad_output.py +++ b/nequip/nn/_grad_output.py @@ -409,10 +409,16 @@ def __init__( self.register_buffer("_empty", torch.Tensor()) def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type: - data = AtomicDataDict.with_batch(data) - - batch = data[AtomicDataDict.BATCH_KEY] - num_batch: int = int(batch.max().cpu().item()) + 1 + assert AtomicDataDict.EDGE_VECTORS_KEY not in data + + if AtomicDataDict.BATCH_KEY in data: + batch = data[AtomicDataDict.BATCH_KEY] + num_batch: int = len(data[AtomicDataDict.BATCH_PTR_KEY]) - 1 + else: + # Special case for efficiency + batch = self._empty + num_batch: int = 1 + pos = data[AtomicDataDict.POSITIONS_KEY] has_cell: bool = AtomicDataDict.CELL_KEY in data @@ -477,10 +483,6 @@ def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type: virial = scatter(atom_virial, batch, dim=0, reduce="sum") virial = (virial + virial.transpose(-1, -2)) / 2 # symmetric - if virial is None: - # condition needed to unwrap optional for torchscript - assert False, "failed to compute virial autograd" - # we only compute the stress (1/V * virial) if we have a cell whose volume we can compute if has_cell: # ^ can only scale by cell volume if we have one...: @@ -492,7 +494,8 @@ def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type: cell[:, 0, :], torch.cross(cell[:, 1, :], cell[:, 2, :], dim=1), ).unsqueeze(-1) - stress = virial / volume.view(-1, 1, 1) + stress = virial / volume.view(num_batch, 1, 1) + data[AtomicDataDict.CELL_KEY] = orig_cell else: stress = self._empty # torchscript data[AtomicDataDict.STRESS_KEY] = stress From 56aa691b4070609c9cd9f8b5d4e8b67259d9389c Mon Sep 17 00:00:00 2001 From: Hongyu Yu <74477906+Hongyu-yu@users.noreply.github.com> Date: Wed, 29 Mar 2023 17:35:48 +0800 Subject: [PATCH 13/16] remove warning --- nequip/nn/_grad_output.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/nequip/nn/_grad_output.py b/nequip/nn/_grad_output.py index 382bbbbc..c45248f4 100644 --- a/nequip/nn/_grad_output.py +++ b/nequip/nn/_grad_output.py @@ -385,10 +385,6 @@ def __init__( ): super().__init__() - warnings.warn( - "!! Stresses in NequIP are in BETA and UNDER DEVELOPMENT: _please_ carefully check the sanity of your results and report any (potential) issues on the GitHub" - ) - if not do_forces: raise NotImplementedError self.do_forces = do_forces From d591b6ec0d09c44cd31f22cdcae12baa90a9231a Mon Sep 17 00:00:00 2001 From: Hongyu Yu <74477906+Hongyu-yu@users.noreply.github.com> Date: Wed, 29 Mar 2023 18:23:48 +0800 Subject: [PATCH 14/16] Update _grad_output.py --- nequip/nn/_grad_output.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/nequip/nn/_grad_output.py b/nequip/nn/_grad_output.py index c45248f4..d60cad81 100644 --- a/nequip/nn/_grad_output.py +++ b/nequip/nn/_grad_output.py @@ -411,8 +411,9 @@ def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type: batch = data[AtomicDataDict.BATCH_KEY] num_batch: int = len(data[AtomicDataDict.BATCH_PTR_KEY]) - 1 else: + data = AtomicDataDict.with_batch(data) # Special case for efficiency - batch = self._empty + batch = data[AtomicDataDict.BATCH_KEY] num_batch: int = 1 pos = data[AtomicDataDict.POSITIONS_KEY] From 970143764cebed3f0a8556fd0ed85da63b4c0707 Mon Sep 17 00:00:00 2001 From: Hongyu-yu Date: Wed, 29 Mar 2023 20:02:19 +0800 Subject: [PATCH 15/16] add para_stress tests workflows --- .github/workflows/tests_stress.yml | 48 ++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) create mode 100644 .github/workflows/tests_stress.yml diff --git a/.github/workflows/tests_stress.yml b/.github/workflows/tests_stress.yml new file mode 100644 index 00000000..4877e7a1 --- /dev/null +++ b/.github/workflows/tests_stress.yml @@ -0,0 +1,48 @@ +name: Run Tests + +on: + push: + branches: + - para_stress + + pull_request: + branches: + - para_stress + +jobs: + build: + + runs-on: ubuntu-latest + strategy: + matrix: + python-version: [3.9] + torch-version: [1.11.0, 1.13.1] + + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + env: + TORCH: "${{ matrix.torch-version }}" + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + run: | + python -m pip install --upgrade pip + pip install setuptools wheel + pip install torch==${TORCH} -f https://download.pytorch.org/whl/cpu/torch_stable.html + pip install h5py + pip install --upgrade-strategy only-if-needed . + - name: Install pytest + run: | + pip install pytest + pip install pytest-xdist[psutil] + - name: Download test data + run: | + mkdir benchmark_data + cd benchmark_data; wget "http://quantum-machine.org/gdml/data/npz/aspirin_ccsd.zip"; cd .. + - name: Test with pytest + run: | + # See https://github.com/pytest-dev/pytest/issues/1075 + PYTHONHASHSEED=0 pytest -n auto tests/ From 235bfe0bc9ea0ee3c5871c92c9bebb71785aa355 Mon Sep 17 00:00:00 2001 From: Hongyu-yu Date: Wed, 29 Mar 2023 21:43:27 +0800 Subject: [PATCH 16/16] fix parastress irreps --- nequip/nn/_grad_output.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/nequip/nn/_grad_output.py b/nequip/nn/_grad_output.py index d60cad81..00435a15 100644 --- a/nequip/nn/_grad_output.py +++ b/nequip/nn/_grad_output.py @@ -397,9 +397,9 @@ def __init__( irreps_out=self.func.irreps_out.copy(), ) self.irreps_out[AtomicDataDict.FORCE_KEY] = "1o" - self.irreps_out[AtomicDataDict.STRESS_KEY] = "3x1o" - self.irreps_out[AtomicDataDict.VIRIAL_KEY] = "3x1o" - self.irreps_out[AtomicDataDict.ATOM_VIRIAL_KEY] = "3x1o" + self.irreps_out[AtomicDataDict.STRESS_KEY] = "1o" + self.irreps_out[AtomicDataDict.VIRIAL_KEY] = "1o" + self.irreps_out[AtomicDataDict.ATOM_VIRIAL_KEY] = "1o" # for torchscript compat self.register_buffer("_empty", torch.Tensor())