Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[next]: Extend and refactor constant folding #1810

Open
wants to merge 47 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 36 commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
e4f9be1
constant fold min/max for domain expressions
SF-N Jan 15, 2025
21de7d7
Extend and refactor constant folding
SF-N Jan 21, 2025
c630b2d
Fix ConvertMinusToUnary
SF-N Jan 21, 2025
fddbc81
Cleanup ConstantFolding
SF-N Jan 21, 2025
16cec90
Cleanup ConstantFolding
SF-N Jan 21, 2025
86f41e0
Merge main
SF-N Jan 22, 2025
4a52c38
Fix imports
SF-N Jan 22, 2025
8dd80a2
Some fixes
SF-N Jan 22, 2025
ce8adec
Add builtin for unary minus
SF-N Jan 23, 2025
d0b99ed
Address review comments
SF-N Jan 23, 2025
8f8a8fb
Add abs to test
SF-N Jan 23, 2025
61ec67b
Extend visit_UnaryOp in foast_to_gtir
SF-N Jan 23, 2025
fb6636f
Merge branch 'neg_builtin' into constant_folding_main
SF-N Jan 23, 2025
5962164
Use neg in ConstantFolding
SF-N Jan 23, 2025
a9f2791
Fix foast_to_gtir test references
SF-N Jan 23, 2025
586147c
Cleanup and use neg builtin
SF-N Jan 23, 2025
b265db3
Merge branch 'neg_builtin' into constant_folding_main
SF-N Jan 23, 2025
ed0f062
Take care of tuple_get
SF-N Jan 24, 2025
ed56f02
Minor improvements
SF-N Jan 24, 2025
b1e7197
Move fixed point transformations to a new class
SF-N Jan 27, 2025
ec687b0
Merge branch 'main' into constant_folding_main
SF-N Jan 27, 2025
f250699
Run pre-commit
SF-N Jan 27, 2025
c24ea80
Address review comments
SF-N Jan 27, 2025
4e4869a
Remove self.visit
SF-N Jan 27, 2025
5f09bea
Merge main
SF-N Jan 28, 2025
39eda41
Minor
SF-N Jan 28, 2025
b5f32cf
Update src/gt4py/next/iterator/builtins.py
tehrengruber Jan 28, 2025
054bc41
Update src/gt4py/next/iterator/embedded.py
tehrengruber Jan 28, 2025
93e0ab5
Update src/gt4py/next/iterator/transforms/collapse_tuple.py
tehrengruber Jan 28, 2025
3ea7217
Update src/gt4py/next/iterator/transforms/collapse_tuple.py
tehrengruber Jan 28, 2025
00332cc
Update src/gt4py/next/iterator/transforms/collapse_tuple.py
tehrengruber Jan 28, 2025
c65dbfb
Remove file
SF-N Jan 28, 2025
089ed22
Merge branch 'constant_folding_main' of github.com:SF-N/gt4py into co…
SF-N Jan 28, 2025
9e14480
Merge branch 'main' into constant_folding_main
SF-N Jan 28, 2025
9676334
Add new line
SF-N Jan 28, 2025
1543688
Merge branch 'constant_folding_main' of github.com:SF-N/gt4py into co…
SF-N Jan 28, 2025
231c0e5
Address review comments
SF-N Jan 28, 2025
5df9b8b
Remove unary minus in the end
SF-N Jan 28, 2025
8ac18ff
Merge branch 'main' into constant_folding_main
SF-N Jan 29, 2025
91f2a66
Cleanup UndoCanonicalizeMinus
SF-N Jan 29, 2025
87184cd
Merge branch 'constant_folding_main' of github.com:SF-N/gt4py into co…
SF-N Jan 29, 2025
84d813d
Cleanup UndoCanonicalizeMinus
SF-N Jan 29, 2025
1445dd5
Merge branch 'main' into constant_folding_main
SF-N Feb 10, 2025
101a25b
Merge branch 'main' into constant_folding_main
SF-N Feb 11, 2025
ed0248e
Fix some test failures
SF-N Feb 11, 2025
d452af9
Clean up constant_folding tests
SF-N Feb 11, 2025
1324765
Merge branch 'main' into constant_folding_main
SF-N Feb 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
253 changes: 223 additions & 30 deletions src/gt4py/next/iterator/transforms/constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,52 +6,245 @@
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

