diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index bd36fd6e63..f47839650a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -65,13 +65,13 @@ repos: - id: clang-format exclude: ^(source/3rdparty|source/lib/src/gpu/cudart/.+\.inc|.+\.ipynb$) # markdown, yaml, CSS, javascript - - repo: https://github.com/pre-commit/mirrors-prettier - rev: v4.0.0-alpha.8 - hooks: - - id: prettier - types_or: [markdown, yaml, css] - # workflow files cannot be modified by pre-commit.ci - exclude: ^(source/3rdparty|\.github/workflows|\.clang-format) + # - repo: https://github.com/pre-commit/mirrors-prettier + # rev: v4.0.0-alpha.8 + # hooks: + # - id: prettier + # types_or: [markdown, yaml, css] + # # workflow files cannot be modified by pre-commit.ci + # exclude: ^(source/3rdparty|\.github/workflows|\.clang-format) # Shell - repo: https://github.com/scop/pre-commit-shfmt rev: v3.10.0-2 @@ -83,25 +83,25 @@ repos: hooks: - id: cmake-format #- id: cmake-lint - - repo: https://github.com/njzjz/mirrors-bibtex-tidy - rev: v1.13.0 - hooks: - - id: bibtex-tidy - args: - - --curly - - --numeric - - --align=13 - - --blank-lines - # disable sort: the order of keys and fields has explict meanings - #- --sort=key - - --duplicates=key,doi,citation,abstract - - --merge=combine - #- --sort-fields - #- --strip-comments - - --trailing-commas - - --encode-urls - - --remove-empty-fields - - --wrap=80 + # - repo: https://github.com/njzjz/mirrors-bibtex-tidy + # rev: v1.13.0 + # hooks: + # - id: bibtex-tidy + # args: + # - --curly + # - --numeric + # - --align=13 + # - --blank-lines + # # disable sort: the order of keys and fields has explict meanings + # #- --sort=key + # - --duplicates=key,doi,citation,abstract + # - --merge=combine + # #- --sort-fields + # #- --strip-comments + # - --trailing-commas + # - --encode-urls + # - --remove-empty-fields + # - --wrap=80 # license header - repo: https://github.com/Lucas-C/pre-commit-hooks rev: v1.5.5 diff --git a/deepmd/pd/train/training.py b/deepmd/pd/train/training.py index 0f3c7a9732..a0328942e4 100644 --- a/deepmd/pd/train/training.py +++ b/deepmd/pd/train/training.py @@ -53,6 +53,8 @@ get_sampler_from_params, ) from deepmd.pd.utils.env import ( + CINN, + DEFAULT_PRECISION, DEVICE, JIT, NUM_WORKERS, @@ -631,6 +633,22 @@ def warm_up_linear(step, warmup_steps): self.profiling_file = training_params.get("profiling_file", "timeline.json") def run(self): + if CINN: + from paddle import ( + jit, + static, + ) + + build_strategy = static.BuildStrategy() + build_strategy.build_cinn_pass: bool = CINN + self.wrapper.forward = jit.to_static( + full_graph=True, build_strategy=build_strategy + )(self.wrapper.forward) + log.info( + "Enable CINN during training, there may be some additional " + "compilation time in the first traning step." + ) + fout = ( open( self.disp_file, @@ -670,9 +688,11 @@ def step(_step_id, task_key="Default") -> None: cur_lr = _lr.value(_step_id) pref_lr = cur_lr self.optimizer.clear_grad(set_to_zero=False) - input_dict, label_dict, log_dict = self.get_data( - is_train=True, task_key=task_key - ) + + with nvprof_context(enable_profiling, "Fetching data"): + input_dict, label_dict, log_dict = self.get_data( + is_train=True, task_key=task_key + ) if SAMPLER_RECORD: print_str = f"Step {_step_id}: sample system{log_dict['sid']} frame{log_dict['fid']}\n" fout1.write(print_str) @@ -686,7 +706,7 @@ def step(_step_id, task_key="Default") -> None: with nvprof_context(enable_profiling, "Forward pass"): model_pred, loss, more_loss = self.wrapper( **input_dict, - cur_lr=pref_lr, + cur_lr=paddle.full([], pref_lr, DEFAULT_PRECISION), label=label_dict, task_key=task_key, ) @@ -745,7 +765,7 @@ def log_loss_valid(_task_key="Default"): return {} _, loss, more_loss = self.wrapper( **input_dict, - cur_lr=pref_lr, + cur_lr=paddle.full([], pref_lr, DEFAULT_PRECISION), label=label_dict, task_key=_task_key, ) @@ -795,7 +815,7 @@ def log_loss_valid(_task_key="Default"): ) _, loss, more_loss = self.wrapper( **input_dict, - cur_lr=pref_lr, + cur_lr=paddle.full([], pref_lr, DEFAULT_PRECISION), label=label_dict, task_key=_key, ) diff --git a/deepmd/pd/utils/env.py b/deepmd/pd/utils/env.py index e2abe9a6e5..a21a1244ff 100644 --- a/deepmd/pd/utils/env.py +++ b/deepmd/pd/utils/env.py @@ -1,4 +1,8 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +from __future__ import ( + annotations, +) + import logging import os @@ -32,7 +36,40 @@ paddle.device.set_device(DEVICE) -JIT = False + +def to_bool(flag: int | bool | str) -> bool: + if isinstance(flag, int): + if flag not in [0, 1]: + raise ValueError(f"flag must be either 0 or 1, but received {flag}") + return bool(flag) + + elif isinstance(flag, str): + flag = flag.lower() + if flag not in ["1", "0", "true", "false"]: + raise ValueError( + "flag must be either '0', '1', 'true', 'false', " + f"but received '{flag}'" + ) + return flag in ["1", "true"] + + elif isinstance(flag, bool): + return flag + + else: + raise ValueError( + f"flag must be either int, bool, or str, but received {type(flag).__name__}" + ) + + +JIT = to_bool(os.environ.get("JIT", False)) +CINN = to_bool(os.environ.get("CINN", False)) +if CINN: + assert paddle.device.is_compiled_with_cinn(), ( + "CINN is set to True, but PaddlePaddle is not compiled with CINN support. " + "Ensure that your PaddlePaddle installation supports CINN by checking your " + "installation or recompiling with CINN enabled." + ) + CACHE_PER_SYS = 5 # keep at most so many sets per sys in memory ENERGY_BIAS_TRAINABLE = True @@ -138,14 +175,23 @@ def enable_prim(enable: bool = True): ] EAGER_COMP_OP_BLACK_LIST = list(set(EAGER_COMP_OP_BLACK_LIST)) - """Enable running program in primitive C++ API in eager/static mode.""" - from paddle.framework import ( - core, - ) + """Enable running program with primitive operators in eager/static mode.""" + if JIT or CINN: + # jit mode + paddle.framework.core._set_prim_all_enabled(enable) + if enable: + # No need to set a blacklist for now in JIT mode. + pass + else: + # eager mode + paddle.framework.core.set_prim_eager_enabled(enable) + if enable: + # Set a blacklist (i.e., disable several composite operators) in eager mode + # to enhance computational performance. + paddle.framework.core._set_prim_backward_blacklist( + *EAGER_COMP_OP_BLACK_LIST + ) - core.set_prim_eager_enabled(enable) - if enable: - paddle.framework.core._set_prim_backward_blacklist(*EAGER_COMP_OP_BLACK_LIST) log = logging.getLogger(__name__) log.info(f"{'Enable' if enable else 'Disable'} prim in eager and static mode.")