Skip to content

Commit

Permalink
add dpa1 + lammps inference
Browse files Browse the repository at this point in the history
  • Loading branch information
HydrogenSulfate committed Jan 15, 2025
1 parent 3e9cf88 commit a584e70
Show file tree
Hide file tree
Showing 11 changed files with 82 additions and 42 deletions.
54 changes: 27 additions & 27 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ repos:
- id: trailing-whitespace
exclude: "^.+\\.pbtxt$"
- id: end-of-file-fixer
exclude: "^.+\\.pbtxt$|deeppot_sea.*\\.json$"
exclude: "^.+\\.pbtxt$|deeppot_sea.*\\.json$|dpa1.*\\.json$"
- id: check-yaml
- id: check-json
- id: check-added-large-files
Expand Down Expand Up @@ -65,13 +65,13 @@ repos:
- id: clang-format
exclude: ^(source/3rdparty|source/lib/src/gpu/cudart/.+\.inc|.+\.ipynb$|.+\.json$)
# markdown, yaml, CSS, javascript
- repo: https://github.com/pre-commit/mirrors-prettier
rev: v4.0.0-alpha.8
hooks:
- id: prettier
types_or: [markdown, yaml, css]
# workflow files cannot be modified by pre-commit.ci
exclude: ^(source/3rdparty|\.github/workflows|\.clang-format)
# - repo: https://github.com/pre-commit/mirrors-prettier
# rev: v4.0.0-alpha.8
# hooks:
# - id: prettier
# types_or: [markdown, yaml, css]
# # workflow files cannot be modified by pre-commit.ci
# exclude: ^(source/3rdparty|\.github/workflows|\.clang-format)
# Shell
- repo: https://github.com/scop/pre-commit-shfmt
rev: v3.10.0-2
Expand All @@ -83,25 +83,25 @@ repos:
hooks:
- id: cmake-format
#- id: cmake-lint
- repo: https://github.com/njzjz/mirrors-bibtex-tidy
rev: v1.13.0
hooks:
- id: bibtex-tidy
args:
- --curly
- --numeric
- --align=13
- --blank-lines
# disable sort: the order of keys and fields has explict meanings
#- --sort=key
- --duplicates=key,doi,citation,abstract
- --merge=combine
#- --sort-fields
#- --strip-comments
- --trailing-commas
- --encode-urls
- --remove-empty-fields
- --wrap=80
# - repo: https://github.com/njzjz/mirrors-bibtex-tidy
# rev: v1.13.0
# hooks:
# - id: bibtex-tidy
# args:
# - --curly
# - --numeric
# - --align=13
# - --blank-lines
# # disable sort: the order of keys and fields has explict meanings
# #- --sort=key
# - --duplicates=key,doi,citation,abstract
# - --merge=combine
# #- --sort-fields
# #- --strip-comments
# - --trailing-commas
# - --encode-urls
# - --remove-empty-fields
# - --wrap=80
# license header
- repo: https://github.com/Lucas-C/pre-commit-hooks
rev: v1.5.5
Expand Down
53 changes: 43 additions & 10 deletions deepmd/pd/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,26 +354,59 @@ def freeze(
InputSpec,
)

