Skip to content

Commit

Permalink
feat: add precision and f1 score support
Browse files Browse the repository at this point in the history
  • Loading branch information
FlorianWoelki committed Jan 31, 2025
1 parent fc947fc commit 4f3b15e
Showing 1 changed file with 87 additions and 5 deletions.
92 changes: 87 additions & 5 deletions src/benchmark/metrics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,38 @@ pub fn calculate_recall(
k: usize,
) -> f32 {
let mut correct_results = 0;
for result in search_results.iter().take(k) {
if groundtruth.iter().any(|gt| gt.indices == result.indices) {
correct_results += 1;
}
}

correct_results as f32 / groundtruth.len() as f32
}

for result in search_results {
if groundtruth
.iter()
.any(|gt_vector| gt_vector.indices == result.indices)
{
pub fn calculate_precision(
search_results: &[SparseVector],
groundtruth: &[SparseVector],
k: usize,
) -> f32 {
let mut correct_results = 0;
for result in search_results.iter().take(k) {
if groundtruth.iter().any(|gt| gt.indices == result.indices) {
correct_results += 1;
}
}

correct_results as f32 / k as f32
}

pub fn calculate_f1_score(precision: f32, recall: f32) -> f32 {
if precision + recall > 0.0 {
(2.0 * precision * recall) / (precision + recall)
} else {
0.0
}
}

/// A scalability factor greater than one indicates that the algorithm is scaling
/// better than linear expectations, maintaining or improving its relative performance
/// despite increases in data size and dimensionality.
Expand Down Expand Up @@ -57,6 +76,69 @@ mod tests {
use super::*;
use std::time::Duration;

#[test]
fn test_perfect_precision() {
let search_results = vec![
SparseVector {
indices: vec![1, 2, 3],
values: vec![OrderedFloat(0.1), OrderedFloat(0.2), OrderedFloat(0.3)],
},
SparseVector {
indices: vec![2, 3, 4],
values: vec![OrderedFloat(0.2), OrderedFloat(0.3), OrderedFloat(0.4)],
},
];
let groundtruth = vec![
SparseVector {
indices: vec![1, 2, 3],
values: vec![OrderedFloat(0.1), OrderedFloat(0.2), OrderedFloat(0.3)],
},
SparseVector {
indices: vec![2, 3, 4],
values: vec![OrderedFloat(0.2), OrderedFloat(0.3), OrderedFloat(0.4)],
},
];

assert_eq!(calculate_precision(&search_results, &groundtruth, 2), 1.0);
}

#[test]
fn test_partial_precision() {
let search_results = vec![
SparseVector {
indices: vec![1, 2, 3],
values: vec![OrderedFloat(0.1), OrderedFloat(0.2), OrderedFloat(0.3)],
},
SparseVector {
indices: vec![2, 3, 4],
values: vec![OrderedFloat(0.2), OrderedFloat(0.3), OrderedFloat(0.4)],
},
];
let groundtruth = vec![
SparseVector {
indices: vec![1, 2, 3],
values: vec![OrderedFloat(0.1), OrderedFloat(0.2), OrderedFloat(0.3)],
},
SparseVector {
indices: vec![3, 4, 5],
values: vec![OrderedFloat(0.3), OrderedFloat(0.4), OrderedFloat(0.5)],
},
];

assert_eq!(calculate_precision(&search_results, &groundtruth, 2), 0.5);
}

#[test]
fn test_f1_score() {
assert_eq!(calculate_f1_score(1.0, 1.0), 1.0);
assert_eq!(calculate_f1_score(0.5, 0.5), 0.5);
let f1 = calculate_f1_score(0.5, 1.0);
assert!(f1 > 0.66 && f1 < 0.67);
assert_eq!(calculate_f1_score(0.0, 0.0), 0.0);
assert_eq!(calculate_f1_score(0.0, 0.5), 0.0);
assert_eq!(calculate_f1_score(0.5, 0.0), 0.0);
}

#[test]
fn test_perfect_recall() {
let search_results = vec![
Expand Down

0 comments on commit 4f3b15e

Please sign in to comment.