Skip to content

Commit

Permalink
Merge branch 'main' into fix_module
Browse files Browse the repository at this point in the history
  • Loading branch information
RaulPPelaez committed Oct 23, 2023
2 parents 377fb08 + 2cc5395 commit 51f666c
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 20 deletions.
39 changes: 38 additions & 1 deletion tests/test_calculator.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import torch
from torch.testing import assert_allclose
import pytest
from pytest import mark
from glob import glob
from os.path import dirname, join
from torchmdnet.calculators import External
from torchmdnet.models.model import load_model
from torchmdnet.models.model import load_model, create_model

from utils import create_example_batch

Expand All @@ -21,6 +22,42 @@ def test_compare_forward():
assert_allclose(e_calc, e_pred)
assert_allclose(f_calc, f_pred.unsqueeze(0))

def test_compare_forward_cuda_graph():
if not torch.cuda.is_available():
pytest.skip("CUDA not available")
checkpoint = join(dirname(dirname(__file__)), "tests", "example.ckpt")
args = {"model": "tensornet",
"embedding_dimension": 128,
"num_layers": 2,
"num_rbf": 32,
"rbf_type": "expnorm",
"trainable_rbf": False,
"activation": "silu",
"cutoff_lower": 0.0,
"cutoff_upper": 5.0,
"max_z": 100,
"max_num_neighbors": 128,
"equivariance_invariance_group": "O(3)",
"prior_model": None,
"atom_filter": -1,
"derivative": True,
"output_model": "Scalar",
"reduce_op": "sum",
"precision": 32 }
model = create_model(args).to(device="cuda")
z, pos, _ = create_example_batch(multiple_batches=False)
z = z.to("cuda")
pos = pos.to("cuda")
calc = External(checkpoint, z.unsqueeze(0), use_cuda_graph=False, device="cuda")
calc_graph = External(checkpoint, z.unsqueeze(0), use_cuda_graph=True, device="cuda")
calc.model = model
calc_graph.model = model
for _ in range(10):
e_calc, f_calc = calc.calculate(pos, None)
e_pred, f_pred = calc_graph.calculate(pos, None)
assert_allclose(e_calc, e_pred)
assert_allclose(f_calc, f_pred)


def test_compare_forward_multiple():
checkpoint = join(dirname(dirname(__file__)), "tests", "example.ckpt")
Expand Down
86 changes: 76 additions & 10 deletions torchmdnet/calculators.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,43 @@

class External:
"""
The External class is used to calculate the energy and forces of an external potential, such as a neural network. The class is initialized with the path to the neural network
ckpt, the embeddings, the device on which the neural network should be run and the output_transform argument. The output_transform is used to give a function that transform
the energy and the forces, this could be a preset transform or a custom function. In this way there is no constraint to the units of the neural network, the user can choose
the units of the simulation and the neural network will be automatically converted to the units of the simulation. The function should take two arguments, the energy and the
forces, and return the transformed energy and the transformed forces.
This is an adapter to use TorchMD-Net models in TorchMD.
Parameters
----------
netfile : str or torch.nn.Module
Path to the checkpoint file of the model or the model itself.
embeddings : torch.Tensor
Embeddings of the atoms in the system.
device : str, optional
Device on which the model should be run. Default: "cpu"
output_transform : str or callable, optional
Transform to apply to the energy and forces.
If a string is given, it should be a key in the `transforms` dict.
If a callable is given, it should take two arguments (energy and forces) and return two tensors of the same shape.
Default: None
use_cuda_graph : bool, optional
Whether to use CUDA graphs to speed up the calculation. Default: False
cuda_graph_warmup_steps : int, optional
Number of steps to run as warmup before recording the CUDA graph. Default: 12
"""

def __init__(self, netfile, embeddings, device="cpu", output_transform=None):
self.model = load_model(netfile, device=device, derivative=True)
def __init__(
self,
netfile,
embeddings,
device="cpu",
output_transform=None,
use_cuda_graph=False,
cuda_graph_warmup_steps=12,
):
if isinstance(netfile, str):
self.model = load_model(netfile, device=device, derivative=True)
elif isinstance(netfile, torch.nn.Module):
self.model = netfile
else:
raise ValueError(
f"Expected a path to a checkpoint file or a torch.nn.Module, got {type(netfile)}"
)
self.device = device
self.n_atoms = embeddings.size(1)
self.embeddings = embeddings.reshape(-1).to(device)
Expand All @@ -46,11 +74,49 @@ def __init__(self, netfile, embeddings, device="cpu", output_transform=None):
self.output_transformer = tranforms[output_transform]
else:
self.output_transformer = eval(output_transform)
if not torch.cuda.is_available() and use_cuda_graph:
raise ValueError("CUDA graphs are only available if CUDA is")
self.use_cuda_graph = use_cuda_graph
self.cuda_graph_warmup_steps = cuda_graph_warmup_steps
self.cuda_graph = None
self.energy = None
self.forces = None
self.pos = None

