Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[next]: Extend and refactor constant folding #1810

Open
wants to merge 47 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
e4f9be1
constant fold min/max for domain expressions
SF-N Jan 15, 2025
21de7d7
Extend and refactor constant folding
SF-N Jan 21, 2025
c630b2d
Fix ConvertMinusToUnary
SF-N Jan 21, 2025
fddbc81
Cleanup ConstantFolding
SF-N Jan 21, 2025
16cec90
Cleanup ConstantFolding
SF-N Jan 21, 2025
86f41e0
Merge main
SF-N Jan 22, 2025
4a52c38
Fix imports
SF-N Jan 22, 2025
8dd80a2
Some fixes
SF-N Jan 22, 2025
ce8adec
Add builtin for unary minus
SF-N Jan 23, 2025
d0b99ed
Address review comments
SF-N Jan 23, 2025
8f8a8fb
Add abs to test
SF-N Jan 23, 2025
61ec67b
Extend visit_UnaryOp in foast_to_gtir
SF-N Jan 23, 2025
fb6636f
Merge branch 'neg_builtin' into constant_folding_main
SF-N Jan 23, 2025
5962164
Use neg in ConstantFolding
SF-N Jan 23, 2025
a9f2791
Fix foast_to_gtir test references
SF-N Jan 23, 2025
586147c
Cleanup and use neg builtin
SF-N Jan 23, 2025
b265db3
Merge branch 'neg_builtin' into constant_folding_main
SF-N Jan 23, 2025
ed0f062
Take care of tuple_get
SF-N Jan 24, 2025
ed56f02
Minor improvements
SF-N Jan 24, 2025
b1e7197
Move fixed point transformations to a new class
SF-N Jan 27, 2025
ec687b0
Merge branch 'main' into constant_folding_main
SF-N Jan 27, 2025
f250699
Run pre-commit
SF-N Jan 27, 2025
c24ea80
Address review comments
SF-N Jan 27, 2025
4e4869a
Remove self.visit
SF-N Jan 27, 2025
5f09bea
Merge main
SF-N Jan 28, 2025
39eda41
Minor
SF-N Jan 28, 2025
b5f32cf
Update src/gt4py/next/iterator/builtins.py
tehrengruber Jan 28, 2025
054bc41
Update src/gt4py/next/iterator/embedded.py
tehrengruber Jan 28, 2025
93e0ab5
Update src/gt4py/next/iterator/transforms/collapse_tuple.py
tehrengruber Jan 28, 2025
3ea7217
Update src/gt4py/next/iterator/transforms/collapse_tuple.py
tehrengruber Jan 28, 2025
00332cc
Update src/gt4py/next/iterator/transforms/collapse_tuple.py
tehrengruber Jan 28, 2025
c65dbfb
Remove file
SF-N Jan 28, 2025
089ed22
Merge branch 'constant_folding_main' of github.com:SF-N/gt4py into co…
SF-N Jan 28, 2025
9e14480
Merge branch 'main' into constant_folding_main
SF-N Jan 28, 2025
9676334
Add new line
SF-N Jan 28, 2025
1543688
Merge branch 'constant_folding_main' of github.com:SF-N/gt4py into co…
SF-N Jan 28, 2025
231c0e5
Address review comments
SF-N Jan 28, 2025
5df9b8b
Remove unary minus in the end
SF-N Jan 28, 2025
8ac18ff
Merge branch 'main' into constant_folding_main
SF-N Jan 29, 2025
91f2a66
Cleanup UndoCanonicalizeMinus
SF-N Jan 29, 2025
87184cd
Merge branch 'constant_folding_main' of github.com:SF-N/gt4py into co…
SF-N Jan 29, 2025
84d813d
Cleanup UndoCanonicalizeMinus
SF-N Jan 29, 2025
1445dd5
Merge branch 'main' into constant_folding_main
SF-N Feb 10, 2025
101a25b
Merge branch 'main' into constant_folding_main
SF-N Feb 11, 2025
ed0248e
Fix some test failures
SF-N Feb 11, 2025
d452af9
Clean up constant_folding tests
SF-N Feb 11, 2025
1324765
Merge branch 'main' into constant_folding_main
SF-N Feb 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/gt4py/next/ffront/fbuiltins.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import functools
import inspect
import math
import operator
from builtins import bool, float, int, tuple # noqa: A004 shadowing a Python built-in
from typing import Any, Callable, Final, Generic, ParamSpec, Tuple, TypeAlias, TypeVar, Union, cast

