Skip to content

Commit

Permalink
feat[next]: GTIR as_fieldop fusion pass (#1670)
Browse files Browse the repository at this point in the history
Adds a pass that transforms expressions like
```
as_fieldop(λ(__arg0, __arg1) → ·__arg0 + ·__arg1, c⟨ IDimₕ: [0, 1) ⟩)(
  as_fieldop(λ(__arg0, __arg1) → ·__arg0 × ·__arg1, c⟨ IDimₕ: [0, 1) ⟩)(inp1, inp2), inp3
)
```
into
```
as_fieldop(λ(inp1, inp2, inp3) → ·inp1 × ·inp2 + ·inp3, c⟨ IDimₕ: [0, 1) ⟩)(inp1, inp2, inp3)
```
  • Loading branch information
tehrengruber authored Oct 11, 2024
1 parent 92e83f3 commit ed9d82d
Show file tree
Hide file tree
Showing 4 changed files with 327 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/gt4py/next/iterator/ir_utils/ir_makers.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,7 @@ def domain(
)


def as_fieldop(expr: itir.Expr, domain: Optional[itir.FunCall] = None) -> call:
def as_fieldop(expr: itir.Expr, domain: Optional[itir.Expr] = None) -> call:
"""
Create an `as_fieldop` call.
Expand Down
204 changes: 204 additions & 0 deletions src/gt4py/next/iterator/transforms/fuse_as_fieldop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
# GT4Py - GridTools Framework
#
# Copyright (c) 2014-2024, ETH Zurich
# All rights reserved.
#
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

import dataclasses
from typing import Optional

from gt4py import eve
from gt4py.eve import utils as eve_utils
from gt4py.next.iterator import ir as itir
from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im
from gt4py.next.iterator.transforms import inline_lambdas, inline_lifts, trace_shifts
from gt4py.next.iterator.type_system import (
inference as type_inference,
type_specifications as it_ts,
)
from gt4py.next.type_system import type_info, type_specifications as ts


def _merge_arguments(
args1: dict[str, itir.Expr], arg2: dict[str, itir.Expr]
) -> dict[str, itir.Expr]:
new_args = {**args1}
for stencil_param, stencil_arg in arg2.items():
if stencil_param not in new_args:
new_args[stencil_param] = stencil_arg
else:
assert new_args[stencil_param] == stencil_arg
return new_args


def _canonicalize_as_fieldop(expr: itir.FunCall) -> itir.FunCall:
"""
Canonicalize applied `as_fieldop`s.
In case the stencil argument is a `deref` wrap it into a lambda such that we have a unified
format to work with (e.g. each parameter has a name without the need to special case).
"""
assert cpm.is_applied_as_fieldop(expr)

stencil = expr.fun.args[0] # type: ignore[attr-defined]
domain = expr.fun.args[1] if len(expr.fun.args) > 1 else None # type: ignore[attr-defined]
if cpm.is_ref_to(stencil, "deref"):
stencil = im.lambda_("arg")(im.deref("arg"))
new_expr = im.as_fieldop(stencil, domain)(*expr.args)
type_inference.copy_type(from_=expr, to=new_expr)

return new_expr

return expr


@dataclasses.dataclass
class FuseAsFieldOp(eve.NodeTranslator):
"""
Merge multiple `as_fieldop` calls into one.
>>> from gt4py import next as gtx
>>> from gt4py.next.iterator.ir_utils import ir_makers as im
>>> IDim = gtx.Dimension("IDim")
>>> field_type = ts.FieldType(dims=[IDim], dtype=ts.ScalarType(kind=ts.ScalarKind.INT32))
>>> d = im.domain("cartesian_domain", {IDim: (0, 1)})
>>> nested_as_fieldop = im.op_as_fieldop("plus", d)(
... im.op_as_fieldop("multiplies", d)(
... im.ref("inp1", field_type), im.ref("inp2", field_type)
... ),
... im.ref("inp3", field_type),
... )
>>> print(nested_as_fieldop)
as_fieldop(λ(__arg0, __arg1) → ·__arg0 + ·__arg1, c⟨ IDimₕ: [0, 1) ⟩)(
as_fieldop(λ(__arg0, __arg1) → ·__arg0 × ·__arg1, c⟨ IDimₕ: [0, 1) ⟩)(inp1, inp2), inp3
)
>>> print(
... FuseAsFieldOp.apply(
... nested_as_fieldop, offset_provider={}, allow_undeclared_symbols=True
... )
... )
as_fieldop(λ(inp1, inp2, inp3) → ·inp1 × ·inp2 + ·inp3, c⟨ IDimₕ: [0, 1) ⟩)(inp1, inp2, inp3)
""" # noqa: RUF002 # ignore ambiguous multiplication character

uids: eve_utils.UIDGenerator

def _inline_as_fieldop_arg(self, arg: itir.Expr) -> tuple[itir.Expr, dict[str, itir.Expr]]:
assert cpm.is_applied_as_fieldop(arg)
arg = _canonicalize_as_fieldop(arg)

stencil, *_ = arg.fun.args # type: ignore[attr-defined] # ensured by `is_applied_as_fieldop`
inner_args: list[itir.Expr] = arg.args
extracted_args: dict[str, itir.Expr] = {} # mapping from outer-stencil param to arg

stencil_params: list[itir.Sym] = []
stencil_body: itir.Expr = stencil.expr

for inner_param, inner_arg in zip(stencil.params, inner_args, strict=True):
if isinstance(inner_arg, itir.SymRef):
stencil_params.append(inner_param)
extracted_args[inner_arg.id] = inner_arg
elif isinstance(inner_arg, itir.Literal):
# note: only literals, not all scalar expressions are required as it doesn't make sense
# for them to be computed per grid point.
stencil_body = im.let(inner_param, im.promote_to_const_iterator(inner_arg))(
stencil_body
)
else:
# a scalar expression, a previously not inlined `as_fieldop` call or an opaque
# expression e.g. containing a tuple
stencil_params.append(inner_param)
new_outer_stencil_param = self.uids.sequential_id(prefix="__iasfop")
extracted_args[new_outer_stencil_param] = inner_arg

return im.lift(im.lambda_(*stencil_params)(stencil_body))(
*extracted_args.keys()
), extracted_args

@classmethod
def apply(
cls,
node: itir.Program,
*,
offset_provider,
uids: Optional[eve_utils.UIDGenerator] = None,
allow_undeclared_symbols=False,
):
node = type_inference.infer(
node, offset_provider=offset_provider, allow_undeclared_symbols=allow_undeclared_symbols
)

if not uids:
uids = eve_utils.UIDGenerator()

return cls(uids=uids).visit(node)

def visit_FunCall(self, node: itir.FunCall):
node = self.generic_visit(node)

if cpm.is_call_to(node.fun, "as_fieldop"):
node = _canonicalize_as_fieldop(node)

if cpm.is_call_to(node.fun, "as_fieldop") and isinstance(node.fun.args[0], itir.Lambda):
stencil: itir.Lambda = node.fun.args[0]
domain = node.fun.args[1] if len(node.fun.args) > 1 else None

shifts = trace_shifts.trace_stencil(stencil)

args: list[itir.Expr] = node.args

new_args: dict[str, itir.Expr] = {}
new_stencil_body: itir.Expr = stencil.expr

for stencil_param, arg, arg_shifts in zip(stencil.params, args, shifts, strict=True):
assert isinstance(arg.type, ts.TypeSpec)
dtype = type_info.extract_dtype(arg.type)
# TODO(tehrengruber): make this configurable
should_inline = isinstance(arg, itir.Literal) or (
isinstance(arg, itir.FunCall)
and (cpm.is_call_to(arg.fun, "as_fieldop") or cpm.is_call_to(arg, "if_"))
and (isinstance(dtype, it_ts.ListType) or len(arg_shifts) <= 1)
)
if should_inline:
if cpm.is_applied_as_fieldop(arg):
pass
elif cpm.is_call_to(arg, "if_"):
# TODO(tehrengruber): revisit if we want to inline if_
type_ = arg.type
arg = im.op_as_fieldop("if_")(*arg.args)
arg.type = type_
elif isinstance(arg, itir.Literal):
arg = im.op_as_fieldop(im.lambda_()(arg))()
else:
raise NotImplementedError()

inline_expr, extracted_args = self._inline_as_fieldop_arg(arg)

new_stencil_body = im.let(stencil_param, inline_expr)(new_stencil_body)

new_args = _merge_arguments(new_args, extracted_args)
else:
new_param: str
if isinstance(
arg, itir.SymRef
): # use name from outer scope (optional, just to get a nice IR)
new_param = arg.id
new_stencil_body = im.let(stencil_param.id, arg.id)(new_stencil_body)
else:
new_param = stencil_param.id
new_args = _merge_arguments(new_args, {new_param: arg})

# simplify stencil directly to keep the tree small
new_stencil_body = inline_lambdas.InlineLambdas.apply(
new_stencil_body, opcount_preserving=True
)
new_stencil_body = inline_lifts.InlineLifts().visit(new_stencil_body)

new_node = im.as_fieldop(im.lambda_(*new_args.keys())(new_stencil_body), domain)(
*new_args.values()
)
type_inference.copy_type(from_=node, to=new_node)

return new_node
return node
10 changes: 10 additions & 0 deletions src/gt4py/next/iterator/type_system/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,16 @@ def _set_node_type(node: itir.Node, type_: ts.TypeSpec) -> None:
node.type = type_


def copy_type(from_: itir.Node, to: itir.Node) -> None:
"""
Copy type from one node to another.
This function mainly exists for readability reasons.
"""
assert isinstance(from_.type, ts.TypeSpec)
_set_node_type(to, from_.type)


def on_inferred(callback: Callable, *args: Union[ts.TypeSpec, ObservableTypeSynthesizer]) -> None:
"""
Execute `callback` as soon as all `args` have a type.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# GT4Py - GridTools Framework
#
# Copyright (c) 2014-2024, ETH Zurich
# All rights reserved.
#
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause
from typing import Callable, Optional

from gt4py import next as gtx
from gt4py.next.iterator import ir as itir
from gt4py.next.iterator.ir_utils import ir_makers as im
from gt4py.next.iterator.transforms import fuse_as_fieldop
from gt4py.next.type_system import type_specifications as ts

IDim = gtx.Dimension("IDim")
field_type = ts.FieldType(dims=[IDim], dtype=ts.ScalarType(kind=ts.ScalarKind.INT32))


def test_trivial():
d = im.domain("cartesian_domain", {IDim: (0, 1)})
testee = im.op_as_fieldop("plus", d)(
im.op_as_fieldop("multiplies", d)(im.ref("inp1", field_type), im.ref("inp2", field_type)),
im.ref("inp3", field_type),
)
expected = im.as_fieldop(
im.lambda_("inp1", "inp2", "inp3")(
im.plus(im.multiplies_(im.deref("inp1"), im.deref("inp2")), im.deref("inp3"))
),
d,
)(im.ref("inp1", field_type), im.ref("inp2", field_type), im.ref("inp3", field_type))
actual = fuse_as_fieldop.FuseAsFieldOp.apply(
testee, offset_provider={}, allow_undeclared_symbols=True
)
assert actual == expected


def test_trivial_literal():
d = im.domain("cartesian_domain", {})
testee = im.op_as_fieldop("plus", d)(im.op_as_fieldop("multiplies", d)(1, 2), 3)
expected = im.as_fieldop(im.lambda_()(im.plus(im.multiplies_(1, 2), 3)), d)()
actual = fuse_as_fieldop.FuseAsFieldOp.apply(
testee, offset_provider={}, allow_undeclared_symbols=True
)
assert actual == expected


def test_symref_used_twice():
d = im.domain("cartesian_domain", {IDim: (0, 1)})
testee = im.as_fieldop(im.lambda_("a", "b")(im.plus(im.deref("a"), im.deref("b"))), d)(
im.as_fieldop(im.lambda_("c", "d")(im.multiplies_(im.deref("c"), im.deref("d"))), d)(
im.ref("inp1", field_type), im.ref("inp2", field_type)
),
im.ref("inp1", field_type),
)
expected = im.as_fieldop(
im.lambda_("inp1", "inp2")(
im.plus(im.multiplies_(im.deref("inp1"), im.deref("inp2")), im.deref("inp1"))
),
d,
)("inp1", "inp2")
actual = fuse_as_fieldop.FuseAsFieldOp.apply(
testee, offset_provider={}, allow_undeclared_symbols=True
)
assert actual == expected


def test_no_inline():
d1 = im.domain("cartesian_domain", {IDim: (1, 2)})
d2 = im.domain("cartesian_domain", {IDim: (0, 3)})
testee = im.as_fieldop(
im.lambda_("a")(
im.plus(im.deref(im.shift("IOff", 1)("a")), im.deref(im.shift("IOff", -1)("a")))
),
d1,
)(im.as_fieldop(im.lambda_("inp1")(im.deref("inp1")), d2)(im.ref("inp1", field_type)))
actual = fuse_as_fieldop.FuseAsFieldOp.apply(
testee, offset_provider={"IOff": IDim}, allow_undeclared_symbols=True
)
assert actual == testee


def test_partial_inline():
d1 = im.domain("cartesian_domain", {IDim: (1, 2)})
d2 = im.domain("cartesian_domain", {IDim: (0, 3)})
testee = im.as_fieldop(
# first argument read at multiple locations -> not inlined
# second argument only reat at a single location -> inlined
im.lambda_("a", "b")(
im.plus(
im.plus(im.deref(im.shift("IOff", 1)("a")), im.deref(im.shift("IOff", -1)("a"))),
im.deref("b"),
)
),
d1,
)(
im.as_fieldop(im.lambda_("inp1")(im.deref("inp1")), d2)(im.ref("inp1", field_type)),
im.as_fieldop(im.lambda_("inp1")(im.deref("inp1")), d2)(im.ref("inp1", field_type)),
)
expected = im.as_fieldop(
im.lambda_("a", "inp1")(
im.plus(
im.plus(im.deref(im.shift("IOff", 1)("a")), im.deref(im.shift("IOff", -1)("a"))),
im.deref("inp1"),
)
),
d1,
)(im.as_fieldop(im.lambda_("inp1")(im.deref("inp1")), d2)(im.ref("inp1", field_type)), "inp1")
actual = fuse_as_fieldop.FuseAsFieldOp.apply(
testee, offset_provider={"IOff": IDim}, allow_undeclared_symbols=True
)
assert actual == expected

0 comments on commit ed9d82d

Please sign in to comment.