-
Notifications
You must be signed in to change notification settings - Fork 525
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] pd: add CINN compiler for dpa2, dpa1 training #4514
base: devel
Are you sure you want to change the base?
[WIP] pd: add CINN compiler for dpa2, dpa1 training #4514
Conversation
📝 WalkthroughWalkthroughThe pull request introduces modifications to the Changes
Sequence DiagramsequenceDiagram
participant Env as Environment
participant Training as Training Module
participant Model as Model
Env->>Env: Configure JIT and CINN settings
Env->>Training: Set precision and compilation strategy
Training->>Model: Apply JIT compilation
Model-->>Training: Optimize forward pass
Training->>Training: Profile performance
The sequence diagram illustrates the flow of configuration and optimization process, showing how environment settings influence the training module and model compilation. 📜 Recent review detailsConfiguration used: CodeRabbit UI 📒 Files selected for processing (2)
✅ Files skipped from review due to trivial changes (1)
🚧 Files skipped from review as they are similar to previous changes (1)
Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media? 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (2)
deepmd/pd/train/training.py (2)
402-405
: Remove or clarify commented-out code.These lines comment out a previously raised
NotImplementedError
and a potentialpaddle.jit.to_static
call. If you no longer need this logic, removing it altogether might reduce confusion and keep the file tidy. Otherwise, add a comment explaining why these lines are kept for future reference.-# if JIT: -# raise NotImplementedError("JIT is not supported yet when training with Paddle") -# self.model = paddle.jit.to_static(self.model)Also applies to: 406-406
925-926
: Consider removing the extra commented-out code.This snippet appears to comment out a JIT debugging break. If it's no longer needed, removing it can avoid potential confusion.
-# if JIT: -# break🧰 Tools
🪛 GitHub Check: CodeQL
[notice] 925-926: Commented-out code
This comment appears to contain commented-out code.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
deepmd/pd/train/training.py
(8 hunks)deepmd/pd/utils/env.py
(2 hunks)
🧰 Additional context used
🪛 GitHub Check: CodeQL
deepmd/pd/utils/env.py
[notice] 61-61: Unused global variable
The global variable 'CINN' is not used.
deepmd/pd/train/training.py
[notice] 925-926: Commented-out code
This comment appears to contain commented-out code.
🔇 Additional comments (7)
deepmd/pd/train/training.py (4)
56-57
: Check the new imports for consistency.
The addition of CINN
and DEFAULT_PRECISION
is consistent with improvements to JIT compilation and precision handling. Nothing problematic is observed here; just ensure that CINN
is successfully imported where used and that DEFAULT_PRECISION
is consistently applied.
636-648
: JIT and CINN integration logic looks good.
You're conditionally enabling JIT using jit.to_static
and passing the build_strategy.build_cinn_pass = CINN
. This is a clean approach, ensuring that CINN is only used if enabled. Just verify upstream usage to avoid unexpected behaviors if CINN
is disabled at runtime.
688-692
: Validate data fetching performance within profiling context.
Wrapping the data loading with nvprof_context
helps profile overhead. Ensure that exceptions thrown within this block are properly handled so that the profiler is closed gracefully.
706-706
: Precision usage for learning rate.
Using paddle.full([], pref_lr, DEFAULT_PRECISION)
enforces consistent floating-point precision for the learning rate. This is beneficial for uniformity, especially in multi-task scenarios. Good practice!
Also applies to: 765-765, 815-815
deepmd/pd/utils/env.py (3)
36-50
: Robust input validation in to_bool
.
Excellent job handling integers, booleans, and string values thoroughly, with clear error messages and suitable lowercasing of strings. This ensures minimal confusion for environment variable parsing.
60-61
: CINN usage is valid despite static analysis hint.
While static analysis tools suggest that CINN
is “unused,” it is in fact used within training.py
(via build_strategy.build_cinn_pass
). You can safely disregard that hint.
🧰 Tools
🪛 GitHub Check: CodeQL
[notice] 61-61: Unused global variable
The global variable 'CINN' is not used.
167-182
: Flexible prim enabling for JIT and eager modes.
Switching between JIT mode and eager mode, while setting blacklists only in eager mode, is a neat approach to leverage performance improvements. This design offers more control and clarity to developers using either mode.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (3)
deepmd/pd/train/training.py (3)
709-709
: Confirm precision alignment for learning rate.
Replacing a float withpaddle.full([], pref_lr, DEFAULT_PRECISION)
ensures consistent floating-point precision. Verify that using a 0-dimensional tensor is intentional and that the rest of the forward pass expects a scalar.
768-768
: Avoid code duplication in validation logic.
This pattern repeats the same approach for creating a scalar tensor frompref_lr
. Consider encapsulating the logic to reduce repetitive calls and simplify maintenance.
818-818
: Maintain consistent approach for multi-task calls.
Reusingpaddle.full([], pref_lr, DEFAULT_PRECISION)
in multiple task keys is valid. However, centralizing this snippet could prevent future errors in multi-task scenarios.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
deepmd/pd/train/training.py
(6 hunks)deepmd/pd/utils/env.py
(2 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- deepmd/pd/utils/env.py
🔇 Additional comments (2)
deepmd/pd/train/training.py (2)
56-57
: Use environment-driven feature flags carefully.
Importing CINN
and DEFAULT_PRECISION
helps toggle experimental optimizations and ensure matching data types throughout training. This is acceptable, but confirm that these environment variables or configurations are properly documented and tested to avoid accidental misconfigurations.
691-695
: Validate data fetching within profiler context.
The usage of nvprof_context
around get_data()
is beneficial for profiling. Make sure exceptions (e.g., StopIteration
) in data fetching are handled gracefully to preserve consistent profiling measurements.
if CINN: | ||
from paddle import ( | ||
jit, | ||
static, | ||
) | ||
|
||
build_strategy = static.BuildStrategy() | ||
build_strategy.build_cinn_pass: bool = CINN | ||
self.wrapper.forward = jit.to_static( | ||
full_graph=True, build_strategy=build_strategy | ||
)(self.wrapper.forward) | ||
log.info( | ||
"Enable CINN during training, there may be some additional " | ||
"compilation time in the first traning step." | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Refine the CINN build strategy assignment.
build_strategy.build_cinn_pass: bool = CINN
Here, the colon syntax is typically a type hint, but does not assign the field at runtime. Consider switching to
build_strategy.build_cinn_pass = CINN
to ensure CINN is actually enabled. Otherwise, it may silently fail to apply the intended optimization.
- build_strategy.build_cinn_pass: bool = CINN
+ build_strategy.build_cinn_pass = CINN
95b201d
to
5cdd421
Compare
5cdd421
to
7ca2a9e
Compare
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## devel #4514 +/- ##
==========================================
- Coverage 84.58% 84.55% -0.03%
==========================================
Files 675 675
Lines 63694 63722 +28
Branches 3486 3487 +1
==========================================
+ Hits 53873 53880 +7
- Misses 8697 8716 +19
- Partials 1124 1126 +2 ☔ View full report in Codecov by Sentry. |
@njzjz I have a question about the code at deepmd-kit/deepmd/pt/model/network/layernorm.py Lines 97 to 99 in 8d4c27b
numel()==0 ? I haven't noticed this happening during Python training and testing. Could you please help me with this? Thank you!
|
When using LAMMPS with MPI. |
See also #2668. Some users may encounter the situation where a processor has no atom. |
Thank you for your response. Although 0-size tensors are not very common, we have indeed encountered similar issues with some object detection models. We are going to support for training and inference with 0-size tensors in Paddle in the near future. |
We verified paddle CINN compiler in DPA-2 example(single A100-SXM (40G), cada11.8, Intel(R) Xeon(R) Gold 6248 CPU @ 2.50GHz x 160).
To enable CINN compiler in training, add one flag:
CINN=1
before training command, e.g.CINN=1 dp --pd train input_torch_medium.json
.Curves:
dpa2
Performance
We tested with torch==2.6.0.dev20241219+cu118
se_atten
Performance
We tested with torch==2.6.0.dev20241219+cu118
Accuracy details:
dpa2
Pytorch:
Paddle(eager mode):
Paddle(CINN compiler)
se_atten
Pytorch
Paddle(eager_mode)
Paddle(CINN compliler)
TODO:
Summary by CodeRabbit
Summary by CodeRabbit
New Features
CINN
andDEFAULT_PRECISION
.Improvements
Bug Fixes
Chores
prettier
andbibtex-tidy
.