from gt4py.eve import NodeTranslator, PreserveLocationVisitor
from __future__ import annotations

import dataclasses
import enum
import functools
import operator
from typing import Optional

from gt4py import eve
from gt4py.next.iterator import builtins, embedded, ir
from gt4py.next.iterator.ir_utils import ir_makers as im
from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im
from gt4py.next.iterator.transforms import fixed_point_transformation


@dataclasses.dataclass(frozen=True, kw_only=True)
class ConstantFolding(
fixed_point_transformation.FixedPointTransformation, eve.PreserveLocationVisitor
):
PRESERVED_ANNEX_ATTRS = (
"type",
"domain",
)

class Transformation(enum.Flag):
# e.g. `literal + symref` -> `symref + literal` and
# `literal + funcall` -> `funcall + literal` and
# `symref + funcall` -> `funcall + symref`
CANONICALIZE_FUNCALL_SYMREF_LITERAL = enum.auto()

# `minus(arg0, arg1) -> plus(im.call("neg")(arg1), arg0)`
SF-N marked this conversation as resolved.
Show resolved Hide resolved
CANONICALIZE_MINUS = enum.auto()

# `maximum(im.call(...)(), maximum(...))` -> `maximum(maximum(...), im.call(...)())`
SF-N marked this conversation as resolved.
Show resolved Hide resolved
CANONICALIZE_MIN_MAX = enum.auto()

# `im.call(...)(im.tuple_get(...), im.plus(...)))` -> `im.call(...)( im.plus(...)), im.tuple_get(...))`
SF-N marked this conversation as resolved.
Show resolved Hide resolved
CANONICALIZE_TUPLE_GET_PLUS = enum.auto()

# `(sym + 1) + 1` -> `sym + 2`
SF-N marked this conversation as resolved.
Show resolved Hide resolved
FOLD_FUNCALL_LITERAL = enum.auto()

# `maximum(maximum(sym, 1), sym)` -> `maximum(sym, 1)`
SF-N marked this conversation as resolved.
Show resolved Hide resolved
FOLD_MIN_MAX_FUNCALL_SYMREF_LITERAL = enum.auto()

# `maximum(plus(sym, 1), sym)` -> `plus(sym, 1)` and
# `maximum(plus(sym, 1), plus(sym, -1))` -> `plus(sym, 1)`
SF-N marked this conversation as resolved.
Show resolved Hide resolved
FOLD_MIN_MAX_PLUS = enum.auto()

# `sym + 0` -> `sym`
FOLD_SYMREF_PLUS_ZERO = enum.auto()
SF-N marked this conversation as resolved.
Show resolved Hide resolved

# `sym + 1 + (sym + 2)` -> `sym + sym + 2 + 1`
CANONICALIZE_PLUS_SYMREF_LITERAL = enum.auto()
SF-N marked this conversation as resolved.
Show resolved Hide resolved

# `1 + 1` -> `2`
FOLD_ARITHMETIC_BUILTINS = enum.auto()

# `minimum(a, a)` -> `a`
FOLD_MIN_MAX_LITERALS = enum.auto()

# `if_(True, true_branch, false_branch)` -> `true_branch`
FOLD_IF = enum.auto()

@classmethod
def all(self) -> ConstantFolding.Transformation:
return functools.reduce(operator.or_, self.__members__.values())

enabled_transformations: Transformation = Transformation.all() # noqa: RUF009 [function-call-in-dataclass-default-argument]

class ConstantFolding(PreserveLocationVisitor, NodeTranslator):
@classmethod
def apply(cls, node: ir.Node) -> ir.Node:
return cls().visit(node)
node = cls().visit(node)
return node

