diff --git a/.gitpod.yml b/.gitpod.yml index 1d579d88eb..802d87796a 100644 --- a/.gitpod.yml +++ b/.gitpod.yml @@ -5,7 +5,7 @@ image: tasks: - name: Setup venv and dev tools init: | - ln -s /workspace/gt4py/.gitpod/.vscode /workspace/gt4py/.vscode + ln -sfn /workspace/gt4py/.gitpod/.vscode /workspace/gt4py/.vscode python -m venv .venv source .venv/bin/activate pip install --upgrade pip setuptools wheel diff --git a/src/gt4py/next/embedded/common.py b/src/gt4py/next/embedded/common.py index d796189ab3..558730cb82 100644 --- a/src/gt4py/next/embedded/common.py +++ b/src/gt4py/next/embedded/common.py @@ -14,6 +14,10 @@ from __future__ import annotations +import functools +import itertools +import operator + from gt4py.eve.extended_typing import Any, Optional, Sequence, cast from gt4py.next import common from gt4py.next.embedded import exceptions as embedded_exceptions @@ -90,6 +94,19 @@ def _absolute_sub_domain( return common.Domain(*named_ranges) +def intersect_domains(*domains: common.Domain) -> common.Domain: + return functools.reduce( + operator.and_, + domains, + common.Domain(dims=tuple(), ranges=tuple()), + ) + + +def iterate_domain(domain: common.Domain): + for i in itertools.product(*[list(r) for r in domain.ranges]): + yield tuple(zip(domain.dims, i)) + + def _expand_ellipsis( indices: common.RelativeIndexSequence, target_size: int ) -> tuple[common.IntIndex | slice, ...]: diff --git a/src/gt4py/next/embedded/context.py b/src/gt4py/next/embedded/context.py index 5fbdbc6f25..93942a5959 100644 --- a/src/gt4py/next/embedded/context.py +++ b/src/gt4py/next/embedded/context.py @@ -24,7 +24,7 @@ #: Column range used in column mode (`column_axis != None`) in the current embedded iterator #: closure execution context. -closure_column_range: cvars.ContextVar[range] = cvars.ContextVar("column_range") +closure_column_range: cvars.ContextVar[common.NamedRange] = cvars.ContextVar("column_range") _undefined_offset_provider: common.OffsetProvider = {} @@ -37,7 +37,7 @@ @contextlib.contextmanager def new_context( *, - closure_column_range: range | eve.NothingType = eve.NOTHING, + closure_column_range: common.NamedRange | eve.NothingType = eve.NOTHING, offset_provider: common.OffsetProvider | eve.NothingType = eve.NOTHING, ): import gt4py.next.embedded.context as this_module diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index ff6a2ceac7..6b69e8f8cc 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -16,7 +16,6 @@ import dataclasses import functools -import operator from collections.abc import Callable, Sequence from types import ModuleType from typing import ClassVar @@ -49,11 +48,10 @@ def _builtin_op(*fields: common.Field | core_defs.Scalar) -> NdArrayField: xp = first.__class__.array_ns op = getattr(xp, array_builtin_name) - domain_intersection = functools.reduce( - operator.and_, - [f.domain for f in fields if common.is_field(f)], - common.Domain(dims=tuple(), ranges=tuple()), + domain_intersection = embedded_common.intersect_domains( + *[f.domain for f in fields if common.is_field(f)] ) + transformed: list[core_defs.NDArrayObject | core_defs.Scalar] = [] for f in fields: if common.is_field(f): diff --git a/src/gt4py/next/embedded/operators.py b/src/gt4py/next/embedded/operators.py new file mode 100644 index 0000000000..f50ace7687 --- /dev/null +++ b/src/gt4py/next/embedded/operators.py @@ -0,0 +1,168 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import dataclasses +from typing import Any, Callable, Generic, ParamSpec, Sequence, TypeVar + +from gt4py import eve +from gt4py._core import definitions as core_defs +from gt4py.next import common, constructors, utils +from gt4py.next.embedded import common as embedded_common, context as embedded_context + + +_P = ParamSpec("_P") +_R = TypeVar("_R") + + +@dataclasses.dataclass(frozen=True) +class EmbeddedOperator(Generic[_R, _P]): + fun: Callable[_P, _R] + + def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: + return self.fun(*args, **kwargs) + + +@dataclasses.dataclass(frozen=True) +class ScanOperator(EmbeddedOperator[_R, _P]): + forward: bool + init: core_defs.Scalar | tuple[core_defs.Scalar | tuple, ...] + axis: common.Dimension + + def __call__(self, *args: common.Field | core_defs.Scalar, **kwargs: common.Field | core_defs.Scalar) -> common.Field: # type: ignore[override] # we cannot properly type annotate relative to self.fun + scan_range = embedded_context.closure_column_range.get() + assert self.axis == scan_range[0] + scan_axis = scan_range[0] + domain_intersection = _intersect_scan_args(*args, *kwargs.values()) + non_scan_domain = common.Domain(*[nr for nr in domain_intersection if nr[0] != scan_axis]) + + out_domain = common.Domain( + *[scan_range if nr[0] == scan_axis else nr for nr in domain_intersection] + ) + if scan_axis not in out_domain.dims: + # even if the scan dimension is not in the input, we can scan over it + out_domain = common.Domain(*out_domain, (scan_range)) + + res = _construct_scan_array(out_domain)(self.init) + + def scan_loop(hpos): + acc = self.init + for k in scan_range[1] if self.forward else reversed(scan_range[1]): + pos = (*hpos, (scan_axis, k)) + new_args = [_tuple_at(pos, arg) for arg in args] + new_kwargs = {k: _tuple_at(pos, v) for k, v in kwargs.items()} + acc = self.fun(acc, *new_args, **new_kwargs) + _tuple_assign_value(pos, res, acc) + + if len(non_scan_domain) == 0: + # if we don't have any dimension orthogonal to scan_axis, we need to do one scan_loop + scan_loop(()) + else: + for hpos in embedded_common.iterate_domain(non_scan_domain): + scan_loop(hpos) + + return res + + +def field_operator_call(op: EmbeddedOperator, args: Any, kwargs: Any): + if "out" in kwargs: + # called from program or direct field_operator as program + offset_provider = kwargs.pop("offset_provider", None) + + new_context_kwargs = {} + if embedded_context.within_context(): + # called from program + assert offset_provider is None + else: + # field_operator as program + new_context_kwargs["offset_provider"] = offset_provider + + out = kwargs.pop("out") + domain = kwargs.pop("domain", None) + + flattened_out: tuple[common.Field, ...] = utils.flatten_nested_tuple((out,)) + assert all(f.domain == flattened_out[0].domain for f in flattened_out) + + out_domain = common.domain(domain) if domain is not None else flattened_out[0].domain + + new_context_kwargs["closure_column_range"] = _get_vertical_range(out_domain) + + with embedded_context.new_context(**new_context_kwargs) as ctx: + res = ctx.run(op, *args, **kwargs) + _tuple_assign_field( + out, + res, + domain=out_domain, + ) + else: + # called from other field_operator + return op(*args, **kwargs) + + +def _get_vertical_range(domain: common.Domain) -> common.NamedRange | eve.NothingType: + vertical_dim_filtered = [nr for nr in domain if nr[0].kind == common.DimensionKind.VERTICAL] + assert len(vertical_dim_filtered) <= 1 + return vertical_dim_filtered[0] if vertical_dim_filtered else eve.NOTHING + + +def _tuple_assign_field( + target: tuple[common.MutableField | tuple, ...] | common.MutableField, + source: tuple[common.Field | tuple, ...] | common.Field, + domain: common.Domain, +): + @utils.tree_map + def impl(target: common.MutableField, source: common.Field): + target[domain] = source[domain] + + impl(target, source) + + +def _intersect_scan_args( + *args: core_defs.Scalar | common.Field | tuple[core_defs.Scalar | common.Field | tuple, ...] +) -> common.Domain: + return embedded_common.intersect_domains( + *[arg.domain for arg in utils.flatten_nested_tuple(args) if common.is_field(arg)] + ) + + +def _construct_scan_array(domain: common.Domain): + @utils.tree_map + def impl(init: core_defs.Scalar) -> common.Field: + return constructors.empty(domain, dtype=type(init)) + + return impl + + +def _tuple_assign_value( + pos: Sequence[common.NamedIndex], + target: common.MutableField | tuple[common.MutableField | tuple, ...], + source: core_defs.Scalar | tuple[core_defs.Scalar | tuple, ...], +) -> None: + @utils.tree_map + def impl(target: common.MutableField, source: core_defs.Scalar): + target[pos] = source + + impl(target, source) + + +def _tuple_at( + pos: Sequence[common.NamedIndex], + field: common.Field | core_defs.Scalar | tuple[common.Field | core_defs.Scalar | tuple, ...], +) -> core_defs.Scalar | tuple[core_defs.ScalarT | tuple, ...]: + @utils.tree_map + def impl(field: common.Field | core_defs.Scalar) -> core_defs.Scalar: + res = field[pos] if common.is_field(field) else field + assert core_defs.is_scalar_type(res) + return res + + return impl(field) # type: ignore[return-value] diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 67272f88b8..a8956716ec 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -32,8 +32,9 @@ from gt4py._core import definitions as core_defs from gt4py.eve import utils as eve_utils from gt4py.eve.extended_typing import Any, Optional -from gt4py.next import allocators as next_allocators, common, embedded as next_embedded +from gt4py.next import allocators as next_allocators, embedded as next_embedded from gt4py.next.common import Dimension, DimensionKind, GridType +from gt4py.next.embedded import operators as embedded_operators from gt4py.next.ffront import ( dialect_ast_enums, field_operator_ast as foast, @@ -545,6 +546,7 @@ class FieldOperator(GTCallable, Generic[OperatorNodeT]): definition: Optional[types.FunctionType] = None backend: Optional[ppi.ProgramExecutor] = DEFAULT_BACKEND grid_type: Optional[GridType] = None + operator_attributes: Optional[dict[str, Any]] = None _program_cache: dict = dataclasses.field(default_factory=dict) @classmethod @@ -581,6 +583,7 @@ def from_function( definition=definition, backend=backend, grid_type=grid_type, + operator_attributes=operator_attributes, ) def __gt_type__(self) -> ts.CallableType: @@ -687,68 +690,38 @@ def __call__( *args, **kwargs, ) -> None: - # TODO(havogt): Don't select mode based on existence of kwargs, - # because now we cannot provide nice error messages. E.g. set context var - # if we are reaching this from a program call. - if "out" in kwargs: - out = kwargs.pop("out") + if not next_embedded.context.within_context() and self.backend is not None: + # non embedded execution offset_provider = kwargs.pop("offset_provider", None) - if self.backend is not None: - # "out" and "offset_provider" -> field_operator as program - # When backend is None, we are in embedded execution and for now - # we disable the program generation since it would involve generating - # Python source code from a PAST node. - args, kwargs = type_info.canonicalize_arguments(self.foast_node.type, args, kwargs) - # TODO(tehrengruber): check all offset providers are given - # deduce argument types - arg_types = [] - for arg in args: - arg_types.append(type_translation.from_value(arg)) - kwarg_types = {} - for name, arg in kwargs.items(): - kwarg_types[name] = type_translation.from_value(arg) - - return self.as_program(arg_types, kwarg_types)( - *args, out, offset_provider=offset_provider, **kwargs - ) - else: - # "out" -> field_operator called from program in embedded execution or - # field_operator called directly from Python in embedded execution - domain = kwargs.pop("domain", None) - if not next_embedded.context.within_context(): - # field_operator from Python in embedded execution - with next_embedded.context.new_context(offset_provider=offset_provider) as ctx: - res = ctx.run(self.definition, *args, **kwargs) - else: - # field_operator from program in embedded execution (offset_provicer is already set) - assert ( - offset_provider is None - or next_embedded.context.offset_provider.get() is offset_provider - ) - res = self.definition(*args, **kwargs) - _tuple_assign_field( - out, res, domain=None if domain is None else common.domain(domain) - ) - return + out = kwargs.pop("out") + args, kwargs = type_info.canonicalize_arguments(self.foast_node.type, args, kwargs) + # TODO(tehrengruber): check all offset providers are given + # deduce argument types + arg_types = [] + for arg in args: + arg_types.append(type_translation.from_value(arg)) + kwarg_types = {} + for name, arg in kwargs.items(): + kwarg_types[name] = type_translation.from_value(arg) + + return self.as_program(arg_types, kwarg_types)( + *args, out, offset_provider=offset_provider, **kwargs + ) else: - # field_operator called from other field_operator in embedded execution - assert self.backend is None - return self.definition(*args, **kwargs) - - -def _tuple_assign_field( - target: tuple[common.Field | tuple, ...] | common.Field, - source: tuple[common.Field | tuple, ...] | common.Field, - domain: Optional[common.Domain], -): - if isinstance(target, tuple): - if not isinstance(source, tuple): - raise RuntimeError(f"Cannot assign {source} to {target}.") - for t, s in zip(target, source): - _tuple_assign_field(t, s, domain) - else: - domain = domain or target.domain - target[domain] = source[domain] + if self.operator_attributes is not None and any( + has_scan_op_attribute := [ + attribute in self.operator_attributes + for attribute in ["init", "axis", "forward"] + ] + ): + assert all(has_scan_op_attribute) + forward = self.operator_attributes["forward"] + init = self.operator_attributes["init"] + axis = self.operator_attributes["axis"] + op = embedded_operators.ScanOperator(self.definition, forward, init, axis) + else: + op = embedded_operators.EmbeddedOperator(self.definition) + return embedded_operators.field_operator_call(op, args, kwargs) @typing.overload diff --git a/src/gt4py/next/field_utils.py b/src/gt4py/next/field_utils.py new file mode 100644 index 0000000000..14b7c3838c --- /dev/null +++ b/src/gt4py/next/field_utils.py @@ -0,0 +1,22 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import numpy as np + +from gt4py.next import common, utils + + +@utils.tree_map +def asnumpy(field: common.Field | np.ndarray) -> np.ndarray: + return field.asnumpy() if common.is_field(field) else field # type: ignore[return-value] # mypy doesn't understand the condition diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index b02d6c8d72..b00e53bfd9 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -196,7 +196,7 @@ def field_setitem(self, indices: NamedFieldIndices, value: Any) -> None: #: Column range used in column mode (`column_axis != None`) in the current closure execution context. -column_range_cvar: cvars.ContextVar[range] = next_embedded.context.closure_column_range +column_range_cvar: cvars.ContextVar[common.NamedRange] = next_embedded.context.closure_column_range #: Offset provider dict in the current closure execution context. offset_provider_cvar: cvars.ContextVar[OffsetProvider] = next_embedded.context.offset_provider @@ -211,8 +211,8 @@ class Column(np.lib.mixins.NDArrayOperatorsMixin): def __init__(self, kstart: int, data: np.ndarray | Scalar) -> None: self.kstart = kstart assert isinstance(data, (np.ndarray, Scalar)) # type: ignore # mypy bug #11673 - column_range = column_range_cvar.get() - self.data = data if isinstance(data, np.ndarray) else np.full(len(column_range), data) + column_range: common.NamedRange = column_range_cvar.get() + self.data = data if isinstance(data, np.ndarray) else np.full(len(column_range[1]), data) def __getitem__(self, i: int) -> Any: result = self.data[i - self.kstart] @@ -746,7 +746,7 @@ def _make_tuple( except embedded_exceptions.IndexOutOfBounds: return _UNDEFINED else: - column_range = column_range_cvar.get() + column_range = column_range_cvar.get()[1] assert column_range is not None col: list[ @@ -823,7 +823,7 @@ def deref(self) -> Any: assert isinstance(k_pos, int) # the following range describes a range in the field # (negative values are relative to the origin, not relative to the size) - slice_column[self.column_axis] = range(k_pos, k_pos + len(column_range)) + slice_column[self.column_axis] = range(k_pos, k_pos + len(column_range[1])) assert _is_concrete_position(shifted_pos) position = {**shifted_pos, **slice_column} @@ -864,7 +864,7 @@ def make_in_iterator( init = [None] * sparse_dimensions.count(sparse_dim) new_pos[sparse_dim] = init # type: ignore[assignment] # looks like mypy is confused if column_axis is not None: - column_range = column_range_cvar.get() + column_range = column_range_cvar.get()[1] # if we deal with column stencil the column position is just an offset by which the whole column needs to be shifted assert column_range is not None new_pos[column_axis] = column_range.start @@ -1479,7 +1479,7 @@ def _column_dtype(elem: Any) -> np.dtype: @builtins.scan.register(EMBEDDED) def scan(scan_pass, is_forward: bool, init): def impl(*iters: ItIterator): - column_range = column_range_cvar.get() + column_range = column_range_cvar.get()[1] if column_range is None: raise RuntimeError("Column range is not defined, cannot scan.") @@ -1532,7 +1532,10 @@ def closure( column = ColumnDescriptor(column_axis.value, domain[column_axis.value]) del domain[column_axis.value] - column_range = column.col_range + column_range = ( + column_axis, + common.UnitRange(column.col_range.start, column.col_range.stop), + ) out = as_tuple_field(out) if is_tuple_of_field(out) else _wrap_field(out) diff --git a/src/gt4py/next/utils.py b/src/gt4py/next/utils.py index baae8361c5..ec459906e0 100644 --- a/src/gt4py/next/utils.py +++ b/src/gt4py/next/utils.py @@ -15,10 +15,6 @@ import functools from typing import Any, Callable, ClassVar, ParamSpec, TypeGuard, TypeVar, cast -import numpy as np - -from gt4py.next import common - class RecursionGuard: """ @@ -57,7 +53,6 @@ def __exit__(self, *exc): _T = TypeVar("_T") - _P = ParamSpec("_P") _R = TypeVar("_R") @@ -66,8 +61,17 @@ def is_tuple_of(v: Any, t: type[_T]) -> TypeGuard[tuple[_T, ...]]: return isinstance(v, tuple) and all(isinstance(e, t) for e in v) +# TODO(havogt): remove flatten duplications in the whole codebase +def flatten_nested_tuple(value: tuple[_T | tuple, ...]) -> tuple[_T, ...]: + if isinstance(value, tuple): + return sum((flatten_nested_tuple(v) for v in value), start=()) # type: ignore[arg-type] # cannot properly express nesting + else: + return (value,) + + def tree_map(fun: Callable[_P, _R]) -> Callable[..., _R | tuple[_R | tuple, ...]]: - """Apply `fun` to each entry of (possibly nested) tuples. + """ + Apply `fun` to each entry of (possibly nested) tuples. Examples: >>> tree_map(lambda x: x + 1)(((1, 2), 3)) @@ -88,9 +92,3 @@ def impl(*args: Any | tuple[Any | tuple, ...]) -> _R | tuple[_R | tuple, ...]: ) # mypy doesn't understand that `args` at this point is of type `_P.args` return impl - - -# TODO(havogt): consider moving to module like `field_utils` -@tree_map -def asnumpy(field: common.Field | np.ndarray) -> np.ndarray: - return field.asnumpy() if common.is_field(field) else field # type: ignore[return-value] # mypy doesn't understand the condition diff --git a/tests/next_tests/exclusion_matrices.py b/tests/next_tests/exclusion_matrices.py index a6a302e143..6d0b7d3c10 100644 --- a/tests/next_tests/exclusion_matrices.py +++ b/tests/next_tests/exclusion_matrices.py @@ -133,7 +133,6 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): (USES_ZERO_DIMENSIONAL_FIELDS, XFAIL, UNSUPPORTED_MESSAGE), ] EMBEDDED_SKIP_LIST = [ - (USES_SCAN, XFAIL, UNSUPPORTED_MESSAGE), (USES_DYNAMIC_OFFSETS, XFAIL, UNSUPPORTED_MESSAGE), (CHECKS_SPECIFIC_ERROR, XFAIL, UNSUPPORTED_MESSAGE), ] diff --git a/tests/next_tests/integration_tests/cases.py b/tests/next_tests/integration_tests/cases.py index 81f216397b..b1e26b40cb 100644 --- a/tests/next_tests/integration_tests/cases.py +++ b/tests/next_tests/integration_tests/cases.py @@ -28,7 +28,7 @@ from gt4py._core import definitions as core_defs from gt4py.eve import extended_typing as xtyping from gt4py.eve.extended_typing import Self -from gt4py.next import common, constructors, utils +from gt4py.next import common, constructors, field_utils from gt4py.next.ffront import decorator from gt4py.next.program_processors import processor_interface as ppi from gt4py.next.type_system import type_specifications as ts, type_translation @@ -436,8 +436,8 @@ def verify( out_comp = out or inout assert out_comp is not None - out_comp_ndarray = utils.asnumpy(out_comp) - ref_ndarray = utils.asnumpy(ref) + out_comp_ndarray = field_utils.asnumpy(out_comp) + ref_ndarray = field_utils.asnumpy(ref) assert comparison(ref_ndarray, out_comp_ndarray), ( f"Verification failed:\n" f"\tcomparison={comparison.__name__}(ref, out)\n" diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index cf273a4524..abb23611e0 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -292,6 +292,7 @@ def testee_op( cases.verify(cartesian_case, testee_op, qc, tuple_scalar, out=qc, ref=expected) +@pytest.mark.uses_cartesian_shift @pytest.mark.uses_scan @pytest.mark.uses_index_fields def test_scalar_scan_vertical_offset(cartesian_case): # noqa: F811 # fixtures @@ -801,6 +802,85 @@ def simple_scan_operator(carry: float, a: tuple[float, float]) -> float: cases.verify(cartesian_case, simple_scan_operator, (inp1, inp2), out=out, ref=expected) +@pytest.mark.uses_scan +def test_scan_different_domain_in_tuple(cartesian_case): + init = 1.0 + i_size = cartesian_case.default_sizes[IDim] + k_size = cartesian_case.default_sizes[KDim] + + inp1_np = np.ones( + ( + i_size + 1, + k_size, + ) + ) # i_size bigger than in the other argument + inp2_np = np.fromfunction(lambda i, k: k, shape=(i_size, k_size), dtype=float) + inp1 = cartesian_case.as_field([IDim, KDim], inp1_np) + inp2 = cartesian_case.as_field([IDim, KDim], inp2_np) + out = cartesian_case.as_field([IDim, KDim], np.zeros((i_size, k_size))) + + def prev_levels_iterator(i): + return range(i + 1) + + expected = np.asarray( + [ + reduce( + lambda prev, k: prev + inp1_np[:-1, k] + inp2_np[:, k], + prev_levels_iterator(k), + init, + ) + for k in range(k_size) + ] + ).transpose() + + @gtx.scan_operator(axis=KDim, forward=True, init=init) + def scan_op(carry: float, a: tuple[float, float]) -> float: + return carry + a[0] + a[1] + + @gtx.field_operator + def foo( + inp1: gtx.Field[[IDim, KDim], float], inp2: gtx.Field[[IDim, KDim], float] + ) -> gtx.Field[[IDim, KDim], float]: + return scan_op((inp1, inp2)) + + cases.verify(cartesian_case, foo, inp1, inp2, out=out, ref=expected) + + +@pytest.mark.uses_scan +def test_scan_tuple_field_scalar_mixed(cartesian_case): + init = 1.0 + i_size = cartesian_case.default_sizes[IDim] + k_size = cartesian_case.default_sizes[KDim] + + inp2_np = np.fromfunction(lambda i, k: k, shape=(i_size, k_size), dtype=float) + inp2 = cartesian_case.as_field([IDim, KDim], inp2_np) + out = cartesian_case.as_field([IDim, KDim], np.zeros((i_size, k_size))) + + def prev_levels_iterator(i): + return range(i + 1) + + expected = np.asarray( + [ + reduce( + lambda prev, k: prev + 1.0 + inp2_np[:, k], + prev_levels_iterator(k), + init, + ) + for k in range(k_size) + ] + ).transpose() + + @gtx.scan_operator(axis=KDim, forward=True, init=init) + def scan_op(carry: float, a: tuple[float, float]) -> float: + return carry + a[0] + a[1] + + @gtx.field_operator + def foo(inp1: float, inp2: gtx.Field[[IDim, KDim], float]) -> gtx.Field[[IDim, KDim], float]: + return scan_op((inp1, inp2)) + + cases.verify(cartesian_case, foo, 1.0, inp2, out=out, ref=expected) + + def test_docstring(cartesian_case): @gtx.field_operator def fieldop_with_docstring(a: cases.IField) -> cases.IField: diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py index fd571514ac..9ba8eef3a3 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py @@ -16,7 +16,7 @@ import pytest import gt4py.next as gtx -from gt4py.next import utils +from gt4py.next import field_utils from gt4py.next.iterator.builtins import * from gt4py.next.iterator.runtime import closure, fendef, fundef, offset @@ -158,7 +158,7 @@ def test_k_level_condition(program_processor, lift_mode, fun, k_level, inp_funct k_size = 5 inp = inp_function(k_size) - ref = ref_function(utils.asnumpy(inp)) + ref = ref_function(field_utils.asnumpy(inp)) out = gtx.as_field([KDim], np.zeros((5,), dtype=np.int32)) diff --git a/tests/next_tests/unit_tests/embedded_tests/test_common.py b/tests/next_tests/unit_tests/embedded_tests/test_common.py index 640ed326bb..de511fdabb 100644 --- a/tests/next_tests/unit_tests/embedded_tests/test_common.py +++ b/tests/next_tests/unit_tests/embedded_tests/test_common.py @@ -19,7 +19,7 @@ from gt4py.next import common from gt4py.next.common import UnitRange from gt4py.next.embedded import exceptions as embedded_exceptions -from gt4py.next.embedded.common import _slice_range, sub_domain +from gt4py.next.embedded.common import _slice_range, iterate_domain, sub_domain @pytest.mark.parametrize( @@ -135,3 +135,15 @@ def test_sub_domain(domain, index, expected): expected = common.domain(expected) result = sub_domain(domain, index) assert result == expected + + +def test_iterate_domain(): + domain = common.domain({I: 2, J: 3}) + ref = [] + for i in domain[I][1]: + for j in domain[J][1]: + ref.append(((I, i), (J, j))) + + testee = list(iterate_domain(domain)) + + assert testee == ref diff --git a/tests/next_tests/unit_tests/iterator_tests/test_embedded_internals.py b/tests/next_tests/unit_tests/iterator_tests/test_embedded_internals.py index 3a35570ca2..9238cd4f7a 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_embedded_internals.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_embedded_internals.py @@ -19,13 +19,14 @@ import numpy as np import pytest +from gt4py.next import common from gt4py.next.iterator import embedded def _run_within_context( func: Callable[[], Any], *, - column_range: Optional[range] = None, + column_range: Optional[common.NamedRange] = None, offset_provider: Optional[embedded.OffsetProvider] = None, ) -> Any: def wrapped_func(): @@ -59,7 +60,10 @@ def test_func(data_a: int, data_b: int): # Setting an invalid column_range here shouldn't affect other contexts embedded.column_range_cvar.set(range(2, 999)) - _run_within_context(lambda: test_func(2, 3), column_range=range(0, 3)) + _run_within_context( + lambda: test_func(2, 3), + column_range=(common.Dimension("K", kind=common.DimensionKind.VERTICAL), range(0, 3)), + ) def test_column_ufunc_with_scalar():