Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: use .walk file in complexity command #10

Draft
wants to merge 11 commits into
base: main
Choose a base branch
from
3 changes: 3 additions & 0 deletions .devcontainer.json
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
],
"settings": {
"python.analysis.typeCheckingMode": "strict",
// "python.analysis.exclude": [
// "**/.vscode-remote/**"
// ],
"python.condaPath": "/opt/conda/condabin/conda",
"python.terminal.activateEnvironment": true,
"python.terminal.activateEnvInCurrentTerminal": true,
Expand Down
4 changes: 2 additions & 2 deletions panct/complexity.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,9 @@ def main(
if file_type == "gfa":
if region_str is not None:
log.warning("Regions are ignored when processing GFA")
exclude = []
exclude = set()
if reference != "":
exclude = [reference]
exclude.add(reference)
node_table = gutils.NodeTable(graph_file, exclude)
metric_results = []
for m in metrics_list:
Expand Down
47 changes: 35 additions & 12 deletions panct/data/walks.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@ class Walks(Data):

Attributes
----------
data : dict[str, Counter[tuple[str, int]]]
data : dict[int, Counter[tuple[str, int]]]
A bunch of nodes, stored as a mapping of node IDs to tuples of (sample labels, haplotype ID)
log: Logger
A logging instance for recording debug statements.
"""

def __init__(self, data: dict[str, Counter[tuple[str, int]]], log: Logger = None):
def __init__(self, data: dict[int, Counter[tuple[str, int]]], log: Logger = None):
super().__init__(log=log)
self.data = data

Expand All @@ -34,7 +34,12 @@ def __len__(self):

@classmethod
def read(
cls: Type[Walks], fname: Path | str, region: str = None, log: Logger = None
cls: Type[Walks],
fname: Path | str,
region: str = None,
nodes: set[int] = None,
exclude_samples: set[str] = set(),
log: Logger = None,
) -> Walks:
"""
Extract walks from a .walk file
Expand All @@ -46,6 +51,10 @@ def read(
region: str, optional
A region string denoting the start and end node IDs in the form
of f'{start}-{end}'
nodes: set[int], optional
A subset of nodes to load. Defaults to all nodes.
exclude_samples: set[str], optional
If specifieed, we will not load these samples
log: Logger, optional
A Logger object to use for debugging statements

Expand All @@ -54,7 +63,7 @@ def read(
Walks
A Walks object loaded with a bunch of Node objects
"""
nodes = {}
final_nodes = {}
parse_samp = lambda samp: (samp[0], int(samp[1]))
# Try to read the file with tabix
if Path(fname).suffix == ".gz" and region is not None:
Expand All @@ -66,11 +75,21 @@ def read(
for line in f.fetch(region=region_str):
samples = line.strip().split("\t")
node = int(samples.pop(0))

nodes[node] = Counter(
parse_samp(samp.rsplit(":", 1)) for samp in samples
if nodes is not None and node not in nodes:
continue
final_nodes[node] = Counter(
s
for samp in samples
if (s := parse_samp(samp.rsplit(":", 1)))[0]
not in exclude_samples
)
return cls(nodes, log)
if (
log is not None
and nodes is not None
and len(final_nodes) < len(nodes)
):
log.warning("Couldn't load all requested nodes")
return cls(final_nodes, log)
except ValueError:
pass
# If we couldn't parse with tabix, then fall back to slow loading
Expand All @@ -88,9 +107,13 @@ def read(
for line in f:
samples = str(line.strip())
node = int(samples.split("\t", maxsplit=1)[0])
if node < start or node > end:
if (node < start or node > end) or (
nodes is not None and node not in nodes
):
continue
nodes[node] = Counter(
parse_samp(samp.rsplit(":", 1)) for samp in samples.split("\t")[1:]
final_nodes[node] = Counter(
s
for samp in samples.split("\t")[1:]
if (s := parse_samp(samp.rsplit(":", 1)))[0] not in exclude_samples
)
return cls(nodes, log)
return cls(final_nodes, log)
11 changes: 10 additions & 1 deletion panct/gbz_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,4 +140,13 @@ def load_node_table_from_gbz(
gfa_file = extract_region_from_gbz(gbz_file, region, reference)
if gfa_file is None:
return gutils.NodeTable()
return gutils.NodeTable(gfa_file=gfa_file, exclude_samples=[reference])
walk_file = gbz_file.with_suffix(".walk")
if not walk_file.exists():
walk_file = walk_file.with_suffix(".walk.gz")
if not walk_file.exists():
walk_file = None
return gutils.NodeTable(
gfa_file=gfa_file,
exclude_samples=set((reference,)),
walk_file=walk_file,
)
103 changes: 76 additions & 27 deletions panct/graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@
"""

from pathlib import Path
from collections import Counter

import numpy as np

from .data import Walks


class Node:
"""
Expand Down Expand Up @@ -79,12 +82,17 @@ class NodeTable:
Get list of nodes from the walk
"""

def __init__(self, gfa_file: Path = None, exclude_samples: list[str] = []):
def __init__(
self,
gfa_file: Path = None,
exclude_samples: set[str] = set(),
walk_file: Path = None,
):
self.nodes = {} # node ID-> Node
self.numwalks = 0
self.walk_lengths = []
if gfa_file is not None:
self.load_from_gfa(gfa_file, exclude_samples)
self.load_from_gfa(gfa_file, exclude_samples, walk_file)

def add_node(self, node: Node):
"""
Expand Down Expand Up @@ -192,7 +200,21 @@ def get_nodes_from_walk(self, walk_string: str) -> list[str]:
ws = walk_string.replace(">", ":").replace("<", ":").strip(":")
return ws.split(":")

def load_from_gfa(self, gfa_file: Path, exclude_samples: list[str] = []):
def load_from_gfa(
self, gfa_file: Path, exclude_samples: set[str] = set(), walk_file: Path = None
):
"""
Load a NodeTable from a GFA file

Parameters
----------
gfa_file : Path
Path to the GFA file
exclude_samples : set[str], optional
List of samples to exclude
walk_file : Path, optional
Path to the .walk file
"""
# First parse all the nodes
with open(gfa_file, "r") as f:
for line in f:
Expand All @@ -213,28 +235,55 @@ def load_from_gfa(self, gfa_file: Path, exclude_samples: list[str] = []):
self.add_node(Node(nodeid, length=nodelen))

# try to find the .walk file
walk_file = Path("")
if gfa_file.suffix == ".gz":
walk_file = gfa_file.with_suffix("").with_suffix(".walk")
if walk_file is None:
if gfa_file.suffix == ".gz":
walk_file = gfa_file.with_suffix("").with_suffix(".walk")
else:
walk_file = gfa_file.with_suffix(".walk")
if not walk_file.exists():
walk_file = walk_file.with_suffix(".walk.gz")

if walk_file.exists():
node_set = set(int(n) for n in self.nodes.keys())
# find smallest and largest node for processing walks
smallest_node = min(node_set, default="")
largest_node = max(node_set, default="")
# Get nodes from .walk file and add with self.add_walk()
walks = Walks.read(
walk_file,
region=f"{smallest_node}-{largest_node}",
nodes=node_set,
exclude_samples=exclude_samples,
log=None, # TODO: pass Logger
)
# check that all of the nodes were loaded properly
# TODO: remove this check? or implement a fail-safe
assert len(walks.data) == len(node_set)
assert len(node_set) == len(self.nodes)
walk_lengths = Counter()
all_samples = set()
for node, node_val in self.nodes.items():
node_int = int(node)
samples = set(f"{hap[0]}:{hap[1]}" for hap in walks.data[node_int])
all_samples.update(samples)
node_val.samples.update(samples)
for sampid, hapid in walks.data[node_int]:
# how many times did this haplotype pass through this node?
num_times = walks.data[node_int][(sampid, hapid)]
walk_lengths[f"{sampid}:{hapid}"] += node_val.length * num_times
self.numwalks += len(all_samples)
self.walk_lengths.extend(walk_lengths.values())
else:
walk_file = gfa_file.with_suffix("")
if not walk_file.exists():
walk_file = walk_file.with_suffix(".walk.gz")

# TODO: get nodes from .walk file and add with self.add_walk()
# if walk_file.exists():
# else:

# Second pass to get the walks
with open(gfa_file, "r") as f:
for line in f:
linetype = line.split()[0]
if linetype != "W":
continue
sampid = line.split()[1]
if sampid in exclude_samples:
continue
hapid = line.split()[2]
walk = line.split()[6]
nodes = self.get_nodes_from_walk(walk)
self.add_walk(f"{sampid}:{hapid}", nodes)
# Second pass over gfa file to get the walks
with open(gfa_file, "r") as f:
for line in f:
linetype = line.split()[0]
if linetype != "W":
continue
sampid = line.split()[1]
if sampid in exclude_samples:
continue
hapid = line.split()[2]
walk = line.split()[6]
nodes = self.get_nodes_from_walk(walk)
self.add_walk(f"{sampid}:{hapid}", nodes)
43 changes: 43 additions & 0 deletions tests/test_complexity.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import shutil
from pathlib import Path
from logging import getLogger

Expand Down Expand Up @@ -62,6 +63,25 @@ def test_basic_stdout(capfd):
assert result.exit_code == 0


def test_basic_stdout_wo_walk_file(capfd):
"""
panct complexity basic.gfa
"""
in_file = DATADIR / "basic.gfa"
tmp_file = Path("basic.gfa")
shutil.copyfile(in_file, tmp_file)
expected = expected_basic_output

cmd = f"complexity {tmp_file}"
result = runner.invoke(app, cmd.split(" "), catch_exceptions=False)
captured = capfd.readouterr()
# check that the output is the same as what we expect
assert captured.out == expected
assert result.exit_code == 0

tmp_file.unlink()


def test_basic_stdout_region(capfd):
"""
panct complexity --region chrTest:0-1 tests/data/basic.gbz
Expand All @@ -80,6 +100,29 @@ def test_basic_stdout_region(capfd):
assert result.exit_code == 0


def test_basic_stdout_region_wo_walk_file(capfd):
"""
panct complexity --region chrTest:0-1 basic.gbz
"""
region = ("chrTest", 0, 1)
region_str = f"{region[0]}:{region[1]}-{region[2]}"
in_file = DATADIR / "basic.gbz"
tmp_file = Path("basic.gbz")
shutil.copyfile(in_file, tmp_file)
expected = expected_basic_output
expected = prefix_expected_with_region(expected, (region,))

cmd = f"complexity --region {region_str} {tmp_file}"
result = runner.invoke(app, cmd.split(" "), catch_exceptions=False)
captured = capfd.readouterr()
# check that the output is the same as what we expect
assert captured.out == expected
assert result.exit_code == 0

tmp_file.unlink()
tmp_file.with_suffix(".gbz.db").unlink()


def test_basic_regions_bed(capfd):
"""
panct complexity --out basic.tsv --region tests/data/basic.bed tests/data/basic.gbz
Expand Down
9 changes: 9 additions & 0 deletions tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,19 @@ def test_parse_walks_file(self):
nodes = Walks.read(DATADIR / "basic.walk.gz", region="-2")
assert nodes.data == expected.data

nodes = Walks.read(DATADIR / "basic.walk.gz", region="-2", nodes={1, 2})
assert nodes.data == expected.data

del expected.data[2]

nodes = Walks.read(DATADIR / "basic.walk", region="1-1")
assert nodes.data == expected.data

nodes = Walks.read(DATADIR / "basic.walk.gz", region="1-1")
assert nodes.data == expected.data

nodes = Walks.read(DATADIR / "basic.walk.gz", region="1-2", nodes=set((1,)))
assert nodes.data == expected.data

nodes = Walks.read(DATADIR / "basic.walk.gz", nodes=set((1,)))
assert nodes.data == expected.data
6 changes: 4 additions & 2 deletions tests/test_graph_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import os
from pathlib import Path

import pytest
Expand Down Expand Up @@ -55,7 +54,10 @@ def test_node_table():
nt = NodeTable(gfa_file=DATADIR / "basic_noseq.gfa")
assert nt.get_mean_walk_length() == 95 / 4
assert nt.numwalks == 4
nt = NodeTable(gfa_file=DATADIR / "basic_noseq.gfa", exclude_samples=["GRCh38"])
nt = NodeTable(
gfa_file=DATADIR / "basic_noseq.gfa",
exclude_samples=set(("GRCh38",)),
)
assert nt.get_mean_walk_length() == 70 / 3
assert nt.numwalks == 3

Expand Down
Loading