Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[next]: Extend and refactor constant folding #1810

Open
wants to merge 52 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
e4f9be1
constant fold min/max for domain expressions
SF-N Jan 15, 2025
21de7d7
Extend and refactor constant folding
SF-N Jan 21, 2025
c630b2d
Fix ConvertMinusToUnary
SF-N Jan 21, 2025
fddbc81
Cleanup ConstantFolding
SF-N Jan 21, 2025
16cec90
Cleanup ConstantFolding
SF-N Jan 21, 2025
86f41e0
Merge main
SF-N Jan 22, 2025
4a52c38
Fix imports
SF-N Jan 22, 2025
8dd80a2
Some fixes
SF-N Jan 22, 2025
ce8adec
Add builtin for unary minus
SF-N Jan 23, 2025
d0b99ed
Address review comments
SF-N Jan 23, 2025
8f8a8fb
Add abs to test
SF-N Jan 23, 2025
61ec67b
Extend visit_UnaryOp in foast_to_gtir
SF-N Jan 23, 2025
fb6636f
Merge branch 'neg_builtin' into constant_folding_main
SF-N Jan 23, 2025
5962164
Use neg in ConstantFolding
SF-N Jan 23, 2025
a9f2791
Fix foast_to_gtir test references
SF-N Jan 23, 2025
586147c
Cleanup and use neg builtin
SF-N Jan 23, 2025
b265db3
Merge branch 'neg_builtin' into constant_folding_main
SF-N Jan 23, 2025
ed0f062
Take care of tuple_get
SF-N Jan 24, 2025
ed56f02
Minor improvements
SF-N Jan 24, 2025
b1e7197
Move fixed point transformations to a new class
SF-N Jan 27, 2025
ec687b0
Merge branch 'main' into constant_folding_main
SF-N Jan 27, 2025
f250699
Run pre-commit
SF-N Jan 27, 2025
c24ea80
Address review comments
SF-N Jan 27, 2025
4e4869a
Remove self.visit
SF-N Jan 27, 2025
5f09bea
Merge main
SF-N Jan 28, 2025
39eda41
Minor
SF-N Jan 28, 2025
b5f32cf
Update src/gt4py/next/iterator/builtins.py
tehrengruber Jan 28, 2025
054bc41
Update src/gt4py/next/iterator/embedded.py
tehrengruber Jan 28, 2025
93e0ab5
Update src/gt4py/next/iterator/transforms/collapse_tuple.py
tehrengruber Jan 28, 2025
3ea7217
Update src/gt4py/next/iterator/transforms/collapse_tuple.py
tehrengruber Jan 28, 2025
00332cc
Update src/gt4py/next/iterator/transforms/collapse_tuple.py
tehrengruber Jan 28, 2025
c65dbfb
Remove file
SF-N Jan 28, 2025
089ed22
Merge branch 'constant_folding_main' of github.com:SF-N/gt4py into co…
SF-N Jan 28, 2025
9e14480
Merge branch 'main' into constant_folding_main
SF-N Jan 28, 2025
9676334
Add new line
SF-N Jan 28, 2025
1543688
Merge branch 'constant_folding_main' of github.com:SF-N/gt4py into co…
SF-N Jan 28, 2025
231c0e5
Address review comments
SF-N Jan 28, 2025
5df9b8b
Remove unary minus in the end
SF-N Jan 28, 2025
8ac18ff
Merge branch 'main' into constant_folding_main
SF-N Jan 29, 2025
91f2a66
Cleanup UndoCanonicalizeMinus
SF-N Jan 29, 2025
87184cd
Merge branch 'constant_folding_main' of github.com:SF-N/gt4py into co…
SF-N Jan 29, 2025
84d813d
Cleanup UndoCanonicalizeMinus
SF-N Jan 29, 2025
1445dd5
Merge branch 'main' into constant_folding_main
SF-N Feb 10, 2025
101a25b
Merge branch 'main' into constant_folding_main
SF-N Feb 11, 2025
ed0248e
Fix some test failures
SF-N Feb 11, 2025
d452af9
Clean up constant_folding tests
SF-N Feb 11, 2025
1324765
Merge branch 'main' into constant_folding_main
SF-N Feb 11, 2025
225acf7
Reformat
SF-N Feb 12, 2025
debe10c
Merge branch 'main' into constant_folding_main
tehrengruber Feb 14, 2025
c91b011
Cleanup tests
tehrengruber Feb 14, 2025
39cadaa
Cleanup tests
tehrengruber Feb 14, 2025
3ca3630
Cleanup
tehrengruber Feb 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/gt4py/next/iterator/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,7 @@ def bool(*args): # noqa: A001 [builtin-variable-shadowing]
"floor",
"ceil",
"trunc",
"neg",
tehrengruber marked this conversation as resolved.
Show resolved Hide resolved
}
UNARY_MATH_FP_PREDICATE_BUILTINS = {"isfinite", "isinf", "isnan"}
BINARY_MATH_NUMBER_BUILTINS = {
Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/next/iterator/embedded.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,7 +549,7 @@ def promote_scalars(val: CompositeOfScalarOrField):
}
decorator = getattr(builtins, math_builtin_name).register(EMBEDDED)
impl: Callable
if math_builtin_name in ["gamma", "not_"]:
if math_builtin_name in ["gamma", "not_", "neg"]:
tehrengruber marked this conversation as resolved.
Show resolved Hide resolved
continue # treated explicitly
elif math_builtin_name in python_builtins:
# TODO: Should potentially use numpy fixed size types to be consistent
Expand Down
7 changes: 3 additions & 4 deletions src/gt4py/next/iterator/transforms/collapse_tuple.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,6 @@ def apply(
def visit(self, node, **kwargs):
if cpm.is_call_to(node, "as_fieldop"):
kwargs = {**kwargs, "within_stencil": True}

return super().visit(node, **kwargs)

def transform_collapse_make_tuple_tuple_get(
Expand Down Expand Up @@ -269,7 +268,7 @@ def transform_propagate_tuple_get(self, node: ir.FunCall, **kwargs) -> Optional[
# TODO(tehrengruber): extend to general symbols as long as the tail call in the let
# does not capture
# `tuple_get(i, let(...)(make_tuple()))` -> `let(...)(tuple_get(i, make_tuple()))`
if cpm.is_let(node.args[1]):
if isinstance(node, ir.FunCall) and cpm.is_let(node.args[1]):
tehrengruber marked this conversation as resolved.
Show resolved Hide resolved
idx, let_expr = node.args
return im.call(
im.lambda_(*let_expr.fun.params)( # type: ignore[attr-defined] # ensured by is_let
Expand Down Expand Up @@ -438,7 +437,7 @@ def transform_propagate_to_if_on_tuples_cps(
return None

def transform_propagate_nested_let(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]:
if cpm.is_let(node):
if isinstance(node, ir.FunCall) and cpm.is_let(node):
tehrengruber marked this conversation as resolved.
Show resolved Hide resolved
# `let((a, let(b, 1)(a_val)))(a)`-> `let(b, 1)(let(a, a_val)(a))`
outer_vars = {}
inner_vars = {}
Expand All @@ -464,7 +463,7 @@ def transform_propagate_nested_let(self, node: ir.FunCall, **kwargs) -> Optional
return None

def transform_inline_trivial_let(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]:
if cpm.is_let(node):
if isinstance(node, ir.FunCall) and cpm.is_let(node):
tehrengruber marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(node.fun.expr, ir.SymRef): # type: ignore[attr-defined] # ensured by is_let
# `let(a, 1)(a)` -> `1`
for arg_sym, arg in zip(node.fun.params, node.args): # type: ignore[attr-defined] # ensured by is_let
Expand Down
253 changes: 223 additions & 30 deletions src/gt4py/next/iterator/transforms/constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,52 +6,245 @@
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

from gt4py.eve import NodeTranslator, PreserveLocationVisitor
from __future__ import annotations

import dataclasses
import enum
import functools
import operator
from typing import Optional

from gt4py import eve
from gt4py.next.iterator import builtins, embedded, ir
from gt4py.next.iterator.ir_utils import ir_makers as im
from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im
from gt4py.next.iterator.transforms import fixed_point_transformation


@dataclasses.dataclass(frozen=True, kw_only=True)
class ConstantFolding(
fixed_point_transformation.FixedPointTransformation, eve.PreserveLocationVisitor
):
PRESERVED_ANNEX_ATTRS = (
"type",
"domain",
)

class Transformation(enum.Flag):
# e.g. `literal + symref` -> `symref + literal` and
# `literal + funcall` -> `funcall + literal` and
# `symref + funcall` -> `funcall + symref`
CANONICALIZE_FUNCALL_SYMREF_LITERAL = enum.auto()

# `minus(arg0, arg1) -> plus(im.call("neg")(arg1), arg0)`
SF-N marked this conversation as resolved.
Show resolved Hide resolved
CANONICALIZE_MINUS = enum.auto()

# `maximum(im.call(...)(), maximum(...))` -> `maximum(maximum(...), im.call(...)())`
SF-N marked this conversation as resolved.
Show resolved Hide resolved
CANONICALIZE_MIN_MAX = enum.auto()

# `im.call(...)(im.tuple_get(...), im.plus(...)))` -> `im.call(...)( im.plus(...)), im.tuple_get(...))`
SF-N marked this conversation as resolved.
Show resolved Hide resolved
CANONICALIZE_TUPLE_GET_PLUS = enum.auto()

# `(sym + 1) + 1` -> `sym + 2`
SF-N marked this conversation as resolved.
Show resolved Hide resolved
FOLD_FUNCALL_LITERAL = enum.auto()

# `maximum(maximum(sym, 1), sym)` -> `maximum(sym, 1)`
SF-N marked this conversation as resolved.
Show resolved Hide resolved
FOLD_MIN_MAX_FUNCALL_SYMREF_LITERAL = enum.auto()

# `maximum(plus(sym, 1), sym)` -> `plus(sym, 1)` and
# `maximum(plus(sym, 1), plus(sym, -1))` -> `plus(sym, 1)`
SF-N marked this conversation as resolved.
Show resolved Hide resolved
FOLD_MIN_MAX_PLUS = enum.auto()

# `sym + 0` -> `sym`
FOLD_SYMREF_PLUS_ZERO = enum.auto()
SF-N marked this conversation as resolved.
Show resolved Hide resolved

# `sym + 1 + (sym + 2)` -> `sym + sym + 2 + 1`
CANONICALIZE_PLUS_SYMREF_LITERAL = enum.auto()
SF-N marked this conversation as resolved.
Show resolved Hide resolved

# `1 + 1` -> `2`
FOLD_ARITHMETIC_BUILTINS = enum.auto()

# `minimum(a, a)` -> `a`
FOLD_MIN_MAX_LITERALS = enum.auto()

# `if_(True, true_branch, false_branch)` -> `true_branch`
FOLD_IF = enum.auto()

@classmethod
def all(self) -> ConstantFolding.Transformation:
return functools.reduce(operator.or_, self.__members__.values())

enabled_transformations: Transformation = Transformation.all() # noqa: RUF009 [function-call-in-dataclass-default-argument]

class ConstantFolding(PreserveLocationVisitor, NodeTranslator):
@classmethod
def apply(cls, node: ir.Node) -> ir.Node:
return cls().visit(node)
node = cls().visit(node)
return node

def transform_canonicalize_funcall_symref_literal(
self, node: ir.FunCall, **kwargs
) -> Optional[ir.Node]:
# e.g. `literal + symref` -> `symref + literal` and
# `literal + funcall` -> `funcall + literal` and
# `symref + funcall` -> `funcall + symref`
if (
isinstance(node, ir.FunCall)
and isinstance(node.fun, ir.SymRef)
and cpm.is_call_to(node, ("plus", "multiplies", "minimum", "maximum"))
):
if (
isinstance(node.args[1], (ir.SymRef, ir.FunCall))
and isinstance(node.args[0], ir.Literal)
) or (isinstance(node.args[1], ir.FunCall) and isinstance(node.args[0], ir.SymRef)):
return im.call(node.fun.id)(node.args[1], node.args[0])
return None

def transform_canonicalize_minus(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]:
# `minus(arg0, arg1) -> plus(im.call("neg")(arg1), arg0)`
if isinstance(node, ir.FunCall) and cpm.is_call_to(node, "minus"):
return im.plus(self.fp_transform(im.call("neg")(node.args[1])), node.args[0])
return None

def transform_canonicalize_min_max(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]:
# `maximum(im.call(...)(), maximum(...))` -> `maximum(maximum(...), im.call(...)())`
if cpm.is_call_to(node, ("maximum", "minimum")):
if (
isinstance(node.args[0], ir.FunCall)
and isinstance(node.fun, ir.SymRef)
and not cpm.is_call_to(node.args[0], ("maximum", "minimum"))
and cpm.is_call_to(node.args[1], ("maximum", "minimum"))
):
return im.call(node.fun.id)(node.args[1], node.args[0])
return None

def transform_canonicalize_tuple_get_plus(
self, node: ir.FunCall, **kwargs
) -> Optional[ir.Node]:
# im.call(...)(im.tuple_get(...), im.plus(...)))` -> `im.call(...)( im.plus(...)), im.tuple_get(...))`
if isinstance(node, ir.FunCall) and isinstance(node.fun, ir.SymRef) and len(node.args) > 1:
if cpm.is_call_to(node.args[0], "tuple_get") and cpm.is_call_to(node.args[1], "plus"):
return im.call(node.fun.id)(node.args[1], node.args[0])
return None

def visit_FunCall(self, node: ir.FunCall):
# visit depth-first such that nested constant expressions (e.g. `(1+2)+3`) are properly folded
new_node = self.generic_visit(node)
def transform_fold_funcall_literal(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]:
# `(sym + 1) + 1` -> `sym + 2`
if cpm.is_call_to(node, "plus"):
if (
isinstance(node.args[0], ir.FunCall)
and cpm.is_call_to(node.args[0], "plus")
and isinstance(node.args[1], ir.Literal)
):
fun_call, literal = node.args
if (
isinstance(fun_call, ir.FunCall)
and isinstance(fun_call.args[0], (ir.SymRef, ir.FunCall))
and isinstance(fun_call.args[1], ir.Literal)
):
return im.plus(
fun_call.args[0],
self.fp_transform(im.plus(fun_call.args[1], literal)),
)
return None

def transform_fold_min_max_funcall_symref_literal(
self, node: ir.FunCall, **kwargs
) -> Optional[ir.Node]:
# `maximum(maximum(sym, 1), sym)` -> `maximum(sym, 1)`
if (
isinstance(new_node.fun, ir.SymRef)
and new_node.fun.id in ["minimum", "maximum"]
and new_node.args[0] == new_node.args[1]
): # `minimum(a, a)` -> `a`
return new_node.args[0]
isinstance(node, ir.FunCall)
and isinstance(node.fun, ir.SymRef)
and cpm.is_call_to(node, ("minimum", "maximum"))
):
if cpm.is_call_to(node.args[0], ("maximum", "minimum")):
fun_call, arg1 = node.args
if arg1 == fun_call.args[0]: # type: ignore[attr-defined] # assured by if above
return im.call(fun_call.fun.id)(fun_call.args[1], arg1) # type: ignore[attr-defined] # assured by if above
if arg1 == fun_call.args[1]: # type: ignore[attr-defined] # assured by if above
return im.call(fun_call.fun.id)(fun_call.args[0], arg1) # type: ignore[attr-defined] # assured by if above
return None

def transform_fold_min_max_plus(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]:
if (
isinstance(new_node.fun, ir.SymRef)
and new_node.fun.id == "if_"
and isinstance(new_node.args[0], ir.Literal)
): # `if_(True, true_branch, false_branch)` -> `true_branch`
if new_node.args[0].value == "True":
new_node = new_node.args[1]
else:
new_node = new_node.args[2]
isinstance(node, ir.FunCall)
and isinstance(node.fun, ir.SymRef)
and cpm.is_call_to(node, ("minimum", "maximum"))
):
arg0, arg1 = node.args
# `maximum(plus(sym, 1), sym)` -> `plus(sym, 1)`
if cpm.is_call_to(arg0, "plus"):
if arg0.args[0] == arg1:
return im.plus(
arg0.args[0], self.fp_transform(im.call(node.fun.id)(arg0.args[1], 0))
)
# `maximum(plus(sym, 1), plus(sym, -1))` -> `plus(sym, 1)`
if cpm.is_call_to(arg0, "plus") and cpm.is_call_to(arg1, "plus"):
if arg0.args[0] == arg1.args[0]:
return im.plus(
arg0.args[0],
self.fp_transform(im.call(node.fun.id)(arg0.args[1], arg1.args[1])),
)

return None

def transform_fold_symref_plus_zero(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]:
# `sym + 0` -> `sym`
if (
isinstance(new_node, ir.FunCall)
and isinstance(new_node.fun, ir.SymRef)
and len(new_node.args) > 0
and all(isinstance(arg, ir.Literal) for arg in new_node.args)
): # `1 + 1` -> `2`
cpm.is_call_to(node, "plus")
and isinstance(node.args[1], ir.Literal)
and node.args[1].value.isdigit()
and int(node.args[1].value) == 0
):
return node.args[0]
return None

def transform_canonicalize_plus_symref_literal(
self, node: ir.FunCall, **kwargs
) -> Optional[ir.Node]:
# `sym1 + 1 + (sym2 + 2)` -> `sym1 + sym2 + 2 + 1`
if cpm.is_call_to(node, "plus"):
if (
cpm.is_call_to(node.args[0], "plus")
and cpm.is_call_to(node.args[1], "plus")
and isinstance(node.args[0].args[1], ir.Literal)
and isinstance(node.args[1].args[1], ir.Literal)
):
return im.plus(
self.fp_transform(im.plus(node.args[0].args[0], node.args[1].args[0])),
self.fp_transform(im.plus(node.args[0].args[1], node.args[1].args[1])),
)
return None

def transform_fold_arithmetic_builtins(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]:
# `1 + 1` -> `2`
if (
isinstance(node, ir.FunCall)
and isinstance(node.fun, ir.SymRef)
and len(node.args) > 0
and all(isinstance(arg, ir.Literal) for arg in node.args)
):
try:
if new_node.fun.id in builtins.ARITHMETIC_BUILTINS:
fun = getattr(embedded, str(new_node.fun.id))
if node.fun.id in builtins.ARITHMETIC_BUILTINS:
fun = getattr(embedded, str(node.fun.id))
arg_values = [
getattr(embedded, str(arg.type))(arg.value) # type: ignore[attr-defined] # arg type already established in if condition
for arg in new_node.args
for arg in node.args
]
new_node = im.literal_from_value(fun(*arg_values))
return im.literal_from_value(fun(*arg_values))
except ValueError:
pass # happens for inf and neginf
return None

def transform_fold_min_max_literals(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]:
# `minimum(a, a)` -> `a`
if cpm.is_call_to(node, ("minimum", "maximum")):
if node.args[0] == node.args[1]:
return node.args[0]
return None

return new_node
def transform_fold_if(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]:
# `if_(True, true_branch, false_branch)` -> `true_branch`
if cpm.is_call_to(node, "if_") and isinstance(node.args[0], ir.Literal):
if node.args[0].value == "True":
return node.args[1]
else:
return node.args[2]
return None
44 changes: 44 additions & 0 deletions src/gt4py/next/iterator/transforms/fixed_point_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# GT4Py - GridTools Framework
#
# Copyright (c) 2014-2024, ETH Zurich
# All rights reserved.
#
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

import dataclasses
import enum
from typing import ClassVar, Optional, Type

from gt4py import eve
from gt4py.next.iterator import ir
from gt4py.next.iterator.type_system import inference as itir_type_inference


@dataclasses.dataclass(frozen=True, kw_only=True)
class FixedPointTransform(eve.PreserveLocationVisitor, eve.NodeTranslator):
Flag: ClassVar[Type[enum.Flag]]
flags: enum.Flag

def fp_transform(self, node: ir.Node, **kwargs) -> ir.Node:
while True:
new_node = self.transform(node, **kwargs)
if new_node is None:
break
assert new_node != node
node = new_node
return node

def transform(self, node: ir.Node, **kwargs) -> Optional[ir.Node]:
for transformation in self.Flag:
if self.flags & transformation:
assert isinstance(transformation.name, str)
method = getattr(self, f"transform_{transformation.name.lower()}")
result = method(node, **kwargs)
if result is not None:
assert (
result is not node
) # transformation should have returned None, since nothing changed
itir_type_inference.reinfer(result)
return result
return None
Loading
Loading