def _init_cuda_graph(self):
stream = torch.cuda.Stream()
self.cuda_graph = torch.cuda.CUDAGraph()
with torch.cuda.stream(stream):
for _ in range(self.cuda_graph_warmup_steps):
self.energy, self.forces = self.model(
self.embeddings, self.pos, self.batch
)
with torch.cuda.graph(self.cuda_graph):
self.energy, self.forces = self.model(
self.embeddings, self.pos, self.batch
)

def calculate(self, pos, box):
pos = pos.to(self.device).type(torch.float32).reshape(-1, 3)
energy, forces = self.model(self.embeddings, pos, self.batch)

if self.use_cuda_graph:
if self.pos is None:
self.pos = (
pos.clone()
.to(self.device)
.detach()
.requires_grad_(pos.requires_grad)
)
if self.cuda_graph is None:
self._init_cuda_graph()
assert self.cuda_graph is not None, "CUDA graph is not initialized. This should not had happened."
with torch.no_grad():
self.pos.copy_(pos)
self.cuda_graph.replay()
else:
self.energy, self.forces = self.model(self.embeddings, pos, self.batch)
assert self.forces is not None, "The model is not returning forces"
assert self.energy is not None, "The model is not returning energy"
return self.output_transformer(
energy.detach(), forces.reshape(-1, self.n_atoms, 3).detach()
self.energy.clone().detach(),
self.forces.clone().reshape(-1, self.n_atoms, 3).detach(),
)
4 changes: 2 additions & 2 deletions torchmdnet/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,8 +289,8 @@ def forward(
[y],
[pos],
grad_outputs=grad_outputs,
create_graph=True,
retain_graph=True,
create_graph=self.training,
retain_graph=self.training,
)[0]
if dy is None:
raise RuntimeError("Autograd returned None for the force prediction.")
Expand Down
19 changes: 12 additions & 7 deletions torchmdnet/models/tensornet.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,26 +211,31 @@ def forward(
assert (
edge_vec is not None
), "Distance module did not return directional information"
# Distance module returns -1 for non-existing edges, to avoid having to resize the tensors when we want to ensure static shapes (for CUDA graphs) we make all non-existing edges pertain to the first atom
# Distance module returns -1 for non-existing edges, to avoid having to resize the tensors when we want to ensure static shapes (for CUDA graphs) we make all non-existing edges pertain to a ghost atom
zp = z
if self.static_shapes:
mask = (edge_index[0] >= 0).unsqueeze(0).expand_as(edge_index)
# I trick the model into thinking that the masked edges pertain to the first atom
mask = (edge_index[0] < 0).unsqueeze(0).expand_as(edge_index)
zp = torch.cat((z, torch.zeros(1, device=z.device, dtype=z.dtype)), dim=0)
# I trick the model into thinking that the masked edges pertain to the extra atom
# WARNING: This can hurt performance if max_num_pairs >> actual_num_pairs
edge_index = edge_index * mask
edge_weight = edge_weight * mask[0]
edge_vec = edge_vec * mask[0].unsqueeze(-1).expand_as(edge_vec)
edge_index = edge_index.masked_fill(mask, z.shape[0])
edge_weight = edge_weight.masked_fill(mask[0], 0)
edge_vec = edge_vec.masked_fill(mask[0].unsqueeze(-1).expand_as(edge_vec), 0)
edge_attr = self.distance_expansion(edge_weight)
mask = edge_index[0] == edge_index[1]
# Normalizing edge vectors by their length can result in NaNs, breaking Autograd.
# I avoid dividing by zero by setting the weight of self edges and self loops to 1
edge_vec = edge_vec / edge_weight.masked_fill(mask, 1).unsqueeze(1)
X = self.tensor_embedding(z, edge_index, edge_weight, edge_vec, edge_attr)
X = self.tensor_embedding(zp, edge_index, edge_weight, edge_vec, edge_attr)
for layer in self.layers:
X = layer(X, edge_index, edge_weight, edge_attr)
I, A, S = decompose_tensor(X)
x = torch.cat((tensor_norm(I), tensor_norm(A), tensor_norm(S)), dim=-1)
x = self.out_norm(x)
x = self.act(self.linear((x)))
# # Remove the extra atom
if self.static_shapes:
x = x[:-1]
return x, None, z, pos, batch


Expand Down

0 comments on commit 51f666c

Please sign in to comment.