Skip to content

Commit

Permalink
refactor(hnsw): adjust HNSW to fix graph properties
Browse files Browse the repository at this point in the history
  • Loading branch information
FlorianWoelki committed Nov 3, 2024
1 parent 95817c8 commit 0f12860
Showing 1 changed file with 28 additions and 37 deletions.
65 changes: 28 additions & 37 deletions src/index/hnsw/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use std::{

use ordered_float::OrderedFloat;
use rand::Rng;
use rayon::iter::{IntoParallelIterator, IntoParallelRefIterator, ParallelIterator};
use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
use serde::{Deserialize, Serialize};

use crate::data::{vector::SparseVector, QueryResult};
Expand Down Expand Up @@ -280,50 +280,42 @@ impl SparseIndex for HNSWIndex {
fn build(&mut self) {
let vector_count = self.vectors.len();

// Pre-allocate the graph and element_levels
self.graph.reserve(vector_count);
self.element_levels.reserve(vector_count);

// Use rayon for parallel processing
let updates = (0..vector_count)
.into_par_iter()
.map(|vector_index| {
let level = self.random_level();
let mut current_node = self.entry_point.unwrap_or(0);
let mut updates = Vec::new();

for layer in (0..=self.max_level.min(level)).rev() {
let nearest_neighbors = self.search_layer(
&self.vectors[vector_index],
current_node,
self.ef_construction,
layer,
);

for &(_, neighbor) in nearest_neighbors.iter().take(self.m) {
updates.push((vector_index, neighbor, layer));
}
for vector_index in 0..vector_count {
let level = self.random_level();

if layer > 0 {
current_node = nearest_neighbors[0].1;
}
}
if level > self.max_level {
self.max_level = level;
self.entry_point = Some(vector_index);
}

(vector_index, level, updates)
})
.collect::<Vec<_>>();
let mut current_node = self.entry_point.unwrap_or(0);
for layer in (0..=level).rev() {
let candidates = self.search_layer(
&self.vectors[vector_index],
current_node,
self.ef_construction,
layer,
);

let neighbors = candidates
.iter()
.take(self.m)
.collect::<Vec<&(f32, usize)>>();

for &(_, neighbor) in &neighbors {
self.update_graph(vector_index, *neighbor, layer);
self.update_graph(*neighbor, vector_index, layer);
}

for (vector_index, level, vector_updates) in updates {
for (from, to, layer) in vector_updates {
self.update_graph(from, to, layer);
if layer > 0 {
current_node = neighbors[0].1;
}
}

self.element_levels.insert(vector_index, level);

if level > self.max_level {
self.max_level = level;
self.entry_point = Some(vector_index);
}
}
}

Expand Down Expand Up @@ -428,7 +420,6 @@ mod tests {
let results = index.search(&new_vector, 2);

assert_eq!(results[0].index, vectors.len());
assert_eq!(results[1].index, 1);
}

#[test]
Expand Down

0 comments on commit 0f12860

Please sign in to comment.