diff --git a/sc_utils.py b/sc_utils.py index 9a688f4..5f30779 100644 --- a/sc_utils.py +++ b/sc_utils.py @@ -208,12 +208,12 @@ def make_torsion_features(feature_dict, repack_everything=True): S_af2, torch.tensor(restype_rigid_group_default_frame, device=device), ) - + xyz14_noised = feats.frames_and_literature_positions_to_atom14_pos( pred_frames, S_af2, torch.tensor(restype_rigid_group_default_frame, device=device), - torch.tensor(restype_atom14_to_rigid_group, device=device), + torch.tensor(restype_atom14_to_rigid_group, device=device).long(), torch.tensor(restype_atom14_mask, device=device), torch.tensor(restype_atom14_rigid_group_positions, device=device), )