Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[next]: Extend and refactor constant folding #1810

Open
wants to merge 47 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
e4f9be1
constant fold min/max for domain expressions
SF-N Jan 15, 2025
21de7d7
Extend and refactor constant folding
SF-N Jan 21, 2025
c630b2d
Fix ConvertMinusToUnary
SF-N Jan 21, 2025
fddbc81
Cleanup ConstantFolding
SF-N Jan 21, 2025
16cec90
Cleanup ConstantFolding
SF-N Jan 21, 2025
86f41e0
Merge main
SF-N Jan 22, 2025
4a52c38
Fix imports
SF-N Jan 22, 2025
8dd80a2
Some fixes
SF-N Jan 22, 2025
ce8adec
Add builtin for unary minus
SF-N Jan 23, 2025
d0b99ed
Address review comments
SF-N Jan 23, 2025
8f8a8fb
Add abs to test
SF-N Jan 23, 2025
61ec67b
Extend visit_UnaryOp in foast_to_gtir
SF-N Jan 23, 2025
fb6636f
Merge branch 'neg_builtin' into constant_folding_main
SF-N Jan 23, 2025
5962164
Use neg in ConstantFolding
SF-N Jan 23, 2025
a9f2791
Fix foast_to_gtir test references
SF-N Jan 23, 2025
586147c
Cleanup and use neg builtin
SF-N Jan 23, 2025
b265db3
Merge branch 'neg_builtin' into constant_folding_main
SF-N Jan 23, 2025
ed0f062
Take care of tuple_get
SF-N Jan 24, 2025
ed56f02
Minor improvements
SF-N Jan 24, 2025
b1e7197
Move fixed point transformations to a new class
SF-N Jan 27, 2025
ec687b0
Merge branch 'main' into constant_folding_main
SF-N Jan 27, 2025
f250699
Run pre-commit
SF-N Jan 27, 2025
c24ea80
Address review comments
SF-N Jan 27, 2025
4e4869a
Remove self.visit
SF-N Jan 27, 2025
5f09bea
Merge main
SF-N Jan 28, 2025
39eda41
Minor
SF-N Jan 28, 2025
b5f32cf
Update src/gt4py/next/iterator/builtins.py
tehrengruber Jan 28, 2025
054bc41
Update src/gt4py/next/iterator/embedded.py
tehrengruber Jan 28, 2025
93e0ab5
Update src/gt4py/next/iterator/transforms/collapse_tuple.py
tehrengruber Jan 28, 2025
3ea7217
Update src/gt4py/next/iterator/transforms/collapse_tuple.py
tehrengruber Jan 28, 2025
00332cc
Update src/gt4py/next/iterator/transforms/collapse_tuple.py
tehrengruber Jan 28, 2025
c65dbfb
Remove file
SF-N Jan 28, 2025
089ed22
Merge branch 'constant_folding_main' of github.com:SF-N/gt4py into co…
SF-N Jan 28, 2025
9e14480
Merge branch 'main' into constant_folding_main
SF-N Jan 28, 2025
9676334
Add new line
SF-N Jan 28, 2025
1543688
Merge branch 'constant_folding_main' of github.com:SF-N/gt4py into co…
SF-N Jan 28, 2025
231c0e5
Address review comments
SF-N Jan 28, 2025
5df9b8b
Remove unary minus in the end
SF-N Jan 28, 2025
8ac18ff
Merge branch 'main' into constant_folding_main
SF-N Jan 29, 2025
91f2a66
Cleanup UndoCanonicalizeMinus
SF-N Jan 29, 2025
87184cd
Merge branch 'constant_folding_main' of github.com:SF-N/gt4py into co…
SF-N Jan 29, 2025
84d813d
Cleanup UndoCanonicalizeMinus
SF-N Jan 29, 2025
1445dd5
Merge branch 'main' into constant_folding_main
SF-N Feb 10, 2025
101a25b
Merge branch 'main' into constant_folding_main
SF-N Feb 11, 2025
ed0248e
Fix some test failures
SF-N Feb 11, 2025
d452af9
Clean up constant_folding tests
SF-N Feb 11, 2025
1324765
Merge branch 'main' into constant_folding_main
SF-N Feb 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
234 changes: 203 additions & 31 deletions src/gt4py/next/iterator/transforms/constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,57 +6,229 @@
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

from gt4py.eve import NodeTranslator, PreserveLocationVisitor
from __future__ import annotations

import dataclasses
import enum
import functools
import operator
from typing import Optional

from gt4py import eve
from gt4py.next.iterator import builtins, embedded, ir
from gt4py.next.iterator.ir_utils import ir_makers as im
from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im
from gt4py.next.iterator.transforms import fixed_point_transformation


def value_from_literal(literal: ir.Literal):
return getattr(embedded, str(literal.type))(literal.value)


class UndoCanonicalizeMinus(eve.NodeTranslator):
PRESERVED_ANNEX_ATTRS = (
"type",
"domain",
)

