diff --git a/graphUtils.cpp b/graphUtils.cpp index 557d0ec..75fcda9 100644 --- a/graphUtils.cpp +++ b/graphUtils.cpp @@ -895,9 +895,9 @@ std::vector graphUtils::Chaining(std::vector anchors, std::str count++; // Read count std::vector haps_seq_cid; haps_seq_cid.resize(num_cid); - std::vector acc_cid; - acc_cid.resize(num_cid); - + std::vector acc_cid(num_cid); + std::vector precision_vec(num_cid); + std::vector recall_vec(num_cid); bool is_haplotype = false; for (int cid = 0; cid < num_cid; cid++) @@ -910,7 +910,6 @@ std::vector graphUtils::Chaining(std::vector anchors, std::str int sf = scale_factor; if (N == 0) continue; - /* Initialise Search Trees */ /* Initialise T */ std::vector T; // Tuples of Anchors int cost = (M[cid][0].d - M[cid][0].c + 1) * scale_factor; @@ -1023,7 +1022,6 @@ std::vector graphUtils::Chaining(std::vector anchors, std::str std::vector, std::pair>> D(N, std::pair, \ std::pair>({std::numeric_limits::min(), -1},{-1, -1})); // (score, index), index - // For dyanamic range tree // Initialize a pointer and a array. std::vector x(K,0); // pointer for the path std::vector rmq_coor(K,0); // current index of the anchor which lies outside the window of G @@ -1039,7 +1037,6 @@ std::vector graphUtils::Chaining(std::vector anchors, std::str if (t.task == 0) { int64_t p = M[cid][i].x + dist2begin[cid][j][t.v] + M[cid][i].c + Distance[cid][j][w]- 2; - // int range = dist2begin[cid][t.path][t.v] + Distance[cid][t.path][t.w] + M[cid][t.anchor].x - G - 1; if (index[cid][j][w] != -1) // j \in paths(M[cid][i].v) { // Query after anchor deletion @@ -1159,9 +1156,10 @@ std::vector graphUtils::Chaining(std::vector anchors, std::str std::pair, int> anchor_id; if (param_z) std::cerr << "Backtracking started for cid : " << cid << "\n"; - accuracy = -1.0f; float sum_acc = 0.0f; + float precision_ = 0.0f; + float recall_ = 0.0f; int count_acc = 0; while (max_score >= threshold_score && count_anchors < N && max_score > min_score) { @@ -1173,7 +1171,9 @@ std::vector graphUtils::Chaining(std::vector anchors, std::str int prev_hap = path; int prev_i = idx; std::vector chained_haps; - for (anchor_id = {{C[idx][path].s, idx}, path}; anchor_id.first.second != -1 && anchor_id.second != -1; anchor_id = {{C[anchor_id.first.second][anchor_id.second].s, C[anchor_id.first.second][anchor_id.second].i}, C[anchor_id.first.second][anchor_id.second].j}) // backtracking + for (anchor_id = {{C[idx][path].s, idx}, path}; anchor_id.first.second != -1 && \ + anchor_id.second != -1; anchor_id = {{C[anchor_id.first.second][anchor_id.second].s, \ + C[anchor_id.first.second][anchor_id.second].i}, C[anchor_id.first.second][anchor_id.second].j}) // backtracking { flag = true; // chain is disjoint count_anchors++; // count anchors @@ -1245,8 +1245,6 @@ std::vector graphUtils::Chaining(std::vector anchors, std::str if (qname_.size() == 3) { is_haplotype = true; - float precision = 0.0f; - float recall = 0.0f; float f1_score = 0.0f; // Split the string into vector of strings by > std::string delimiter2 = ">"; @@ -1265,7 +1263,7 @@ std::vector graphUtils::Chaining(std::vector anchors, std::str true_rec = true_haplotypes.size() - 1; assert(true_rec == atoi(qname_[1].c_str())); - if (recomb == 0) + if (recomb == 0 || true) { chained_haps.clear(); std::vector chained_anchors; @@ -1274,7 +1272,22 @@ std::vector graphUtils::Chaining(std::vector anchors, std::str chained_anchors.push_back(idx_Anchor[cid][idx]); } // reverse the chain - // std::reverse(chained_anchors.begin(), chained_anchors.end()); + std::reverse(chained_anchors.begin(), chained_anchors.end()); + + std::set set_chained_vtxs; + std::vector chained_vtx; + for (auto idx:chained_anchors) + { + int32_t v = idx.x>>32; + set_chained_vtxs.insert(v); + } + // fill the chained_vtx + for (auto v:set_chained_vtxs){ chained_vtx.push_back(v); } + set_chained_vtxs.clear(); + + // sort chained vtx by topological order + std::sort(chained_vtx.begin(), chained_vtx.end(), [&](int32_t a, int32_t b) -> bool \ + { return map_top_sort[cid][idx_component[cid][a]] < map_top_sort[cid][idx_component[cid][b]]; }); // Do DP to find minimum recombination int min_recomb = std::numeric_limits::max(); @@ -1282,26 +1295,26 @@ std::vector graphUtils::Chaining(std::vector anchors, std::str // Inialize C as \inf for (auto hap:walk_ids) { - for (int i = 0; i < chained_anchors.size(); i++) + for (int i = 0; i < chained_vtx.size(); i++) { C_[hap].push_back({std::numeric_limits::max(), {"-1", -1}}); } } // Inialize C - int node = (int)(chained_anchors[0].x >> 32); + int node = chained_vtx[0]; for (auto hap:haps[node]) { C_[hap][0] = {0, {"-1", -1}}; } // Do DP - for (int i = 1; i < chained_anchors.size(); i++) + for (int i = 1; i < chained_vtx.size(); i++) { - int node_1 = (int)(chained_anchors[i].x >> 32); + int node_1 = chained_vtx[i]; for (auto hap:haps[node_1]) { - int node_2 = (int)(chained_anchors[i - 1].x >> 32); + int node_2 = chained_vtx[i-1]; for (auto hap2:haps[node_2]) { if (hap == hap2) @@ -1319,9 +1332,9 @@ std::vector graphUtils::Chaining(std::vector anchors, std::str // Find the mimum recombination for (auto hap:walk_ids) { - if (min_dp_recomb > C_[hap][chained_anchors.size() - 1]) + if (min_dp_recomb > C_[hap][chained_vtx.size() - 1]) { - min_dp_recomb = C_[hap][chained_anchors.size() - 1]; + min_dp_recomb = C_[hap][chained_vtx.size() - 1]; } } @@ -1455,28 +1468,14 @@ std::vector graphUtils::Chaining(std::vector anchors, std::str } } - // // print true_hap_pair - // for (auto hap_pair:true_hap_pair) - // { - // std::cerr << "True Hap Pair : " << hap_pair.first << " " << hap_pair.second << "\n"; - // } - - // // print chain_hap_pair - // for (auto hap_pair:chain_hap_pair) - // { - // std::cerr << "Chain Hap Pair : " << hap_pair.first << " " << hap_pair.second << "\n"; - // } - - precision = (float)true_pos/(float)(true_pos + false_pos); - recall = (float)true_pos/(float)(true_pos + false_neg); - // std::cerr << "Precision : " << precision << "\n"; - // std::cerr << "Recall : " << recall << "\n"; + precision_ = (float)true_pos/(float)(true_pos + false_pos); + recall_ = (float)true_pos/(float)(true_pos + false_neg); // compute F1 float loc_acc = 0.0f; - f1_score = 2.0f * ((precision * recall)/(precision + recall)); - if (precision != 0.0f && recall != 0.0f) + f1_score = 2.0f * ((precision_ * recall_)/(precision_ + recall_)); + if (precision_ != 0.0f && recall_ != 0.0f) { sum_acc = f1_score; loc_acc = f1_score; @@ -1486,15 +1485,6 @@ std::vector graphUtils::Chaining(std::vector anchors, std::str loc_acc = 0.0f; } count_acc++; - - // std::cerr << "f1_score: " << f1_score << std::endl; - // if (precision != 0.0f && recall != 0.0f) { - // accuracy = f1_score; - // std::cerr << "Accuracy updated to: " << accuracy << std::endl; - // } else { - // accuracy = 0.0f; - // std::cerr << "Accuracy updated to: " << accuracy << std::endl; - // } // add chained_haps to hap_seqs with S and E and > sign haps_seq_cid[cid] += ">S"; @@ -1522,18 +1512,6 @@ std::vector graphUtils::Chaining(std::vector anchors, std::str } } - // // count recombinations - // if (recombination != -1) - // { - // if (loc_max < recombination) - // { - // loc_max = recombination; - // loc_min = recombination; - // } - // } - - // Compute max score of C again considering the visited anchors - // if(param_z) std::cerr << " Second pass for max_score computation started ... " << std::endl; max_score = std::numeric_limits::min(); //{{std::numeric_limits::min(), -1}, -1}; for (int i = prev_idx; i < D.size(); i++) // O(N) time maximum score search { @@ -1546,22 +1524,13 @@ std::vector graphUtils::Chaining(std::vector anchors, std::str break; } } - - // // For each chain add min/max recombination - // if (loc_max > 0) - // { - // max_loc[cid] += loc_max; - // min_loc[cid] += loc_min; - // } - if(param_z) std::cerr << " Second pass for max_score computation finished ... " << std::endl; - } acc_cid[cid] = sum_acc/(float)count_acc; - - + precision_vec[cid] = precision_; + recall_vec[cid] = recall_; for (int i = 0; i < chain.size(); i++) { @@ -1578,6 +1547,8 @@ std::vector graphUtils::Chaining(std::vector anchors, std::str } } int max_f1_cid_ = 0; + float max_precision = 0.0f; + float max_recall = 0.0f; float max_f1 = std::numeric_limits::min(); for (int cid = 0; cid < num_cid; cid++) { @@ -1588,6 +1559,8 @@ std::vector graphUtils::Chaining(std::vector anchors, std::str } } accuracy = max_f1; + precision = precision_vec[max_f1_cid_]; + recall = recall_vec[max_f1_cid_]; hap_seqs = haps_seq_cid[max_f1_cid_]; max = std::max(max, min_loc[max_f1_cid_]); min = std::min(min, min_loc[max_f1_cid_]); @@ -1943,7 +1916,7 @@ std::vector graphUtils::Chaining(std::vector anchors, std::str if(!is_hprc_gfa && !is_hap) { for (int i = 0; i < best.size(); i++) { - int node = (int)(best[i].x >> 32); + int node = best[i].x >> 32; if (count[node] <= 5) { red_idx.push_back(i); } diff --git a/graphUtils.h b/graphUtils.h index 5b3073e..122c4ac 100644 --- a/graphUtils.h +++ b/graphUtils.h @@ -124,6 +124,7 @@ class graphUtils float div; int max_itr; std::string hap_seqs; + float precision = 0.0f, recall = 0.0f; // for recombinations count int min = std::numeric_limits::max(), max = std::numeric_limits::min(), max_sum = 0, count = 0; diff --git a/main.c b/main.c index 11c2b0d..7fbff4e 100644 --- a/main.c +++ b/main.c @@ -315,30 +315,25 @@ int main(int argc, char *argv[]) int min, max, max_sum, count, count_correct, count_not_correct; float accuracy, frac_correct; + float precision, recall; std::string haps; - get_vars(min, max, max_sum, count, accuracy, haps, count_correct, count_not_correct, frac_correct); + get_vars(min, max, max_sum, count, accuracy, haps, count_correct, count_not_correct, frac_correct, precision, recall); float mean = (float)max_sum / (float)count; - // print count of correct and not correct recombination events - // std::cerr << " count_correct : " << count_correct << " count_not_correct : " << count_not_correct << " count : " << count << std::endl; - // assert(count_correct + count_not_correct == count); - - float corr = (float)count_correct / (float)count; float incorr = (1.0f - corr); float incorr_corr = ((float)frac_correct / (float)count_not_correct) * incorr; float incorr_incorr = incorr - incorr_corr; - // std::cerr << " accuracy: " << accuracy << " Benchmark: " << benchmark << std::endl; - if (mg_verbose >= 3) { if (benchmark && (accuracy == -1.0f)) { - // {fprintf(stderr, "[M::%s] R: %f, NR_R: %f, NR_NR: %f\n", __func__, corr, incorr_corr, incorr_incorr);}; {fprintf(stderr, "[M::%s] R: %f\n", __func__, corr);}; } else if (benchmark) { - if (min != INT_MAX) {fprintf(stderr, "[M::%s] Recombinations [Min: %d, Max: %d, Mean: %f, Accuracy: %f]\n", __func__, min, max, mean, accuracy);}; + // if (min != INT_MAX) {fprintf(stderr, "[M::%s] Recombinations [Min: %d, Max: %d, Mean: %f, Accuracy: %f]\n", __func__, min, max, mean, accuracy);}; + if (min != INT_MAX) {fprintf(stderr, "[M::%s] Recombinations [Min: %d, Max: %d, Mean: %f, Accuracy: %f, precision: %f, recall: %f]\n", __func__, min, max, \ + mean, accuracy, precision, recall);}; if (min != INT_MAX) {fprintf(stderr, "[M::%s] Haplotype paths: %s\n", __func__, haps.c_str());}; }else { diff --git a/map-algo.c b/map-algo.c index c331d1c..5c41b0d 100644 --- a/map-algo.c +++ b/map-algo.c @@ -15,7 +15,8 @@ void get_Op(graphUtils *graph_Op) graphOp = graph_Op; } -void get_vars(int &min, int &max, int &max_sum, int &count, float &accuracy, std::string &hap_seqs, int32_t &count_correct, int32_t &count_not_correct, float &frac_correct){ +void get_vars(int &min, int &max, int &max_sum, int &count, float &accuracy, std::string &hap_seqs, int32_t &count_correct, \ + int32_t &count_not_correct, float &frac_correct, float &precision, float &recall){ min = graphOp->min; max = graphOp->max; max_sum = graphOp->max_sum; @@ -25,6 +26,8 @@ void get_vars(int &min, int &max, int &max_sum, int &count, float &accuracy, std count_correct = graphOp->count_correct; count_not_correct = graphOp->count_not_correct; frac_correct = graphOp->frac_correct; + precision = graphOp->precision; + recall = graphOp->recall; } struct mg_tbuf_s { diff --git a/mgpriv.h b/mgpriv.h index 31b0b6d..8de353c 100644 --- a/mgpriv.h +++ b/mgpriv.h @@ -134,7 +134,8 @@ struct params }; -void get_vars(int &min, int &max, int &max_sum, int &count, float &accuracy, std::string &hap_seqs, int32_t &count_correct, int32_t &count_not_correct, float &frac_correct); +void get_vars(int &min, int &max, int &max_sum, int &count, float &accuracy, std::string &hap_seqs, int32_t &count_correct, \ + int32_t &count_not_correct, float &frac_correct, float &precison, float &recall); #ifdef __cplusplus }