def transform_canonicalize_funcall_symref_literal(
self, node: ir.FunCall, **kwargs
) -> Optional[ir.Node]:
# e.g. `literal + symref` -> `symref + literal` and
# `literal + funcall` -> `funcall + literal` and
# `symref + funcall` -> `funcall + symref`
if (
isinstance(node, ir.FunCall)
and isinstance(node.fun, ir.SymRef)
and cpm.is_call_to(node, ("plus", "multiplies", "minimum", "maximum"))
):
if (
isinstance(node.args[1], (ir.SymRef, ir.FunCall))
and isinstance(node.args[0], ir.Literal)
) or (isinstance(node.args[1], ir.FunCall) and isinstance(node.args[0], ir.SymRef)):
return im.call(node.fun.id)(node.args[1], node.args[0])
return None

def transform_canonicalize_minus(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]:
# `minus(arg0, arg1) -> plus(im.call("neg")(arg1), arg0)`
if isinstance(node, ir.FunCall) and cpm.is_call_to(node, "minus"):
return im.plus(self.fp_transform(im.call("neg")(node.args[1])), node.args[0])
return None

def transform_canonicalize_min_max(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]:
# `maximum(im.call(...)(), maximum(...))` -> `maximum(maximum(...), im.call(...)())`
if cpm.is_call_to(node, ("maximum", "minimum")):
if (
isinstance(node.args[0], ir.FunCall)
and isinstance(node.fun, ir.SymRef)
and not cpm.is_call_to(node.args[0], ("maximum", "minimum"))
and cpm.is_call_to(node.args[1], ("maximum", "minimum"))
):
return im.call(node.fun.id)(node.args[1], node.args[0])
return None

def transform_canonicalize_tuple_get_plus(
self, node: ir.FunCall, **kwargs
) -> Optional[ir.Node]:
# im.call(...)(im.tuple_get(...), im.plus(...)))` -> `im.call(...)( im.plus(...)), im.tuple_get(...))`
if isinstance(node, ir.FunCall) and isinstance(node.fun, ir.SymRef) and len(node.args) > 1:
if cpm.is_call_to(node.args[0], "tuple_get") and cpm.is_call_to(node.args[1], "plus"):
return im.call(node.fun.id)(node.args[1], node.args[0])
return None

def visit_FunCall(self, node: ir.FunCall):
# visit depth-first such that nested constant expressions (e.g. `(1+2)+3`) are properly folded
new_node = self.generic_visit(node)
def transform_fold_funcall_literal(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]:
# `(sym + 1) + 1` -> `sym + 2`
if cpm.is_call_to(node, "plus"):
if (
isinstance(node.args[0], ir.FunCall)
and cpm.is_call_to(node.args[0], "plus")
and isinstance(node.args[1], ir.Literal)
):
fun_call, literal = node.args
if (
isinstance(fun_call, ir.FunCall)
and isinstance(fun_call.args[0], (ir.SymRef, ir.FunCall))
and isinstance(fun_call.args[1], ir.Literal)
):
return im.plus(
fun_call.args[0],
self.fp_transform(im.plus(fun_call.args[1], literal)),
)
return None

def transform_fold_min_max_funcall_symref_literal(
self, node: ir.FunCall, **kwargs
) -> Optional[ir.Node]:
# `maximum(maximum(sym, 1), sym)` -> `maximum(sym, 1)`
if (
isinstance(new_node.fun, ir.SymRef)
and new_node.fun.id in ["minimum", "maximum"]
and new_node.args[0] == new_node.args[1]
): # `minimum(a, a)` -> `a`
return new_node.args[0]
isinstance(node, ir.FunCall)
and isinstance(node.fun, ir.SymRef)
and cpm.is_call_to(node, ("minimum", "maximum"))
):
if cpm.is_call_to(node.args[0], ("maximum", "minimum")):
fun_call, arg1 = node.args
if arg1 == fun_call.args[0]: # type: ignore[attr-defined] # assured by if above
return im.call(fun_call.fun.id)(fun_call.args[1], arg1) # type: ignore[attr-defined] # assured by if above
if arg1 == fun_call.args[1]: # type: ignore[attr-defined] # assured by if above
return im.call(fun_call.fun.id)(fun_call.args[0], arg1) # type: ignore[attr-defined] # assured by if above
return None

