Skip to content

Commit

Permalink
Return partitions as arrow
Browse files Browse the repository at this point in the history
  • Loading branch information
kylebarron committed Dec 30, 2024
1 parent cfece7a commit 2566fe1
Show file tree
Hide file tree
Showing 6 changed files with 108 additions and 67 deletions.
3 changes: 2 additions & 1 deletion python/python/geoindex_rs/rtree.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import sys
from typing import Literal, Union

import numpy as np
from arro3.core import Array
from arro3.core import Array, RecordBatch
from arro3.core.types import ArrowArrayExportable

if sys.version_info > (3, 12):
Expand All @@ -26,6 +26,7 @@ def intersection_candidates(
left: IndexLike,
right: IndexLike,
) -> Array: ...
def partitions(index: IndexLike) -> RecordBatch: ...

class RTreeMetadata:
def __repr__(self) -> str: ...
Expand Down
1 change: 1 addition & 0 deletions python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
mod coord_type;
mod kdtree;
mod rtree;
pub(crate) mod util;

use pyo3::exceptions::PyRuntimeWarning;
use pyo3::intern;
Expand Down
26 changes: 3 additions & 23 deletions python/src/rtree/builder.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
use arrow_array::builder::UInt32Builder;
use arrow_array::cast::AsArray;
use arrow_array::types::{Float32Type, Float64Type};
use arrow_array::{ArrayRef, ArrowPrimitiveType, PrimitiveArray};
use arrow_buffer::alloc::Allocation;
use arrow_buffer::{ArrowNativeType, Buffer, ScalarBuffer};
use arrow_cast::cast;
use arrow_schema::DataType;
use geo_index::rtree::sort::{HilbertSort, STRSort};
Expand All @@ -15,10 +12,10 @@ use pyo3::prelude::*;
use pyo3::pybacked::PyBackedStr;
use pyo3_arrow::PyArray;
use std::os::raw::c_int;
use std::ptr::NonNull;
use std::sync::Arc;

use crate::coord_type::CoordType;
use crate::util::slice_to_arrow;

#[allow(clippy::upper_case_acronyms)]
pub enum RTreeMethod {
Expand Down Expand Up @@ -386,14 +383,14 @@ impl PyRTreeInner {
let boxes = index
.boxes_at_level(level)
.map_err(|err| PyIndexError::new_err(err.to_string()))?;
PyArray::from_array_ref(boxes_at_level::<Float32Type>(boxes, index.clone()))
PyArray::from_array_ref(slice_to_arrow::<Float32Type>(boxes, index.clone()))
.to_arro3(py)
}
Self::Float64(index) => {
let boxes = index
.boxes_at_level(level)
.map_err(|err| PyIndexError::new_err(err.to_string()))?;
PyArray::from_array_ref(boxes_at_level::<Float64Type>(boxes, index.clone()))
PyArray::from_array_ref(slice_to_arrow::<Float64Type>(boxes, index.clone()))
.to_arro3(py)
}
}
Expand Down Expand Up @@ -473,20 +470,3 @@ impl PyRTree {
self.0.boxes_at_level(py, level)
}
}

fn boxes_at_level<T: ArrowPrimitiveType>(
boxes: &[T::Native],
owner: Arc<dyn Allocation>,
) -> ArrayRef {
let ptr = NonNull::new(boxes.as_ptr() as *mut _).unwrap();
let len = boxes.len();
let bytes_len = len * T::Native::get_byte_width();

// Safety:
// ptr is a non-null pointer owned by the RTree, which is passed in as the Allocation
let buffer = unsafe { Buffer::from_custom_allocation(ptr, bytes_len, owner) };
Arc::new(PrimitiveArray::<T>::new(
ScalarBuffer::new(buffer, 0, len),
None,
))
}
105 changes: 62 additions & 43 deletions python/src/rtree/partitions.rs
Original file line number Diff line number Diff line change
@@ -1,69 +1,88 @@
use std::sync::Arc;

use arrow_array::{ArrayRef, UInt32Array};
use arrow_buffer::ScalarBuffer;
use arrow_schema::{DataType, Field};
use arrow_array::builder::{UInt16Builder, UInt32Builder};
use arrow_array::types::{UInt16Type, UInt32Type};
use arrow_array::{ArrayRef, RecordBatch};
use arrow_buffer::alloc::Allocation;
use arrow_schema::{Field, Schema};
use geo_index::indices::Indices;
use geo_index::rtree::{RTreeIndex, RTreeRef};
use geo_index::CoordType;
use pyo3::prelude::*;
use pyo3_arrow::buffer::PyArrowBuffer;
use pyo3_arrow::PyChunkedArray;
use pyo3_arrow::PyRecordBatch;

use crate::util::slice_to_arrow;

