From b1f9c9a567e01c14e7236b41fa95f09ae1bb3e2a Mon Sep 17 00:00:00 2001 From: edopao Date: Mon, 4 Dec 2023 09:15:16 +0100 Subject: [PATCH] feat[next][dace]: Support for sparse fields and reductions over lift expressions (#1377) This PR adds support to DaCe backend for sparse fields and reductions over lift expressions. --- .../runners/dace_iterator/itir_to_sdfg.py | 6 +- .../runners/dace_iterator/itir_to_tasklet.py | 181 +++++++++++++----- .../runners/dace_iterator/utility.py | 7 + tests/next_tests/exclusion_matrices.py | 3 - .../ffront_tests/test_execution.py | 1 + 5 files changed, 142 insertions(+), 56 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py index 94878fd46d..271a79c04b 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py @@ -150,7 +150,7 @@ def get_output_nodes( def visit_FencilDefinition(self, node: itir.FencilDefinition): program_sdfg = dace.SDFG(name=node.id) - last_state = program_sdfg.add_state("program_entry") + last_state = program_sdfg.add_state("program_entry", True) self.node_types = itir_typing.infer_all(node) # Filter neighbor tables from offset providers. @@ -216,7 +216,7 @@ def visit_StencilClosure( # Create the closure's nested SDFG and single state. closure_sdfg = dace.SDFG(name="closure") closure_state = closure_sdfg.add_state("closure_entry") - closure_init_state = closure_sdfg.add_state_before(closure_state, "closure_init") + closure_init_state = closure_sdfg.add_state_before(closure_state, "closure_init", True) input_names = [str(inp.id) for inp in node.inputs] neighbor_tables = filter_neighbor_tables(self.offset_provider) @@ -423,7 +423,7 @@ def _visit_scan_stencil_closure( scan_sdfg = dace.SDFG(name="scan") # create a state machine for lambda call over the scan dimension - start_state = scan_sdfg.add_state("start") + start_state = scan_sdfg.add_state("start", True) lambda_state = scan_sdfg.add_state("lambda_compute") end_state = scan_sdfg.add_state("end") diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py index da54f9be14..de18446bbe 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py @@ -18,6 +18,7 @@ import dace import numpy as np +from dace import subsets from dace.transformation.dataflow import MapFusion from dace.transformation.passes.prune_symbols import RemoveUnusedSymbols @@ -39,6 +40,7 @@ filter_neighbor_tables, flatten_list, map_nested_sdfg_symbols, + new_array_symbols, unique_name, unique_var_name, ) @@ -131,9 +133,13 @@ def get_reduce_identity_value(op_name_: str, type_: Any): } +# Define type of variables used for field indexing +_INDEX_DTYPE = _TYPE_MAPPING["int64"] + + @dataclasses.dataclass class SymbolExpr: - value: str | dace.symbolic.sympy.Basic + value: dace.symbolic.SymbolicType dtype: dace.typeclass @@ -226,7 +232,7 @@ def builtin_neighbors( outputs={"__result"}, ) idx_name = unique_var_name() - sdfg.add_scalar(idx_name, dace.int64, transient=True) + sdfg.add_scalar(idx_name, _INDEX_DTYPE, transient=True) state.add_memlet_path( state.add_access(table_name), me, @@ -283,10 +289,12 @@ def builtin_can_deref( assert shift_callable.fun.id == "shift" iterator = transformer._visit_shift(can_deref_callable) + # this iterator is accessing a neighbor table, so it should return an index + assert iterator.dtype in dace.dtypes.INTEGER_TYPES # create tasklet to check that field indices are non-negative (-1 is invalid) - args = [ValueExpr(iterator.indices[dim], iterator.dtype) for dim in iterator.dimensions] + args = [ValueExpr(access_node, iterator.dtype) for access_node in iterator.indices.values()] internals = [f"{arg.value.data}_v" for arg in args] - expr_code = " && ".join([f"{v} >= 0" for v in internals]) + expr_code = " and ".join([f"{v} >= 0" for v in internals]) # TODO(edopao): select-memlet could maybe allow to efficiently translate can_deref to predicative execution return transformer.add_expr_tasklet( @@ -309,6 +317,26 @@ def builtin_if( return transformer.add_expr_tasklet(expr_args, expr, type_, "if") +def builtin_list_get( + transformer: "PythonTaskletCodegen", node: itir.Expr, node_args: list[itir.Expr] +) -> list[ValueExpr]: + args = list(itertools.chain(*transformer.visit(node_args))) + assert len(args) == 2 + # index node + assert isinstance(args[0], (SymbolExpr, ValueExpr)) + # 1D-array node + assert isinstance(args[1], ValueExpr) + # source node should be a 1D array + assert len(transformer.context.body.arrays[args[1].value.data].shape) == 1 + + expr_args = [(arg, f"{arg.value.data}_v") for arg in args if not isinstance(arg, SymbolExpr)] + internals = [ + arg.value if isinstance(arg, SymbolExpr) else f"{arg.value.data}_v" for arg in args + ] + expr = f"{internals[1]}[{internals[0]}]" + return transformer.add_expr_tasklet(expr_args, expr, args[1].dtype, "list_get") + + def builtin_cast( transformer: "PythonTaskletCodegen", node: itir.Expr, node_args: list[itir.Expr] ) -> list[ValueExpr]: @@ -340,16 +368,13 @@ def builtin_tuple_get( raise ValueError("Tuple can only be subscripted with compile-time constants") -def builtin_undefined(*args: Any) -> Any: - raise NotImplementedError() - - _GENERAL_BUILTIN_MAPPING: dict[ str, Callable[["PythonTaskletCodegen", itir.Expr, list[itir.Expr]], list[ValueExpr]] ] = { "can_deref": builtin_can_deref, "cast_": builtin_cast, "if_": builtin_if, + "list_get": builtin_list_get, "make_tuple": builtin_make_tuple, "neighbors": builtin_neighbors, "tuple_get": builtin_tuple_get, @@ -387,16 +412,11 @@ def _add_symbol(self, param, arg): elif isinstance(arg, IteratorExpr): # create storage in lambda sdfg ndims = len(arg.dimensions) - shape = tuple( - dace.symbol(unique_var_name() + "__shp", dace.int64) for _ in range(ndims) - ) - strides = tuple( - dace.symbol(unique_var_name() + "__strd", dace.int64) for _ in range(ndims) - ) + shape, strides = new_array_symbols(param, ndims) self._sdfg.add_array(param, shape=shape, strides=strides, dtype=arg.dtype) index_names = {dim: f"__{param}_i_{dim}" for dim in arg.indices.keys()} for _, index_name in index_names.items(): - self._sdfg.add_scalar(index_name, dtype=dace.int64) + self._sdfg.add_scalar(index_name, dtype=_INDEX_DTYPE) # update table of lambda symbol field = self._state.add_access(param) indices = { @@ -513,14 +533,7 @@ def visit_Lambda( # Add connectivities as arrays for name in connectivity_names: - shape = ( - dace.symbol(unique_var_name() + "__shp", dace.int64), - dace.symbol(unique_var_name() + "__shp", dace.int64), - ) - strides = ( - dace.symbol(unique_var_name() + "__strd", dace.int64), - dace.symbol(unique_var_name() + "__strd", dace.int64), - ) + shape, strides = new_array_symbols(name, ndim=2) dtype = self.context.body.arrays[name].dtype lambda_sdfg.add_array(name, shape=shape, strides=strides, dtype=dtype) @@ -542,11 +555,9 @@ def visit_Lambda( result_name = unique_var_name() lambda_sdfg.add_scalar(result_name, expr.dtype, transient=True) result_access = lambda_state.add_access(result_name) - lambda_state.add_edge( + lambda_state.add_nedge( expr.value, - None, result_access, - None, # in case of reduction lambda, the output edge from lambda tasklet performs write-conflict resolution dace.Memlet.simple(result_access.data, "0", wcr_str=self.context.reduce_wcr), ) @@ -587,12 +598,13 @@ def visit_FunCall(self, node: itir.FunCall) -> list[ValueExpr] | IteratorExpr: return self._visit_reduce(node) if isinstance(node.fun, itir.SymRef): - if str(node.fun.id) in _MATH_BUILTINS_MAPPING: + builtin_name = str(node.fun.id) + if builtin_name in _MATH_BUILTINS_MAPPING: return self._visit_numeric_builtin(node) - elif str(node.fun.id) in _GENERAL_BUILTIN_MAPPING: + elif builtin_name in _GENERAL_BUILTIN_MAPPING: return self._visit_general_builtin(node) else: - raise NotImplementedError() + raise NotImplementedError(f"{builtin_name} not implemented") return self._visit_call(node) def _visit_call(self, node: itir.FunCall): @@ -697,7 +709,7 @@ def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]: for dim in sorted_dims ] args = [ValueExpr(iterator.field, iterator.dtype)] + [ - ValueExpr(iterator.indices[dim], iterator.dtype) for dim in iterator.indices + ValueExpr(iterator.indices[dim], _INDEX_DTYPE) for dim in iterator.indices ] internals = [f"{arg.value.data}_v" for arg in args] @@ -726,14 +738,88 @@ def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]: return [ValueExpr(value=result_access, dtype=iterator.dtype)] - else: + elif all([dim in iterator.indices for dim in iterator.dimensions]): + # The deref iterator has index values on all dimensions: the result will be a scalar args = [ValueExpr(iterator.field, iterator.dtype)] + [ - ValueExpr(iterator.indices[dim], iterator.dtype) for dim in sorted_dims + ValueExpr(iterator.indices[dim], _INDEX_DTYPE) for dim in sorted_dims ] internals = [f"{arg.value.data}_v" for arg in args] expr = f"{internals[0]}[{', '.join(internals[1:])}]" return self.add_expr_tasklet(list(zip(args, internals)), expr, iterator.dtype, "deref") + else: + # Not all dimensions are included in the deref index list: + # this means the ND-field will be sliced along one or more dimensions and the result will be an array + field_array = self.context.body.arrays[iterator.field.data] + result_shape = tuple( + dim_size + for dim, dim_size in zip(sorted_dims, field_array.shape) + if dim not in iterator.indices + ) + result_name = unique_var_name() + self.context.body.add_array(result_name, result_shape, iterator.dtype, transient=True) + result_array = self.context.body.arrays[result_name] + result_node = self.context.state.add_access(result_name) + + deref_connectors = ["_inp"] + [ + f"_i_{dim}" for dim in sorted_dims if dim in iterator.indices + ] + deref_nodes = [iterator.field] + [ + iterator.indices[dim] for dim in sorted_dims if dim in iterator.indices + ] + deref_memlets = [dace.Memlet.from_array(iterator.field.data, field_array)] + [ + dace.Memlet.simple(node.data, "0") for node in deref_nodes[1:] + ] + + # we create a nested sdfg in order to access the index scalar values as symbols in a memlet subset + deref_sdfg = dace.SDFG("deref") + deref_sdfg.add_array( + "_inp", field_array.shape, iterator.dtype, strides=field_array.strides + ) + for connector in deref_connectors[1:]: + deref_sdfg.add_scalar(connector, _INDEX_DTYPE) + deref_sdfg.add_array("_out", result_shape, iterator.dtype) + deref_init_state = deref_sdfg.add_state("init", True) + deref_access_state = deref_sdfg.add_state("access") + deref_sdfg.add_edge( + deref_init_state, + deref_access_state, + dace.InterstateEdge( + assignments={f"_sym{inp}": inp for inp in deref_connectors[1:]} + ), + ) + # we access the size in source field shape as symbols set on the nested sdfg + source_subset = tuple( + f"_sym_i_{dim}" if dim in iterator.indices else f"0:{size}" + for dim, size in zip(sorted_dims, field_array.shape) + ) + deref_access_state.add_nedge( + deref_access_state.add_access("_inp"), + deref_access_state.add_access("_out"), + dace.Memlet( + data="_out", + subset=subsets.Range.from_array(result_array), + other_subset=",".join(source_subset), + ), + ) + + deref_node = self.context.state.add_nested_sdfg( + deref_sdfg, + self.context.body, + inputs=set(deref_connectors), + outputs={"_out"}, + ) + for connector, node, memlet in zip(deref_connectors, deref_nodes, deref_memlets): + self.context.state.add_edge(node, None, deref_node, connector, memlet) + self.context.state.add_edge( + deref_node, + "_out", + result_node, + None, + dace.Memlet.from_array(result_name, result_array), + ) + return [ValueExpr(result_node, iterator.dtype)] + def _split_shift_args( self, args: list[itir.Expr] ) -> tuple[list[itir.Expr], Optional[list[itir.Expr]]]: @@ -760,6 +846,7 @@ def _visit_shift(self, node: itir.FunCall) -> IteratorExpr: offset_dim = tail[0].value assert isinstance(offset_dim, str) offset_node = self.visit(tail[1])[0] + assert offset_node.dtype in dace.dtypes.INTEGER_TYPES if isinstance(self.offset_provider[offset_dim], NeighborTableOffsetProvider): offset_provider = self.offset_provider[offset_dim] @@ -769,7 +856,7 @@ def _visit_shift(self, node: itir.FunCall) -> IteratorExpr: target_dim = offset_provider.neighbor_axis.value args = [ ValueExpr(connectivity, offset_provider.table.dtype), - ValueExpr(iterator.indices[shifted_dim], dace.int64), + ValueExpr(iterator.indices[shifted_dim], offset_node.dtype), offset_node, ] internals = [f"{arg.value.data}_v" for arg in args] @@ -780,7 +867,7 @@ def _visit_shift(self, node: itir.FunCall) -> IteratorExpr: shifted_dim = offset_provider.origin_axis.value target_dim = offset_provider.neighbor_axis.value args = [ - ValueExpr(iterator.indices[shifted_dim], dace.int64), + ValueExpr(iterator.indices[shifted_dim], offset_node.dtype), offset_node, ] internals = [f"{arg.value.data}_v" for arg in args] @@ -791,14 +878,14 @@ def _visit_shift(self, node: itir.FunCall) -> IteratorExpr: shifted_dim = self.offset_provider[offset_dim].value target_dim = shifted_dim args = [ - ValueExpr(iterator.indices[shifted_dim], dace.int64), + ValueExpr(iterator.indices[shifted_dim], offset_node.dtype), offset_node, ] internals = [f"{arg.value.data}_v" for arg in args] expr = f"{internals[0]} + {internals[1]}" shifted_value = self.add_expr_tasklet( - list(zip(args, internals)), expr, dace.dtypes.int64, "shift" + list(zip(args, internals)), expr, offset_node.dtype, "shift" )[0].value shifted_index = {dim: value for dim, value in iterator.indices.items()} @@ -811,7 +898,7 @@ def visit_OffsetLiteral(self, node: itir.OffsetLiteral) -> list[ValueExpr]: offset = node.value assert isinstance(offset, int) offset_var = unique_var_name() - self.context.body.add_scalar(offset_var, dace.dtypes.int64, transient=True) + self.context.body.add_scalar(offset_var, _INDEX_DTYPE, transient=True) offset_node = self.context.state.add_access(offset_var) tasklet_node = self.context.state.add_tasklet( "get_offset", {}, {"__out"}, f"__out = {offset}" @@ -906,7 +993,7 @@ def _visit_reduce(self, node: itir.FunCall): # initialize the reduction result based on type of operation init_value = get_reduce_identity_value(op_name.id, result_dtype) - init_state = self.context.body.add_state_before(self.context.state, "init") + init_state = self.context.body.add_state_before(self.context.state, "init", True) init_tasklet = init_state.add_tasklet( "init_reduce", {}, {"__out"}, f"__out = {init_value}" ) @@ -1044,13 +1131,13 @@ def closure_to_tasklet_sdfg( node_types: dict[int, next_typing.Type], ) -> tuple[Context, Sequence[ValueExpr]]: body = dace.SDFG("tasklet_toplevel") - state = body.add_state("tasklet_toplevel_entry") + state = body.add_state("tasklet_toplevel_entry", True) symbol_map: dict[str, TaskletExpr] = {} idx_accesses = {} for dim, idx in domain.items(): name = f"{idx}_value" - body.add_scalar(name, dtype=dace.int64, transient=True) + body.add_scalar(name, dtype=_INDEX_DTYPE, transient=True) tasklet = state.add_tasklet(f"get_{dim}", set(), {"value"}, f"value = {idx}") access = state.add_access(name) idx_accesses[dim] = access @@ -1058,15 +1145,10 @@ def closure_to_tasklet_sdfg( for name, ty in inputs: if isinstance(ty, ts.FieldType): ndim = len(ty.dims) - shape = [ - dace.symbol(f"{unique_var_name()}_shp{i}", dtype=dace.int64) for i in range(ndim) - ] - stride = [ - dace.symbol(f"{unique_var_name()}_strd{i}", dtype=dace.int64) for i in range(ndim) - ] + shape, strides = new_array_symbols(name, ndim) dims = [dim.value for dim in ty.dims] dtype = as_dace_type(ty.dtype) - body.add_array(name, shape=shape, strides=stride, dtype=dtype) + body.add_array(name, shape=shape, strides=strides, dtype=dtype) field = state.add_access(name) indices = {dim: idx_accesses[dim] for dim in domain.keys()} symbol_map[name] = IteratorExpr(field, indices, dtype, dims) @@ -1076,9 +1158,8 @@ def closure_to_tasklet_sdfg( body.add_scalar(name, dtype=dtype) symbol_map[name] = ValueExpr(state.add_access(name), dtype) for arr, name in connectivities: - shape = [dace.symbol(f"{unique_var_name()}_shp{i}", dtype=dace.int64) for i in range(2)] - stride = [dace.symbol(f"{unique_var_name()}_strd{i}", dtype=dace.int64) for i in range(2)] - body.add_array(name, shape=shape, strides=stride, dtype=arr.dtype) + shape, strides = new_array_symbols(name, ndim=2) + body.add_array(name, shape=shape, strides=strides, dtype=arr.dtype) context = Context(body, state, symbol_map) translator = PythonTaskletCodegen(offset_provider, context, node_types) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py b/src/gt4py/next/program_processors/runners/dace_iterator/utility.py index c17a39ef2d..5ae4676cd7 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/utility.py @@ -166,6 +166,13 @@ def unique_var_name(): return unique_name("__var") +def new_array_symbols(name: str, ndim: int) -> tuple[list[dace.symbol], list[dace.symbol]]: + dtype = dace.int64 + shape = [dace.symbol(unique_name(f"{name}_shp{i}"), dtype) for i in range(ndim)] + strides = [dace.symbol(unique_name(f"{name}_strd{i}"), dtype) for i in range(ndim)] + return shape, strides + + def flatten_list(node_list: list[Any]) -> list[Any]: return list( itertools.chain.from_iterable( diff --git a/tests/next_tests/exclusion_matrices.py b/tests/next_tests/exclusion_matrices.py index a6a302e143..84287e209f 100644 --- a/tests/next_tests/exclusion_matrices.py +++ b/tests/next_tests/exclusion_matrices.py @@ -122,12 +122,9 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): ] DACE_SKIP_TEST_LIST = COMMON_SKIP_TEST_LIST + [ (USES_CONSTANT_FIELDS, XFAIL, UNSUPPORTED_MESSAGE), - (USES_DYNAMIC_OFFSETS, XFAIL, UNSUPPORTED_MESSAGE), (USES_INDEX_FIELDS, XFAIL, UNSUPPORTED_MESSAGE), (USES_LIFT_EXPRESSIONS, XFAIL, UNSUPPORTED_MESSAGE), (USES_ORIGIN, XFAIL, UNSUPPORTED_MESSAGE), - (USES_REDUCTION_OVER_LIFT_EXPRESSIONS, XFAIL, UNSUPPORTED_MESSAGE), - (USES_SPARSE_FIELDS, XFAIL, UNSUPPORTED_MESSAGE), (USES_TUPLE_ARGS, XFAIL, UNSUPPORTED_MESSAGE), (USES_TUPLE_RETURNS, XFAIL, UNSUPPORTED_MESSAGE), (USES_ZERO_DIMENSIONAL_FIELDS, XFAIL, UNSUPPORTED_MESSAGE), diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index cf273a4524..7f37b41383 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -698,6 +698,7 @@ def testee( ) +@pytest.mark.uses_constant_fields @pytest.mark.uses_unstructured_shift @pytest.mark.uses_reduction_over_lift_expressions def test_ternary_builtin_neighbor_sum(unstructured_case):