Skip to content

Commit

Permalink
Added read mapping benchmark
Browse files Browse the repository at this point in the history
  • Loading branch information
gsc74 committed Jan 21, 2024
1 parent 4cb0261 commit a6792e7
Show file tree
Hide file tree
Showing 12 changed files with 284 additions and 12 deletions.
3 changes: 3 additions & 0 deletions data/v1.3/Gen_Graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,4 +388,7 @@ def process_fasta(fasta):
print("Walks for " + graph + " generated successfully")
print("--- Time : %s seconds ---" % (time.time() - start_time))

# compress all gfa files
os.system("cd Graphs && gzip *.gfa")

print("--- Total Graph Generation Time : %s seconds ---" % (time.time() - start_time_))
82 changes: 82 additions & 0 deletions data/v1.3/Map_Reads.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
#!/usr/bin/env python3

import subprocess
import os
import multiprocessing
import re
import time
import sys
import getopt as getopt


par_threads = 6

## Pass arguments
argv = sys.argv[1:]

if(len(argv)==0):
print("help: Map_Reads.py -h")
sys.exit(2)

try:
opts, args = getopt.getopt(argv, 't:')
except:
print("usage: Map_Reads.py -t <threads>")
sys.exit(2)

for opt, arg in opts:
if opt == '-h':
print("usage: Map_Reads.py -t <threads>")
sys.exit()
elif opt in ("-t"):
par_threads = int(arg)

print("Mapping threads : " + str(par_threads))

total_threads = multiprocessing.cpu_count()
map_threads = total_threads/par_threads
map_threads = int(map_threads)

Reads = ['PacBio', 'ONT']
Graph = 'Graphs/MHC-CHM13.0.gfa.gz'
R = ['0', '1000', '10000', '100000', '1000000', '2000000000']

Metadata = []
for read in Reads:
Read = 'Reads/MHC_CHM13_' + read + '_filt.fq.gz'
for r in R:
Metadata.append([read, r, Read])

# check if Mapped_Reads directory exists
if os.path.exists('Mapped_Reads'):
os.system('rm -rf Mapped_Reads')
os.system('mkdir Mapped_Reads')
if not os.path.exists('Mapped_Reads'):
os.system('mkdir Mapped_Reads')

def Map_Reads(Metadata):
read, r, Read = Metadata
Output = 'Mapped_Reads/' + read + '_' + r + '.gaf'
cmd = 'minichain -t' + str(map_threads) + ' -cx lr -b1 -R' + r + ' ' + Graph + ' ' + Read + ' > ' + Output
out = subprocess.check_output(cmd, shell=True, stderr=subprocess.STDOUT)
out = out.decode("utf-8")

print(out)
# add 6th line from out to count_recomb
out = out.split('\n')
val = re.findall(r'R: (\d+.\d+)', out[5])
R, NR_R, NR_NR = val[0], val[1], val[2]

# write to file
with open('Mapped_Reads/' + read + '_' + r + '.txt', 'w') as f:
f.write('Read\tr\tR\tNR_R\tNR_NR\n')
f.write(read + '\t' + r + '\t' + R + '\t' + NR_R + '\t' + NR_NR + '\n')

time_start = time.time()
# run in parallel
pool = multiprocessing.Pool(processes=par_threads)
pool.map(Map_Reads, Metadata)
pool.close()
pool.join()
time_end = time.time()
print('Mapping time: ' + str(time_end - time_start))
93 changes: 93 additions & 0 deletions data/v1.3/Plot_Map.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
#!/usr/bin/env python3

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import matplotlib.patches as mpatches
import matplotlib.lines as mlines
from rich import print as rprint
from rich.console import Console
from rich.table import Table
from pylatexenc.latex2text import LatexNodes2Text


Reads = ['PacBio', 'ONT']
R = ['0', '1000', '10000', '100000', '1000000', '2000000000']
R_labels = ['$0$', '$10^3$', '$10^4$', '$10^5$', '$10^6$', '$\infty$']
R_labels_ = [LatexNodes2Text().latex_to_text(r) for r in R_labels]

R_ = dict()
NR_R = dict()
NR_NR = dict()