def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node:
node = super().generic_visit(node, **kwargs)
# `a + (-b)` -> `a - b` , `-a + b` -> `b - a`, `-a + (-b)` -> `-a - b`
if cpm.is_call_to(node, "plus"):
a, b = node.args
if cpm.is_call_to(b, "neg"):
return im.minus(a, b.args[0])
if isinstance(b, ir.Literal) and value_from_literal(b) < 0:
return im.minus(a, -value_from_literal(b))
if cpm.is_call_to(a, "neg"):
return im.minus(b, a.args[0])
if isinstance(a, ir.Literal) and value_from_literal(a) < 0:
return im.minus(b, -value_from_literal(a))
return node

class ConstantFolding(PreserveLocationVisitor, NodeTranslator):

@dataclasses.dataclass(frozen=True, kw_only=True)
class ConstantFolding(
fixed_point_transformation.FixedPointTransformation, eve.PreserveLocationVisitor
):
PRESERVED_ANNEX_ATTRS = (
"type",
"domain",
)

class Transformation(enum.Flag):
# `literal, symref` -> `symref, literal`, prerequisite for FOLD_FUNCALL_LITERAL, FOLD_NEUTRAL_OP
# `literal, funcall` -> `funcall, literal`, prerequisite for FOLD_FUNCALL_LITERAL, FOLD_NEUTRAL_OP
# `funcall, op` -> `op, funcall` for s[0] + (s[0] + 1), prerequisite for FOLD_MIN_MAX_PLUS
CANONICALIZE_OP_FUNCALL_SYMREF_LITERAL = enum.auto()

# `a - b` -> `a + (-b)`, prerequisite for FOLD_MIN_MAX_PLUS
CANONICALIZE_MINUS = enum.auto()

# `maximum(a, maximum(...))` -> `maximum(maximum(...), a)`, prerequisite for FOLD_MIN_MAX
CANONICALIZE_MIN_MAX = enum.auto()

# `(a + 1) + 1` -> `a + (1 + 1)`
FOLD_FUNCALL_LITERAL = enum.auto()

# `maximum(maximum(a, 1), a)` -> `maximum(a, 1)`
# `maximum(maximum(a, 1), 1)` -> `maximum(a, 1)`
FOLD_MIN_MAX = enum.auto()

# `maximum(plus(a, 1), a)` -> `plus(a, 1)`
# `maximum(plus(a, 1), plus(a, -1))` -> `plus(a, maximum(1, -1))`
FOLD_MIN_MAX_PLUS = enum.auto()

# `a + 0` -> `a`, `a * 1` -> `a`
FOLD_NEUTRAL_OP = enum.auto()

# `1 + 1` -> `2`
FOLD_ARITHMETIC_BUILTINS = enum.auto()

# `minimum(a, a)` -> `a`
FOLD_MIN_MAX_LITERALS = enum.auto()

# `if_(True, true_branch, false_branch)` -> `true_branch`
FOLD_IF = enum.auto()

@classmethod
def all(self) -> ConstantFolding.Transformation:
return functools.reduce(operator.or_, self.__members__.values())

enabled_transformations: Transformation = Transformation.all() # noqa: RUF009 [function-call-in-dataclass-default-argument]

@classmethod
def apply(cls, node: ir.Node) -> ir.Node:
return cls().visit(node)
node = cls().visit(node)
return UndoCanonicalizeMinus().visit(node)

def transform_canonicalize_op_funcall_symref_literal(
self, node: ir.FunCall, **kwargs
) -> Optional[ir.Node]:
# `literal, symref` -> `symref, literal`, prerequisite for FOLD_FUNCALL_LITERAL, FOLD_NEUTRAL_OP
# `literal, funcall` -> `funcall, literal`, prerequisite for FOLD_FUNCALL_LITERAL, FOLD_NEUTRAL_OP
# `funcall, op` -> `op, funcall` for s[0] + (s[0] + 1), prerequisite for FOLD_MIN_MAX_PLUS
if cpm.is_call_to(node, ("plus", "multiplies", "minimum", "maximum")):
if (
isinstance(node.args[0], ir.Literal) and not isinstance(node.args[1], ir.Literal)
) or (
(not cpm.is_call_to(node.args[0], ("plus", "multiplies", "minimum", "maximum")))
and cpm.is_call_to(node.args[1], ("plus", "multiplies", "minimum", "maximum"))
):
return im.call(node.fun)(node.args[1], node.args[0])
return None

def visit_FunCall(self, node: ir.FunCall):
# visit depth-first such that nested constant expressions (e.g. `(1+2)+3`) are properly folded
new_node = self.generic_visit(node)
def transform_canonicalize_minus(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]:
# `a - b` -> `a + (-b)`, prerequisite for FOLD_MIN_MAX_PLUS
if cpm.is_call_to(node, "minus"):
return im.plus(node.args[0], self.fp_transform(im.call("neg")(node.args[1])))
return None

def transform_canonicalize_min_max(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]:
# `maximum(a, maximum(...))` -> `maximum(maximum(...), a)`, prerequisite for FOLD_MIN_MAX
if cpm.is_call_to(node, ("maximum", "minimum")):
op = node.fun.id # type: ignore[attr-defined] # assured by if above
if cpm.is_call_to(node.args[1], op) and not cpm.is_call_to(node.args[0], op):
return im.call(op)(node.args[1], node.args[0])
return None