#[pyfunction]
pub fn partitions(py: Python, index: PyArrowBuffer) -> PyResult<PyObject> {
let buffer = index.into_inner();
let owner = Arc::new(buffer.clone());
let slice = buffer.as_slice();
let coord_type = CoordType::from_buffer(&slice).unwrap();
let result = match coord_type {
let (indices, partition_ids) = match coord_type {
CoordType::Float32 => {
let tree = RTreeRef::<f32>::try_new(&slice).unwrap();
let node_size = tree.node_size();
match tree.indices() {
Indices::U16(indices) => indices_to_chunked_array(indices, node_size),
Indices::U32(indices) => indices_to_chunked_array_u32(indices, node_size),
}
let indices = indices_to_arrow(tree.indices(), tree.num_items(), owner);
let partition_ids = partition_id_array(tree.num_items(), tree.node_size());
(indices, partition_ids)
}
CoordType::Float64 => {
let tree = RTreeRef::<f64>::try_new(&slice).unwrap();
let node_size = tree.node_size();
match tree.indices() {
Indices::U16(indices) => indices_to_chunked_array(indices, node_size),
Indices::U32(indices) => indices_to_chunked_array_u32(indices, node_size),
}
let indices = indices_to_arrow(tree.indices(), tree.num_items(), owner);
let partition_ids = partition_id_array(tree.num_items(), tree.node_size());
(indices, partition_ids)
}
_ => todo!("Only f32 and f64 implemented so far"),
};
result.to_arro3(py)

let fields = vec![
Field::new("indices", indices.data_type().clone(), false),
Field::new("partition_id", partition_ids.data_type().clone(), false),
];
let schema = Schema::new(fields);
PyRecordBatch::new(RecordBatch::try_new(schema.into(), vec![indices, partition_ids]).unwrap())
.to_arro3(py)
}

fn indices_to_chunked_array(indices: &[u16], node_size: u16) -> PyChunkedArray {
let array_chunks = indices
.chunks(node_size as usize)
.map(|chunk| {
Arc::new(UInt32Array::new(
ScalarBuffer::from(Vec::from_iter(chunk.iter().map(|x| *x as u32))),
None,
)) as ArrayRef
})
.collect::<Vec<_>>();
PyChunkedArray::try_new(
array_chunks,
Arc::new(Field::new("indices", DataType::UInt32, false)),
)
.unwrap()
fn indices_to_arrow(indices: Indices, num_items: u32, owner: Arc<dyn Allocation>) -> ArrayRef {
match indices {
Indices::U16(slice) => slice_to_arrow::<UInt16Type>(&slice[0..num_items as usize], owner),
Indices::U32(slice) => slice_to_arrow::<UInt32Type>(&slice[0..num_items as usize], owner),
}
}

fn indices_to_chunked_array_u32(indices: &[u32], node_size: u16) -> PyChunkedArray {
let array_chunks = indices
.chunks(node_size as usize)
.map(|chunk| {
Arc::new(UInt32Array::new(ScalarBuffer::from(chunk.to_vec()), None)) as ArrayRef
})
.collect::<Vec<_>>();
PyChunkedArray::try_new(
array_chunks,
Arc::new(Field::new("indices", DataType::UInt32, false)),
)
.unwrap()
fn partition_id_array(num_items: u32, node_size: u16) -> ArrayRef {
let num_full_nodes = num_items / node_size as u32;
let remainder = num_items % node_size as u32;

// Check if the partition ids fit inside a u16
// We add 1 to cover the remainder
if num_full_nodes + 1 < u16::MAX as _ {
let mut output_array = UInt16Builder::with_capacity(num_items as _);

let mut partition_id = 0;
for _ in 0..num_full_nodes {
output_array.append_value_n(partition_id, node_size as usize);
partition_id += 1;
}

// The loop omits the last node
output_array.append_value_n(partition_id, remainder as usize);

Arc::new(output_array.finish())
} else {
let mut output_array = UInt32Builder::with_capacity(num_items as _);

let mut partition_id = 0;
for _ in 0..num_full_nodes {
output_array.append_value_n(partition_id, node_size as usize);
partition_id += 1;
}

// The loop omits the last node
output_array.append_value_n(partition_id, remainder as usize);

Arc::new(output_array.finish())
}
}
23 changes: 23 additions & 0 deletions python/src/util.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
use std::ptr::NonNull;
use std::sync::Arc;

use arrow_array::{ArrayRef, ArrowPrimitiveType, PrimitiveArray};
use arrow_buffer::alloc::Allocation;
use arrow_buffer::{ArrowNativeType, Buffer, ScalarBuffer};

pub(crate) fn slice_to_arrow<T: ArrowPrimitiveType>(
slice: &[T::Native],
owner: Arc<dyn Allocation>,
) -> ArrayRef {
let ptr = NonNull::new(slice.as_ptr() as *mut _).unwrap();
let len = slice.len();
let bytes_len = len * T::Native::get_byte_width();

// Safety:
// ptr is a non-null pointer owned by the RTree, which is passed in as the Allocation
let buffer = unsafe { Buffer::from_custom_allocation(ptr, bytes_len, owner) };
Arc::new(PrimitiveArray::<T>::new(
ScalarBuffer::new(buffer, 0, len),
None,
))
}
17 changes: 17 additions & 0 deletions python/tests/test_rtree.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,20 @@ def test_rtree():
assert np.all(min_y == np_arr[:, 1])
assert np.all(max_x == np_arr[:, 2])
assert np.all(max_y == np_arr[:, 3])


def test_partitions():
builder = rtree.RTreeBuilder(5, 2)
min_x = np.arange(5)
min_y = np.arange(5)
max_x = np.arange(5, 10)
max_y = np.arange(5, 10)
builder.add(min_x, min_y, max_x, max_y)
tree = builder.finish()

partitions = rtree.partitions(tree)
indices = partitions["indices"]
partition_id = partitions["partition_id"]

assert np.all(np.asarray(indices) == np.arange(5))
assert len(np.unique(np.asarray(partition_id))) == 3

0 comments on commit 2566fe1

Please sign in to comment.