Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] pd: add CINN compiler for dpa2, dpa1 training #4514

Open
wants to merge 6 commits into
base: devel
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 30 additions & 17 deletions deepmd/pd/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@
get_sampler_from_params,
)
from deepmd.pd.utils.env import (
CINN,
DEFAULT_PRECISION,
DEVICE,
JIT,
NUM_WORKERS,
Expand Down Expand Up @@ -397,11 +399,11 @@
self.lr_exp = get_lr(config["learning_rate"])

# JIT
if JIT:
raise NotImplementedError(
"JIT is not supported yet when training with Paddle"
)
self.model = paddle.jit.to_static(self.model)
# if JIT:
# raise NotImplementedError(
# "JIT is not supported yet when training with Paddle"
# )
# self.model = paddle.jit.to_static(self.model)

# Model Wrapper
self.wrapper = ModelWrapper(self.model, self.loss, model_params=model_params)
Expand Down Expand Up @@ -631,6 +633,19 @@
self.profiling_file = training_params.get("profiling_file", "timeline.json")

def run(self):
if JIT:
from paddle import (
jit,
static,
)

build_strategy = static.BuildStrategy()
build_strategy.build_cinn_pass: bool = CINN
self.wrapper.forward = jit.to_static(
full_graph=True, build_strategy=build_strategy
)(self.wrapper.forward)
log.info(f"{'*' * 20} Using Jit {'*' * 20}")

fout = (
open(
self.disp_file,
Expand Down Expand Up @@ -670,9 +685,11 @@
cur_lr = _lr.value(_step_id)
pref_lr = cur_lr
self.optimizer.clear_grad(set_to_zero=False)
input_dict, label_dict, log_dict = self.get_data(
is_train=True, task_key=task_key
)

with nvprof_context(enable_profiling, "Fetching data"):
input_dict, label_dict, log_dict = self.get_data(
is_train=True, task_key=task_key
)
if SAMPLER_RECORD:
print_str = f"Step {_step_id}: sample system{log_dict['sid']} frame{log_dict['fid']}\n"
fout1.write(print_str)
Expand All @@ -686,7 +703,7 @@
with nvprof_context(enable_profiling, "Forward pass"):
model_pred, loss, more_loss = self.wrapper(
**input_dict,
cur_lr=pref_lr,
cur_lr=paddle.full([], pref_lr, DEFAULT_PRECISION),
label=label_dict,
task_key=task_key,
)
Expand Down Expand Up @@ -745,7 +762,7 @@
return {}
_, loss, more_loss = self.wrapper(
**input_dict,
cur_lr=pref_lr,
cur_lr=paddle.full([], pref_lr, DEFAULT_PRECISION),
label=label_dict,
task_key=_task_key,
)
Expand Down Expand Up @@ -795,7 +812,7 @@
)
_, loss, more_loss = self.wrapper(
**input_dict,
cur_lr=pref_lr,
cur_lr=paddle.full([], pref_lr, DEFAULT_PRECISION),
label=label_dict,
task_key=_key,
)
Expand Down Expand Up @@ -905,8 +922,8 @@
else:
model_key = "Default"
step(step_id, model_key)
if JIT:
break
# if JIT:
# break
Fixed Show fixed Hide fixed

if self.change_bias_after_training and (self.rank == 0 or dist.get_rank() == 0):
if not self.multi_task:
Expand Down Expand Up @@ -961,10 +978,6 @@
/ (elapsed_batch // self.disp_freq * self.disp_freq),
)

if JIT:
raise NotImplementedError(
"Paddle JIT saving during training is not supported yet."
)
log.info(f"Trained model has been saved to: {self.save_ckpt}")

if fout:
Expand Down
51 changes: 43 additions & 8 deletions deepmd/pd/utils/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,33 @@

paddle.device.set_device(DEVICE)

JIT = False

def to_bool(flag: int | bool | str) -> bool:
if isinstance(flag, int):
if flag not in [0, 1]:
raise ValueError(f"flag must be either 0 or 1, but received {flag}")
return bool(flag)

elif isinstance(flag, str):
flag = flag.lower()
if flag not in ["1", "0", "true", "false"]:
raise ValueError(
"flag must be either '0', '1', 'true', 'false', "
f"but received '{flag}'"
)
return flag in ["1", "true"]

elif isinstance(flag, bool):
return flag

else:
raise ValueError(
f"flag must be either int, bool, or str, but received {type(flag).__name__}"
)


JIT = to_bool(os.environ.get("JIT", False))
CINN = to_bool(os.environ.get("CINN", False))
Fixed Show fixed Hide fixed
CACHE_PER_SYS = 5 # keep at most so many sets per sys in memory
ENERGY_BIAS_TRAINABLE = True

Expand Down Expand Up @@ -138,14 +164,23 @@
]
EAGER_COMP_OP_BLACK_LIST = list(set(EAGER_COMP_OP_BLACK_LIST))

"""Enable running program in primitive C++ API in eager/static mode."""
from paddle.framework import (
core,
)
"""Enable running program with primitive operators in eager/static mode."""
if JIT:
# jit mode
paddle.framework.core._set_prim_all_enabled(enable)
if enable:
# No need to set a blacklist for now in JIT mode.
pass
else:
# eager mode
paddle.framework.core.set_prim_eager_enabled(enable)
if enable:
# Set a blacklist (i.e., disable several composite operators) in eager mode
# to enhance computational performance.
paddle.framework.core._set_prim_backward_blacklist(
*EAGER_COMP_OP_BLACK_LIST
)

core.set_prim_eager_enabled(enable)
if enable:
paddle.framework.core._set_prim_backward_blacklist(*EAGER_COMP_OP_BLACK_LIST)
log = logging.getLogger(__name__)
log.info(f"{'Enable' if enable else 'Disable'} prim in eager and static mode.")

Expand Down
Loading