def transform_fold_min_max_plus(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]:
if (
isinstance(new_node.fun, ir.SymRef)
and new_node.fun.id == "if_"
and isinstance(new_node.args[0], ir.Literal)
): # `if_(True, true_branch, false_branch)` -> `true_branch`
if new_node.args[0].value == "True":
new_node = new_node.args[1]
else:
new_node = new_node.args[2]
isinstance(node, ir.FunCall)
and isinstance(node.fun, ir.SymRef)
and cpm.is_call_to(node, ("minimum", "maximum"))
):
arg0, arg1 = node.args
# `maximum(plus(sym, 1), sym)` -> `plus(sym, 1)`
if cpm.is_call_to(arg0, "plus"):
if arg0.args[0] == arg1:
return im.plus(
arg0.args[0], self.fp_transform(im.call(node.fun.id)(arg0.args[1], 0))
)
# `maximum(plus(sym, 1), plus(sym, -1))` -> `plus(sym, 1)`
if cpm.is_call_to(arg0, "plus") and cpm.is_call_to(arg1, "plus"):
if arg0.args[0] == arg1.args[0]:
return im.plus(
arg0.args[0],
self.fp_transform(im.call(node.fun.id)(arg0.args[1], arg1.args[1])),
)

return None

def transform_fold_symref_plus_zero(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]:
# `sym + 0` -> `sym`
if (
isinstance(new_node, ir.FunCall)
and isinstance(new_node.fun, ir.SymRef)
and len(new_node.args) > 0
and all(isinstance(arg, ir.Literal) for arg in new_node.args)
): # `1 + 1` -> `2`
cpm.is_call_to(node, "plus")
and isinstance(node.args[1], ir.Literal)
and node.args[1].value.isdigit()
and int(node.args[1].value) == 0
):
return node.args[0]
return None

def transform_canonicalize_plus_symref_literal(
self, node: ir.FunCall, **kwargs
) -> Optional[ir.Node]:
# `sym1 + 1 + (sym2 + 2)` -> `sym1 + sym2 + 2 + 1`
if cpm.is_call_to(node, "plus"):
if (
cpm.is_call_to(node.args[0], "plus")
and cpm.is_call_to(node.args[1], "plus")
and isinstance(node.args[0].args[1], ir.Literal)
and isinstance(node.args[1].args[1], ir.Literal)
):
return im.plus(
self.fp_transform(im.plus(node.args[0].args[0], node.args[1].args[0])),
self.fp_transform(im.plus(node.args[0].args[1], node.args[1].args[1])),
)
return None

def transform_fold_arithmetic_builtins(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]:
# `1 + 1` -> `2`
if (
isinstance(node, ir.FunCall)
and isinstance(node.fun, ir.SymRef)
and len(node.args) > 0
and all(isinstance(arg, ir.Literal) for arg in node.args)
):
try:
if new_node.fun.id in builtins.ARITHMETIC_BUILTINS:
fun = getattr(embedded, str(new_node.fun.id))
if node.fun.id in builtins.ARITHMETIC_BUILTINS:
fun = getattr(embedded, str(node.fun.id))
arg_values = [
getattr(embedded, str(arg.type))(arg.value) # type: ignore[attr-defined] # arg type already established in if condition
for arg in new_node.args
for arg in node.args
]
new_node = im.literal_from_value(fun(*arg_values))
return im.literal_from_value(fun(*arg_values))
except ValueError:
pass # happens for inf and neginf
return None

def transform_fold_min_max_literals(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]:
# `minimum(a, a)` -> `a`
if cpm.is_call_to(node, ("minimum", "maximum")):
if node.args[0] == node.args[1]:
return node.args[0]
return None

return new_node
def transform_fold_if(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]:
# `if_(True, true_branch, false_branch)` -> `true_branch`
if cpm.is_call_to(node, "if_") and isinstance(node.args[0], ir.Literal):
if node.args[0].value == "True":
return node.args[1]
else:
return node.args[2]
return None
Loading
Loading