Skip to content

Commit

Permalink
Add copy parameter to functions returning Arrow views (#107)
Browse files Browse the repository at this point in the history
  • Loading branch information
kylebarron authored Jan 6, 2025
1 parent 29684ce commit 9350af4
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 33 deletions.
40 changes: 29 additions & 11 deletions python/python/geoindex_rs/rtree.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,24 @@ IndexLike = Union[np.ndarray, ArrowArrayExportable, Buffer, RTree]
"""A type alias for accepted input as an RTree.
"""

def boxes_at_level(index: IndexLike, level: int) -> Array:
def boxes_at_level(index: IndexLike, level: int, *, copy: bool = False) -> Array:
"""Access the raw bounding box data contained in the RTree at a given tree level.
Args:
index: the RTree to search.
level: The level of the tree to read from. Level 0 is the _base_ of the tree. Each integer higher is one level higher of the tree.
Other Args:
copy: if True, make a _copy_ of the data from the underlying RTree instead of
viewing it directly. Making a copy can be preferred if you'd like to delete
the index itself to save memory.
Returns:
An Arrow FixedSizeListArray containing the bounding box coordinates.
The returned array is a a zero-copy view from Rust. Note that it will keep
the entire index memory alive until the returned array is garbage collected.
If `copy` is `False`, the returned array is a a zero-copy view from Rust.
Note that it will keep the entire index memory alive until the returned
array is garbage collected.
"""

def tree_join(
Expand Down Expand Up @@ -106,7 +112,7 @@ def neighbors(
An Arrow array with the insertion indexes of query results.
"""

def partitions(index: IndexLike) -> RecordBatch:
def partitions(index: IndexLike, *, copy=False) -> RecordBatch:
"""Extract the spatial partitions from an RTree.
This can be used to find the sorted groups for spatially partitioning the original
Expand Down Expand Up @@ -144,15 +150,21 @@ def partitions(index: IndexLike) -> RecordBatch:
Args:
index: the RTree to use.
Other Args:
copy: if True, make a _copy_ of the data from the underlying RTree instead of
viewing it directly. Making a copy can be preferred if you'd like to delete
the index itself to save memory.
Returns:
An Arrow `RecordBatch` with two columns: `indices` and `partition_ids`. `indices` refers to the insertion index of each row and `partition_ids` refers to the partition each row belongs to.
The `indices` column is constructed as a zero-copy view on the provided
index. Therefore, the `indices` array will have type `uint16` if the tree
has fewer than 16,384 items; otherwise it will have type `uint32`.
If `copy` is `False`, the `indices` column is constructed as a zero-copy
view on the provided index. Therefore, the `indices` array will have type
`uint16` if the tree has fewer than 16,384 items; otherwise it will have
type `uint32`.
"""

def partition_boxes(index: IndexLike) -> RecordBatch:
def partition_boxes(index: IndexLike, *, copy: bool = False) -> RecordBatch:
"""Extract the geometries of the spatial partitions from an RTree.
In order for these boxes to be zero-copy from Rust, they are returned as a
Expand All @@ -169,12 +181,18 @@ def partition_boxes(index: IndexLike) -> RecordBatch:
Args:
index: the RTree to use.
Other Args:
copy: if True, make a _copy_ of the data from the underlying RTree instead of
viewing it directly. Making a copy can be preferred if you'd like to delete
the index itself to save memory.
Returns:
An Arrow `RecordBatch` with two columns: `boxes` and `partition_ids`. `boxes` stores the box geometry of each partition and `partition_ids` refers to the partition each row belongs to.
The `boxes` column is constructed as a zero-copy view on the internal boxes
data. The `partition_id` column will be `uint16` type if there are less than
65,536 partitions; otherwise it will be `uint32` type.
If `copy` is `False`, the `boxes` column is constructed as a zero-copy view
on the internal boxes data. The `partition_id` column will be `uint16` type
if there are less than 65,536 partitions; otherwise it will be `uint32`
type.
"""

def search(
Expand Down
12 changes: 9 additions & 3 deletions python/src/rtree/boxes_at_level.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,25 @@ use crate::rtree::input::PyRTreeRef;
use crate::util::boxes_to_arrow;

#[pyfunction]
pub fn boxes_at_level(py: Python, index: PyRTreeRef, level: usize) -> PyResult<PyObject> {
#[pyo3(signature = (index, level, *, copy = false))]
pub fn boxes_at_level(
py: Python,
index: PyRTreeRef,
level: usize,
copy: bool,
) -> PyResult<PyObject> {
let array = match index {
PyRTreeRef::Float32(tree) => {
let boxes = tree
.boxes_at_level(level)
.map_err(|err| PyIndexError::new_err(err.to_string()))?;
boxes_to_arrow::<Float32Type>(boxes, tree.buffer().clone())
boxes_to_arrow::<Float32Type>(boxes, tree.buffer().clone(), copy)
}
PyRTreeRef::Float64(tree) => {
let boxes = tree
.boxes_at_level(level)
.map_err(|err| PyIndexError::new_err(err.to_string()))?;
boxes_to_arrow::<Float64Type>(boxes, tree.buffer().clone())
boxes_to_arrow::<Float64Type>(boxes, tree.buffer().clone(), copy)
}
};
PyArray::from_array_ref(array).to_arro3(py)
Expand Down
37 changes: 29 additions & 8 deletions python/src/rtree/partitions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,26 @@ use crate::rtree::input::PyRTreeRef;
use crate::util::slice_to_arrow;

#[pyfunction]
pub fn partitions(py: Python, index: PyRTreeRef) -> PyResult<PyObject> {
#[pyo3(signature = (index, *, copy = false))]
pub fn partitions(py: Python, index: PyRTreeRef, copy: bool) -> PyResult<PyObject> {
let (indices, partition_ids) = match index {
PyRTreeRef::Float32(tree) => {
let indices = indices_to_arrow(tree.indices(), tree.num_items(), tree.buffer().clone());
let indices = indices_to_arrow(
tree.indices(),
tree.num_items(),
tree.buffer().clone(),
copy,
);
let partition_ids = partition_id_array(tree.num_items(), tree.node_size());
(indices, partition_ids)
}
PyRTreeRef::Float64(tree) => {
let indices = indices_to_arrow(tree.indices(), tree.num_items(), tree.buffer().clone());
let indices = indices_to_arrow(
tree.indices(),
tree.num_items(),
tree.buffer().clone(),
copy,
);
let partition_ids = partition_id_array(tree.num_items(), tree.node_size());
(indices, partition_ids)
}
Expand All @@ -38,10 +49,19 @@ pub fn partitions(py: Python, index: PyRTreeRef) -> PyResult<PyObject> {
.to_arro3(py)
}

fn indices_to_arrow(indices: Indices, num_items: u32, owner: Arc<dyn Allocation>) -> ArrayRef {
fn indices_to_arrow(
indices: Indices,
num_items: u32,
owner: Arc<dyn Allocation>,
copy: bool,
) -> ArrayRef {
match indices {
Indices::U16(slice) => slice_to_arrow::<UInt16Type>(&slice[0..num_items as usize], owner),
Indices::U32(slice) => slice_to_arrow::<UInt32Type>(&slice[0..num_items as usize], owner),
Indices::U16(slice) => {
slice_to_arrow::<UInt16Type>(&slice[0..num_items as usize], owner, copy)
}
Indices::U32(slice) => {
slice_to_arrow::<UInt32Type>(&slice[0..num_items as usize], owner, copy)
}
}
}

Expand Down Expand Up @@ -83,8 +103,9 @@ fn partition_id_array(num_items: u32, node_size: u16) -> ArrayRef {
// Since for now we assume that the partition level is the node level, we select the boxes at level
// 1.
#[pyfunction]
pub fn partition_boxes(py: Python, index: PyRTreeRef) -> PyResult<PyObject> {
let array = boxes_at_level(py, index, 1)?.extract::<PyArray>(py)?;
#[pyo3(signature = (index, *, copy = false))]
pub fn partition_boxes(py: Python, index: PyRTreeRef, copy: bool) -> PyResult<PyObject> {
let array = boxes_at_level(py, index, 1, copy)?.extract::<PyArray>(py)?;
let (array, _field) = array.into_inner();

let partition_ids: ArrayRef = if array.len() < u16::MAX as _ {
Expand Down
31 changes: 20 additions & 11 deletions python/src/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,25 +10,34 @@ use arrow_schema::Field;
pub(crate) fn slice_to_arrow<T: ArrowPrimitiveType>(
slice: &[T::Native],
owner: Arc<dyn Allocation>,
copy: bool,
) -> ArrayRef {
let ptr = NonNull::new(slice.as_ptr() as *mut _).unwrap();
let len = slice.len();
let bytes_len = len * T::Native::get_byte_width();
if copy {
Arc::new(PrimitiveArray::<T>::new(
ScalarBuffer::from(slice.to_vec()),
None,
))
} else {
let ptr = NonNull::new(slice.as_ptr() as *mut _).unwrap();
let len = slice.len();
let bytes_len = len * T::Native::get_byte_width();

// Safety:
// ptr is a non-null pointer owned by the RTree, which is passed in as the Allocation
let buffer = unsafe { Buffer::from_custom_allocation(ptr, bytes_len, owner) };
Arc::new(PrimitiveArray::<T>::new(
ScalarBuffer::new(buffer, 0, len),
None,
))
// Safety:
// ptr is a non-null pointer owned by the RTree, which is passed in as the Allocation
let buffer = unsafe { Buffer::from_custom_allocation(ptr, bytes_len, owner) };
Arc::new(PrimitiveArray::<T>::new(
ScalarBuffer::new(buffer, 0, len),
None,
))
}
}

pub(crate) fn boxes_to_arrow<T: ArrowPrimitiveType>(
slice: &[T::Native],
owner: Arc<dyn Allocation>,
copy: bool,
) -> ArrayRef {
let values_array = slice_to_arrow::<T>(slice, owner);
let values_array = slice_to_arrow::<T>(slice, owner, copy);
Arc::new(FixedSizeListArray::new(
Arc::new(Field::new("item", values_array.data_type().clone(), false)),
4,
Expand Down

0 comments on commit 9350af4

Please sign in to comment.