From 399dddeee191d4150198d787f8b06d0db07f0ea1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Bouysset?= Date: Thu, 17 Nov 2022 20:11:32 +0100 Subject: [PATCH 01/11] reformat with black --- docs/notebooks/how-to.ipynb | 24 +- .../protein-protein_interactions.ipynb | 19 +- docs/notebooks/quickstart.ipynb | 51 ++- docs/notebooks/visualisation.ipynb | 188 +++++--- prolif/__init__.py | 12 +- prolif/_version.py | 154 ++++--- prolif/datafiles.py | 2 +- prolif/fingerprint.py | 128 ++++-- prolif/interactions.py | 205 +++++---- prolif/molecule.py | 26 +- prolif/parallel.py | 10 +- prolif/plotting/network.py | 426 ++++++++++-------- prolif/rdkitmol.py | 2 + prolif/residue.py | 30 +- prolif/utils.py | 63 +-- pyproject.toml | 11 + setup.py | 2 +- tests/mol2factory.py | 18 + tests/notebooks/viz-pi-stacking.ipynb | 59 ++- tests/plotting/test_network.py | 7 +- tests/test_fingerprint.py | 94 ++-- tests/test_interactions.py | 336 +++++++------- tests/test_molecule.py | 29 +- tests/test_residues.py | 149 +++--- tests/test_utils.py | 166 +++++-- tests/test_wrapper_docs.py | 4 +- 26 files changed, 1316 insertions(+), 899 deletions(-) create mode 100644 pyproject.toml diff --git a/docs/notebooks/how-to.ipynb b/docs/notebooks/how-to.ipynb index fa0d52b..f1ce20b 100644 --- a/docs/notebooks/how-to.ipynb +++ b/docs/notebooks/how-to.ipynb @@ -139,6 +139,7 @@ "class Hydrophobic(plf.interactions.Hydrophobic):\n", " pass\n", "\n", + "\n", "fp = plf.Fingerprint([\"Hydrophobic\"])\n", "fp.hydrophobic(lmol, pmol[\"TYR109.A\"])" ] @@ -173,7 +174,8 @@ "class CustomHydrophobic(plf.interactions.Hydrophobic):\n", " def __init__(self):\n", " super().__init__(distance=4.0)\n", - " \n", + "\n", + "\n", "fp = plf.Fingerprint([\"Hydrophobic\", \"CustomHydrophobic\"])\n", "fp.hydrophobic(lmol, pmol[\"TYR109.A\"])" ] @@ -247,6 +249,7 @@ "source": [ "from scipy.spatial import distance_matrix\n", "\n", + "\n", "class CloseContact(plf.interactions.Interaction):\n", " def __init__(self, threshold=2.0):\n", " self.threshold = threshold\n", @@ -260,6 +263,7 @@ " return True, res1_i[0], res2_i[0]\n", " return False, None, None\n", "\n", + "\n", "fp = plf.Fingerprint([\"CloseContact\"])\n", "fp.closecontact(lmol, pmol[\"ASP129.A\"])" ] @@ -293,8 +297,7 @@ "source": [ "ifp = fp.generate(lmol, pmol, return_atoms=True)\n", "# check the interactino between the ligand and ASP129\n", - "ifp[(plf.ResidueId(\"LIG\", 1, \"G\"),\n", - " plf.ResidueId(\"ASP\", 129, \"A\"))]" + "ifp[(plf.ResidueId(\"LIG\", 1, \"G\"), plf.ResidueId(\"ASP\", 129, \"A\"))]" ] }, { @@ -424,10 +427,13 @@ "df0.columns = df0.columns.droplevel(0)\n", "df.columns = df.columns.droplevel(0)\n", "# concatenate and sort columns\n", - "df = (pd.concat([df0, df])\n", - " .fillna(False)\n", - " .sort_index(axis=1, level=0,\n", - " key=lambda index: [plf.ResidueId.from_string(x) for x in index]))\n", + "df = (\n", + " pd.concat([df0, df])\n", + " .fillna(False)\n", + " .sort_index(\n", + " axis=1, level=0, key=lambda index: [plf.ResidueId.from_string(x) for x in index]\n", + " )\n", + ")\n", "df" ] }, @@ -593,7 +599,9 @@ "source": [ "from rdkit import Chem\n", "\n", - "template = Chem.MolFromSmiles(\"C[NH+]1CC(C(=O)NC2(C)OC3(O)C4CCCN4C(=O)C(Cc4ccccc4)N3C2=O)C=C2c3cccc4[nH]cc(c34)CC21\")\n", + "template = Chem.MolFromSmiles(\n", + " \"C[NH+]1CC(C(=O)NC2(C)OC3(O)C4CCCN4C(=O)C(Cc4ccccc4)N3C2=O)C=C2c3cccc4[nH]cc(c34)CC21\"\n", + ")\n", "template" ] }, diff --git a/docs/notebooks/protein-protein_interactions.ipynb b/docs/notebooks/protein-protein_interactions.ipynb index 3955384..e4b4177 100644 --- a/docs/notebooks/protein-protein_interactions.ipynb +++ b/docs/notebooks/protein-protein_interactions.ipynb @@ -54,7 +54,17 @@ "outputs": [], "source": [ "# prot-prot interactions\n", - "fp = plf.Fingerprint([\"HBDonor\", \"HBAcceptor\", \"PiStacking\", \"PiCation\", \"CationPi\", \"Anionic\", \"Cationic\"])\n", + "fp = plf.Fingerprint(\n", + " [\n", + " \"HBDonor\",\n", + " \"HBAcceptor\",\n", + " \"PiStacking\",\n", + " \"PiCation\",\n", + " \"CationPi\",\n", + " \"Anionic\",\n", + " \"Cationic\",\n", + " ]\n", + ")\n", "fp.run(u.trajectory[::10], tm3, prot)" ] }, @@ -117,10 +127,7 @@ "outputs": [], "source": [ "# regroup all interactions together and do the same\n", - "g = (df.groupby(level=[\"ligand\", \"protein\"], axis=1)\n", - " .sum()\n", - " .astype(bool)\n", - " .mean())\n", + "g = df.groupby(level=[\"ligand\", \"protein\"], axis=1).sum().astype(bool).mean()\n", "g.loc[g > 0.3]" ] }, @@ -148,12 +155,14 @@ "backbone = Chem.MolFromSmarts(\"[C^2](=O)-[C;X4](-[H])-[N;+0]\")\n", "fix_h = Chem.MolFromSmarts(\"[H&D0]\")\n", "\n", + "\n", "def remove_backbone(atomgroup):\n", " mol = plf.Molecule.from_mda(atomgroup)\n", " mol = AllChem.DeleteSubstructs(mol, backbone)\n", " mol = AllChem.DeleteSubstructs(mol, fix_h)\n", " return plf.Molecule(mol)\n", "\n", + "\n", "# generate IFP\n", "ifp = []\n", "for ts in tqdm(u.trajectory[::10]):\n", diff --git a/docs/notebooks/quickstart.ipynb b/docs/notebooks/quickstart.ipynb index bf1be90..a738aa3 100644 --- a/docs/notebooks/quickstart.ipynb +++ b/docs/notebooks/quickstart.ipynb @@ -19,6 +19,7 @@ "source": [ "import MDAnalysis as mda\n", "import prolif as plf\n", + "\n", "# load trajectory\n", "u = mda.Universe(plf.datafiles.TOP, plf.datafiles.TRAJ)\n", "# create selections for the ligand and protein\n", @@ -48,12 +49,13 @@ "source": [ "from rdkit import Chem\n", "from rdkit.Chem import Draw\n", + "\n", "# create a molecule from the MDAnalysis selection\n", "lmol = plf.Molecule.from_mda(lig)\n", "# cleanup before drawing\n", "mol = Chem.RemoveHs(lmol)\n", "mol.RemoveAllConformers()\n", - "Draw.MolToImage(mol, size=(400,200))" + "Draw.MolToImage(mol, size=(400, 200))" ] }, { @@ -77,11 +79,13 @@ " mol = Chem.RemoveHs(res)\n", " mol.RemoveAllConformers()\n", " frags.append(mol)\n", - "Draw.MolsToGridImage(frags,\n", - " legends=[str(res.resid) for res in pmol], \n", - " subImgSize=(200, 140),\n", - " molsPerRow=4,\n", - " maxMols=prot.n_residues)" + "Draw.MolsToGridImage(\n", + " frags,\n", + " legends=[str(res.resid) for res in pmol],\n", + " subImgSize=(200, 140),\n", + " molsPerRow=4,\n", + " maxMols=prot.n_residues,\n", + ")" ] }, { @@ -209,19 +213,36 @@ "\n", "# reorganize data\n", "data = df.reset_index()\n", - "data = pd.melt(data, id_vars=[\"Frame\"], var_name=[\"residue\",\"interaction\"])\n", + "data = pd.melt(data, id_vars=[\"Frame\"], var_name=[\"residue\", \"interaction\"])\n", "data = data[data[\"value\"] != False]\n", "data.reset_index(inplace=True, drop=True)\n", "\n", "# plot\n", - "sns.set_theme(font_scale=.8, style=\"white\", context=\"talk\")\n", + "sns.set_theme(font_scale=0.8, style=\"white\", context=\"talk\")\n", "g = sns.catplot(\n", - " data=data, x=\"interaction\", y=\"Frame\", hue=\"interaction\", col=\"residue\",\n", - " hue_order=[\"Hydrophobic\", \"HBDonor\", \"HBAcceptor\", \"PiStacking\", \"CationPi\", \"Cationic\"],\n", - " height=3, aspect=0.2, jitter=0, sharex=False, marker=\"_\", s=8, linewidth=3.5,\n", + " data=data,\n", + " x=\"interaction\",\n", + " y=\"Frame\",\n", + " hue=\"interaction\",\n", + " col=\"residue\",\n", + " hue_order=[\n", + " \"Hydrophobic\",\n", + " \"HBDonor\",\n", + " \"HBAcceptor\",\n", + " \"PiStacking\",\n", + " \"CationPi\",\n", + " \"Cationic\",\n", + " ],\n", + " height=3,\n", + " aspect=0.2,\n", + " jitter=0,\n", + " sharex=False,\n", + " marker=\"_\",\n", + " s=8,\n", + " linewidth=3.5,\n", ")\n", "g.set_titles(\"{col_name}\")\n", - "g.set(xticks=[], ylim=(-.5, data.Frame.max()+1))\n", + "g.set(xticks=[], ylim=(-0.5, data.Frame.max() + 1))\n", "g.set_xticklabels([])\n", "g.set_xlabels(\"\")\n", "g.fig.subplots_adjust(wspace=0)\n", @@ -251,10 +272,7 @@ "outputs": [], "source": [ "# regroup all interactions together and do the same\n", - "g = (df.groupby(level=[\"protein\"], axis=1)\n", - " .sum()\n", - " .astype(bool)\n", - " .mean())\n", + "g = df.groupby(level=[\"protein\"], axis=1).sum().astype(bool).mean()\n", "g.loc[g > 0.3]" ] }, @@ -272,6 +290,7 @@ "outputs": [], "source": [ "from rdkit import DataStructs\n", + "\n", "bvs = fp.to_bitvectors()\n", "tanimoto_sims = DataStructs.BulkTanimotoSimilarity(bvs[0], bvs)\n", "tanimoto_sims" diff --git a/docs/notebooks/visualisation.ipynb b/docs/notebooks/visualisation.ipynb index 90c0568..9f9ac21 100644 --- a/docs/notebooks/visualisation.ipynb +++ b/docs/notebooks/visualisation.ipynb @@ -18,6 +18,7 @@ "import MDAnalysis as mda\n", "import prolif as plf\n", "import numpy as np\n", + "\n", "# load topology\n", "u = mda.Universe(plf.datafiles.TOP, plf.datafiles.TRAJ)\n", "lig = u.select_atoms(\"resname LIG\")\n", @@ -68,6 +69,7 @@ "from rdkit import Chem\n", "from rdkit import Geometry\n", "\n", + "\n", "def get_ring_centroid(mol, index):\n", " # find ring using the atom index\n", " Chem.SanitizeMol(mol, Chem.SanitizeFlags.SANITIZE_SETAROMATICITY)\n", @@ -76,7 +78,9 @@ " if index in r:\n", " break\n", " else:\n", - " raise ValueError(\"No ring containing this atom index was found in the given molecule\")\n", + " raise ValueError(\n", + " \"No ring containing this atom index was found in the given molecule\"\n", + " )\n", " # get centroid\n", " coords = mol.xyz[list(r)]\n", " ctd = plf.utils.get_centroid(coords)\n", @@ -140,8 +144,10 @@ " lres = lmol[lresid]\n", " pres = pmol[presid]\n", " # set model ids for reusing later\n", - " for resid, res, style in [(lresid, lres, {\"colorscheme\": \"cyanCarbon\"}),\n", - " (presid, pres, {})]:\n", + " for resid, res, style in [\n", + " (lresid, lres, {\"colorscheme\": \"cyanCarbon\"}),\n", + " (presid, pres, {}),\n", + " ]:\n", " if resid not in models.keys():\n", " mid += 1\n", " v.addModel(Chem.MolToMolBlock(res), \"sdf\")\n", @@ -158,34 +164,35 @@ " if interaction in [\"PiStacking\", \"EdgeToFace\", \"FaceToFace\", \"CationPi\"]:\n", " p2 = get_ring_centroid(pres, pindex)\n", " else:\n", - " p2 = pres.GetConformer().GetAtomPosition(pindex) \n", + " p2 = pres.GetConformer().GetAtomPosition(pindex)\n", " # add interaction line\n", - " v.addCylinder({\"start\": dict(x=p1.x, y=p1.y, z=p1.z),\n", - " \"end\": dict(x=p2.x, y=p2.y, z=p2.z),\n", - " \"color\": colors[interaction],\n", - " \"radius\": .15,\n", - " \"dashed\": True,\n", - " \"fromCap\": 1,\n", - " \"toCap\": 1,\n", - " })\n", + " v.addCylinder(\n", + " {\n", + " \"start\": dict(x=p1.x, y=p1.y, z=p1.z),\n", + " \"end\": dict(x=p2.x, y=p2.y, z=p2.z),\n", + " \"color\": colors[interaction],\n", + " \"radius\": 0.15,\n", + " \"dashed\": True,\n", + " \"fromCap\": 1,\n", + " \"toCap\": 1,\n", + " }\n", + " )\n", " # add label when hovering the middle of the dashed line by adding a dummy atom\n", " c = Geometry.Point3D(*plf.utils.get_centroid([p1, p2]))\n", " modelID = models[lresid]\n", " model = v.getModel(modelID)\n", - " model.addAtoms([{\"elem\": 'Z',\n", - " \"x\": c.x, \"y\": c.y, \"z\": c.z,\n", - " \"interaction\": interaction}])\n", - " model.setStyle({\"interaction\": interaction}, {\"clicksphere\": {\"radius\": .5}})\n", - " model.setHoverable(\n", - " {\"interaction\": interaction}, True,\n", - " hover_func, unhover_func)\n", + " model.addAtoms(\n", + " [{\"elem\": \"Z\", \"x\": c.x, \"y\": c.y, \"z\": c.z, \"interaction\": interaction}]\n", + " )\n", + " model.setStyle({\"interaction\": interaction}, {\"clicksphere\": {\"radius\": 0.5}})\n", + " model.setHoverable({\"interaction\": interaction}, True, hover_func, unhover_func)\n", "\n", "# show protein\n", "mol = Chem.RemoveAllHs(pmol)\n", "pdb = Chem.MolToPDBBlock(mol, flavor=0x20 | 0x10)\n", "v.addModel(pdb, \"pdb\")\n", "model = v.getModel()\n", - "model.setStyle({}, {\"cartoon\": {\"style\":\"edged\"}})\n", + "model.setStyle({}, {\"cartoon\": {\"style\": \"edged\"}})\n", "\n", "v.zoomTo({\"model\": list(models.values())})" ] @@ -218,10 +225,14 @@ "fp.run(u.trajectory[::10], lig, prot)\n", "df = fp.to_dataframe(return_atoms=True)\n", "\n", - "net = LigNetwork.from_ifp(df, lmol,\n", - " # replace with `kind=\"frame\", frame=0` for the other depiction\n", - " kind=\"aggregate\", threshold=.3,\n", - " rotation=270)\n", + "net = LigNetwork.from_ifp(\n", + " df,\n", + " lmol,\n", + " # replace with `kind=\"frame\", frame=0` for the other depiction\n", + " kind=\"aggregate\",\n", + " threshold=0.3,\n", + " rotation=270,\n", + ")\n", "net.display()" ] }, @@ -278,11 +289,16 @@ "metadata": {}, "outputs": [], "source": [ - "def make_graph(values, df=None,\n", - " node_color=[\"#FFB2AC\", \"#ACD0FF\"], node_shape=\"dot\",\n", - " edge_color=\"#a9a9a9\", width_multiplier=1):\n", + "def make_graph(\n", + " values,\n", + " df=None,\n", + " node_color=[\"#FFB2AC\", \"#ACD0FF\"],\n", + " node_shape=\"dot\",\n", + " edge_color=\"#a9a9a9\",\n", + " width_multiplier=1,\n", + "):\n", " \"\"\"Convert a pandas DataFrame to a NetworkX object\n", - " \n", + "\n", " Parameters\n", " ----------\n", " values : pandas.Series\n", @@ -290,51 +306,65 @@ " each lig-prot residue pair that will be used to set the width and weigth\n", " of each edge. For example:\n", "\n", - " ligand protein \n", + " ligand protein\n", " LIG1.G ALA216.A 0.66\n", " ALA343.B 0.10\n", "\n", " df : pandas.DataFrame\n", " DataFrame obtained from the fp.to_dataframe() method\n", " Used to label each edge with the type of interaction\n", - " \n", + "\n", " node_color : list\n", " Colors for the ligand and protein residues, respectively\n", "\n", " node_shape : str\n", " One of ellipse, circle, database, box, text or image, circularImage,\n", " diamond, dot, star, triangle, triangleDown, square, icon.\n", - " \n", + "\n", " edge_color : str\n", " Color of the edge between nodes\n", - " \n", + "\n", " width_multiplier : int or float\n", " Each edge's width is defined as `width_multiplier * value`\n", " \"\"\"\n", " lig_res = values.index.get_level_values(\"ligand\").unique().tolist()\n", " prot_res = values.index.get_level_values(\"protein\").unique().tolist()\n", - " \n", + "\n", " G = nx.Graph()\n", " # add nodes\n", " # https://pyvis.readthedocs.io/en/latest/documentation.html#pyvis.network.Network.add_node\n", " for res in lig_res:\n", - " G.add_node(res, title=res, shape=node_shape,\n", - " color=node_color[0], dtype=\"ligand\")\n", + " G.add_node(\n", + " res, title=res, shape=node_shape, color=node_color[0], dtype=\"ligand\"\n", + " )\n", " for res in prot_res:\n", - " G.add_node(res, title=res, shape=node_shape,\n", - " color=node_color[1], dtype=\"protein\")\n", + " G.add_node(\n", + " res, title=res, shape=node_shape, color=node_color[1], dtype=\"protein\"\n", + " )\n", "\n", " for resids, value in values.items():\n", - " label = \"{} - {}
{}\".format(*resids, \"
\".join([f\"{k}: {v}\"\n", - " for k, v in (df.xs(resids,\n", - " level=[\"ligand\", \"protein\"],\n", - " axis=1)\n", - " .sum()\n", - " .to_dict()\n", - " .items())]))\n", + " label = \"{} - {}
{}\".format(\n", + " *resids,\n", + " \"
\".join(\n", + " [\n", + " f\"{k}: {v}\"\n", + " for k, v in (\n", + " df.xs(resids, level=[\"ligand\", \"protein\"], axis=1)\n", + " .sum()\n", + " .to_dict()\n", + " .items()\n", + " )\n", + " ]\n", + " ),\n", + " )\n", " # https://pyvis.readthedocs.io/en/latest/documentation.html#pyvis.network.Network.add_edge\n", - " G.add_edge(*resids, title=label, color=edge_color,\n", - " weight=value, width=value*width_multiplier)\n", + " G.add_edge(\n", + " *resids,\n", + " title=label,\n", + " color=edge_color,\n", + " weight=value,\n", + " width=value * width_multiplier,\n", + " )\n", "\n", " return G" ] @@ -354,10 +384,7 @@ "metadata": {}, "outputs": [], "source": [ - "data = (df.groupby(level=[\"ligand\", \"protein\"], axis=1)\n", - " .sum()\n", - " .astype(bool)\n", - " .mean())\n", + "data = df.groupby(level=[\"ligand\", \"protein\"], axis=1).sum().astype(bool).mean()\n", "\n", "G = make_graph(data, df, width_multiplier=3)\n", "\n", @@ -383,11 +410,10 @@ "metadata": {}, "outputs": [], "source": [ - "data = (df.xs(\"Hydrophobic\", level=\"interaction\", axis=1)\n", - " .mean())\n", + "data = df.xs(\"Hydrophobic\", level=\"interaction\", axis=1).mean()\n", "\n", "G = make_graph(data, df, width_multiplier=3)\n", - " \n", + "\n", "# display graph\n", "net = Network(width=600, height=500, notebook=True, heading=\"\")\n", "net.from_nx(G)\n", @@ -424,23 +450,25 @@ "metadata": {}, "outputs": [], "source": [ - "data = (df.groupby(level=[\"ligand\", \"protein\"], axis=1, sort=False)\n", - " .sum()\n", - " .astype(bool)\n", - " .mean())\n", + "data = (\n", + " df.groupby(level=[\"ligand\", \"protein\"], axis=1, sort=False)\n", + " .sum()\n", + " .astype(bool)\n", + " .mean()\n", + ")\n", "\n", "G = make_graph(data, df, width_multiplier=8)\n", "\n", "# color each node based on its degree\n", "max_nbr = len(max(G.adj.values(), key=lambda x: len(x)))\n", - "blues = cm.get_cmap('Blues', max_nbr)\n", - "reds = cm.get_cmap('Reds', max_nbr)\n", + "blues = cm.get_cmap(\"Blues\", max_nbr)\n", + "reds = cm.get_cmap(\"Reds\", max_nbr)\n", "for n, d in G.nodes(data=True):\n", " n_neighbors = len(G.adj[n])\n", " # show TM3 in red and the rest of the protein in blue\n", " palette = reds if d[\"dtype\"] == \"ligand\" else blues\n", - " d[\"color\"] = colors.to_hex( palette(n_neighbors / max_nbr) )\n", - " \n", + " d[\"color\"] = colors.to_hex(palette(n_neighbors / max_nbr))\n", + "\n", "# convert to pyvis network\n", "net = Network(width=640, height=500, notebook=True, heading=\"\")\n", "net.from_nx(G)\n", @@ -464,7 +492,17 @@ "outputs": [], "source": [ "prot = u.select_atoms(\"protein\")\n", - "fp = plf.Fingerprint(['HBDonor', 'HBAcceptor', 'PiStacking', 'Anionic', 'Cationic', 'CationPi', 'PiCation'])\n", + "fp = plf.Fingerprint(\n", + " [\n", + " \"HBDonor\",\n", + " \"HBAcceptor\",\n", + " \"PiStacking\",\n", + " \"Anionic\",\n", + " \"Cationic\",\n", + " \"CationPi\",\n", + " \"PiCation\",\n", + " ]\n", + ")\n", "fp.run(u.trajectory[::10], prot, prot)\n", "df = fp.to_dataframe()\n", "df.head()" @@ -483,13 +521,15 @@ "metadata": {}, "outputs": [], "source": [ - "# remove interactions between residues i and i±4 or less \n", + "# remove interactions between residues i and i±4 or less\n", "mask = []\n", "for l, p, interaction in df.columns:\n", " lr = plf.ResidueId.from_string(l)\n", " pr = plf.ResidueId.from_string(p)\n", - " if (pr == lr) or (abs(pr.number - lr.number) <= 4\n", - " and interaction in [\"HBDonor\", \"HBAcceptor\", \"Hydrophobic\"]):\n", + " if (pr == lr) or (\n", + " abs(pr.number - lr.number) <= 4\n", + " and interaction in [\"HBDonor\", \"HBAcceptor\", \"Hydrophobic\"]\n", + " ):\n", " mask.append(False)\n", " else:\n", " mask.append(True)\n", @@ -503,22 +543,24 @@ "metadata": {}, "outputs": [], "source": [ - "data = (df.groupby(level=[\"ligand\", \"protein\"], axis=1, sort=False)\n", - " .sum()\n", - " .astype(bool)\n", - " .mean())\n", + "data = (\n", + " df.groupby(level=[\"ligand\", \"protein\"], axis=1, sort=False)\n", + " .sum()\n", + " .astype(bool)\n", + " .mean()\n", + ")\n", "\n", "G = make_graph(data, df, width_multiplier=5)\n", "\n", "# color each node based on its degree\n", "max_nbr = len(max(G.adj.values(), key=lambda x: len(x)))\n", - "palette = cm.get_cmap('YlGnBu', max_nbr)\n", + "palette = cm.get_cmap(\"YlGnBu\", max_nbr)\n", "for n, d in G.nodes(data=True):\n", " n_neighbors = len(G.adj[n])\n", - " d[\"color\"] = colors.to_hex( palette(n_neighbors / max_nbr) )\n", - " \n", + " d[\"color\"] = colors.to_hex(palette(n_neighbors / max_nbr))\n", + "\n", "# convert to pyvis network\n", - "net = Network(width=640, height=500, notebook=True, heading=\"\")\n", + "net = Network(width=640, height=500, notebook=True, heading=\"\")\n", "net.from_nx(G)\n", "\n", "# use specific layout\n", diff --git a/prolif/__init__.py b/prolif/__init__.py index d23658a..6b650cd 100644 --- a/prolif/__init__.py +++ b/prolif/__init__.py @@ -1,13 +1,9 @@ -from .molecule import (Molecule, - pdbqt_supplier, - mol2_supplier, - sdf_supplier) +from .molecule import Molecule, pdbqt_supplier, mol2_supplier, sdf_supplier from .residue import ResidueId from .fingerprint import Fingerprint -from .utils import (get_residues_near_ligand, - to_dataframe, - to_bitvectors) +from .utils import get_residues_near_ligand, to_dataframe, to_bitvectors from . import datafiles from ._version import get_versions -__version__ = get_versions()['version'] + +__version__ = get_versions()["version"] del get_versions diff --git a/prolif/_version.py b/prolif/_version.py index d1368ea..309a170 100644 --- a/prolif/_version.py +++ b/prolif/_version.py @@ -1,4 +1,3 @@ - # This file helps to compute a version number in source trees obtained from # git-archive tarball (such as those provided by githubs download-from-tag # feature). Distribution tarballs (built by setup.py sdist) and build @@ -58,17 +57,18 @@ class NotThisMethod(Exception): def register_vcs_handler(vcs, method): # decorator """Create decorator to mark a method as the handler of a VCS.""" + def decorate(f): """Store f in HANDLERS[vcs][method].""" if vcs not in HANDLERS: HANDLERS[vcs] = {} HANDLERS[vcs][method] = f return f + return decorate -def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, - env=None): +def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, env=None): """Call the given command(s).""" assert isinstance(commands, list) p = None @@ -76,10 +76,13 @@ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, try: dispcmd = str([c] + args) # remember shell=False, so use git.cmd on windows, not just git - p = subprocess.Popen([c] + args, cwd=cwd, env=env, - stdout=subprocess.PIPE, - stderr=(subprocess.PIPE if hide_stderr - else None)) + p = subprocess.Popen( + [c] + args, + cwd=cwd, + env=env, + stdout=subprocess.PIPE, + stderr=(subprocess.PIPE if hide_stderr else None), + ) break except EnvironmentError: e = sys.exc_info()[1] @@ -114,16 +117,22 @@ def versions_from_parentdir(parentdir_prefix, root, verbose): for i in range(3): dirname = os.path.basename(root) if dirname.startswith(parentdir_prefix): - return {"version": dirname[len(parentdir_prefix):], - "full-revisionid": None, - "dirty": False, "error": None, "date": None} + return { + "version": dirname[len(parentdir_prefix) :], + "full-revisionid": None, + "dirty": False, + "error": None, + "date": None, + } else: rootdirs.append(root) root = os.path.dirname(root) # up a level if verbose: - print("Tried directories %s but none started with prefix %s" % - (str(rootdirs), parentdir_prefix)) + print( + "Tried directories %s but none started with prefix %s" + % (str(rootdirs), parentdir_prefix) + ) raise NotThisMethod("rootdir doesn't start with parentdir_prefix") @@ -183,7 +192,7 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of # just "foo-1.0". If we see a "tag: " prefix, prefer those. TAG = "tag: " - tags = set([r[len(TAG):] for r in refs if r.startswith(TAG)]) + tags = set([r[len(TAG) :] for r in refs if r.startswith(TAG)]) if not tags: # Either we're using git < 1.8.3, or there really are no tags. We use # a heuristic: assume all version tags have a digit. The old git %d @@ -192,7 +201,7 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): # between branches and tags. By ignoring refnames without digits, we # filter out many common branch names like "release" and # "stabilization", as well as "HEAD" and "master". - tags = set([r for r in refs if re.search(r'\d', r)]) + tags = set([r for r in refs if re.search(r"\d", r)]) if verbose: print("discarding '%s', no digits" % ",".join(refs - tags)) if verbose: @@ -200,19 +209,26 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): for ref in sorted(tags): # sorting will prefer e.g. "2.0" over "2.0rc1" if ref.startswith(tag_prefix): - r = ref[len(tag_prefix):] + r = ref[len(tag_prefix) :] if verbose: print("picking %s" % r) - return {"version": r, - "full-revisionid": keywords["full"].strip(), - "dirty": False, "error": None, - "date": date} + return { + "version": r, + "full-revisionid": keywords["full"].strip(), + "dirty": False, + "error": None, + "date": date, + } # no suitable tags, so version is "0+unknown", but full hex is still there if verbose: print("no suitable tags, using unknown + full revision id") - return {"version": "0+unknown", - "full-revisionid": keywords["full"].strip(), - "dirty": False, "error": "no suitable tags", "date": None} + return { + "version": "0+unknown", + "full-revisionid": keywords["full"].strip(), + "dirty": False, + "error": "no suitable tags", + "date": None, + } @register_vcs_handler("git", "pieces_from_vcs") @@ -227,8 +243,7 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): if sys.platform == "win32": GITS = ["git.cmd", "git.exe"] - out, rc = run_command(GITS, ["rev-parse", "--git-dir"], cwd=root, - hide_stderr=True) + out, rc = run_command(GITS, ["rev-parse", "--git-dir"], cwd=root, hide_stderr=True) if rc != 0: if verbose: print("Directory %s not under git control" % root) @@ -236,10 +251,19 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] # if there isn't one, this yields HEX[-dirty] (no NUM) - describe_out, rc = run_command(GITS, ["describe", "--tags", "--dirty", - "--always", "--long", - "--match", "%s*" % tag_prefix], - cwd=root) + describe_out, rc = run_command( + GITS, + [ + "describe", + "--tags", + "--dirty", + "--always", + "--long", + "--match", + "%s*" % tag_prefix, + ], + cwd=root, + ) # --long was added in git-1.5.5 if describe_out is None: raise NotThisMethod("'git describe' failed") @@ -262,17 +286,16 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): dirty = git_describe.endswith("-dirty") pieces["dirty"] = dirty if dirty: - git_describe = git_describe[:git_describe.rindex("-dirty")] + git_describe = git_describe[: git_describe.rindex("-dirty")] # now we have TAG-NUM-gHEX or HEX if "-" in git_describe: # TAG-NUM-gHEX - mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe) + mo = re.search(r"^(.+)-(\d+)-g([0-9a-f]+)$", git_describe) if not mo: # unparseable. Maybe git-describe is misbehaving? - pieces["error"] = ("unable to parse git-describe output: '%s'" - % describe_out) + pieces["error"] = "unable to parse git-describe output: '%s'" % describe_out return pieces # tag @@ -281,10 +304,12 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): if verbose: fmt = "tag '%s' doesn't start with prefix '%s'" print(fmt % (full_tag, tag_prefix)) - pieces["error"] = ("tag '%s' doesn't start with prefix '%s'" - % (full_tag, tag_prefix)) + pieces["error"] = "tag '%s' doesn't start with prefix '%s'" % ( + full_tag, + tag_prefix, + ) return pieces - pieces["closest-tag"] = full_tag[len(tag_prefix):] + pieces["closest-tag"] = full_tag[len(tag_prefix) :] # distance: number of commits since tag pieces["distance"] = int(mo.group(2)) @@ -295,13 +320,13 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): else: # HEX: no tags pieces["closest-tag"] = None - count_out, rc = run_command(GITS, ["rev-list", "HEAD", "--count"], - cwd=root) + count_out, rc = run_command(GITS, ["rev-list", "HEAD", "--count"], cwd=root) pieces["distance"] = int(count_out) # total number of commits # commit date: see ISO-8601 comment in git_versions_from_keywords() - date = run_command(GITS, ["show", "-s", "--format=%ci", "HEAD"], - cwd=root)[0].strip() + date = run_command(GITS, ["show", "-s", "--format=%ci", "HEAD"], cwd=root)[ + 0 + ].strip() # Use only the last line. Previous lines may contain GPG signature # information. date = date.splitlines()[-1] @@ -335,8 +360,7 @@ def render_pep440(pieces): rendered += ".dirty" else: # exception #1 - rendered = "0+untagged.%d.g%s" % (pieces["distance"], - pieces["short"]) + rendered = "0+untagged.%d.g%s" % (pieces["distance"], pieces["short"]) if pieces["dirty"]: rendered += ".dirty" return rendered @@ -450,11 +474,13 @@ def render_git_describe_long(pieces): def render(pieces, style): """Render the given version pieces into the requested style.""" if pieces["error"]: - return {"version": "unknown", - "full-revisionid": pieces.get("long"), - "dirty": None, - "error": pieces["error"], - "date": None} + return { + "version": "unknown", + "full-revisionid": pieces.get("long"), + "dirty": None, + "error": pieces["error"], + "date": None, + } if not style or style == "default": style = "pep440" # the default @@ -474,9 +500,13 @@ def render(pieces, style): else: raise ValueError("unknown style '%s'" % style) - return {"version": rendered, "full-revisionid": pieces["long"], - "dirty": pieces["dirty"], "error": None, - "date": pieces.get("date")} + return { + "version": rendered, + "full-revisionid": pieces["long"], + "dirty": pieces["dirty"], + "error": None, + "date": pieces.get("date"), + } def get_versions(): @@ -490,8 +520,7 @@ def get_versions(): verbose = cfg.verbose try: - return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, - verbose) + return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, verbose) except NotThisMethod: pass @@ -500,13 +529,16 @@ def get_versions(): # versionfile_source is the relative path from the top of the source # tree (where the .git directory might live) to this file. Invert # this to find the root from __file__. - for i in cfg.versionfile_source.split('/'): + for i in cfg.versionfile_source.split("/"): root = os.path.dirname(root) except NameError: - return {"version": "0+unknown", "full-revisionid": None, - "dirty": None, - "error": "unable to find root of source tree", - "date": None} + return { + "version": "0+unknown", + "full-revisionid": None, + "dirty": None, + "error": "unable to find root of source tree", + "date": None, + } try: pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose) @@ -520,6 +552,10 @@ def get_versions(): except NotThisMethod: pass - return {"version": "0+unknown", "full-revisionid": None, - "dirty": None, - "error": "unable to compute version", "date": None} + return { + "version": "0+unknown", + "full-revisionid": None, + "dirty": None, + "error": "unable to compute version", + "date": None, + } diff --git a/prolif/datafiles.py b/prolif/datafiles.py index 2afeee1..22c7d60 100644 --- a/prolif/datafiles.py +++ b/prolif/datafiles.py @@ -3,4 +3,4 @@ datapath = Path(__file__).parent / "data" TOP = str(datapath / "top.pdb") -TRAJ = str(datapath / "traj.xtc") \ No newline at end of file +TRAJ = str(datapath / "traj.xtc") diff --git a/prolif/fingerprint.py b/prolif/fingerprint.py index d656674..0e083fa 100644 --- a/prolif/fingerprint.py +++ b/prolif/fingerprint.py @@ -37,15 +37,21 @@ from .interactions import _INTERACTIONS from .molecule import Molecule -from .parallel import (Progress, ProgressCounter, - declare_shared_objs_for_chunk, - declare_shared_objs_for_mol, process_chunk, process_mol) +from .parallel import ( + Progress, + ProgressCounter, + declare_shared_objs_for_chunk, + declare_shared_objs_for_mol, + process_chunk, + process_mol, +) from .utils import get_residues_near_ligand, to_bitvectors, to_dataframe class _Docstring: """Descriptor that replaces the documentation shown when calling ``fp.hydrophobic?`` and other interaction methods""" + def __init__(self): self._docs = {} @@ -100,6 +106,7 @@ class _InteractionWrapper: Changed from a wrapper function to a class for easier pickling support """ + __doc__ = _Docstring() _current_func = "" @@ -199,8 +206,19 @@ class Fingerprint: """ - def __init__(self, interactions=["Hydrophobic", "HBDonor", "HBAcceptor", - "PiStacking", "Anionic", "Cationic", "CationPi", "PiCation"]): + def __init__( + self, + interactions=[ + "Hydrophobic", + "HBDonor", + "HBAcceptor", + "PiStacking", + "Anionic", + "Cationic", + "CationPi", + "PiCation", + ], + ): self._set_interactions(interactions) def _set_interactions(self, interactions): @@ -248,9 +266,11 @@ def list_available(show_hidden=False): if show_hidden: interactions = [name for name in _INTERACTIONS.keys()] else: - interactions = [name for name in _INTERACTIONS.keys() - if not (name.startswith("_") - or name == "Interaction")] + interactions = [ + name + for name in _INTERACTIONS.keys() + if not (name.startswith("_") or name == "Interaction") + ] return sorted(interactions) @property @@ -379,7 +399,16 @@ def generate(self, lig, prot, residues=None, return_atoms=False): ifp[key] = self.bitvector(lres, pres) return ifp - def run(self, traj, lig, prot, residues=None, converter_kwargs=None, progress=True, n_jobs=None): + def run( + self, + traj, + lig, + prot, + residues=None, + converter_kwargs=None, + progress=True, + n_jobs=None, + ): """Generates the fingerprint on a trajectory for a ligand and a protein Parameters @@ -441,7 +470,7 @@ def run(self, traj, lig, prot, residues=None, converter_kwargs=None, progress=Tr .. versionchanged:: 1.0.0 Added support for multiprocessing - + .. versionchanged:: 1.1.0 Added support for passing kwargs to the RDKitConverter through the ``converter_kwargs`` parameter @@ -452,9 +481,15 @@ def run(self, traj, lig, prot, residues=None, converter_kwargs=None, progress=Tr if converter_kwargs is not None and len(converter_kwargs) != 2: raise ValueError("converter_kwargs must be a list of 2 dicts") if n_jobs != 1: - return self._run_parallel(traj, lig, prot, residues=residues, - converter_kwargs=converter_kwargs, - progress=progress, n_jobs=n_jobs) + return self._run_parallel( + traj, + lig, + prot, + residues=residues, + converter_kwargs=converter_kwargs, + progress=progress, + n_jobs=n_jobs, + ) lig_kwargs, prot_kwargs = converter_kwargs or ({}, {}) iterator = tqdm(traj) if progress else traj @@ -464,15 +499,24 @@ def run(self, traj, lig, prot, residues=None, converter_kwargs=None, progress=Tr for ts in iterator: prot_mol = Molecule.from_mda(prot, **prot_kwargs) lig_mol = Molecule.from_mda(lig, **lig_kwargs) - data = self.generate(lig_mol, prot_mol, residues=residues, - return_atoms=True) + data = self.generate( + lig_mol, prot_mol, residues=residues, return_atoms=True + ) data["Frame"] = ts.frame ifp.append(data) self.ifp = ifp return self - def _run_parallel(self, traj, lig, prot, residues=None, converter_kwargs=None, - progress=True, n_jobs=None): + def _run_parallel( + self, + traj, + lig, + prot, + residues=None, + converter_kwargs=None, + progress=True, + n_jobs=None, + ): """Parallel implementation of :meth:`~Fingerprint.run`""" n_chunks = n_jobs if n_jobs else mp.cpu_count() try: @@ -497,8 +541,11 @@ def _run_parallel(self, traj, lig, prot, residues=None, converter_kwargs=None, pbar_thread = Thread(target=pbar, daemon=True) # run pool of workers - with mp.Pool(n_jobs, initializer=declare_shared_objs_for_chunk, - initargs=(self, residues, progress, pcount, (lig_kwargs, prot_kwargs))) as pool: + with mp.Pool( + n_jobs, + initializer=declare_shared_objs_for_chunk, + initargs=(self, residues, progress, pcount, (lig_kwargs, prot_kwargs)), + ) as pool: pbar_thread.start() args = ((traj, lig, prot, chunk) for chunk in chunks) results = [] @@ -508,8 +555,9 @@ def _run_parallel(self, traj, lig, prot, residues=None, converter_kwargs=None, self.ifp = results return self - def run_from_iterable(self, lig_iterable, prot_mol, residues=None, - progress=True, n_jobs=None): + def run_from_iterable( + self, lig_iterable, prot_mol, residues=None, progress=True, n_jobs=None + ): """Generates the fingerprint between a list of ligands and a protein Parameters @@ -571,28 +619,32 @@ def run_from_iterable(self, lig_iterable, prot_mol, residues=None, raise ValueError("n_jobs must be > 0 or None") if n_jobs != 1: return self._run_iter_parallel( - lig_iterable=lig_iterable, prot_mol=prot_mol, - residues=residues, progress=progress, n_jobs=n_jobs) + lig_iterable=lig_iterable, + prot_mol=prot_mol, + residues=residues, + progress=progress, + n_jobs=n_jobs, + ) iterator = tqdm(lig_iterable) if progress else lig_iterable if residues == "all": residues = prot_mol.residues.keys() ifp = [] for i, lig_mol in enumerate(iterator): - data = self.generate(lig_mol, prot_mol, residues=residues, - return_atoms=True) + data = self.generate( + lig_mol, prot_mol, residues=residues, return_atoms=True + ) data["Frame"] = i ifp.append(data) self.ifp = ifp return self - def _run_iter_parallel(self, lig_iterable, prot_mol, residues=None, - progress=True, n_jobs=None): + def _run_iter_parallel( + self, lig_iterable, prot_mol, residues=None, progress=True, n_jobs=None + ): """Parallel implementation of :meth:`~Fingerprint.run_from_iterable`""" - if ( - isinstance(lig_iterable, Chem.SDMolSupplier) - or (isinstance(lig_iterable, Iterable) - and not isgenerator(lig_iterable)) + if isinstance(lig_iterable, Chem.SDMolSupplier) or ( + isinstance(lig_iterable, Iterable) and not isgenerator(lig_iterable) ): total = len(lig_iterable) else: @@ -601,11 +653,17 @@ def _run_iter_parallel(self, lig_iterable, prot_mol, residues=None, if residues == "all": residues = prot_mol.residues.keys() - with mp.Pool(n_jobs, initializer=declare_shared_objs_for_mol, - initargs=(self, prot_mol, residues)) as pool: + with mp.Pool( + n_jobs, + initializer=declare_shared_objs_for_mol, + initargs=(self, prot_mol, residues), + ) as pool: results = [] - for data in tqdm(pool.imap_unordered(process_mol, suppl), - total=total, disable=not progress): + for data in tqdm( + pool.imap_unordered(process_mol, suppl), + total=total, + disable=not progress, + ): results.append(data) results.sort(key=lambda ifp: ifp["Frame"]) self.ifp = results diff --git a/prolif/interactions.py b/prolif/interactions.py index d6c002f..9eba6fe 100644 --- a/prolif/interactions.py +++ b/prolif/interactions.py @@ -57,11 +57,14 @@ def detect(self, res1, res2, threshold=2.0): class _InteractionMeta(ABCMeta): """Metaclass to register interactions automatically""" + def __init__(cls, name, bases, classdict): type.__init__(cls, name, bases, classdict) if name in _INTERACTIONS.keys(): - warnings.warn(f"The {name!r} interaction has been superseded by a " - f"new class with id {id(cls):#x}") + warnings.warn( + f"The {name!r} interaction has been superseded by a " + f"new class with id {id(cls):#x}" + ) _INTERACTIONS[name] = cls @@ -71,6 +74,7 @@ class Interaction(ABC, metaclass=_InteractionMeta): All interaction classes must inherit this class and define a :meth:`~detect` method """ + @abstractmethod def detect(self, **kwargs): pass @@ -106,6 +110,7 @@ class _Distance(Interaction): distance : float Cutoff distance, measured between the first atom of each pattern """ + def __init__(self, lig_pattern, prot_pattern, distance): self.lig_pattern = MolFromSmarts(lig_pattern) self.prot_pattern = MolFromSmarts(prot_pattern) @@ -115,8 +120,7 @@ def detect(self, lig_res, prot_res): lig_matches = lig_res.GetSubstructMatches(self.lig_pattern) prot_matches = prot_res.GetSubstructMatches(self.prot_pattern) if lig_matches and prot_matches: - for lig_match, prot_match in product(lig_matches, - prot_matches): + for lig_match, prot_match in product(lig_matches, prot_matches): alig = Geometry.Point3D(*lig_res.xyz[lig_match[0]]) aprot = Geometry.Point3D(*prot_res.xyz[prot_match[0]]) if alig.Distance(aprot) <= self.distance: @@ -139,13 +143,13 @@ class Hydrophobic(_Distance): The initial SMARTS pattern was too broad. """ + def __init__( self, hydrophobic=( - "[c,s,Br,I,S&H0&v2," - "$([D3,D4;#6])&!$([#6]~[#7,#8,#9])&!$([#6X4H0]);+0]" + "[c,s,Br,I,S&H0&v2," "$([D3,D4;#6])&!$([#6]~[#7,#8,#9])&!$([#6X4H0]);+0]" ), - distance=4.5 + distance=4.5, ): super().__init__(hydrophobic, hydrophobic, distance) @@ -169,8 +173,9 @@ class _BaseHBond(Interaction): The initial SMARTS pattern was too broad. """ + def __init__( - self, + self, donor="[$([O,S;+0]),$([N;v3,v4&+1]),n+0]-[H]", acceptor=( "[#7&!$([nX3])&!$([NX3]-*=[O,N,P,S])&!$([NX3]-[a])&!$([Nv4&+1])," @@ -178,7 +183,7 @@ def __init__( "F&$(F-[#6])&!$(F-[#6][F,Cl,Br,I])]" ), distance=3.5, - angles=(130, 180) + angles=(130, 180), ): self.donor = MolFromSmarts(donor) self.acceptor = MolFromSmarts(acceptor) @@ -189,8 +194,7 @@ def detect(self, acceptor, donor): acceptor_matches = acceptor.GetSubstructMatches(self.acceptor) donor_matches = donor.GetSubstructMatches(self.donor) if acceptor_matches and donor_matches: - for donor_match, acceptor_match in product(donor_matches, - acceptor_matches): + for donor_match, acceptor_match in product(donor_matches, acceptor_matches): # D-H ... A d = Geometry.Point3D(*donor.xyz[donor_match[0]]) h = Geometry.Point3D(*donor.xyz[donor_match[1]]) @@ -207,6 +211,7 @@ def detect(self, acceptor, donor): class HBDonor(_BaseHBond): """Hbond interaction between a ligand (donor) and a residue (acceptor)""" + def detect(self, ligand, residue): bit, ires, ilig = super().detect(residue, ligand) return bit, ilig, ires @@ -214,6 +219,7 @@ def detect(self, ligand, residue): class HBAcceptor(_BaseHBond): """Hbond interaction between a ligand (acceptor) and a residue (donor)""" + def detect(self, ligand, residue): return super().detect(ligand, residue) @@ -238,12 +244,15 @@ class _BaseXBond(Interaction): ----- Distance and angle adapted from Auffinger et al. PNAS 2004 """ - def __init__(self, - donor="[#6,#7,Si,F,Cl,Br,I]-[Cl,Br,I,At]", - acceptor="[#7,#8,P,S,Se,Te,a;!+{1-}][*]", - distance=3.5, - axd_angles=(130, 180), - xar_angles=(80, 140)): + + def __init__( + self, + donor="[#6,#7,Si,F,Cl,Br,I]-[Cl,Br,I,At]", + acceptor="[#7,#8,P,S,Se,Te,a;!+{1-}][*]", + distance=3.5, + axd_angles=(130, 180), + xar_angles=(80, 140), + ): self.donor = MolFromSmarts(donor) self.acceptor = MolFromSmarts(acceptor) self.distance = distance @@ -254,8 +263,7 @@ def detect(self, acceptor, donor): acceptor_matches = acceptor.GetSubstructMatches(self.acceptor) donor_matches = donor.GetSubstructMatches(self.donor) if acceptor_matches and donor_matches: - for donor_match, acceptor_match in product(donor_matches, - acceptor_matches): + for donor_match, acceptor_match in product(donor_matches, acceptor_matches): # D-X ... A distance d = Geometry.Point3D(*donor.xyz[donor_match[0]]) x = Geometry.Point3D(*donor.xyz[donor_match[1]]) @@ -278,12 +286,14 @@ def detect(self, acceptor, donor): class XBAcceptor(_BaseXBond): """Halogen bonding between a ligand (acceptor) and a residue (donor)""" + def detect(self, ligand, residue): return super().detect(ligand, residue) class XBDonor(_BaseXBond): """Halogen bonding between a ligand (donor) and a residue (acceptor)""" + def detect(self, ligand, residue): bit, ires, ilig = super().detect(residue, ligand) return bit, ilig, ires @@ -296,21 +306,26 @@ class _BaseIonic(_Distance): Handles resonance forms for common acids, amidine and guanidine. """ - def __init__(self, - cation="[+{1-},$([NX3&!$([NX3]-O)]-[C]=[NX3+])]", - anion="[-{1-},$(O=[C,S,P]-[O-])]", - distance=4.5): + + def __init__( + self, + cation="[+{1-},$([NX3&!$([NX3]-O)]-[C]=[NX3+])]", + anion="[-{1-},$(O=[C,S,P]-[O-])]", + distance=4.5, + ): super().__init__(cation, anion, distance) class Cationic(_BaseIonic): """Ionic interaction between a ligand (cation) and a residue (anion)""" + def detect(self, ligand, residue): return super().detect(ligand, residue) class Anionic(_BaseIonic): """Ionic interaction between a ligand (anion) and a residue (cation)""" + def detect(self, ligand, residue): bit, ires, ilig = super().detect(residue, ligand) return bit, ilig, ires @@ -336,11 +351,17 @@ class _BaseCationPi(Interaction): Handles resonance forms for amidine and guanidine as cations. """ - def __init__(self, - cation="[+{1-},$([NX3&!$([NX3]-O)]-[C]=[NX3+])]", - pi_ring=("[a;r6]1:[a;r6]:[a;r6]:[a;r6]:[a;r6]:[a;r6]:1", "[a;r5]1:[a;r5]:[a;r5]:[a;r5]:[a;r5]:1"), - distance=4.5, - angles=(0, 30)): + + def __init__( + self, + cation="[+{1-},$([NX3&!$([NX3]-O)]-[C]=[NX3+])]", + pi_ring=( + "[a;r6]1:[a;r6]:[a;r6]:[a;r6]:[a;r6]:[a;r6]:1", + "[a;r5]1:[a;r5]:[a;r5]:[a;r5]:[a;r5]:1", + ), + distance=4.5, + angles=(0, 30), + ): self.cation = MolFromSmarts(cation) self.pi_ring = [MolFromSmarts(s) for s in pi_ring] self.distance = distance @@ -376,6 +397,7 @@ def detect(self, cation, pi): class PiCation(_BaseCationPi): """Cation-Pi interaction between a ligand (aromatic ring) and a residue (cation)""" + def detect(self, ligand, residue): bit, ires, ilig = super().detect(residue, ligand) return bit, ilig, ires @@ -384,13 +406,14 @@ def detect(self, ligand, residue): class CationPi(_BaseCationPi): """Cation-Pi interaction between a ligand (cation) and a residue (aromatic ring)""" + def detect(self, ligand, residue): return super().detect(ligand, residue) class _BasePiStacking(Interaction): """Base class for Pi-Stacking interactions - + Parameters ---------- centroid_distance : float @@ -410,15 +433,23 @@ class _BasePiStacking(Interaction): the ``shortest_distance`` parameter. """ - def __init__(self, - centroid_distance=5.5, - plane_angle=(0, 35), - normal_to_centroid_angle=(0, 30), - pi_ring=("[a;r6]1:[a;r6]:[a;r6]:[a;r6]:[a;r6]:[a;r6]:1", "[a;r5]1:[a;r5]:[a;r5]:[a;r5]:[a;r5]:1")): + + def __init__( + self, + centroid_distance=5.5, + plane_angle=(0, 35), + normal_to_centroid_angle=(0, 30), + pi_ring=( + "[a;r6]1:[a;r6]:[a;r6]:[a;r6]:[a;r6]:[a;r6]:1", + "[a;r5]1:[a;r5]:[a;r5]:[a;r5]:[a;r5]:1", + ), + ): self.pi_ring = [MolFromSmarts(s) for s in pi_ring] self.centroid_distance = centroid_distance - self.plane_angle = tuple(radians(i) for i in plane_angle) - self.normal_to_centroid_angle = tuple(radians(i) for i in normal_to_centroid_angle) + self.plane_angle = tuple(radians(i) for i in plane_angle) + self.normal_to_centroid_angle = tuple( + radians(i) for i in normal_to_centroid_angle + ) self.edge = False self.ring_radius = 1.7 @@ -437,15 +468,11 @@ def detect(self, ligand, residue): if cdist > self.centroid_distance: continue # ligand - lig_normal = get_ring_normal_vector(lig_centroid, - lig_pi_coords) + lig_normal = get_ring_normal_vector(lig_centroid, lig_pi_coords) # residue - res_normal = get_ring_normal_vector(res_centroid, - res_pi_coords) + res_normal = get_ring_normal_vector(res_centroid, res_pi_coords) plane_angle = lig_normal.AngleTo(res_normal) - if not angle_between_limits( - plane_angle, *self.plane_angle, ring=True - ): + if not angle_between_limits(plane_angle, *self.plane_angle, ring=True): continue c1c2 = lig_centroid.DirectionVector(res_centroid) c2c1 = res_centroid.DirectionVector(lig_centroid) @@ -454,7 +481,8 @@ def detect(self, ligand, residue): if not ( angle_between_limits( n1c1c2, *self.normal_to_centroid_angle, ring=True - ) or angle_between_limits( + ) + or angle_between_limits( n2c2c1, *self.normal_to_centroid_angle, ring=True ) ): @@ -469,7 +497,7 @@ def detect(self, ligand, residue): # check if intersection point falls ~within plane ring intersect_dist = min( lig_centroid.Distance(intersect), - res_centroid.Distance(intersect) + res_centroid.Distance(intersect), ) if intersect_dist > self.ring_radius: continue @@ -478,21 +506,22 @@ def detect(self, ligand, residue): @staticmethod def _get_intersect_point( - plane_normal, plane_centroid, tilted_normal, tilted_centroid, + plane_normal, + plane_centroid, + tilted_normal, + tilted_centroid, ): - # intersect line is orthogonal to both planes normal vectors + # intersect line is orthogonal to both planes normal vectors intersect_direction = plane_normal.CrossProduct(tilted_normal) # setup system of linear equations to solve - A = np.array([list(plane_normal), list(tilted_normal), list(intersect_direction)]) + A = np.array( + [list(plane_normal), list(tilted_normal), list(intersect_direction)] + ) if np.linalg.det(A) == 0: return None - tilted_offset = tilted_normal.DotProduct( - Geometry.Point3D(*tilted_centroid) - ) - plane_offset = plane_normal.DotProduct( - Geometry.Point3D(*plane_centroid) - ) - d = np.array([[plane_offset], [tilted_offset], [0.]]) + tilted_offset = tilted_normal.DotProduct(Geometry.Point3D(*tilted_centroid)) + plane_offset = plane_normal.DotProduct(Geometry.Point3D(*plane_centroid)) + d = np.array([[plane_offset], [tilted_offset], [0.0]]) # point on intersect line point = np.linalg.solve(A, d).T[0] point = Geometry.Point3D(*point) @@ -505,43 +534,55 @@ def _get_intersect_point( class FaceToFace(_BasePiStacking): """Face-to-face Pi-Stacking interaction between a ligand and a residue""" - def __init__(self, - centroid_distance=5.5, - plane_angle=(0, 35), - normal_to_centroid_angle=(0, 33), - pi_ring=("[a;r6]1:[a;r6]:[a;r6]:[a;r6]:[a;r6]:[a;r6]:1", "[a;r5]1:[a;r5]:[a;r5]:[a;r5]:[a;r5]:1")): + + def __init__( + self, + centroid_distance=5.5, + plane_angle=(0, 35), + normal_to_centroid_angle=(0, 33), + pi_ring=( + "[a;r6]1:[a;r6]:[a;r6]:[a;r6]:[a;r6]:[a;r6]:1", + "[a;r5]1:[a;r5]:[a;r5]:[a;r5]:[a;r5]:1", + ), + ): super().__init__( centroid_distance=centroid_distance, plane_angle=plane_angle, normal_to_centroid_angle=normal_to_centroid_angle, - pi_ring=pi_ring + pi_ring=pi_ring, ) class EdgeToFace(_BasePiStacking): """Edge-to-face Pi-Stacking interaction between a ligand and a residue - + .. versionchanged:: 1.1.0 In addition to the changes made to the base pi-stacking interaction, this implementation makes sure that the intersection between the perpendicular ring's plane and the other's plane falls inside the ring. """ - def __init__(self, - centroid_distance=6.5, - plane_angle=(50, 90), - normal_to_centroid_angle=(0, 30), - ring_radius=1.5, - pi_ring=("[a;r6]1:[a;r6]:[a;r6]:[a;r6]:[a;r6]:[a;r6]:1", "[a;r5]1:[a;r5]:[a;r5]:[a;r5]:[a;r5]:1")): + + def __init__( + self, + centroid_distance=6.5, + plane_angle=(50, 90), + normal_to_centroid_angle=(0, 30), + ring_radius=1.5, + pi_ring=( + "[a;r6]1:[a;r6]:[a;r6]:[a;r6]:[a;r6]:[a;r6]:1", + "[a;r5]1:[a;r5]:[a;r5]:[a;r5]:[a;r5]:1", + ), + ): super().__init__( centroid_distance=centroid_distance, plane_angle=plane_angle, normal_to_centroid_angle=normal_to_centroid_angle, - pi_ring=pi_ring + pi_ring=pi_ring, ) self.edge = True self.ring_radius = ring_radius - + def detect(self, ligand, residue): return super().detect(ligand, residue) @@ -559,12 +600,13 @@ class PiStacking(Interaction): .. versionchanged:: 0.3.4 `shortest_distance` has been replaced by `angle_normal_centroid` - + .. versionchanged:: 1.1.0 The implementation now directly calls :class:`EdgeToFace` and :class:`FaceToFace` instead of overwriting the default parameters with more generic ones. """ + def __init__(self, ftf_kwargs=None, etf_kwargs=None): self.ftf = FaceToFace(**ftf_kwargs or {}) self.etf = EdgeToFace(**etf_kwargs or {}) @@ -593,21 +635,26 @@ class _BaseMetallic(_Distance): The initial SMARTS pattern was too broad. """ - def __init__(self, - metal="[Ca,Cd,Co,Cu,Fe,Mg,Mn,Ni,Zn]", - ligand="[O,#7&!$([nX3])&!$([NX3]-*=[!#6])&!$([NX3]-[a])&!$([NX4]),-{1-};!+{1-}]", - distance=2.8): + + def __init__( + self, + metal="[Ca,Cd,Co,Cu,Fe,Mg,Mn,Ni,Zn]", + ligand="[O,#7&!$([nX3])&!$([NX3]-*=[!#6])&!$([NX3]-[a])&!$([NX4]),-{1-};!+{1-}]", + distance=2.8, + ): super().__init__(metal, ligand, distance) class MetalDonor(_BaseMetallic): """Metallic interaction between a metal and a residue (chelated)""" + def detect(self, ligand, residue): return super().detect(ligand, residue) class MetalAcceptor(_BaseMetallic): """Metallic interaction between a ligand (chelated) and a metal residue""" + def detect(self, ligand, residue): bit, ires, ilig = super().detect(residue, ligand) return bit, ilig, ires @@ -627,7 +674,8 @@ class VdWContact(Interaction): ------ ValueError : ``tolerance`` parameter cannot be negative """ - def __init__(self, tolerance=.5): + + def __init__(self, tolerance=0.5): if tolerance >= 0: self.tolerance = tolerance else: @@ -645,8 +693,9 @@ def detect(self, ligand, residue): except KeyError: vdw = vdwradii[lig] + vdwradii[res] + self.tolerance self._vdw_cache[(lig, res)] = vdw - dist = (lxyz.GetAtomPosition(la.GetIdx()) - .Distance(rxyz.GetAtomPosition(ra.GetIdx()))) + dist = lxyz.GetAtomPosition(la.GetIdx()).Distance( + rxyz.GetAtomPosition(ra.GetIdx()) + ) if dist <= vdw: return True, la.GetIdx(), ra.GetIdx() return False, None, None diff --git a/prolif/molecule.py b/prolif/molecule.py index 9ce8764..c506349 100644 --- a/prolif/molecule.py +++ b/prolif/molecule.py @@ -69,6 +69,7 @@ class Molecule(BaseRDKitMol): See :mod:`prolif.residue` for more information on residues """ + def __init__(self, mol): super().__init__(mol) # set mapping of atoms @@ -115,8 +116,7 @@ def from_mda(cls, obj, selection=None, **kwargs): """ ag = obj.select_atoms(selection) if selection else obj.atoms if ag.n_atoms == 0: - raise mda.SelectionError( - f"AtomGroup is empty, please check your selection") + raise mda.SelectionError(f"AtomGroup is empty, please check your selection") mol = ag.convert_to.rdkit(**kwargs) return cls(mol) @@ -149,10 +149,12 @@ def from_rdkit(cls, mol, resname="UNL", resnumber=1, chain=""): return cls(mol) mol = copy.deepcopy(mol) for atom in mol.GetAtoms(): - mi = Chem.AtomPDBResidueInfo(f" {atom.GetSymbol():<3.3}", - residueName=resname, - residueNumber=resnumber, - chainId=chain) + mi = Chem.AtomPDBResidueInfo( + f" {atom.GetSymbol():<3.3}", + residueName=resname, + residueNumber=resnumber, + chainId=chain, + ) atom.SetMonomerInfo(mi) return cls(mol) @@ -163,7 +165,7 @@ def __iter__(self): def __getitem__(self, key): return self.residues[key] - def __repr__(self): # pragma: no cover + def __repr__(self): # pragma: no cover name = ".".join([self.__class__.__module__, self.__class__.__name__]) params = f"{self.n_residues} residues and {self.GetNumAtoms()} atoms" return f"<{name} with {params} at {id(self):#x}>" @@ -211,7 +213,7 @@ class pdbqt_supplier(Sequence): .. versionchanged:: 1.0.0 Molecule suppliers are now sequences that can be reused, indexed, and can return their length, instead of single-use generators. - + .. versionchanged:: 1.1.0 Because the PDBQT supplier needs to strip hydrogen atoms before assigning bond orders from the template, it used to replace them @@ -222,6 +224,7 @@ class pdbqt_supplier(Sequence): A lot of irrelevant warnings and logs have been disabled as well. """ + def __init__(self, paths, template, converter_kwargs=None, **kwargs): self.paths = list(paths) self.template = template @@ -242,8 +245,9 @@ def pdbqt_to_mol(self, pdbqt_path): with catch_warning(message=r"^Failed to guess the mass"): pdbqt = mda.Universe(pdbqt_path) # set attributes needed by the converter - elements = [mda.topology.guessers.guess_atom_element(x) - for x in pdbqt.atoms.names] + elements = [ + mda.topology.guessers.guess_atom_element(x) for x in pdbqt.atoms.names + ] pdbqt.add_TopologyAttr("elements", elements) pdbqt.add_TopologyAttr("chainIDs", pdbqt.atoms.segids) pdbqt.atoms.types = pdbqt.atoms.elements @@ -330,6 +334,7 @@ class sdf_supplier(Sequence): and can return their length, instead of single-use generators. """ + def __init__(self, path, **kwargs): self.path = path self._suppl = Chem.SDMolSupplier(path, removeHs=False) @@ -379,6 +384,7 @@ class mol2_supplier(Sequence): and can return their length, instead of single-use generators. """ + def __init__(self, path, **kwargs): self.path = path self._kwargs = kwargs diff --git a/prolif/parallel.py b/prolif/parallel.py index 1977c69..c8b86c5 100644 --- a/prolif/parallel.py +++ b/prolif/parallel.py @@ -15,8 +15,7 @@ def process_chunk(args): for ts in traj[chunk]: lig_mol = Molecule.from_mda(lig, **lig_kwargs) prot_mol = Molecule.from_mda(prot, **prot_kwargs) - data = fp.generate(lig_mol, prot_mol, residues=residues, - return_atoms=True) + data = fp.generate(lig_mol, prot_mol, residues=residues, return_atoms=True) data["Frame"] = ts.frame ifp.append(data) if display_progress: @@ -25,8 +24,9 @@ def process_chunk(args): return ifp -def declare_shared_objs_for_chunk(fingerprint, resid_list, show_progressbar, - progress_counter, rdkitconverter_kwargs): +def declare_shared_objs_for_chunk( + fingerprint, resid_list, show_progressbar, progress_counter, rdkitconverter_kwargs +): """Declares global objects that are available to the pool of workers for a trajectory""" global fp, residues, display_progress, pcount, converter_kwargs @@ -57,6 +57,7 @@ def declare_shared_objs_for_mol(fingerprint, pmol, resid_list): class ProgressCounter: """Tracks the progress of the fingerprint analysis accross the pool of workers""" + def __init__(self): self.lock = mp.Lock() self.counter = mp.Value(c_int32) @@ -65,6 +66,7 @@ def __init__(self): class Progress: """Handles tracking the progress of the ProgressCounter and updating the tqdm progress bar, from within an independent thread""" + def __init__(self, pcount, *args, **kwargs): self.pbar = tqdm(*args, **kwargs) self.pcount = pcount diff --git a/prolif/plotting/network.py b/prolif/plotting/network.py index 8b77064..a98bb4d 100644 --- a/prolif/plotting/network.py +++ b/prolif/plotting/network.py @@ -20,13 +20,15 @@ from rdkit.Chem import rdDepictor from ..residue import ResidueId from ..utils import requires + try: from IPython.display import HTML except ModuleNotFoundError: pass else: - warnings.filterwarnings("ignore", # pragma: no cover - "Consider using IPython.display.IFrame instead") + warnings.filterwarnings( + "ignore", "Consider using IPython.display.IFrame instead" # pragma: no cover + ) class LigNetwork: @@ -74,6 +76,7 @@ class LigNetwork: :attr:`LigNetwork.RESIDUE_TYPES` by adding or modifying the dictionaries inplace. """ + COLORS = { "interactions": { "Hydrophobic": "#59e382", @@ -106,41 +109,40 @@ class LigNetwork: "Acidic": "#e35959", "Basic": "#5979e3", "Polar": "#59bee3", - "Sulfur": "#e3ce59" - } + "Sulfur": "#e3ce59", + }, } RESIDUE_TYPES = { - 'ALA': "Aliphatic", - 'GLY': "Aliphatic", - 'ILE': "Aliphatic", - 'LEU': "Aliphatic", - 'PRO': "Aliphatic", - 'VAL': "Aliphatic", - 'PHE': "Aromatic", - 'TRP': "Aromatic", - 'TYR': "Aromatic", - 'ASP': "Acidic", - 'GLU': "Acidic", - 'ARG': "Basic", - 'HIS': "Basic", - 'HID': "Basic", - 'HIE': "Basic", - 'HIP': "Basic", - 'HSD': "Basic", - 'HSE': "Basic", - 'HSP': "Basic", - 'LYS': "Basic", - 'SER': "Polar", - 'THR': "Polar", - 'ASN': "Polar", - 'GLN': "Polar", - 'CYS': "Sulfur", - 'CYM': "Sulfur", - 'CYX': "Sulfur", - 'MET': "Sulfur", + "ALA": "Aliphatic", + "GLY": "Aliphatic", + "ILE": "Aliphatic", + "LEU": "Aliphatic", + "PRO": "Aliphatic", + "VAL": "Aliphatic", + "PHE": "Aromatic", + "TRP": "Aromatic", + "TYR": "Aromatic", + "ASP": "Acidic", + "GLU": "Acidic", + "ARG": "Basic", + "HIS": "Basic", + "HID": "Basic", + "HIE": "Basic", + "HIP": "Basic", + "HSD": "Basic", + "HSE": "Basic", + "HSP": "Basic", + "LYS": "Basic", + "SER": "Polar", + "THR": "Polar", + "ASN": "Polar", + "GLN": "Polar", + "CYS": "Sulfur", + "CYM": "Sulfur", + "CYX": "Sulfur", + "MET": "Sulfur", } - _LIG_PI_INTERACTIONS = ["EdgeToFace", "FaceToFace", "PiStacking", - "PiCation"] + _LIG_PI_INTERACTIONS = ["EdgeToFace", "FaceToFace", "PiStacking", "PiCation"] _JS_TEMPLATE = """ var ifp, legend, nodes, edges, legend_buttons; function drawGraph(_id, nodes, edges, options) { @@ -189,8 +191,16 @@ class LigNetwork: """ - def __init__(self, df, lig_mol, match3D=True, kekulize=False, molsize=35, - rotation=0, carbon=.16): + def __init__( + self, + df, + lig_mol, + match3D=True, + kekulize=False, + molsize=35, + rotation=0, + carbon=0.16, + ): self.df = df mol = deepcopy(lig_mol) Chem.SanitizeMol(mol, Chem.SanitizeFlags.SANITIZE_SETAROMATICITY) @@ -212,23 +222,20 @@ def __init__(self, df, lig_mol, match3D=True, kekulize=False, molsize=35, xyz = np.concatenate([xy, z], axis=1) if carbon: self._carbon = { - 'label': " ", - 'shape': "dot", - 'color': self.COLORS["atoms"]["C"], - 'size': molsize * carbon, + "label": " ", + "shape": "dot", + "color": self.COLORS["atoms"]["C"], + "size": molsize * carbon, } else: - self._carbon = { - 'label': " ", - 'shape': "text" - } + self._carbon = {"label": " ", "shape": "text"} self.xyz = molsize * xyz self.mol = mol self._multiplier = molsize self.options = {} self._max_interaction_width = 6 - self._avoidOverlap = .8 - self._springConstant = .1 + self._avoidOverlap = 0.8 + self._springConstant = 0.1 self._bond_color = "black" self._default_atom_color = "grey" self._default_residue_color = "#dbdbdb" @@ -239,13 +246,10 @@ def __init__(self, df, lig_mol, match3D=True, kekulize=False, molsize=35, for interaction, color in self.COLORS["interactions"].items(): if interaction in interactions: temp[color].append(interaction) - self._interaction_types = {i: "/".join(t) - for c, t in temp.items() - for i in t} + self._interaction_types = {i: "/".join(t) for c, t in temp.items() for i in t} @classmethod - def from_ifp(cls, ifp, lig, kind="aggregate", frame=0, threshold=.3, - **kwargs): + def from_ifp(cls, ifp, lig, kind="aggregate", frame=0, threshold=0.3, **kwargs): """Helper method to create a ligand interaction diagram from an IFP DataFrame obtained with ``fp.to_dataframe(return_atoms=True)`` @@ -275,44 +279,46 @@ def from_ifp(cls, ifp, lig, kind="aggregate", frame=0, threshold=.3, Other arguments passed to the :class:`LigNetwork` class """ if kind == "aggregate": - data = (pd.get_dummies(ifp.applymap(lambda x: x[0]) - .astype(object), - prefix_sep=", ") - .rename(columns=lambda x: - x.translate({ord(c): None for c in "()'"})) - .mean()) + data = ( + pd.get_dummies( + ifp.applymap(lambda x: x[0]).astype(object), prefix_sep=", " + ) + .rename(columns=lambda x: x.translate({ord(c): None for c in "()'"})) + .mean() + ) index = [i.split(", ") for i in data.index] - index = [[j for j in i[:-1]+[int(float(i[-1]))]] for i in index] + index = [[j for j in i[:-1] + [int(float(i[-1]))]] for i in index] data.index = pd.MultiIndex.from_tuples( - index, - names=["ligand", "protein", "interaction", "atom"]) + index, names=["ligand", "protein", "interaction", "atom"] + ) data = data.to_frame() data.rename(columns={data.columns[-1]: "weight"}, inplace=True) # merge different ligand atoms before applying the threshold data = data.join( data.groupby(level=["ligand", "protein", "interaction"]).sum(), - rsuffix="_total") + rsuffix="_total", + ) # threshold and keep most occuring atom - data = (data - .loc[data["weight_total"] >= threshold] - .drop(columns="weight_total") - .sort_values("weight", ascending=False) - .groupby(level=["ligand", "protein", "interaction"]) - .head(1) - .sort_index()) + data = ( + data.loc[data["weight_total"] >= threshold] + .drop(columns="weight_total") + .sort_values("weight", ascending=False) + .groupby(level=["ligand", "protein", "interaction"]) + .head(1) + .sort_index() + ) return cls(data, lig, **kwargs) elif kind == "frame": - data = (ifp - .loc[ifp.index == frame] - .T - .applymap(lambda x: x[0]) - .dropna() - .astype(int) - .reset_index()) + data = ( + ifp.loc[ifp.index == frame] + .T.applymap(lambda x: x[0]) + .dropna() + .astype(int) + .reset_index() + ) data.rename(columns={data.columns[-1]: "atom"}, inplace=True) data["weight"] = 1 - data.set_index(["ligand", "protein", "interaction", "atom"], - inplace=True) + data.set_index(["ligand", "protein", "interaction", "atom"], inplace=True) return cls(data, lig, **kwargs) else: raise ValueError(f'{kind!r} must be "aggregate" or "frame"') @@ -329,8 +335,9 @@ def _make_lig_node(self, atom): return charge = atom.GetFormalCharge() if charge != 0: - charge = "{}{}".format('' if abs(charge) == 1 else str(charge), - '+' if charge > 0 else '-') + charge = "{}{}".format( + "" if abs(charge) == 1 else str(charge), "+" if charge > 0 else "-" + ) label = f"{elem}{charge}" shape = "ellipse" else: @@ -340,22 +347,23 @@ def _make_lig_node(self, atom): node = self._make_carbon() else: node = { - 'label': label, - 'shape': shape, - 'color': "white", - 'font': { - 'color': self.COLORS["atoms"].get(elem, - self._default_atom_color) - } + "label": label, + "shape": shape, + "color": "white", + "font": { + "color": self.COLORS["atoms"].get(elem, self._default_atom_color) + }, + } + node.update( + { + "id": idx, + "x": float(self.xyz[idx, 0]), + "y": float(self.xyz[idx, 1]), + "fixed": True, + "group": "ligand", + "borderWidth": 0, } - node.update({ - 'id': idx, - 'x': float(self.xyz[idx, 0]), - 'y': float(self.xyz[idx, 1]), - 'fixed': True, - 'group': "ligand", - 'borderWidth': 0 - }) + ) self.nodes[idx] = node def _make_lig_edge(self, bond): @@ -365,18 +373,20 @@ def _make_lig_edge(self, bond): return btype = bond.GetBondTypeAsDouble() if btype == 1: - self.edges.append({ - 'from': idx[0], 'to': idx[1], - 'color': self._bond_color, - 'physics': False, - 'group': "ligand", - 'width': 4, - }) + self.edges.append( + { + "from": idx[0], + "to": idx[1], + "color": self._bond_color, + "physics": False, + "group": "ligand", + "width": 4, + } + ) else: self._make_non_single_bond(idx, btype) - def _make_non_single_bond(self, ids, btype, bdist=.06, - dash=[10]): + def _make_non_single_bond(self, ids, btype, bdist=0.06, dash=[10]): """Prepare double, triple and aromatic bonds""" xyz = self.xyz[ids] d = xyz[1, :2] - xyz[0, :2] @@ -392,32 +402,48 @@ def _make_non_single_bond(self, ids, btype, bdist=.06, _id = hash(xy.tobytes()) nodes.append(_id) self.nodes[_id] = { - 'id': _id, 'x': xy[0], 'y': xy[1], - 'shape': "text", "label": " ", - 'fixed': True, 'physics': False} + "id": _id, + "x": xy[0], + "y": xy[1], + "shape": "text", + "label": " ", + "fixed": True, + "physics": False, + } l1, l2, r1, r2 = nodes - self.edges.extend([ - {'from': l1, 'to': l2, - 'color': self._bond_color, - 'physics': False, - 'dashes': dashes, - 'group': "ligand", - 'width': 4}, - {'from': r1, 'to': r2, - 'color': self._bond_color, - 'physics': False, - 'dashes': dashes, - 'group': "ligand", - 'width': 4} - ]) + self.edges.extend( + [ + { + "from": l1, + "to": l2, + "color": self._bond_color, + "physics": False, + "dashes": dashes, + "group": "ligand", + "width": 4, + }, + { + "from": r1, + "to": r2, + "color": self._bond_color, + "physics": False, + "dashes": dashes, + "group": "ligand", + "width": 4, + }, + ] + ) if btype == 3: - self.edges.append({ - 'from': ids[0], 'to': ids[1], - 'color': self._bond_color, - 'physics': False, - 'group': "ligand", - 'width': 4 - }) + self.edges.append( + { + "from": ids[0], + "to": ids[1], + "color": self._bond_color, + "physics": False, + "group": "ligand", + "width": 4, + } + ) def _make_interactions(self, mass=2): """Prepare lig-prot interactions""" @@ -426,44 +452,46 @@ def _make_interactions(self, mass=2): resname = ResidueId.from_string(prot_res).name restype = self.RESIDUE_TYPES.get(resname) restypes[prot_res] = restype - color = self.COLORS["residues"].get(restype, - self._default_residue_color) + color = self.COLORS["residues"].get(restype, self._default_residue_color) node = { - 'id': prot_res, - 'label': prot_res, - 'color': color, - 'shape': "box", - 'borderWidth': 0, - 'physics': True, - 'mass': mass, - 'group': "protein", - 'residue_type': restype, + "id": prot_res, + "label": prot_res, + "color": color, + "shape": "box", + "borderWidth": 0, + "physics": True, + "mass": mass, + "group": "protein", + "residue_type": restype, } self.nodes[prot_res] = node - for ((lig_res, prot_res, interaction, lig_id), - (weight,)) in self.df.iterrows(): + for ((lig_res, prot_res, interaction, lig_id), (weight,)) in self.df.iterrows(): if interaction in self._LIG_PI_INTERACTIONS: centroid = self._get_ring_centroid(lig_id) origin = str((lig_res, prot_res, interaction)) - self.nodes[origin] = {'id': origin, - 'x': centroid[0], 'y': centroid[1], - 'shape': "text", "label": " ", - 'fixed': True, 'physics': False} + self.nodes[origin] = { + "id": origin, + "x": centroid[0], + "y": centroid[1], + "shape": "text", + "label": " ", + "fixed": True, + "physics": False, + } else: origin = int(lig_id) edge = { - 'from': origin, 'to': prot_res, - 'title': interaction, - 'interaction_type': self._interaction_types[interaction], - 'color': self.COLORS["interactions"].get( - interaction, - self._default_interaction_color), - 'smooth': { - 'type': 'cubicBezier', - 'roundness': .2}, - 'dashes': [10], - 'width': weight * self._max_interaction_width, - 'group': "interaction", + "from": origin, + "to": prot_res, + "title": interaction, + "interaction_type": self._interaction_types[interaction], + "color": self.COLORS["interactions"].get( + interaction, self._default_interaction_color + ), + "smooth": {"type": "cubicBezier", "roundness": 0.2}, + "dashes": [10], + "width": weight * self._max_interaction_width, + "group": "interaction", } self.edges.append(edge) @@ -473,8 +501,9 @@ def _get_ring_centroid(self, index): if index in r: break else: - raise ValueError("No ring containing this atom index was found in " - "the given molecule") + raise ValueError( + "No ring containing this atom index was found in " "the given molecule" + ) return self.xyz[list(r)].mean(axis=0) def _patch_hydrogens(self): @@ -493,7 +522,7 @@ def _patch_hydrogens(self): for idx, nH in to_patch.items(): node = self.nodes[idx] h_str = "H" if nH == 1 else f"H{nH}" - label = re.sub(r'(\w+)(.*)', fr'\1{h_str}\2', node["label"]) + label = re.sub(r"(\w+)(.*)", rf"\1{h_str}\2", node["label"]) node["label"] = label node["shape"] = "ellipse" @@ -512,8 +541,7 @@ def _make_graph_data(self): self._patch_hydrogens() self.nodes = list(self.nodes.values()) - def _get_js(self, width="100%", height="500px", div_id="mynetwork", - fontsize=20): + def _get_js(self, width="100%", height="500px", div_id="mynetwork", fontsize=20): """Returns the JavaScript code to draw the network""" self.width = width self.height = height @@ -529,13 +557,15 @@ def _get_js(self, width="100%", height="500px", div_id="mynetwork", "avoidOverlap": self._avoidOverlap, "springConstant": self._springConstant, } - } + }, } options.update(self.options) - js = self._JS_TEMPLATE % dict(div_id=div_id, - nodes=json.dumps(self.nodes), - edges=json.dumps(self.edges), - options=json.dumps(options)) + js = self._JS_TEMPLATE % dict( + div_id=div_id, + nodes=json.dumps(self.nodes), + edges=json.dumps(self.edges), + options=json.dumps(options), + ) js += self._get_legend() return js @@ -547,41 +577,44 @@ def _get_legend(self, height="90px"): available = {} buttons = [] map_color_restype = {c: t for t, c in self.COLORS["residues"].items()} - map_color_interactions = {self.COLORS["interactions"][i]: t - for i, t in self._interaction_types.items()} + map_color_interactions = { + self.COLORS["interactions"][i]: t + for i, t in self._interaction_types.items() + } # residues for node in self.nodes: if node.get("group", "") == "protein": color = node["color"] available[color] = map_color_restype.get(color, "Unknown") - available = {k: v for k, v in sorted(available.items(), - key=lambda item: item[1])} + available = { + k: v for k, v in sorted(available.items(), key=lambda item: item[1]) + } for i, (color, restype) in enumerate(available.items()): - buttons.append({ - "index": i, - "label": restype, - "color": color, - "group": "residues" - }) + buttons.append( + {"index": i, "label": restype, "color": color, "group": "residues"} + ) # interactions available.clear() for edge in self.edges: if edge.get("group", "") == "interaction": color = edge["color"] available[color] = map_color_interactions.get(color, "Unknown") - available = {k: v for k, v in sorted(available.items(), - key=lambda item: item[1])} + available = { + k: v for k, v in sorted(available.items(), key=lambda item: item[1]) + } for i, (color, interaction) in enumerate(available.items()): - buttons.append({ + buttons.append( + { "index": i, "label": interaction, "color": color, - "group": "interactions" - }) + "group": "interactions", + } + ) # JS code if all("px" in h for h in [self.height, height]): - h1 = int(re.findall(r'(\d+)\w+', self.height)[0]) - h2 = int(re.findall(r'(\d+)\w+', height)[0]) + h1 = int(re.findall(r"(\d+)\w+", self.height)[0]) + h2 = int(re.findall(r"(\d+)\w+", height)[0]) self.height = f"{h1+h2}px" return """ legend_buttons = %(buttons)s; @@ -683,17 +716,21 @@ def _get_legend(self, height="90px"): }); legend.appendChild(div_residues); legend.appendChild(div_interactions); - """ % dict(div_id="networklegend", - buttons=json.dumps(buttons)) + """ % dict( + div_id="networklegend", buttons=json.dumps(buttons) + ) @requires("IPython.display") def display(self, **kwargs): """Prepare and display the network""" html = self._get_html(**kwargs) - iframe = ('') - return HTML(iframe.format(width=self.width, height=self.height, - doc=escape(html))) + iframe = ( + '' + ) + return HTML( + iframe.format(width=self.width, height=self.height, doc=escape(html)) + ) @requires("IPython.display") def show(self, filename, **kwargs): @@ -701,10 +738,13 @@ def show(self, filename, **kwargs): html = self._get_html(**kwargs) with open(filename, "w") as f: f.write(html) - iframe = ('') - return HTML(iframe.format(width=self.width, height=self.height, - filename=filename)) + iframe = ( + '' + ) + return HTML( + iframe.format(width=self.width, height=self.height, filename=filename) + ) def save(self, fp, **kwargs): """Save the network to an HTML file diff --git a/prolif/rdkitmol.py b/prolif/rdkitmol.py index b82a579..eed076c 100644 --- a/prolif/rdkitmol.py +++ b/prolif/rdkitmol.py @@ -5,6 +5,7 @@ from rdkit import Chem from rdkit.Chem.rdMolTransforms import ComputeCentroid + class BaseRDKitMol(Chem.Mol): """Base molecular class that behaves like an RDKit :class:`~rdkit.Chem.rdchem.Mol` with extra attributes (see below). @@ -24,6 +25,7 @@ class BaseRDKitMol(Chem.Mol): xyz : numpy.ndarray XYZ coordinates of all atoms in the molecule """ + @property def centroid(self): return ComputeCentroid(self.GetConformer()) diff --git a/prolif/residue.py b/prolif/residue.py index df5f1b7..b248770 100644 --- a/prolif/residue.py +++ b/prolif/residue.py @@ -11,7 +11,7 @@ from .rdkitmol import BaseRDKitMol -_RE_RESID = re.compile(r'([A-Z]{,3})?(\d*)\.?(\w)?') +_RE_RESID = re.compile(r"([A-Z]{,3})?(\d*)\.?(\w)?") NoneType = type(None) @@ -27,10 +27,8 @@ class ResidueId: chain : str or None, optionnal 1-letter protein chain """ - def __init__(self, - name: str = "UNK", - number: int = 0, - chain: Optional[str] = None): + + def __init__(self, name: str = "UNK", number: int = 0, chain: Optional[str] = None): self.name = name or "UNK" self.number = number or 0 self.chain = chain or None @@ -129,12 +127,13 @@ class Residue(BaseRDKitMol): The name of the residue can be converted to a string by using ``str(Residue)`` """ + def __init__(self, mol): super().__init__(mol) FastFindRings(self) self.resid = ResidueId.from_atom(self.GetAtomWithIdx(0)) - def __repr__(self): # pragma: no cover + def __repr__(self): # pragma: no cover name = ".".join([self.__class__.__module__, self.__class__.__name__]) return f"<{name} {self.resid} at {id(self):#x}>" @@ -162,10 +161,12 @@ class ResidueGroup(UserDict): You can also use the :meth:`~prolif.residue.ResidueGroup.select` method to access a subset of a ResidueGroup. """ + def __init__(self, residues: List[Residue]): self._residues = np.asarray(residues, dtype=object) - resinfo = [(r.resid.name, r.resid.number, r.resid.chain) - for r in self._residues] + resinfo = [ + (r.resid.name, r.resid.number, r.resid.chain) for r in self._residues + ] try: name, number, chain = zip(*resinfo) except ValueError: @@ -181,8 +182,10 @@ def __init__(self, residues: List[Residue]): def __getitem__(self, key): # bool is a subclass of int but shouldn't be used here if isinstance(key, bool): - raise KeyError("Expected a ResidueId, int, or str, " - f"got {type(key).__name__!r} instead") + raise KeyError( + "Expected a ResidueId, int, or str, " + f"got {type(key).__name__!r} instead" + ) if isinstance(key, int): return self._residues[key] elif isinstance(key, str): @@ -190,8 +193,9 @@ def __getitem__(self, key): return self.data[key] elif isinstance(key, ResidueId): return self.data[key] - raise KeyError("Expected a ResidueId, int, or str, " - f"got {type(key).__name__!r} instead") + raise KeyError( + "Expected a ResidueId, int, or str, " f"got {type(key).__name__!r} instead" + ) def select(self, mask): """Locate a subset of a ResidueGroup based on a boolean mask @@ -235,7 +239,7 @@ def select(self, mask): """ return ResidueGroup(self._residues[mask]) - def __repr__(self): # pragma: no cover + def __repr__(self): # pragma: no cover name = ".".join([self.__class__.__module__, self.__class__.__name__]) return f"<{name} with {self.n_residues} residues at {id(self):#x}>" diff --git a/prolif/utils.py b/prolif/utils.py index a68220a..ed64542 100644 --- a/prolif/utils.py +++ b/prolif/utils.py @@ -21,7 +21,7 @@ from .residue import ResidueId -_90_deg_to_rad = pi/2 +_90_deg_to_rad = pi / 2 def requires(module): # pragma: no cover @@ -32,8 +32,11 @@ def wrapper(*args, **kwargs): return func(*args, **kwargs) raise ModuleNotFoundError( f"The module {module!r} is required to use {func.__name__!r} " - "but it is not installed!") + "but it is not installed!" + ) + return wrapper + return inner @@ -102,7 +105,7 @@ def angle_between_limits(angle, min_angle, max_angle, ring=False): angle %= _90_deg_to_rad elif angle > _90_deg_to_rad: angle = _90_deg_to_rad - (angle % _90_deg_to_rad) - return (min_angle <= angle <= max_angle) + return min_angle <= angle <= max_angle def get_residues_near_ligand(lig, prot, cutoff=6.0): @@ -152,12 +155,12 @@ def split_mol_by_residues(mol): for res in SplitMolByPDBResidues(mol).values(): for frag in GetMolFrags(res, asMols=True, sanitizeFrags=False): # count number of unique residues in the fragment - resids = {a.GetIdx(): ResidueId.from_atom(a) - for a in frag.GetAtoms()} + resids = {a.GetIdx(): ResidueId.from_atom(a) for a in frag.GetAtoms()} if len(set(resids.values())) > 1: # split on peptide bonds - bonds = [b.GetIdx() for b in frag.GetBonds() - if is_peptide_bond(b, resids)] + bonds = [ + b.GetIdx() for b in frag.GetBonds() if is_peptide_bond(b, resids) + ] mols = FragmentOnBonds(frag, bonds, addDummies=False) mols = GetMolFrags(mols, asMols=True, sanitizeFrags=False) residues.extend(mols) @@ -181,8 +184,14 @@ def is_peptide_bond(bond, resids): return resids[bond.GetBeginAtomIdx()] != resids[bond.GetEndAtomIdx()] -def to_dataframe(ifp, interactions, index_col="Frame", dtype=None, - drop_empty=True, return_atoms=False): +def to_dataframe( + ifp, + interactions, + index_col="Frame", + dtype=None, + drop_empty=True, + return_atoms=False, +): """Converts IFPs to a pandas DataFrame Parameters @@ -243,8 +252,9 @@ def to_dataframe(ifp, interactions, index_col="Frame", dtype=None, has_atom_indices = isinstance(value[0], Iterable) break if return_atoms and not has_atom_indices: - raise ValueError("The IFP either doesn't contain atom indices or is " - "formatted incorrectly") + raise ValueError( + "The IFP either doesn't contain atom indices or is " "formatted incorrectly" + ) # create empty array for each residue pair interaction that doesn't exist # in a particular frame if has_atom_indices and return_atoms: @@ -269,31 +279,32 @@ def to_dataframe(ifp, interactions, index_col="Frame", dtype=None, data[key].append(arr) index = pd.Series(index, name=index_col) # create dataframes - values = np.array([np.hstack([np.ravel(a[i]) for a in data.values()]) - for i in range(len(index))]) + values = np.array( + [np.hstack([np.ravel(a[i]) for a in data.values()]) for i in range(len(index))] + ) if has_atom_indices and return_atoms: columns = pd.MultiIndex.from_tuples( - [(str(k[0]), str(k[1]), i, a) - for k in keys - for i in interactions - for a in ["ligand", "protein"]], - names=["ligand", "protein", "interaction", "atom"]) + [ + (str(k[0]), str(k[1]), i, a) + for k in keys + for i in interactions + for a in ["ligand", "protein"] + ], + names=["ligand", "protein", "interaction", "atom"], + ) else: columns = pd.MultiIndex.from_tuples( - [(str(k[0]), str(k[1]), i) - for k in keys - for i in interactions], - names=["ligand", "protein", "interaction"]) + [(str(k[0]), str(k[1]), i) for k in keys for i in interactions], + names=["ligand", "protein", "interaction"], + ) df = pd.DataFrame(values, columns=columns, index=index) if has_atom_indices and return_atoms: - df = (df.groupby(axis=1, level=["ligand", "protein", "interaction"]) - .agg(tuple)) + df = df.groupby(axis=1, level=["ligand", "protein", "interaction"]).agg(tuple) if dtype: df = df.astype(dtype) if drop_empty: if has_atom_indices and return_atoms: - mask = df.apply(lambda s: - ~(s.isin([(None, None)]).all()), axis=0) + mask = df.apply(lambda s: ~(s.isin([(None, None)]).all()), axis=0) else: mask = (df != empty_value).any(axis=0) df = df.loc[:, mask] diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..7c294e0 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,11 @@ +[tool.black] +line-length = 88 +extend-exclude = ''' +( + ^/versioneer.py + | ^/docs/conf.py +) +''' + +[tool.isort] +profile = "black" \ No newline at end of file diff --git a/setup.py b/setup.py index 47d8c3e..6abbbfe 100644 --- a/setup.py +++ b/setup.py @@ -15,4 +15,4 @@ if re.match(r"^20[0-1][0-9]\.", rdkit_version): raise ValueError("ProLIF requires a version of RDKit >= 2020") -setup(version=versioneer.get_version()) \ No newline at end of file +setup(version=versioneer.get_version()) diff --git a/tests/mol2factory.py b/tests/mol2factory.py index 3a75c5f..6801305 100644 --- a/tests/mol2factory.py +++ b/tests/mol2factory.py @@ -13,56 +13,74 @@ def from_mol2(f): u.atoms.types = np.array([x.upper() for x in u.atoms.types], dtype=object) return Molecule.from_mda(u, force=True) + def benzene(): return from_mol2("benzene.mol2") + def cation(): return from_mol2("cation.mol2") + def cation_false(): return from_mol2("cation_false.mol2") + def anion(): return from_mol2("anion.mol2") + def ftf(): return from_mol2("facetoface.mol2") + def etf(): return from_mol2("edgetoface.mol2") + def chlorine(): return from_mol2("chlorine.mol2") + def bromine(): return from_mol2("bromine.mol2") + def hb_donor(): return from_mol2("donor.mol2") + def hb_acceptor(): return from_mol2("acceptor.mol2") + def hb_acceptor_false(): return from_mol2("acceptor_false.mol2") + def xb_donor(): return from_mol2("xbond_donor.mol2") + def xb_acceptor(): return from_mol2("xbond_acceptor.mol2") + def xb_acceptor_false_xar(): return from_mol2("xbond_acceptor_false_xar.mol2") + def xb_acceptor_false_axd(): return from_mol2("xbond_acceptor_false_axd.mol2") + def ligand(): return from_mol2("ligand.mol2") + def metal(): return from_mol2("metal.mol2") + def metal_false(): return from_mol2("metal_false.mol2") diff --git a/tests/notebooks/viz-pi-stacking.ipynb b/tests/notebooks/viz-pi-stacking.ipynb index fe619b9..48eabc4 100644 --- a/tests/notebooks/viz-pi-stacking.ipynb +++ b/tests/notebooks/viz-pi-stacking.ipynb @@ -12,7 +12,7 @@ "from MDAnalysis.transformations import translate, rotateby\n", "import prolif as plf\n", "from rdkit.Geometry import Point3D\n", - "from ipywidgets import interactive, HBox, Layout,VBox" + "from ipywidgets import interactive, HBox, Layout, VBox" ] }, { @@ -27,13 +27,14 @@ "u1.segments.segids = np.array([\"U1\"], dtype=object)\n", "u1.transfer_to_memory()\n", "\n", - "def create(xyz=[0, 0, 0], rotation=[0,0,0]):\n", + "\n", + "def create(xyz=[0, 0, 0], rotation=[0, 0, 0]):\n", " u2 = u1.copy()\n", " u2.segments.segids = np.array([\"U2\"], dtype=object)\n", " tr = translate(xyz)\n", - " rotx = rotateby(rotation[0], [1,0,0], ag=u2.atoms)\n", - " roty = rotateby(rotation[1], [0,1,0], ag=u2.atoms)\n", - " rotz = rotateby(rotation[2], [0,0,1], ag=u2.atoms)\n", + " rotx = rotateby(rotation[0], [1, 0, 0], ag=u2.atoms)\n", + " roty = rotateby(rotation[1], [0, 1, 0], ag=u2.atoms)\n", + " rotz = rotateby(rotation[2], [0, 0, 1], ag=u2.atoms)\n", " u2.trajectory.add_transformations(tr, rotx, roty, rotz)\n", " u2.transfer_to_memory()\n", " u = mda.Merge(u1.atoms, u2.atoms)\n", @@ -47,28 +48,30 @@ "outputs": [], "source": [ "fp = plf.Fingerprint()\n", - "rad90 = np.pi/2\n", + "rad90 = np.pi / 2\n", + "\n", "\n", "def cap_angle(angle, cap=rad90):\n", " if angle >= np.pi:\n", " angle %= cap\n", " elif angle > cap:\n", - " angle = (cap - (angle % cap))\n", + " angle = cap - (angle % cap)\n", " return angle\n", "\n", + "\n", "def measure(u):\n", " ag1 = u.select_atoms(\"segid U1\")\n", " ring1 = ag1.select_atoms(\"type C.2\").positions.astype(float)\n", " c1 = plf.utils.get_centroid(ring1)\n", " c1 = Point3D(*c1)\n", " n1 = plf.utils.get_ring_normal_vector(c1, ring1)\n", - " \n", + "\n", " ag2 = u.select_atoms(\"segid U2\")\n", " ring2 = ag2.select_atoms(\"type C.2\").positions.astype(float)\n", " c2 = plf.utils.get_centroid(ring2)\n", " c2 = Point3D(*c2)\n", " n2 = plf.utils.get_ring_normal_vector(c2, ring2)\n", - " \n", + "\n", " planes_angle = n1.AngleTo(n2)\n", " c1c2 = c1.DirectionVector(c2)\n", " c2c1 = c2.DirectionVector(c1)\n", @@ -76,15 +79,17 @@ " n2c2c1 = n2.AngleTo(c2c1)\n", " proj = plf.interactions.EdgeToFace._get_intersect_point(n1, c1, n2, c2)\n", " pdist = min(c1.Distance(proj), c2.Distance(proj))\n", - " \n", + "\n", " m1 = plf.Molecule.from_mda(ag1)\n", " m2 = plf.Molecule.from_mda(ag2)\n", - " \n", - " print(f'''\n", + "\n", + " print(\n", + " f\"\"\"\n", "centroid distance: {c1.Distance(c2):.3f} pdist: {pdist:.3f}\n", "planes: {np.degrees(cap_angle(planes_angle)):.3f}°\n", "n1c1c2: {np.degrees(cap_angle(n1c1c2)):.3f}° n2c2c1: {np.degrees(cap_angle(n2c2c1)):.3f}°\n", - "FTF: {fp.facetoface(m1, m2)} ETF: {fp.edgetoface(m1, m2)}''')\n", + "FTF: {fp.facetoface(m1, m2)} ETF: {fp.edgetoface(m1, m2)}\"\"\"\n", + " )\n", " return c1, n1, c2, n2, proj" ] }, @@ -100,9 +105,10 @@ "v = nv.show_mdanalysis(u.atoms)\n", "v.center(\"*\")\n", "v._set_size(\"100%\", \"400px\")\n", - "v.camera = 'orthographic'\n", + "v.camera = \"orthographic\"\n", "shapes = {}\n", "\n", + "\n", "def view(dx=0, dy=1.5, dz=4.5, ax=30, ay=0, az=0):\n", " new = create(xyz=[dx, dy, dz], rotation=[ax, ay, az])\n", " u.atoms.positions = new.atoms.positions\n", @@ -113,15 +119,22 @@ " comp.clear()\n", " except:\n", " pass\n", - " shapes[\"c1c2\"] = v.shape.add_cylinder(list(c1), list(c2), [1,0,0], .1)\n", - " shapes[\"n1\"] = v.shape.add_cylinder(list(c1), list(c1 + n1 + n1), [0,1,0], .1)\n", - " shapes[\"n2\"] = v.shape.add_cylinder(list(c2), list(c2 + n2 + n2), [0,0,1], .1)\n", - " shapes[\"proj\"] = v.shape.add_sphere(list(proj), [.8,.2,.6], .3)\n", - " \n", - "widget=interactive(view, \n", - " dx=(-7, 7, .5), dy=(-7, 7, .5), dz=(-7, 7, .5),\n", - " ax=(0, 180, 5), ay=(0, 180, 5), az=(0, 180, 5))\n", - "controls = HBox(widget.children[:-1], layout=Layout(flex_flow='row wrap'))\n", + " shapes[\"c1c2\"] = v.shape.add_cylinder(list(c1), list(c2), [1, 0, 0], 0.1)\n", + " shapes[\"n1\"] = v.shape.add_cylinder(list(c1), list(c1 + n1 + n1), [0, 1, 0], 0.1)\n", + " shapes[\"n2\"] = v.shape.add_cylinder(list(c2), list(c2 + n2 + n2), [0, 0, 1], 0.1)\n", + " shapes[\"proj\"] = v.shape.add_sphere(list(proj), [0.8, 0.2, 0.6], 0.3)\n", + "\n", + "\n", + "widget = interactive(\n", + " view,\n", + " dx=(-7, 7, 0.5),\n", + " dy=(-7, 7, 0.5),\n", + " dz=(-7, 7, 0.5),\n", + " ax=(0, 180, 5),\n", + " ay=(0, 180, 5),\n", + " az=(0, 180, 5),\n", + ")\n", + "controls = HBox(widget.children[:-1], layout=Layout(flex_flow=\"row wrap\"))\n", "output = widget.children[-1]\n", "display(VBox([controls, output, v]))" ] diff --git a/tests/plotting/test_network.py b/tests/plotting/test_network.py index 008033c..cd13308 100644 --- a/tests/plotting/test_network.py +++ b/tests/plotting/test_network.py @@ -9,7 +9,7 @@ @contextmanager -def TempFilename(mode='w'): +def TempFilename(mode="w"): f = NamedTemporaryFile(delete=False, mode=mode) try: f.close() @@ -50,8 +50,9 @@ def test_integration_agg(self, lignetwork_data): def test_kwargs(self, lignetwork_data): lig, df = lignetwork_data - net = LigNetwork.from_ifp(df, lig, kekulize=True, match3D=False, - rotation=42, carbon=0) + net = LigNetwork.from_ifp( + df, lig, kekulize=True, match3D=False, rotation=42, carbon=0 + ) with StringIO() as buffer: net.save(buffer) buffer.seek(0) diff --git a/tests/test_fingerprint.py b/tests/test_fingerprint.py index 956c562..7e00cec 100644 --- a/tests/test_fingerprint.py +++ b/tests/test_fingerprint.py @@ -38,16 +38,11 @@ def test_wrapper_repr(): assert _repr.startswith("<") and ".Dummy at " in _repr -@pytest.mark.parametrize("returned", [ - True, - (True,), - (True, 4) -]) +@pytest.mark.parametrize("returned", [True, (True,), (True, 4)]) def test_wrapper_incorrect_return(returned): mod = _InteractionWrapper(Return().detect) assert mod.__wrapped__(returned) == returned - with pytest.raises(TypeError, - match="Incorrect function signature"): + with pytest.raises(TypeError, match="Incorrect function signature"): mod(returned) @@ -64,8 +59,7 @@ def fp_simple(self): def test_init(self, fp_simple): assert "Hydrophobic" in fp_simple.interactions.keys() - assert (hasattr(fp_simple, "hydrophobic") - and callable(fp_simple.hydrophobic)) + assert hasattr(fp_simple, "hydrophobic") and callable(fp_simple.hydrophobic) assert "Dummy" not in fp_simple.interactions.keys() assert hasattr(fp_simple, "dummy") and callable(fp_simple.dummy) assert "_BaseHBond" not in fp_simple.interactions.keys() @@ -91,31 +85,37 @@ def test_bitvector(self, fp): assert bv.sum() > 0 def test_bitvector_atoms(self, fp): - bv, lig_ix, prot_ix = fp.bitvector_atoms(ligand_mol, - protein_mol["ASP129.A"]) + bv, lig_ix, prot_ix = fp.bitvector_atoms(ligand_mol, protein_mol["ASP129.A"]) assert len(bv) == fp.n_interactions assert len(lig_ix) == fp.n_interactions assert len(prot_ix) == fp.n_interactions assert bv.sum() > 0 ids = np.where(bv == 1)[0] - assert (lig_ix[ids[0]] is not None and prot_ix[ids[0]] is not None) + assert lig_ix[ids[0]] is not None and prot_ix[ids[0]] is not None def test_run_residues(self, fp_simple): - fp_simple.run(u.trajectory[0:1], ligand_ag, protein_ag, - residues="all", progress=False) + fp_simple.run( + u.trajectory[0:1], ligand_ag, protein_ag, residues="all", progress=False + ) lig_id = ResidueId.from_string("LIG1.G") assert hasattr(fp_simple, "ifp") assert len(fp_simple.ifp) == 1 res = ResidueId.from_string("LYS387.B") assert (lig_id, res) in fp_simple.ifp[0].keys() - fp_simple.run(u.trajectory[1:2], ligand_ag, protein_ag, - residues=["ASP129.A"], progress=False) + fp_simple.run( + u.trajectory[1:2], + ligand_ag, + protein_ag, + residues=["ASP129.A"], + progress=False, + ) assert hasattr(fp_simple, "ifp") assert len(fp_simple.ifp) == 1 res = ResidueId.from_string("ASP129.A") assert (lig_id, res) in fp_simple.ifp[0].keys() - fp_simple.run(u.trajectory[:3], ligand_ag, protein_ag, - residues=None, progress=False) + fp_simple.run( + u.trajectory[:3], ligand_ag, protein_ag, residues=None, progress=False + ) assert hasattr(fp_simple, "ifp") assert len(fp_simple.ifp) == 3 assert len(fp_simple.ifp[0]) > 1 @@ -131,8 +131,9 @@ def test_generate(self, fp_simple): assert bv[0] is np.True_ def test_run(self, fp_simple): - fp_simple.run(u.trajectory[0:1], ligand_ag, protein_ag, - residues=None, progress=False) + fp_simple.run( + u.trajectory[0:1], ligand_ag, protein_ag, residues=None, progress=False + ) assert hasattr(fp_simple, "ifp") ifp = fp_simple.ifp[0] ifp.pop("Frame") @@ -150,27 +151,29 @@ def test_run_from_iterable(self, fp_simple): def test_to_df(self, fp_simple): with pytest.raises(AttributeError, match="use the `run` method"): Fingerprint().to_dataframe() - fp_simple.run(u.trajectory[:3], ligand_ag, protein_ag, - residues=None, progress=False) + fp_simple.run( + u.trajectory[:3], ligand_ag, protein_ag, residues=None, progress=False + ) df = fp_simple.to_dataframe() assert isinstance(df, DataFrame) assert len(df) == 3 def test_to_df_kwargs(self, fp_simple): - fp_simple.run(u.trajectory[:3], ligand_ag, protein_ag, - residues=None, progress=False) + fp_simple.run( + u.trajectory[:3], ligand_ag, protein_ag, residues=None, progress=False + ) df = fp_simple.to_dataframe(dtype=np.uint8) assert df.dtypes[0].type is np.uint8 df = fp_simple.to_dataframe(drop_empty=False) - resids = set([key for d in fp_simple.ifp for key in d.keys() - if key != "Frame"]) + resids = set([key for d in fp_simple.ifp for key in d.keys() if key != "Frame"]) assert df.shape == (3, len(resids)) def test_to_bv(self, fp_simple): with pytest.raises(AttributeError, match="use the `run` method"): Fingerprint().to_bitvectors() - fp_simple.run(u.trajectory[:3], ligand_ag, protein_ag, - residues=None, progress=False) + fp_simple.run( + u.trajectory[:3], ligand_ag, protein_ag, residues=None, progress=False + ) bvs = fp_simple.to_bitvectors() assert isinstance(bvs[0], ExplicitBitVect) assert len(bvs) == 3 @@ -224,8 +227,9 @@ def test_pickle(self, fp, fp_pkled): assert d1.keys() == d2.keys() d1.pop("Frame", None) d2.pop("Frame") - for (fp1, fpal1, fpap1), (fp2, fpal2, fpap2) in zip(d1.values(), - d2.values()): + for (fp1, fpal1, fpap1), (fp2, fpal2, fpap2) in zip( + d1.values(), d2.values() + ): assert (fp1 == fp2).all() assert fpal1 == fpal2 assert fpap1 == fpap2 @@ -235,11 +239,11 @@ def test_pickle_custom_interaction(self, fp_unpkl): assert callable(fp_unpkl.dummy) def test_run_multiproc_serial_same(self, fp): - fp.run(u.trajectory[0:100:10], ligand_ag, protein_ag, - n_jobs=1, progress=False) + fp.run(u.trajectory[0:100:10], ligand_ag, protein_ag, n_jobs=1, progress=False) serial = fp.to_dataframe() - fp.run(u.trajectory[0:100:10], ligand_ag, protein_ag, - n_jobs=None, progress=False) + fp.run( + u.trajectory[0:100:10], ligand_ag, protein_ag, n_jobs=None, progress=False + ) multi = fp.to_dataframe() assert serial.equals(multi) @@ -255,20 +259,26 @@ def test_run_iter_multiproc_serial_same(self, fp): def test_converter_kwargs_raises_error(self, fp: Fingerprint): with pytest.raises( - ValueError, - match="converter_kwargs must be a list of 2 dicts" + ValueError, match="converter_kwargs must be a list of 2 dicts" ): fp.run( - u.trajectory[0:5], ligand_ag, protein_ag, n_jobs=1, progress=False, - converter_kwargs=[dict(force=True)] + u.trajectory[0:5], + ligand_ag, + protein_ag, + n_jobs=1, + progress=False, + converter_kwargs=[dict(force=True)], ) - + @pytest.mark.parametrize("n_jobs", [1, 2]) def test_converter_kwargs(self, fp: Fingerprint, n_jobs: int): u = mda.Universe.from_smiles("O=C=O.O=C=O") lig, prot = u.atoms.fragments fp.run( - u.trajectory, lig, prot, n_jobs=n_jobs, - converter_kwargs=[dict(force=True), dict(force=True)] + u.trajectory, + lig, + prot, + n_jobs=n_jobs, + converter_kwargs=[dict(force=True), dict(force=True)], ) - assert fp.ifp \ No newline at end of file + assert fp.ifp diff --git a/tests/test_interactions.py b/tests/test_interactions.py index 5e06739..4bbd85c 100644 --- a/tests/test_interactions.py +++ b/tests/test_interactions.py @@ -6,8 +6,7 @@ from MDAnalysis.transformations import translate, rotateby from rdkit import RDLogger from prolif.fingerprint import Fingerprint -from prolif.interactions import (_INTERACTIONS, Interaction, VdWContact, - get_mapindex) +from prolif.interactions import _INTERACTIONS, Interaction, VdWContact, get_mapindex from rdkit import Chem, RDLogger from . import mol2factory @@ -24,7 +23,8 @@ benzene.transfer_to_memory() interaction_instances = { - name: cls() for name, cls in _INTERACTIONS.items() + name: cls() + for name, cls in _INTERACTIONS.items() if name not in ["Interaction", "_Distance"] } @@ -50,75 +50,80 @@ class TestInteractions: def fingerprint(self): return Fingerprint() - @pytest.mark.parametrize("func_name, lig_mol, prot_mol, expected", [ - ("cationic", "cation", "anion", True), - ("cationic", "anion", "cation", False), - ("cationic", "cation", "benzene", False), - ("anionic", "cation", "anion", False), - ("anionic", "anion", "cation", True), - ("anionic", "anion", "benzene", False), - ("cationpi", "cation", "benzene", True), - ("cationpi", "cation_false", "benzene", False), - ("cationpi", "benzene", "cation", False), - ("cationpi", "cation", "cation", False), - ("cationpi", "benzene", "benzene", False), - ("pication", "benzene", "cation", True), - ("pication", "benzene", "cation_false", False), - ("pication", "cation", "benzene", False), - ("pication", "cation", "cation", False), - ("pication", "benzene", "benzene", False), - ("pistacking", "benzene", "etf", True), - ("pistacking", "etf", "benzene", True), - ("pistacking", "ftf", "benzene", True), - ("pistacking", "benzene", "ftf", True), - ("facetoface", "benzene", "ftf", True), - ("facetoface", "ftf", "benzene", True), - ("facetoface", "benzene", "etf", False), - ("facetoface", "etf", "benzene", False), - ("edgetoface", "benzene", "etf", True), - ("edgetoface", "etf", "benzene", True), - ("edgetoface", "benzene", "ftf", False), - ("edgetoface", "ftf", "benzene", False), - ("hydrophobic", "benzene", "etf", True), - ("hydrophobic", "benzene", "ftf", True), - ("hydrophobic", "benzene", "chlorine", False), - ("hydrophobic", "benzene", "bromine", True), - ("hydrophobic", "benzene", "anion", False), - ("hydrophobic", "benzene", "cation", False), - ("hbdonor", "hb_donor", "hb_acceptor", True), - ("hbdonor", "hb_donor", "hb_acceptor_false", False), - ("hbdonor", "hb_acceptor", "hb_donor", False), - ("hbacceptor", "hb_acceptor", "hb_donor", True), - ("hbacceptor", "hb_acceptor_false", "hb_donor", False), - ("hbacceptor", "hb_donor", "hb_acceptor", False), - ("xbdonor", "xb_donor", "xb_acceptor", True), - ("xbdonor", "xb_donor", "xb_acceptor_false_xar", False), - ("xbdonor", "xb_donor", "xb_acceptor_false_axd", False), - ("xbdonor", "xb_acceptor", "xb_donor", False), - ("xbacceptor", "xb_acceptor", "xb_donor", True), - ("xbacceptor", "xb_acceptor_false_xar", "xb_donor", False), - ("xbacceptor", "xb_acceptor_false_axd", "xb_donor", False), - ("xbacceptor", "xb_donor", "xb_acceptor", False), - ("metaldonor", "metal", "ligand", True), - ("metaldonor", "metal_false", "ligand", False), - ("metaldonor", "ligand", "metal", False), - ("metalacceptor", "ligand", "metal", True), - ("metalacceptor", "ligand", "metal_false", False), - ("metalacceptor", "metal", "ligand", False), - ("vdwcontact", "benzene", "etf", True), - ("vdwcontact", "hb_acceptor", "metal_false", False), - ], indirect=["lig_mol", "prot_mol"]) + @pytest.mark.parametrize( + "func_name, lig_mol, prot_mol, expected", + [ + ("cationic", "cation", "anion", True), + ("cationic", "anion", "cation", False), + ("cationic", "cation", "benzene", False), + ("anionic", "cation", "anion", False), + ("anionic", "anion", "cation", True), + ("anionic", "anion", "benzene", False), + ("cationpi", "cation", "benzene", True), + ("cationpi", "cation_false", "benzene", False), + ("cationpi", "benzene", "cation", False), + ("cationpi", "cation", "cation", False), + ("cationpi", "benzene", "benzene", False), + ("pication", "benzene", "cation", True), + ("pication", "benzene", "cation_false", False), + ("pication", "cation", "benzene", False), + ("pication", "cation", "cation", False), + ("pication", "benzene", "benzene", False), + ("pistacking", "benzene", "etf", True), + ("pistacking", "etf", "benzene", True), + ("pistacking", "ftf", "benzene", True), + ("pistacking", "benzene", "ftf", True), + ("facetoface", "benzene", "ftf", True), + ("facetoface", "ftf", "benzene", True), + ("facetoface", "benzene", "etf", False), + ("facetoface", "etf", "benzene", False), + ("edgetoface", "benzene", "etf", True), + ("edgetoface", "etf", "benzene", True), + ("edgetoface", "benzene", "ftf", False), + ("edgetoface", "ftf", "benzene", False), + ("hydrophobic", "benzene", "etf", True), + ("hydrophobic", "benzene", "ftf", True), + ("hydrophobic", "benzene", "chlorine", False), + ("hydrophobic", "benzene", "bromine", True), + ("hydrophobic", "benzene", "anion", False), + ("hydrophobic", "benzene", "cation", False), + ("hbdonor", "hb_donor", "hb_acceptor", True), + ("hbdonor", "hb_donor", "hb_acceptor_false", False), + ("hbdonor", "hb_acceptor", "hb_donor", False), + ("hbacceptor", "hb_acceptor", "hb_donor", True), + ("hbacceptor", "hb_acceptor_false", "hb_donor", False), + ("hbacceptor", "hb_donor", "hb_acceptor", False), + ("xbdonor", "xb_donor", "xb_acceptor", True), + ("xbdonor", "xb_donor", "xb_acceptor_false_xar", False), + ("xbdonor", "xb_donor", "xb_acceptor_false_axd", False), + ("xbdonor", "xb_acceptor", "xb_donor", False), + ("xbacceptor", "xb_acceptor", "xb_donor", True), + ("xbacceptor", "xb_acceptor_false_xar", "xb_donor", False), + ("xbacceptor", "xb_acceptor_false_axd", "xb_donor", False), + ("xbacceptor", "xb_donor", "xb_acceptor", False), + ("metaldonor", "metal", "ligand", True), + ("metaldonor", "metal_false", "ligand", False), + ("metaldonor", "ligand", "metal", False), + ("metalacceptor", "ligand", "metal", True), + ("metalacceptor", "ligand", "metal_false", False), + ("metalacceptor", "metal", "ligand", False), + ("vdwcontact", "benzene", "etf", True), + ("vdwcontact", "hb_acceptor", "metal_false", False), + ], + indirect=["lig_mol", "prot_mol"], + ) def test_interaction(self, fingerprint, func_name, lig_mol, prot_mol, expected): interaction = getattr(fingerprint, func_name) assert interaction(lig_mol, prot_mol) is expected def test_warning_supersede(self): old = id(_INTERACTIONS["Hydrophobic"]) - with pytest.warns(UserWarning, - match="interaction has been superseded"): + with pytest.warns(UserWarning, match="interaction has been superseded"): + class Hydrophobic(Interaction): def detect(self): pass + new = id(_INTERACTIONS["Hydrophobic"]) assert old != new # fix dummy Hydrophobic class being reused in later unrelated tests @@ -129,25 +134,22 @@ class Hydrophobic(prolif.interactions.Hydrophobic): def test_error_no_detect(self): class _Dummy(Interaction): pass - with pytest.raises(TypeError, - match="Can't instantiate abstract class _Dummy"): + + with pytest.raises(TypeError, match="Can't instantiate abstract class _Dummy"): _Dummy() - @pytest.mark.parametrize("index", [ - 0, 1, 3, 42, 78 - ]) + @pytest.mark.parametrize("index", [0, 1, 3, 42, 78]) def test_get_mapindex(self, index): parent_index = get_mapindex(ligand_mol[0], index) assert parent_index == index def test_vdwcontact_tolerance_error(self): - with pytest.raises(ValueError, - match="`tolerance` must be 0 or positive"): + with pytest.raises(ValueError, match="`tolerance` must be 0 or positive"): VdWContact(tolerance=-1) - @pytest.mark.parametrize("lig_mol, prot_mol", [ - ("benzene", "cation") - ], indirect=["lig_mol", "prot_mol"]) + @pytest.mark.parametrize( + "lig_mol, prot_mol", [("benzene", "cation")], indirect=["lig_mol", "prot_mol"] + ) def test_vdwcontact_cache(self, lig_mol, prot_mol): vdw = VdWContact() assert vdw._vdw_cache == {} @@ -156,101 +158,109 @@ def test_vdwcontact_cache(self, lig_mol, prot_mol): vdw_dist = vdwradii[lig] + vdwradii[res] + vdw.tolerance assert vdw_dist == value - @pytest.mark.parametrize(["interaction_qmol", "smiles", "expected"], [ - ("Hydrophobic.lig_pattern", "C", 1), - ("Hydrophobic.lig_pattern", "C=[SH2]", 1), - ("Hydrophobic.lig_pattern", "c1cscc1", 5), - ("Hydrophobic.lig_pattern", "CSC", 3), - ("Hydrophobic.lig_pattern", "CS(C)(C)C", 4), - ("Hydrophobic.lig_pattern", "FC(F)(F)F", 0), - ("Hydrophobic.lig_pattern", "BrI", 2), - ("Hydrophobic.lig_pattern", "C=O", 0), - ("Hydrophobic.lig_pattern", "C=N", 0), - ("Hydrophobic.lig_pattern", "CF", 0), - ("_BaseHBond.donor", "[OH2]", 2), - ("_BaseHBond.donor", "[NH3]", 3), - ("_BaseHBond.donor", "[NH4+]", 4), - ("_BaseHBond.donor", "[SH2]", 2), - ("_BaseHBond.donor", "O=C=O", 0), - ("_BaseHBond.donor", "c1c[nH+]ccc1", 0), - ("_BaseHBond.donor", "c1c[nH]cc1", 1), - ("_BaseHBond.acceptor", "O", 1), - ("_BaseHBond.acceptor", "N", 1), - ("_BaseHBond.acceptor", "[NH4+]", 0), - ("_BaseHBond.acceptor", "N-C=O", 1), - ("_BaseHBond.acceptor", "N-C=[SH2]", 0), - ("_BaseHBond.acceptor", "[nH+]1ccccc1", 0), - ("_BaseHBond.acceptor", "n1ccccc1", 1), - ("_BaseHBond.acceptor", "Nc1ccccc1", 0), - ("_BaseHBond.acceptor", "o1cccc1", 1), - ("_BaseHBond.acceptor", "COC=O", 1), - ("_BaseHBond.acceptor", "c1ccccc1Oc1ccccc1", 0), - ("_BaseHBond.acceptor", "FC", 1), - ("_BaseHBond.acceptor", "Fc1ccccc1", 1), - ("_BaseHBond.acceptor", "FCF", 0), - ("_BaseXBond.donor", "CCl", 1), - ("_BaseXBond.donor", "c1ccccc1Cl", 1), - ("_BaseXBond.donor", "NCl", 1), - ("_BaseXBond.donor", "c1cccc[n+]1Cl", 1), - ("_BaseXBond.acceptor", "[NH3]", 3), - ("_BaseXBond.acceptor", "[NH+]C", 0), - ("_BaseXBond.acceptor", "c1ccccc1", 12), - ("Cationic.lig_pattern", "[NH4+]", 1), - ("Cationic.lig_pattern", "[Ca+2]", 1), - ("Cationic.lig_pattern", "CC(=[NH2+])N", 2), - ("Cationic.lig_pattern", "NC(=[NH2+])N", 3), - ("Cationic.prot_pattern", "[Cl-]", 1), - ("Cationic.prot_pattern", "CC(=O)[O-]", 2), - ("Cationic.prot_pattern", "CS(=O)[O-]", 2), - ("Cationic.prot_pattern", "CP(=O)[O-]", 2), - ("_BaseCationPi.cation", "[NH4+]", 1), - ("_BaseCationPi.cation", "[Ca+2]", 1), - ("_BaseCationPi.cation", "CC(=[NH2+])N", 2), - ("_BaseCationPi.cation", "NC(=[NH2+])N", 3), - ("_BaseCationPi.pi_ring", "c1ccccc1", 1), - ("_BaseCationPi.pi_ring", "c1cocc1", 1), - ("EdgeToFace.pi_ring", "c1ccccc1", 1), - ("EdgeToFace.pi_ring", "c1cocc1", 1), - ("FaceToFace.pi_ring", "c1ccccc1", 1), - ("FaceToFace.pi_ring", "c1cocc1", 1), - ("_BaseMetallic.lig_pattern", "[Mg]", 1), - ("_BaseMetallic.prot_pattern", "O", 1), - ("_BaseMetallic.prot_pattern", "N", 1), - ("_BaseMetallic.prot_pattern", "[NH+]", 0), - ("_BaseMetallic.prot_pattern", "N-C=[SH2]", 0), - ("_BaseMetallic.prot_pattern", "[nH+]1ccccc1", 0), - ("_BaseMetallic.prot_pattern", "Nc1ccccc1", 0), - ("_BaseMetallic.prot_pattern", "o1cccc1", 0), - ("_BaseMetallic.prot_pattern", "COC=O", 2), - ], indirect=["interaction_qmol"]) + @pytest.mark.parametrize( + ["interaction_qmol", "smiles", "expected"], + [ + ("Hydrophobic.lig_pattern", "C", 1), + ("Hydrophobic.lig_pattern", "C=[SH2]", 1), + ("Hydrophobic.lig_pattern", "c1cscc1", 5), + ("Hydrophobic.lig_pattern", "CSC", 3), + ("Hydrophobic.lig_pattern", "CS(C)(C)C", 4), + ("Hydrophobic.lig_pattern", "FC(F)(F)F", 0), + ("Hydrophobic.lig_pattern", "BrI", 2), + ("Hydrophobic.lig_pattern", "C=O", 0), + ("Hydrophobic.lig_pattern", "C=N", 0), + ("Hydrophobic.lig_pattern", "CF", 0), + ("_BaseHBond.donor", "[OH2]", 2), + ("_BaseHBond.donor", "[NH3]", 3), + ("_BaseHBond.donor", "[NH4+]", 4), + ("_BaseHBond.donor", "[SH2]", 2), + ("_BaseHBond.donor", "O=C=O", 0), + ("_BaseHBond.donor", "c1c[nH+]ccc1", 0), + ("_BaseHBond.donor", "c1c[nH]cc1", 1), + ("_BaseHBond.acceptor", "O", 1), + ("_BaseHBond.acceptor", "N", 1), + ("_BaseHBond.acceptor", "[NH4+]", 0), + ("_BaseHBond.acceptor", "N-C=O", 1), + ("_BaseHBond.acceptor", "N-C=[SH2]", 0), + ("_BaseHBond.acceptor", "[nH+]1ccccc1", 0), + ("_BaseHBond.acceptor", "n1ccccc1", 1), + ("_BaseHBond.acceptor", "Nc1ccccc1", 0), + ("_BaseHBond.acceptor", "o1cccc1", 1), + ("_BaseHBond.acceptor", "COC=O", 1), + ("_BaseHBond.acceptor", "c1ccccc1Oc1ccccc1", 0), + ("_BaseHBond.acceptor", "FC", 1), + ("_BaseHBond.acceptor", "Fc1ccccc1", 1), + ("_BaseHBond.acceptor", "FCF", 0), + ("_BaseXBond.donor", "CCl", 1), + ("_BaseXBond.donor", "c1ccccc1Cl", 1), + ("_BaseXBond.donor", "NCl", 1), + ("_BaseXBond.donor", "c1cccc[n+]1Cl", 1), + ("_BaseXBond.acceptor", "[NH3]", 3), + ("_BaseXBond.acceptor", "[NH+]C", 0), + ("_BaseXBond.acceptor", "c1ccccc1", 12), + ("Cationic.lig_pattern", "[NH4+]", 1), + ("Cationic.lig_pattern", "[Ca+2]", 1), + ("Cationic.lig_pattern", "CC(=[NH2+])N", 2), + ("Cationic.lig_pattern", "NC(=[NH2+])N", 3), + ("Cationic.prot_pattern", "[Cl-]", 1), + ("Cationic.prot_pattern", "CC(=O)[O-]", 2), + ("Cationic.prot_pattern", "CS(=O)[O-]", 2), + ("Cationic.prot_pattern", "CP(=O)[O-]", 2), + ("_BaseCationPi.cation", "[NH4+]", 1), + ("_BaseCationPi.cation", "[Ca+2]", 1), + ("_BaseCationPi.cation", "CC(=[NH2+])N", 2), + ("_BaseCationPi.cation", "NC(=[NH2+])N", 3), + ("_BaseCationPi.pi_ring", "c1ccccc1", 1), + ("_BaseCationPi.pi_ring", "c1cocc1", 1), + ("EdgeToFace.pi_ring", "c1ccccc1", 1), + ("EdgeToFace.pi_ring", "c1cocc1", 1), + ("FaceToFace.pi_ring", "c1ccccc1", 1), + ("FaceToFace.pi_ring", "c1cocc1", 1), + ("_BaseMetallic.lig_pattern", "[Mg]", 1), + ("_BaseMetallic.prot_pattern", "O", 1), + ("_BaseMetallic.prot_pattern", "N", 1), + ("_BaseMetallic.prot_pattern", "[NH+]", 0), + ("_BaseMetallic.prot_pattern", "N-C=[SH2]", 0), + ("_BaseMetallic.prot_pattern", "[nH+]1ccccc1", 0), + ("_BaseMetallic.prot_pattern", "Nc1ccccc1", 0), + ("_BaseMetallic.prot_pattern", "o1cccc1", 0), + ("_BaseMetallic.prot_pattern", "COC=O", 2), + ], + indirect=["interaction_qmol"], + ) def test_smarts_matches(self, interaction_qmol, smiles, expected): mol = Chem.MolFromSmiles(smiles) mol = Chem.AddHs(mol) if isinstance(interaction_qmol, list): - n_matches = sum(len(mol.GetSubstructMatches(qmol)) - for qmol in interaction_qmol) + n_matches = sum( + len(mol.GetSubstructMatches(qmol)) for qmol in interaction_qmol + ) else: n_matches = len(mol.GetSubstructMatches(interaction_qmol)) assert n_matches == expected - @pytest.mark.parametrize(["xyz", "rotation", "pi_type", "expected"], [ - ([0, 2.5, 4.0], [0, 0, 0], "facetoface", True), - ([0, 3, 4.5], [0, 0, 0], "facetoface", False), - ([0, 2, 4.5], [30, 0, 0], "facetoface", True), - ([0, 2, 4.5], [150, 0, 0], "facetoface", True), - ([0, 2, -4.5], [30, 0, 0], "facetoface", True), - ([0, 2, -4.5], [150, 0, 0], "facetoface", True), - ([1, 1.5, 3.5], [30, 15, 80], "facetoface", True), - ([1, 2.5, 4.5], [30, 15, 65], "facetoface", True), - ([0, 1.5, 4.5], [60, 0, 0], "edgetoface", True), - ([0, 2, 5], [60, 0, 0], "edgetoface", True), - ([0, 1.5, 4.5], [90, 0, 0], "edgetoface", True), - ([0, 1.5, -4.5], [90, 0, 0], "edgetoface", True), - ([0, 6, -.5], [110, 0, 0], "edgetoface", True), - ([0, 4.5, -.5], [105, 0, 0], "edgetoface", True), - ([0, 1.5, 4.5], [105, 0, 0], "edgetoface", False), - ([0, 1.5, -4.5], [75, 0, 0], "edgetoface", False), - ]) + @pytest.mark.parametrize( + ["xyz", "rotation", "pi_type", "expected"], + [ + ([0, 2.5, 4.0], [0, 0, 0], "facetoface", True), + ([0, 3, 4.5], [0, 0, 0], "facetoface", False), + ([0, 2, 4.5], [30, 0, 0], "facetoface", True), + ([0, 2, 4.5], [150, 0, 0], "facetoface", True), + ([0, 2, -4.5], [30, 0, 0], "facetoface", True), + ([0, 2, -4.5], [150, 0, 0], "facetoface", True), + ([1, 1.5, 3.5], [30, 15, 80], "facetoface", True), + ([1, 2.5, 4.5], [30, 15, 65], "facetoface", True), + ([0, 1.5, 4.5], [60, 0, 0], "edgetoface", True), + ([0, 2, 5], [60, 0, 0], "edgetoface", True), + ([0, 1.5, 4.5], [90, 0, 0], "edgetoface", True), + ([0, 1.5, -4.5], [90, 0, 0], "edgetoface", True), + ([0, 6, -0.5], [110, 0, 0], "edgetoface", True), + ([0, 4.5, -0.5], [105, 0, 0], "edgetoface", True), + ([0, 1.5, 4.5], [105, 0, 0], "edgetoface", False), + ([0, 1.5, -4.5], [75, 0, 0], "edgetoface", False), + ], + ) def test_pi_stacking(self, xyz, rotation, pi_type, expected, fingerprint): r1, r2 = self.create_rings(xyz, rotation) assert getattr(fingerprint, pi_type)(r1, r2) is expected @@ -264,9 +274,9 @@ def create_rings(xyz, rotation): r2 = benzene.copy() r2.segments.segids = np.array(["U2"], dtype=object) tr = translate(xyz) - rotx = rotateby(rotation[0], [1,0,0], ag=r2.atoms) - roty = rotateby(rotation[1], [0,1,0], ag=r2.atoms) - rotz = rotateby(rotation[2], [0,0,1], ag=r2.atoms) + rotx = rotateby(rotation[0], [1, 0, 0], ag=r2.atoms) + roty = rotateby(rotation[1], [0, 1, 0], ag=r2.atoms) + rotz = rotateby(rotation[2], [0, 0, 1], ag=r2.atoms) r2.trajectory.add_transformations(tr, rotx, roty, rotz) return prolif.Molecule.from_mda(benzene), prolif.Molecule.from_mda(r2) diff --git a/tests/test_molecule.py b/tests/test_molecule.py index ae119dc..f7e5488 100644 --- a/tests/test_molecule.py +++ b/tests/test_molecule.py @@ -3,8 +3,7 @@ from MDAnalysis import SelectionError from numpy.testing import assert_array_equal from prolif.datafiles import datapath -from prolif.molecule import (Molecule, mol2_supplier, pdbqt_supplier, - sdf_supplier) +from prolif.molecule import Molecule, mol2_supplier, pdbqt_supplier, sdf_supplier from prolif.residue import ResidueId from rdkit import Chem @@ -24,14 +23,15 @@ def test_from_mda(self): rdkit_mol = Molecule(ligand_rdkit) mda_mol = Molecule.from_mda(u, "resname LIG") assert rdkit_mol[0].resid == mda_mol[0].resid - assert (rdkit_mol.HasSubstructMatch(mda_mol) and - mda_mol.HasSubstructMatch(rdkit_mol)) + assert rdkit_mol.HasSubstructMatch(mda_mol) and mda_mol.HasSubstructMatch( + rdkit_mol + ) def test_from_mda_empty_ag(self): ag = u.select_atoms("resname FOO") with pytest.raises(SelectionError, match="AtomGroup is empty"): Molecule.from_mda(ag) - + def test_from_rdkit(self): rdkit_mol = Molecule(ligand_rdkit) newmol = Molecule.from_rdkit(ligand_rdkit) @@ -47,13 +47,7 @@ def test_from_rdkit_resid_args(self): newmol = Molecule.from_rdkit(mol, "FOO", 42, "A") assert newmol[0].resid == ResidueId("FOO", 42, "A") - @pytest.mark.parametrize("key", [ - 0, - 42, - -1, - "LYS49.A", - ResidueId("LYS", 49, "A") - ]) + @pytest.mark.parametrize("key", [0, 42, -1, "LYS49.A", ResidueId("LYS", 49, "A")]) def test_getitem(self, mol, key): assert mol[key].resid is mol.residues[key].resid @@ -94,9 +88,11 @@ class TestPDBQTSupplier(SupplierBase): def suppl(self): path = datapath / "vina" pdbqts = sorted(path.glob("*.pdbqt")) - template = Chem.MolFromSmiles("C[NH+]1CC(C(=O)NC2(C)OC3(O)C4CCCN4C(=O)" - "C(Cc4ccccc4)N3C2=O)C=C2c3cccc4[nH]cc" - "(c34)CC21") + template = Chem.MolFromSmiles( + "C[NH+]1CC(C(=O)NC2(C)OC3(O)C4CCCN4C(=O)" + "C(Cc4ccccc4)N3C2=O)C=C2c3cccc4[nH]cc" + "(c34)CC21" + ) return pdbqt_supplier(pdbqts, template) def test_pdbqt_hydrogens_stay_in_mol(self): @@ -119,7 +115,8 @@ def test_pdbqt_hydrogens_stay_in_mol(self): pdbqt_mol = rwmol.GetMol() mol = pdbqt_supplier._adjust_hydrogens(template, pdbqt_mol) hydrogens = [ - idx for atom in mol.GetAtoms() + idx + for atom in mol.GetAtoms() if atom.HasProp("_MDAnalysis_index") and (idx := atom.GetIntProp("_MDAnalysis_index")) in indices ] diff --git a/tests/test_residues.py b/tests/test_residues.py index c41ada1..235a6e3 100644 --- a/tests/test_residues.py +++ b/tests/test_residues.py @@ -8,20 +8,23 @@ class TestResidueId: - @pytest.mark.parametrize("name, number, chain", [ - ("ALA", None, None), - ("ALA", 1, None), - ("ALA", 0, None), - ("ALA", None, "B"), - ("ALA", 1, "B"), - (None, 1, "B"), - (None, None, "B"), - (None, 1, None), - (None, None, None), - ("", None, None), - (None, None, ""), - ("", None, ""), - ]) + @pytest.mark.parametrize( + "name, number, chain", + [ + ("ALA", None, None), + ("ALA", 1, None), + ("ALA", 0, None), + ("ALA", None, "B"), + ("ALA", 1, "B"), + (None, 1, "B"), + (None, None, "B"), + (None, 1, None), + (None, None, None), + ("", None, None), + (None, None, ""), + ("", None, ""), + ], + ) def test_init(self, name, number, chain): resid = ResidueId(name, number, chain) name = name or "UNK" @@ -31,21 +34,24 @@ def test_init(self, name, number, chain): assert resid.number == number assert resid.chain == chain - @pytest.mark.parametrize("name, number, chain", [ - ("ALA", None, None), - ("ALA", 1, None), - ("ALA", 0, None), - ("ALA", None, "B"), - ("ALA", 1, "B"), - ("DA", 1, None), - (None, 1, "B"), - (None, None, "B"), - (None, 1, None), - (None, None, None), - ("", None, None), - (None, None, ""), - ("", None, ""), - ]) + @pytest.mark.parametrize( + "name, number, chain", + [ + ("ALA", None, None), + ("ALA", 1, None), + ("ALA", 0, None), + ("ALA", None, "B"), + ("ALA", 1, "B"), + ("DA", 1, None), + (None, 1, "B"), + (None, None, "B"), + (None, 1, None), + (None, None, None), + ("", None, None), + (None, None, ""), + ("", None, ""), + ], + ) def test_from_atom(self, name, number, chain): atom = Chem.Atom(1) mi = Chem.AtomPDBResidueInfo() @@ -71,20 +77,23 @@ def test_from_atom_no_mi(self): assert resid.number is 0 assert resid.chain is None - @pytest.mark.parametrize("resid_str, expected", [ - ("ALA", ("ALA", 0, None)), - ("ALA1", ("ALA", 1, None)), - ("ALA.B", ("ALA", 0, "B")), - ("ALA1.B", ("ALA", 1, "B")), - ("1.B", ("UNK", 1, "B")), - (".B", ("UNK", 0, "B")), - (".0", ("UNK", 0, "0")), - ("1", ("UNK", 1, None)), - ("", ("UNK", 0, None)), - ("DA2.A", ("DA", 2, "A")), - ("DA2", ("DA", 2, None)), - ("DA", ("DA", 0, None)), - ]) + @pytest.mark.parametrize( + "resid_str, expected", + [ + ("ALA", ("ALA", 0, None)), + ("ALA1", ("ALA", 1, None)), + ("ALA.B", ("ALA", 0, "B")), + ("ALA1.B", ("ALA", 1, "B")), + ("1.B", ("UNK", 1, "B")), + (".B", ("UNK", 0, "B")), + (".0", ("UNK", 0, "0")), + ("1", ("UNK", 1, None)), + ("", ("UNK", 0, None)), + ("DA2.A", ("DA", 2, "A")), + ("DA2", ("DA", 2, None)), + ("DA", ("DA", 0, None)), + ], + ) def test_from_string(self, resid_str, expected): resid = ResidueId.from_string(resid_str) assert resid == ResidueId(*expected) @@ -95,22 +104,20 @@ def test_eq(self): res2 = ResidueId(name, number, chain) assert res1 == res2 - @pytest.mark.parametrize("res1, res2", [ - ("ALA1.A", "ALA1.B"), - ("ALA2.A", "ALA3.A"), - ("ALA4.A", "ALA1.B"), - ]) + @pytest.mark.parametrize( + "res1, res2", + [ + ("ALA1.A", "ALA1.B"), + ("ALA2.A", "ALA3.A"), + ("ALA4.A", "ALA1.B"), + ], + ) def test_lt(self, res1, res2): res1 = ResidueId.from_string(res1) res2 = ResidueId.from_string(res2) assert res1 < res2 - @pytest.mark.parametrize("resid_str", [ - "ALA1.A", - "DA2.B", - "HIS3", - "UNK0" - ]) + @pytest.mark.parametrize("resid_str", ["ALA1.A", "DA2.B", "HIS3", "UNK0"]) def test_repr(self, resid_str): resid = ResidueId.from_string(resid_str) expected = f"ResidueId({resid.name}, {resid.number}, {resid.chain})" @@ -139,7 +146,9 @@ class TestResidueGroup: def residues(self): sequence = "ARNDCQEGHILKMFPSTWYV" protein = Chem.MolFromSequence(sequence) - residues = [Residue(res) for res in Chem.SplitMolByPDBResidues(protein).values()] + residues = [ + Residue(res) for res in Chem.SplitMolByPDBResidues(protein).values() + ] return residues def test_init(self, residues): @@ -149,8 +158,7 @@ def test_init(self, residues): for (resid, rg_res), res in zip(rg.items(), residues): assert rg_res is res assert resid is rg_res.resid - resinfo = [(r.resid.name, r.resid.number, r.resid.chain) - for r in residues] + resinfo = [(r.resid.name, r.resid.number, r.resid.chain) for r in residues] name, number, chain = zip(*resinfo) assert_equal(rg.name, name) assert_equal(rg.number, number) @@ -169,14 +177,17 @@ def test_n_residues(self, residues): assert rg.n_residues == len(rg) assert rg.n_residues == 20 - @pytest.mark.parametrize("ix, resid, resid_str", [ - (0, ("ALA", 1, "A"), "ALA1.A"), - (4, ("CYS", 5, "A"), "CYS5.A"), - (6, ("GLU", 7, "A"), "GLU7.A"), - (9, ("ILE", 10, "A"), "ILE10.A"), - (19, ("VAL", 20, "A"), "VAL20.A"), - (-1, ("VAL", 20, "A"), "VAL20.A"), - ]) + @pytest.mark.parametrize( + "ix, resid, resid_str", + [ + (0, ("ALA", 1, "A"), "ALA1.A"), + (4, ("CYS", 5, "A"), "CYS5.A"), + (6, ("GLU", 7, "A"), "GLU7.A"), + (9, ("ILE", 10, "A"), "ILE10.A"), + (19, ("VAL", 20, "A"), "VAL20.A"), + (-1, ("VAL", 20, "A"), "VAL20.A"), + ], + ) def test_getitem(self, residues, ix, resid, resid_str): rg = ResidueGroup(residues) resid = ResidueId(*resid) @@ -185,11 +196,9 @@ def test_getitem(self, residues, ix, resid, resid_str): def test_getitem_keyerror(self): rg = ResidueGroup([]) - with pytest.raises(KeyError, - match="Expected a ResidueId, int, or str"): + with pytest.raises(KeyError, match="Expected a ResidueId, int, or str"): rg[True] - with pytest.raises(KeyError, - match="Expected a ResidueId, int, or str"): + with pytest.raises(KeyError, match="Expected a ResidueId, int, or str"): rg[1.5] def test_select(self): @@ -206,7 +215,7 @@ def test_select(self): assert rg.select((rg.chain == "B") ^ (rg.name == "ALA")).n_residues == 103 # not assert rg.select(~(rg.chain == "B")).n_residues == 212 - + def test_select_sameas_getitem(self): rg = protein_mol.residues sel = rg.select((rg.name == "LYS") & (rg.number == 49))[0] diff --git a/tests/test_utils.py b/tests/test_utils.py index 9628aae..f87f9bd 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -5,38 +5,44 @@ import pytest from numpy.testing import assert_equal from prolif.residue import Residue, ResidueGroup, ResidueId -from prolif.utils import (angle_between_limits, get_centroid, - get_residues_near_ligand, is_peptide_bond, - pandas_series_to_bv, split_mol_by_residues, - to_bitvectors, to_dataframe) +from prolif.utils import ( + angle_between_limits, + get_centroid, + get_residues_near_ligand, + is_peptide_bond, + pandas_series_to_bv, + split_mol_by_residues, + to_bitvectors, + to_dataframe, +) from rdkit import Chem from .test_base import ligand_mol, protein_mol def test_centroid(): - xyz = np.array([(0, 0, 0), - (0, 0, 0), - (0, 0, 0), - (2, 2, 2), - (2, 2, 2), - (2, 2, 2)], - dtype=np.float32) + xyz = np.array( + [(0, 0, 0), (0, 0, 0), (0, 0, 0), (2, 2, 2), (2, 2, 2), (2, 2, 2)], + dtype=np.float32, + ) ctd = get_centroid(xyz) assert ctd.shape == (3,) assert_equal(ctd, [1, 1, 1]) -@pytest.mark.parametrize("angle, mina, maxa, ring, expected", [ - (0, 0, 30, False, True), - (30, 0, 30, False, True), - (10, 0, 30, False, True), - (60, 0, 30, False, False), - (150, 0, 30, False, False), - (150, 0, 30, True, True), - (60, 0, 30, True, False), - (120, 0, 30, True, False), -]) +@pytest.mark.parametrize( + "angle, mina, maxa, ring, expected", + [ + (0, 0, 30, False, True), + (30, 0, 30, False, True), + (10, 0, 30, False, True), + (60, 0, 30, False, False), + (150, 0, 30, False, False), + (150, 0, 30, True, True), + (60, 0, 30, True, False), + (120, 0, 30, True, False), + ], +) def test_angle_limits(angle, mina, maxa, ring, expected): angle = radians(angle) mina = radians(mina) @@ -46,17 +52,60 @@ def test_angle_limits(angle, mina, maxa, ring, expected): def test_pocket_residues(): resids = get_residues_near_ligand(ligand_mol, protein_mol) - residues = ["TYR38.A", "TYR40.A", "GLN41.A", "VAL102.A", "SER106.A", - "TYR109.A", "THR110.A", "TRP115.A", "TRP125.A", "LEU126.A", - "ASP129.A", "ILE130.A", "THR131.A", "CYS133.A", "THR134.A", - "ILE137.A", "ILE180.A", "GLU198.A", "CYS199.A", "VAL200.A", - "VAL201.A", "ASN202.A", "THR203.A", "TYR208.A", "THR209.A", - "VAL210.A", "TYR211.A", "SER212.A", "THR213.A", "VAL214.A", - "GLY215.A", "ALA216.A", "PHE217.A", "TRP327.B", "PHE330.B", - "PHE331.B", "ILE333.B", "SER334.B", "LEU335.B", "MET337.B", - "PRO338.B", "LEU348.B", "ALA349.B", "ILE350.B", "PHE351.B", - "ASP352.B", "PHE353.B", "PHE354.B", "THR355.B", "TRP356.B", - "GLY358.B", "TYR359.B"] + residues = [ + "TYR38.A", + "TYR40.A", + "GLN41.A", + "VAL102.A", + "SER106.A", + "TYR109.A", + "THR110.A", + "TRP115.A", + "TRP125.A", + "LEU126.A", + "ASP129.A", + "ILE130.A", + "THR131.A", + "CYS133.A", + "THR134.A", + "ILE137.A", + "ILE180.A", + "GLU198.A", + "CYS199.A", + "VAL200.A", + "VAL201.A", + "ASN202.A", + "THR203.A", + "TYR208.A", + "THR209.A", + "VAL210.A", + "TYR211.A", + "SER212.A", + "THR213.A", + "VAL214.A", + "GLY215.A", + "ALA216.A", + "PHE217.A", + "TRP327.B", + "PHE330.B", + "PHE331.B", + "ILE333.B", + "SER334.B", + "LEU335.B", + "MET337.B", + "PRO338.B", + "LEU348.B", + "ALA349.B", + "ILE350.B", + "PHE351.B", + "ASP352.B", + "PHE353.B", + "PHE354.B", + "THR355.B", + "TRP356.B", + "GLY358.B", + "TYR359.B", + ] for res in residues: r = ResidueId.from_string(res) assert r in resids @@ -65,8 +114,9 @@ def test_pocket_residues(): def test_split_residues(): sequence = "ARNDCQEGHILKMFPSTWYV" prot = Chem.MolFromSequence(sequence) - rg = ResidueGroup([Residue(res) - for res in Chem.SplitMolByPDBResidues(prot).values()]) + rg = ResidueGroup( + [Residue(res) for res in Chem.SplitMolByPDBResidues(prot).values()] + ) residues = [Residue(mol) for mol in split_mol_by_residues(prot)] residues.sort(key=lambda x: x.resid) for molres, res in zip(residues, rg.values()): @@ -97,19 +147,31 @@ def test_series_to_bv(): assert bv.GetNumOnBits() == 3 -ifp = [{"Frame": 0, +ifp = [ + { + "Frame": 0, ("LIG", "ALA1"): np.array([True, False, False]), - ("LIG", "GLU2"): np.array([False, True, False])}, - {"Frame": 1, + ("LIG", "GLU2"): np.array([False, True, False]), + }, + { + "Frame": 1, ("LIG", "ALA1"): np.array([True, True, False]), - ("LIG", "ASP3"): np.array([False, True, False])}] - -ifp_atoms = [{"Frame": 0, - ("LIG", "ALA1"): [[1, 0, 0], [0, None, None], [1, None, None]], - ("LIG", "GLU2"): [[0, 1, 0], [None, 1, None], [None, 3, None]]}, - {"Frame": 1, - ("LIG", "ALA1"): [[1, 1, 0], [2, 2, None], [4, 5, None]], - ("LIG", "ASP3"): [[0, 1, 0], [None, 8, None], [None, 10, None]]}] + ("LIG", "ASP3"): np.array([False, True, False]), + }, +] + +ifp_atoms = [ + { + "Frame": 0, + ("LIG", "ALA1"): [[1, 0, 0], [0, None, None], [1, None, None]], + ("LIG", "GLU2"): [[0, 1, 0], [None, 1, None], [None, 3, None]], + }, + { + "Frame": 1, + ("LIG", "ALA1"): [[1, 1, 0], [2, 2, None], [4, 5, None]], + ("LIG", "ASP3"): [[0, 1, 0], [None, 8, None], [None, 10, None]], + }, +] def test_to_df(): @@ -141,9 +203,14 @@ def test_to_df_atom_pairs(): assert df[("LIG", "ASP3", "B")][0] == (None, None) -@pytest.mark.parametrize("dtype", [ - np.uint8, np.int16, np.bool_, -]) +@pytest.mark.parametrize( + "dtype", + [ + np.uint8, + np.int16, + np.bool_, + ], +) def test_to_df_dtype(dtype): df = to_dataframe(ifp, ["A", "B", "C"], dtype=dtype) assert df.dtypes[0].type is dtype @@ -159,8 +226,7 @@ def test_to_df_drop_empty(): def test_to_df_raise_dtype_return_atoms(): with pytest.raises( - ValueError, - match="`dtype` cannot be used with `return_atoms=True`" + ValueError, match="`dtype` cannot be used with `return_atoms=True`" ): to_dataframe(ifp_atoms, ["A", "B", "C"], dtype=int, return_atoms=True) diff --git a/tests/test_wrapper_docs.py b/tests/test_wrapper_docs.py index f2655f5..becdde6 100644 --- a/tests/test_wrapper_docs.py +++ b/tests/test_wrapper_docs.py @@ -2,8 +2,7 @@ import pytest from prolif.fingerprint import Fingerprint, _Docstring -interaction_list = [i for i in Fingerprint.list_available() - if "Dummy" not in i] +interaction_list = [i for i in Fingerprint.list_available() if "Dummy" not in i] class Wrapper: @@ -13,6 +12,7 @@ class Wrapper: class Dummy: """Dummy class docs""" + def do_something(self): """Method docstring""" return 1 From c22d939b102a886606bfe95225835a6dda2a00be Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Bouysset?= Date: Thu, 17 Nov 2022 20:21:57 +0100 Subject: [PATCH 02/11] add missing entries in changelog --- CHANGELOG.md | 37 ++++++++++++++++++++++++++++++++----- 1 file changed, 32 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index dc9d03a..543081c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,9 +7,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] ## [1.1.0] - 2022-11-XX + ### Added - `Fingerprint.run` now has a `converter_kwargs` parameter that can pass kwargs to the underlying RDKitConverter from MDAnalysis (Issue #57). +- Formatting with `black`. ### Changed - The SMARTS for the following groups have been updated to a more accurate definition @@ -23,61 +25,74 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Cation: include amidine and guanidine, - Metal ligand: exclude amides and some amines. - The Pi stacking interactions have been changed for a more accurate implementation - (PR #97). + (PR #97, PR #98). - The `pdbqt_supplier` will not add explicit hydrogen atoms anymore, to avoid detecting - hydrogen bonds with "random" hydrogens that weren't in the PDBQT file. -- When using the `pdbqt_supplier`, irrelevant warnings and logs have been disabled. -- Updated the minimal RDKit version to `2021.03.1` + hydrogen bonds with "random" hydrogens that weren't in the PDBQT file (PR #99). +- When using the `pdbqt_supplier`, irrelevant warnings and logs have been disabled (PR #99). +- Updated the minimal RDKit version to `2021.03.1` ### Fixed - Dead link in the quickstart notebook for the MDAnalysis quickstart (PR #75, @radifar). -- The `pdbqt_supplier` now correctly preserves hydrogens from the input PDBQT file. +- The `pdbqt_supplier` now correctly preserves hydrogens from the input PDBQT file (PR #99). + ## [1.0.0] - 2022-06-07 + ### Added - Support for multiprocessing, enabled by default (Issue #46). The number of processes can be controlled through `n_jobs` in `fp.run` and `fp.run_from_iterable`. - New interaction: van der Waals contact, based on the sum of vdW radii of two atoms. - Saving/loading the fingerprint object as a pickle with `fp.to_pickle` and `Fingerprint.from_pickle` (Issue #40). + ### Changed - Molecule suppliers can now be indexed, reused and can return their length, instead of being single-use generators. + ### Fixed - ProLIF can now be installed through pip and conda (Issue #6). - If no interaction is detected in the first frame, `to_dataframe` will not complain about a `KeyError` anymore (Issue #44). - When creating a `plf.Fingerprint`, unknown interactions will no longer fail silently. + ## [0.3.4] - 2021-09-28 + ### Added - Added our J. Cheminformatics article to the citation page of the documentation and the `CITATION.cff` file. + ### Changed - Improved the documentation on how to properly restrict interactions to ignore the protein backbone (Issue #22), how to fix the empty dataframe issue when no bond information is present in the PDB file (Issue #15), how to save the LigNetwork diagram (Issue #21), and some clarifications on using `fp.generate` + ### Fixed - Mixing residue type with interaction type in the interactive legend of the LigNetwork would incorrectly display/hide some residues on the canvas (#PR 23) - MOL2 files starting with a comment (`#`) would lead to an error + ## [0.3.3] - 2021-06-11 + ### Changed - Custom interactions must return three values: a boolean for the interaction, and the indices of residue atoms responsible for the interaction + ### Fixed - Custom interactions that only returned a single value instead of three would raise an uninformative error message ## [0.3.2] - 2021-06-11 + ### Added - LigNetwork: an interaction diagram with atomistic details for the ligand and residue-level details for the protein, fully interactive in a browser/notebook, inspired from LigPlot (PR #19) - `fp.generate`: a method to get the IFP between two `prolif.Molecule` objects (PR #19) + ### Changed - Default residue name and number: `UNK` and `0` are now the default values if `None` or `''` is given @@ -87,11 +102,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 recalculating the IFP if one wants to display it with atomic details (PR #19) - Changed the values returned by `fp.bitvector_atoms`: the atom indices have been separated in two lists, one for the ligand and one for the protein (PR #19) + ### Fixed - Residues with a resnumber of `0` are not converted to `None` anymore (Issue #13) - Fingerprint instantiated with an unknown interaction name will now raise a `NameError` + ## [0.3.1] - 2021-02-02 + ### Added - Integration with Zenodo to automatically generate a DOI for new releases - Citation page @@ -99,6 +117,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - PDBQT, MOL2 and SDF molecule suppliers to make it easier for users to use docking results as input (Issue #11) - `Molecule.from_rdkit` classmethod to easily prepare RDKit molecules for ProLIF + ### Changed - The visualisation notebook now displays the protein with py3Dmol. Some examples for creating and displaying a graph from the interaction dataframe have been added @@ -109,25 +128,33 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added the `Fingerprint.run_from_iterable` method, which uses the new supplier functions to quickly generate a fingerprint. - Sorted the output of `Fingerprint.list_available` + ### Fixed - `Fingerprint.to_dataframe` is now much faster (Issue #7) - `ResidueId.from_string` method now supports 1-letter and 2-letter codes for RNA/DNA (Issue #8) + ## [0.3.0] - 2020-12-23 + ### Added - Reading input directly from RDKit Mol as well as MDAnalysis AtomGroup objects - Proper documentation and tests - CI through GitHub Actions - Publishing to PyPI triggered by GitHub releases + ### Changed - All the API and the underlying code have been modified - Repository has been moved from GitHub user @cbouy to organisation @chemosim-lab + ### Removed - Custom MOL2 file reader - Command-line interface + ### Fixed - Interactions not detected properly + ## [0.2.1] - 2019-10-02 + Base version for this changelog \ No newline at end of file From 8690833f8d2c82e7a0071978dfbed17c6095a708 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Bouysset?= Date: Thu, 17 Nov 2022 22:04:49 +0100 Subject: [PATCH 03/11] add vdwcontact to defaults --- CHANGELOG.md | 2 ++ prolif/fingerprint.py | 1 + prolif/interactions.py | 6 +++--- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 543081c..815dfe0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -26,6 +26,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Metal ligand: exclude amides and some amines. - The Pi stacking interactions have been changed for a more accurate implementation (PR #97, PR #98). +- The Van der Waals contact has been added to the default interactions, and the `tolerance` + parameter has been set to 0. - The `pdbqt_supplier` will not add explicit hydrogen atoms anymore, to avoid detecting hydrogen bonds with "random" hydrogens that weren't in the PDBQT file (PR #99). - When using the `pdbqt_supplier`, irrelevant warnings and logs have been disabled (PR #99). diff --git a/prolif/fingerprint.py b/prolif/fingerprint.py index 0e083fa..0548def 100644 --- a/prolif/fingerprint.py +++ b/prolif/fingerprint.py @@ -217,6 +217,7 @@ def __init__( "Cationic", "CationPi", "PiCation", + "VdWContact", ], ): self._set_interactions(interactions) diff --git a/prolif/interactions.py b/prolif/interactions.py index 9eba6fe..e2b6d1d 100644 --- a/prolif/interactions.py +++ b/prolif/interactions.py @@ -675,7 +675,7 @@ class VdWContact(Interaction): ValueError : ``tolerance`` parameter cannot be negative """ - def __init__(self, tolerance=0.5): + def __init__(self, tolerance=0.0): if tolerance >= 0: self.tolerance = tolerance else: @@ -689,10 +689,10 @@ def detect(self, ligand, residue): lig = la.GetSymbol().upper() res = ra.GetSymbol().upper() try: - vdw = self._vdw_cache[(lig, res)] + vdw = self._vdw_cache[frozenset((lig, res))] except KeyError: vdw = vdwradii[lig] + vdwradii[res] + self.tolerance - self._vdw_cache[(lig, res)] = vdw + self._vdw_cache[frozenset((lig, res))] = vdw dist = lxyz.GetAtomPosition(la.GetIdx()).Distance( rxyz.GetAtomPosition(ra.GetIdx()) ) From 10b733679d7ccc4be45bf5b2b1ba1e1844ee265a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Bouysset?= Date: Thu, 17 Nov 2022 22:05:15 +0100 Subject: [PATCH 04/11] fix lignetwork colors --- prolif/plotting/network.py | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/prolif/plotting/network.py b/prolif/plotting/network.py index a98bb4d..f40fca2 100644 --- a/prolif/plotting/network.py +++ b/prolif/plotting/network.py @@ -91,6 +91,7 @@ class LigNetwork: "PiStacking": "#b559e3", "EdgeToFace": "#b559e3", "FaceToFace": "#b559e3", + "VdWContact": "#59e3ad", }, "atoms": { "C": "black", @@ -243,10 +244,16 @@ def __init__( # regroup interactions of the same color temp = defaultdict(list) interactions = set(df.index.get_level_values("interaction").unique()) - for interaction, color in self.COLORS["interactions"].items(): - if interaction in interactions: - temp[color].append(interaction) - self._interaction_types = {i: "/".join(t) for c, t in temp.items() for i in t} + for interaction in interactions: + color = self.COLORS["interactions"].get( + interaction, self._default_interaction_color + ) + temp[color].append(interaction) + self._interaction_types = { + interaction: "/".join(interaction_group) + for interaction_group in temp.values() + for interaction in interaction_group + } @classmethod def from_ifp(cls, ifp, lig, kind="aggregate", frame=0, threshold=0.3, **kwargs): @@ -484,7 +491,9 @@ def _make_interactions(self, mass=2): "from": origin, "to": prot_res, "title": interaction, - "interaction_type": self._interaction_types[interaction], + "interaction_type": self._interaction_types.get( + interaction, interaction + ), "color": self.COLORS["interactions"].get( interaction, self._default_interaction_color ), @@ -578,7 +587,7 @@ def _get_legend(self, height="90px"): buttons = [] map_color_restype = {c: t for t, c in self.COLORS["residues"].items()} map_color_interactions = { - self.COLORS["interactions"][i]: t + self.COLORS["interactions"].get(i, self._default_interaction_color): t for i, t in self._interaction_types.items() } # residues @@ -598,7 +607,7 @@ def _get_legend(self, height="90px"): for edge in self.edges: if edge.get("group", "") == "interaction": color = edge["color"] - available[color] = map_color_interactions.get(color, "Unknown") + available[color] = map_color_interactions[color] available = { k: v for k, v in sorted(available.items(), key=lambda item: item[1]) } From 90ba8970ee273f4ca874efd825647c0687683e38 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Bouysset?= Date: Thu, 17 Nov 2022 22:05:49 +0100 Subject: [PATCH 05/11] add some color options for the py3Dmol script --- docs/notebooks/visualisation.ipynb | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/docs/notebooks/visualisation.ipynb b/docs/notebooks/visualisation.ipynb index 9f9ac21..c554258 100644 --- a/docs/notebooks/visualisation.ipynb +++ b/docs/notebooks/visualisation.ipynb @@ -43,7 +43,7 @@ "outputs": [], "source": [ "# get lig-prot interactions with atom info\n", - "fp = plf.Fingerprint([\"HBDonor\", \"HBAcceptor\", \"Cationic\", \"PiStacking\"])\n", + "fp = plf.Fingerprint()\n", "fp.run(u.trajectory[0:1], lig, prot)\n", "df = fp.to_dataframe(return_atoms=True)\n", "df.T" @@ -105,9 +105,13 @@ "import py3Dmol\n", "\n", "colors = {\n", - " \"HBAcceptor\": \"blue\",\n", - " \"HBDonor\": \"red\",\n", - " \"Cationic\": \"green\",\n", + " \"Hydrophobic\": \"green\",\n", + " \"HBAcceptor\": \"cyan\",\n", + " \"HBDonor\": \"cyan\",\n", + " \"XBDonor\": \"orange\",\n", + " \"XBAcceptor\": \"orange\",\n", + " \"Cationic\": \"red\",\n", + " \"Anionic\": \"blue\",\n", " \"PiStacking\": \"purple\",\n", "}\n", "\n", @@ -170,7 +174,7 @@ " {\n", " \"start\": dict(x=p1.x, y=p1.y, z=p1.z),\n", " \"end\": dict(x=p2.x, y=p2.y, z=p2.z),\n", - " \"color\": colors[interaction],\n", + " \"color\": colors.get(interaction, \"grey\"),\n", " \"radius\": 0.15,\n", " \"dashed\": True,\n", " \"fromCap\": 1,\n", @@ -262,7 +266,6 @@ "source": [ "import networkx as nx\n", "from pyvis.network import Network\n", - "from tqdm.auto import tqdm\n", "from matplotlib import cm, colors\n", "from IPython.display import IFrame" ] @@ -613,7 +616,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.5" + "version": "3.8.13" } }, "nbformat": 4, From 2554f507a8228534565b2b6be212a7f821b03dcc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Bouysset?= Date: Thu, 17 Nov 2022 22:09:21 +0100 Subject: [PATCH 06/11] update citation page to Zenodo --- docs/source/citation.rst | 32 ++------------------------------ 1 file changed, 2 insertions(+), 30 deletions(-) diff --git a/docs/source/citation.rst b/docs/source/citation.rst index c887a46..77838cb 100644 --- a/docs/source/citation.rst +++ b/docs/source/citation.rst @@ -8,33 +8,5 @@ If you use ProLIF in your research, please cite the following `paper `_. From 820c93dc39e663da1a51e628c9024943b621a486 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Bouysset?= Date: Thu, 17 Nov 2022 22:48:46 +0100 Subject: [PATCH 07/11] fix issue #89 to_dataframe on empty ifp --- CHANGELOG.md | 2 ++ prolif/utils.py | 6 +++++- tests/test_utils.py | 6 ++++++ 3 files changed, 13 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 815dfe0..5716596 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -36,6 +36,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed - Dead link in the quickstart notebook for the MDAnalysis quickstart (PR #75, @radifar). - The `pdbqt_supplier` now correctly preserves hydrogens from the input PDBQT file (PR #99). +- If no interaction was detected, `to_dataframe` would error without giving a helpful message. It + now returns a dataframe with the correct number of frames in the index and no column. ## [1.0.0] - 2022-06-07 diff --git a/prolif/utils.py b/prolif/utils.py index ed64542..4f21fee 100644 --- a/prolif/utils.py +++ b/prolif/utils.py @@ -246,6 +246,7 @@ def to_dataframe( # residue pairs keys = sorted(set([k for d in ifp for k in d.keys() if k != index_col])) # check if each interaction value is a list of atom indices or smthg else + has_atom_indices = False for d in ifp: for key, value in d.items(): if key != index_col: @@ -253,7 +254,7 @@ def to_dataframe( break if return_atoms and not has_atom_indices: raise ValueError( - "The IFP either doesn't contain atom indices or is " "formatted incorrectly" + "The IFP either doesn't contain atom indices or is formatted incorrectly" ) # create empty array for each residue pair interaction that doesn't exist # in a particular frame @@ -279,6 +280,9 @@ def to_dataframe( data[key].append(arr) index = pd.Series(index, name=index_col) # create dataframes + if not data: + warnings.warn("No interaction detected") + return pd.DataFrame([], index=index) values = np.array( [np.hstack([np.ravel(a[i]) for a in data.values()]) for i in range(len(index))] ) diff --git a/tests/test_utils.py b/tests/test_utils.py index f87f9bd..7ffde8a 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -243,6 +243,12 @@ def test_to_df_no_interaction_in_first_frame(ifp): to_dataframe(fp, ["A", "B", "C"]) +def test_to_df_empty_ifp(): + ifp = [{"Frame": 0}, {"Frame": 1}] + df = to_dataframe(ifp, ["A"]) + assert df.to_numpy().shape == (2, 0) + + def test_to_bv(): df = to_dataframe(ifp, ["A", "B", "C"]) bvs = to_bitvectors(df) From d759977d785ae9f1546be4f5ae5db77cee7a8661 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Bouysset?= Date: Thu, 17 Nov 2022 22:55:37 +0100 Subject: [PATCH 08/11] switch to conftest --- tests/conftest.py | 46 ++++++++++++++++++++++++++++++++++++++ tests/test_base.py | 14 +----------- tests/test_fingerprint.py | 34 +++++++++++++--------------- tests/test_interactions.py | 5 ++--- tests/test_molecule.py | 12 +++++----- tests/test_residues.py | 6 ++--- tests/test_utils.py | 4 +--- 7 files changed, 75 insertions(+), 46 deletions(-) create mode 100644 tests/conftest.py diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..5fc1dae --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,46 @@ +import pytest +from MDAnalysis import Universe +from rdkit import Chem + +from prolif.datafiles import TOP, TRAJ +from prolif.molecule import Molecule + + +@pytest.fixture(scope="session") +def u(): + return Universe(TOP, TRAJ) + + +@pytest.fixture(scope="session") +def rdkit_mol(): + return Chem.MolFromPDBFile(TOP, removeHs=False) + + +@pytest.fixture(scope="session") +def ligand_ag(u): + return u.select_atoms("resname LIG") + + +@pytest.fixture(scope="session") +def ligand_rdkit(ligand_ag): + return ligand_ag.convert_to.rdkit() + + +@pytest.fixture(scope="session") +def ligand_mol(ligand_ag): + return Molecule.from_mda(ligand_ag) + + +@pytest.fixture(scope="session") +def protein_ag(u): + return u.select_atoms("protein") + + +@pytest.fixture(scope="session") +def protein_rdkit(protein_ag): + return protein_ag.convert_to.rdkit() + + +@pytest.fixture(scope="session") +def protein_mol(protein_ag): + return Molecule.from_mda(protein_ag) diff --git a/tests/test_base.py b/tests/test_base.py index 88aa73d..07cc7da 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -1,25 +1,13 @@ import pytest -from MDAnalysis import Universe from numpy.testing import assert_array_almost_equal -from prolif.datafiles import TOP, TRAJ -from prolif.molecule import Molecule from prolif.rdkitmol import BaseRDKitMol from rdkit import Chem from rdkit.Chem.rdMolTransforms import ComputeCentroid -u = Universe(TOP, TRAJ) -rdkit_mol = Chem.MolFromPDBFile(TOP, removeHs=False) -ligand_ag = u.select_atoms("resname LIG") -ligand_rdkit = ligand_ag.convert_to.rdkit() -ligand_mol = Molecule.from_mda(ligand_ag) -protein_ag = u.select_atoms("protein") -protein_rdkit = protein_ag.convert_to.rdkit() -protein_mol = Molecule.from_mda(protein_ag) - class TestBaseRDKitMol: @pytest.fixture(scope="class") - def mol(self): + def mol(self, rdkit_mol): return BaseRDKitMol(rdkit_mol) def test_init(self, mol): diff --git a/tests/test_fingerprint.py b/tests/test_fingerprint.py index 7e00cec..2723220 100644 --- a/tests/test_fingerprint.py +++ b/tests/test_fingerprint.py @@ -11,8 +11,6 @@ from prolif.residue import ResidueId from rdkit.DataStructs import ExplicitBitVect -from .test_base import ligand_ag, ligand_mol, protein_ag, protein_mol, u - class Dummy(Interaction): def detect(self, res1, res2): @@ -79,12 +77,12 @@ def test_wrapped(self, fp_simple): assert fp_simple.dummy("foo", "bar") == 1 assert fp_simple.dummy.__wrapped__("foo", "bar") == (True, 4, 2) - def test_bitvector(self, fp): + def test_bitvector(self, fp, ligand_mol, protein_mol): bv = fp.bitvector(ligand_mol, protein_mol["ASP129.A"]) assert len(bv) == fp.n_interactions assert bv.sum() > 0 - def test_bitvector_atoms(self, fp): + def test_bitvector_atoms(self, fp, ligand_mol, protein_mol): bv, lig_ix, prot_ix = fp.bitvector_atoms(ligand_mol, protein_mol["ASP129.A"]) assert len(bv) == fp.n_interactions assert len(lig_ix) == fp.n_interactions @@ -93,7 +91,7 @@ def test_bitvector_atoms(self, fp): ids = np.where(bv == 1)[0] assert lig_ix[ids[0]] is not None and prot_ix[ids[0]] is not None - def test_run_residues(self, fp_simple): + def test_run_residues(self, fp_simple, u, ligand_ag, protein_ag): fp_simple.run( u.trajectory[0:1], ligand_ag, protein_ag, residues="all", progress=False ) @@ -123,14 +121,14 @@ def test_run_residues(self, fp_simple): assert (lig_id, res) in fp_simple.ifp[0].keys() u.trajectory[0] - def test_generate(self, fp_simple): + def test_generate(self, fp_simple, ligand_mol, protein_mol): ifp = fp_simple.generate(ligand_mol, protein_mol) key = (ResidueId("LIG", 1, "G"), ResidueId("VAL", 201, "A")) bv = ifp[key] assert isinstance(bv, np.ndarray) assert bv[0] is np.True_ - def test_run(self, fp_simple): + def test_run(self, fp_simple, u, ligand_ag, protein_ag): fp_simple.run( u.trajectory[0:1], ligand_ag, protein_ag, residues=None, progress=False ) @@ -142,13 +140,13 @@ def test_run(self, fp_simple): assert isinstance(data[1], list) assert isinstance(data[2], list) - def test_run_from_iterable(self, fp_simple): + def test_run_from_iterable(self, fp_simple, protein_mol): path = str(datapath / "vina" / "vina_output.sdf") lig_suppl = list(sdf_supplier(path)) fp_simple.run_from_iterable(lig_suppl[:2], protein_mol, progress=False) assert len(fp_simple.ifp) == 2 - def test_to_df(self, fp_simple): + def test_to_df(self, fp_simple, u, ligand_ag, protein_ag): with pytest.raises(AttributeError, match="use the `run` method"): Fingerprint().to_dataframe() fp_simple.run( @@ -158,7 +156,7 @@ def test_to_df(self, fp_simple): assert isinstance(df, DataFrame) assert len(df) == 3 - def test_to_df_kwargs(self, fp_simple): + def test_to_df_kwargs(self, fp_simple, u, ligand_ag, protein_ag): fp_simple.run( u.trajectory[:3], ligand_ag, protein_ag, residues=None, progress=False ) @@ -168,7 +166,7 @@ def test_to_df_kwargs(self, fp_simple): resids = set([key for d in fp_simple.ifp for key in d.keys() if key != "Frame"]) assert df.shape == (3, len(resids)) - def test_to_bv(self, fp_simple): + def test_to_bv(self, fp_simple, u, ligand_ag, protein_ag): with pytest.raises(AttributeError, match="use the `run` method"): Fingerprint().to_bitvectors() fp_simple.run( @@ -195,7 +193,7 @@ def test_unknown_interaction(self): Fingerprint(["Cationic", "foo"]) @pytest.fixture - def fp_unpkl(self, fp): + def fp_unpkl(self, fp, protein_mol): path = str(datapath / "vina" / "vina_output.sdf") lig_suppl = list(sdf_supplier(path)) fp.run_from_iterable(lig_suppl[:2], protein_mol, progress=False) @@ -203,7 +201,7 @@ def fp_unpkl(self, fp): return Fingerprint.from_pickle(pkl) @pytest.fixture - def fp_unpkl_file(self, fp): + def fp_unpkl_file(self, fp, protein_mol): path = str(datapath / "vina" / "vina_output.sdf") lig_suppl = list(sdf_supplier(path)) fp.run_from_iterable(lig_suppl[:2], protein_mol, progress=False) @@ -216,7 +214,7 @@ def fp_unpkl_file(self, fp): def fp_pkled(self, request): return request.getfixturevalue(request.param) - def test_pickle(self, fp, fp_pkled): + def test_pickle(self, fp, fp_pkled, protein_mol): path = str(datapath / "vina" / "vina_output.sdf") lig_suppl = list(sdf_supplier(path)) fp.run_from_iterable(lig_suppl[:2], protein_mol, progress=False) @@ -238,7 +236,7 @@ def test_pickle_custom_interaction(self, fp_unpkl): assert hasattr(fp_unpkl, "dummy") assert callable(fp_unpkl.dummy) - def test_run_multiproc_serial_same(self, fp): + def test_run_multiproc_serial_same(self, fp, u, ligand_ag, protein_ag): fp.run(u.trajectory[0:100:10], ligand_ag, protein_ag, n_jobs=1, progress=False) serial = fp.to_dataframe() fp.run( @@ -247,7 +245,7 @@ def test_run_multiproc_serial_same(self, fp): multi = fp.to_dataframe() assert serial.equals(multi) - def test_run_iter_multiproc_serial_same(self, fp): + def test_run_iter_multiproc_serial_same(self, fp, protein_mol): run = fp.run_from_iterable path = str(datapath / "vina" / "vina_output.sdf") lig_suppl = sdf_supplier(path) @@ -257,7 +255,7 @@ def test_run_iter_multiproc_serial_same(self, fp): multi = fp.to_dataframe() assert serial.equals(multi) - def test_converter_kwargs_raises_error(self, fp: Fingerprint): + def test_converter_kwargs_raises_error(self, fp, u, ligand_ag, protein_ag): with pytest.raises( ValueError, match="converter_kwargs must be a list of 2 dicts" ): @@ -271,7 +269,7 @@ def test_converter_kwargs_raises_error(self, fp: Fingerprint): ) @pytest.mark.parametrize("n_jobs", [1, 2]) - def test_converter_kwargs(self, fp: Fingerprint, n_jobs: int): + def test_converter_kwargs(self, fp, n_jobs): u = mda.Universe.from_smiles("O=C=O.O=C=O") lig, prot = u.atoms.fragments fp.run( diff --git a/tests/test_interactions.py b/tests/test_interactions.py index 4bbd85c..2b898e9 100644 --- a/tests/test_interactions.py +++ b/tests/test_interactions.py @@ -10,7 +10,6 @@ from rdkit import Chem, RDLogger from . import mol2factory -from .test_base import ligand_mol, protein_mol # disable rdkit warnings lg = RDLogger.logger() @@ -139,7 +138,7 @@ class _Dummy(Interaction): _Dummy() @pytest.mark.parametrize("index", [0, 1, 3, 42, 78]) - def test_get_mapindex(self, index): + def test_get_mapindex(self, index, ligand_mol): parent_index = get_mapindex(ligand_mol[0], index) assert parent_index == index @@ -280,7 +279,7 @@ def create_rings(xyz, rotation): r2.trajectory.add_transformations(tr, rotx, roty, rotz) return prolif.Molecule.from_mda(benzene), prolif.Molecule.from_mda(r2) - def test_edgetoface_phe331(self): + def test_edgetoface_phe331(self, ligand_mol, protein_mol): fp = Fingerprint() lig, phe331 = ligand_mol[0], protein_mol["PHE331.B"] assert fp.edgetoface(lig, phe331) is True diff --git a/tests/test_molecule.py b/tests/test_molecule.py index f7e5488..c3c4943 100644 --- a/tests/test_molecule.py +++ b/tests/test_molecule.py @@ -7,19 +7,19 @@ from prolif.residue import ResidueId from rdkit import Chem -from .test_base import TestBaseRDKitMol, ligand_rdkit, rdkit_mol, u +from .test_base import TestBaseRDKitMol class TestMolecule(TestBaseRDKitMol): @pytest.fixture(scope="class") - def mol(self): + def mol(self, rdkit_mol): return Molecule(rdkit_mol) def test_mapindex(self, mol): for atom in mol.GetAtoms(): assert atom.GetUnsignedProp("mapindex") == atom.GetIdx() - def test_from_mda(self): + def test_from_mda(self, u, ligand_rdkit): rdkit_mol = Molecule(ligand_rdkit) mda_mol = Molecule.from_mda(u, "resname LIG") assert rdkit_mol[0].resid == mda_mol[0].resid @@ -27,12 +27,12 @@ def test_from_mda(self): rdkit_mol ) - def test_from_mda_empty_ag(self): + def test_from_mda_empty_ag(self, u): ag = u.select_atoms("resname FOO") with pytest.raises(SelectionError, match="AtomGroup is empty"): Molecule.from_mda(ag) - def test_from_rdkit(self): + def test_from_rdkit(self, ligand_rdkit): rdkit_mol = Molecule(ligand_rdkit) newmol = Molecule.from_rdkit(ligand_rdkit) assert rdkit_mol[0].resid == newmol[0].resid @@ -95,7 +95,7 @@ def suppl(self): ) return pdbqt_supplier(pdbqts, template) - def test_pdbqt_hydrogens_stay_in_mol(self): + def test_pdbqt_hydrogens_stay_in_mol(self, ligand_rdkit): template = Chem.RemoveHs(ligand_rdkit) indices = [] rwmol = Chem.RWMol(ligand_rdkit) diff --git a/tests/test_residues.py b/tests/test_residues.py index 235a6e3..16fa34e 100644 --- a/tests/test_residues.py +++ b/tests/test_residues.py @@ -4,7 +4,7 @@ from rdkit import Chem from rdkit.Chem import AllChem -from .test_base import TestBaseRDKitMol, protein_mol +from .test_base import TestBaseRDKitMol class TestResidueId: @@ -201,7 +201,7 @@ def test_getitem_keyerror(self): with pytest.raises(KeyError, match="Expected a ResidueId, int, or str"): rg[1.5] - def test_select(self): + def test_select(self, protein_mol): rg = protein_mol.residues assert rg.select(rg.name == "LYS").n_residues == 16 assert rg.select(rg.number == 300).n_residues == 1 @@ -216,7 +216,7 @@ def test_select(self): # not assert rg.select(~(rg.chain == "B")).n_residues == 212 - def test_select_sameas_getitem(self): + def test_select_sameas_getitem(self, protein_mol): rg = protein_mol.residues sel = rg.select((rg.name == "LYS") & (rg.number == 49))[0] assert sel.resid is rg["LYS49.A"].resid diff --git a/tests/test_utils.py b/tests/test_utils.py index 7ffde8a..04e4c86 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -17,8 +17,6 @@ ) from rdkit import Chem -from .test_base import ligand_mol, protein_mol - def test_centroid(): xyz = np.array( @@ -50,7 +48,7 @@ def test_angle_limits(angle, mina, maxa, ring, expected): assert angle_between_limits(angle, mina, maxa, ring) is expected -def test_pocket_residues(): +def test_pocket_residues(ligand_mol, protein_mol): resids = get_residues_near_ligand(ligand_mol, protein_mol) residues = [ "TYR38.A", From a75b305cffc4cca238b2612adf72d7ef66fd18c7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Bouysset?= Date: Thu, 17 Nov 2022 23:00:15 +0100 Subject: [PATCH 09/11] sort imports --- prolif/__init__.py | 8 ++++---- prolif/plotting/network.py | 10 ++++++---- setup.py | 6 ++++-- tests/mol2factory.py | 1 + tests/plotting/test_network.py | 8 +++++--- tests/test_base.py | 3 ++- tests/test_fingerprint.py | 3 ++- tests/test_interactions.py | 12 ++++++------ tests/test_molecule.py | 6 ++++-- tests/test_residues.py | 3 ++- tests/test_utils.py | 3 ++- tests/test_wrapper_docs.py | 3 ++- 12 files changed, 40 insertions(+), 26 deletions(-) diff --git a/prolif/__init__.py b/prolif/__init__.py index 6b650cd..e3165b3 100644 --- a/prolif/__init__.py +++ b/prolif/__init__.py @@ -1,9 +1,9 @@ -from .molecule import Molecule, pdbqt_supplier, mol2_supplier, sdf_supplier -from .residue import ResidueId -from .fingerprint import Fingerprint -from .utils import get_residues_near_ligand, to_dataframe, to_bitvectors from . import datafiles from ._version import get_versions +from .fingerprint import Fingerprint +from .molecule import Molecule, mol2_supplier, pdbqt_supplier, sdf_supplier +from .residue import ResidueId +from .utils import get_residues_near_ligand, to_bitvectors, to_dataframe __version__ = get_versions()["version"] del get_versions diff --git a/prolif/plotting/network.py b/prolif/plotting/network.py index f40fca2..3ae002d 100644 --- a/prolif/plotting/network.py +++ b/prolif/plotting/network.py @@ -8,16 +8,18 @@ :members: """ -from copy import deepcopy -from collections import defaultdict -import warnings import json import re +import warnings +from collections import defaultdict +from copy import deepcopy from html import escape -import pandas as pd + import numpy as np +import pandas as pd from rdkit import Chem from rdkit.Chem import rdDepictor + from ..residue import ResidueId from ..utils import requires diff --git a/setup.py b/setup.py index 6abbbfe..672f6fd 100644 --- a/setup.py +++ b/setup.py @@ -1,7 +1,9 @@ +import os +import re + from setuptools import setup + import versioneer -import re -import os GITHUB_ACTIONS = os.environ.get("GITHUB_ACTIONS", False) diff --git a/tests/mol2factory.py b/tests/mol2factory.py index 6801305..a2d46b1 100644 --- a/tests/mol2factory.py +++ b/tests/mol2factory.py @@ -1,6 +1,7 @@ import numpy as np from MDAnalysis import Universe from MDAnalysis.topology.guessers import guess_atom_element + from prolif.datafiles import datapath from prolif.molecule import Molecule diff --git a/tests/plotting/test_network.py b/tests/plotting/test_network.py index cd13308..3b5a11e 100644 --- a/tests/plotting/test_network.py +++ b/tests/plotting/test_network.py @@ -1,11 +1,13 @@ -from io import StringIO -from tempfile import NamedTemporaryFile import os from contextlib import contextmanager +from io import StringIO +from tempfile import NamedTemporaryFile + import MDAnalysis as mda +import pytest + import prolif as plf from prolif.plotting.network import LigNetwork -import pytest @contextmanager diff --git a/tests/test_base.py b/tests/test_base.py index 07cc7da..8de5bd1 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -1,9 +1,10 @@ import pytest from numpy.testing import assert_array_almost_equal -from prolif.rdkitmol import BaseRDKitMol from rdkit import Chem from rdkit.Chem.rdMolTransforms import ComputeCentroid +from prolif.rdkitmol import BaseRDKitMol + class TestBaseRDKitMol: @pytest.fixture(scope="class") diff --git a/tests/test_fingerprint.py b/tests/test_fingerprint.py index 2723220..47b24d7 100644 --- a/tests/test_fingerprint.py +++ b/tests/test_fingerprint.py @@ -4,12 +4,13 @@ import numpy as np import pytest from pandas import DataFrame +from rdkit.DataStructs import ExplicitBitVect + from prolif.datafiles import datapath from prolif.fingerprint import Fingerprint, _InteractionWrapper from prolif.interactions import _INTERACTIONS, Interaction from prolif.molecule import sdf_supplier from prolif.residue import ResidueId -from rdkit.DataStructs import ExplicitBitVect class Dummy(Interaction): diff --git a/tests/test_interactions.py b/tests/test_interactions.py index 2b898e9..98f5df6 100644 --- a/tests/test_interactions.py +++ b/tests/test_interactions.py @@ -1,13 +1,13 @@ -import prolif -import pytest -import numpy as np import MDAnalysis as mda +import numpy as np +import pytest from MDAnalysis.topology.tables import vdwradii -from MDAnalysis.transformations import translate, rotateby -from rdkit import RDLogger +from MDAnalysis.transformations import rotateby, translate +from rdkit import Chem, RDLogger + +import prolif from prolif.fingerprint import Fingerprint from prolif.interactions import _INTERACTIONS, Interaction, VdWContact, get_mapindex -from rdkit import Chem, RDLogger from . import mol2factory diff --git a/tests/test_molecule.py b/tests/test_molecule.py index c3c4943..00deaf0 100644 --- a/tests/test_molecule.py +++ b/tests/test_molecule.py @@ -1,11 +1,13 @@ -import pytest from copy import deepcopy + +import pytest from MDAnalysis import SelectionError from numpy.testing import assert_array_equal +from rdkit import Chem + from prolif.datafiles import datapath from prolif.molecule import Molecule, mol2_supplier, pdbqt_supplier, sdf_supplier from prolif.residue import ResidueId -from rdkit import Chem from .test_base import TestBaseRDKitMol diff --git a/tests/test_residues.py b/tests/test_residues.py index 16fa34e..2c630d9 100644 --- a/tests/test_residues.py +++ b/tests/test_residues.py @@ -1,9 +1,10 @@ import pytest from numpy.testing import assert_equal -from prolif.residue import Residue, ResidueGroup, ResidueId from rdkit import Chem from rdkit.Chem import AllChem +from prolif.residue import Residue, ResidueGroup, ResidueId + from .test_base import TestBaseRDKitMol diff --git a/tests/test_utils.py b/tests/test_utils.py index 04e4c86..6951d4d 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -4,6 +4,8 @@ import numpy as np import pytest from numpy.testing import assert_equal +from rdkit import Chem + from prolif.residue import Residue, ResidueGroup, ResidueId from prolif.utils import ( angle_between_limits, @@ -15,7 +17,6 @@ to_bitvectors, to_dataframe, ) -from rdkit import Chem def test_centroid(): diff --git a/tests/test_wrapper_docs.py b/tests/test_wrapper_docs.py index becdde6..b2e46dc 100644 --- a/tests/test_wrapper_docs.py +++ b/tests/test_wrapper_docs.py @@ -1,5 +1,6 @@ -import prolif.interactions import pytest + +import prolif.interactions from prolif.fingerprint import Fingerprint, _Docstring interaction_list = [i for i in Fingerprint.list_available() if "Dummy" not in i] From a36b4707bb0511ac534d7040f16bd0cdcee264ce Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Bouysset?= Date: Thu, 17 Nov 2022 23:02:19 +0100 Subject: [PATCH 10/11] remove docs wrapper for fp. --- prolif/fingerprint.py | 29 ----------------------- tests/test_wrapper_docs.py | 47 -------------------------------------- 2 files changed, 76 deletions(-) delete mode 100644 tests/test_wrapper_docs.py diff --git a/prolif/fingerprint.py b/prolif/fingerprint.py index 0548def..47b76ab 100644 --- a/prolif/fingerprint.py +++ b/prolif/fingerprint.py @@ -48,25 +48,6 @@ from .utils import get_residues_near_ligand, to_bitvectors, to_dataframe -class _Docstring: - """Descriptor that replaces the documentation shown when calling - ``fp.hydrophobic?`` and other interaction methods""" - - def __init__(self): - self._docs = {} - - def __set__(self, instance, func): - # add function's docstring to memory - cls = func.__self__.__class__ - self._docs[cls.__name__] = cls.__doc__ - - def __get__(self, instance, owner): - if instance is None: - return self - # fetch docstring of last accessed Fingerprint method - return self._docs[type(instance)._current_func] - - class _InteractionWrapper: """Modifies the return signature of an interaction ``detect`` method by forcing it to return only the first element when multiple values are @@ -107,9 +88,6 @@ class _InteractionWrapper: """ - __doc__ = _Docstring() - _current_func = "" - def __init__(self, func): self.__wrapped__ = func # add docstring to descriptor @@ -242,13 +220,6 @@ def _set_interactions(self, interactions): if name in interactions: self.interactions[name] = func - def __getattribute__(self, name): - # trick to get the correct docstring when calling `fp.hydrophobic?` - attr = super().__getattribute__(name) - if isinstance(attr, _InteractionWrapper): - type(attr)._current_func = attr.__wrapped__.__self__.__class__.__name__ - return attr - def __repr__(self): # pragma: no cover name = ".".join([self.__class__.__module__, self.__class__.__name__]) params = f"{self.n_interactions} interactions: {list(self.interactions.keys())}" diff --git a/tests/test_wrapper_docs.py b/tests/test_wrapper_docs.py deleted file mode 100644 index b2e46dc..0000000 --- a/tests/test_wrapper_docs.py +++ /dev/null @@ -1,47 +0,0 @@ -import pytest - -import prolif.interactions -from prolif.fingerprint import Fingerprint, _Docstring - -interaction_list = [i for i in Fingerprint.list_available() if "Dummy" not in i] - - -class Wrapper: - __doc__ = _Docstring() - _current_func = "" - - -class Dummy: - """Dummy class docs""" - - def do_something(self): - """Method docstring""" - return 1 - - -@pytest.fixture -def wrap(): - method = Dummy().do_something - wrap = Wrapper() - wrap.__doc__ = method - return wrap - - -def test_getter(wrap): - assert Wrapper._current_func == "" - Wrapper._current_func = "Dummy" # simulate __getattribute__ - wrap.__doc__ - assert Wrapper.__doc__._docs["Dummy"] == "Dummy class docs" - - -@pytest.fixture(scope="module") -def fp(): - return Fingerprint() - - -@pytest.mark.parametrize("int_name", interaction_list) -def test_fp_docs(fp, int_name): - meth = getattr(fp, int_name.lower()) - assert type(meth)._current_func == int_name - cls = getattr(prolif.interactions, int_name) - assert type(meth).__doc__._docs[int_name] == cls.__doc__ From 65acfa08f972e06442327856bdc78e807afc50e8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Bouysset?= Date: Thu, 17 Nov 2022 23:27:52 +0100 Subject: [PATCH 11/11] fix rdkit version check on install --- setup.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index 672f6fd..4c25136 100644 --- a/setup.py +++ b/setup.py @@ -6,15 +6,16 @@ import versioneer GITHUB_ACTIONS = os.environ.get("GITHUB_ACTIONS", False) +READTHEDOCS = os.environ.get("READTHEDOCS", False) # manually check RDKit version try: from rdkit import __version__ as rdkit_version except ImportError: - if not GITHUB_ACTIONS: + if not (GITHUB_ACTIONS or READTHEDOCS): raise ImportError("ProLIF requires RDKit but it is not installed") else: - if re.match(r"^20[0-1][0-9]\.", rdkit_version): - raise ValueError("ProLIF requires a version of RDKit >= 2020") + if re.match(r"^(20[0-1][0-9])|(2020)", rdkit_version): + raise ValueError("ProLIF requires a version of RDKit >= 2021") setup(version=versioneer.get_version())