# Read data from file
for read in Reads:
R_[read] = []
NR_R[read] = []
NR_NR[read] = []
for r in R:
with open('Mapped_Reads/' + read + '_' + r + '.txt', 'r') as f:
for line in f:
line = line.strip()
line = line.split('\t')
if line[0] == 'Read':
continue
R_[read].append(float(line[2]))
NR_R[read].append(float(line[3]))
NR_NR[read].append(float(line[4]))

# Plot a single bar for R, R+NR_R, R+NR_R+NR_NR for each read
print('Plotting')
for read in Reads:
x = np.arange(len(R))
fig, ax = plt.subplots()
# Increase the figure size
fig.set_size_inches(6, 4)
# for r in R:
# print('Recombination Penalty : ' + r + ' Read : ' + read + ' R : ' + str(R_[read][R.index(r)]) + ' NR_R : ' + str(NR_R[read][R.index(r)]) + ' NR_NR : ' + str(NR_NR[read][R.index(r)]))
# print table with rich text

table = Table(title="Reads: " + read)
table.add_column("Recombination Penalty", justify="right", style="cyan", no_wrap=True)
table.add_column("Complete Support", justify="right", style="blue", no_wrap=True)
table.add_column("Partial Support", justify="right", style="magenta", no_wrap=True)
table.add_column("No Support", justify="right", style="green", no_wrap=True)
for i in range(len(R)):
table.add_row(R_labels_[i], str(R_[read][i]), str(NR_R[read][i]), str(NR_NR[read][i]))
console = Console(record=True)
console.print(table, justify="center")

# save table to pdf file
console.save_svg("Mapped_Reads/" + read + "_table.svg", title = "Effect of Haplotype-aware Chaining")

val_1 = []
val_2 = []
val_3 = []
for i in range(len(R)):
val_1.append(NR_NR[read][i] + NR_R[read][i] + R_[read][i])
val_2.append(R_[read][i] + NR_R[read][i])
val_3.append(R_[read][i])
ax.bar(x, val_1, label='No Support', zorder = 3)
ax.bar(x, val_2, label='Partial Support', zorder = 3)
ax.bar(x, val_3, label='Complete Support', zorder = 3)

ax.tick_params(axis='both', which='major', labelsize=11)
ax.tick_params(axis='both', which='minor', labelsize=11)
ax.set_ylabel('Chains supported by reads', fontsize=12)
# ax.set_title(read + ' reads')
ax.set_xticks(x)
ax.set_xticklabels(R_labels, rotation=45)
ax.set_xlabel('Recombination penalty', fontsize=12)
plt.grid(True)
# ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0.)
# put legend on the top of the plot
ax.legend(bbox_to_anchor=(0.5, 1.2), loc='upper center', borderaxespad=0., ncol=3)
# reverse the legend order
handles, labels = ax.get_legend_handles_labels()
ax.legend(handles[::-1], labels[::-1], bbox_to_anchor=(0.5, 1.2), loc='upper center', borderaxespad=0., ncol=3)

fig.tight_layout()
# scale font size by factor of 1.3
# plt.rcParams.update({'font.size': 12})
plt.savefig('Mapped_Reads/' + read + '.pdf')
plt.close()
Binary file added data/v1.3/Reads/MHC_CHM13_ONT_filt.fq.gz
Binary file not shown.
Binary file added data/v1.3/Reads/MHC_CHM13_PacBio_filt.fq.gz
Binary file not shown.
8 changes: 6 additions & 2 deletions data/v1.3/Reproduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,15 +82,19 @@

# create conda environment named MC and install python packages numpy, scipy, matplotlib and networkx Biopython getopt seaborn pandas
# check if conda environment named MC exists or not
os.system("source ~/.bashrc && conda create --force -n MC -y && conda activate MC && conda install -c conda-forge -y numpy scipy matplotlib networkx biopython seaborn pandas")

os.system("source ~/.bashrc && conda create --force -n MC -y && conda activate MC && conda install -c conda-forge -y numpy scipy matplotlib networkx biopython seaborn pandas rich pylatexenc")

map_threads = 6
# Generate the graph
os.system("python3 Gen_Graph.py -t " + str(threads))
# Simulate queries
os.system("source ~/.bashrc && conda activate MC && python3 Simulate_query.py -t " + str(threads))
# Map the queries
os.system("source ~/.bashrc && conda activate MC && python3 Map_Graph.py -t " + str(threads))
# Map the reads
os.system("source ~/.bashrc && conda activate MC && python3 Map_Reads.py -t " + str(map_threads))
# Plot the results
os.system("source ~/.bashrc && conda activate MC && python3 Plot.py")
# Plot the results for mapping
os.system("source ~/.bashrc && conda activate MC && python3 Plot_Map.py")

