Skip to content

Commit

Permalink
added precion and recall in the benchmark
Browse files Browse the repository at this point in the history
  • Loading branch information
gsc74 committed May 20, 2024
1 parent 9ab3d0f commit 74e16ad
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 82 deletions.
113 changes: 43 additions & 70 deletions graphUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -895,9 +895,9 @@ std::vector<mg128_t> graphUtils::Chaining(std::vector<mg128_t> anchors, std::str
count++; // Read count
std::vector<std::string> haps_seq_cid;
haps_seq_cid.resize(num_cid);
std::vector<float> acc_cid;
acc_cid.resize(num_cid);

std::vector<float> acc_cid(num_cid);
std::vector<float> precision_vec(num_cid);
std::vector<float> recall_vec(num_cid);

bool is_haplotype = false;
for (int cid = 0; cid < num_cid; cid++)
Expand All @@ -910,7 +910,6 @@ std::vector<mg128_t> graphUtils::Chaining(std::vector<mg128_t> anchors, std::str
int sf = scale_factor;
if (N == 0) continue;

/* Initialise Search Trees */
/* Initialise T */
std::vector<Tuples> T; // Tuples of Anchors
int cost = (M[cid][0].d - M[cid][0].c + 1) * scale_factor;
Expand Down Expand Up @@ -1023,7 +1022,6 @@ std::vector<mg128_t> graphUtils::Chaining(std::vector<mg128_t> anchors, std::str
std::vector<std::pair<std::pair<int64_t, int>, std::pair<int, int>>> D(N, std::pair<std::pair<int64_t, int>, \
std::pair<int, int>>({std::numeric_limits<int>::min(), -1},{-1, -1})); // (score, index), index

// For dyanamic range tree
// Initialize a pointer and a array.
std::vector<int> x(K,0); // pointer for the path
std::vector<int> rmq_coor(K,0); // current index of the anchor which lies outside the window of G
Expand All @@ -1039,7 +1037,6 @@ std::vector<mg128_t> graphUtils::Chaining(std::vector<mg128_t> anchors, std::str
if (t.task == 0)
{
int64_t p = M[cid][i].x + dist2begin[cid][j][t.v] + M[cid][i].c + Distance[cid][j][w]- 2;
// int range = dist2begin[cid][t.path][t.v] + Distance[cid][t.path][t.w] + M[cid][t.anchor].x - G - 1;
if (index[cid][j][w] != -1) // j \in paths(M[cid][i].v)
{
// Query after anchor deletion
Expand Down Expand Up @@ -1159,9 +1156,10 @@ std::vector<mg128_t> graphUtils::Chaining(std::vector<mg128_t> anchors, std::str
std::pair<std::pair<int64_t, int>, int> anchor_id;

if (param_z) std::cerr << "Backtracking started for cid : " << cid << "\n";

accuracy = -1.0f;
float sum_acc = 0.0f;
float precision_ = 0.0f;
float recall_ = 0.0f;
int count_acc = 0;
while (max_score >= threshold_score && count_anchors < N && max_score > min_score)
{
Expand All @@ -1173,7 +1171,9 @@ std::vector<mg128_t> graphUtils::Chaining(std::vector<mg128_t> anchors, std::str
int prev_hap = path;
int prev_i = idx;
std::vector<std::string> chained_haps;
for (anchor_id = {{C[idx][path].s, idx}, path}; anchor_id.first.second != -1 && anchor_id.second != -1; anchor_id = {{C[anchor_id.first.second][anchor_id.second].s, C[anchor_id.first.second][anchor_id.second].i}, C[anchor_id.first.second][anchor_id.second].j}) // backtracking
for (anchor_id = {{C[idx][path].s, idx}, path}; anchor_id.first.second != -1 && \
anchor_id.second != -1; anchor_id = {{C[anchor_id.first.second][anchor_id.second].s, \
C[anchor_id.first.second][anchor_id.second].i}, C[anchor_id.first.second][anchor_id.second].j}) // backtracking
{
flag = true; // chain is disjoint
count_anchors++; // count anchors
Expand Down Expand Up @@ -1245,8 +1245,6 @@ std::vector<mg128_t> graphUtils::Chaining(std::vector<mg128_t> anchors, std::str
if (qname_.size() == 3)
{
is_haplotype = true;
float precision = 0.0f;
float recall = 0.0f;
float f1_score = 0.0f;
// Split the string into vector of strings by >
std::string delimiter2 = ">";
Expand All @@ -1265,7 +1263,7 @@ std::vector<mg128_t> graphUtils::Chaining(std::vector<mg128_t> anchors, std::str
true_rec = true_haplotypes.size() - 1;
assert(true_rec == atoi(qname_[1].c_str()));

if (recomb == 0)
if (recomb == 0 || true)
{
chained_haps.clear();
std::vector<mg128_t> chained_anchors;
Expand All @@ -1274,34 +1272,49 @@ std::vector<mg128_t> graphUtils::Chaining(std::vector<mg128_t> anchors, std::str
chained_anchors.push_back(idx_Anchor[cid][idx]);
}
// reverse the chain
// std::reverse(chained_anchors.begin(), chained_anchors.end());
std::reverse(chained_anchors.begin(), chained_anchors.end());

std::set<int32_t> set_chained_vtxs;
std::vector<int32_t> chained_vtx;
for (auto idx:chained_anchors)
{
int32_t v = idx.x>>32;
set_chained_vtxs.insert(v);
}
// fill the chained_vtx
for (auto v:set_chained_vtxs){ chained_vtx.push_back(v); }
set_chained_vtxs.clear();

// sort chained vtx by topological order
std::sort(chained_vtx.begin(), chained_vtx.end(), [&](int32_t a, int32_t b) -> bool \
{ return map_top_sort[cid][idx_component[cid][a]] < map_top_sort[cid][idx_component[cid][b]]; });

// Do DP to find minimum recombination
int min_recomb = std::numeric_limits<int>::max();
std::map<std::string, std::vector<std::pair<int, std::pair<std::string, int>>>> C_;
// Inialize C as \inf
for (auto hap:walk_ids)
{
for (int i = 0; i < chained_anchors.size(); i++)
for (int i = 0; i < chained_vtx.size(); i++)
{
C_[hap].push_back({std::numeric_limits<int>::max(), {"-1", -1}});
}
}

// Inialize C
int node = (int)(chained_anchors[0].x >> 32);
int node = chained_vtx[0];
for (auto hap:haps[node])
{
C_[hap][0] = {0, {"-1", -1}};
}

// Do DP
for (int i = 1; i < chained_anchors.size(); i++)
for (int i = 1; i < chained_vtx.size(); i++)
{
int node_1 = (int)(chained_anchors[i].x >> 32);
int node_1 = chained_vtx[i];
for (auto hap:haps[node_1])
{
int node_2 = (int)(chained_anchors[i - 1].x >> 32);
int node_2 = chained_vtx[i-1];
for (auto hap2:haps[node_2])
{
if (hap == hap2)
Expand All @@ -1319,9 +1332,9 @@ std::vector<mg128_t> graphUtils::Chaining(std::vector<mg128_t> anchors, std::str
// Find the mimum recombination
for (auto hap:walk_ids)
{
if (min_dp_recomb > C_[hap][chained_anchors.size() - 1])
if (min_dp_recomb > C_[hap][chained_vtx.size() - 1])
{
min_dp_recomb = C_[hap][chained_anchors.size() - 1];
min_dp_recomb = C_[hap][chained_vtx.size() - 1];
}
}

Expand Down Expand Up @@ -1455,28 +1468,14 @@ std::vector<mg128_t> graphUtils::Chaining(std::vector<mg128_t> anchors, std::str
}
}

// // print true_hap_pair
// for (auto hap_pair:true_hap_pair)
// {
// std::cerr << "True Hap Pair : " << hap_pair.first << " " << hap_pair.second << "\n";
// }

// // print chain_hap_pair
// for (auto hap_pair:chain_hap_pair)
// {
// std::cerr << "Chain Hap Pair : " << hap_pair.first << " " << hap_pair.second << "\n";
// }


precision = (float)true_pos/(float)(true_pos + false_pos);
recall = (float)true_pos/(float)(true_pos + false_neg);
// std::cerr << "Precision : " << precision << "\n";
// std::cerr << "Recall : " << recall << "\n";
precision_ = (float)true_pos/(float)(true_pos + false_pos);
recall_ = (float)true_pos/(float)(true_pos + false_neg);

// compute F1
float loc_acc = 0.0f;
f1_score = 2.0f * ((precision * recall)/(precision + recall));
if (precision != 0.0f && recall != 0.0f)
f1_score = 2.0f * ((precision_ * recall_)/(precision_ + recall_));
if (precision_ != 0.0f && recall_ != 0.0f)
{
sum_acc = f1_score;
loc_acc = f1_score;
Expand All @@ -1486,15 +1485,6 @@ std::vector<mg128_t> graphUtils::Chaining(std::vector<mg128_t> anchors, std::str
loc_acc = 0.0f;
}
count_acc++;

// std::cerr << "f1_score: " << f1_score << std::endl;
// if (precision != 0.0f && recall != 0.0f) {
// accuracy = f1_score;
// std::cerr << "Accuracy updated to: " << accuracy << std::endl;
// } else {
// accuracy = 0.0f;
// std::cerr << "Accuracy updated to: " << accuracy << std::endl;
// }

// add chained_haps to hap_seqs with S and E and > sign
haps_seq_cid[cid] += ">S";
Expand Down Expand Up @@ -1522,18 +1512,6 @@ std::vector<mg128_t> graphUtils::Chaining(std::vector<mg128_t> anchors, std::str
}
}

// // count recombinations
// if (recombination != -1)
// {
// if (loc_max < recombination)
// {
// loc_max = recombination;
// loc_min = recombination;
// }
// }

// Compute max score of C again considering the visited anchors
// if(param_z) std::cerr << " Second pass for max_score computation started ... " << std::endl;
max_score = std::numeric_limits<int>::min(); //{{std::numeric_limits<int>::min(), -1}, -1};
for (int i = prev_idx; i < D.size(); i++) // O(N) time maximum score search
{
Expand All @@ -1546,22 +1524,13 @@ std::vector<mg128_t> graphUtils::Chaining(std::vector<mg128_t> anchors, std::str
break;
}
}

// // For each chain add min/max recombination
// if (loc_max > 0)
// {
// max_loc[cid] += loc_max;
// min_loc[cid] += loc_min;
// }

if(param_z) std::cerr << " Second pass for max_score computation finished ... " << std::endl;

}


acc_cid[cid] = sum_acc/(float)count_acc;


precision_vec[cid] = precision_;
recall_vec[cid] = recall_;

for (int i = 0; i < chain.size(); i++)
{
Expand All @@ -1578,6 +1547,8 @@ std::vector<mg128_t> graphUtils::Chaining(std::vector<mg128_t> anchors, std::str
}
}
int max_f1_cid_ = 0;
float max_precision = 0.0f;
float max_recall = 0.0f;
float max_f1 = std::numeric_limits<float>::min();
for (int cid = 0; cid < num_cid; cid++)
{
Expand All @@ -1588,6 +1559,8 @@ std::vector<mg128_t> graphUtils::Chaining(std::vector<mg128_t> anchors, std::str
}
}
accuracy = max_f1;
precision = precision_vec[max_f1_cid_];
recall = recall_vec[max_f1_cid_];
hap_seqs = haps_seq_cid[max_f1_cid_];
max = std::max(max, min_loc[max_f1_cid_]);
min = std::min(min, min_loc[max_f1_cid_]);
Expand Down Expand Up @@ -1943,7 +1916,7 @@ std::vector<mg128_t> graphUtils::Chaining(std::vector<mg128_t> anchors, std::str
if(!is_hprc_gfa && !is_hap)
{
for (int i = 0; i < best.size(); i++) {
int node = (int)(best[i].x >> 32);
int node = best[i].x >> 32;
if (count[node] <= 5) {
red_idx.push_back(i);
}
Expand Down
1 change: 1 addition & 0 deletions graphUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ class graphUtils
float div;
int max_itr;
std::string hap_seqs;
float precision = 0.0f, recall = 0.0f;

// for recombinations count
int min = std::numeric_limits<int>::max(), max = std::numeric_limits<int>::min(), max_sum = 0, count = 0;
Expand Down
15 changes: 5 additions & 10 deletions main.c
Original file line number Diff line number Diff line change
Expand Up @@ -315,30 +315,25 @@ int main(int argc, char *argv[])

int min, max, max_sum, count, count_correct, count_not_correct;
float accuracy, frac_correct;
float precision, recall;
std::string haps;
get_vars(min, max, max_sum, count, accuracy, haps, count_correct, count_not_correct, frac_correct);
get_vars(min, max, max_sum, count, accuracy, haps, count_correct, count_not_correct, frac_correct, precision, recall);
float mean = (float)max_sum / (float)count;

// print count of correct and not correct recombination events
// std::cerr << " count_correct : " << count_correct << " count_not_correct : " << count_not_correct << " count : " << count << std::endl;
// assert(count_correct + count_not_correct == count);


float corr = (float)count_correct / (float)count;
float incorr = (1.0f - corr);
float incorr_corr = ((float)frac_correct / (float)count_not_correct) * incorr;
float incorr_incorr = incorr - incorr_corr;

// std::cerr << " accuracy: " << accuracy << " Benchmark: " << benchmark << std::endl;

if (mg_verbose >= 3) {
if (benchmark && (accuracy == -1.0f))
{
// {fprintf(stderr, "[M::%s] R: %f, NR_R: %f, NR_NR: %f\n", __func__, corr, incorr_corr, incorr_incorr);};
{fprintf(stderr, "[M::%s] R: %f\n", __func__, corr);};
} else if (benchmark)
{
if (min != INT_MAX) {fprintf(stderr, "[M::%s] Recombinations [Min: %d, Max: %d, Mean: %f, Accuracy: %f]\n", __func__, min, max, mean, accuracy);};
// if (min != INT_MAX) {fprintf(stderr, "[M::%s] Recombinations [Min: %d, Max: %d, Mean: %f, Accuracy: %f]\n", __func__, min, max, mean, accuracy);};
if (min != INT_MAX) {fprintf(stderr, "[M::%s] Recombinations [Min: %d, Max: %d, Mean: %f, Accuracy: %f, precision: %f, recall: %f]\n", __func__, min, max, \
mean, accuracy, precision, recall);};
if (min != INT_MAX) {fprintf(stderr, "[M::%s] Haplotype paths: %s\n", __func__, haps.c_str());};
}else
{
Expand Down
5 changes: 4 additions & 1 deletion map-algo.c
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ void get_Op(graphUtils *graph_Op)
graphOp = graph_Op;
}

void get_vars(int &min, int &max, int &max_sum, int &count, float &accuracy, std::string &hap_seqs, int32_t &count_correct, int32_t &count_not_correct, float &frac_correct){
void get_vars(int &min, int &max, int &max_sum, int &count, float &accuracy, std::string &hap_seqs, int32_t &count_correct, \
int32_t &count_not_correct, float &frac_correct, float &precision, float &recall){
min = graphOp->min;
max = graphOp->max;
max_sum = graphOp->max_sum;
Expand All @@ -25,6 +26,8 @@ void get_vars(int &min, int &max, int &max_sum, int &count, float &accuracy, std
count_correct = graphOp->count_correct;
count_not_correct = graphOp->count_not_correct;
frac_correct = graphOp->frac_correct;
precision = graphOp->precision;
recall = graphOp->recall;
}

struct mg_tbuf_s {
Expand Down
3 changes: 2 additions & 1 deletion mgpriv.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,8 @@ struct params
};


void get_vars(int &min, int &max, int &max_sum, int &count, float &accuracy, std::string &hap_seqs, int32_t &count_correct, int32_t &count_not_correct, float &frac_correct);
void get_vars(int &min, int &max, int &max_sum, int &count, float &accuracy, std::string &hap_seqs, int32_t &count_correct, \
int32_t &count_not_correct, float &frac_correct, float &precison, float &recall);

#ifdef __cplusplus
}
Expand Down

0 comments on commit 74e16ad

Please sign in to comment.