Skip to content

Commit

Permalink
feat[next][dace]: GTIR-to-SDFG lowering of neighbors and reduce (#1597)
Browse files Browse the repository at this point in the history
This PR adds lowering of neighbors and reduce operators to SDFG.
  • Loading branch information
edopao authored Jul 30, 2024
1 parent abed597 commit 719c487
Show file tree
Hide file tree
Showing 8 changed files with 842 additions and 91 deletions.
10 changes: 10 additions & 0 deletions src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,16 @@ def is_applied_lift(arg: itir.Node) -> TypeGuard[itir.FunCall]:
)


def is_applied_reduce(arg: itir.Node) -> TypeGuard[itir.FunCall]:
"""Match expressions of the form `reduce(λ(...) → ...)(...)`."""
return (
isinstance(arg, itir.FunCall)
and isinstance(arg.fun, itir.FunCall)
and isinstance(arg.fun.fun, itir.SymRef)
and arg.fun.fun.id == "reduce"
)


def is_applied_shift(arg: itir.Node) -> TypeGuard[itir.FunCall]:
"""Match expressions of the form `shift(λ(...) → ...)(...)`."""
return (
Expand Down
15 changes: 4 additions & 11 deletions src/gt4py/next/iterator/transforms/fuse_maps.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from gt4py.eve import NodeTranslator, traits
from gt4py.eve.utils import UIDGenerator
from gt4py.next.iterator import ir
from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm
from gt4py.next.iterator.transforms import inline_lambdas


Expand All @@ -29,14 +30,6 @@ def _is_map(node: ir.Node) -> TypeGuard[ir.FunCall]:
)


def _is_reduce(node: ir.Node) -> TypeGuard[ir.FunCall]:
return (
isinstance(node, ir.FunCall)
and isinstance(node.fun, ir.FunCall)
and node.fun.fun == ir.SymRef(id="reduce")
)


