Skip to content

Commit

Permalink
Implement nearest neighbor searches on RTree (#79)
Browse files Browse the repository at this point in the history
* Implement nearest neighbor searches on RTree

* Add doctest

* Remove Ord bound and document panic for NaN

* remove ordered_float

* reorder

* comment
  • Loading branch information
kylebarron authored Dec 30, 2024
1 parent 1285b9e commit 59fdaa9
Show file tree
Hide file tree
Showing 4 changed files with 125 additions and 37 deletions.
5 changes: 5 additions & 0 deletions src/kdtree/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@
//! slice. If you don't know the coordinate type used in the index, you can use
//! [`CoordType::from_buffer`][crate::CoordType::from_buffer] to infer the coordinate type.
//!
//! ## Coordinate types
//!
//! Supported coordinate types implement [`IndexableNum`][crate::IndexableNum]. Note that float
//! `NaN` is not supported and may panic.
//!
//! ## Example
//!
//! ```
Expand Down
5 changes: 5 additions & 0 deletions src/rtree/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@
//! slice. If you don't know the coordinate type used in the index, you can use
//! [`CoordType::from_buffer`][crate::CoordType::from_buffer] to infer the coordinate type.
//!
//! ## Coordinate types
//!
//! Supported coordinate types implement [`IndexableNum`][crate::IndexableNum]. Note that float
//! `NaN` is not supported and may panic.
//!
//! ## Example
//!
//! ```
Expand Down
139 changes: 113 additions & 26 deletions src/rtree/trait.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
use std::cmp::Reverse;
use std::collections::BinaryHeap;

use geo_traits::{CoordTrait, RectTrait};

use crate::error::Result;
Expand Down Expand Up @@ -124,39 +127,101 @@ pub trait RTreeIndex<N: IndexableNum>: Sized {
)
}

// #[allow(unused_mut, unused_labels, unused_variables)]
// fn neighbors(&self, x: N, y: N, max_distance: Option<N>) -> Vec<usize> {
// let boxes = self.boxes();
// let indices = self.indices();
// let max_distance = max_distance.unwrap_or(N::max_value());
/// Search items in order of distance from the given point.
///
/// ```
/// use geo_index::rtree::{RTreeBuilder, RTreeIndex, RTreeRef};
/// use geo_index::rtree::sort::HilbertSort;
///
/// // Create an RTree
/// let mut builder = RTreeBuilder::<f64>::new(3);
/// builder.add(0., 0., 2., 2.);
/// builder.add(1., 1., 3., 3.);
/// builder.add(2., 2., 4., 4.);
/// let tree = builder.finish::<HilbertSort>();
///
/// let results = tree.neighbors(5., 5., None, None);
/// assert_eq!(results, vec![2, 1, 0]);
/// ```
fn neighbors(
&self,
x: N,
y: N,
max_results: Option<usize>,
max_distance: Option<N>,
) -> Vec<usize> {
let boxes = self.boxes();
let indices = self.indices();
let max_distance = max_distance.unwrap_or(N::max_value());

// let mut outer_node_index = Some(boxes.len() - 4);
let mut outer_node_index = Some(boxes.len() - 4);
let mut queue = BinaryHeap::new();
let mut results = vec![];
let max_dist_squared = max_distance * max_distance;

// let mut results = vec![];
// let max_dist_squared = max_distance * max_distance;
'outer: while let Some(node_index) = outer_node_index {
// find the end index of the node
let end = (node_index + self.node_size() as usize * 4)
.min(upper_bound(node_index, self.level_bounds()));

// 'outer: while let Some(node_index) = outer_node_index {
// // find the end index of the node
// let end = (node_index + self.node_size() * 4)
// .min(upper_bound(node_index, self.level_bounds()));
// add child nodes to the queue
for pos in (node_index..end).step_by(4) {
let index = indices.get(pos >> 2);

// // add child nodes to the queue
// for pos in (node_index..end).step_by(4) {
// let index = indices.get(pos >> 2);
let dx = axis_dist(x, boxes[pos], boxes[pos + 2]);
let dy = axis_dist(y, boxes[pos + 1], boxes[pos + 3]);
let dist = dx * dx + dy * dy;
if dist > max_dist_squared {
continue;
}

// let dx = axis_dist(x, boxes[pos], boxes[pos + 2]);
// let dy = axis_dist(y, boxes[pos + 1], boxes[pos + 3]);
// let dist = dx * dx + dy * dy;
// if dist > max_dist_squared {
// continue;
// }
// }
if node_index >= self.num_items() as usize * 4 {
// node (use even id)
queue.push(Reverse(NeighborNode {
id: index << 1,
dist,
}));
} else {
// leaf item (use odd id)
queue.push(Reverse(NeighborNode {
id: (index << 1) + 1,
dist,
}));
}
}

// // break 'outer;
// }
// pop items from the queue
while !queue.is_empty() && queue.peek().is_some_and(|val| (val.0.id & 1) != 0) {
let dist = queue.peek().unwrap().0.dist;
if dist > max_dist_squared {
break 'outer;
}
let item = queue.pop().unwrap();
results.push(item.0.id >> 1);
if max_results.is_some_and(|max_results| results.len() == max_results) {
break 'outer;
}
}

if let Some(item) = queue.pop() {
outer_node_index = Some(item.0.id >> 1);
} else {
outer_node_index = None;
}
}

// results
// }
results
}

/// Search items in order of distance from the given coordinate.
fn neighbors_coord(
&self,
coord: &impl CoordTrait<T = N>,
max_results: Option<usize>,
max_distance: Option<N>,
) -> Vec<usize> {
self.neighbors(coord.x(), coord.y(), max_results, max_distance)
}

/// Returns an iterator over the indexes of objects in this and another tree that intersect.
///
Expand All @@ -175,6 +240,28 @@ pub trait RTreeIndex<N: IndexableNum>: Sized {
}
}

/// A wrapper around a node and its distance for use in the priority queue.
#[derive(Debug, Clone, Copy, PartialEq)]
struct NeighborNode<N: IndexableNum> {
id: usize,
dist: N,
}

impl<N: IndexableNum> Eq for NeighborNode<N> {}

impl<N: IndexableNum> Ord for NeighborNode<N> {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
// We don't allow NaN. This should only panic on NaN
self.dist.partial_cmp(&other.dist).unwrap()
}
}

impl<N: IndexableNum> PartialOrd for NeighborNode<N> {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}

impl<N: IndexableNum> RTreeIndex<N> for OwnedRTree<N> {
fn boxes(&self) -> &[N] {
self.metadata.boxes_slice(&self.buffer)
Expand Down
13 changes: 2 additions & 11 deletions src/type.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::fmt::Debug;

use num_traits::{Bounded, Num, NumCast, ToPrimitive};
use num_traits::{Bounded, Num, NumCast};

use crate::kdtree::constants::KDBUSH_MAGIC;
use crate::GeoIndexError;
Expand All @@ -12,16 +12,7 @@ use crate::GeoIndexError;
/// JavaScript ([rtree](https://github.com/mourner/flatbush),
/// [kdtree](https://github.com/mourner/kdbush))
pub trait IndexableNum:
private::Sealed
+ Num
+ NumCast
+ ToPrimitive
+ PartialOrd
+ Debug
+ Send
+ Sync
+ bytemuck::Pod
+ Bounded
private::Sealed + Num + NumCast + PartialOrd + Debug + Send + Sync + bytemuck::Pod + Bounded
{
/// The type index to match the array order of `ARRAY_TYPES` in flatbush JS
const TYPE_INDEX: u8;
Expand Down

0 comments on commit 59fdaa9

Please sign in to comment.