diff --git a/README.md b/README.md index 5123ec4..298adf0 100644 --- a/README.md +++ b/README.md @@ -78,7 +78,7 @@ To run the model of your choice specify `--model_type` and optionally the model --checkpoint_per_residue_label_membrane_mpnn "./model_params/per_residue_label_membrane_mpnn_v_48_020.pt" #noised with 0.20A Gaussian noise ``` -## Examples +## Design examples ### 1 default Default settings will run ProteinMPNN. ``` @@ -461,6 +461,77 @@ python run.py \ --parse_atoms_with_zero_occupancy 1 ``` +## Scoring examples +### Output dictionary +``` +out_dict = {} +out_dict["logits"] - raw logits from the model +out_dict["probs"] - softmax(logits) +out_dict["log_probs"] - log_softmax(logits) +out_dict["decoding_order"] - decoding order used (logits will depend on the decoding order) +out_dict["native_sequence"] - parsed input sequence in integers +out_dict["mask"] - mask for missing residues (usually all ones) +out_dict["chain_mask"] - controls which residues are decoded first +out_dict["alphabet"] - amino acid alphabet used +out_dict["residue_names"] - dictionary to map integers to residue_names, e.g. {0: "C10", 1: "C11"} +out_dict["sequence"] - parsed input sequence in alphabet +out_dict["mean_of_probs"] - averaged over batch_size*number_of_batches probabilities, [protein_length, 21] +out_dict["std_of_probs"] - same as above, but std +``` + +### 1 autoregressive with sequence info +Get probabilities/scores for backbone-sequence pairs using autoregressive probabilities: p(AA_1|backbone), p(AA_2|backbone, AA_1) etc. These probabilities will depend on the decoding order, so it's recomended to set number_of_batches to at least 10. +``` +python score.py \ + --model_type "ligand_mpnn" \ + --seed 111 \ + --autoregressive_score 1\ + --pdb_path "./outputs/ligandmpnn_default/backbones/1BC8_1.pdb" \ + --out_folder "./outputs/autoregressive_score_w_seq" \ + --use_sequence 1\ + --batch_size 1 \ + --number_of_batches 10 +``` +### 2 autoregressive with backbone info only +Get probabilities/scores for backbone using probabilities: p(AA_1|backbone), p(AA_2|backbone) etc. These probabilities will depend on the decoding order, so it's recomended to set number_of_batches to at least 10. +``` +python score.py \ + --model_type "ligand_mpnn" \ + --seed 111 \ + --autoregressive_score 1\ + --pdb_path "./outputs/ligandmpnn_default/backbones/1BC8_1.pdb" \ + --out_folder "./outputs/autoregressive_score_wo_seq" \ + --use_sequence 0\ + --batch_size 1 \ + --number_of_batches 10 +``` +### 3 single amino acid score with sequence info +Get probabilities/scores for backbone-sequence pairs using single aa probabilities: p(AA_1|backbone, AA_{all except AA_1}), p(AA_2|backbone, AA_{all except AA_2}) etc. These probabilities will depend on the decoding order, so it's recomended to set number_of_batches to at least 10. +``` +python score.py \ + --model_type "ligand_mpnn" \ + --seed 111 \ + --single_aa_score 1\ + --pdb_path "./outputs/ligandmpnn_default/backbones/1BC8_1.pdb" \ + --out_folder "./outputs/single_aa_score_w_seq" \ + --use_sequence 1\ + --batch_size 1 \ + --number_of_batches 10 +``` +### 4 single amino acid score with backbone info only +Get probabilities/scores for backbone-sequence pairs using single aa probabilities: p(AA_1|backbone), p(AA_2|backbone) etc. These probabilities will depend on the decoding order, so it's recomended to set number_of_batches to at least 10. +``` +python score.py \ + --model_type "ligand_mpnn" \ + --seed 111 \ + --single_aa_score 1\ + --pdb_path "./outputs/ligandmpnn_default/backbones/1BC8_1.pdb" \ + --out_folder "./outputs/single_aa_score_wo_seq" \ + --use_sequence 0\ + --batch_size 1 \ + --number_of_batches 10 +``` + ### Things to add - Support for ProteinMPNN CA-only model. - Examples for scoring sequences only. diff --git a/model_utils.py b/model_utils.py index e8b0031..7cc9759 100644 --- a/model_utils.py +++ b/model_utils.py @@ -468,72 +468,112 @@ def sample(self, feature_dict): } return output_dict - def unconditional_probs(self, feature_dict): - # xyz_37 = feature_dict["xyz_37"] #[B,L,37,3] - xyz coordinates for all atoms if needed - # xyz_37_m = feature_dict["xyz_37_m"] #[B,L,37] - mask for all coords - # Y = feature_dict["Y"] #[B,L,num_context_atoms,3] - for ligandMPNN coords - # Y_t = feature_dict["Y_t"] #[B,L,num_context_atoms] - element type - # Y_m = feature_dict["Y_m"] #[B,L,num_context_atoms] - mask - X = feature_dict["X"] # [B,L,4,3] - backbone xyz coordinates for N,CA,C,O + def single_aa_score(self, feature_dict, use_sequence: bool): + """ + feature_dict - input features + use_sequence - False using backbone info only + """ B_decoder = feature_dict["batch_size"] - # R_idx = feature_dict["R_idx"] #[B,L] - primary sequence residue index - mask = feature_dict[ + S_true_enc = feature_dict[ + "S" + ] + mask_enc = feature_dict[ "mask" - ] # [B,L] - mask for missing regions - should be removed! all ones most of the time - # chain_labels = feature_dict["chain_labels"] #[B,L] - integer labels for chain letters - device = X.device - - h_V, h_E, E_idx = self.encode(feature_dict) - order_mask_backward = torch.zeros( - [X.shape[0], X.shape[1], X.shape[1]], device=device - ) - mask_attend = torch.gather(order_mask_backward, 2, E_idx).unsqueeze(-1) - mask_1D = mask.view([mask.size(0), mask.size(1), 1, 1]) - mask_fw = mask_1D * (1.0 - mask_attend) + ] + chain_mask_enc = feature_dict[ + "chain_mask" + ] + randn = feature_dict[ + "randn" + ] + B, L = S_true_enc.shape + device = S_true_enc.device + + h_V_enc, h_E_enc, E_idx_enc = self.encode(feature_dict) + log_probs_out = torch.zeros([B_decoder, L, 21], device=device).float() + logits_out = torch.zeros([B_decoder, L, 21], device=device).float() + decoding_order_out = torch.zeros([B_decoder, L, L], device=device).float() + + for idx in range(L): + h_V = torch.clone(h_V_enc) + E_idx = torch.clone(E_idx_enc) + mask = torch.clone(mask_enc) + S_true = torch.clone(S_true_enc) + if not use_sequence: + order_mask = torch.zeros(chain_mask_enc.shape[1], device=device).float() + order_mask[idx] = 1. + else: + order_mask = torch.ones(chain_mask_enc.shape[1], device=device).float() + order_mask[idx] = 0. + decoding_order = torch.argsort( + (order_mask + 0.0001) * (torch.abs(randn)) + ) # [numbers will be smaller for places where chain_M = 0.0 and higher for places where chain_M = 1.0] + E_idx = E_idx.repeat(B_decoder, 1, 1) + permutation_matrix_reverse = torch.nn.functional.one_hot( + decoding_order, num_classes=L + ).float() + order_mask_backward = torch.einsum( + "ij, biq, bjp->bqp", + (1 - torch.triu(torch.ones(L, L, device=device))), + permutation_matrix_reverse, + permutation_matrix_reverse, + ) + mask_attend = torch.gather(order_mask_backward, 2, E_idx).unsqueeze(-1) + mask_1D = mask.view([B, L, 1, 1]) + mask_bw = mask_1D * mask_attend + mask_fw = mask_1D * (1.0 - mask_attend) + S_true = S_true.repeat(B_decoder, 1) + h_V = h_V.repeat(B_decoder, 1, 1) + h_E = h_E_enc.repeat(B_decoder, 1, 1, 1) + mask = mask.repeat(B_decoder, 1) - h_V = h_V.repeat(B_decoder, 1, 1) - h_E = h_E.repeat(B_decoder, 1, 1, 1) - E_idx = E_idx.repeat(B_decoder, 1, 1) - mask_fw = mask_fw.repeat(B_decoder, 1, 1, 1) - mask = mask.repeat(B_decoder, 1) + h_S = self.W_s(S_true) + h_ES = cat_neighbors_nodes(h_S, h_E, E_idx) - h_EX_encoder = cat_neighbors_nodes(torch.zeros_like(h_V), h_E, E_idx) - h_EXV_encoder = cat_neighbors_nodes(h_V, h_EX_encoder, E_idx) + # Build encoder embeddings + h_EX_encoder = cat_neighbors_nodes(torch.zeros_like(h_S), h_E, E_idx) + h_EXV_encoder = cat_neighbors_nodes(h_V, h_EX_encoder, E_idx) - h_EXV_encoder_fw = mask_fw * h_EXV_encoder - for layer in self.decoder_layers: - h_V = layer(h_V, h_EXV_encoder_fw, mask) + h_EXV_encoder_fw = mask_fw * h_EXV_encoder + for layer in self.decoder_layers: + # Masked positions attend to encoder information, unmasked see. + h_ESV = cat_neighbors_nodes(h_V, h_ES, E_idx) + h_ESV = mask_bw * h_ESV + h_EXV_encoder_fw + h_V = layer(h_V, h_ESV, mask) + + logits = self.W_out(h_V) + log_probs = torch.nn.functional.log_softmax(logits, dim=-1) + + log_probs_out[:,idx,:] = log_probs[:,idx,:] + logits_out[:,idx,:] = logits[:,idx,:] + decoding_order_out[:,idx,:] = decoding_order - logits = self.W_out(h_V) - log_probs = torch.nn.functional.log_softmax(logits, dim=-1) - output_dict = {"log_probs": log_probs} + output_dict = { + "S": S_true, + "log_probs": log_probs_out, + "logits": logits_out, + "decoding_order": decoding_order_out, + } return output_dict - def score(self, feature_dict): - # check if score matches - sample log probs - # xyz_37 = feature_dict["xyz_37"] #[B,L,37,3] - xyz coordinates for all atoms if needed - # xyz_37_m = feature_dict["xyz_37_m"] #[B,L,37] - mask for all coords - # Y = feature_dict["Y"] #[B,L,num_context_atoms,3] - for ligandMPNN coords - # Y_t = feature_dict["Y_t"] #[B,L,num_context_atoms] - element type - # Y_m = feature_dict["Y_m"] #[B,L,num_context_atoms] - mask - # X = feature_dict["X"] #[B,L,4,3] - backbone xyz coordinates for N,CA,C,O + def score(self, feature_dict, use_sequence: bool): B_decoder = feature_dict["batch_size"] S_true = feature_dict[ "S" - ] # [B,L] - integer proitein sequence encoded using "restype_STRtoINT" - # R_idx = feature_dict["R_idx"] #[B,L] - primary sequence residue index + ] mask = feature_dict[ "mask" - ] # [B,L] - mask for missing regions - should be removed! all ones most of the time + ] chain_mask = feature_dict[ "chain_mask" - ] # [B,L] - mask for which residues need to be fixed; 0.0 - fixed; 1.0 - will be designed - # chain_labels = feature_dict["chain_labels"] #[B,L] - integer labels for chain letters + ] randn = feature_dict[ "randn" - ] # [B,L] - random numbers for decoding order; only the first entry is used since decoding within a batch needs to match for symmetry - + ] + symmetry_list_of_lists = feature_dict[ + "symmetry_residues" + ] B, L = S_true.shape device = S_true.device @@ -543,27 +583,57 @@ def score(self, feature_dict): decoding_order = torch.argsort( (chain_mask + 0.0001) * (torch.abs(randn)) ) # [numbers will be smaller for places where chain_M = 0.0 and higher for places where chain_M = 1.0] + if len(symmetry_list_of_lists[0]) == 0 and len(symmetry_list_of_lists) == 1: + E_idx = E_idx.repeat(B_decoder, 1, 1) + permutation_matrix_reverse = torch.nn.functional.one_hot( + decoding_order, num_classes=L + ).float() + order_mask_backward = torch.einsum( + "ij, biq, bjp->bqp", + (1 - torch.triu(torch.ones(L, L, device=device))), + permutation_matrix_reverse, + permutation_matrix_reverse, + ) + mask_attend = torch.gather(order_mask_backward, 2, E_idx).unsqueeze(-1) + mask_1D = mask.view([B, L, 1, 1]) + mask_bw = mask_1D * mask_attend + mask_fw = mask_1D * (1.0 - mask_attend) + else: + new_decoding_order = [] + for t_dec in list(decoding_order[0,].cpu().data.numpy()): + if t_dec not in list(itertools.chain(*new_decoding_order)): + list_a = [item for item in symmetry_list_of_lists if t_dec in item] + if list_a: + new_decoding_order.append(list_a[0]) + else: + new_decoding_order.append([t_dec]) - permutation_matrix_reverse = torch.nn.functional.one_hot( - decoding_order, num_classes=L - ).float() - order_mask_backward = torch.einsum( - "ij, biq, bjp->bqp", - (1 - torch.triu(torch.ones(L, L, device=device))), - permutation_matrix_reverse, - permutation_matrix_reverse, - ) - mask_attend = torch.gather(order_mask_backward, 2, E_idx).unsqueeze(-1) - mask_1D = mask.view([B, L, 1, 1]) - mask_bw = mask_1D * mask_attend - mask_fw = mask_1D * (1.0 - mask_attend) + decoding_order = torch.tensor( + list(itertools.chain(*new_decoding_order)), device=device + )[None,].repeat(B, 1) + + permutation_matrix_reverse = torch.nn.functional.one_hot( + decoding_order, num_classes=L + ).float() + order_mask_backward = torch.einsum( + "ij, biq, bjp->bqp", + (1 - torch.triu(torch.ones(L, L, device=device))), + permutation_matrix_reverse, + permutation_matrix_reverse, + ) + mask_attend = torch.gather(order_mask_backward, 2, E_idx).unsqueeze(-1) + mask_1D = mask.view([B, L, 1, 1]) + mask_bw = mask_1D * mask_attend + mask_fw = mask_1D * (1.0 - mask_attend) + + E_idx = E_idx.repeat(B_decoder, 1, 1) + mask_fw = mask_fw.repeat(B_decoder, 1, 1, 1) + mask_bw = mask_bw.repeat(B_decoder, 1, 1, 1) + decoding_order = decoding_order.repeat(B_decoder, 1) - # repeat for decoding S_true = S_true.repeat(B_decoder, 1) h_V = h_V.repeat(B_decoder, 1, 1) h_E = h_E.repeat(B_decoder, 1, 1, 1) - E_idx = E_idx.repeat(B_decoder, 1, 1) - chain_mask = chain_mask.repeat(B_decoder, 1) mask = mask.repeat(B_decoder, 1) h_S = self.W_s(S_true) @@ -574,11 +644,15 @@ def score(self, feature_dict): h_EXV_encoder = cat_neighbors_nodes(h_V, h_EX_encoder, E_idx) h_EXV_encoder_fw = mask_fw * h_EXV_encoder - for layer in self.decoder_layers: - # Masked positions attend to encoder information, unmasked see. - h_ESV = cat_neighbors_nodes(h_V, h_ES, E_idx) - h_ESV = mask_bw * h_ESV + h_EXV_encoder_fw - h_V = layer(h_V, h_ESV, mask) + if not use_sequence: + for layer in self.decoder_layers: + h_V = layer(h_V, h_EXV_encoder_fw, mask) + else: + for layer in self.decoder_layers: + # Masked positions attend to encoder information, unmasked see. + h_ESV = cat_neighbors_nodes(h_V, h_ES, E_idx) + h_ESV = mask_bw * h_ESV + h_EXV_encoder_fw + h_V = layer(h_V, h_ESV, mask) logits = self.W_out(h_V) log_probs = torch.nn.functional.log_softmax(logits, dim=-1) @@ -586,7 +660,8 @@ def score(self, feature_dict): output_dict = { "S": S_true, "log_probs": log_probs, - "decoding_order": decoding_order[0], + "logits": logits, + "decoding_order": decoding_order, } return output_dict diff --git a/outputs/autoregressive_score_w_seq/1BC8_1.pt b/outputs/autoregressive_score_w_seq/1BC8_1.pt new file mode 100644 index 0000000..7e77977 Binary files /dev/null and b/outputs/autoregressive_score_w_seq/1BC8_1.pt differ diff --git a/outputs/autoregressive_score_wo_seq/1BC8_1.pt b/outputs/autoregressive_score_wo_seq/1BC8_1.pt new file mode 100644 index 0000000..fc3131c Binary files /dev/null and b/outputs/autoregressive_score_wo_seq/1BC8_1.pt differ diff --git a/outputs/single_aa_score_w_seq/1BC8_1.pt b/outputs/single_aa_score_w_seq/1BC8_1.pt new file mode 100644 index 0000000..e01591e Binary files /dev/null and b/outputs/single_aa_score_w_seq/1BC8_1.pt differ diff --git a/outputs/single_aa_score_wo_seq/1BC8_1.pt b/outputs/single_aa_score_wo_seq/1BC8_1.pt new file mode 100644 index 0000000..6e89da8 Binary files /dev/null and b/outputs/single_aa_score_wo_seq/1BC8_1.pt differ diff --git a/score.py b/score.py new file mode 100644 index 0000000..9ef7449 --- /dev/null +++ b/score.py @@ -0,0 +1,549 @@ +import argparse +import json +import os.path +import random +import sys + +import numpy as np +import torch + +from data_utils import ( + element_dict_rev, + alphabet, + restype_int_to_str, + featurize, + parse_PDB, +) +from model_utils import ProteinMPNN + + +def main(args) -> None: + """ + Inference function + """ + if args.seed: + seed = args.seed + else: + seed = int(np.random.randint(0, high=99999, size=1, dtype=int)[0]) + torch.manual_seed(seed) + random.seed(seed) + np.random.seed(seed) + device = torch.device("cuda" if (torch.cuda.is_available()) else "cpu") + folder_for_outputs = args.out_folder + base_folder = folder_for_outputs + if base_folder[-1] != "/": + base_folder = base_folder + "/" + if not os.path.exists(base_folder): + os.makedirs(base_folder, exist_ok=True) + if args.model_type == "protein_mpnn": + checkpoint_path = args.checkpoint_protein_mpnn + elif args.model_type == "ligand_mpnn": + checkpoint_path = args.checkpoint_ligand_mpnn + elif args.model_type == "per_residue_label_membrane_mpnn": + checkpoint_path = args.checkpoint_per_residue_label_membrane_mpnn + elif args.model_type == "global_label_membrane_mpnn": + checkpoint_path = args.checkpoint_global_label_membrane_mpnn + elif args.model_type == "soluble_mpnn": + checkpoint_path = args.checkpoint_soluble_mpnn + else: + print("Choose one of the available models") + sys.exit() + checkpoint = torch.load(checkpoint_path, map_location=device) + if args.model_type == "ligand_mpnn": + atom_context_num = checkpoint["atom_context_num"] + ligand_mpnn_use_side_chain_context = args.ligand_mpnn_use_side_chain_context + k_neighbors = checkpoint["num_edges"] + else: + atom_context_num = 1 + ligand_mpnn_use_side_chain_context = 0 + k_neighbors = checkpoint["num_edges"] + + model = ProteinMPNN( + node_features=128, + edge_features=128, + hidden_dim=128, + num_encoder_layers=3, + num_decoder_layers=3, + k_neighbors=k_neighbors, + device=device, + atom_context_num=atom_context_num, + model_type=args.model_type, + ligand_mpnn_use_side_chain_context=ligand_mpnn_use_side_chain_context, + ) + + model.load_state_dict(checkpoint["model_state_dict"]) + model.to(device) + model.eval() + + if args.pdb_path_multi: + with open(args.pdb_path_multi, "r") as fh: + pdb_paths = list(json.load(fh)) + else: + pdb_paths = [args.pdb_path] + + if args.fixed_residues_multi: + with open(args.fixed_residues_multi, "r") as fh: + fixed_residues_multi = json.load(fh) + else: + fixed_residues = [item for item in args.fixed_residues.split()] + fixed_residues_multi = {} + for pdb in pdb_paths: + fixed_residues_multi[pdb] = fixed_residues + + if args.redesigned_residues_multi: + with open(args.redesigned_residues_multi, "r") as fh: + redesigned_residues_multi = json.load(fh) + else: + redesigned_residues = [item for item in args.redesigned_residues.split()] + redesigned_residues_multi = {} + for pdb in pdb_paths: + redesigned_residues_multi[pdb] = redesigned_residues + + # loop over PDB paths + for pdb in pdb_paths: + if args.verbose: + print("Designing protein from this path:", pdb) + fixed_residues = fixed_residues_multi[pdb] + redesigned_residues = redesigned_residues_multi[pdb] + protein_dict, backbone, other_atoms, icodes, _ = parse_PDB( + pdb, + device=device, + chains=args.parse_these_chains_only, + parse_all_atoms=args.ligand_mpnn_use_side_chain_context, + parse_atoms_with_zero_occupancy=args.parse_atoms_with_zero_occupancy + ) + # make chain_letter + residue_idx + insertion_code mapping to integers + R_idx_list = list(protein_dict["R_idx"].cpu().numpy()) # residue indices + chain_letters_list = list(protein_dict["chain_letters"]) # chain letters + encoded_residues = [] + for i, R_idx_item in enumerate(R_idx_list): + tmp = str(chain_letters_list[i]) + str(R_idx_item) + icodes[i] + encoded_residues.append(tmp) + encoded_residue_dict = dict(zip(encoded_residues, range(len(encoded_residues)))) + encoded_residue_dict_rev = dict( + zip(list(range(len(encoded_residues))), encoded_residues) + ) + + fixed_positions = torch.tensor( + [int(item not in fixed_residues) for item in encoded_residues], + device=device, + ) + redesigned_positions = torch.tensor( + [int(item not in redesigned_residues) for item in encoded_residues], + device=device, + ) + + # specify which residues are buried for checkpoint_per_residue_label_membrane_mpnn model + if args.transmembrane_buried: + buried_residues = [item for item in args.transmembrane_buried.split()] + buried_positions = torch.tensor( + [int(item in buried_residues) for item in encoded_residues], + device=device, + ) + else: + buried_positions = torch.zeros_like(fixed_positions) + + if args.transmembrane_interface: + interface_residues = [item for item in args.transmembrane_interface.split()] + interface_positions = torch.tensor( + [int(item in interface_residues) for item in encoded_residues], + device=device, + ) + else: + interface_positions = torch.zeros_like(fixed_positions) + protein_dict["membrane_per_residue_labels"] = 2 * buried_positions * ( + 1 - interface_positions + ) + 1 * interface_positions * (1 - buried_positions) + + if args.model_type == "global_label_membrane_mpnn": + protein_dict["membrane_per_residue_labels"] = ( + args.global_transmembrane_label + 0 * fixed_positions + ) + if type(args.chains_to_design) == str: + chains_to_design_list = args.chains_to_design.split(",") + else: + chains_to_design_list = protein_dict["chain_letters"] + chain_mask = torch.tensor( + np.array( + [ + item in chains_to_design_list + for item in protein_dict["chain_letters"] + ], + dtype=np.int32, + ), + device=device, + ) + + # create chain_mask to notify which residues are fixed (0) and which need to be designed (1) + if redesigned_residues: + protein_dict["chain_mask"] = chain_mask * (1 - redesigned_positions) + elif fixed_residues: + protein_dict["chain_mask"] = chain_mask * fixed_positions + else: + protein_dict["chain_mask"] = chain_mask + + if args.verbose: + PDB_residues_to_be_redesigned = [ + encoded_residue_dict_rev[item] + for item in range(protein_dict["chain_mask"].shape[0]) + if protein_dict["chain_mask"][item] == 1 + ] + PDB_residues_to_be_fixed = [ + encoded_residue_dict_rev[item] + for item in range(protein_dict["chain_mask"].shape[0]) + if protein_dict["chain_mask"][item] == 0 + ] + print("These residues will be redesigned: ", PDB_residues_to_be_redesigned) + print("These residues will be fixed: ", PDB_residues_to_be_fixed) + + # specify which residues are linked + if args.symmetry_residues: + symmetry_residues_list_of_lists = [ + x.split(",") for x in args.symmetry_residues.split("|") + ] + remapped_symmetry_residues = [] + for t_list in symmetry_residues_list_of_lists: + tmp_list = [] + for t in t_list: + tmp_list.append(encoded_residue_dict[t]) + remapped_symmetry_residues.append(tmp_list) + else: + remapped_symmetry_residues = [[]] + + if args.homo_oligomer: + if args.verbose: + print("Designing HOMO-OLIGOMER") + chain_letters_set = list(set(chain_letters_list)) + reference_chain = chain_letters_set[0] + lc = len(reference_chain) + residue_indices = [ + item[lc:] for item in encoded_residues if item[:lc] == reference_chain + ] + remapped_symmetry_residues = [] + for res in residue_indices: + tmp_list = [] + tmp_w_list = [] + for chain in chain_letters_set: + name = chain + res + tmp_list.append(encoded_residue_dict[name]) + tmp_w_list.append(1 / len(chain_letters_set)) + remapped_symmetry_residues.append(tmp_list) + + # set other atom bfactors to 0.0 + if other_atoms: + other_bfactors = other_atoms.getBetas() + other_atoms.setBetas(other_bfactors * 0.0) + + # adjust input PDB name by dropping .pdb if it does exist + name = pdb[pdb.rfind("/") + 1 :] + if name[-4:] == ".pdb": + name = name[:-4] + + with torch.no_grad(): + # run featurize to remap R_idx and add batch dimension + if args.verbose: + if "Y" in list(protein_dict): + atom_coords = protein_dict["Y"].cpu().numpy() + atom_types = list(protein_dict["Y_t"].cpu().numpy()) + atom_mask = list(protein_dict["Y_m"].cpu().numpy()) + number_of_atoms_parsed = np.sum(atom_mask) + else: + print("No ligand atoms parsed") + number_of_atoms_parsed = 0 + atom_types = "" + atom_coords = [] + if number_of_atoms_parsed == 0: + print("No ligand atoms parsed") + elif args.model_type == "ligand_mpnn": + print( + f"The number of ligand atoms parsed is equal to: {number_of_atoms_parsed}" + ) + for i, atom_type in enumerate(atom_types): + print( + f"Type: {element_dict_rev[atom_type]}, Coords {atom_coords[i]}, Mask {atom_mask[i]}" + ) + feature_dict = featurize( + protein_dict, + cutoff_for_score=args.ligand_mpnn_cutoff_for_score, + use_atom_context=args.ligand_mpnn_use_atom_context, + number_of_ligand_atoms=atom_context_num, + model_type=args.model_type, + ) + feature_dict["batch_size"] = args.batch_size + B, L, _, _ = feature_dict["X"].shape # batch size should be 1 for now. + # add additional keys to the feature dictionary + feature_dict["symmetry_residues"] = remapped_symmetry_residues + + logits_list = [] + probs_list = [] + log_probs_list = [] + decoding_order_list = [] + for _ in range(args.number_of_batches): + feature_dict["randn"] = torch.randn( + [feature_dict["batch_size"], feature_dict["mask"].shape[1]], + device=device, + ) + if args.autoregressive_score: + score_dict = model.score(feature_dict, use_sequence=args.use_sequence) + elif args.single_aa_score: + score_dict = model.single_aa_score(feature_dict, use_sequence=args.use_sequence) + else: + print("Set either autoregressive_score or single_aa_score to True") + sys.exit() + logits_list.append(score_dict["logits"]) + log_probs_list.append(score_dict["log_probs"]) + probs_list.append(torch.exp(score_dict["log_probs"])) + decoding_order_list.append(score_dict["decoding_order"]) + log_probs_stack = torch.cat(log_probs_list, 0) + logits_stack = torch.cat(logits_list, 0) + probs_stack = torch.cat(probs_list, 0) + decoding_order_stack = torch.cat(decoding_order_list, 0) + + output_stats_path = base_folder + name + args.file_ending + ".pt" + out_dict = {} + out_dict["logits"] = logits_stack.cpu().numpy() + out_dict["probs"] = probs_stack.cpu().numpy() + out_dict["log_probs"] = log_probs_stack.cpu().numpy() + out_dict["decoding_order"] = decoding_order_stack.cpu().numpy() + out_dict["native_sequence"] = feature_dict["S"][0].cpu().numpy() + out_dict["mask"] = feature_dict["mask"][0].cpu().numpy() + out_dict["chain_mask"] = feature_dict["chain_mask"][0].cpu().numpy() #this affects decoding order + out_dict["seed"] = seed + out_dict["alphabet"] = alphabet + out_dict["residue_names"] = encoded_residue_dict_rev + + mean_probs = np.mean(out_dict["probs"], 0) + std_probs = np.std(out_dict["probs"], 0) + sequence = [restype_int_to_str[AA] for AA in out_dict["native_sequence"]] + mean_dict = {} + std_dict = {} + for residue in range(L): + mean_dict_ = dict(zip(alphabet, mean_probs[residue])) + mean_dict[encoded_residue_dict_rev[residue]] = mean_dict_ + std_dict_ = dict(zip(alphabet, std_probs[residue])) + std_dict[encoded_residue_dict_rev[residue]] = std_dict_ + + out_dict["sequence"] = sequence + out_dict["mean_of_probs"] = mean_dict + out_dict["std_of_probs"] = std_dict + torch.save(out_dict, output_stats_path) + + + +if __name__ == "__main__": + argparser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + argparser.add_argument( + "--model_type", + type=str, + default="protein_mpnn", + help="Choose your model: protein_mpnn, ligand_mpnn, per_residue_label_membrane_mpnn, global_label_membrane_mpnn, soluble_mpnn", + ) + # protein_mpnn - original ProteinMPNN trained on the whole PDB exluding non-protein atoms + # ligand_mpnn - atomic context aware model trained with small molecules, nucleotides, metals etc on the whole PDB + # per_residue_label_membrane_mpnn - ProteinMPNN model trained with addition label per residue specifying if that residue is buried or exposed + # global_label_membrane_mpnn - ProteinMPNN model trained with global label per PDB id to specify if protein is transmembrane + # soluble_mpnn - ProteinMPNN trained only on soluble PDB ids + argparser.add_argument( + "--checkpoint_protein_mpnn", + type=str, + default="./model_params/proteinmpnn_v_48_020.pt", + help="Path to model weights.", + ) + argparser.add_argument( + "--checkpoint_ligand_mpnn", + type=str, + default="./model_params/ligandmpnn_v_32_010_25.pt", + help="Path to model weights.", + ) + argparser.add_argument( + "--checkpoint_per_residue_label_membrane_mpnn", + type=str, + default="./model_params/per_residue_label_membrane_mpnn_v_48_020.pt", + help="Path to model weights.", + ) + argparser.add_argument( + "--checkpoint_global_label_membrane_mpnn", + type=str, + default="./model_params/global_label_membrane_mpnn_v_48_020.pt", + help="Path to model weights.", + ) + argparser.add_argument( + "--checkpoint_soluble_mpnn", + type=str, + default="./model_params/solublempnn_v_48_020.pt", + help="Path to model weights.", + ) + + argparser.add_argument("--verbose", type=int, default=1, help="Print stuff") + + argparser.add_argument( + "--pdb_path", type=str, default="", help="Path to the input PDB." + ) + argparser.add_argument( + "--pdb_path_multi", + type=str, + default="", + help="Path to json listing PDB paths. {'/path/to/pdb': ''} - only keys will be used.", + ) + + argparser.add_argument( + "--fixed_residues", + type=str, + default="", + help="Provide fixed residues, A12 A13 A14 B2 B25", + ) + argparser.add_argument( + "--fixed_residues_multi", + type=str, + default="", + help="Path to json mapping of fixed residues for each pdb i.e., {'/path/to/pdb': 'A12 A13 A14 B2 B25'}", + ) + + argparser.add_argument( + "--redesigned_residues", + type=str, + default="", + help="Provide to be redesigned residues, everything else will be fixed, A12 A13 A14 B2 B25", + ) + argparser.add_argument( + "--redesigned_residues_multi", + type=str, + default="", + help="Path to json mapping of redesigned residues for each pdb i.e., {'/path/to/pdb': 'A12 A13 A14 B2 B25'}", + ) + + argparser.add_argument( + "--symmetry_residues", + type=str, + default="", + help="Add list of lists for which residues need to be symmetric, e.g. 'A12,A13,A14|C2,C3|A5,B6'", + ) + + argparser.add_argument( + "--homo_oligomer", + type=int, + default=0, + help="Setting this to 1 will automatically set --symmetry_residues and --symmetry_weights to do homooligomer design with equal weighting.", + ) + + argparser.add_argument( + "--out_folder", + type=str, + help="Path to a folder to output scores, e.g. /home/out/", + ) + argparser.add_argument( + "--file_ending", type=str, default="", help="adding_string_to_the_end" + ) + argparser.add_argument( + "--zero_indexed", + type=str, + default=0, + help="1 - to start output PDB numbering with 0", + ) + argparser.add_argument( + "--seed", + type=int, + default=0, + help="Set seed for torch, numpy, and python random.", + ) + argparser.add_argument( + "--batch_size", + type=int, + default=1, + help="Number of sequence to generate per one pass.", + ) + argparser.add_argument( + "--number_of_batches", + type=int, + default=1, + help="Number of times to design sequence using a chosen batch size.", + ) + + argparser.add_argument( + "--ligand_mpnn_use_atom_context", + type=int, + default=1, + help="1 - use atom context, 0 - do not use atom context.", + ) + + argparser.add_argument( + "--ligand_mpnn_use_side_chain_context", + type=int, + default=0, + help="Flag to use side chain atoms as ligand context for the fixed residues", + ) + + argparser.add_argument( + "--ligand_mpnn_cutoff_for_score", + type=float, + default=8.0, + help="Cutoff in angstroms between protein and context atoms to select residues for reporting score.", + ) + + argparser.add_argument( + "--chains_to_design", + type=str, + default=None, + help="Specify which chains to redesign, all others will be kept fixed.", + ) + + argparser.add_argument( + "--parse_these_chains_only", + type=str, + default="", + help="Provide chains letters for parsing backbones, 'ABCF'", + ) + + argparser.add_argument( + "--transmembrane_buried", + type=str, + default="", + help="Provide buried residues when using checkpoint_per_residue_label_membrane_mpnn model, A12 A13 A14 B2 B25", + ) + argparser.add_argument( + "--transmembrane_interface", + type=str, + default="", + help="Provide interface residues when using checkpoint_per_residue_label_membrane_mpnn model, A12 A13 A14 B2 B25", + ) + + argparser.add_argument( + "--global_transmembrane_label", + type=int, + default=0, + help="Provide global label for global_label_membrane_mpnn model. 1 - transmembrane, 0 - soluble", + ) + + argparser.add_argument( + "--parse_atoms_with_zero_occupancy", + type=int, + default=0, + help="To parse atoms with zero occupancy in the PDB input files. 0 - do not parse, 1 - parse atoms with zero occupancy", + ) + + argparser.add_argument( + "--use_sequence", + type=int, + default=1, + help="1 - get scores using amino acid sequence info; 0 - get scores using backbone info only", + ) + + argparser.add_argument( + "--autoregressive_score", + type=int, + default=0, + help="1 - run autoregressive scoring function; p(AA_1|backbone); p(AA_2|backbone, AA_1) etc, 0 - False", + ) + + argparser.add_argument( + "--single_aa_score", + type=int, + default=1, + help="1 - run single amino acid scoring function; p(AA_i|backbone, AA_{all except ith one}), 0 - False", + ) + + args = argparser.parse_args() + main(args)