From 81d5cda1be07ac034017eb910a909b424954d31c Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Wed, 4 Dec 2024 18:13:26 +0100 Subject: [PATCH 01/12] feat[next]: allocator in field --- src/gt4py/next/allocators.py | 51 ++++++++++++++++------ src/gt4py/next/common.py | 2 + src/gt4py/next/constructors.py | 52 ++++++++++++++++++++--- src/gt4py/next/embedded/nd_array_field.py | 11 +++-- 4 files changed, 94 insertions(+), 22 deletions(-) diff --git a/src/gt4py/next/allocators.py b/src/gt4py/next/allocators.py index 864f8c1b09..63911d3308 100644 --- a/src/gt4py/next/allocators.py +++ b/src/gt4py/next/allocators.py @@ -60,7 +60,7 @@ def __gt_allocate__( dtype: core_defs.DType[core_defs.ScalarT], device_id: int = 0, aligned_index: Optional[Sequence[common.NamedIndex]] = None, # absolute position - ) -> core_allocators.TensorBuffer[core_defs.DeviceTypeT, core_defs.ScalarT]: ... + ) -> core_defs.NDArrayObject: ... def is_field_allocator(obj: Any) -> TypeGuard[FieldBufferAllocatorProtocol]: @@ -160,7 +160,7 @@ def __gt_allocate__( dtype: core_defs.DType[core_defs.ScalarT], device_id: int = 0, aligned_index: Optional[Sequence[common.NamedIndex]] = None, # absolute position - ) -> core_allocators.TensorBuffer[core_defs.DeviceTypeT, core_defs.ScalarT]: + ) -> core_defs.NDArrayObject: shape = domain.shape layout_map = self.layout_mapper(domain.dims) # TODO(egparedes): add support for non-empty aligned index values @@ -168,7 +168,7 @@ def __gt_allocate__( return self.buffer_allocator.allocate( shape, dtype, device_id, layout_map, self.byte_alignment, aligned_index - ) + ).ndarray if TYPE_CHECKING: @@ -242,7 +242,7 @@ def __gt_allocate__( dtype: core_defs.DType[core_defs.ScalarT], device_id: int = 0, aligned_index: Optional[Sequence[common.NamedIndex]] = None, # absolute position - ) -> core_allocators.TensorBuffer[core_defs.DeviceTypeT, core_defs.ScalarT]: + ) -> core_defs.NDArrayObject: raise self.exception @@ -292,16 +292,28 @@ def __init__(self) -> None: ) -def allocate( - domain: common.DomainLike, +class ConcreteAllocator(Protocol): + def __call__( + domain: common.DomainLike, + dtype: core_defs.DType[core_defs.ScalarT], + *, + aligned_index: Optional[Sequence[common.NamedIndex]], + allocator: FieldBufferAllocationUtil, + device: core_defs.Device, + ) -> core_defs.NDArrayObject: ... + + +def make_concrete_allocator( + domain: common.DomainLike, # TODO: there is an inconsistency between DomainLike and concrete DType, probably accept either (Domain, DType) or (DomainLike, DTypeLike). anyway this is not meant to be user-facing dtype: core_defs.DType[core_defs.ScalarT], *, aligned_index: Optional[Sequence[common.NamedIndex]] = None, allocator: Optional[FieldBufferAllocationUtil] = None, device: Optional[core_defs.Device] = None, -) -> core_allocators.TensorBuffer: +) -> ConcreteAllocator: """ - Allocate a TensorBuffer for the given domain and device or allocator. + TODO: docstring + Allocate an NDArrayObject for the given domain and device or allocator. The arguments `device` and `allocator` are mutually exclusive. If `device` is specified, the corresponding default allocator @@ -334,9 +346,20 @@ def allocate( elif device.device_type != actual_allocator.__gt_device_type__: raise ValueError(f"Device '{device}' and allocator '{actual_allocator}' are incompatible.") - return actual_allocator.__gt_allocate__( - domain=common.domain(domain), - dtype=dtype, - device_id=device.device_id, - aligned_index=aligned_index, - ) + def allocate( + domain: common.DomainLike = domain, + dtype: core_defs.DType[core_defs.ScalarT] = dtype, + *, + aligned_index: Optional[Sequence[common.NamedIndex]] = aligned_index, + allocator: FieldBufferAllocationUtil = actual_allocator, + device: core_defs.Device = device, + ) -> core_defs.NDArrayObject: + # TODO check how to get from FieldBufferAllocationUtil to FieldBufferAllocatorProtocol + return allocator.__gt_allocate__( + domain=common.domain(domain), + dtype=dtype, + device_id=device.device_id, + aligned_index=aligned_index, + ) + + return allocate diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index 9b2870e1c0..23fe6c9e0c 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -950,6 +950,7 @@ def _field( /, *, domain: Optional[DomainLike] = None, + allocator: Optional[Any] = None, # TODO: resolve the type annotation dtype: Optional[core_defs.DType] = None, ) -> Field: raise NotImplementedError @@ -963,6 +964,7 @@ def _connectivity( codomain: Dimension, *, domain: Optional[DomainLike] = None, + allocator: Optional[Any] = None, # TODO: resolve the type annotation dtype: Optional[core_defs.DType] = None, skip_value: Optional[core_defs.IntegralScalar] = None, ) -> Connectivity: diff --git a/src/gt4py/next/constructors.py b/src/gt4py/next/constructors.py index 7b39511674..e1c063a30f 100644 --- a/src/gt4py/next/constructors.py +++ b/src/gt4py/next/constructors.py @@ -11,6 +11,8 @@ from collections.abc import Mapping, Sequence from typing import Optional, cast +from typing_extensions import NotRequired, TypedDict + import gt4py._core.definitions as core_defs import gt4py.eve as eve import gt4py.eve.extended_typing as xtyping @@ -77,10 +79,11 @@ def empty( dtype = core_defs.dtype(dtype) if allocator is None and device is None: device = core_defs.Device(core_defs.DeviceType.CPU, device_id=0) - buffer = next_allocators.allocate( + allocate = next_allocators.make_concrete_allocator( domain, dtype, aligned_index=aligned_index, allocator=allocator, device=device ) - res = common._field(buffer.ndarray, domain=domain) + buffer = allocate() + res = common._field(buffer, domain=domain, allocator=allocate) assert isinstance(res, common.MutableField) assert isinstance(res, nd_array_field.NdArrayField) return res @@ -349,12 +352,51 @@ def as_connectivity( if (allocator is None) and (device is None) and xtyping.supports_dlpack(data): device = core_defs.Device(*data.__dlpack_device__()) - buffer = next_allocators.allocate(actual_domain, dtype, allocator=allocator, device=device) + allocate = next_allocators.make_concrete_allocator( + actual_domain, dtype, allocator=allocator, device=device + ) + buffer = allocate() # TODO(havogt): consider adding MutableNDArrayObject - buffer.ndarray[...] = storage_utils.asarray(data) # type: ignore[index] + buffer[...] = storage_utils.asarray(data) # type: ignore[index] connectivity_field = common._connectivity( - buffer.ndarray, codomain=codomain, domain=actual_domain, skip_value=skip_value + buffer, codomain=codomain, domain=actual_domain, skip_value=skip_value, allocator=allocate ) assert isinstance(connectivity_field, nd_array_field.NdArrayConnectivityField) return connectivity_field + + +_like_field = None # for more descriptive function signature in editors + + +class AllocatorParams(TypedDict): + domain: NotRequired[common.DomainLike] + dtype: NotRequired[core_defs.DType[core_defs.ScalarT],] + aligned_index: NotRequired[Sequence[common.NamedIndex]] + allocator: NotRequired[next_allocators.FieldBufferAllocatorProtocol] + device: NotRequired[core_defs.Device] + + +def empty_like( + field: nd_array_field.NdArrayField, + *, + domain: Optional[common.DomainLike] = _like_field, + dtype: Optional[core_defs.DTypeLike] = _like_field, + aligned_index: Optional[Sequence[common.NamedIndex]] = _like_field, + allocator: Optional[next_allocators.FieldBufferAllocationUtil] = _like_field, + device: Optional[core_defs.Device] = _like_field, +) -> nd_array_field.NdArrayField: + kwargs: AllocatorParams = {} + if domain is not None: + kwargs["domain"] = domain + if dtype is not None: + kwargs["dtype"] = core_defs.dtype(dtype) + if aligned_index is not None: + kwargs["aligned_index"] = aligned_index + if allocator is not eve.NOTHING: + kwargs["allocator"] = allocator + if device is not eve.NOTHING: + kwargs["device"] = device + if field._allocator is None: + raise ValueError("'Field' does not have an allocator.") # TODO discuss if this is possible + return field._allocator(**kwargs) diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index e15fb4266a..3181fe808f 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -29,7 +29,7 @@ TypeVar, cast, ) -from gt4py.next import common +from gt4py.next import allocators, common from gt4py.next.embedded import ( common as embedded_common, context as embedded_context, @@ -116,6 +116,7 @@ class NdArrayField( _domain: common.Domain _ndarray: core_defs.NDArrayObject + _allocator: Optional[allocators.ConcreteAllocator] array_ns: ClassVar[ModuleType] # TODO(havogt) introduce a NDArrayNamespace protocol @@ -167,6 +168,9 @@ def from_array( /, *, domain: common.DomainLike, + allocator: Optional[ + allocators.ConcreteAllocator + ] = None, # TODO: maybe an NDArrayField always has an allocator? dtype: Optional[core_defs.DTypeLike] = None, ) -> NdArrayField: domain = common.domain(domain) @@ -184,7 +188,7 @@ def from_array( assert len(domain) == array.ndim assert all(s == 1 or len(r) == s for r, s in zip(domain.ranges, array.shape)) - return cls(domain, array) + return cls(domain, array, allocator) def premap( self: NdArrayField, @@ -513,6 +517,7 @@ def from_array( # type: ignore[override] codomain: common.DimT, *, domain: common.DomainLike, + allocator: Optional[allocators.ConcreteAllocator] = None, dtype: Optional[core_defs.DTypeLike] = None, skip_value: Optional[core_defs.IntegralScalar] = None, ) -> NdArrayConnectivityField: @@ -533,7 +538,7 @@ def from_array( # type: ignore[override] assert isinstance(codomain, common.Dimension) - return cls(domain, array, codomain, _skip_value=skip_value) + return cls(domain, array, allocator, codomain, _skip_value=skip_value) def inverse_image(self, image_range: common.UnitRange | common.NamedRange) -> common.Domain: cache_key = hash((id(self.ndarray), self.domain, image_range)) From fe09c5f0e0bf5de1aa4debab21ad2e0ee81542ce Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 9 Dec 2024 08:58:58 +0100 Subject: [PATCH 02/12] continue refactoring --- src/gt4py/_core/definitions.py | 11 +++ src/gt4py/next/allocators.py | 2 +- src/gt4py/next/constructors.py | 111 ++++++++++++++++++++++- src/gt4py/storage/allocators.py | 111 +---------------------- src/gt4py/storage/cartesian/interface.py | 4 +- 5 files changed, 125 insertions(+), 114 deletions(-) diff --git a/src/gt4py/_core/definitions.py b/src/gt4py/_core/definitions.py index 8f62788b8f..238ec9d266 100644 --- a/src/gt4py/_core/definitions.py +++ b/src/gt4py/_core/definitions.py @@ -505,3 +505,14 @@ def __and__(self, other: NDArrayObject | Scalar) -> NDArrayObject: ... def __or__(self, other: NDArrayObject | Scalar) -> NDArrayObject: ... def __xor__(self, other: NDArrayObject | Scalar) -> NDArrayObject: ... + + +class ArrayApiNamespace(Protocol): + @property + def __array_api_version__(self) -> str: ... + + # TODO(havogt): add relevant methods and attributes or wait for the standard to provide it, see e.g. https://github.com/data-apis/array-api/issues/697 + + +def is_array_api_namespace(obj: Any) -> TypeGuard[ArrayApiNamespace]: + return hasattr(obj, "__array_api_version__") diff --git a/src/gt4py/next/allocators.py b/src/gt4py/next/allocators.py index 63911d3308..abc9c5f4ee 100644 --- a/src/gt4py/next/allocators.py +++ b/src/gt4py/next/allocators.py @@ -168,7 +168,7 @@ def __gt_allocate__( return self.buffer_allocator.allocate( shape, dtype, device_id, layout_map, self.byte_alignment, aligned_index - ).ndarray + ) if TYPE_CHECKING: diff --git a/src/gt4py/next/constructors.py b/src/gt4py/next/constructors.py index e1c063a30f..acf35c4ac2 100644 --- a/src/gt4py/next/constructors.py +++ b/src/gt4py/next/constructors.py @@ -9,7 +9,7 @@ from __future__ import annotations from collections.abc import Mapping, Sequence -from typing import Optional, cast +from typing import Any, Callable, Optional, Protocol, TypeGuard, cast from typing_extensions import NotRequired, TypedDict @@ -22,15 +22,67 @@ import gt4py.storage.cartesian.utils as storage_utils +class _HasArrayApiCreationFunctions(Protocol): + def empty(self, shape: Sequence[int], *, dtype=None, device=None) -> Any: ... + def zeros(self, shape: Sequence[int], *, dtype=None, device=None) -> Any: ... + def ones(self, shape: Sequence[int], *, dtype=None, device=None) -> Any: ... + def full(self, shape: Sequence[int], fill_value, *, dtype=None, device=None) -> Any: ... + def asarray(self, obj, *, dtype=None, copy=None) -> Any: ... + + +def _has_array_api_creation_functions(obj: Any) -> TypeGuard[_HasArrayApiCreationFunctions]: + return core_defs.is_array_api_namespace(obj) or ( + hasattr(obj, "emtpy") + and hasattr(obj, "zeros") + and hasattr(obj, "ones") + and hasattr(obj, "full") + and hasattr(obj, "asarray") + ) + + +def _array_api_construction( + fun: Callable, + *args, + domain: common.Domain, + dtype: Optional[core_defs.DTypeLike] = None, + device: Optional[core_defs.Device] = None, + **kwargs: Any, +): + buffer = fun(*args, **kwargs) # TODO add converted dtype and device + + # xp = array_api_compat.array_namespace( + # buffer + # ) # TODO(havogt): replace by buffer.__array_namespace__ once all libraries support that + xp = None # TODO + + def allocate( + domain: common.DomainLike = domain, + dtype: core_defs.DTypeLike = dtype, + *, + aligned_index: Optional[Sequence[common.NamedIndex]], + allocator: next_allocators.FieldBufferAllocationUtil | core_defs.ArrayApiNamespace = xp, + device: core_defs.Device = device, + ) -> core_defs.NDArrayObject: + # always returns an empty buffer by design + return empty( + domain, dtype=dtype, aligned_index=aligned_index, allocator=allocator, device=device + ) + + return common._field(buffer, domain=domain, allocator=allocate) + + @eve.utils.with_fluid_partial def empty( domain: common.DomainLike, dtype: core_defs.DTypeLike = core_defs.Float64DType(()), # noqa: B008 [function-call-in-default-argument] *, aligned_index: Optional[Sequence[common.NamedIndex]] = None, - allocator: Optional[next_allocators.FieldBufferAllocationUtil] = None, + allocator: Optional[ + next_allocators.FieldBufferAllocationUtil | core_defs.ArrayApiNamespace + ] = None, # TODO make sure numpy/cupy etc namespaces are accepted by mypy (maybe we have to allow Any) device: Optional[core_defs.Device] = None, ) -> nd_array_field.NdArrayField: + # TODO: update doc """Create a `Field` of uninitialized (undefined) values using the given (or device-default) allocator. This function supports partial binding of arguments, see :class:`eve.utils.partial` for details. @@ -76,11 +128,25 @@ def empty( >>> b.shape (3, 3) """ - dtype = core_defs.dtype(dtype) + if _has_array_api_creation_functions(allocator): + domain = common.domain(domain) + return _array_api_construction( + allocator.empty, + domain.shape, + domain=domain, + dtype=dtype, + device=device, + ) + if allocator is None and device is None: device = core_defs.Device(core_defs.DeviceType.CPU, device_id=0) + allocate = next_allocators.make_concrete_allocator( - domain, dtype, aligned_index=aligned_index, allocator=allocator, device=device + domain, + dtype, # TODO resolve dtype-like inside! + aligned_index=aligned_index, + allocator=allocator, + device=device, ) buffer = allocate() res = common._field(buffer, domain=domain, allocator=allocate) @@ -109,6 +175,15 @@ def zeros( >>> gtx.zeros({IDim: range(3, 10)}, allocator=gtx.itir_python).ndarray array([0., 0., 0., 0., 0., 0., 0.]) """ + if _has_array_api_creation_functions(allocator): + domain = common.domain(domain) + return _array_api_construction( + allocator.zeros, + domain.shape, + domain=domain, + dtype=dtype, + device=device, + ) field = empty( domain=domain, dtype=dtype, aligned_index=aligned_index, allocator=allocator, device=device ) @@ -136,6 +211,15 @@ def ones( >>> gtx.ones({IDim: range(3, 10)}, allocator=gtx.itir_python).ndarray array([1., 1., 1., 1., 1., 1., 1.]) """ + if _has_array_api_creation_functions(allocator): + domain = common.domain(domain) + return _array_api_construction( + allocator.ones, + domain.shape, + domain=domain, + dtype=dtype, + device=device, + ) field = empty( domain=domain, dtype=dtype, aligned_index=aligned_index, allocator=allocator, device=device ) @@ -169,6 +253,16 @@ def full( >>> gtx.full({IDim: 3}, 5, allocator=gtx.itir_python).ndarray array([5, 5, 5]) """ + if _has_array_api_creation_functions(allocator): + domain = common.domain(domain) + return _array_api_construction( + allocator.full, + domain.shape, + fill_value, + domain=domain, + dtype=dtype, + device=device, + ) field = empty( domain=domain, dtype=dtype if dtype is not None else core_defs.dtype(type(fill_value)), @@ -235,6 +329,15 @@ def as_field( >>> gtx.as_field({IDim: range(-1, 2)}, xdata).domain.ranges[0] UnitRange(-1, 2) """ + if _has_array_api_creation_functions(allocator): + return _array_api_construction( + allocator.asarray, + data, + domain=common.domain(domain), + dtype=dtype, + device=device, + # copy=copy + ) if isinstance(domain, Sequence) and all(isinstance(dim, common.Dimension) for dim in domain): domain = cast(Sequence[common.Dimension], domain) if len(domain) != data.ndim: diff --git a/src/gt4py/storage/allocators.py b/src/gt4py/storage/allocators.py index 298b9c2e5a..9290402b2c 100644 --- a/src/gt4py/storage/allocators.py +++ b/src/gt4py/storage/allocators.py @@ -17,12 +17,10 @@ import types import numpy as np -import numpy.typing as npt from gt4py._core import definitions as core_defs from gt4py.eve import extended_typing as xtyping from gt4py.eve.extended_typing import ( - TYPE_CHECKING, Any, Callable, Generic, @@ -31,7 +29,6 @@ Protocol, Sequence, Tuple, - Type, TypeAlias, TypeGuard, Union, @@ -64,88 +61,6 @@ def is_valid_layout_map(value: Sequence[Any]) -> TypeGuard[BufferLayoutMap]: ) -@dataclasses.dataclass(frozen=True) -class TensorBuffer(Generic[core_defs.DeviceTypeT, core_defs.ScalarT]): - """ - N-dimensional (tensor-like) memory buffer. - - The actual class of the stored buffer and ndarray instances is - represented in the `NDBufferT` parameter and might be any n-dimensional - buffer-like class with a compatible buffer interface (e.g. NumPy - or CuPy `ndarray`.) - - Attributes: - buffer: Raw allocated buffer. - memory_address: Memory address of the buffer. - device: Device where the buffer is allocated. - dtype: Data type descriptor. - shape: Tuple with lengths of the corresponding tensor dimensions. - strides: Tuple with sizes (in bytes) of the steps in each dimension. - layout_map: Tuple with the order of the dimensions in the buffer - layout_map[i] = j means that the i-th dimension of the tensor - corresponds to the j-th dimension in the (C-layout) buffer. - byte_offset: Offset (in bytes) from the beginning of the buffer to - the first valid element. - byte_alignment: Alignment (in bytes) of the first valid element. - aligned_index: N-dimensional index of the first aligned element. - ndarray: N-dimensional tensor view of the allocated buffer. - """ - - buffer: _NDBuffer = dataclasses.field(hash=False) - memory_address: int - device: core_defs.Device[core_defs.DeviceTypeT] - dtype: core_defs.DType[core_defs.ScalarT] - shape: core_defs.TensorShape - strides: Tuple[int, ...] - layout_map: BufferLayoutMap - byte_offset: int - byte_alignment: int - aligned_index: Tuple[int, ...] - ndarray: core_defs.NDArrayObject = dataclasses.field(hash=False) - - @property - def ndim(self): - """Order of the tensor (`len(tensor_buffer.shape)`).""" - return len(self.shape) - - def __array__(self, dtype: Optional[npt.DTypeLike] = None, /) -> np.ndarray: - if not xtyping.supports_array(self.ndarray): - raise TypeError("Cannot export tensor buffer as NumPy array.") - - return self.ndarray.__array__(dtype) - - @property - def __array_interface__(self) -> dict[str, Any]: - if not xtyping.supports_array_interface(self.ndarray): - raise TypeError("Cannot export tensor buffer to NumPy array interface.") - - return self.ndarray.__array_interface__ - - @property - def __cuda_array_interface__(self) -> dict[str, Any]: - if not xtyping.supports_cuda_array_interface(self.ndarray): - raise TypeError("Cannot export tensor buffer to CUDA array interface.") - - return self.ndarray.__cuda_array_interface__ - - def __dlpack__(self, *, stream: Optional[int] = None) -> Any: - if not hasattr(self.ndarray, "__dlpack__"): - raise TypeError("Cannot export tensor buffer to DLPack.") - return self.ndarray.__dlpack__(stream=stream) # type: ignore[call-arg,arg-type] # stream is not always supported - - def __dlpack_device__(self) -> xtyping.DLPackDevice: - if not hasattr(self.ndarray, "__dlpack_device__"): - raise TypeError("Cannot extract DLPack device from tensor buffer.") - return self.ndarray.__dlpack_device__() - - -if TYPE_CHECKING: - # TensorBuffer should be compatible with all the expected buffer interfaces - __TensorBufferAsArrayInterfaceT: Type[xtyping.ArrayInterface] = TensorBuffer - __TensorBufferAsCUDAArrayInterfaceT: Type[xtyping.CUDAArrayInterface] = TensorBuffer - __TensorBufferAsDLPackBufferT: Type[xtyping.DLPackBuffer] = TensorBuffer - - class BufferAllocator(Protocol[core_defs.DeviceTypeT]): """Protocol for buffer allocators.""" @@ -160,9 +75,9 @@ def allocate( layout_map: BufferLayoutMap, byte_alignment: int, aligned_index: Optional[Sequence[int]] = None, - ) -> TensorBuffer[core_defs.DeviceTypeT, core_defs.ScalarT]: + ) -> core_defs.NDArrayObject: """ - Allocate a TensorBuffer with the given shape, layout and alignment settings. + Allocate an NDArrayObject with the given shape, layout and alignment settings. Args: shape: Tensor dimensions. @@ -194,7 +109,7 @@ def allocate( layout_map: BufferLayoutMap, byte_alignment: int, aligned_index: Optional[Sequence[int]] = None, - ) -> TensorBuffer[core_defs.DeviceTypeT, core_defs.ScalarT]: + ) -> core_defs.NDArrayObject: if not core_defs.is_valid_tensor_shape(shape): raise ValueError(f"Invalid shape {shape}") ndim = len(shape) @@ -254,24 +169,7 @@ def allocate( ) % byte_alignment byte_offset = (aligned_index_offset + allocation_mismatch_offset) % byte_alignment - # Create shaped view from buffer - ndarray = self.tensorize( - buffer, dtype, shape, padded_shape, item_size, strides, byte_offset - ) - - return TensorBuffer( - buffer=buffer, - memory_address=memory_address, - device=core_defs.Device(self.device_type, device_id), - dtype=dtype, - shape=shape, - strides=strides, - layout_map=layout_map, - byte_offset=byte_offset, - byte_alignment=byte_alignment, - aligned_index=aligned_index, - ndarray=ndarray, - ) + return self.tensorize(buffer, dtype, shape, padded_shape, item_size, strides, byte_offset) @property @abc.abstractmethod @@ -293,6 +191,7 @@ def tensorize( strides: Sequence[int], byte_offset: int, ) -> core_defs.NDArrayObject: + """Create shaped view from buffer.""" pass diff --git a/src/gt4py/storage/cartesian/interface.py b/src/gt4py/storage/cartesian/interface.py index 8b38bcdd42..01f884a555 100644 --- a/src/gt4py/storage/cartesian/interface.py +++ b/src/gt4py/storage/cartesian/interface.py @@ -97,9 +97,7 @@ def empty( assert allocators.is_valid_layout_map(layout_map) dtype = np.dtype(dtype) - _, res = allocate_f(shape, layout_map, dtype, alignment * dtype.itemsize, aligned_index) - - return res + return allocate_f(shape, layout_map, dtype, alignment * dtype.itemsize, aligned_index) def ones( From e4dabc89c6747441837bdb7094e032419599df76 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Tue, 10 Dec 2024 17:50:21 +0100 Subject: [PATCH 03/12] save state --- pyproject.toml | 5 +- src/gt4py/next/constructors.py | 55 ++++---- src/gt4py/next/embedded/nd_array_field.py | 43 +++++- src/gt4py/storage/cartesian/utils.py | 14 +- tests/next_tests/definitions.py | 11 ++ .../ffront_tests/ffront_test_utils.py | 7 +- .../next_tests/unit_tests/test_allocators.py | 14 +- .../unit_tests/test_constructors.py | 124 ++++++++++++------ 8 files changed, 182 insertions(+), 91 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e859c9b4f7..3f80dac5d5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -237,8 +237,9 @@ module = 'gt4py.next.iterator.runtime' [tool.pytest.ini_options] markers = [ 'all: special marker that skips all tests', - 'requires_atlas: tests that require `atlas4py` bindings package', - 'requires_dace: tests that require `dace` package', + 'requires_atlas: tests that require the `atlas4py` bindings package', + 'requires_dace: tests that require the `dace` package', + 'requires_jax: tests that require the `jax` package', 'requires_gpu: tests that require a NVidia GPU (`cupy` and `cudatoolkit` are required)', 'uses_applied_shifts: tests that require backend support for applied-shifts', 'uses_constant_fields: tests that require backend support for constant fields', diff --git a/src/gt4py/next/constructors.py b/src/gt4py/next/constructors.py index acf35c4ac2..7ac800ab97 100644 --- a/src/gt4py/next/constructors.py +++ b/src/gt4py/next/constructors.py @@ -32,7 +32,7 @@ def asarray(self, obj, *, dtype=None, copy=None) -> Any: ... def _has_array_api_creation_functions(obj: Any) -> TypeGuard[_HasArrayApiCreationFunctions]: return core_defs.is_array_api_namespace(obj) or ( - hasattr(obj, "emtpy") + hasattr(obj, "empty") and hasattr(obj, "zeros") and hasattr(obj, "ones") and hasattr(obj, "full") @@ -40,20 +40,26 @@ def _has_array_api_creation_functions(obj: Any) -> TypeGuard[_HasArrayApiCreatio ) +def _convert_dtype( + xp: Any, dtype: core_defs.DTypeLike +) -> Any: # TODO move to core_defs as `to_array_api_dtype` + if dtype is None: + return None + return getattr(xp, core_defs.dtype(dtype).scalar_type.__name__) + + def _array_api_construction( - fun: Callable, + xp: _HasArrayApiCreationFunctions, + fun: str, *args, domain: common.Domain, dtype: Optional[core_defs.DTypeLike] = None, device: Optional[core_defs.Device] = None, **kwargs: Any, ): - buffer = fun(*args, **kwargs) # TODO add converted dtype and device - - # xp = array_api_compat.array_namespace( - # buffer - # ) # TODO(havogt): replace by buffer.__array_namespace__ once all libraries support that - xp = None # TODO + if device is not None: + raise NotImplementedError("Device specification is not yet supported.") + buffer = getattr(xp, fun)(*args, dtype=_convert_dtype(xp, dtype), **kwargs) def allocate( domain: common.DomainLike = domain, @@ -131,7 +137,8 @@ def empty( if _has_array_api_creation_functions(allocator): domain = common.domain(domain) return _array_api_construction( - allocator.empty, + allocator, + "empty", domain.shape, domain=domain, dtype=dtype, @@ -143,7 +150,7 @@ def empty( allocate = next_allocators.make_concrete_allocator( domain, - dtype, # TODO resolve dtype-like inside! + core_defs.dtype(dtype), aligned_index=aligned_index, allocator=allocator, device=device, @@ -178,7 +185,8 @@ def zeros( if _has_array_api_creation_functions(allocator): domain = common.domain(domain) return _array_api_construction( - allocator.zeros, + allocator, + "zeros", domain.shape, domain=domain, dtype=dtype, @@ -214,7 +222,8 @@ def ones( if _has_array_api_creation_functions(allocator): domain = common.domain(domain) return _array_api_construction( - allocator.ones, + allocator, + "ones", domain.shape, domain=domain, dtype=dtype, @@ -256,7 +265,8 @@ def full( if _has_array_api_creation_functions(allocator): domain = common.domain(domain) return _array_api_construction( - allocator.full, + allocator, + "full", domain.shape, fill_value, domain=domain, @@ -329,15 +339,6 @@ def as_field( >>> gtx.as_field({IDim: range(-1, 2)}, xdata).domain.ranges[0] UnitRange(-1, 2) """ - if _has_array_api_creation_functions(allocator): - return _array_api_construction( - allocator.asarray, - data, - domain=common.domain(domain), - dtype=dtype, - device=device, - # copy=copy - ) if isinstance(domain, Sequence) and all(isinstance(dim, common.Dimension) for dim in domain): domain = cast(Sequence[common.Dimension], domain) if len(domain) != data.ndim: @@ -360,6 +361,16 @@ def as_field( if origin: raise ValueError(f"Cannot specify origin for domain {domain}") actual_domain = common.domain(cast(common.DomainLike, domain)) + if _has_array_api_creation_functions(allocator): + return _array_api_construction( + allocator, + "asarray", + data, + domain=actual_domain, + dtype=dtype, + device=device, + # copy=copy + ) # TODO(egparedes): allow zero-copy construction (no reallocation) if buffer has # already the correct layout and device. diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index 3181fe808f..417cf1f50a 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -334,7 +334,7 @@ def restrict(self, index: common.AnyIndexSpec) -> NdArrayField: new_domain, buffer_slice = self._slice(index) new_buffer = self.ndarray[buffer_slice] new_buffer = self.__class__.array_ns.asarray(new_buffer) - return self.__class__.from_array(new_buffer, domain=new_domain) + return self.__class__.from_array(new_buffer, domain=new_domain, allocator=self._allocator) __getitem__ = restrict @@ -432,6 +432,17 @@ def _slice( assert common.is_relative_index_sequence(slice_) return new_domain, slice_ + def __copy__(self) -> Never: + # TODO does this make sense? + raise NotImplementedError( + "`NdArrayField` is frozen, shallow copying is not a useful operation. Did you want to deepcopy?" + ) + + def __deepcopy__(self, _: Any) -> NdArrayField: + ndarray_copy = self._allocator() + ndarray_copy[:] = self.ndarray[:] + return self.__class__(self.domain, ndarray_copy, _allocator=self._allocator) + if dace: # Extension of NdArrayField adding SDFGConvertible support in GT4Py Programs def _dace_data_ptr(self) -> int: @@ -575,7 +586,13 @@ def restrict(self, index: common.AnyIndexSpec) -> NdArrayConnectivityField: xp = cls.array_ns new_domain, buffer_slice = self._slice(index) new_buffer = xp.asarray(self.ndarray[buffer_slice]) - restricted_connectivity = cls(new_domain, new_buffer, self.codomain, self.skip_value) + restricted_connectivity = cls( + new_domain, + new_buffer, + _allocator=self._allocator, + _codomain=self._codomain, + _skip_value=self._skip_value, + ) self._cache[cache_key] = restricted_connectivity return restricted_connectivity @@ -599,7 +616,9 @@ def _domain_premap(data: NdArrayField, *connectivities: common.Connectivity) -> new_ranges = connectivity.inverse_image(current_range) new_domain = new_domain.replace(dim_idx, *new_ranges) - return data.__class__.from_array(data._ndarray, domain=new_domain, dtype=data.dtype) + return data.__class__.from_array( + data._ndarray, domain=new_domain, dtype=data.dtype, allocator=data._allocator + ) def _reshuffling_premap( @@ -639,7 +658,7 @@ def _reshuffling_premap( conn_ndarray = xp.broadcast_to(conn_ndarray, data.domain.shape) if conn_ndarray is not conn.ndarray: conn = conn.__class__.from_array( - conn_ndarray, domain=data.domain, codomain=conn.codomain + conn_ndarray, domain=data.domain, codomain=conn.codomain, allocator=conn._allocator ) conn_map[conn.codomain] = conn dim_idx = data.domain.dim_index(conn.codomain, allow_missing=False) @@ -669,6 +688,7 @@ def _reshuffling_premap( new_buffer, domain=new_domain, dtype=data.dtype, + allocator=data._allocator, ) @@ -709,6 +729,7 @@ def _remapping_premap(data: NdArrayField, connectivity: common.Connectivity) -> new_buffer, domain=new_domain, dtype=data.dtype, + allocator=data._allocator, ) @@ -877,6 +898,7 @@ def _intersect_fields( nd_array_class.from_array( f.ndarray[_get_slices_from_domain_slice(f.domain, intersected_domain)], domain=intersected_domain, + # TODO allocator ) for f, intersected_domain in zip(broadcasted_fields, intersected_domains, strict=True) ) @@ -913,6 +935,7 @@ def _concat(*fields: common.Field, dim: common.Dimension) -> common.Field: [nd_array_class.array_ns.broadcast_to(f.ndarray, f.domain.shape) for f in fields], axis=new_domain.dim_index(dim, allow_missing=False), ), + # TODO allocator domain=new_domain, ) @@ -1081,8 +1104,16 @@ def __setitem__( index: common.AnyIndexSpec, value: common.Field | core_defs.NDArrayObject | core_defs.ScalarT, ) -> None: - # TODO(havogt): use something like `self.ndarray = self.ndarray.at(index).set(value)` - raise NotImplementedError("'__setitem__' for JaxArrayField not yet implemented.") + target_domain, target_slice = self._slice(index) + if isinstance(value, NdArrayField): + if not value.domain == target_domain: + raise ValueError( + f"Incompatible `Domain` in assignment. Source domain = {value.domain}, target domain = {target_domain}." + ) + value = value.ndarray + + assert hasattr(self._ndarray, "at") + object.__setattr__(self, "_ndarray", self._ndarray.at[target_slice].set(value)) common._field.register(jnp.ndarray, JaxArrayField.from_array) diff --git a/src/gt4py/storage/cartesian/utils.py b/src/gt4py/storage/cartesian/utils.py index 50500e536b..9e78ffd97a 100644 --- a/src/gt4py/storage/cartesian/utils.py +++ b/src/gt4py/storage/cartesian/utils.py @@ -248,7 +248,7 @@ def allocate_cpu( dtype: DTypeLike, alignment_bytes: int, aligned_index: Optional[Sequence[int]], -) -> Tuple[allocators._NDBuffer, np.ndarray]: +) -> np.ndarray: device = core_defs.Device(core_defs.DeviceType.CPU, 0) buffer = _CPUBufferAllocator.allocate( shape, @@ -258,7 +258,7 @@ def allocate_cpu( byte_alignment=alignment_bytes, aligned_index=aligned_index, ) - return buffer.buffer, cast(np.ndarray, buffer.ndarray) + return cast(np.ndarray, buffer) def _allocate_gpu( @@ -283,9 +283,7 @@ def _allocate_gpu( aligned_index=aligned_index, ) - buffer_ndarray = cast("cp.ndarray", buffer.ndarray) - - return buffer.buffer, buffer_ndarray + return cast("cp.ndarray", buffer.ndarray) allocate_gpu = _allocate_gpu @@ -321,8 +319,8 @@ def _allocate_gpu_rocm( dtype: DTypeLike, alignment_bytes: int, aligned_index: Optional[Sequence[int]], - ) -> Tuple["cp.ndarray", "cp.ndarray"]: - buffer, ndarray = _allocate_gpu(shape, layout_map, dtype, alignment_bytes, aligned_index) - return buffer, CUDAArrayInterfaceNDArray(ndarray) + ) -> "cp.ndarray": + ndarray = _allocate_gpu(shape, layout_map, dtype, alignment_bytes, aligned_index) + return CUDAArrayInterfaceNDArray(ndarray) allocate_gpu = _allocate_gpu_rocm diff --git a/tests/next_tests/definitions.py b/tests/next_tests/definitions.py index d7413f32d7..4ec5367522 100644 --- a/tests/next_tests/definitions.py +++ b/tests/next_tests/definitions.py @@ -21,6 +21,14 @@ XFAIL = pytest.xfail SKIP = pytest.skip +try: + import jax + import jax.numpy as jnp + + jax.config.update("jax_enable_x64", True) +except ImportError: + jnp = None + # Program processors class _PythonObjectIdMixin: @@ -58,11 +66,13 @@ class EmbeddedDummyBackend: numpy_execution = EmbeddedDummyBackend(next_allocators.StandardCPUFieldBufferAllocator()) cupy_execution = EmbeddedDummyBackend(next_allocators.StandardGPUFieldBufferAllocator()) +jax_execution = EmbeddedDummyBackend(jnp) class EmbeddedIds(_PythonObjectIdMixin, str, enum.Enum): NUMPY_EXECUTION = "next_tests.definitions.numpy_execution" CUPY_EXECUTION = "next_tests.definitions.cupy_execution" + JAX_EXECUTION = "next_tests.definitions.jax_execution" class OptionalProgramBackendId(_PythonObjectIdMixin, str, enum.Enum): @@ -169,6 +179,7 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): BACKEND_SKIP_TEST_MATRIX = { EmbeddedIds.NUMPY_EXECUTION: EMBEDDED_SKIP_LIST, EmbeddedIds.CUPY_EXECUTION: EMBEDDED_SKIP_LIST, + EmbeddedIds.JAX_EXECUTION: EMBEDDED_SKIP_LIST, OptionalProgramBackendId.DACE_CPU: DACE_SKIP_TEST_LIST, OptionalProgramBackendId.DACE_GPU: DACE_SKIP_TEST_LIST, OptionalProgramBackendId.DACE_CPU_NO_OPT: DACE_SKIP_TEST_LIST, diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py index 1147f4bc3e..139df658a1 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py @@ -7,14 +7,14 @@ # SPDX-License-Identifier: BSD-3-Clause import types -from typing import Any, Protocol, TypeVar +from typing import Protocol, TypeVar import numpy as np import pytest import gt4py.next as gtx from gt4py._core import definitions as core_defs -from gt4py.next import backend as next_backend, common, allocators as next_allocators +from gt4py.next import allocators as next_allocators, backend as next_backend, common from gt4py.next.ffront import decorator import next_tests @@ -57,6 +57,9 @@ def __gt_allocator__( pytest.param( next_tests.definitions.EmbeddedIds.CUPY_EXECUTION, marks=pytest.mark.requires_gpu ), + pytest.param( + next_tests.definitions.EmbeddedIds.JAX_EXECUTION, marks=pytest.mark.requires_jax + ), pytest.param( next_tests.definitions.OptionalProgramBackendId.DACE_CPU, marks=pytest.mark.requires_dace, diff --git a/tests/next_tests/unit_tests/test_allocators.py b/tests/next_tests/unit_tests/test_allocators.py index d3001779e1..c8f395b29b 100644 --- a/tests/next_tests/unit_tests/test_allocators.py +++ b/tests/next_tests/unit_tests/test_allocators.py @@ -26,7 +26,7 @@ def __gt_allocate__( dtype: core_defs.DType[core_defs.ScalarT], device_id: int = 0, aligned_index: Optional[Sequence[common.NamedIndex]] = None, - ) -> core_allocators.TensorBuffer[core_defs.DeviceTypeT, core_defs.ScalarT]: + ) -> core_defs.NDArrayObject: pass @@ -152,7 +152,7 @@ def test_allocate(self): def test_allocate(): - from gt4py.next.allocators import StandardCPUFieldBufferAllocator, allocate + from gt4py.next.allocators import StandardCPUFieldBufferAllocator, make_concrete_allocator I = common.Dimension("I") J = common.Dimension("J") @@ -161,21 +161,19 @@ def test_allocate(): # Test with a explicit field allocator allocator = StandardCPUFieldBufferAllocator() - tensor_buffer = allocate(domain, dtype, allocator=allocator) + tensor_buffer = make_concrete_allocator(domain, dtype, allocator=allocator)() assert tensor_buffer.shape == domain.shape assert tensor_buffer.dtype == dtype - assert tensor_buffer.device == core_defs.Device(core_defs.DeviceType.CPU, 0) # Test with a device device = core_defs.Device(core_defs.DeviceType.CPU, 0) - tensor_buffer = allocate(domain, dtype, device=device) + tensor_buffer = make_concrete_allocator(domain, dtype, device=device)() assert tensor_buffer.shape == domain.shape assert tensor_buffer.dtype == dtype - assert tensor_buffer.device == core_defs.Device(core_defs.DeviceType.CPU, 0) # Test with both allocator and device with pytest.raises(ValueError, match="are incompatible"): - allocate( + make_concrete_allocator( domain, dtype, allocator=allocator, @@ -184,4 +182,4 @@ def test_allocate(): # Test with no device or allocator with pytest.raises(ValueError, match="No 'device' or 'allocator' specified"): - allocate(domain, dtype) + make_concrete_allocator(domain, dtype) diff --git a/tests/next_tests/unit_tests/test_constructors.py b/tests/next_tests/unit_tests/test_constructors.py index 0998ab8eab..b6bfad17c6 100644 --- a/tests/next_tests/unit_tests/test_constructors.py +++ b/tests/next_tests/unit_tests/test_constructors.py @@ -6,7 +6,23 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +import copy +from types import ModuleType +from typing import Any + import numpy as np + + +try: + import cupy as cp +except ImportError: + cp = None + +try: + import jax.numpy as jnp +except ImportError: + jnp = None + import pytest from gt4py import next as gtx @@ -21,16 +37,49 @@ sizes = {I: 10, J: 10, K: 10} -# TODO: parametrize with gpu backend and compare with cupy array -@pytest.mark.parametrize( - "allocator, device", - [ - [next_allocators.StandardCPUFieldBufferAllocator(), None], - [None, core_defs.Device(core_defs.DeviceType.CPU, 0)], - ], -) -def test_empty(allocator, device): - ref = np.empty([sizes[I], sizes[J]]).astype(gtx.float32) +def _pretty_print(val): + if val is None: + return "None" + if isinstance(val, ModuleType): + return val.__name__ + return val.__class__.__name__ + + +def _pretty_print_allocator_device_namespace(val: tuple[Any, Any, Any]): + return f"allocator={_pretty_print(val[0])}-device={_pretty_print(val[1])}-ref_namespace={_pretty_print(val[2])}" + + +def allocator_device_refnamespace_params(): + for v in [ + [next_allocators.StandardCPUFieldBufferAllocator(), None, np], + [None, core_defs.Device(core_defs.DeviceType.CPU, 0), np], + [np, None, np], + ]: + yield pytest.param( + v, + id=_pretty_print_allocator_device_namespace(v), + ) + for v in [ + [next_allocators.StandardGPUFieldBufferAllocator(), None, cp], + [None, core_defs.Device(core_defs.DeviceType.CUDA, 0), cp], # TODO CUDA or HIP... + ]: + yield pytest.param( + v, id=_pretty_print_allocator_device_namespace(v), marks=pytest.mark.requires_gpu + ) + for v in [[jnp, None, jnp]]: + yield pytest.param( + v, id=_pretty_print_allocator_device_namespace(v), marks=pytest.mark.requires_jax + ) + + +@pytest.fixture(params=allocator_device_refnamespace_params()) +def allocator_device_refnamespace(request): + return request.param + + +def test_empty(allocator_device_refnamespace): + allocator, device, xp = allocator_device_refnamespace + ref = xp.empty([sizes[I], sizes[J]]).astype(gtx.float32) a = gtx.empty( domain={I: range(sizes[I]), J: range(sizes[J])}, dtype=core_defs.dtype(np.float32), @@ -40,15 +89,8 @@ def test_empty(allocator, device): assert a.shape == ref.shape -# TODO: parametrize with gpu backend and compare with cupy array -@pytest.mark.parametrize( - "allocator, device", - [ - [next_allocators.StandardCPUFieldBufferAllocator(), None], - [None, core_defs.Device(core_defs.DeviceType.CPU, 0)], - ], -) -def test_zeros(allocator, device): +def test_zeros(allocator_device_refnamespace): + allocator, device, xp = allocator_device_refnamespace a = gtx.zeros( common.Domain( dims=(I, J), ranges=(common.UnitRange(0, sizes[I]), common.UnitRange(0, sizes[J])) @@ -57,40 +99,26 @@ def test_zeros(allocator, device): allocator=allocator, device=device, ) - ref = np.zeros((sizes[I], sizes[J])).astype(gtx.float32) + ref = xp.zeros((sizes[I], sizes[J])).astype(gtx.float32) - assert np.array_equal(a.ndarray, ref) + assert xp.array_equal(a.ndarray, ref) -# TODO: parametrize with gpu backend and compare with cupy array -@pytest.mark.parametrize( - "allocator, device", - [ - [next_allocators.StandardCPUFieldBufferAllocator(), None], - [None, core_defs.Device(core_defs.DeviceType.CPU, 0)], - ], -) -def test_ones(allocator, device): +def test_ones(allocator_device_refnamespace): + allocator, device, xp = allocator_device_refnamespace a = gtx.ones( common.Domain(dims=(I, J), ranges=(common.UnitRange(0, 10), common.UnitRange(0, 10))), dtype=core_defs.dtype(np.float32), allocator=allocator, device=device, ) - ref = np.ones((sizes[I], sizes[J])).astype(gtx.float32) + ref = xp.ones((sizes[I], sizes[J])).astype(gtx.float32) - assert np.array_equal(a.ndarray, ref) + assert xp.array_equal(a.ndarray, ref) -# TODO: parametrize with gpu backend and compare with cupy array -@pytest.mark.parametrize( - "allocator, device", - [ - [next_allocators.StandardCPUFieldBufferAllocator(), None], - [None, core_defs.Device(core_defs.DeviceType.CPU, 0)], - ], -) -def test_full(allocator, device): +def test_full(allocator_device_refnamespace): + allocator, device, xp = allocator_device_refnamespace a = gtx.full( domain={I: range(sizes[I] - 2), J: (sizes[J] - 2)}, fill_value=42.0, @@ -98,9 +126,19 @@ def test_full(allocator, device): allocator=allocator, device=device, ) - ref = np.full((sizes[I] - 2, sizes[J] - 2), 42.0).astype(gtx.float32) + ref = xp.full((sizes[I] - 2, sizes[J] - 2), 42.0).astype(gtx.float32) - assert np.array_equal(a.ndarray, ref) + assert xp.array_equal(a.ndarray, ref) + + +def test_deepcopy(): + testee = gtx.as_field([I, J], np.random.rand(sizes[I], sizes[J])) + result = copy.deepcopy(testee) + assert testee.ndarray.strides == result.ndarray.strides + assert ( + result.ndarray.strides != result.ndarray.copy().strides + ) # sanity check for this test, make sure our allocator don't have C-contiguous strides + assert np.array_equal(testee.ndarray, result.ndarray) def test_as_field(): From 1b3e0c5273305187db5025e7307f8620c796995f Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Wed, 18 Dec 2024 21:46:02 +0100 Subject: [PATCH 04/12] cleanup copying --- src/gt4py/next/allocators.py | 40 ++++------------ src/gt4py/next/constructors.py | 47 ++++++++++--------- src/gt4py/next/embedded/nd_array_field.py | 12 ++--- .../embedded_tests/test_nd_array_field.py | 8 ++++ .../unit_tests/test_constructors.py | 4 +- 5 files changed, 51 insertions(+), 60 deletions(-) diff --git a/src/gt4py/next/allocators.py b/src/gt4py/next/allocators.py index abc9c5f4ee..5336a6fb1a 100644 --- a/src/gt4py/next/allocators.py +++ b/src/gt4py/next/allocators.py @@ -292,25 +292,14 @@ def __init__(self) -> None: ) -class ConcreteAllocator(Protocol): - def __call__( - domain: common.DomainLike, - dtype: core_defs.DType[core_defs.ScalarT], - *, - aligned_index: Optional[Sequence[common.NamedIndex]], - allocator: FieldBufferAllocationUtil, - device: core_defs.Device, - ) -> core_defs.NDArrayObject: ... - - -def make_concrete_allocator( +def allocate( + *, domain: common.DomainLike, # TODO: there is an inconsistency between DomainLike and concrete DType, probably accept either (Domain, DType) or (DomainLike, DTypeLike). anyway this is not meant to be user-facing dtype: core_defs.DType[core_defs.ScalarT], - *, aligned_index: Optional[Sequence[common.NamedIndex]] = None, allocator: Optional[FieldBufferAllocationUtil] = None, device: Optional[core_defs.Device] = None, -) -> ConcreteAllocator: +) -> core_defs.NDArrayObject: """ TODO: docstring Allocate an NDArrayObject for the given domain and device or allocator. @@ -346,20 +335,9 @@ def make_concrete_allocator( elif device.device_type != actual_allocator.__gt_device_type__: raise ValueError(f"Device '{device}' and allocator '{actual_allocator}' are incompatible.") - def allocate( - domain: common.DomainLike = domain, - dtype: core_defs.DType[core_defs.ScalarT] = dtype, - *, - aligned_index: Optional[Sequence[common.NamedIndex]] = aligned_index, - allocator: FieldBufferAllocationUtil = actual_allocator, - device: core_defs.Device = device, - ) -> core_defs.NDArrayObject: - # TODO check how to get from FieldBufferAllocationUtil to FieldBufferAllocatorProtocol - return allocator.__gt_allocate__( - domain=common.domain(domain), - dtype=dtype, - device_id=device.device_id, - aligned_index=aligned_index, - ) - - return allocate + return actual_allocator.__gt_allocate__( + domain=common.domain(domain), + dtype=dtype, + device_id=device.device_id, + aligned_index=aligned_index, + ) diff --git a/src/gt4py/next/constructors.py b/src/gt4py/next/constructors.py index 7ac800ab97..c4e066f669 100644 --- a/src/gt4py/next/constructors.py +++ b/src/gt4py/next/constructors.py @@ -8,8 +8,9 @@ from __future__ import annotations +import functools from collections.abc import Mapping, Sequence -from typing import Any, Callable, Optional, Protocol, TypeGuard, cast +from typing import Any, Optional, Protocol, TypeGuard, cast from typing_extensions import NotRequired, TypedDict @@ -148,9 +149,10 @@ def empty( if allocator is None and device is None: device = core_defs.Device(core_defs.DeviceType.CPU, device_id=0) - allocate = next_allocators.make_concrete_allocator( - domain, - core_defs.dtype(dtype), + allocate = functools.partial( + next_allocators.allocate, + domain=domain, + dtype=core_defs.dtype(dtype), aligned_index=aligned_index, allocator=allocator, device=device, @@ -294,7 +296,7 @@ def as_field( aligned_index: Optional[Sequence[common.NamedIndex]] = None, allocator: Optional[next_allocators.FieldBufferAllocatorProtocol] = None, device: Optional[core_defs.Device] = None, - # TODO: copy=False + # TODO(havogt): copy=False ) -> nd_array_field.NdArrayField: """Create a Field from an array-like object using the given (or device-default) allocator. @@ -369,7 +371,7 @@ def as_field( domain=actual_domain, dtype=dtype, device=device, - # copy=copy + # TODO(havogt): copy=copy ) # TODO(egparedes): allow zero-copy construction (no reallocation) if buffer has @@ -466,8 +468,12 @@ def as_connectivity( if (allocator is None) and (device is None) and xtyping.supports_dlpack(data): device = core_defs.Device(*data.__dlpack_device__()) - allocate = next_allocators.make_concrete_allocator( - actual_domain, dtype, allocator=allocator, device=device + allocate = functools.partial( + next_allocators.allocate, + domain=actual_domain, + dtype=dtype, + allocator=allocator, + device=device, ) buffer = allocate() # TODO(havogt): consider adding MutableNDArrayObject @@ -480,10 +486,7 @@ def as_connectivity( return connectivity_field -_like_field = None # for more descriptive function signature in editors - - -class AllocatorParams(TypedDict): +class _AllocatorParams(TypedDict): domain: NotRequired[common.DomainLike] dtype: NotRequired[core_defs.DType[core_defs.ScalarT],] aligned_index: NotRequired[Sequence[common.NamedIndex]] @@ -494,23 +497,25 @@ class AllocatorParams(TypedDict): def empty_like( field: nd_array_field.NdArrayField, *, - domain: Optional[common.DomainLike] = _like_field, - dtype: Optional[core_defs.DTypeLike] = _like_field, - aligned_index: Optional[Sequence[common.NamedIndex]] = _like_field, - allocator: Optional[next_allocators.FieldBufferAllocationUtil] = _like_field, - device: Optional[core_defs.Device] = _like_field, + domain: Optional[common.DomainLike] = None, + dtype: Optional[core_defs.DTypeLike] = None, + aligned_index: Optional[Sequence[common.NamedIndex]] = None, + allocator: Optional[next_allocators.FieldBufferAllocationUtil] = None, + device: Optional[core_defs.Device] = None, ) -> nd_array_field.NdArrayField: - kwargs: AllocatorParams = {} + kwargs: _AllocatorParams = {} if domain is not None: kwargs["domain"] = domain if dtype is not None: kwargs["dtype"] = core_defs.dtype(dtype) if aligned_index is not None: kwargs["aligned_index"] = aligned_index - if allocator is not eve.NOTHING: + if allocator is not None: kwargs["allocator"] = allocator - if device is not eve.NOTHING: + if device is not None: kwargs["device"] = device if field._allocator is None: raise ValueError("'Field' does not have an allocator.") # TODO discuss if this is possible - return field._allocator(**kwargs) + + allocate = functools.partial(field._allocator, **kwargs) + return common._field(allocate(), domain=kwargs.get("domain", field.domain), allocator=allocate) diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index 417cf1f50a..c432dfbc45 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -432,17 +432,15 @@ def _slice( assert common.is_relative_index_sequence(slice_) return new_domain, slice_ - def __copy__(self) -> Never: - # TODO does this make sense? - raise NotImplementedError( - "`NdArrayField` is frozen, shallow copying is not a useful operation. Did you want to deepcopy?" - ) - - def __deepcopy__(self, _: Any) -> NdArrayField: + def __copy__(self) -> NdArrayField: + # Note: `copy` copies the data, following NumPy behavior ndarray_copy = self._allocator() ndarray_copy[:] = self.ndarray[:] return self.__class__(self.domain, ndarray_copy, _allocator=self._allocator) + def __deepcopy__(self, _: Any) -> NdArrayField: + return self.__copy__() + if dace: # Extension of NdArrayField adding SDFGConvertible support in GT4Py Programs def _dace_data_ptr(self) -> int: diff --git a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py index 9dde5bb40a..af5dd03102 100644 --- a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py +++ b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py @@ -6,6 +6,7 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +import copy import math import operator from typing import Callable, Iterable, Optional @@ -261,6 +262,13 @@ def test_as_scalar(nd_array_implementation): assert isinstance(result, np.float32) +@pytest.mark.parametrize("copy", [copy.copy, copy.deepcopy]) +def test_copy(copy, nd_array_implementation): + testee = _make_field_or_scalar([[0, 1], [2, 3]], nd_array_implementation) + result = copy(testee) + assert np.array_equal(testee.ndarray, result.ndarray) + + def product_nd_array_implementation_params(): for xp1 in nd_array_field._nd_array_implementations: for xp2 in nd_array_field._nd_array_implementations: diff --git a/tests/next_tests/unit_tests/test_constructors.py b/tests/next_tests/unit_tests/test_constructors.py index b6bfad17c6..cd0c0014ce 100644 --- a/tests/next_tests/unit_tests/test_constructors.py +++ b/tests/next_tests/unit_tests/test_constructors.py @@ -131,7 +131,9 @@ def test_full(allocator_device_refnamespace): assert xp.array_equal(a.ndarray, ref) -def test_deepcopy(): +def test_copy(): + """Ensure data AND layout is preserved.""" + testee = gtx.as_field([I, J], np.random.rand(sizes[I], sizes[J])) result = copy.deepcopy(testee) assert testee.ndarray.strides == result.ndarray.strides From afbbc8e136ce14e7f9a7f735f1d05b3deaeb2441 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Thu, 19 Dec 2024 11:06:12 +0100 Subject: [PATCH 05/12] fix storage inferface test --- .../unit_tests/test_interface.py | 23 ++----------------- 1 file changed, 2 insertions(+), 21 deletions(-) diff --git a/tests/storage_tests/unit_tests/test_interface.py b/tests/storage_tests/unit_tests/test_interface.py index ba7bc2aaef..038b09aa4c 100644 --- a/tests/storage_tests/unit_tests/test_interface.py +++ b/tests/storage_tests/unit_tests/test_interface.py @@ -121,16 +121,7 @@ def test_allocate_cpu(param_dict): shape = param_dict["shape"] layout_map = param_dict["layout_order"] - raw_buffer, field = allocate_cpu(shape, layout_map, dtype, alignment_bytes, aligned_index) - - # check that memory of field is contained in raw_buffer - np_byte_bounds = ( - np.byte_bounds if hasattr(np, "byte_bounds") else np.lib.array_utils.byte_bounds - ) - assert ( - np_byte_bounds(field)[0] >= np_byte_bounds(raw_buffer)[0] - and np_byte_bounds(field)[1] <= np_byte_bounds(raw_buffer)[1] - ) + field = allocate_cpu(shape, layout_map, dtype, alignment_bytes, aligned_index) # check if the first compute-domain point in the last dimension is aligned for 100 random "columns" import random @@ -185,17 +176,7 @@ def test_allocate_gpu(param_dict): aligned_index = param_dict["aligned_index"] shape = param_dict["shape"] layout_map = param_dict["layout_order"] - device_raw_buffer, device_field = allocate_gpu( - shape, layout_map, dtype, alignment_bytes, aligned_index - ) - - # Would like to check device_field.base against device_raw_buffer but - # as_strided returns an ndarray where device_field.base is set to None. - # Instead, check that the memory of field is contained in raws buffer - assert ( - device_field.data.ptr >= device_raw_buffer.data.ptr - and device_field[-1:].data.ptr <= device_raw_buffer[-1:].data.ptr - ) + device_field = allocate_gpu(shape, layout_map, dtype, alignment_bytes, aligned_index) # check if the first compute-domain point in the last dimension is aligned for 100 random "columns" import random From 3bb206f584b44cc0b72fe0ce34d1377f062bd197 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Thu, 19 Dec 2024 11:07:00 +0100 Subject: [PATCH 06/12] move to_array_api_dtype --- src/gt4py/_core/definitions.py | 10 ++++++++++ src/gt4py/next/constructors.py | 10 +--------- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/src/gt4py/_core/definitions.py b/src/gt4py/_core/definitions.py index 238ec9d266..821f0707e8 100644 --- a/src/gt4py/_core/definitions.py +++ b/src/gt4py/_core/definitions.py @@ -14,6 +14,7 @@ import functools import math import numbers +from types import ModuleType import numpy as np import numpy.typing as npt @@ -338,6 +339,15 @@ def dtype(dtype_like: DTypeLike) -> DType: return dtype_like if isinstance(dtype_like, DType) else DType(np.dtype(dtype_like).type) +def to_array_api_dtype(xp: ModuleType, dtype_: DTypeLike | None) -> Any: + """ + Converts a GT4Py `DTypeLike` to the dtype object of the given Array API namespace. + + Note: For convenience `None` is passed-through as it has a consistent meaning in all Array API implementations. + """ + return None if dtype_ is None else getattr(xp, dtype(dtype_).scalar_type.__name__) + + # -- Custom protocols -- class GTDimsInterface(Protocol): """ diff --git a/src/gt4py/next/constructors.py b/src/gt4py/next/constructors.py index c4e066f669..9238e7b363 100644 --- a/src/gt4py/next/constructors.py +++ b/src/gt4py/next/constructors.py @@ -41,14 +41,6 @@ def _has_array_api_creation_functions(obj: Any) -> TypeGuard[_HasArrayApiCreatio ) -def _convert_dtype( - xp: Any, dtype: core_defs.DTypeLike -) -> Any: # TODO move to core_defs as `to_array_api_dtype` - if dtype is None: - return None - return getattr(xp, core_defs.dtype(dtype).scalar_type.__name__) - - def _array_api_construction( xp: _HasArrayApiCreationFunctions, fun: str, @@ -60,7 +52,7 @@ def _array_api_construction( ): if device is not None: raise NotImplementedError("Device specification is not yet supported.") - buffer = getattr(xp, fun)(*args, dtype=_convert_dtype(xp, dtype), **kwargs) + buffer = getattr(xp, fun)(*args, dtype=core_defs.to_array_api_dtype(xp, dtype), **kwargs) def allocate( domain: common.DomainLike = domain, From 1c5acd6053cd5cb9f76349a831958fd9068cb330 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Thu, 19 Dec 2024 11:08:38 +0100 Subject: [PATCH 07/12] fix variable names --- src/gt4py/storage/cartesian/utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/gt4py/storage/cartesian/utils.py b/src/gt4py/storage/cartesian/utils.py index 9e78ffd97a..ac12ba17c1 100644 --- a/src/gt4py/storage/cartesian/utils.py +++ b/src/gt4py/storage/cartesian/utils.py @@ -250,7 +250,7 @@ def allocate_cpu( aligned_index: Optional[Sequence[int]], ) -> np.ndarray: device = core_defs.Device(core_defs.DeviceType.CPU, 0) - buffer = _CPUBufferAllocator.allocate( + ndarray = _CPUBufferAllocator.allocate( shape, core_defs.dtype(dtype), device_id=device.device_id, @@ -258,7 +258,7 @@ def allocate_cpu( byte_alignment=alignment_bytes, aligned_index=aligned_index, ) - return cast(np.ndarray, buffer) + return cast(np.ndarray, ndarray) def _allocate_gpu( @@ -274,7 +274,7 @@ def _allocate_gpu( (core_defs.DeviceType.ROCM if gt_config.GT4PY_USE_HIP else core_defs.DeviceType.CUDA), 0, ) - buffer = _GPUBufferAllocator.allocate( + ndarray = _GPUBufferAllocator.allocate( shape, core_defs.dtype(dtype), device_id=device.device_id, @@ -283,7 +283,7 @@ def _allocate_gpu( aligned_index=aligned_index, ) - return cast("cp.ndarray", buffer.ndarray) + return cast("cp.ndarray", ndarray) allocate_gpu = _allocate_gpu From 92b8ea62b8960edd72761255679e3a51758a81ce Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Thu, 19 Dec 2024 20:42:30 +0100 Subject: [PATCH 08/12] refactor to an allocation namespace --- pyproject.toml | 2 + src/gt4py/_core/definitions.py | 55 +++- src/gt4py/_core/gt_array_namespace.py | 8 + .../next/{allocators.py => _allocators.py} | 283 ++++++++++++++++- src/gt4py/next/backend.py | 2 +- src/gt4py/next/common.py | 8 +- src/gt4py/next/constructors.py | 291 +++++------------- src/gt4py/next/embedded/nd_array_field.py | 44 ++- src/gt4py/next/ffront/decorator.py | 2 +- .../next/program_processors/runners/dace.py | 2 +- .../runners/dace_fieldview/workflow.py | 2 +- .../next/program_processors/runners/gtfn.py | 2 +- .../program_processors/runners/roundtrip.py | 2 +- src/gt4py/storage/allocators.py | 10 +- src/gt4py/storage/cartesian/utils.py | 2 +- tests/next_tests/definitions.py | 14 +- tests/next_tests/integration_tests/cases.py | 4 +- .../feature_tests/dace/test_orchestration.py | 4 +- .../ffront_tests/ffront_test_utils.py | 2 +- .../ffront_tests/test_program.py | 3 + .../ffront_tests/test_ffront_fvm_nabla.py | 3 +- .../runners_tests/test_gtfn.py | 6 +- .../next_tests/unit_tests/test_allocators.py | 6 +- .../unit_tests/test_constructors.py | 2 +- 24 files changed, 486 insertions(+), 273 deletions(-) create mode 100644 src/gt4py/_core/gt_array_namespace.py rename src/gt4py/next/{allocators.py => _allocators.py} (50%) diff --git a/pyproject.toml b/pyproject.toml index 3f80dac5d5..da33899965 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,7 @@ classifiers = [ 'Topic :: Scientific/Engineering :: Physics' ] dependencies = [ + "array-api-compat>=1.9.1", "astunparse>=1.6.3;python_version<'3.9'", 'attrs>=21.3', 'black>=22.3', @@ -265,6 +266,7 @@ markers = [ 'uses_unstructured_shift: tests that use a unstructured connectivity', 'uses_max_over: tests that use the max_over builtin', 'uses_mesh_with_skip_values: tests that use a mesh with skip values', + 'slices_out_argument: tests that slice the out argument in a field_operator call', 'checks_specific_error: tests that rely on the backend to produce a specific error message' ] norecursedirs = ['dist', 'build', 'cpp_backend_tests/build*', '_local/*', '.*'] diff --git a/src/gt4py/_core/definitions.py b/src/gt4py/_core/definitions.py index 821f0707e8..f6250b1b60 100644 --- a/src/gt4py/_core/definitions.py +++ b/src/gt4py/_core/definitions.py @@ -14,7 +14,6 @@ import functools import math import numbers -from types import ModuleType import numpy as np import numpy.typing as npt @@ -28,6 +27,7 @@ Iterator, Literal, Protocol, + Self, Sequence, Tuple, Type, @@ -339,15 +339,6 @@ def dtype(dtype_like: DTypeLike) -> DType: return dtype_like if isinstance(dtype_like, DType) else DType(np.dtype(dtype_like).type) -def to_array_api_dtype(xp: ModuleType, dtype_: DTypeLike | None) -> Any: - """ - Converts a GT4Py `DTypeLike` to the dtype object of the given Array API namespace. - - Note: For convenience `None` is passed-through as it has a consistent meaning in all Array API implementations. - """ - return None if dtype_ is None else getattr(xp, dtype(dtype_).scalar_type.__name__) - - # -- Custom protocols -- class GTDimsInterface(Protocol): """ @@ -415,6 +406,7 @@ class DeviceType(enum.IntEnum): MetalDeviceTyping, VPIDeviceTyping, ROCMDeviceTyping, + covariant=True, ) @@ -464,7 +456,7 @@ def astype(self, dtype: npt.DTypeLike) -> NDArrayObject: ... def any(self) -> bool: ... - def __getitem__(self, item: Any) -> NDArrayObject: ... + def __getitem__(self, item: Any) -> Self: ... def __abs__(self) -> NDArrayObject: ... @@ -517,12 +509,47 @@ def __or__(self, other: NDArrayObject | Scalar) -> NDArrayObject: ... def __xor__(self, other: NDArrayObject | Scalar) -> NDArrayObject: ... +class MutableNDArrayObject(NDArrayObject, Protocol): + def __setitem__(self, index: Any, value: Any) -> None: ... + + class ArrayApiNamespace(Protocol): - @property - def __array_api_version__(self) -> str: ... + def empty(self, shape: Sequence[int], *, dtype: Any = None, device: Any = None) -> Any: ... + def zeros(self, shape: Sequence[int], *, dtype: Any = None, device: Any = None) -> Any: ... + def ones(self, shape: Sequence[int], *, dtype: Any = None, device: Any = None) -> Any: ... + def full( + self, shape: Sequence[int], fill_value: Scalar, *, dtype: Any = None, device: Any = None + ) -> Any: ... + def asarray(self, obj: Any, *, dtype: Any = None, copy: Any = None) -> Any: ... + + # @property # once all relevant implementations have this attribute + # def __array_api_version__(self) -> str: ... # noqa: ERA001 # TODO(havogt): add relevant methods and attributes or wait for the standard to provide it, see e.g. https://github.com/data-apis/array-api/issues/697 def is_array_api_namespace(obj: Any) -> TypeGuard[ArrayApiNamespace]: - return hasattr(obj, "__array_api_version__") + # return hasattr(obj, "__array_api_version__") # noqa: ERA001 # once all relevant implementations have this attribute + return ( + hasattr(obj, "empty") + and hasattr(obj, "zeros") + and hasattr(obj, "ones") + and hasattr(obj, "full") + and hasattr(obj, "asarray") + ) + + +def to_array_api_dtype(xp: ArrayApiNamespace, dtype_: DTypeLike | None) -> Any: + """ + Converts a GT4Py `DTypeLike` to the dtype object of the given Array API namespace. + + Note: For convenience `None` is passed-through as it has a consistent meaning in all Array API implementations. + """ + if dtype_ is None: + return None + else: + dtype_ = dtype(dtype_) + assert ( + dtype_.tensor_shape == () + ) # TODO(havogt): support tensor shapes (or remove from our DType) + return getattr(xp, dtype_.scalar_type.__name__) diff --git a/src/gt4py/_core/gt_array_namespace.py b/src/gt4py/_core/gt_array_namespace.py new file mode 100644 index 0000000000..abf4c3e24c --- /dev/null +++ b/src/gt4py/_core/gt_array_namespace.py @@ -0,0 +1,8 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + diff --git a/src/gt4py/next/allocators.py b/src/gt4py/next/_allocators.py similarity index 50% rename from src/gt4py/next/allocators.py rename to src/gt4py/next/_allocators.py index 5336a6fb1a..1f2e03f9fe 100644 --- a/src/gt4py/next/allocators.py +++ b/src/gt4py/next/_allocators.py @@ -6,10 +6,14 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +# TODO: make this module private + import abc import dataclasses import functools +import array_api_compat + import gt4py._core.definitions as core_defs import gt4py.next.common as common import gt4py.storage.allocators as core_allocators @@ -41,6 +45,68 @@ ) +class GTArrayAllocationNamespace(Protocol): + """ + Standard Array API-like construction functions based on `domain` instead of `shape`. + + The reason to use `domain` is: + - we need the `Dimension`s for getting the desired ordering of strides + - `aligned_index` refers to a point in the domain (absolute position), not relative to the array shape + """ + + # Notes: + # - this concept could be evolved to a more general `GTArrayNamespace` that adds all array functions that we use in embedded + # - maybe for advanced indexing use-case: extend the namespace with standard compliant fallback functions. + + def empty( + self, + domain: common.DomainLike, + *, + dtype: Optional[core_defs.DTypeLike] = None, + device: Optional[core_defs.Device] = None, + aligned_index: Optional[Sequence[common.NamedIndex]] = None, + ) -> core_defs.NDArrayObject: ... + + def zeros( + self, + domain: common.DomainLike, + *, + dtype: Optional[core_defs.DTypeLike] = None, + device: Optional[core_defs.Device] = None, + aligned_index: Optional[Sequence[common.NamedIndex]] = None, + ) -> core_defs.NDArrayObject: ... + + def ones( + self, + domain: common.DomainLike, + *, + dtype: Optional[core_defs.DTypeLike] = None, + device: Optional[core_defs.Device] = None, + aligned_index: Optional[Sequence[common.NamedIndex]] = None, + ) -> core_defs.NDArrayObject: ... + + def full( + self, + domain: common.DomainLike, + fill_value: core_defs.Scalar, + *, + dtype: Optional[core_defs.DTypeLike] = None, + device: Optional[core_defs.Device] = None, + aligned_index: Optional[Sequence[common.NamedIndex]] = None, + ) -> core_defs.NDArrayObject: ... + + def asarray( + self, + data: core_defs.NDArrayObject, + *, + domain: common.DomainLike, + dtype: Optional[core_defs.DTypeLike] = None, + device: Optional[core_defs.Device] = None, + copy: Optional[bool] = None, + aligned_index: Optional[Sequence[common.NamedIndex]] = None, + ) -> core_defs.NDArrayObject: ... + + FieldLayoutMapper: TypeAlias = Callable[ [Sequence[common.Dimension]], core_allocators.BufferLayoutMap ] @@ -60,7 +126,7 @@ def __gt_allocate__( dtype: core_defs.DType[core_defs.ScalarT], device_id: int = 0, aligned_index: Optional[Sequence[common.NamedIndex]] = None, # absolute position - ) -> core_defs.NDArrayObject: ... + ) -> core_defs.MutableNDArrayObject: ... def is_field_allocator(obj: Any) -> TypeGuard[FieldBufferAllocatorProtocol]: @@ -160,7 +226,7 @@ def __gt_allocate__( dtype: core_defs.DType[core_defs.ScalarT], device_id: int = 0, aligned_index: Optional[Sequence[common.NamedIndex]] = None, # absolute position - ) -> core_defs.NDArrayObject: + ) -> core_defs.MutableNDArrayObject: shape = domain.shape layout_map = self.layout_mapper(domain.dims) # TODO(egparedes): add support for non-empty aligned index values @@ -242,7 +308,7 @@ def __gt_allocate__( dtype: core_defs.DType[core_defs.ScalarT], device_id: int = 0, aligned_index: Optional[Sequence[common.NamedIndex]] = None, # absolute position - ) -> core_defs.NDArrayObject: + ) -> core_defs.MutableNDArrayObject: raise self.exception @@ -294,12 +360,12 @@ def __init__(self) -> None: def allocate( *, - domain: common.DomainLike, # TODO: there is an inconsistency between DomainLike and concrete DType, probably accept either (Domain, DType) or (DomainLike, DTypeLike). anyway this is not meant to be user-facing + domain: common.Domain, dtype: core_defs.DType[core_defs.ScalarT], aligned_index: Optional[Sequence[common.NamedIndex]] = None, allocator: Optional[FieldBufferAllocationUtil] = None, device: Optional[core_defs.Device] = None, -) -> core_defs.NDArrayObject: +) -> core_defs.MutableNDArrayObject: """ TODO: docstring Allocate an NDArrayObject for the given domain and device or allocator. @@ -336,8 +402,213 @@ def allocate( raise ValueError(f"Device '{device}' and allocator '{actual_allocator}' are incompatible.") return actual_allocator.__gt_allocate__( - domain=common.domain(domain), + domain=domain, dtype=dtype, device_id=device.device_id, aligned_index=aligned_index, ) + + +def _check_unsupported_device_and_aligned_index( + device: Optional[core_defs.Device], aligned_index: Optional[Sequence[common.NamedIndex]] +) -> None: + if aligned_index is not None: + raise NotImplementedError("Aligned index is not support for Array API namespaces.") + if device is not None: + # TODO(havogt): this requires to translate our device object to the concrete Array API implementation's device object + raise NotImplementedError("Device specification is not yet supported.") + + +def get_array_allocation_namespace( + allocator: FieldBufferAllocationUtil | core_defs.ArrayApiNamespace | None, +) -> GTArrayAllocationNamespace: + if allocator is None: + allocator = StandardCPUFieldBufferAllocator() + if core_defs.is_array_api_namespace(allocator): + assert core_defs.is_array_api_namespace(allocator) + array_ns = array_api_compat.array_namespace(allocator.empty([0])) + + class _ArrayNamespaceWrapper: + @staticmethod + def empty( + domain: common.DomainLike, + *, + dtype: Optional[core_defs.DTypeLike] = None, + device: Optional[core_defs.Device] = None, + aligned_index: Optional[Sequence[common.NamedIndex]] = None, + ) -> core_defs.NDArrayObject: + _check_unsupported_device_and_aligned_index(device, aligned_index) + return array_ns.empty( + shape=common.domain(domain).shape, + dtype=core_defs.to_array_api_dtype(array_ns, dtype), + ) + + @staticmethod + def zeros( + domain: common.DomainLike, + *, + dtype: Optional[core_defs.DTypeLike] = None, + device: Optional[core_defs.Device] = None, + aligned_index: Optional[Sequence[common.NamedIndex]] = None, + ) -> core_defs.NDArrayObject: + _check_unsupported_device_and_aligned_index(device, aligned_index) + return array_ns.zeros( + shape=common.domain(domain).shape, + dtype=core_defs.to_array_api_dtype(array_ns, dtype), + ) + + @staticmethod + def ones( + domain: common.DomainLike, + *, + dtype: Optional[core_defs.DTypeLike] = None, + device: Optional[core_defs.Device] = None, + aligned_index: Optional[Sequence[common.NamedIndex]] = None, + ) -> core_defs.NDArrayObject: + _check_unsupported_device_and_aligned_index(device, aligned_index) + return array_ns.ones( + shape=common.domain(domain).shape, + dtype=core_defs.to_array_api_dtype(array_ns, dtype), + ) + + @staticmethod + def full( + domain: common.DomainLike, + fill_value: core_defs.Scalar, + *, + dtype: Optional[core_defs.DTypeLike] = None, + device: Optional[core_defs.Device] = None, + aligned_index: Optional[Sequence[common.NamedIndex]] = None, + ) -> core_defs.NDArrayObject: + _check_unsupported_device_and_aligned_index(device, aligned_index) + return array_ns.full( + shape=common.domain(domain).shape, + fill_value=fill_value, + dtype=core_defs.to_array_api_dtype(array_ns, dtype), + ) + + @staticmethod + def asarray( + data: core_defs.NDArrayObject, + *, + domain: common.DomainLike, + dtype: Optional[core_defs.DTypeLike] = None, + device: Optional[core_defs.Device] = None, + copy: Optional[bool] = None, + aligned_index: Optional[Sequence[common.NamedIndex]] = None, + ) -> core_defs.NDArrayObject: + _check_unsupported_device_and_aligned_index(device, aligned_index) + if not data.shape == common.domain(domain).shape: + raise ValueError( + f"Array of shape '{data.shape}' is incompatible with domain '{domain}'." + ) + + return array_ns.asarray( + data, dtype=core_defs.to_array_api_dtype(array_ns, dtype), copy=copy + ) + + return _ArrayNamespaceWrapper + + else: + + class _CustomAllocationArrayNamespace: + @staticmethod + def empty( + domain: common.DomainLike, + *, + dtype: Optional[core_defs.DTypeLike] = None, + device: Optional[core_defs.Device] = None, + aligned_index: Optional[Sequence[common.NamedIndex]] = None, + ) -> core_defs.NDArrayObject: + assert is_field_allocator(allocator) + return allocate( + domain=common.domain(domain), + dtype=core_defs.dtype(dtype), + aligned_index=aligned_index, + allocator=allocator, + device=device, + ) + + @staticmethod + def zeros( + domain: common.DomainLike, + *, + dtype: Optional[core_defs.DTypeLike] = None, + device: Optional[core_defs.Device] = None, + aligned_index: Optional[Sequence[common.NamedIndex]] = None, + ) -> core_defs.NDArrayObject: + assert is_field_allocator(allocator) + buffer = allocate( + domain=common.domain(domain), + dtype=core_defs.dtype(dtype), + aligned_index=aligned_index, + allocator=allocator, + device=device, + ) + buffer[...] = 0 + return buffer + + @staticmethod + def ones( + domain: common.DomainLike, + *, + dtype: Optional[core_defs.DTypeLike] = None, + device: Optional[core_defs.Device] = None, + aligned_index: Optional[Sequence[common.NamedIndex]] = None, + ) -> core_defs.NDArrayObject: + assert is_field_allocator(allocator) + buffer = allocate( + domain=common.domain(domain), + dtype=core_defs.dtype(dtype), + aligned_index=aligned_index, + allocator=allocator, + device=device, + ) + buffer[...] = 1 + return buffer + + @staticmethod + def full( + domain: common.DomainLike, + fill_value: core_defs.Scalar, + *, + dtype: Optional[core_defs.DTypeLike] = None, + device: Optional[core_defs.Device] = None, + aligned_index: Optional[Sequence[common.NamedIndex]] = None, + ) -> core_defs.NDArrayObject: + assert is_field_allocator(allocator) + buffer = allocate( + domain=common.domain(domain), + dtype=core_defs.dtype(dtype), # TODO check all dtypes + aligned_index=aligned_index, + allocator=allocator, + device=device, + ) + buffer[...] = fill_value + return buffer + + @staticmethod + def asarray( + data: core_defs.NDArrayObject, + *, + domain: common.DomainLike, + dtype: Optional[core_defs.DTypeLike] = None, + device: Optional[core_defs.Device] = None, + copy: Optional[bool] = None, + aligned_index: Optional[Sequence[common.NamedIndex]] = None, + ) -> core_defs.NDArrayObject: + assert is_field_allocator(allocator) + if not copy: + raise NotImplementedError("Zero-copy construction is not yet supported.") + dtype = core_defs.dtype(data.dtype) if dtype is None else core_defs.dtype(dtype) + buffer = allocate( + domain=common.domain(domain), + dtype=dtype, + aligned_index=aligned_index, + allocator=allocator, + device=device, + ) + buffer[...] = array_api_compat.array_namespace(buffer).asarray(data) + return buffer + + return _CustomAllocationArrayNamespace diff --git a/src/gt4py/next/backend.py b/src/gt4py/next/backend.py index e223d7771c..5a02628012 100644 --- a/src/gt4py/next/backend.py +++ b/src/gt4py/next/backend.py @@ -13,7 +13,7 @@ from typing import Any, Generic from gt4py._core import definitions as core_defs -from gt4py.next import allocators as next_allocators +from gt4py.next import _allocators as next_allocators from gt4py.next.ffront import ( foast_to_gtir, foast_to_itir, diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index 23fe6c9e0c..8d3bc832d2 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -950,7 +950,9 @@ def _field( /, *, domain: Optional[DomainLike] = None, - allocator: Optional[Any] = None, # TODO: resolve the type annotation + allocation_ns: Optional[ + Any + ] = None, # TODO: should be `next_allocators.GTArrayAllocationNamespace` dtype: Optional[core_defs.DType] = None, ) -> Field: raise NotImplementedError @@ -964,7 +966,9 @@ def _connectivity( codomain: Dimension, *, domain: Optional[DomainLike] = None, - allocator: Optional[Any] = None, # TODO: resolve the type annotation + allocation_ns: Optional[ + Any + ] = None, # TODO: should be `next_allocators.GTArrayAllocationNamespace` dtype: Optional[core_defs.DType] = None, skip_value: Optional[core_defs.IntegralScalar] = None, ) -> Connectivity: diff --git a/src/gt4py/next/constructors.py b/src/gt4py/next/constructors.py index 9238e7b363..bc3dce3acb 100644 --- a/src/gt4py/next/constructors.py +++ b/src/gt4py/next/constructors.py @@ -8,68 +8,18 @@ from __future__ import annotations -import functools from collections.abc import Mapping, Sequence -from typing import Any, Optional, Protocol, TypeGuard, cast - -from typing_extensions import NotRequired, TypedDict +from typing import Optional, cast import gt4py._core.definitions as core_defs import gt4py.eve as eve import gt4py.eve.extended_typing as xtyping -import gt4py.next.allocators as next_allocators +import gt4py.next._allocators as next_allocators import gt4py.next.common as common import gt4py.next.embedded.nd_array_field as nd_array_field import gt4py.storage.cartesian.utils as storage_utils -class _HasArrayApiCreationFunctions(Protocol): - def empty(self, shape: Sequence[int], *, dtype=None, device=None) -> Any: ... - def zeros(self, shape: Sequence[int], *, dtype=None, device=None) -> Any: ... - def ones(self, shape: Sequence[int], *, dtype=None, device=None) -> Any: ... - def full(self, shape: Sequence[int], fill_value, *, dtype=None, device=None) -> Any: ... - def asarray(self, obj, *, dtype=None, copy=None) -> Any: ... - - -def _has_array_api_creation_functions(obj: Any) -> TypeGuard[_HasArrayApiCreationFunctions]: - return core_defs.is_array_api_namespace(obj) or ( - hasattr(obj, "empty") - and hasattr(obj, "zeros") - and hasattr(obj, "ones") - and hasattr(obj, "full") - and hasattr(obj, "asarray") - ) - - -def _array_api_construction( - xp: _HasArrayApiCreationFunctions, - fun: str, - *args, - domain: common.Domain, - dtype: Optional[core_defs.DTypeLike] = None, - device: Optional[core_defs.Device] = None, - **kwargs: Any, -): - if device is not None: - raise NotImplementedError("Device specification is not yet supported.") - buffer = getattr(xp, fun)(*args, dtype=core_defs.to_array_api_dtype(xp, dtype), **kwargs) - - def allocate( - domain: common.DomainLike = domain, - dtype: core_defs.DTypeLike = dtype, - *, - aligned_index: Optional[Sequence[common.NamedIndex]], - allocator: next_allocators.FieldBufferAllocationUtil | core_defs.ArrayApiNamespace = xp, - device: core_defs.Device = device, - ) -> core_defs.NDArrayObject: - # always returns an empty buffer by design - return empty( - domain, dtype=dtype, aligned_index=aligned_index, allocator=allocator, device=device - ) - - return common._field(buffer, domain=domain, allocator=allocate) - - @eve.utils.with_fluid_partial def empty( domain: common.DomainLike, @@ -127,30 +77,11 @@ def empty( >>> b.shape (3, 3) """ - if _has_array_api_creation_functions(allocator): - domain = common.domain(domain) - return _array_api_construction( - allocator, - "empty", - domain.shape, - domain=domain, - dtype=dtype, - device=device, - ) - - if allocator is None and device is None: - device = core_defs.Device(core_defs.DeviceType.CPU, device_id=0) - - allocate = functools.partial( - next_allocators.allocate, - domain=domain, - dtype=core_defs.dtype(dtype), - aligned_index=aligned_index, - allocator=allocator, - device=device, + gtarray_namespace = next_allocators.get_array_allocation_namespace(allocator) + buffer = gtarray_namespace.empty( + domain, device=device, dtype=dtype, aligned_index=aligned_index ) - buffer = allocate() - res = common._field(buffer, domain=domain, allocator=allocate) + res = common._field(buffer, domain=domain, allocation_ns=gtarray_namespace) assert isinstance(res, common.MutableField) assert isinstance(res, nd_array_field.NdArrayField) return res @@ -176,21 +107,14 @@ def zeros( >>> gtx.zeros({IDim: range(3, 10)}, allocator=gtx.itir_python).ndarray array([0., 0., 0., 0., 0., 0., 0.]) """ - if _has_array_api_creation_functions(allocator): - domain = common.domain(domain) - return _array_api_construction( - allocator, - "zeros", - domain.shape, - domain=domain, - dtype=dtype, - device=device, - ) - field = empty( - domain=domain, dtype=dtype, aligned_index=aligned_index, allocator=allocator, device=device + gtarray_namespace = next_allocators.get_array_allocation_namespace(allocator) + buffer = gtarray_namespace.zeros( + domain, device=device, dtype=dtype, aligned_index=aligned_index ) - field[...] = field.dtype.scalar_type(0) - return field + res = common._field(buffer, domain=domain, allocation_ns=gtarray_namespace) + assert isinstance(res, common.MutableField) + assert isinstance(res, nd_array_field.NdArrayField) + return res @eve.utils.with_fluid_partial @@ -213,21 +137,12 @@ def ones( >>> gtx.ones({IDim: range(3, 10)}, allocator=gtx.itir_python).ndarray array([1., 1., 1., 1., 1., 1., 1.]) """ - if _has_array_api_creation_functions(allocator): - domain = common.domain(domain) - return _array_api_construction( - allocator, - "ones", - domain.shape, - domain=domain, - dtype=dtype, - device=device, - ) - field = empty( - domain=domain, dtype=dtype, aligned_index=aligned_index, allocator=allocator, device=device - ) - field[...] = field.dtype.scalar_type(1) - return field + gtarray_namespace = next_allocators.get_array_allocation_namespace(allocator) + buffer = gtarray_namespace.ones(domain, device=device, dtype=dtype, aligned_index=aligned_index) + res = common._field(buffer, domain=domain, allocation_ns=gtarray_namespace) + assert isinstance(res, common.MutableField) + assert isinstance(res, nd_array_field.NdArrayField) + return res @eve.utils.with_fluid_partial @@ -256,26 +171,54 @@ def full( >>> gtx.full({IDim: 3}, 5, allocator=gtx.itir_python).ndarray array([5, 5, 5]) """ - if _has_array_api_creation_functions(allocator): - domain = common.domain(domain) - return _array_api_construction( - allocator, - "full", - domain.shape, - fill_value, - domain=domain, - dtype=dtype, - device=device, - ) - field = empty( - domain=domain, + gtarray_namespace = next_allocators.get_array_allocation_namespace(allocator) + buffer = gtarray_namespace.full( + domain, + fill_value, + device=device, dtype=dtype if dtype is not None else core_defs.dtype(type(fill_value)), aligned_index=aligned_index, - allocator=allocator, - device=device, ) - field[...] = field.dtype.scalar_type(fill_value) - return field + res = common._field(buffer, domain=domain, allocation_ns=gtarray_namespace) + assert isinstance(res, common.MutableField) + assert isinstance(res, nd_array_field.NdArrayField) + return res + + +def _actual_domain( + dims_or_domain: common.DomainLike | Sequence[common.Dimension], + shape: Sequence[int], + origin: Optional[Mapping[common.Dimension, int]] = None, +) -> common.Domain: + if isinstance(dims_or_domain, Sequence) and all( + isinstance(dim, common.Dimension) for dim in dims_or_domain + ): + dims = cast(Sequence[common.Dimension], dims_or_domain) + if len(dims) != len(shape): + raise ValueError( + f"Cannot construct 'Field' from array of shape '{shape}' and domain '{dims}'." + ) + if origin: + domain_dims = set(dims) + if unknown_dims := set(origin.keys()) - domain_dims: + raise ValueError(f"Origin keys {unknown_dims} not in domain {dims}.") + else: + origin = {} + return common.domain( + [ + (d, (-(start_offset := origin.get(d, 0)), s - start_offset)) + for d, s in zip(dims, shape) + ] + ) + else: + domain = common.domain(cast(common.DomainLike, dims_or_domain)) + if origin: + raise ValueError(f"Cannot specify origin for domain {domain}") + if domain.shape != shape: + raise ValueError( + f"Cannot construct 'Field' from array of shape '{shape}' and domain '{domain}'." + ) + return domain @eve.utils.with_fluid_partial @@ -333,63 +276,26 @@ def as_field( >>> gtx.as_field({IDim: range(-1, 2)}, xdata).domain.ranges[0] UnitRange(-1, 2) """ - if isinstance(domain, Sequence) and all(isinstance(dim, common.Dimension) for dim in domain): - domain = cast(Sequence[common.Dimension], domain) - if len(domain) != data.ndim: - raise ValueError( - f"Cannot construct 'Field' from array of shape '{data.shape}' and domain '{domain}'." - ) - if origin: - domain_dims = set(domain) - if unknown_dims := set(origin.keys()) - domain_dims: - raise ValueError(f"Origin keys {unknown_dims} not in domain {domain}.") - else: - origin = {} - actual_domain = common.domain( - [ - (d, (-(start_offset := origin.get(d, 0)), s - start_offset)) - for d, s in zip(domain, data.shape) - ] - ) - else: - if origin: - raise ValueError(f"Cannot specify origin for domain {domain}") - actual_domain = common.domain(cast(common.DomainLike, domain)) - if _has_array_api_creation_functions(allocator): - return _array_api_construction( - allocator, - "asarray", - data, - domain=actual_domain, - dtype=dtype, - device=device, - # TODO(havogt): copy=copy - ) + actual_domain = _actual_domain(dims_or_domain=domain, shape=data.shape, origin=origin) # TODO(egparedes): allow zero-copy construction (no reallocation) if buffer has # already the correct layout and device. - shape = storage_utils.asarray(data).shape - if shape != actual_domain.shape: - raise ValueError(f"Cannot construct 'Field' from array of shape '{shape}'.") - if dtype is None: - dtype = storage_utils.asarray(data).dtype - dtype = core_defs.dtype(dtype) - assert dtype.tensor_shape == () # TODO if (allocator is None) and (device is None) and xtyping.supports_dlpack(data): device = core_defs.Device(*data.__dlpack_device__()) - field = empty( + gtarray_namespace = next_allocators.get_array_allocation_namespace(allocator) + buffer = gtarray_namespace.asarray( + data, domain=actual_domain, dtype=dtype, - aligned_index=aligned_index, - allocator=allocator, device=device, + copy=True, # TODO(havogt) add support for zero-copy construction + aligned_index=aligned_index, ) + res = common._field(buffer, domain=actual_domain, allocation_ns=gtarray_namespace) - field[...] = field.array_ns.asarray(data) - - return field + return res # type: ignore[return-value] # it is an NDArrayField @eve.utils.with_fluid_partial @@ -460,54 +366,23 @@ def as_connectivity( if (allocator is None) and (device is None) and xtyping.supports_dlpack(data): device = core_defs.Device(*data.__dlpack_device__()) - allocate = functools.partial( - next_allocators.allocate, + + gtarray_namespace = next_allocators.get_array_allocation_namespace(allocator) + buffer = gtarray_namespace.asarray( + data, domain=actual_domain, dtype=dtype, - allocator=allocator, device=device, + copy=True, # TODO(havogt) add support for zero-copy construction ) - buffer = allocate() - # TODO(havogt): consider adding MutableNDArrayObject - buffer[...] = storage_utils.asarray(data) # type: ignore[index] connectivity_field = common._connectivity( - buffer, codomain=codomain, domain=actual_domain, skip_value=skip_value, allocator=allocate + buffer, + codomain=codomain, + domain=actual_domain, + skip_value=skip_value, + allocation_ns=gtarray_namespace, ) + assert isinstance(connectivity_field, nd_array_field.NdArrayConnectivityField) return connectivity_field - - -class _AllocatorParams(TypedDict): - domain: NotRequired[common.DomainLike] - dtype: NotRequired[core_defs.DType[core_defs.ScalarT],] - aligned_index: NotRequired[Sequence[common.NamedIndex]] - allocator: NotRequired[next_allocators.FieldBufferAllocatorProtocol] - device: NotRequired[core_defs.Device] - - -def empty_like( - field: nd_array_field.NdArrayField, - *, - domain: Optional[common.DomainLike] = None, - dtype: Optional[core_defs.DTypeLike] = None, - aligned_index: Optional[Sequence[common.NamedIndex]] = None, - allocator: Optional[next_allocators.FieldBufferAllocationUtil] = None, - device: Optional[core_defs.Device] = None, -) -> nd_array_field.NdArrayField: - kwargs: _AllocatorParams = {} - if domain is not None: - kwargs["domain"] = domain - if dtype is not None: - kwargs["dtype"] = core_defs.dtype(dtype) - if aligned_index is not None: - kwargs["aligned_index"] = aligned_index - if allocator is not None: - kwargs["allocator"] = allocator - if device is not None: - kwargs["device"] = device - if field._allocator is None: - raise ValueError("'Field' does not have an allocator.") # TODO discuss if this is possible - - allocate = functools.partial(field._allocator, **kwargs) - return common._field(allocate(), domain=kwargs.get("domain", field.domain), allocator=allocate) diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index c432dfbc45..182a1af1e2 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -29,7 +29,7 @@ TypeVar, cast, ) -from gt4py.next import allocators, common +from gt4py.next import _allocators, common from gt4py.next.embedded import ( common as embedded_common, context as embedded_context, @@ -116,7 +116,7 @@ class NdArrayField( _domain: common.Domain _ndarray: core_defs.NDArrayObject - _allocator: Optional[allocators.ConcreteAllocator] + _allocation_ns: Optional[_allocators.GTArrayAllocationNamespace] array_ns: ClassVar[ModuleType] # TODO(havogt) introduce a NDArrayNamespace protocol @@ -168,8 +168,8 @@ def from_array( /, *, domain: common.DomainLike, - allocator: Optional[ - allocators.ConcreteAllocator + allocation_ns: Optional[ + _allocators.GTArrayAllocationNamespace ] = None, # TODO: maybe an NDArrayField always has an allocator? dtype: Optional[core_defs.DTypeLike] = None, ) -> NdArrayField: @@ -188,7 +188,7 @@ def from_array( assert len(domain) == array.ndim assert all(s == 1 or len(r) == s for r, s in zip(domain.ranges, array.shape)) - return cls(domain, array, allocator) + return cls(domain, array, allocation_ns) def premap( self: NdArrayField, @@ -334,7 +334,9 @@ def restrict(self, index: common.AnyIndexSpec) -> NdArrayField: new_domain, buffer_slice = self._slice(index) new_buffer = self.ndarray[buffer_slice] new_buffer = self.__class__.array_ns.asarray(new_buffer) - return self.__class__.from_array(new_buffer, domain=new_domain, allocator=self._allocator) + return self.__class__.from_array( + new_buffer, domain=new_domain, allocation_ns=self._allocation_ns + ) __getitem__ = restrict @@ -434,9 +436,16 @@ def _slice( def __copy__(self) -> NdArrayField: # Note: `copy` copies the data, following NumPy behavior - ndarray_copy = self._allocator() - ndarray_copy[:] = self.ndarray[:] - return self.__class__(self.domain, ndarray_copy, _allocator=self._allocator) + allocation_ns = self._allocation_ns or _allocators.get_array_allocation_namespace( + self.array_ns + ) + ndarray_copy = allocation_ns.asarray( + self.ndarray, + domain=self.domain, + dtype=self.dtype, + copy=True, # aligned_index??? + ) + return self.__class__(self.domain, ndarray_copy, _allocation_ns=self._allocation_ns) def __deepcopy__(self, _: Any) -> NdArrayField: return self.__copy__() @@ -526,7 +535,7 @@ def from_array( # type: ignore[override] codomain: common.DimT, *, domain: common.DomainLike, - allocator: Optional[allocators.ConcreteAllocator] = None, + allocation_ns: Optional[_allocators.GTArrayAllocationNamespace] = None, dtype: Optional[core_defs.DTypeLike] = None, skip_value: Optional[core_defs.IntegralScalar] = None, ) -> NdArrayConnectivityField: @@ -547,7 +556,7 @@ def from_array( # type: ignore[override] assert isinstance(codomain, common.Dimension) - return cls(domain, array, allocator, codomain, _skip_value=skip_value) + return cls(domain, array, allocation_ns, codomain, _skip_value=skip_value) def inverse_image(self, image_range: common.UnitRange | common.NamedRange) -> common.Domain: cache_key = hash((id(self.ndarray), self.domain, image_range)) @@ -587,7 +596,7 @@ def restrict(self, index: common.AnyIndexSpec) -> NdArrayConnectivityField: restricted_connectivity = cls( new_domain, new_buffer, - _allocator=self._allocator, + _allocation_ns=self._allocation_ns, _codomain=self._codomain, _skip_value=self._skip_value, ) @@ -615,7 +624,7 @@ def _domain_premap(data: NdArrayField, *connectivities: common.Connectivity) -> new_domain = new_domain.replace(dim_idx, *new_ranges) return data.__class__.from_array( - data._ndarray, domain=new_domain, dtype=data.dtype, allocator=data._allocator + data._ndarray, domain=new_domain, dtype=data.dtype, allocation_ns=data._allocation_ns ) @@ -656,7 +665,10 @@ def _reshuffling_premap( conn_ndarray = xp.broadcast_to(conn_ndarray, data.domain.shape) if conn_ndarray is not conn.ndarray: conn = conn.__class__.from_array( - conn_ndarray, domain=data.domain, codomain=conn.codomain, allocator=conn._allocator + conn_ndarray, + domain=data.domain, + codomain=conn.codomain, + allocation_ns=conn._allocation_ns, ) conn_map[conn.codomain] = conn dim_idx = data.domain.dim_index(conn.codomain, allow_missing=False) @@ -686,7 +698,7 @@ def _reshuffling_premap( new_buffer, domain=new_domain, dtype=data.dtype, - allocator=data._allocator, + allocation_ns=data._allocation_ns, ) @@ -727,7 +739,7 @@ def _remapping_premap(data: NdArrayField, connectivity: common.Connectivity) -> new_buffer, domain=new_domain, dtype=data.dtype, - allocator=data._allocator, + allocation_ns=data._allocation_ns, ) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 61756f30c9..a8a7c4e730 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -24,7 +24,7 @@ from gt4py._core import definitions as core_defs from gt4py.eve import extended_typing as xtyping from gt4py.next import ( - allocators as next_allocators, + _allocators as next_allocators, backend as next_backend, common, embedded as next_embedded, diff --git a/src/gt4py/next/program_processors/runners/dace.py b/src/gt4py/next/program_processors/runners/dace.py index 1b3b930818..45bc9908ef 100644 --- a/src/gt4py/next/program_processors/runners/dace.py +++ b/src/gt4py/next/program_processors/runners/dace.py @@ -9,7 +9,7 @@ import factory import gt4py._core.definitions as core_defs -import gt4py.next.allocators as next_allocators +import gt4py.next._allocators as next_allocators from gt4py.next import backend from gt4py.next.otf import workflow from gt4py.next.program_processors.runners.dace_fieldview import workflow as dace_fieldview_workflow diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py b/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py index 40d44f5ab0..edf510ca2e 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py @@ -16,7 +16,7 @@ import factory from gt4py._core import definitions as core_defs -from gt4py.next import allocators as gtx_allocators, common, config +from gt4py.next import _allocators as gtx_allocators, common, config from gt4py.next.iterator import ir as itir, transforms as itir_transforms from gt4py.next.otf import languages, recipes, stages, step_types, workflow from gt4py.next.otf.binding import interface diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index 55f479c665..2847540e6c 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -17,7 +17,7 @@ import filelock import gt4py._core.definitions as core_defs -import gt4py.next.allocators as next_allocators +import gt4py.next._allocators as next_allocators from gt4py.eve import utils from gt4py.eve.utils import content_hash from gt4py.next import backend, common, config diff --git a/src/gt4py/next/program_processors/runners/roundtrip.py b/src/gt4py/next/program_processors/runners/roundtrip.py index 1dd568b95a..57df1fe228 100644 --- a/src/gt4py/next/program_processors/runners/roundtrip.py +++ b/src/gt4py/next/program_processors/runners/roundtrip.py @@ -20,7 +20,7 @@ from gt4py.eve import codegen from gt4py.eve.codegen import FormatTemplate as as_fmt, MakoTemplate as as_mako -from gt4py.next import allocators as next_allocators, backend as next_backend, common, config +from gt4py.next import _allocators as next_allocators, backend as next_backend, common, config from gt4py.next.ffront import foast_to_gtir, foast_to_past, past_to_itir from gt4py.next.iterator import ir as itir, transforms as itir_transforms from gt4py.next.otf import stages, workflow diff --git a/src/gt4py/storage/allocators.py b/src/gt4py/storage/allocators.py index 9290402b2c..dbd3113a77 100644 --- a/src/gt4py/storage/allocators.py +++ b/src/gt4py/storage/allocators.py @@ -75,7 +75,7 @@ def allocate( layout_map: BufferLayoutMap, byte_alignment: int, aligned_index: Optional[Sequence[int]] = None, - ) -> core_defs.NDArrayObject: + ) -> core_defs.MutableNDArrayObject: """ Allocate an NDArrayObject with the given shape, layout and alignment settings. @@ -109,7 +109,7 @@ def allocate( layout_map: BufferLayoutMap, byte_alignment: int, aligned_index: Optional[Sequence[int]] = None, - ) -> core_defs.NDArrayObject: + ) -> core_defs.MutableNDArrayObject: if not core_defs.is_valid_tensor_shape(shape): raise ValueError(f"Invalid shape {shape}") ndim = len(shape) @@ -190,7 +190,7 @@ def tensorize( item_size: int, strides: Sequence[int], byte_offset: int, - ) -> core_defs.NDArrayObject: + ) -> core_defs.MutableNDArrayObject: """Create shaped view from buffer.""" pass @@ -200,7 +200,7 @@ class ArrayUtils: array_ns: types.ModuleType empty: Callable[..., _NDBuffer] byte_bounds: Callable[[_NDBuffer], Tuple[int, int]] - as_strided: Callable[..., core_defs.NDArrayObject] + as_strided: Callable[..., core_defs.MutableNDArrayObject] numpy_array_utils = ArrayUtils( @@ -258,7 +258,7 @@ def tensorize( item_size: int, strides: Sequence[int], byte_offset: int, - ) -> core_defs.NDArrayObject: + ) -> core_defs.MutableNDArrayObject: aligned_buffer = buffer[byte_offset : byte_offset + math.prod(allocated_shape) * item_size] # type: ignore[index] # TODO(egparedes): should we extend `_NDBuffer`s to cover __getitem__? flat_ndarray = aligned_buffer.view(dtype=np.dtype(dtype)) tensor_view = self._array_utils.as_strided( diff --git a/src/gt4py/storage/cartesian/utils.py b/src/gt4py/storage/cartesian/utils.py index ac12ba17c1..b94abb915c 100644 --- a/src/gt4py/storage/cartesian/utils.py +++ b/src/gt4py/storage/cartesian/utils.py @@ -267,7 +267,7 @@ def _allocate_gpu( dtype: DTypeLike, alignment_bytes: int, aligned_index: Optional[Sequence[int]], -) -> Tuple["cp.ndarray", "cp.ndarray"]: +) -> "cp.ndarray": assert cp is not None assert _GPUBufferAllocator is not None, "GPU allocation library or device not found" device = core_defs.Device( # type: ignore[type-var] diff --git a/tests/next_tests/definitions.py b/tests/next_tests/definitions.py index 4ec5367522..85acf23300 100644 --- a/tests/next_tests/definitions.py +++ b/tests/next_tests/definitions.py @@ -14,7 +14,7 @@ import pytest -from gt4py.next import allocators as next_allocators +from gt4py.next import _allocators as next_allocators # Skip definitions @@ -64,7 +64,11 @@ class EmbeddedDummyBackend: allocator: next_allocators.FieldBufferAllocatorProtocol -numpy_execution = EmbeddedDummyBackend(next_allocators.StandardCPUFieldBufferAllocator()) +import numpy as np + + +# numpy_execution = EmbeddedDummyBackend(next_allocators.StandardCPUFieldBufferAllocator()) +numpy_execution = EmbeddedDummyBackend(np) cupy_execution = EmbeddedDummyBackend(next_allocators.StandardGPUFieldBufferAllocator()) jax_execution = EmbeddedDummyBackend(jnp) @@ -122,6 +126,7 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): USES_MAX_OVER = "uses_max_over" USES_MESH_WITH_SKIP_VALUES = "uses_mesh_with_skip_values" USES_SCALAR_IN_DOMAIN_AND_FO = "uses_scalar_in_domain_and_fo" +SLICES_OUT_ARGUMENT = "slices_out_argument" CHECKS_SPECIFIC_ERROR = "checks_specific_error" # Skip messages (available format keys: 'marker', 'backend') @@ -157,6 +162,9 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): UNSUPPORTED_MESSAGE, ), # we can't extract the field type from scan args ] +JAX_SKIP_LIST = EMBEDDED_SKIP_LIST + [ + (SLICES_OUT_ARGUMENT, XFAIL, UNSUPPORTED_MESSAGE), +] ROUNDTRIP_SKIP_LIST = DOMAIN_INFERENCE_SKIP_LIST + [ (USES_SPARSE_FIELDS_AS_OUTPUT, XFAIL, UNSUPPORTED_MESSAGE), ] @@ -179,7 +187,7 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): BACKEND_SKIP_TEST_MATRIX = { EmbeddedIds.NUMPY_EXECUTION: EMBEDDED_SKIP_LIST, EmbeddedIds.CUPY_EXECUTION: EMBEDDED_SKIP_LIST, - EmbeddedIds.JAX_EXECUTION: EMBEDDED_SKIP_LIST, + EmbeddedIds.JAX_EXECUTION: JAX_SKIP_LIST, OptionalProgramBackendId.DACE_CPU: DACE_SKIP_TEST_LIST, OptionalProgramBackendId.DACE_GPU: DACE_SKIP_TEST_LIST, OptionalProgramBackendId.DACE_CPU_NO_OPT: DACE_SKIP_TEST_LIST, diff --git a/tests/next_tests/integration_tests/cases.py b/tests/next_tests/integration_tests/cases.py index 759cd1cf1f..9faff154da 100644 --- a/tests/next_tests/integration_tests/cases.py +++ b/tests/next_tests/integration_tests/cases.py @@ -23,11 +23,12 @@ from gt4py.eve import extended_typing as xtyping from gt4py.eve.extended_typing import Self from gt4py.next import ( - allocators as next_allocators, + _allocators as next_allocators, backend as next_backend, common, constructors, field_utils, + utils as gt_utils, ) from gt4py.next.ffront import decorator from gt4py.next.type_system import type_specifications as ts, type_translation @@ -55,7 +56,6 @@ mesh_descriptor, ) -from gt4py.next import utils as gt_utils # mypy does not accept [IDim, ...] as a type diff --git a/tests/next_tests/integration_tests/feature_tests/dace/test_orchestration.py b/tests/next_tests/integration_tests/feature_tests/dace/test_orchestration.py index 08904c06f3..a5ce949e04 100644 --- a/tests/next_tests/integration_tests/feature_tests/dace/test_orchestration.py +++ b/tests/next_tests/integration_tests/feature_tests/dace/test_orchestration.py @@ -6,11 +6,13 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +from typing import Optional + import numpy as np import pytest import gt4py.next as gtx -from gt4py.next import allocators as gtx_allocators, common as gtx_common +from gt4py.next import _allocators as gtx_allocators, common as gtx_common from next_tests.integration_tests import cases from next_tests.integration_tests.cases import cartesian_case, unstructured_case diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py index 139df658a1..c83c433545 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py @@ -14,7 +14,7 @@ import gt4py.next as gtx from gt4py._core import definitions as core_defs -from gt4py.next import allocators as next_allocators, backend as next_backend, common +from gt4py.next import _allocators as next_allocators, backend as next_backend, common from gt4py.next.ffront import decorator import next_tests diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py index f1cb8ffb17..57607fefc3 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py @@ -53,6 +53,7 @@ def test_identity_fo_execution(cartesian_case, identity_def): @pytest.mark.uses_cartesian_shift +@pytest.mark.slices_out_argument def test_shift_by_one_execution(cartesian_case): @gtx.field_operator def shift_by_one(in_field: cases.IFloatField) -> cases.IFloatField: @@ -95,6 +96,7 @@ def test_double_copy_execution(cartesian_case, double_copy_program_def): ) +@pytest.mark.slices_out_argument def test_copy_restricted_execution(cartesian_case, copy_restrict_program_def): copy_restrict_program = gtx.program(copy_restrict_program_def, backend=cartesian_case.backend) @@ -154,6 +156,7 @@ def prog( assert np.allclose((a.asnumpy(), b.asnumpy()), (out_a.asnumpy(), out_b.asnumpy())) +@pytest.mark.slices_out_argument def test_tuple_program_return_constructed_inside_with_slicing(cartesian_case): @gtx.field_operator def pack_tuple( diff --git a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_ffront_fvm_nabla.py b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_ffront_fvm_nabla.py index 6c6ca7e4bc..5c914ecede 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_ffront_fvm_nabla.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_ffront_fvm_nabla.py @@ -11,10 +11,11 @@ import numpy as np import pytest + pytest.importorskip("atlas4py") from gt4py import next as gtx -from gt4py.next import allocators, neighbor_sum +from gt4py.next import _allocators, neighbor_sum from gt4py.next.iterator import atlas_utils from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_gtfn.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_gtfn.py index 3d82dd8ee5..3df9d11015 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_gtfn.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_gtfn.py @@ -20,7 +20,7 @@ """ import gt4py._core.definitions as core_defs -from gt4py.next import allocators, config +from gt4py.next import _allocators, config from gt4py.next.iterator import transforms from gt4py.next.iterator.transforms import global_tmps from gt4py.next.otf import workflow @@ -40,8 +40,8 @@ def test_backend_factory_trait_device(): assert cpu_version.executor.decoration.keywords["device"] is core_defs.DeviceType.CPU assert gpu_version.executor.decoration.keywords["device"] is core_defs.DeviceType.CUDA - assert allocators.is_field_allocator_for(cpu_version.allocator, core_defs.DeviceType.CPU) - assert allocators.is_field_allocator_for(gpu_version.allocator, core_defs.DeviceType.CUDA) + assert _allocators.is_field_allocator_for(cpu_version.allocator, core_defs.DeviceType.CPU) + assert _allocators.is_field_allocator_for(gpu_version.allocator, core_defs.DeviceType.CUDA) def test_backend_factory_trait_cached(): diff --git a/tests/next_tests/unit_tests/test_allocators.py b/tests/next_tests/unit_tests/test_allocators.py index c8f395b29b..0d91620105 100644 --- a/tests/next_tests/unit_tests/test_allocators.py +++ b/tests/next_tests/unit_tests/test_allocators.py @@ -12,7 +12,7 @@ import pytest import gt4py._core.definitions as core_defs -import gt4py.next.allocators as next_allocators +import gt4py.next._allocators as next_allocators import gt4py.next.common as common import gt4py.storage.allocators as core_allocators @@ -108,7 +108,7 @@ def test_get_allocator(): def test_horizontal_first_layout_mapper(): - from gt4py.next.allocators import horizontal_first_layout_mapper + from gt4py.next._allocators import horizontal_first_layout_mapper # Test with only horizontal dimensions dims = [ @@ -152,7 +152,7 @@ def test_allocate(self): def test_allocate(): - from gt4py.next.allocators import StandardCPUFieldBufferAllocator, make_concrete_allocator + from gt4py.next._allocators import StandardCPUFieldBufferAllocator, make_concrete_allocator I = common.Dimension("I") J = common.Dimension("J") diff --git a/tests/next_tests/unit_tests/test_constructors.py b/tests/next_tests/unit_tests/test_constructors.py index cd0c0014ce..984d87fd00 100644 --- a/tests/next_tests/unit_tests/test_constructors.py +++ b/tests/next_tests/unit_tests/test_constructors.py @@ -27,7 +27,7 @@ from gt4py import next as gtx from gt4py._core import definitions as core_defs -from gt4py.next import allocators as next_allocators, common +from gt4py.next import _allocators as next_allocators, common I = gtx.Dimension("I") From 7817ec302b0390961639d055b6dfc073ebe953ca Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Thu, 19 Dec 2024 20:47:51 +0100 Subject: [PATCH 09/12] fix compat --- pyproject.toml | 2 +- src/gt4py/_core/gt_array_namespace.py | 8 -------- 2 files changed, 1 insertion(+), 9 deletions(-) delete mode 100644 src/gt4py/_core/gt_array_namespace.py diff --git a/pyproject.toml b/pyproject.toml index da33899965..8b91890d9b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,7 @@ classifiers = [ 'Topic :: Scientific/Engineering :: Physics' ] dependencies = [ - "array-api-compat>=1.9.1", + "array-api-compat>=1.9.1;python_version>='3.10'", "astunparse>=1.6.3;python_version<'3.9'", 'attrs>=21.3', 'black>=22.3', diff --git a/src/gt4py/_core/gt_array_namespace.py b/src/gt4py/_core/gt_array_namespace.py deleted file mode 100644 index abf4c3e24c..0000000000 --- a/src/gt4py/_core/gt_array_namespace.py +++ /dev/null @@ -1,8 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - From 6709b04bad1201294a307158c6ff26dbadf34d1f Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Thu, 19 Dec 2024 21:16:09 +0100 Subject: [PATCH 10/12] fix tests --- src/gt4py/next/_allocators.py | 10 +++++----- .../runners_tests/dace_tests/test_dace.py | 6 +++--- tests/next_tests/unit_tests/test_allocators.py | 14 +++++++------- 3 files changed, 15 insertions(+), 15 deletions(-) diff --git a/src/gt4py/next/_allocators.py b/src/gt4py/next/_allocators.py index 1f2e03f9fe..440aa6c0de 100644 --- a/src/gt4py/next/_allocators.py +++ b/src/gt4py/next/_allocators.py @@ -520,7 +520,7 @@ def empty( device: Optional[core_defs.Device] = None, aligned_index: Optional[Sequence[common.NamedIndex]] = None, ) -> core_defs.NDArrayObject: - assert is_field_allocator(allocator) + assert is_field_allocation_tool(allocator) return allocate( domain=common.domain(domain), dtype=core_defs.dtype(dtype), @@ -537,7 +537,7 @@ def zeros( device: Optional[core_defs.Device] = None, aligned_index: Optional[Sequence[common.NamedIndex]] = None, ) -> core_defs.NDArrayObject: - assert is_field_allocator(allocator) + assert is_field_allocation_tool(allocator) buffer = allocate( domain=common.domain(domain), dtype=core_defs.dtype(dtype), @@ -556,7 +556,7 @@ def ones( device: Optional[core_defs.Device] = None, aligned_index: Optional[Sequence[common.NamedIndex]] = None, ) -> core_defs.NDArrayObject: - assert is_field_allocator(allocator) + assert is_field_allocation_tool(allocator) buffer = allocate( domain=common.domain(domain), dtype=core_defs.dtype(dtype), @@ -576,7 +576,7 @@ def full( device: Optional[core_defs.Device] = None, aligned_index: Optional[Sequence[common.NamedIndex]] = None, ) -> core_defs.NDArrayObject: - assert is_field_allocator(allocator) + assert is_field_allocation_tool(allocator) buffer = allocate( domain=common.domain(domain), dtype=core_defs.dtype(dtype), # TODO check all dtypes @@ -597,7 +597,7 @@ def asarray( copy: Optional[bool] = None, aligned_index: Optional[Sequence[common.NamedIndex]] = None, ) -> core_defs.NDArrayObject: - assert is_field_allocator(allocator) + assert is_field_allocation_tool(allocator) if not copy: raise NotImplementedError("Zero-copy construction is not yet supported.") dtype = core_defs.dtype(data.dtype) if dtype is None else core_defs.dtype(dtype) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace.py index 62d88d9f0a..4ef69ec429 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace.py @@ -173,13 +173,13 @@ def verify_testee(offset_provider): ) mock_fast_call.assert_called_once() - if gtx.allocators.is_field_allocator_for( + if gtx._allocators.is_field_allocator_for( unstructured_case.backend.allocator, core_defs.DeviceType.CPU ): offset_provider = unstructured_case.offset_provider else: - assert gtx.allocators.is_field_allocator_for( - unstructured_case.backend.allocator, gtx.allocators.CUPY_DEVICE + assert gtx._allocators.is_field_allocator_for( + unstructured_case.backend.allocator, gtx._allocators.CUPY_DEVICE ) import cupy as cp diff --git a/tests/next_tests/unit_tests/test_allocators.py b/tests/next_tests/unit_tests/test_allocators.py index 0d91620105..1aeff1d3ec 100644 --- a/tests/next_tests/unit_tests/test_allocators.py +++ b/tests/next_tests/unit_tests/test_allocators.py @@ -152,7 +152,7 @@ def test_allocate(self): def test_allocate(): - from gt4py.next._allocators import StandardCPUFieldBufferAllocator, make_concrete_allocator + from gt4py.next._allocators import StandardCPUFieldBufferAllocator, allocate I = common.Dimension("I") J = common.Dimension("J") @@ -161,25 +161,25 @@ def test_allocate(): # Test with a explicit field allocator allocator = StandardCPUFieldBufferAllocator() - tensor_buffer = make_concrete_allocator(domain, dtype, allocator=allocator)() + tensor_buffer = allocate(domain=domain, dtype=dtype, allocator=allocator) assert tensor_buffer.shape == domain.shape assert tensor_buffer.dtype == dtype # Test with a device device = core_defs.Device(core_defs.DeviceType.CPU, 0) - tensor_buffer = make_concrete_allocator(domain, dtype, device=device)() + tensor_buffer = allocate(domain=domain, dtype=dtype, device=device) assert tensor_buffer.shape == domain.shape assert tensor_buffer.dtype == dtype # Test with both allocator and device with pytest.raises(ValueError, match="are incompatible"): - make_concrete_allocator( - domain, - dtype, + allocate( + domain=domain, + dtype=dtype, allocator=allocator, device=core_defs.Device(core_defs.DeviceType.CUDA, 0), ) # Test with no device or allocator with pytest.raises(ValueError, match="No 'device' or 'allocator' specified"): - make_concrete_allocator(domain, dtype) + allocate(domain=domain, dtype=dtype) From 3ee60b8dbe1961ca200daa3555b29c10f9547cff Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Fri, 20 Dec 2024 09:50:08 +0100 Subject: [PATCH 11/12] move device to allocation namespace construction --- src/gt4py/next/_allocators.py | 240 +++++++++--------- src/gt4py/next/constructors.py | 25 +- src/gt4py/next/embedded/nd_array_field.py | 6 +- .../unit_tests/test_constructors.py | 3 - 4 files changed, 126 insertions(+), 148 deletions(-) diff --git a/src/gt4py/next/_allocators.py b/src/gt4py/next/_allocators.py index 440aa6c0de..3b08056653 100644 --- a/src/gt4py/next/_allocators.py +++ b/src/gt4py/next/_allocators.py @@ -63,7 +63,6 @@ def empty( domain: common.DomainLike, *, dtype: Optional[core_defs.DTypeLike] = None, - device: Optional[core_defs.Device] = None, aligned_index: Optional[Sequence[common.NamedIndex]] = None, ) -> core_defs.NDArrayObject: ... @@ -72,7 +71,6 @@ def zeros( domain: common.DomainLike, *, dtype: Optional[core_defs.DTypeLike] = None, - device: Optional[core_defs.Device] = None, aligned_index: Optional[Sequence[common.NamedIndex]] = None, ) -> core_defs.NDArrayObject: ... @@ -81,7 +79,6 @@ def ones( domain: common.DomainLike, *, dtype: Optional[core_defs.DTypeLike] = None, - device: Optional[core_defs.Device] = None, aligned_index: Optional[Sequence[common.NamedIndex]] = None, ) -> core_defs.NDArrayObject: ... @@ -91,7 +88,6 @@ def full( fill_value: core_defs.Scalar, *, dtype: Optional[core_defs.DTypeLike] = None, - device: Optional[core_defs.Device] = None, aligned_index: Optional[Sequence[common.NamedIndex]] = None, ) -> core_defs.NDArrayObject: ... @@ -101,7 +97,6 @@ def asarray( *, domain: common.DomainLike, dtype: Optional[core_defs.DTypeLike] = None, - device: Optional[core_defs.Device] = None, copy: Optional[bool] = None, aligned_index: Optional[Sequence[common.NamedIndex]] = None, ) -> core_defs.NDArrayObject: ... @@ -362,9 +357,9 @@ def allocate( *, domain: common.Domain, dtype: core_defs.DType[core_defs.ScalarT], + allocator: FieldBufferAllocatorProtocol, + device: core_defs.Device, aligned_index: Optional[Sequence[common.NamedIndex]] = None, - allocator: Optional[FieldBufferAllocationUtil] = None, - device: Optional[core_defs.Device] = None, ) -> core_defs.MutableNDArrayObject: """ TODO: docstring @@ -390,18 +385,8 @@ def allocate( If illegal or inconsistent arguments are specified. """ - if device is None and allocator is None: - raise ValueError("No 'device' or 'allocator' specified.") - actual_allocator = get_allocator(allocator) - if actual_allocator is None: - assert device is not None # for mypy - actual_allocator = device_allocators[device.device_type] - elif device is None: - device = core_defs.Device(actual_allocator.__gt_device_type__, 0) - elif device.device_type != actual_allocator.__gt_device_type__: - raise ValueError(f"Device '{device}' and allocator '{actual_allocator}' are incompatible.") - return actual_allocator.__gt_allocate__( + return allocator.__gt_allocate__( domain=domain, dtype=dtype, device_id=device.device_id, @@ -419,11 +404,26 @@ def _check_unsupported_device_and_aligned_index( raise NotImplementedError("Device specification is not yet supported.") +def _get_actual_allocator_and_device( + allocator: Optional[FieldBufferAllocationUtil], device: Optional[core_defs.Device] +) -> tuple[FieldBufferAllocatorProtocol, core_defs.Device]: + if device is None and allocator is None: + raise ValueError("No 'device' or 'allocator' specified.") + actual_allocator = get_allocator(allocator) + if actual_allocator is None: + assert device is not None # for mypy + actual_allocator = device_allocators[device.device_type] + elif device is None: + device = core_defs.Device(actual_allocator.__gt_device_type__, 0) + elif device.device_type != actual_allocator.__gt_device_type__: + raise ValueError(f"Device '{device}' and allocator '{actual_allocator}' are incompatible.") + return actual_allocator, device + + def get_array_allocation_namespace( - allocator: FieldBufferAllocationUtil | core_defs.ArrayApiNamespace | None, + allocator: Optional[FieldBufferAllocationUtil | core_defs.ArrayApiNamespace], + device: Optional[core_defs.Device] = None, ) -> GTArrayAllocationNamespace: - if allocator is None: - allocator = StandardCPUFieldBufferAllocator() if core_defs.is_array_api_namespace(allocator): assert core_defs.is_array_api_namespace(allocator) array_ns = array_api_compat.array_namespace(allocator.empty([0])) @@ -434,7 +434,6 @@ def empty( domain: common.DomainLike, *, dtype: Optional[core_defs.DTypeLike] = None, - device: Optional[core_defs.Device] = None, aligned_index: Optional[Sequence[common.NamedIndex]] = None, ) -> core_defs.NDArrayObject: _check_unsupported_device_and_aligned_index(device, aligned_index) @@ -448,7 +447,6 @@ def zeros( domain: common.DomainLike, *, dtype: Optional[core_defs.DTypeLike] = None, - device: Optional[core_defs.Device] = None, aligned_index: Optional[Sequence[common.NamedIndex]] = None, ) -> core_defs.NDArrayObject: _check_unsupported_device_and_aligned_index(device, aligned_index) @@ -462,7 +460,6 @@ def ones( domain: common.DomainLike, *, dtype: Optional[core_defs.DTypeLike] = None, - device: Optional[core_defs.Device] = None, aligned_index: Optional[Sequence[common.NamedIndex]] = None, ) -> core_defs.NDArrayObject: _check_unsupported_device_and_aligned_index(device, aligned_index) @@ -477,7 +474,6 @@ def full( fill_value: core_defs.Scalar, *, dtype: Optional[core_defs.DTypeLike] = None, - device: Optional[core_defs.Device] = None, aligned_index: Optional[Sequence[common.NamedIndex]] = None, ) -> core_defs.NDArrayObject: _check_unsupported_device_and_aligned_index(device, aligned_index) @@ -493,7 +489,6 @@ def asarray( *, domain: common.DomainLike, dtype: Optional[core_defs.DTypeLike] = None, - device: Optional[core_defs.Device] = None, copy: Optional[bool] = None, aligned_index: Optional[Sequence[common.NamedIndex]] = None, ) -> core_defs.NDArrayObject: @@ -509,106 +504,97 @@ def asarray( return _ArrayNamespaceWrapper - else: - - class _CustomAllocationArrayNamespace: - @staticmethod - def empty( - domain: common.DomainLike, - *, - dtype: Optional[core_defs.DTypeLike] = None, - device: Optional[core_defs.Device] = None, - aligned_index: Optional[Sequence[common.NamedIndex]] = None, - ) -> core_defs.NDArrayObject: - assert is_field_allocation_tool(allocator) - return allocate( - domain=common.domain(domain), - dtype=core_defs.dtype(dtype), - aligned_index=aligned_index, - allocator=allocator, - device=device, - ) - - @staticmethod - def zeros( - domain: common.DomainLike, - *, - dtype: Optional[core_defs.DTypeLike] = None, - device: Optional[core_defs.Device] = None, - aligned_index: Optional[Sequence[common.NamedIndex]] = None, - ) -> core_defs.NDArrayObject: - assert is_field_allocation_tool(allocator) - buffer = allocate( - domain=common.domain(domain), - dtype=core_defs.dtype(dtype), - aligned_index=aligned_index, - allocator=allocator, - device=device, - ) - buffer[...] = 0 - return buffer - - @staticmethod - def ones( - domain: common.DomainLike, - *, - dtype: Optional[core_defs.DTypeLike] = None, - device: Optional[core_defs.Device] = None, - aligned_index: Optional[Sequence[common.NamedIndex]] = None, - ) -> core_defs.NDArrayObject: - assert is_field_allocation_tool(allocator) - buffer = allocate( - domain=common.domain(domain), - dtype=core_defs.dtype(dtype), - aligned_index=aligned_index, - allocator=allocator, - device=device, - ) - buffer[...] = 1 - return buffer - - @staticmethod - def full( - domain: common.DomainLike, - fill_value: core_defs.Scalar, - *, - dtype: Optional[core_defs.DTypeLike] = None, - device: Optional[core_defs.Device] = None, - aligned_index: Optional[Sequence[common.NamedIndex]] = None, - ) -> core_defs.NDArrayObject: - assert is_field_allocation_tool(allocator) - buffer = allocate( - domain=common.domain(domain), - dtype=core_defs.dtype(dtype), # TODO check all dtypes - aligned_index=aligned_index, - allocator=allocator, - device=device, - ) - buffer[...] = fill_value - return buffer + assert is_field_allocation_tool(allocator) or allocator is None + actual_allocator, actual_device = _get_actual_allocator_and_device(allocator, device) + + class _CustomAllocationArrayNamespace: + @staticmethod + def empty( + domain: common.DomainLike, + *, + dtype: Optional[core_defs.DTypeLike] = None, + aligned_index: Optional[Sequence[common.NamedIndex]] = None, + ) -> core_defs.NDArrayObject: + return allocate( + domain=common.domain(domain), + dtype=core_defs.dtype(dtype), + aligned_index=aligned_index, + allocator=actual_allocator, + device=actual_device, + ) - @staticmethod - def asarray( - data: core_defs.NDArrayObject, - *, - domain: common.DomainLike, - dtype: Optional[core_defs.DTypeLike] = None, - device: Optional[core_defs.Device] = None, - copy: Optional[bool] = None, - aligned_index: Optional[Sequence[common.NamedIndex]] = None, - ) -> core_defs.NDArrayObject: - assert is_field_allocation_tool(allocator) - if not copy: - raise NotImplementedError("Zero-copy construction is not yet supported.") - dtype = core_defs.dtype(data.dtype) if dtype is None else core_defs.dtype(dtype) - buffer = allocate( - domain=common.domain(domain), - dtype=dtype, - aligned_index=aligned_index, - allocator=allocator, - device=device, - ) - buffer[...] = array_api_compat.array_namespace(buffer).asarray(data) - return buffer + @staticmethod + def zeros( + domain: common.DomainLike, + *, + dtype: Optional[core_defs.DTypeLike] = None, + aligned_index: Optional[Sequence[common.NamedIndex]] = None, + ) -> core_defs.NDArrayObject: + buffer = allocate( + domain=common.domain(domain), + dtype=core_defs.dtype(dtype), + aligned_index=aligned_index, + allocator=actual_allocator, + device=actual_device, + ) + buffer[...] = 0 + return buffer + + @staticmethod + def ones( + domain: common.DomainLike, + *, + dtype: Optional[core_defs.DTypeLike] = None, + aligned_index: Optional[Sequence[common.NamedIndex]] = None, + ) -> core_defs.NDArrayObject: + buffer = allocate( + domain=common.domain(domain), + dtype=core_defs.dtype(dtype), + aligned_index=aligned_index, + allocator=actual_allocator, + device=actual_device, + ) + buffer[...] = 1 + return buffer + + @staticmethod + def full( + domain: common.DomainLike, + fill_value: core_defs.Scalar, + *, + dtype: Optional[core_defs.DTypeLike] = None, + aligned_index: Optional[Sequence[common.NamedIndex]] = None, + ) -> core_defs.NDArrayObject: + buffer = allocate( + domain=common.domain(domain), + dtype=core_defs.dtype(dtype), # TODO check all dtypes + aligned_index=aligned_index, + allocator=actual_allocator, + device=actual_device, + ) + buffer[...] = fill_value + return buffer + + @staticmethod + def asarray( + data: core_defs.NDArrayObject, + *, + domain: common.DomainLike, + dtype: Optional[core_defs.DTypeLike] = None, + copy: Optional[bool] = None, + aligned_index: Optional[Sequence[common.NamedIndex]] = None, + ) -> core_defs.NDArrayObject: + if not copy: + raise NotImplementedError("Zero-copy construction is not yet supported.") + dtype = core_defs.dtype(data.dtype) if dtype is None else core_defs.dtype(dtype) + buffer = allocate( + domain=common.domain(domain), + dtype=dtype, + aligned_index=aligned_index, + allocator=actual_allocator, + device=actual_device, + ) + buffer[...] = array_api_compat.array_namespace(buffer).asarray(data) + return buffer - return _CustomAllocationArrayNamespace + return _CustomAllocationArrayNamespace diff --git a/src/gt4py/next/constructors.py b/src/gt4py/next/constructors.py index bc3dce3acb..21b4b63636 100644 --- a/src/gt4py/next/constructors.py +++ b/src/gt4py/next/constructors.py @@ -77,10 +77,8 @@ def empty( >>> b.shape (3, 3) """ - gtarray_namespace = next_allocators.get_array_allocation_namespace(allocator) - buffer = gtarray_namespace.empty( - domain, device=device, dtype=dtype, aligned_index=aligned_index - ) + gtarray_namespace = next_allocators.get_array_allocation_namespace(allocator, device) + buffer = gtarray_namespace.empty(domain, dtype=dtype, aligned_index=aligned_index) res = common._field(buffer, domain=domain, allocation_ns=gtarray_namespace) assert isinstance(res, common.MutableField) assert isinstance(res, nd_array_field.NdArrayField) @@ -107,10 +105,8 @@ def zeros( >>> gtx.zeros({IDim: range(3, 10)}, allocator=gtx.itir_python).ndarray array([0., 0., 0., 0., 0., 0., 0.]) """ - gtarray_namespace = next_allocators.get_array_allocation_namespace(allocator) - buffer = gtarray_namespace.zeros( - domain, device=device, dtype=dtype, aligned_index=aligned_index - ) + gtarray_namespace = next_allocators.get_array_allocation_namespace(allocator, device) + buffer = gtarray_namespace.zeros(domain, dtype=dtype, aligned_index=aligned_index) res = common._field(buffer, domain=domain, allocation_ns=gtarray_namespace) assert isinstance(res, common.MutableField) assert isinstance(res, nd_array_field.NdArrayField) @@ -137,8 +133,8 @@ def ones( >>> gtx.ones({IDim: range(3, 10)}, allocator=gtx.itir_python).ndarray array([1., 1., 1., 1., 1., 1., 1.]) """ - gtarray_namespace = next_allocators.get_array_allocation_namespace(allocator) - buffer = gtarray_namespace.ones(domain, device=device, dtype=dtype, aligned_index=aligned_index) + gtarray_namespace = next_allocators.get_array_allocation_namespace(allocator, device) + buffer = gtarray_namespace.ones(domain, dtype=dtype, aligned_index=aligned_index) res = common._field(buffer, domain=domain, allocation_ns=gtarray_namespace) assert isinstance(res, common.MutableField) assert isinstance(res, nd_array_field.NdArrayField) @@ -171,11 +167,10 @@ def full( >>> gtx.full({IDim: 3}, 5, allocator=gtx.itir_python).ndarray array([5, 5, 5]) """ - gtarray_namespace = next_allocators.get_array_allocation_namespace(allocator) + gtarray_namespace = next_allocators.get_array_allocation_namespace(allocator, device) buffer = gtarray_namespace.full( domain, fill_value, - device=device, dtype=dtype if dtype is not None else core_defs.dtype(type(fill_value)), aligned_index=aligned_index, ) @@ -284,12 +279,11 @@ def as_field( if (allocator is None) and (device is None) and xtyping.supports_dlpack(data): device = core_defs.Device(*data.__dlpack_device__()) - gtarray_namespace = next_allocators.get_array_allocation_namespace(allocator) + gtarray_namespace = next_allocators.get_array_allocation_namespace(allocator, device) buffer = gtarray_namespace.asarray( data, domain=actual_domain, dtype=dtype, - device=device, copy=True, # TODO(havogt) add support for zero-copy construction aligned_index=aligned_index, ) @@ -367,12 +361,11 @@ def as_connectivity( if (allocator is None) and (device is None) and xtyping.supports_dlpack(data): device = core_defs.Device(*data.__dlpack_device__()) - gtarray_namespace = next_allocators.get_array_allocation_namespace(allocator) + gtarray_namespace = next_allocators.get_array_allocation_namespace(allocator, device) buffer = gtarray_namespace.asarray( data, domain=actual_domain, dtype=dtype, - device=device, copy=True, # TODO(havogt) add support for zero-copy construction ) connectivity_field = common._connectivity( diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index 182a1af1e2..06eeecb66a 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -436,8 +436,10 @@ def _slice( def __copy__(self) -> NdArrayField: # Note: `copy` copies the data, following NumPy behavior - allocation_ns = self._allocation_ns or _allocators.get_array_allocation_namespace( - self.array_ns + allocation_ns = ( + self._allocation_ns + if self._allocation_ns is not None + else _allocators.get_array_allocation_namespace(self.array_ns) ) ndarray_copy = allocation_ns.asarray( self.ndarray, diff --git a/tests/next_tests/unit_tests/test_constructors.py b/tests/next_tests/unit_tests/test_constructors.py index 984d87fd00..7e592a687d 100644 --- a/tests/next_tests/unit_tests/test_constructors.py +++ b/tests/next_tests/unit_tests/test_constructors.py @@ -184,9 +184,6 @@ def test_field_wrong_origin(): with pytest.raises(ValueError, match=(r"Origin keys {'J'} not in domain")): gtx.as_field([I], np.random.rand(sizes[I]).astype(gtx.float32), origin={"J": 0}) - with pytest.raises(ValueError, match=(r"Cannot specify origin for domain I")): - gtx.as_field("I", np.random.rand(sizes[J]).astype(gtx.float32), origin={"J": 0}) - @pytest.mark.xfail(reason="aligned_index not supported yet") def test_aligned_index(): From 4fb8bab14abc8010f313d5af03b2cef5e5cd86c4 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Fri, 20 Dec 2024 10:06:34 +0100 Subject: [PATCH 12/12] fix get_actual_allocator --- src/gt4py/next/_allocators.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/src/gt4py/next/_allocators.py b/src/gt4py/next/_allocators.py index 3b08056653..184785246a 100644 --- a/src/gt4py/next/_allocators.py +++ b/src/gt4py/next/_allocators.py @@ -6,8 +6,6 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -# TODO: make this module private - import abc import dataclasses import functools @@ -407,13 +405,12 @@ def _check_unsupported_device_and_aligned_index( def _get_actual_allocator_and_device( allocator: Optional[FieldBufferAllocationUtil], device: Optional[core_defs.Device] ) -> tuple[FieldBufferAllocatorProtocol, core_defs.Device]: - if device is None and allocator is None: - raise ValueError("No 'device' or 'allocator' specified.") - actual_allocator = get_allocator(allocator) - if actual_allocator is None: - assert device is not None # for mypy - actual_allocator = device_allocators[device.device_type] - elif device is None: + if allocator is None and device is not None: + return device_allocators[device.device_type], device + + actual_allocator = get_allocator(allocator, default=device_allocators[core_defs.DeviceType.CPU]) + assert actual_allocator is not None + if device is None: device = core_defs.Device(actual_allocator.__gt_device_type__, 0) elif device.device_type != actual_allocator.__gt_device_type__: raise ValueError(f"Device '{device}' and allocator '{actual_allocator}' are incompatible.")