@dataclasses.dataclass(frozen=True)
class FuseMaps(traits.PreserveLocationVisitor, traits.VisitorWithSymbolTableTrait, NodeTranslator):
"""
Expand Down Expand Up @@ -71,7 +64,7 @@ def _as_lambda(self, fun: ir.SymRef | ir.Lambda, param_count: int) -> ir.Lambda:

def visit_FunCall(self, node: ir.FunCall, **kwargs):
node = self.generic_visit(node)
if _is_map(node) or _is_reduce(node):
if _is_map(node) or cpm.is_applied_reduce(node):
if any(_is_map(arg) for arg in node.args):
first_param = (
0 if _is_map(node) else 1
Expand All @@ -83,7 +76,7 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs):
inlined_args = []
new_params = []
new_args = []
if _is_reduce(node):
if cpm.is_applied_reduce(node):
# param corresponding to reduce acc
inlined_args.append(ir.SymRef(id=outer_op.params[0].id))
new_params.append(outer_op.params[0])
Expand Down Expand Up @@ -119,7 +112,7 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs):
return ir.FunCall(
fun=ir.FunCall(fun=ir.SymRef(id="map_"), args=[new_op]), args=new_args
)
else: # _is_reduce(node)
else: # is_applied_reduce(node)
return ir.FunCall(
fun=ir.FunCall(fun=ir.SymRef(id="reduce"), args=[new_op, node.fun.args[1]]),
args=new_args,
Expand Down
9 changes: 3 additions & 6 deletions src/gt4py/next/iterator/transforms/unroll_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from gt4py.eve.utils import UIDGenerator
from gt4py.next import common
from gt4py.next.iterator import ir as itir
from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm
from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_applied_lift


Expand Down Expand Up @@ -60,16 +61,12 @@ def _get_partial_offset_tags(reduce_args: Iterable[itir.Expr]) -> Iterable[str]:
return [_get_partial_offset_tag(arg) for arg in _get_neighbors_args(reduce_args)]


def _is_reduce(node: itir.FunCall) -> TypeGuard[itir.FunCall]:
return isinstance(node.fun, itir.FunCall) and node.fun.fun == itir.SymRef(id="reduce")


def _get_connectivity(
applied_reduce_node: itir.FunCall,
offset_provider: dict[str, common.Dimension | common.Connectivity],
) -> common.Connectivity:
"""Return single connectivity that is compatible with the arguments of the reduce."""
if not _is_reduce(applied_reduce_node):
if not cpm.is_applied_reduce(applied_reduce_node):
raise ValueError("Expected a call to a 'reduce' object, i.e. 'reduce(...)(...)'.")

connectivities: list[common.Connectivity] = []
Expand Down Expand Up @@ -158,6 +155,6 @@ def _visit_reduce(self, node: itir.FunCall, **kwargs) -> itir.Expr:

def visit_FunCall(self, node: itir.FunCall, **kwargs) -> itir.Expr:
node = self.generic_visit(node, **kwargs)
if _is_reduce(node):
if cpm.is_applied_reduce(node):
return self._visit_reduce(node, **kwargs)
return node
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from gt4py.next import common as gtx_common
from gt4py.next.iterator import ir as gtir
from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm
from gt4py.next.iterator.type_system import type_specifications as itir_ts
from gt4py.next.iterator.type_system import type_specifications as gtir_ts
from gt4py.next.program_processors.runners.dace_fieldview import (
gtir_python_codegen,
gtir_to_tasklet,
Expand All @@ -38,7 +38,7 @@


IteratorIndexDType: TypeAlias = dace.int32 # type of iterator indexes
LetSymbol: TypeAlias = tuple[str, ts.FieldType | ts.ScalarType]
LetSymbol: TypeAlias = tuple[gtir.Literal | gtir.SymRef, ts.FieldType | ts.ScalarType]
TemporaryData: TypeAlias = tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]


Expand All @@ -62,7 +62,9 @@ def __call__(
sdfg: The SDFG where the primitive subgraph should be instantiated
state: The SDFG state where the result of the primitive function should be made available
sdfg_builder: The object responsible for visiting child nodes of the primitive node.
let_symbols: Mapping of symbols (i.e. lambda parameters) to known temporary fields.
let_symbols: Mapping of symbols (i.e. lambda parameters and/or local constants
like the identity value in a reduction context) to temporary fields
or symbolic expressions.
Returns:
A list of data access nodes and the associated GT4Py data type, which provide
Expand Down Expand Up @@ -105,11 +107,12 @@ def _parse_arg_expr(
)
for dim, _, _ in domain
}
return gtir_to_tasklet.IteratorExpr(
data_node,
arg_type.dims,
indices,
dims = arg_type.dims + (
# we add an extra anonymous dimension in the iterator definition to enable
# dereferencing elements in `ListType`
[gtx_common.Dimension("")] if isinstance(arg_type.dtype, gtir_ts.ListType) else []
)
return gtir_to_tasklet.IteratorExpr(data_node, dims, indices)


def _create_temporary_field(
Expand All @@ -134,27 +137,20 @@ def _create_temporary_field(
field_offset = [-lb for lb in domain_lbs]

if isinstance(output_desc, dace.data.Array):
# extend the result arrays with the local dimensions added by the field operator e.g. `neighbors`)
assert isinstance(output_field_type, ts.FieldType)
if isinstance(node_type.dtype, itir_ts.ListType):
raise NotImplementedError
else:
field_dtype = node_type.dtype
assert output_field_type.dtype == field_dtype
field_dims.extend(output_field_type.dims)
assert isinstance(node_type.dtype, gtir_ts.ListType)
field_dtype = node_type.dtype.element_type
# extend the result arrays with the local dimensions added by the field operator (e.g. `neighbors`)
field_shape.extend(output_desc.shape)
else:
assert isinstance(output_desc, dace.data.Scalar)
assert isinstance(output_field_type, ts.ScalarType)
field_dtype = node_type.dtype
assert output_field_type == field_dtype

# allocate local temporary storage for the result field
temp_name, _ = sdfg.add_temp_transient(
field_shape, dace_fieldview_util.as_dace_type(field_dtype), offset=field_offset
)
field_node = state.add_access(temp_name)
field_type = ts.FieldType(field_dims, field_dtype)
field_type = ts.FieldType(field_dims, node_type.dtype)

return field_node, field_type

Expand All @@ -169,6 +165,7 @@ def translate_as_field_op(
"""Generates the dataflow subgraph for the `as_fieldop` builtin function."""
assert isinstance(node, gtir.FunCall)
assert cpm.is_call_to(node.fun, "as_fieldop")
assert isinstance(node.type, ts.FieldType)

fun_node = node.fun
assert len(fun_node.args) == 2
Expand All @@ -182,13 +179,40 @@ def translate_as_field_op(
domain = dace_fieldview_util.get_domain(domain_expr)
assert isinstance(node.type, ts.FieldType)

reduce_identity: Optional[gtir_to_tasklet.SymbolExpr] = None
if cpm.is_applied_reduce(stencil_expr.expr):
# 'reduce' is a reserved keyword of the DSL and we will never find a user-defined symbol
# with this name. Since 'reduce' will never collide with a user-defined symbol, it is safe
# to use it internally to store the reduce identity value as a let-symbol.
if "reduce" in let_symbols:
raise NotImplementedError("nested reductions not supported.")

# the reduce identity value is used to fill the skip values in neighbors list
_, _, reduce_identity = gtir_to_tasklet.get_reduce_params(stencil_expr.expr)

# we store the reduce identity value as a constant let-symbol
let_symbols = let_symbols | {
"reduce": (
gtir.Literal(value=str(reduce_identity.value), type=stencil_expr.expr.type),
reduce_identity.dtype,
)
}

elif "reduce" in let_symbols:
# a parent node is a reduction node, so we are visiting the current node in the context of a reduction
reduce_symbol, _ = let_symbols["reduce"]
assert isinstance(reduce_symbol, gtir.Literal)
reduce_identity = gtir_to_tasklet.SymbolExpr(
reduce_symbol.value, dace_fieldview_util.as_dace_type(reduce_symbol.type)
)

# first visit the list of arguments and build a symbol map
stencil_args = [
_parse_arg_expr(arg, sdfg, state, sdfg_builder, domain, let_symbols) for arg in node.args
]

# represent the field operator as a mapped tasklet graph, which will range over the field domain
taskgen = gtir_to_tasklet.LambdaToTasklet(sdfg, state, sdfg_builder)
taskgen = gtir_to_tasklet.LambdaToTasklet(sdfg, state, sdfg_builder, reduce_identity)
input_connections, output_expr = taskgen.visit(stencil_expr, args=stencil_args)
assert isinstance(output_expr, gtir_to_tasklet.ValueExpr)
output_desc = output_expr.node.desc(sdfg)
Expand All @@ -205,7 +229,7 @@ def translate_as_field_op(

# allocate local temporary storage for the result field
field_node, field_type = _create_temporary_field(
sdfg, state, domain, node.type, output_desc, output_expr.field_type
sdfg, state, domain, node.type, output_desc, output_expr.dtype
)

# assume tasklet with single output
Expand Down Expand Up @@ -327,6 +351,54 @@ def translate_cond(
return output_nodes


def _get_symbolic_value(
sdfg: dace.SDFG,
state: dace.SDFGState,
sdfg_builder: gtir_to_sdfg.SDFGBuilder,
symbolic_expr: dace.symbolic.SymExpr,
scalar_type: ts.ScalarType,
temp_name: Optional[str] = None,
) -> dace.nodes.AccessNode:
tasklet_node = sdfg_builder.add_tasklet(
"get_value",
state,
{},
{"__out"},
f"__out = {symbolic_expr}",
)
temp_name, _ = sdfg.add_scalar(
f"__{temp_name or 'tmp'}",
dace_fieldview_util.as_dace_type(scalar_type),
find_new_name=True,
transient=True,
)
data_node = state.add_access(temp_name)
state.add_edge(
tasklet_node,
"__out",
data_node,
None,
dace.Memlet(data=temp_name, subset="0"),
)
return data_node


def translate_literal(
node: gtir.Node,
sdfg: dace.SDFG,
state: dace.SDFGState,
sdfg_builder: gtir_to_sdfg.SDFGBuilder,
let_symbols: dict[str, LetSymbol],
) -> list[TemporaryData]:
"""Generates the dataflow subgraph for a `ir.Literal` node."""
assert isinstance(node, gtir.Literal)

data_type = node.type
data_node = _get_symbolic_value(sdfg, state, sdfg_builder, node.value, data_type)

return [(data_node, data_type)]


def translate_symbol_ref(
node: gtir.Node,
sdfg: dace.SDFG,
Expand All @@ -335,57 +407,33 @@ def translate_symbol_ref(
let_symbols: dict[str, LetSymbol],
) -> list[TemporaryData]:
"""Generates the dataflow subgraph for a `ir.SymRef` node."""
assert isinstance(node, (gtir.Literal, gtir.SymRef))

data_type: ts.FieldType | ts.ScalarType
if isinstance(node, gtir.Literal):
sym_value = node.value
data_type = node.type
temp_name = "literal"
assert isinstance(node, gtir.SymRef)

sym_value = str(node.id)
if sym_value in let_symbols:
let_node, sym_type = let_symbols[sym_value]
if isinstance(let_node, gtir.Literal):
# this branch handles the case a let-symbol is mapped to some constant value
return sdfg_builder.visit(let_node)
# The `let_symbols` dictionary maps a `gtir.SymRef` string to a temporary
# data container. These symbols are visited and initialized in a state
# that preceeds the current state, therefore a new access node needs to
# be created in the state where they are accessed.
sym_value = str(let_node.id)
else:
sym_value = str(node.id)
if sym_value in let_symbols:
# The `let_symbols` dictionary maps a `gtir.SymRef` string to a temporary
# data container. These symbols are visited and initialized in a state
# that preceeds the current state, therefore a new access node is created
# everytime they are accessed. It is therefore possible that multiple access
# nodes are created in one state for the same data container. We rely
# on the simplify to remove duplicated access nodes.
sym_value, data_type = let_symbols[sym_value]
else:
data_type = sdfg_builder.get_symbol_type(sym_value)
temp_name = sym_value

if isinstance(data_type, ts.FieldType):
# add access node to current state
sym_node = state.add_access(sym_value)
sym_type = sdfg_builder.get_symbol_type(sym_value)

# Create new access node in current state. It is possible that multiple
# access nodes are created in one state for the same data container.
# We rely on the dace simplify pass to remove duplicated access nodes.
if isinstance(sym_type, ts.FieldType):
sym_node = state.add_access(sym_value)
else:
# scalar symbols are passed to the SDFG as symbols: build tasklet node
# to write the symbol to a scalar access node
tasklet_node = sdfg_builder.add_tasklet(
f"get_{temp_name}",
state,
{},
{"__out"},
f"__out = {sym_value}",
)
temp_name, _ = sdfg.add_scalar(
f"__{temp_name}",
dace_fieldview_util.as_dace_type(data_type),
find_new_name=True,
transient=True,
)
sym_node = state.add_access(temp_name)
state.add_edge(
tasklet_node,
"__out",
sym_node,
None,
dace.Memlet(data=sym_node.data, subset="0"),
sym_node = _get_symbolic_value(
sdfg, state, sdfg_builder, sym_value, sym_type, temp_name=sym_value
)

return [(sym_node, data_type)]
return [(sym_node, sym_type)]


if TYPE_CHECKING:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ def visit_Lambda(
symbol, the parameter will shadow the previous symbol during traversal of the lambda expression.
"""
lambda_symbols = let_symbols | {
str(p.id): (temp_node.data, type_)
str(p.id): (gtir.SymRef(id=temp_node.data), type_)
for p, (temp_node, type_) in zip(node.params, args, strict=True)
}

Expand All @@ -404,7 +404,7 @@ def visit_Literal(
head_state: dace.SDFGState,
let_symbols: dict[str, gtir_builtin_translators.LetSymbol],
) -> list[gtir_builtin_translators.TemporaryData]:
return gtir_builtin_translators.translate_symbol_ref(
return gtir_builtin_translators.translate_literal(
node, sdfg, head_state, self, let_symbols={}
)

Expand Down
Loading

0 comments on commit 719c487

Please sign in to comment.