Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[next]: prototype minimal premap-relocation (staggering) #1617

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions src/gt4py/next/ffront/foast_passes/type_deduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,12 +484,12 @@ def visit_Subscript(self, node: foast.Subscript, **kwargs: Any) -> foast.Subscri
# for cartesian axes (e.g. I, J) the index of the subscript only
# signifies the displacement in the respective dimension,
# but does not change the target type.
if source != target:
raise errors.DSLError(
new_value.location,
"Source and target must be equal for offsets with a single target.",
)
new_type = new_value.type
# if source != target:
# raise errors.DSLError(
# new_value.location,
# "Source and target must be equal for offsets with a single target.",
# )
new_type = ts.OffsetType(source=source, target=(target,))
case _:
raise errors.DSLError(
new_value.location, "Could not deduce type of subscript expression."
Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/next/ffront/transform_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def _deduce_grid_type(
"""

def is_cartesian_offset(o: fbuiltins.FieldOffset) -> bool:
return len(o.target) == 1 and o.source == o.target[0]
return len(o.target) == 1 # and o.source == o.target[0]

deduced_grid_type = common.GridType.CARTESIAN
for o in offsets_and_dimensions:
Expand Down
9 changes: 9 additions & 0 deletions src/gt4py/next/iterator/embedded.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,6 +629,15 @@ def execute_shift(
else:
raise AssertionError()
return new_pos
elif isinstance(offset_implementation, tuple):
source, target = offset_implementation
new_pos = copy.copy(pos)
new_pos.pop(target.value)
if common.is_int_index(value := pos[target.value]):
new_pos[source.value] = value + index
else:
raise AssertionError()
return new_pos
else:
assert isinstance(offset_implementation, common.Connectivity)
assert offset_implementation.origin_axis.value in pos
Expand Down
11 changes: 11 additions & 0 deletions src/gt4py/next/iterator/type_system/type_synthesizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def _(lhs: ts.ScalarType, rhs: ts.ScalarType) -> ts.ScalarType | ts.TupleType:
@_register_builtin_type_synthesizer
def deref(it: it_ts.IteratorType) -> ts.DataType:
assert isinstance(it, it_ts.IteratorType)
print(it)
assert _is_derefable_iterator_type(it)
return it.element_type

Expand Down Expand Up @@ -339,6 +340,7 @@ def apply_shift(it: it_ts.IteratorType) -> it_ts.IteratorType:
assert isinstance(offset_axis, it_ts.OffsetLiteralType) and isinstance(
offset_axis.value, common.Dimension
)
print(offset_provider)
provider = offset_provider[offset_axis.value.value] # TODO: naming
if isinstance(provider, common.Dimension):
pass
Expand All @@ -350,6 +352,15 @@ def apply_shift(it: it_ts.IteratorType) -> it_ts.IteratorType:
new_position_dims[i] = provider.neighbor_axis
found = True
assert found
elif isinstance(provider, tuple):
source, target = provider
found = False
for i, dim in enumerate(new_position_dims):
if dim.value == target.value:
assert not found
new_position_dims[i] = source
found = True
assert found
else:
raise NotImplementedError()
return it_ts.IteratorType(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,8 @@ def _process_connectivity_args(
)
elif isinstance(connectivity, Dimension):
pass
elif isinstance(connectivity, tuple):
assert all(isinstance(dim, Dimension) for dim in connectivity)
else:
raise AssertionError(
f"Expected offset provider '{name}' to be a 'Connectivity' or 'Dimension', "
Expand Down
121 changes: 99 additions & 22 deletions src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from gt4py.eve import utils as eve_utils
from gt4py.eve.concepts import SymbolName
from gt4py.next import common
from gt4py.next.ffront import fbuiltins
from gt4py.next.iterator import ir as itir
from gt4py.next.iterator.type_system import inference as itir_type_inference
from gt4py.next.program_processors.codegens.gtfn.gtfn_ir import (
Expand Down Expand Up @@ -144,24 +145,37 @@ def _collect_offset_definitions(
offset_definitions = {}

for offset_name, dim_or_connectivity in offset_provider.items():
if isinstance(dim_or_connectivity, common.Dimension):
dim: common.Dimension = dim_or_connectivity
if grid_type == common.GridType.CARTESIAN:
# create alias from offset to dimension
offset_definitions[dim.value] = TagDefinition(name=Sym(id=dim.value))
offset_definitions[offset_name] = TagDefinition(
name=Sym(id=offset_name), alias=SymRef(id=dim.value)
)
else:
assert grid_type == common.GridType.UNSTRUCTURED
if not dim.kind == common.DimensionKind.VERTICAL:
raise ValueError(
"Mapping an offset to a horizontal dimension in unstructured is not allowed."
)
# create alias from vertical offset to vertical dimension
offset_definitions[offset_name] = TagDefinition(
name=Sym(id=offset_name), alias=SymRef(id=dim.value)
)
if isinstance(dim_or_connectivity, (common.Dimension, tuple)):
if isinstance(dim_or_connectivity, common.Dimension):
dim_or_connectivity = (dim_or_connectivity,)
for dim in dim_or_connectivity:
assert isinstance(dim, common.Dimension)
if grid_type == common.GridType.CARTESIAN:
# create alias from offset to dimension
if offset_name in offset_definitions:
offset_definitions[dim.value] = TagDefinition(
name=Sym(id=dim.value), alias=SymRef(id=offset_name)
)
else:
offset_definitions[dim.value] = TagDefinition(name=Sym(id=dim.value))
offset_definitions[offset_name] = TagDefinition(
name=Sym(id=offset_name), alias=SymRef(id=dim.value)
)
else:
assert grid_type == common.GridType.UNSTRUCTURED
if not dim.kind == common.DimensionKind.VERTICAL:
raise ValueError(
"Mapping an offset to a horizontal dimension in unstructured is not allowed."
)
if offset_name in offset_definitions:
offset_definitions[dim.value] = TagDefinition(
name=Sym(id=dim.value), alias=SymRef(id=offset_name)
)
else:
# create alias from vertical offset to vertical dimension
offset_definitions[offset_name] = TagDefinition(
name=Sym(id=offset_name), alias=SymRef(id=dim.value)
)
elif isinstance(dim_or_connectivity, common.Connectivity):
assert grid_type == common.GridType.UNSTRUCTURED
offset_definitions[offset_name] = TagDefinition(name=Sym(id=offset_name))
Expand Down Expand Up @@ -229,6 +243,67 @@ def apply(
return cls().visit(node)


@dataclasses.dataclass
class ResolveRelocation(eve.NodeTranslator):
offset_remapping: dict[str, common.Dimension]
dims_remapping: dict[common.Dimension, common.Dimension]

@classmethod
def apply(
cls,
node: itir.Program,
*,
offset_provider: dict,
) -> Program:
offset_remapping = {}
dims_remapping = {}
offset_provider = offset_provider.copy()
for v in [*offset_provider.values()]:
if isinstance(v, fbuiltins.FieldOffset):
assert len(v.target) == 1
if v.target[0] in dims_remapping.values():
assert v.source in dims_remapping
else:
dims_remapping[v.source] = v.target[0]
offset_remapping[v.value] = dims_remapping.get(v.target[0], v.target[0])
offset_provider[v.source.value] = common.Dimension(
v.target[0].value
) # super ugly hack: this introduces in gtfn an alias from the original dimension to the mapped one
offset_provider[v.value] = v.target[0]
return cls(offset_remapping=offset_remapping, dims_remapping=dims_remapping).visit(
node
), offset_provider

def visit_OffsetLiteral(self, node: itir.OffsetLiteral) -> itir.OffsetLiteral:
# if isinstance(node.value, str):
# return itir.OffsetLiteral(value=self.offset_remapping.get(node.value, node).value)
return node

def visit_FunCall(self, node: itir.FunCall) -> itir.FunCall:
if node.fun == itir.SymRef(id="named_range"):
return itir.FunCall(
fun=node.fun,
args=[
self.dims_remapping.get(a, a) if isinstance(a, common.Dimension) else a
for a in node.args
],
)
return self.generic_visit(node)

def visit_Node(self, node: itir.Node) -> itir.Node:
if node.type is not None:
if isinstance(node.type, ts.FieldType) and any(
d in self.dims_remapping for d in type_info.extract_dims(node.type)
):
node.type = ts.FieldType(
dims=[self.dims_remapping.get(d, d) for d in node.type.dims],
dtype=node.type.dtype,
)
print(node.type)
return node
return self.generic_visit(node)


@dataclasses.dataclass(frozen=True)
class GTFN_lowering(eve.NodeTranslator, eve.VisitorWithSymbolTableTrait):
_binary_op_map: ClassVar[dict[str, str]] = {
Expand Down Expand Up @@ -270,6 +345,8 @@ def apply(
if not isinstance(node, itir.Program):
raise TypeError(f"Expected a 'Program', got '{type(node).__name__}'.")

node, offset_provider = ResolveRelocation.apply(node, offset_provider=offset_provider)
print(node)
node = itir_type_inference.infer(node, offset_provider=offset_provider)
grid_type = _get_gridtype(node.body)
if grid_type == common.GridType.UNSTRUCTURED:
Expand Down Expand Up @@ -559,10 +636,10 @@ def visit_Program(self, node: itir.Program, **kwargs: Any) -> Program:
executions = self.visit(node.body, extracted_functions=extracted_functions)
executions = self._merge_scans(executions)
function_definitions = self.visit(node.function_definitions) + extracted_functions
offset_definitions = {
**_collect_dimensions_from_domain(node.body),
**_collect_offset_definitions(node, self.grid_type, self.offset_provider),
}
offset_definitions = _collect_offset_definitions(node, self.grid_type, self.offset_provider)
for k, v in _collect_dimensions_from_domain(node.body).items():
if k not in offset_definitions:
offset_definitions[k] = v
return Program(
id=SymbolName(node.id),
params=self.visit(node.params),
Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/next/program_processors/runners/gtfn.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def extract_connectivity_args(
# copying to device here is a fallback for easy testing and might be removed later
conn_arg = _ensure_is_on_device(conn.table, device)
args.append((conn_arg, tuple([0] * 2)))
elif isinstance(conn, common.Dimension):
elif isinstance(conn, (common.Dimension, tuple)):
pass
else:
raise AssertionError(
Expand Down
2 changes: 2 additions & 0 deletions tests/next_tests/integration_tests/cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,9 +344,11 @@ def allocate(
Useful for shifted fields, which must start off bigger
than the output field in the shifted dimension.
"""
print(sizes)
sizes = extend_sizes(
case.default_sizes | (sizes or {}), extend
) # TODO: this should take into account the Domain of the allocated field
print(sizes)
arg_type = get_param_types(fieldview_prog)[name]
if strategy is None:
if name in ["out", RETURN]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,14 @@
E2V,
V2E,
E2VDim,
Edge,
IDim,
Ioff,
JDim,
KDim,
Koff,
V2EDim,
Vertex,
Edge,
cartesian_case,
unstructured_case,
)
Expand Down Expand Up @@ -85,6 +85,30 @@ def testee(a: cases.IJKField) -> cases.IJKField:
cases.verify(cartesian_case, testee, a, out=out, ref=a[1:])


@pytest.mark.uses_cartesian_shift
def test_cartesian_shift_staggering(cartesian_case):
IHalfDim = gtx.Dimension("IHalfDim")
IPlusHalf = gtx.FieldOffset("IPlusHalf", source=IDim, target=(IHalfDim,))

@gtx.field_operator
def testee(a: cases.IJKField) -> gtx.Field[[IHalfDim, JDim, KDim], np.int32]:
return a(IPlusHalf[0])

a = cases.allocate(cartesian_case, testee, "a")()
out = cases.allocate(
cartesian_case, testee, cases.RETURN, sizes={IHalfDim: cartesian_case.default_sizes[IDim]}
)()

cases.verify(
cartesian_case,
testee,
a,
out=out,
ref=a,
offset_provider={"IPlusHalf": (IDim, IHalfDim)}, # source to target
)


@pytest.mark.uses_unstructured_shift
def test_unstructured_shift(unstructured_case):
@gtx.field_operator
Expand Down
Loading