Skip to content

Commit

Permalink
feat[next][dace]: Support for reduce-unroll special case (#1381)
Browse files Browse the repository at this point in the history
During integration of icon4py stencils with the DaCe backend, it was found that reduce-unroll can generate an ITIR containing can_deref on a scalar value. Such expression should always evaluate to true, so it can be evaluated at compile-time.
Note that in theory such case could be detected by the ITIR pass, once ITIR type inference is replaced by a new solution. At that time, the solution proposed here should be removed.
  • Loading branch information
edopao authored Dec 4, 2023
1 parent b1f9c9a commit 8a22ba7
Showing 1 changed file with 25 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -289,10 +289,25 @@ def builtin_can_deref(
assert shift_callable.fun.id == "shift"
iterator = transformer._visit_shift(can_deref_callable)

# this iterator is accessing a neighbor table, so it should return an index
assert iterator.dtype in dace.dtypes.INTEGER_TYPES
# TODO: remove this special case when ITIR reduce-unroll pass is able to catch it
if not isinstance(iterator, IteratorExpr):
assert len(iterator) == 1 and isinstance(iterator[0], ValueExpr)
# We can always deref a value expression, therefore hard-code `can_deref` to True.
# Returning a SymbolExpr would be preferable, but it requires update to type-checking.
result_name = unique_var_name()
transformer.context.body.add_scalar(result_name, dace.dtypes.bool, transient=True)
result_node = transformer.context.state.add_access(result_name)
transformer.context.state.add_edge(
transformer.context.state.add_tasklet("can_always_deref", {}, {"_out"}, "_out = True"),
"_out",
result_node,
None,
dace.Memlet.simple(result_name, "0"),
)
return [ValueExpr(result_node, dace.dtypes.bool)]

# create tasklet to check that field indices are non-negative (-1 is invalid)
args = [ValueExpr(access_node, iterator.dtype) for access_node in iterator.indices.values()]
args = [ValueExpr(access_node, _INDEX_DTYPE) for access_node in iterator.indices.values()]
internals = [f"{arg.value.data}_v" for arg in args]
expr_code = " and ".join([f"{v} >= 0" for v in internals])

Expand Down Expand Up @@ -833,14 +848,20 @@ def _make_shift_for_rest(self, rest, iterator):
fun=itir.FunCall(fun=itir.SymRef(id="shift"), args=rest), args=[iterator]
)

def _visit_shift(self, node: itir.FunCall) -> IteratorExpr:
def _visit_shift(self, node: itir.FunCall) -> IteratorExpr | list[ValueExpr]:
shift = node.fun
assert isinstance(shift, itir.FunCall)
tail, rest = self._split_shift_args(shift.args)
if rest:
iterator = self.visit(self._make_shift_for_rest(rest, node.args[0]))
else:
iterator = self.visit(node.args[0])
if not isinstance(iterator, IteratorExpr):
# shift cannot be applied because the argument is not iterable
# TODO: remove this special case when ITIR reduce-unroll pass is able to catch it
assert isinstance(iterator, list) and len(iterator) == 1
assert isinstance(iterator[0], ValueExpr)
return iterator

assert isinstance(tail[0], itir.OffsetLiteral)
offset_dim = tail[0].value
Expand Down

0 comments on commit 8a22ba7

Please sign in to comment.