Skip to content

Commit

Permalink
remove ase dependence from nn modules
Browse files Browse the repository at this point in the history
  • Loading branch information
cw-tan committed Oct 22, 2024
1 parent 70daf0e commit 9ec3289
Show file tree
Hide file tree
Showing 2 changed files with 133 additions and 3 deletions.
131 changes: 131 additions & 0 deletions nequip/data/misc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
"""
Relevant chunk from ASE
https://gitlab.com/ase/ase/-/blob/master/ase/data/__init__.py
to reduce dependencies for nn modules
"""

chemical_symbols = [
"X",
"H",
"He",
"Li",
"Be",
"B",
"C",
"N",
"O",
"F",
"Ne",
"Na",
"Mg",
"Al",
"Si",
"P",
"S",
"Cl",
"Ar",
"K",
"Ca",
"Sc",
"Ti",
"V",
"Cr",
"Mn",
"Fe",
"Co",
"Ni",
"Cu",
"Zn",
"Ga",
"Ge",
"As",
"Se",
"Br",
"Kr",
"Rb",
"Sr",
"Y",
"Zr",
"Nb",
"Mo",
"Tc",
"Ru",
"Rh",
"Pd",
"Ag",
"Cd",
"In",
"Sn",
"Sb",
"Te",
"I",
"Xe",
"Cs",
"Ba",
"La",
"Ce",
"Pr",
"Nd",
"Pm",
"Sm",
"Eu",
"Gd",
"Tb",
"Dy",
"Ho",
"Er",
"Tm",
"Yb",
"Lu",
"Hf",
"Ta",
"W",
"Re",
"Os",
"Ir",
"Pt",
"Au",
"Hg",
"Tl",
"Pb",
"Bi",
"Po",
"At",
"Rn",
"Fr",
"Ra",
"Ac",
"Th",
"Pa",
"U",
"Np",
"Pu",
"Am",
"Cm",
"Bk",
"Cf",
"Es",
"Fm",
"Md",
"No",
"Lr",
"Rf",
"Db",
"Sg",
"Bh",
"Hs",
"Mt",
"Ds",
"Rg",
"Cn",
"Nh",
"Fl",
"Mc",
"Lv",
"Ts",
"Og",
]

chemical_symbols_to_atomic_numbers_dict = {
symbol: Z for Z, symbol in enumerate(chemical_symbols)
}
5 changes: 2 additions & 3 deletions nequip/nn/pair_potential.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@

from e3nn.util.jit import compile_mode

import ase.data

from nequip.data import AtomicDataDict
from nequip.data.misc import chemical_symbols_to_atomic_numbers_dict
from ._util import scatter
from ._graph_mixin import GraphModuleMixin
from nequip.utils import conditional_torchscript_jit
Expand Down Expand Up @@ -269,7 +268,7 @@ def __init__(
)
assert len(chemical_species) == num_types
atomic_numbers: List[int] = [
ase.data.atomic_numbers[chemical_species[type_i]]
chemical_symbols_to_atomic_numbers_dict[chemical_species[type_i]]
for type_i in range(num_types)
]
if min(atomic_numbers) < 1:
Expand Down

0 comments on commit 9ec3289

Please sign in to comment.