diff --git a/src/kdtree/mod.rs b/src/kdtree/mod.rs index ee8cb8d..e82e147 100644 --- a/src/kdtree/mod.rs +++ b/src/kdtree/mod.rs @@ -19,6 +19,11 @@ //! slice. If you don't know the coordinate type used in the index, you can use //! [`CoordType::from_buffer`][crate::CoordType::from_buffer] to infer the coordinate type. //! +//! ## Coordinate types +//! +//! Supported coordinate types implement [`IndexableNum`][crate::IndexableNum]. Note that float +//! `NaN` is not supported and may panic. +//! //! ## Example //! //! ``` diff --git a/src/rtree/mod.rs b/src/rtree/mod.rs index b5bae48..2a056f6 100644 --- a/src/rtree/mod.rs +++ b/src/rtree/mod.rs @@ -18,6 +18,11 @@ //! slice. If you don't know the coordinate type used in the index, you can use //! [`CoordType::from_buffer`][crate::CoordType::from_buffer] to infer the coordinate type. //! +//! ## Coordinate types +//! +//! Supported coordinate types implement [`IndexableNum`][crate::IndexableNum]. Note that float +//! `NaN` is not supported and may panic. +//! //! ## Example //! //! ``` diff --git a/src/rtree/trait.rs b/src/rtree/trait.rs index ba9c96d..b23e7b5 100644 --- a/src/rtree/trait.rs +++ b/src/rtree/trait.rs @@ -1,3 +1,6 @@ +use std::cmp::Reverse; +use std::collections::BinaryHeap; + use geo_traits::{CoordTrait, RectTrait}; use crate::error::Result; @@ -124,39 +127,101 @@ pub trait RTreeIndex: Sized { ) } - // #[allow(unused_mut, unused_labels, unused_variables)] - // fn neighbors(&self, x: N, y: N, max_distance: Option) -> Vec { - // let boxes = self.boxes(); - // let indices = self.indices(); - // let max_distance = max_distance.unwrap_or(N::max_value()); + /// Search items in order of distance from the given point. + /// + /// ``` + /// use geo_index::rtree::{RTreeBuilder, RTreeIndex, RTreeRef}; + /// use geo_index::rtree::sort::HilbertSort; + /// + /// // Create an RTree + /// let mut builder = RTreeBuilder::::new(3); + /// builder.add(0., 0., 2., 2.); + /// builder.add(1., 1., 3., 3.); + /// builder.add(2., 2., 4., 4.); + /// let tree = builder.finish::(); + /// + /// let results = tree.neighbors(5., 5., None, None); + /// assert_eq!(results, vec![2, 1, 0]); + /// ``` + fn neighbors( + &self, + x: N, + y: N, + max_results: Option, + max_distance: Option, + ) -> Vec { + let boxes = self.boxes(); + let indices = self.indices(); + let max_distance = max_distance.unwrap_or(N::max_value()); - // let mut outer_node_index = Some(boxes.len() - 4); + let mut outer_node_index = Some(boxes.len() - 4); + let mut queue = BinaryHeap::new(); + let mut results = vec![]; + let max_dist_squared = max_distance * max_distance; - // let mut results = vec![]; - // let max_dist_squared = max_distance * max_distance; + 'outer: while let Some(node_index) = outer_node_index { + // find the end index of the node + let end = (node_index + self.node_size() as usize * 4) + .min(upper_bound(node_index, self.level_bounds())); - // 'outer: while let Some(node_index) = outer_node_index { - // // find the end index of the node - // let end = (node_index + self.node_size() * 4) - // .min(upper_bound(node_index, self.level_bounds())); + // add child nodes to the queue + for pos in (node_index..end).step_by(4) { + let index = indices.get(pos >> 2); - // // add child nodes to the queue - // for pos in (node_index..end).step_by(4) { - // let index = indices.get(pos >> 2); + let dx = axis_dist(x, boxes[pos], boxes[pos + 2]); + let dy = axis_dist(y, boxes[pos + 1], boxes[pos + 3]); + let dist = dx * dx + dy * dy; + if dist > max_dist_squared { + continue; + } - // let dx = axis_dist(x, boxes[pos], boxes[pos + 2]); - // let dy = axis_dist(y, boxes[pos + 1], boxes[pos + 3]); - // let dist = dx * dx + dy * dy; - // if dist > max_dist_squared { - // continue; - // } - // } + if node_index >= self.num_items() as usize * 4 { + // node (use even id) + queue.push(Reverse(NeighborNode { + id: index << 1, + dist, + })); + } else { + // leaf item (use odd id) + queue.push(Reverse(NeighborNode { + id: (index << 1) + 1, + dist, + })); + } + } - // // break 'outer; - // } + // pop items from the queue + while !queue.is_empty() && queue.peek().is_some_and(|val| (val.0.id & 1) != 0) { + let dist = queue.peek().unwrap().0.dist; + if dist > max_dist_squared { + break 'outer; + } + let item = queue.pop().unwrap(); + results.push(item.0.id >> 1); + if max_results.is_some_and(|max_results| results.len() == max_results) { + break 'outer; + } + } + + if let Some(item) = queue.pop() { + outer_node_index = Some(item.0.id >> 1); + } else { + outer_node_index = None; + } + } - // results - // } + results + } + + /// Search items in order of distance from the given coordinate. + fn neighbors_coord( + &self, + coord: &impl CoordTrait, + max_results: Option, + max_distance: Option, + ) -> Vec { + self.neighbors(coord.x(), coord.y(), max_results, max_distance) + } /// Returns an iterator over the indexes of objects in this and another tree that intersect. /// @@ -175,6 +240,28 @@ pub trait RTreeIndex: Sized { } } +/// A wrapper around a node and its distance for use in the priority queue. +#[derive(Debug, Clone, Copy, PartialEq)] +struct NeighborNode { + id: usize, + dist: N, +} + +impl Eq for NeighborNode {} + +impl Ord for NeighborNode { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + // We don't allow NaN. This should only panic on NaN + self.dist.partial_cmp(&other.dist).unwrap() + } +} + +impl PartialOrd for NeighborNode { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + impl RTreeIndex for OwnedRTree { fn boxes(&self) -> &[N] { self.metadata.boxes_slice(&self.buffer) diff --git a/src/type.rs b/src/type.rs index f7a8bcc..b8be430 100644 --- a/src/type.rs +++ b/src/type.rs @@ -1,6 +1,6 @@ use std::fmt::Debug; -use num_traits::{Bounded, Num, NumCast, ToPrimitive}; +use num_traits::{Bounded, Num, NumCast}; use crate::kdtree::constants::KDBUSH_MAGIC; use crate::GeoIndexError; @@ -12,16 +12,7 @@ use crate::GeoIndexError; /// JavaScript ([rtree](https://github.com/mourner/flatbush), /// [kdtree](https://github.com/mourner/kdbush)) pub trait IndexableNum: - private::Sealed - + Num - + NumCast - + ToPrimitive - + PartialOrd - + Debug - + Send - + Sync - + bytemuck::Pod - + Bounded + private::Sealed + Num + NumCast + PartialOrd + Debug + Send + Sync + bytemuck::Pod + Bounded { /// The type index to match the array order of `ARRAY_TYPES` in flatbush JS const TYPE_INDEX: u8;