Skip to content

Commit

Permalink
Fix python test and buffer protocol type hinting
Browse files Browse the repository at this point in the history
  • Loading branch information
kylebarron committed Dec 30, 2024
1 parent 219f7bd commit 6c67ee2
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 7 deletions.
11 changes: 9 additions & 2 deletions python/python/geoindex_rs/kdtree.pyi
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
from __future__ import annotations

import sys
from typing import Literal, Union

import numpy as np
from arro3.core.types import ArrowArrayExportable
from arro3.core import Array
from arro3.core.types import ArrowArrayExportable

if sys.version_info > (3, 12):
from collections.abc import Buffer
else:
from typing_extensions import Buffer

ArrayLike = Union[np.ndarray, ArrowArrayExportable, memoryview, bytes]
IndexLike = Union[np.ndarray, ArrowArrayExportable, memoryview, bytes, KDTree]
Expand Down Expand Up @@ -32,4 +39,4 @@ class KDTreeBuilder:
def add(self, x: ArrayLike, y: ArrayLike | None = None) -> Array: ...
def finish(self) -> KDTree: ...

class KDTree: ...
class KDTree(Buffer): ...
11 changes: 9 additions & 2 deletions python/python/geoindex_rs/rtree.pyi
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
from __future__ import annotations

import sys
from typing import Literal, Union

import numpy as np
from arro3.core.types import ArrowArrayExportable
from arro3.core import Array
from arro3.core.types import ArrowArrayExportable

if sys.version_info > (3, 12):
from collections.abc import Buffer
else:
from typing_extensions import Buffer

ArrayLike = Union[np.ndarray, ArrowArrayExportable, memoryview, bytes]
IndexLike = Union[np.ndarray, ArrowArrayExportable, memoryview, bytes, RTree]
Expand Down Expand Up @@ -36,7 +43,7 @@ class RTreeBuilder:
) -> Array: ...
def finish(self, method: Literal["hilbert", "str", None] = None) -> RTree: ...

class RTree:
class RTree(Buffer):
@property
def num_items(self) -> int: ...
@property
Expand Down
7 changes: 4 additions & 3 deletions python/tests/test_buffers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import numpy as np

# from .. import RTree
from geoindex_rs import rtree


def generate_random_boxes():
Expand All @@ -11,7 +10,9 @@ def generate_random_boxes():

def test_buffer_protocol():
boxes = generate_random_boxes()
initial = RTree.from_interleaved(boxes)
builder = rtree.RTreeBuilder(len(boxes))
builder.add(boxes)
initial = builder.finish()
# construct a memoryview transparently
view = memoryview(initial)
assert initial.num_bytes == view.nbytes

0 comments on commit 6c67ee2

Please sign in to comment.