diff --git a/torchmdnet/module.py b/torchmdnet/module.py index 9c66c35b2..e30a1e0ab 100644 --- a/torchmdnet/module.py +++ b/torchmdnet/module.py @@ -97,7 +97,9 @@ def _compute_losses(self, y, neg_y, batch, loss_fn, stage): # Returns: # loss_y: loss for the predicted value # loss_neg_y: loss for the predicted negative derivative - loss_y, loss_neg_y = 0.0, 0.0 + loss_y, loss_neg_y = torch.tensor(0.0, device=self.device), torch.tensor( + 0.0, device=self.device + ) loss_name = loss_fn.__name__ if self.hparams.derivative and "neg_dy" in batch: loss_neg_y = loss_fn(neg_y, batch.neg_dy)