65 changes: 65 additions & 0 deletions graphUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -886,6 +886,10 @@ std::vector<mg128_t> graphUtils::Chaining(std::vector<mg128_t> anchors, std::str

if (is_hprc_gfa)
{
if (benchmark) tau_1 = 0.99999f;
// chain vtxs
std::vector<std::vector<std::vector<int32_t>>> chains(num_cid);

// Recombinations count
std::vector<int> min_loc(num_cid, std::numeric_limits<int>::max()), max_loc(num_cid, std::numeric_limits<int>::min());
count++; // Read count
Expand All @@ -894,6 +898,8 @@ std::vector<mg128_t> graphUtils::Chaining(std::vector<mg128_t> anchors, std::str
std::vector<float> acc_cid;
acc_cid.resize(num_cid);


bool is_haplotype = false;
for (int cid = 0; cid < num_cid; cid++)
{
int N = M[cid].size(); // #Anchors
Expand Down Expand Up @@ -1200,6 +1206,8 @@ std::vector<mg128_t> graphUtils::Chaining(std::vector<mg128_t> anchors, std::str
}
}

if (benchmark) chains[cid].push_back(temp_chain); // push anchor to chains

// Store the anchors if chain is disjoint and clear the keys
if (flag == true && temp_chain.size() > 0 ) // minimap2 min_cnt
{
Expand Down Expand Up @@ -1236,6 +1244,7 @@ std::vector<mg128_t> graphUtils::Chaining(std::vector<mg128_t> anchors, std::str
// exit(0);
if (qname_.size() == 3)
{
is_haplotype = true;
float precision = 0.0f;
float recall = 0.0f;
float f1_score = 0.0f;
Expand Down Expand Up @@ -1583,6 +1592,62 @@ std::vector<mg128_t> graphUtils::Chaining(std::vector<mg128_t> anchors, std::str
max = std::max(max, min_loc[max_f1_cid_]);
min = std::min(min, min_loc[max_f1_cid_]);
max_sum += max;

if (benchmark && !is_haplotype)
{
accuracy = -1.0f;
max_f1_cid_ = 0;
int64_t max_score = std::numeric_limits<int64_t>::min();
for (int cid = 0; cid < num_cid; cid++)
{
if (max_score < best_chains[cid].second)
{
max_score = best_chains[cid].second;
max_f1_cid_ = cid;
}
}
// find count correct over all cids
std::set<int32_t> walk_vtx;
for (auto walk:fwd_walk_map[walk_map[0]])
{
walk_vtx.insert(walk);
}

for (auto walk:rev_walk_map[walk_map[0]])
{
walk_vtx.insert(walk);
}

int32_t loc_correct = 0;
int32_t loc_not_correct = 0;
float frac_correct_loc = 0.0f;
std::set<int32_t> chain_vtx;
for (auto chain_num:chains[max_f1_cid_]) // If any chain is part of CHM13#0 then do count_correct++
{
for (auto idx:chain_num)
{
chain_vtx.insert(idx_Anchor[max_f1_cid_][idx].x>>32);
}

// find the intersection of chain_vtx and walk_vtx
std::vector<int32_t> common_vtx;
// do set_minus
std::set_difference(chain_vtx.begin(), chain_vtx.end(), walk_vtx.begin(), walk_vtx.end(), std::inserter(common_vtx, common_vtx.begin()));
// If the intersection is empty then do count_correct++;
float frac_correct_ = ((float)chain_vtx.size() - (float)common_vtx.size())/(float)chain_vtx.size();
if (common_vtx.size() == 0)
{
loc_correct = 1;
}else
{
frac_correct_loc = frac_correct_;
loc_not_correct = 1;
}
}
count_correct += loc_correct;
count_not_correct += loc_not_correct;
frac_correct += frac_correct_loc;
}

} else {
for (int cid = 0; cid < num_cid; cid++)
Expand Down
4 changes: 4 additions & 0 deletions graphUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,10 @@ class graphUtils
std::vector<int> ref_query;
float accuracy = 0.0f;
int num_walks = 0;
bool benchmark;
int32_t count_correct = 0;
int32_t count_not_correct = 0;
float frac_correct = 0.0f;

graphUtils(gfa_t *g); // This is constructor

Expand Down
4 changes: 3 additions & 1 deletion index.c
Original file line number Diff line number Diff line change
Expand Up @@ -211,14 +211,15 @@ mg_idx_t *mg_index_core(gfa_t *g, int k, int w, int b, int n_threads)

/* Pass parameters */
params* par;
void pass_par(bool &param_z, int32_t &scale_factor, char* &graph_name, int &G, int32_t &recomb)
void pass_par(bool &param_z, int32_t &scale_factor, char* &graph_name, int &G, int32_t &recomb, bool &benchmark)
{
par = new params();
par->param_z = param_z;
par->scale_factor = scale_factor;
par->graph_name = graph_name;
par->G = G;
par->recomb = recomb;
par->benchmark = benchmark;
}

mg_idx_t *mg_index(gfa_t *g, const mg_idxopt_t *io, int n_threads, mg_mapopt_t *mo)
Expand All @@ -240,6 +241,7 @@ mg_idx_t *mg_index(gfa_t *g, const mg_idxopt_t *io, int n_threads, mg_mapopt_t *
graphOp->param_z = par->param_z;
graphOp->graph_name = par->graph_name;
graphOp->read_graph();
graphOp->benchmark = par->benchmark;
omp_set_dynamic(1);
omp_set_num_threads(0);
graphOp->scale_factor = par->scale_factor;
Expand Down
27 changes: 21 additions & 6 deletions main.c
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ int main(int argc, char *argv[])

// pass_par(z, G);
// minichain
pass_par(z, scale_factor, argv[o.ind], G, recomb);
pass_par(z, scale_factor, argv[o.ind], G, recomb, benchmark);

g = gfa_read(argv[o.ind]);
if (g == 0) {
Expand All @@ -313,20 +313,35 @@ int main(int argc, char *argv[])
exit(EXIT_FAILURE);
}

int min, max, max_sum, count;
float accuracy;
int min, max, max_sum, count, count_correct, count_not_correct;
float accuracy, frac_correct;
std::string haps;
get_vars(min, max, max_sum, count, accuracy, haps);
get_vars(min, max, max_sum, count, accuracy, haps, count_correct, count_not_correct, frac_correct);
float mean = (float)max_sum / (float)count;

// print count of correct and not correct recombination events
// std::cerr << " count_correct : " << count_correct << " count_not_correct : " << count_not_correct << " count : " << count << std::endl;
// assert(count_correct + count_not_correct == count);


float corr = (float)count_correct / (float)count;
float incorr = (1.0f - corr);
float incorr_corr = ((float)frac_correct / (float)count_not_correct) * incorr;
float incorr_incorr = incorr - incorr_corr;

// std::cerr << " accuracy: " << accuracy << " Benchmark: " << benchmark << std::endl;

if (mg_verbose >= 3) {
if (benchmark)
if (benchmark && (accuracy == -1.0f))
{
{fprintf(stderr, "[M::%s] R: %f, NR_R: %f, NR_NR: %f\n", __func__, corr, incorr_corr, incorr_incorr);};
} else if (benchmark)
{
if (min != INT_MAX) {fprintf(stderr, "[M::%s] Recombinations [Min: %d, Max: %d, Mean: %f, Accuracy: %f]\n", __func__, min, max, mean, accuracy);};
if (min != INT_MAX) {fprintf(stderr, "[M::%s] Haplotype paths: %s\n", __func__, haps.c_str());};
}else
{
if (min != INT_MAX){fprintf(stderr, "[M::%s] Recombinations [Min: %d, Max: %d, Mean: %f]\n", __func__, min, max, mean);};
if (min != INT_MAX) {fprintf(stderr, "[M::%s] Recombinations [Min: %d, Max: %d, Mean: %f]\n", __func__, min, max, mean);};
}
fprintf(stderr, "[M::%s] Version: %s\n", __func__, MC_VERSION);
fprintf(stderr, "[M::%s] CMD:", __func__);
Expand Down
Loading

0 comments on commit a6792e7

Please sign in to comment.