Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(dpa3): add huber loss #4549

Merged
merged 4 commits into from
Jan 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 88 additions & 3 deletions deepmd/pt/loss/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,43 @@
)


def custom_huber_loss(predictions, targets, delta=1.0):
error = targets - predictions
abs_error = torch.abs(error)
quadratic_loss = 0.5 * torch.pow(error, 2)
linear_loss = delta * (abs_error - 0.5 * delta)
loss = torch.where(abs_error <= delta, quadratic_loss, linear_loss)
return torch.mean(loss)


def custom_step_huber_loss(predictions, targets, delta=1.0):
error = targets - predictions
abs_error = torch.abs(error)
abs_targets = torch.abs(targets)

# Define the different delta values based on the absolute value of targets
delta1 = delta
delta2 = 0.7 * delta
delta3 = 0.4 * delta
delta4 = 0.1 * delta

# Determine which delta to use based on the absolute value of targets
delta_values = torch.where(
abs_targets < 100,
delta1,
torch.where(
abs_targets < 200, delta2, torch.where(abs_targets < 300, delta3, delta4)
),
)

# Compute the quadratic and linear loss based on the dynamically selected delta values
quadratic_loss = 0.5 * torch.pow(error, 2)
linear_loss = delta_values * (abs_error - 0.5 * delta_values)
# Select the appropriate loss based on whether abs_error is less than or greater than delta_values
loss = torch.where(abs_error <= delta_values, quadratic_loss, linear_loss)
return torch.mean(loss)


class EnergyStdLoss(TaskLoss):
def __init__(
self,
Expand All @@ -41,6 +78,9 @@ def __init__(
numb_generalized_coord: int = 0,
use_l1_all: bool = False,
inference=False,
use_huber=False,
huber_delta=0.01,
torch_huber=False,
**kwargs,
) -> None:
r"""Construct a layer to compute loss on energy, force and virial.
Expand Down Expand Up @@ -118,6 +158,10 @@ def __init__(
)
self.use_l1_all = use_l1_all
self.inference = inference
self.huber = use_huber
self.huber_delta = huber_delta
self.torch_huber = torch_huber
self.huber_loss = torch.nn.HuberLoss(reduction="mean", delta=huber_delta)

def forward(self, input_dict, model, label, natoms, learning_rate, mae=False):
"""Return loss on energy and force.
Expand Down Expand Up @@ -180,7 +224,21 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False):
more_loss["l2_ener_loss"] = self.display_if_exist(
l2_ener_loss.detach(), find_energy
)
loss += atom_norm * (pref_e * l2_ener_loss)
if not self.huber:
loss += atom_norm * (pref_e * l2_ener_loss)
else:
if self.torch_huber:
l_huber_loss = self.huber_loss(
atom_norm * model_pred["energy"],
atom_norm * label["energy"],
)
else:
l_huber_loss = custom_huber_loss(
atom_norm * model_pred["energy"],
atom_norm * label["energy"],
delta=self.huber_delta,
)
loss += pref_e * l_huber_loss
rmse_e = l2_ener_loss.sqrt() * atom_norm
more_loss["rmse_e"] = self.display_if_exist(
rmse_e.detach(), find_energy
Expand Down Expand Up @@ -233,7 +291,20 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False):
more_loss["l2_force_loss"] = self.display_if_exist(
l2_force_loss.detach(), find_force
)
loss += (pref_f * l2_force_loss).to(GLOBAL_PT_FLOAT_PRECISION)
if not self.huber:
loss += (pref_f * l2_force_loss).to(GLOBAL_PT_FLOAT_PRECISION)
else:
if self.torch_huber:
l_huber_loss = self.huber_loss(
model_pred["force"], label["force"]
)
else:
l_huber_loss = custom_huber_loss(
force_pred.reshape(-1),
force_label.reshape(-1),
delta=self.huber_delta,
)
loss += pref_f * l_huber_loss
rmse_f = l2_force_loss.sqrt()
more_loss["rmse_f"] = self.display_if_exist(
rmse_f.detach(), find_force
Expand Down Expand Up @@ -304,7 +375,21 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False):
more_loss["l2_virial_loss"] = self.display_if_exist(
l2_virial_loss.detach(), find_virial
)
loss += atom_norm * (pref_v * l2_virial_loss)
if not self.huber:
loss += atom_norm * (pref_v * l2_virial_loss)
else:
if self.torch_huber:
l_huber_loss = self.huber_loss(
atom_norm * model_pred["virial"],
atom_norm * label["virial"],
)
else:
l_huber_loss = custom_huber_loss(
atom_norm * model_pred["virial"].reshape(-1),
atom_norm * label["virial"].reshape(-1),
delta=self.huber_delta,
)
loss += pref_v * l_huber_loss
rmse_v = l2_virial_loss.sqrt() * atom_norm
more_loss["rmse_v"] = self.display_if_exist(
rmse_v.detach(), find_virial
Expand Down
1 change: 1 addition & 0 deletions deepmd/pt/utils/auto_batch_size.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def is_oom_error(self, e: Exception) -> bool:
"CUDA out of memory." in e.args[0]
or "CUDA driver error: out of memory" in e.args[0]
or "cusolver error: CUSOLVER_STATUS_INTERNAL_ERROR" in e.args[0]
or "CUDA error: CUBLAS_STATUS_INTERNAL_ERROR" in e.args[0]
):
# Release all unoccupied cached memory
torch.cuda.empty_cache()
Expand Down
18 changes: 18 additions & 0 deletions deepmd/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -2506,6 +2506,24 @@ def loss_ener():
default=0,
doc=doc_numb_generalized_coord,
),
Argument(
"use_huber",
bool,
optional=True,
default=False,
),
Argument(
"huber_delta",
float,
optional=True,
default=0.01,
),
Argument(
"torch_huber",
bool,
optional=True,
default=False,
),
]


Expand Down
Loading