# coord = paddle.load("./coord.pdin")
# atype = paddle.load("./atype.pdin")
# box = paddle.load("./box.pdin")
# od = model.forward(coord, atype, box, do_atomic_virial=True)
# for k, v in od.items():
# if isinstance(v, paddle.Tensor):
# paddle.save(v, f'{k}.pdout')
# exit()
""" example output shape and dtype of forward
atom_energy: fetch_name_0 (1, 6, 1) float64
atom_virial: fetch_name_1 (1, 6, 1, 9) float64
energy: fetch_name_2 (1, 1) float64
force: fetch_name_3 (1, 6, 3) float64
mask: fetch_name_4 (1, 6) int32
virial: fetch_name_5 (1, 9) float64
"""
** coord [None, natoms, 3] paddle.float64
** atype [None, natoms] paddle.int64
** nlist [None, natoms, nnei] paddle.int32
model.forward = paddle.jit.to_static(
model.forward,
full_graph=True,
input_spec=[
InputSpec([1, -1, 3], dtype="float64", name="coord"), # coord
InputSpec([1, -1], dtype="int64", name="atype"), # atype
InputSpec([1, 9], dtype="float64", name="box"), # box
None, # fparam
None, # aparam
True, # do_atomic_virial
],
)
""" example output shape and dtype of forward_lower
fetch_name_0: atom_energy [1, 192, 1] paddle.float64
fetch_name_1: energy [1, 1] paddle.float64
fetch_name_2: extended_force [1, 5184, 3] paddle.float64
fetch_name_3: extended_virial [1, 5184, 1, 9] paddle.float64
fetch_name_4: virial [1, 9] paddle.float64
"""
# NOTE: 'FLAGS_save_cf_stack_op', 'FLAGS_prim_enable_dynamic' and
# 'FLAGS_enable_pir_api' shoule be enabled when freezing model.
jit_model = paddle.jit.to_static(
model.forward_lower = paddle.jit.to_static(
model.forward_lower,
full_graph=True,
input_spec=[
InputSpec([-1, -1, 3], dtype="float64", name="coord"),
InputSpec([-1, -1], dtype="int32", name="atype"),
InputSpec([-1, -1, -1], dtype="int32", name="nlist"),
InputSpec([1, -1, 3], dtype="float64", name="coord"), # extended_coord
InputSpec([1, -1], dtype="int32", name="atype"), # extended_atype
InputSpec([1, -1, -1], dtype="int32", name="nlist"), # nlist
InputSpec([1, -1], dtype="int64", name="mapping"), # mapping
None, # fparam
None, # aparam
True, # do_atomic_virial
None, # comm_dict
],
)
if output.endswith(".json"):
output = output[:-5]
paddle.jit.save(
jit_model,
model,
path=output,
skip_prune_program=True,
)
Expand Down
3 changes: 2 additions & 1 deletion deepmd/pd/model/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
prod_env_mat,
)
from deepmd.pd.utils import (
decomp,
env,
)
from deepmd.pd.utils.env import (
Expand Down Expand Up @@ -744,7 +745,7 @@ def forward(
"Compressed environment is not implemented yet."
)
else:
if rr.numel() > 0:
if decomp.numel(rr) > 0:
rr = rr * mm.unsqueeze(2).astype(rr.dtype)
ss = rr[:, :, :1]
if self.compress:
Expand Down
3 changes: 2 additions & 1 deletion deepmd/pd/utils/decomp.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,8 @@ def numel(x: paddle.Tensor) -> int:
if paddle.in_dynamic_mode():
return np.prod(x.shape)

return paddle.numel(x)
return True
# return paddle.numel(x)


# alias for decomposed functions for convinience
Expand Down
5 changes: 3 additions & 2 deletions deepmd/pd/utils/nlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import paddle

from deepmd.pd.utils import (
decomp,
env,
)
from deepmd.pd.utils.region import (
Expand Down Expand Up @@ -97,7 +98,7 @@ def build_neighbor_list(
nall = coord.shape[1] // 3
# fill virtual atoms with large coords so they are not neighbors of any
# real atom.
if coord.numel() > 0:
if decomp.numel(coord) > 0:
xmax = paddle.max(coord) + 2.0 * rcut
else:
xmax = paddle.zeros([], dtype=coord.dtype).to(device=coord.place) + 2.0 * rcut
Expand Down Expand Up @@ -240,7 +241,7 @@ def build_directional_neighbor_list(
nall_neig = coord_neig.shape[1] // 3
# fill virtual atoms with large coords so they are not neighbors of any
# real atom.
if coord_neig.numel() > 0:
if decomp.numel(coord_neig) > 0:
xmax = paddle.max(coord_cntl) + 2.0 * rcut
else:
xmax = (
Expand Down
1 change: 1 addition & 0 deletions examples/water/se_atten/dpa1_infer.forward_lower.json

Large diffs are not rendered by default.

Binary file not shown.
1 change: 1 addition & 0 deletions examples/water/se_atten/dpa1_infer.json

Large diffs are not rendered by default.

Binary file added examples/water/se_atten/dpa1_infer.pdiparams
Binary file not shown.
1 change: 1 addition & 0 deletions source/api_cc/src/DeepPotPD.cc
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ void DeepPotPD::init(const std::string& model,
// initialize hyper params from model buffers
ntypes_spin = 0;
DeepPotPD::get_buffer<int>("buffer_has_message_passing", do_message_passing);
this->do_message_passing = 0;
DeepPotPD::get_buffer<double>("buffer_rcut", rcut);
DeepPotPD::get_buffer<int>("buffer_ntypes", ntypes);
DeepPotPD::get_buffer<int>("buffer_dfparam", dfparam);
Expand Down
3 changes: 2 additions & 1 deletion source/lmp/plugin/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ if(DEFINED LAMMPS_SOURCE_ROOT OR DEFINED LAMMPS_VERSION)
lammps_download
GIT_REPOSITORY https://github.com/lammps/lammps
GIT_TAG ${LAMMPS_VERSION})
message(STATUS "STARTING DOWNLOAD LAMMPS TO: " ${LAMMPS_SOURCE_ROOT})
message(STATUS "STARTING DOWNLOAD LAMMPS TO: "
${lammps_download_SOURCE_DIR})
FetchContent_MakeAvailable(lammps_download)
set(LAMMPS_SOURCE_ROOT ${lammps_download_SOURCE_DIR})
endif()
Expand Down

0 comments on commit a584e70

Please sign in to comment.