Expand Down Expand Up @@ -203,7 +204,7 @@ def astype(
return core_defs.dtype(type_).scalar_type(value)


_UNARY_MATH_NUMBER_BUILTIN_IMPL: Final = {"abs": abs}
_UNARY_MATH_NUMBER_BUILTIN_IMPL: Final = {"abs": abs, "neg": operator.neg}
UNARY_MATH_NUMBER_BUILTIN_NAMES: Final = [*_UNARY_MATH_NUMBER_BUILTIN_IMPL.keys()]

_UNARY_MATH_FP_BUILTIN_IMPL: Final = {
Expand Down
12 changes: 6 additions & 6 deletions src/gt4py/next/ffront/foast_to_gtir.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,12 +241,12 @@ def visit_UnaryOp(self, node: foast.UnaryOp, **kwargs: Any) -> itir.Expr:
if dtype.kind != ts.ScalarKind.BOOL:
raise NotImplementedError(f"'{node.op}' is only supported on 'bool' arguments.")
return self._lower_and_map("not_", node.operand)

return self._lower_and_map(
node.op.value,
foast.Constant(value="0", type=dtype, location=node.location),
node.operand,
)
if node.op in [dialect_ast_enums.UnaryOperator.USUB]:
return self._lower_and_map("neg", node.operand)
if node.op in [dialect_ast_enums.UnaryOperator.UADD]:
return self.visit(node.operand)
else:
raise NotImplementedError(f"Unary operator '{node.op}' is not supported.")

def visit_BinOp(self, node: foast.BinOp, **kwargs: Any) -> itir.FunCall:
return self._lower_and_map(node.op.value, node.left, node.right)
Expand Down
8 changes: 7 additions & 1 deletion src/gt4py/next/iterator/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,11 @@ def trunc(*args):
raise BackendNotSelectedError()


@builtin_dispatch
def neg(*args):
raise BackendNotSelectedError()


@builtin_dispatch
def isfinite(*args):
raise BackendNotSelectedError()
Expand Down Expand Up @@ -397,7 +402,7 @@ def bool(*args): # noqa: A001 [builtin-variable-shadowing]
raise BackendNotSelectedError()


UNARY_MATH_NUMBER_BUILTINS = {"abs"}
UNARY_MATH_NUMBER_BUILTINS = {"abs", "neg"}
UNARY_LOGICAL_BUILTINS = {"not_"}
UNARY_MATH_FP_BUILTINS = {
"sin",
Expand All @@ -420,6 +425,7 @@ def bool(*args): # noqa: A001 [builtin-variable-shadowing]
"floor",
"ceil",
"trunc",
"neg",
tehrengruber marked this conversation as resolved.
Show resolved Hide resolved
}
UNARY_MATH_FP_PREDICATE_BUILTINS = {"isfinite", "isinf", "isnan"}
BINARY_MATH_NUMBER_BUILTINS = {
Expand Down
10 changes: 9 additions & 1 deletion src/gt4py/next/iterator/embedded.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,13 @@ def not_(a):
return not a


@builtins.neg.register(EMBEDDED)
def neg(a):
if isinstance(a, Column):
return np.negative(a)
return np.negative(a)


@builtins.gamma.register(EMBEDDED)
def gamma(a):
gamma_ = np.vectorize(math.gamma)
Expand Down Expand Up @@ -538,10 +545,11 @@ def promote_scalars(val: CompositeOfScalarOrField):
"and_": operator.and_,
"or_": operator.or_,
"xor_": operator.xor,
"neg": operator.neg,
}
decorator = getattr(builtins, math_builtin_name).register(EMBEDDED)
impl: Callable
if math_builtin_name in ["gamma", "not_"]:
if math_builtin_name in ["gamma", "not_", "neg"]:
tehrengruber marked this conversation as resolved.
Show resolved Hide resolved
continue # treated explicitly
elif math_builtin_name in python_builtins:
# TODO: Should potentially use numpy fixed size types to be consistent
Expand Down
30 changes: 2 additions & 28 deletions src/gt4py/next/iterator/transforms/collapse_tuple.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import operator
from typing import Optional

from gt4py import eve
from gt4py.eve import utils as eve_utils
from gt4py.next import common
from gt4py.next.iterator import ir
Expand All @@ -23,6 +22,7 @@
ir_makers as im,
misc as ir_misc,
)
from gt4py.next.iterator.transforms.fixed_point_transform import FixedPointTransform
from gt4py.next.iterator.transforms.inline_lambdas import InlineLambdas, inline_lambda
from gt4py.next.iterator.type_system import inference as itir_type_inference
from gt4py.next.type_system import type_info, type_specifications as ts
Expand Down Expand Up @@ -87,7 +87,7 @@ def _is_trivial_or_tuple_thereof_expr(node: ir.Node) -> bool:
# reads a little convoluted and is also different to how we write other transformations. We
# should revisit the pattern here and try to find a more general mechanism.
@dataclasses.dataclass(frozen=True)
class CollapseTuple(eve.PreserveLocationVisitor, eve.NodeTranslator):
class CollapseTuple(FixedPointTransform):
"""
Simplifies `make_tuple`, `tuple_get` calls.

Expand Down Expand Up @@ -217,32 +217,6 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node:
node = self.generic_visit(node, **kwargs)
return self.fp_transform(node, **kwargs)

def fp_transform(self, node: ir.Node, **kwargs) -> ir.Node:
while True:
new_node = self.transform(node, **kwargs)
if new_node is None:
break
assert new_node != node
node = new_node
return node

def transform(self, node: ir.Node, **kwargs) -> Optional[ir.Node]:
if not isinstance(node, ir.FunCall):
return None

for transformation in self.Flag:
if self.flags & transformation:
assert isinstance(transformation.name, str)
method = getattr(self, f"transform_{transformation.name.lower()}")
result = method(node, **kwargs)
if result is not None:
assert (
result is not node
) # transformation should have returned None, since nothing changed
itir_type_inference.reinfer(result)
return result
return None

def transform_collapse_make_tuple_tuple_get(
self, node: ir.FunCall, **kwargs
) -> Optional[ir.Node]:
Expand Down
Loading
Loading