diff --git a/.gitpod.yml b/.gitpod.yml
index 1d579d88eb..802d87796a 100644
--- a/.gitpod.yml
+++ b/.gitpod.yml
@@ -5,7 +5,7 @@ image:
tasks:
- name: Setup venv and dev tools
init: |
- ln -s /workspace/gt4py/.gitpod/.vscode /workspace/gt4py/.vscode
+ ln -sfn /workspace/gt4py/.gitpod/.vscode /workspace/gt4py/.vscode
python -m venv .venv
source .venv/bin/activate
pip install --upgrade pip setuptools wheel
diff --git a/src/gt4py/next/embedded/common.py b/src/gt4py/next/embedded/common.py
index d796189ab3..558730cb82 100644
--- a/src/gt4py/next/embedded/common.py
+++ b/src/gt4py/next/embedded/common.py
@@ -14,6 +14,10 @@
from __future__ import annotations
+import functools
+import itertools
+import operator
+
from gt4py.eve.extended_typing import Any, Optional, Sequence, cast
from gt4py.next import common
from gt4py.next.embedded import exceptions as embedded_exceptions
@@ -90,6 +94,19 @@ def _absolute_sub_domain(
return common.Domain(*named_ranges)
+def intersect_domains(*domains: common.Domain) -> common.Domain:
+ return functools.reduce(
+ operator.and_,
+ domains,
+ common.Domain(dims=tuple(), ranges=tuple()),
+ )
+
+
+def iterate_domain(domain: common.Domain):
+ for i in itertools.product(*[list(r) for r in domain.ranges]):
+ yield tuple(zip(domain.dims, i))
+
+
def _expand_ellipsis(
indices: common.RelativeIndexSequence, target_size: int
) -> tuple[common.IntIndex | slice, ...]:
diff --git a/src/gt4py/next/embedded/context.py b/src/gt4py/next/embedded/context.py
index 5fbdbc6f25..93942a5959 100644
--- a/src/gt4py/next/embedded/context.py
+++ b/src/gt4py/next/embedded/context.py
@@ -24,7 +24,7 @@
#: Column range used in column mode (`column_axis != None`) in the current embedded iterator
#: closure execution context.
-closure_column_range: cvars.ContextVar[range] = cvars.ContextVar("column_range")
+closure_column_range: cvars.ContextVar[common.NamedRange] = cvars.ContextVar("column_range")
_undefined_offset_provider: common.OffsetProvider = {}
@@ -37,7 +37,7 @@
@contextlib.contextmanager
def new_context(
*,
- closure_column_range: range | eve.NothingType = eve.NOTHING,
+ closure_column_range: common.NamedRange | eve.NothingType = eve.NOTHING,
offset_provider: common.OffsetProvider | eve.NothingType = eve.NOTHING,
):
import gt4py.next.embedded.context as this_module
diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py
index ff6a2ceac7..6b69e8f8cc 100644
--- a/src/gt4py/next/embedded/nd_array_field.py
+++ b/src/gt4py/next/embedded/nd_array_field.py
@@ -16,7 +16,6 @@
import dataclasses
import functools
-import operator
from collections.abc import Callable, Sequence
from types import ModuleType
from typing import ClassVar
@@ -49,11 +48,10 @@ def _builtin_op(*fields: common.Field | core_defs.Scalar) -> NdArrayField:
xp = first.__class__.array_ns
op = getattr(xp, array_builtin_name)
- domain_intersection = functools.reduce(
- operator.and_,
- [f.domain for f in fields if common.is_field(f)],
- common.Domain(dims=tuple(), ranges=tuple()),
+ domain_intersection = embedded_common.intersect_domains(
+ *[f.domain for f in fields if common.is_field(f)]
)
+
transformed: list[core_defs.NDArrayObject | core_defs.Scalar] = []
for f in fields:
if common.is_field(f):
diff --git a/src/gt4py/next/embedded/operators.py b/src/gt4py/next/embedded/operators.py
new file mode 100644
index 0000000000..f50ace7687
--- /dev/null
+++ b/src/gt4py/next/embedded/operators.py
@@ -0,0 +1,168 @@
+# GT4Py - GridTools Framework
+#
+# Copyright (c) 2014-2023, ETH Zurich
+# All rights reserved.
+#
+# This file is part of the GT4Py project and the GridTools framework.
+# GT4Py is free software: you can redistribute it and/or modify it under
+# the terms of the GNU General Public License as published by the
+# Free Software Foundation, either version 3 of the License, or any later
+# version. See the LICENSE.txt file at the top-level directory of this
+# distribution for a copy of the license or check .
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+import dataclasses
+from typing import Any, Callable, Generic, ParamSpec, Sequence, TypeVar
+
+from gt4py import eve
+from gt4py._core import definitions as core_defs
+from gt4py.next import common, constructors, utils
+from gt4py.next.embedded import common as embedded_common, context as embedded_context
+
+
+_P = ParamSpec("_P")
+_R = TypeVar("_R")
+
+
+@dataclasses.dataclass(frozen=True)
+class EmbeddedOperator(Generic[_R, _P]):
+ fun: Callable[_P, _R]
+
+ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R:
+ return self.fun(*args, **kwargs)
+
+
+@dataclasses.dataclass(frozen=True)
+class ScanOperator(EmbeddedOperator[_R, _P]):
+ forward: bool
+ init: core_defs.Scalar | tuple[core_defs.Scalar | tuple, ...]
+ axis: common.Dimension
+
+ def __call__(self, *args: common.Field | core_defs.Scalar, **kwargs: common.Field | core_defs.Scalar) -> common.Field: # type: ignore[override] # we cannot properly type annotate relative to self.fun
+ scan_range = embedded_context.closure_column_range.get()
+ assert self.axis == scan_range[0]
+ scan_axis = scan_range[0]
+ domain_intersection = _intersect_scan_args(*args, *kwargs.values())
+ non_scan_domain = common.Domain(*[nr for nr in domain_intersection if nr[0] != scan_axis])
+
+ out_domain = common.Domain(
+ *[scan_range if nr[0] == scan_axis else nr for nr in domain_intersection]
+ )
+ if scan_axis not in out_domain.dims:
+ # even if the scan dimension is not in the input, we can scan over it
+ out_domain = common.Domain(*out_domain, (scan_range))
+
+ res = _construct_scan_array(out_domain)(self.init)
+
+ def scan_loop(hpos):
+ acc = self.init
+ for k in scan_range[1] if self.forward else reversed(scan_range[1]):
+ pos = (*hpos, (scan_axis, k))
+ new_args = [_tuple_at(pos, arg) for arg in args]
+ new_kwargs = {k: _tuple_at(pos, v) for k, v in kwargs.items()}
+ acc = self.fun(acc, *new_args, **new_kwargs)
+ _tuple_assign_value(pos, res, acc)
+
+ if len(non_scan_domain) == 0:
+ # if we don't have any dimension orthogonal to scan_axis, we need to do one scan_loop
+ scan_loop(())
+ else:
+ for hpos in embedded_common.iterate_domain(non_scan_domain):
+ scan_loop(hpos)
+
+ return res
+
+
+def field_operator_call(op: EmbeddedOperator, args: Any, kwargs: Any):
+ if "out" in kwargs:
+ # called from program or direct field_operator as program
+ offset_provider = kwargs.pop("offset_provider", None)
+
+ new_context_kwargs = {}
+ if embedded_context.within_context():
+ # called from program
+ assert offset_provider is None
+ else:
+ # field_operator as program
+ new_context_kwargs["offset_provider"] = offset_provider
+
+ out = kwargs.pop("out")
+ domain = kwargs.pop("domain", None)
+
+ flattened_out: tuple[common.Field, ...] = utils.flatten_nested_tuple((out,))
+ assert all(f.domain == flattened_out[0].domain for f in flattened_out)
+
+ out_domain = common.domain(domain) if domain is not None else flattened_out[0].domain
+
+ new_context_kwargs["closure_column_range"] = _get_vertical_range(out_domain)
+
+ with embedded_context.new_context(**new_context_kwargs) as ctx:
+ res = ctx.run(op, *args, **kwargs)
+ _tuple_assign_field(
+ out,
+ res,
+ domain=out_domain,
+ )
+ else:
+ # called from other field_operator
+ return op(*args, **kwargs)
+
+
+def _get_vertical_range(domain: common.Domain) -> common.NamedRange | eve.NothingType:
+ vertical_dim_filtered = [nr for nr in domain if nr[0].kind == common.DimensionKind.VERTICAL]
+ assert len(vertical_dim_filtered) <= 1
+ return vertical_dim_filtered[0] if vertical_dim_filtered else eve.NOTHING
+
+
+def _tuple_assign_field(
+ target: tuple[common.MutableField | tuple, ...] | common.MutableField,
+ source: tuple[common.Field | tuple, ...] | common.Field,
+ domain: common.Domain,
+):
+ @utils.tree_map
+ def impl(target: common.MutableField, source: common.Field):
+ target[domain] = source[domain]
+
+ impl(target, source)
+
+
+def _intersect_scan_args(
+ *args: core_defs.Scalar | common.Field | tuple[core_defs.Scalar | common.Field | tuple, ...]
+) -> common.Domain:
+ return embedded_common.intersect_domains(
+ *[arg.domain for arg in utils.flatten_nested_tuple(args) if common.is_field(arg)]
+ )
+
+
+def _construct_scan_array(domain: common.Domain):
+ @utils.tree_map
+ def impl(init: core_defs.Scalar) -> common.Field:
+ return constructors.empty(domain, dtype=type(init))
+
+ return impl
+
+
+def _tuple_assign_value(
+ pos: Sequence[common.NamedIndex],
+ target: common.MutableField | tuple[common.MutableField | tuple, ...],
+ source: core_defs.Scalar | tuple[core_defs.Scalar | tuple, ...],
+) -> None:
+ @utils.tree_map
+ def impl(target: common.MutableField, source: core_defs.Scalar):
+ target[pos] = source
+
+ impl(target, source)
+
+
+def _tuple_at(
+ pos: Sequence[common.NamedIndex],
+ field: common.Field | core_defs.Scalar | tuple[common.Field | core_defs.Scalar | tuple, ...],
+) -> core_defs.Scalar | tuple[core_defs.ScalarT | tuple, ...]:
+ @utils.tree_map
+ def impl(field: common.Field | core_defs.Scalar) -> core_defs.Scalar:
+ res = field[pos] if common.is_field(field) else field
+ assert core_defs.is_scalar_type(res)
+ return res
+
+ return impl(field) # type: ignore[return-value]
diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py
index 67272f88b8..a8956716ec 100644
--- a/src/gt4py/next/ffront/decorator.py
+++ b/src/gt4py/next/ffront/decorator.py
@@ -32,8 +32,9 @@
from gt4py._core import definitions as core_defs
from gt4py.eve import utils as eve_utils
from gt4py.eve.extended_typing import Any, Optional
-from gt4py.next import allocators as next_allocators, common, embedded as next_embedded
+from gt4py.next import allocators as next_allocators, embedded as next_embedded
from gt4py.next.common import Dimension, DimensionKind, GridType
+from gt4py.next.embedded import operators as embedded_operators
from gt4py.next.ffront import (
dialect_ast_enums,
field_operator_ast as foast,
@@ -545,6 +546,7 @@ class FieldOperator(GTCallable, Generic[OperatorNodeT]):
definition: Optional[types.FunctionType] = None
backend: Optional[ppi.ProgramExecutor] = DEFAULT_BACKEND
grid_type: Optional[GridType] = None
+ operator_attributes: Optional[dict[str, Any]] = None
_program_cache: dict = dataclasses.field(default_factory=dict)
@classmethod
@@ -581,6 +583,7 @@ def from_function(
definition=definition,
backend=backend,
grid_type=grid_type,
+ operator_attributes=operator_attributes,
)
def __gt_type__(self) -> ts.CallableType:
@@ -687,68 +690,38 @@ def __call__(
*args,
**kwargs,
) -> None:
- # TODO(havogt): Don't select mode based on existence of kwargs,
- # because now we cannot provide nice error messages. E.g. set context var
- # if we are reaching this from a program call.
- if "out" in kwargs:
- out = kwargs.pop("out")
+ if not next_embedded.context.within_context() and self.backend is not None:
+ # non embedded execution
offset_provider = kwargs.pop("offset_provider", None)
- if self.backend is not None:
- # "out" and "offset_provider" -> field_operator as program
- # When backend is None, we are in embedded execution and for now
- # we disable the program generation since it would involve generating
- # Python source code from a PAST node.
- args, kwargs = type_info.canonicalize_arguments(self.foast_node.type, args, kwargs)
- # TODO(tehrengruber): check all offset providers are given
- # deduce argument types
- arg_types = []
- for arg in args:
- arg_types.append(type_translation.from_value(arg))
- kwarg_types = {}
- for name, arg in kwargs.items():
- kwarg_types[name] = type_translation.from_value(arg)
-
- return self.as_program(arg_types, kwarg_types)(
- *args, out, offset_provider=offset_provider, **kwargs
- )
- else:
- # "out" -> field_operator called from program in embedded execution or
- # field_operator called directly from Python in embedded execution
- domain = kwargs.pop("domain", None)
- if not next_embedded.context.within_context():
- # field_operator from Python in embedded execution
- with next_embedded.context.new_context(offset_provider=offset_provider) as ctx:
- res = ctx.run(self.definition, *args, **kwargs)
- else:
- # field_operator from program in embedded execution (offset_provicer is already set)
- assert (
- offset_provider is None
- or next_embedded.context.offset_provider.get() is offset_provider
- )
- res = self.definition(*args, **kwargs)
- _tuple_assign_field(
- out, res, domain=None if domain is None else common.domain(domain)
- )
- return
+ out = kwargs.pop("out")
+ args, kwargs = type_info.canonicalize_arguments(self.foast_node.type, args, kwargs)
+ # TODO(tehrengruber): check all offset providers are given
+ # deduce argument types
+ arg_types = []
+ for arg in args:
+ arg_types.append(type_translation.from_value(arg))
+ kwarg_types = {}
+ for name, arg in kwargs.items():
+ kwarg_types[name] = type_translation.from_value(arg)
+
+ return self.as_program(arg_types, kwarg_types)(
+ *args, out, offset_provider=offset_provider, **kwargs
+ )
else:
- # field_operator called from other field_operator in embedded execution
- assert self.backend is None
- return self.definition(*args, **kwargs)
-
-
-def _tuple_assign_field(
- target: tuple[common.Field | tuple, ...] | common.Field,
- source: tuple[common.Field | tuple, ...] | common.Field,
- domain: Optional[common.Domain],
-):
- if isinstance(target, tuple):
- if not isinstance(source, tuple):
- raise RuntimeError(f"Cannot assign {source} to {target}.")
- for t, s in zip(target, source):
- _tuple_assign_field(t, s, domain)
- else:
- domain = domain or target.domain
- target[domain] = source[domain]
+ if self.operator_attributes is not None and any(
+ has_scan_op_attribute := [
+ attribute in self.operator_attributes
+ for attribute in ["init", "axis", "forward"]
+ ]
+ ):
+ assert all(has_scan_op_attribute)
+ forward = self.operator_attributes["forward"]
+ init = self.operator_attributes["init"]
+ axis = self.operator_attributes["axis"]
+ op = embedded_operators.ScanOperator(self.definition, forward, init, axis)
+ else:
+ op = embedded_operators.EmbeddedOperator(self.definition)
+ return embedded_operators.field_operator_call(op, args, kwargs)
@typing.overload
diff --git a/src/gt4py/next/field_utils.py b/src/gt4py/next/field_utils.py
new file mode 100644
index 0000000000..14b7c3838c
--- /dev/null
+++ b/src/gt4py/next/field_utils.py
@@ -0,0 +1,22 @@
+# GT4Py - GridTools Framework
+#
+# Copyright (c) 2014-2023, ETH Zurich
+# All rights reserved.
+#
+# This file is part of the GT4Py project and the GridTools framework.
+# GT4Py is free software: you can redistribute it and/or modify it under
+# the terms of the GNU General Public License as published by the
+# Free Software Foundation, either version 3 of the License, or any later
+# version. See the LICENSE.txt file at the top-level directory of this
+# distribution for a copy of the license or check .
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+import numpy as np
+
+from gt4py.next import common, utils
+
+
+@utils.tree_map
+def asnumpy(field: common.Field | np.ndarray) -> np.ndarray:
+ return field.asnumpy() if common.is_field(field) else field # type: ignore[return-value] # mypy doesn't understand the condition
diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py
index b02d6c8d72..b00e53bfd9 100644
--- a/src/gt4py/next/iterator/embedded.py
+++ b/src/gt4py/next/iterator/embedded.py
@@ -196,7 +196,7 @@ def field_setitem(self, indices: NamedFieldIndices, value: Any) -> None:
#: Column range used in column mode (`column_axis != None`) in the current closure execution context.
-column_range_cvar: cvars.ContextVar[range] = next_embedded.context.closure_column_range
+column_range_cvar: cvars.ContextVar[common.NamedRange] = next_embedded.context.closure_column_range
#: Offset provider dict in the current closure execution context.
offset_provider_cvar: cvars.ContextVar[OffsetProvider] = next_embedded.context.offset_provider
@@ -211,8 +211,8 @@ class Column(np.lib.mixins.NDArrayOperatorsMixin):
def __init__(self, kstart: int, data: np.ndarray | Scalar) -> None:
self.kstart = kstart
assert isinstance(data, (np.ndarray, Scalar)) # type: ignore # mypy bug #11673
- column_range = column_range_cvar.get()
- self.data = data if isinstance(data, np.ndarray) else np.full(len(column_range), data)
+ column_range: common.NamedRange = column_range_cvar.get()
+ self.data = data if isinstance(data, np.ndarray) else np.full(len(column_range[1]), data)
def __getitem__(self, i: int) -> Any:
result = self.data[i - self.kstart]
@@ -746,7 +746,7 @@ def _make_tuple(
except embedded_exceptions.IndexOutOfBounds:
return _UNDEFINED
else:
- column_range = column_range_cvar.get()
+ column_range = column_range_cvar.get()[1]
assert column_range is not None
col: list[
@@ -823,7 +823,7 @@ def deref(self) -> Any:
assert isinstance(k_pos, int)
# the following range describes a range in the field
# (negative values are relative to the origin, not relative to the size)
- slice_column[self.column_axis] = range(k_pos, k_pos + len(column_range))
+ slice_column[self.column_axis] = range(k_pos, k_pos + len(column_range[1]))
assert _is_concrete_position(shifted_pos)
position = {**shifted_pos, **slice_column}
@@ -864,7 +864,7 @@ def make_in_iterator(
init = [None] * sparse_dimensions.count(sparse_dim)
new_pos[sparse_dim] = init # type: ignore[assignment] # looks like mypy is confused
if column_axis is not None:
- column_range = column_range_cvar.get()
+ column_range = column_range_cvar.get()[1]
# if we deal with column stencil the column position is just an offset by which the whole column needs to be shifted
assert column_range is not None
new_pos[column_axis] = column_range.start
@@ -1479,7 +1479,7 @@ def _column_dtype(elem: Any) -> np.dtype:
@builtins.scan.register(EMBEDDED)
def scan(scan_pass, is_forward: bool, init):
def impl(*iters: ItIterator):
- column_range = column_range_cvar.get()
+ column_range = column_range_cvar.get()[1]
if column_range is None:
raise RuntimeError("Column range is not defined, cannot scan.")
@@ -1532,7 +1532,10 @@ def closure(
column = ColumnDescriptor(column_axis.value, domain[column_axis.value])
del domain[column_axis.value]
- column_range = column.col_range
+ column_range = (
+ column_axis,
+ common.UnitRange(column.col_range.start, column.col_range.stop),
+ )
out = as_tuple_field(out) if is_tuple_of_field(out) else _wrap_field(out)
diff --git a/src/gt4py/next/utils.py b/src/gt4py/next/utils.py
index baae8361c5..ec459906e0 100644
--- a/src/gt4py/next/utils.py
+++ b/src/gt4py/next/utils.py
@@ -15,10 +15,6 @@
import functools
from typing import Any, Callable, ClassVar, ParamSpec, TypeGuard, TypeVar, cast
-import numpy as np
-
-from gt4py.next import common
-
class RecursionGuard:
"""
@@ -57,7 +53,6 @@ def __exit__(self, *exc):
_T = TypeVar("_T")
-
_P = ParamSpec("_P")
_R = TypeVar("_R")
@@ -66,8 +61,17 @@ def is_tuple_of(v: Any, t: type[_T]) -> TypeGuard[tuple[_T, ...]]:
return isinstance(v, tuple) and all(isinstance(e, t) for e in v)
+# TODO(havogt): remove flatten duplications in the whole codebase
+def flatten_nested_tuple(value: tuple[_T | tuple, ...]) -> tuple[_T, ...]:
+ if isinstance(value, tuple):
+ return sum((flatten_nested_tuple(v) for v in value), start=()) # type: ignore[arg-type] # cannot properly express nesting
+ else:
+ return (value,)
+
+
def tree_map(fun: Callable[_P, _R]) -> Callable[..., _R | tuple[_R | tuple, ...]]:
- """Apply `fun` to each entry of (possibly nested) tuples.
+ """
+ Apply `fun` to each entry of (possibly nested) tuples.
Examples:
>>> tree_map(lambda x: x + 1)(((1, 2), 3))
@@ -88,9 +92,3 @@ def impl(*args: Any | tuple[Any | tuple, ...]) -> _R | tuple[_R | tuple, ...]:
) # mypy doesn't understand that `args` at this point is of type `_P.args`
return impl
-
-
-# TODO(havogt): consider moving to module like `field_utils`
-@tree_map
-def asnumpy(field: common.Field | np.ndarray) -> np.ndarray:
- return field.asnumpy() if common.is_field(field) else field # type: ignore[return-value] # mypy doesn't understand the condition
diff --git a/tests/next_tests/exclusion_matrices.py b/tests/next_tests/exclusion_matrices.py
index a6a302e143..6d0b7d3c10 100644
--- a/tests/next_tests/exclusion_matrices.py
+++ b/tests/next_tests/exclusion_matrices.py
@@ -133,7 +133,6 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum):
(USES_ZERO_DIMENSIONAL_FIELDS, XFAIL, UNSUPPORTED_MESSAGE),
]
EMBEDDED_SKIP_LIST = [
- (USES_SCAN, XFAIL, UNSUPPORTED_MESSAGE),
(USES_DYNAMIC_OFFSETS, XFAIL, UNSUPPORTED_MESSAGE),
(CHECKS_SPECIFIC_ERROR, XFAIL, UNSUPPORTED_MESSAGE),
]
diff --git a/tests/next_tests/integration_tests/cases.py b/tests/next_tests/integration_tests/cases.py
index 81f216397b..b1e26b40cb 100644
--- a/tests/next_tests/integration_tests/cases.py
+++ b/tests/next_tests/integration_tests/cases.py
@@ -28,7 +28,7 @@
from gt4py._core import definitions as core_defs
from gt4py.eve import extended_typing as xtyping
from gt4py.eve.extended_typing import Self
-from gt4py.next import common, constructors, utils
+from gt4py.next import common, constructors, field_utils
from gt4py.next.ffront import decorator
from gt4py.next.program_processors import processor_interface as ppi
from gt4py.next.type_system import type_specifications as ts, type_translation
@@ -436,8 +436,8 @@ def verify(
out_comp = out or inout
assert out_comp is not None
- out_comp_ndarray = utils.asnumpy(out_comp)
- ref_ndarray = utils.asnumpy(ref)
+ out_comp_ndarray = field_utils.asnumpy(out_comp)
+ ref_ndarray = field_utils.asnumpy(ref)
assert comparison(ref_ndarray, out_comp_ndarray), (
f"Verification failed:\n"
f"\tcomparison={comparison.__name__}(ref, out)\n"
diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py
index cf273a4524..abb23611e0 100644
--- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py
+++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py
@@ -292,6 +292,7 @@ def testee_op(
cases.verify(cartesian_case, testee_op, qc, tuple_scalar, out=qc, ref=expected)
+@pytest.mark.uses_cartesian_shift
@pytest.mark.uses_scan
@pytest.mark.uses_index_fields
def test_scalar_scan_vertical_offset(cartesian_case): # noqa: F811 # fixtures
@@ -801,6 +802,85 @@ def simple_scan_operator(carry: float, a: tuple[float, float]) -> float:
cases.verify(cartesian_case, simple_scan_operator, (inp1, inp2), out=out, ref=expected)
+@pytest.mark.uses_scan
+def test_scan_different_domain_in_tuple(cartesian_case):
+ init = 1.0
+ i_size = cartesian_case.default_sizes[IDim]
+ k_size = cartesian_case.default_sizes[KDim]
+
+ inp1_np = np.ones(
+ (
+ i_size + 1,
+ k_size,
+ )
+ ) # i_size bigger than in the other argument
+ inp2_np = np.fromfunction(lambda i, k: k, shape=(i_size, k_size), dtype=float)
+ inp1 = cartesian_case.as_field([IDim, KDim], inp1_np)
+ inp2 = cartesian_case.as_field([IDim, KDim], inp2_np)
+ out = cartesian_case.as_field([IDim, KDim], np.zeros((i_size, k_size)))
+
+ def prev_levels_iterator(i):
+ return range(i + 1)
+
+ expected = np.asarray(
+ [
+ reduce(
+ lambda prev, k: prev + inp1_np[:-1, k] + inp2_np[:, k],
+ prev_levels_iterator(k),
+ init,
+ )
+ for k in range(k_size)
+ ]
+ ).transpose()
+
+ @gtx.scan_operator(axis=KDim, forward=True, init=init)
+ def scan_op(carry: float, a: tuple[float, float]) -> float:
+ return carry + a[0] + a[1]
+
+ @gtx.field_operator
+ def foo(
+ inp1: gtx.Field[[IDim, KDim], float], inp2: gtx.Field[[IDim, KDim], float]
+ ) -> gtx.Field[[IDim, KDim], float]:
+ return scan_op((inp1, inp2))
+
+ cases.verify(cartesian_case, foo, inp1, inp2, out=out, ref=expected)
+
+
+@pytest.mark.uses_scan
+def test_scan_tuple_field_scalar_mixed(cartesian_case):
+ init = 1.0
+ i_size = cartesian_case.default_sizes[IDim]
+ k_size = cartesian_case.default_sizes[KDim]
+
+ inp2_np = np.fromfunction(lambda i, k: k, shape=(i_size, k_size), dtype=float)
+ inp2 = cartesian_case.as_field([IDim, KDim], inp2_np)
+ out = cartesian_case.as_field([IDim, KDim], np.zeros((i_size, k_size)))
+
+ def prev_levels_iterator(i):
+ return range(i + 1)
+
+ expected = np.asarray(
+ [
+ reduce(
+ lambda prev, k: prev + 1.0 + inp2_np[:, k],
+ prev_levels_iterator(k),
+ init,
+ )
+ for k in range(k_size)
+ ]
+ ).transpose()
+
+ @gtx.scan_operator(axis=KDim, forward=True, init=init)
+ def scan_op(carry: float, a: tuple[float, float]) -> float:
+ return carry + a[0] + a[1]
+
+ @gtx.field_operator
+ def foo(inp1: float, inp2: gtx.Field[[IDim, KDim], float]) -> gtx.Field[[IDim, KDim], float]:
+ return scan_op((inp1, inp2))
+
+ cases.verify(cartesian_case, foo, 1.0, inp2, out=out, ref=expected)
+
+
def test_docstring(cartesian_case):
@gtx.field_operator
def fieldop_with_docstring(a: cases.IField) -> cases.IField:
diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py
index fd571514ac..9ba8eef3a3 100644
--- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py
+++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py
@@ -16,7 +16,7 @@
import pytest
import gt4py.next as gtx
-from gt4py.next import utils
+from gt4py.next import field_utils
from gt4py.next.iterator.builtins import *
from gt4py.next.iterator.runtime import closure, fendef, fundef, offset
@@ -158,7 +158,7 @@ def test_k_level_condition(program_processor, lift_mode, fun, k_level, inp_funct
k_size = 5
inp = inp_function(k_size)
- ref = ref_function(utils.asnumpy(inp))
+ ref = ref_function(field_utils.asnumpy(inp))
out = gtx.as_field([KDim], np.zeros((5,), dtype=np.int32))
diff --git a/tests/next_tests/unit_tests/embedded_tests/test_common.py b/tests/next_tests/unit_tests/embedded_tests/test_common.py
index 640ed326bb..de511fdabb 100644
--- a/tests/next_tests/unit_tests/embedded_tests/test_common.py
+++ b/tests/next_tests/unit_tests/embedded_tests/test_common.py
@@ -19,7 +19,7 @@
from gt4py.next import common
from gt4py.next.common import UnitRange
from gt4py.next.embedded import exceptions as embedded_exceptions
-from gt4py.next.embedded.common import _slice_range, sub_domain
+from gt4py.next.embedded.common import _slice_range, iterate_domain, sub_domain
@pytest.mark.parametrize(
@@ -135,3 +135,15 @@ def test_sub_domain(domain, index, expected):
expected = common.domain(expected)
result = sub_domain(domain, index)
assert result == expected
+
+
+def test_iterate_domain():
+ domain = common.domain({I: 2, J: 3})
+ ref = []
+ for i in domain[I][1]:
+ for j in domain[J][1]:
+ ref.append(((I, i), (J, j)))
+
+ testee = list(iterate_domain(domain))
+
+ assert testee == ref
diff --git a/tests/next_tests/unit_tests/iterator_tests/test_embedded_internals.py b/tests/next_tests/unit_tests/iterator_tests/test_embedded_internals.py
index 3a35570ca2..9238cd4f7a 100644
--- a/tests/next_tests/unit_tests/iterator_tests/test_embedded_internals.py
+++ b/tests/next_tests/unit_tests/iterator_tests/test_embedded_internals.py
@@ -19,13 +19,14 @@
import numpy as np
import pytest
+from gt4py.next import common
from gt4py.next.iterator import embedded
def _run_within_context(
func: Callable[[], Any],
*,
- column_range: Optional[range] = None,
+ column_range: Optional[common.NamedRange] = None,
offset_provider: Optional[embedded.OffsetProvider] = None,
) -> Any:
def wrapped_func():
@@ -59,7 +60,10 @@ def test_func(data_a: int, data_b: int):
# Setting an invalid column_range here shouldn't affect other contexts
embedded.column_range_cvar.set(range(2, 999))
- _run_within_context(lambda: test_func(2, 3), column_range=range(0, 3))
+ _run_within_context(
+ lambda: test_func(2, 3),
+ column_range=(common.Dimension("K", kind=common.DimensionKind.VERTICAL), range(0, 3)),
+ )
def test_column_ufunc_with_scalar():