def transform_fold_funcall_literal(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]:
# `(a + 1) + 1` -> `a + (1 + 1)`
if cpm.is_call_to(node, "plus"):
if cpm.is_call_to(node.args[0], "plus") and isinstance(node.args[1], ir.Literal):
fun_call, literal = node.args
if isinstance(fun_call.args[0], (ir.SymRef, ir.FunCall)) and isinstance( # type: ignore[attr-defined] # assured by if above
fun_call.args[1], # type: ignore[attr-defined] # assured by if above
ir.Literal,
):
return im.plus(
fun_call.args[0], # type: ignore[attr-defined] # assured by if above
self.fp_transform(im.plus(fun_call.args[1], literal)), # type: ignore[attr-defined] # assured by if above
)
return None

def transform_fold_min_max(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]:
# `maximum(maximum(a, 1), a)` -> `maximum(a, 1)`
# `maximum(maximum(a, 1), 1)` -> `maximum(a, 1)`
if cpm.is_call_to(node, ("minimum", "maximum")):
op = node.fun.id # type: ignore[attr-defined] # assured by if above
if cpm.is_call_to(node.args[0], op):
fun_call, arg1 = node.args
if arg1 in fun_call.args: # type: ignore[attr-defined] # assured by if above
return fun_call
return None

def transform_fold_min_max_plus(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]:
if (
isinstance(new_node.fun, ir.SymRef)
and new_node.fun.id in ["minimum", "maximum"]
and new_node.args[0] == new_node.args[1]
): # `minimum(a, a)` -> `a`
return new_node.args[0]
isinstance(node, ir.FunCall)
and isinstance(node.fun, ir.SymRef)
and cpm.is_call_to(node, ("minimum", "maximum"))
):
arg0, arg1 = node.args
# `maximum(plus(a, 1), a)` -> `plus(a, 1)`
if cpm.is_call_to(arg0, "plus"):
if arg0.args[0] == arg1:
return im.plus(
arg0.args[0], self.fp_transform(im.call(node.fun.id)(arg0.args[1], 0))
)
# `maximum(plus(a, 1), plus(a, -1))` -> `plus(a, maximum(1, -1))`
if cpm.is_call_to(arg0, "plus") and cpm.is_call_to(arg1, "plus"):
if arg0.args[0] == arg1.args[0]:
return im.plus(
arg0.args[0],
self.fp_transform(im.call(node.fun.id)(arg0.args[1], arg1.args[1])),
)

return None

def transform_fold_neutral_op(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]:
# `a + 0` -> `a`, `a * 1` -> `a`
if (
isinstance(new_node.fun, ir.SymRef)
and new_node.fun.id == "if_"
and isinstance(new_node.args[0], ir.Literal)
): # `if_(True, true_branch, false_branch)` -> `true_branch`
if new_node.args[0].value == "True":
new_node = new_node.args[1]
else:
new_node = new_node.args[2]
cpm.is_call_to(node, "plus")
and isinstance(node.args[1], ir.Literal)
and node.args[1].value.isdigit()
and int(node.args[1].value) == 0
) or (
cpm.is_call_to(node, "multiplies")
and isinstance(node.args[1], ir.Literal)
and node.args[1].value.isdigit()
and int(node.args[1].value) == 1
):
return node.args[0]
return None

@classmethod
def transform_fold_arithmetic_builtins(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]:
# `1 + 1` -> `2`
if (
isinstance(new_node, ir.FunCall)
and isinstance(new_node.fun, ir.SymRef)
and len(new_node.args) > 0
and all(isinstance(arg, ir.Literal) for arg in new_node.args)
): # `1 + 1` -> `2`
isinstance(node, ir.FunCall)
and isinstance(node.fun, ir.SymRef)
and len(node.args) > 0
and all(isinstance(arg, ir.Literal) for arg in node.args)
):
try:
if new_node.fun.id in builtins.ARITHMETIC_BUILTINS:
fun = getattr(embedded, str(new_node.fun.id))
if node.fun.id in builtins.ARITHMETIC_BUILTINS:
fun = getattr(embedded, str(node.fun.id))
arg_values = [
getattr(embedded, str(arg.type))(arg.value) # type: ignore[attr-defined] # arg type already established in if condition
for arg in new_node.args
value_from_literal(arg) # type: ignore[arg-type] # arg type already established in if condition
for arg in node.args
]
new_node = im.literal_from_value(fun(*arg_values))
return im.literal_from_value(fun(*arg_values))
except ValueError:
pass # happens for inf and neginf
return None

return new_node
def transform_fold_min_max_literals(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]:
# `minimum(a, a)` -> `a`
if cpm.is_call_to(node, ("minimum", "maximum")):
if node.args[0] == node.args[1]:
return node.args[0]
return None

def transform_fold_if(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]:
# `if_(True, true_branch, false_branch)` -> `true_branch`
if cpm.is_call_to(node, "if_") and isinstance(node.args[0], ir.Literal):
if node.args[0].value == "True":
return node.args[1]
else:
return node.args[2]
return None
Loading