Skip to content

Commit

Permalink
refactor: more parallelization for data generator
Browse files Browse the repository at this point in the history
  • Loading branch information
FlorianWoelki committed Aug 22, 2024
1 parent 21a5fd6 commit 71891ca
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 25 deletions.
44 changes: 24 additions & 20 deletions src/data/generator_sparse.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
use std::collections::BinaryHeap;
use std::{collections::BinaryHeap, sync::Mutex};

use ordered_float::OrderedFloat;
use rand::{
distributions::{Bernoulli, Distribution, Uniform},
rngs::StdRng,
SeedableRng,
};
use rayon::iter::{IntoParallelIterator, IntoParallelRefIterator, ParallelIterator};

use crate::index::DistanceMetric;

Expand Down Expand Up @@ -119,25 +120,25 @@ impl SparseDataGenerator {
range: (f32, f32),
sparsity: f32,
) -> Vec<SparseVector> {
let mut rng = StdRng::seed_from_u64(self.seed.wrapping_add(1000));
let uniform_dist = Uniform::from(range.0..range.1);
let bernoulli_dist = Bernoulli::new(sparsity as f64).unwrap();
let mut vectors = Vec::with_capacity(count);

for _ in 0..count {
let mut indices = Vec::new();
let mut values = Vec::new();
for i in 0..dim {
if !bernoulli_dist.sample(&mut rng) {
indices.push(i);
values.push(OrderedFloat(uniform_dist.sample(&mut rng)));
}
}

vectors.push(SparseVector { indices, values });
}
(0..count)
.into_par_iter()
.map(|i| {
let mut rng = StdRng::seed_from_u64(self.seed.wrapping_add(1000 + i as u64));
let mut indices = Vec::new();
let mut values = Vec::new();
for i in 0..dim {
if !bernoulli_dist.sample(&mut rng) {
indices.push(i);
values.push(OrderedFloat(uniform_dist.sample(&mut rng)));
}
}

vectors
SparseVector { indices, values }
})
.collect()
}

fn find_nearest_neighbors(
Expand All @@ -146,10 +147,11 @@ impl SparseDataGenerator {
query: &SparseVector,
k: usize,
) -> Vec<SparseVector> {
let mut heap = BinaryHeap::new();
let heap = Mutex::new(BinaryHeap::new());

for vector in data {
data.par_iter().for_each(|vector| {
let distance = query.distance(&vector, &self.metric);
let mut heap = heap.lock().unwrap();
if heap.len() < k {
heap.push((OrderedFloat(distance), vector.clone()));
} else if let Some((OrderedFloat(max_distance), _)) = heap.peek() {
Expand All @@ -158,9 +160,11 @@ impl SparseDataGenerator {
heap.push((OrderedFloat(distance), vector.clone()));
}
}
}
});

heap.into_sorted_vec()
heap.into_inner()
.unwrap()
.into_sorted_vec()
.into_iter()
.map(|(_, v)| v)
.collect::<Vec<_>>()
Expand Down
12 changes: 7 additions & 5 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use rand::{thread_rng, Rng};
use sysinfo::{Pid, System};

use clap::Parser;
use index::{annoy::AnnoyIndex, DistanceMetric, IndexType, SparseIndex};
use index::{annoy::AnnoyIndex, linscan::LinScanIndex, DistanceMetric, IndexType, SparseIndex};
use ordered_float::OrderedFloat;
use rayon::iter::{IntoParallelRefIterator, ParallelIterator};

Expand Down Expand Up @@ -149,13 +149,13 @@ async fn main() {
println!("Executing in serial? {}", !is_parallel);

let args = Args::parse();
let dimensions = args.dimensions.unwrap_or(100);
let dimensions = args.dimensions.unwrap_or(10000);
let amount = args.features.unwrap_or(1000);

let mut rng = thread_rng();
let distance_metric = DistanceMetric::Cosine;
let benchmark_config = BenchmarkConfig::new(
(dimensions, 100, dimensions),
(dimensions, 10000, dimensions),
(amount, 1000, amount),
(0.0, 1.0),
0.90,
Expand All @@ -177,7 +177,8 @@ async fn main() {
println!("...finished generating data");

let total_index_start = Instant::now();
let mut index = AnnoyIndex::new(20, 20, 40, distance_metric);
let mut index = LinScanIndex::new(distance_metric);
// let mut index = AnnoyIndex::new(20, 20, 40, distance_metric);

for vector in &vectors {
index.add_vector_before_build(&vector);
Expand Down Expand Up @@ -220,6 +221,7 @@ async fn main() {
.collect::<Vec<_>>();

let iter_recall = calculate_recall(&search_results, groundtruth_vectors, k);
println!("{}", iter_recall);
accumulated_recall += iter_recall;
}

Expand Down Expand Up @@ -274,7 +276,7 @@ async fn main() {
save_index(
&dir_path,
format!("annoy_serial_{}", amount), // TODO: Modify to support parallel
IndexType::Annoy(index),
IndexType::LinScan(index),
)
});

Expand Down

0 comments on commit 71891ca

Please sign in to comment.