Skip to content

Commit

Permalink
Finish domain inference for (nested) concat_where and transform to as…
Browse files Browse the repository at this point in the history
…_fieldop
  • Loading branch information
SF-N committed Oct 25, 2024
1 parent d96cebd commit 71026ff
Show file tree
Hide file tree
Showing 18 changed files with 281 additions and 126 deletions.
24 changes: 16 additions & 8 deletions src/gt4py/next/ffront/foast_passes/type_deduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
type_info as ti_ffront,
type_specifications as ts_ffront,
)
from gt4py.next.iterator import ir as itir
from gt4py.next.ffront.foast_passes.utils import compute_assign_indices
from gt4py.next.iterator import ir as itir
from gt4py.next.type_system import type_info, type_specifications as ts, type_translation


Expand Down Expand Up @@ -565,9 +565,17 @@ def _deduce_compare_type(
self, node: foast.Compare, *, left: foast.Expr, right: foast.Expr, **kwargs: Any
) -> Optional[ts.TypeSpec]:
# check both types compatible
if (isinstance(left.type, ts.DimensionType) and isinstance(right.type, ts.ScalarType) and right.type.kind == getattr(ts.ScalarKind, itir.INTEGER_INDEX_BUILTIN.upper())):
if (
isinstance(left.type, ts.DimensionType)
and isinstance(right.type, ts.ScalarType)
and right.type.kind == getattr(ts.ScalarKind, itir.INTEGER_INDEX_BUILTIN.upper())
):
return ts.DomainType(dims=[left.type.dim])
if (isinstance(right.type, ts.DimensionType) and isinstance(left.type, ts.ScalarType) and left.type.kind == getattr(ts.ScalarKind, itir.INTEGER_INDEX_BUILTIN.upper())):
if (
isinstance(right.type, ts.DimensionType)
and isinstance(left.type, ts.ScalarType)
and left.type.kind == getattr(ts.ScalarKind, itir.INTEGER_INDEX_BUILTIN.upper())
):
return ts.DomainType(dims=[right.type.dim])
# TODO
for arg in (left, right):
Expand Down Expand Up @@ -948,9 +956,11 @@ def _visit_concat_where(self, node: foast.Call, **kwargs: Any) -> foast.Call:
true_branch_type = node.args[1].type
false_branch_type = node.args[2].type
if true_branch_type != false_branch_type:
raise errors.DSLError(node.location,
f"Incompatible argument in call to '{node.func!s}': expected "
f"'{true_branch_type}' and '{false_branch_type}' to be equal.")
raise errors.DSLError(
node.location,
f"Incompatible argument in call to '{node.func!s}': expected "
f"'{true_branch_type}' and '{false_branch_type}' to be equal.",
)
return_type = true_branch_type

return foast.Call(
Expand All @@ -961,8 +971,6 @@ def _visit_concat_where(self, node: foast.Call, **kwargs: Any) -> foast.Call:
location=node.location,
)



def _visit_broadcast(self, node: foast.Call, **kwargs: Any) -> foast.Call:
arg_type = cast(ts.FieldType | ts.ScalarType, node.args[0].type)
broadcast_dims_expr = cast(foast.TupleExpr, node.args[1]).elts
Expand Down
18 changes: 14 additions & 4 deletions src/gt4py/next/ffront/foast_to_gtir.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,15 @@ def visit_TernaryExpr(self, node: foast.TernaryExpr, **kwargs: Any) -> itir.FunC

def visit_Compare(self, node: foast.Compare, **kwargs: Any) -> itir.FunCall:
left, right = node.left, node.right
if (isinstance(left.type, ts.DimensionType) and isinstance(right.type, ts.ScalarType) and node.right.type.kind == getattr(ts.ScalarKind, itir.INTEGER_INDEX_BUILTIN.upper())) or (isinstance(right.type, ts.DimensionType) and isinstance(left.type, ts.ScalarType) and left.type.kind == getattr(ts.ScalarKind, itir.INTEGER_INDEX_BUILTIN.upper())):
if (
isinstance(left.type, ts.DimensionType)
and isinstance(right.type, ts.ScalarType)
and right.type.kind == getattr(ts.ScalarKind, itir.INTEGER_INDEX_BUILTIN.upper())
) or (
isinstance(right.type, ts.DimensionType)
and isinstance(left.type, ts.ScalarType)
and left.type.kind == getattr(ts.ScalarKind, itir.INTEGER_INDEX_BUILTIN.upper())
):
lowered_args = [self.visit(arg, **kwargs) for arg in (node.left, node.right)]
return im.call(node.op.value)(*lowered_args)
return self._map(node.op.value, node.left, node.right)
Expand Down Expand Up @@ -404,7 +412,6 @@ def _visit_concat_where(self, node: foast.Call, **kwargs: Any) -> itir.FunCall:

# TODO: tuple case


def _visit_broadcast(self, node: foast.Call, **kwargs: Any) -> itir.FunCall:
expr = self.visit(node.args[0], **kwargs)
if isinstance(node.args[0].type, ts.ScalarType):
Expand Down Expand Up @@ -483,8 +490,11 @@ def _map(self, op: itir.Expr | str, *args: Any, **kwargs: Any) -> itir.FunCall:
op = im.call("map_")(op)
return im.op_as_fieldop(im.call(op))(*lowered_args)

assert all(isinstance(t, (ts.ScalarType, ts.DimensionType)) for arg in args
for t in type_info.primitive_constituents(arg.type))
assert all(
isinstance(t, (ts.ScalarType, ts.DimensionType))
for arg in args
for t in type_info.primitive_constituents(arg.type)
)
return im.call(op)(*lowered_args)


Expand Down
5 changes: 3 additions & 2 deletions src/gt4py/next/iterator/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,9 @@ def _input_validator(self: datamodels.DataModelTP, attribute: datamodels.Attribu
"if_",
"index", # `index(dim)` creates a dim-field that has the current index at each point
"concat_where",
"inf", # TODO: discuss
"neg_inf", #TODO: discuss
"in",
"inf", # TODO: discuss
"neg_inf", # TODO: discuss
*ARITHMETIC_BUILTINS,
*TYPEBUILTINS,
}
Expand Down
19 changes: 11 additions & 8 deletions src/gt4py/next/iterator/ir_utils/domain_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ def domain_union(*domains: SymbolicDomain) -> SymbolicDomain:

return SymbolicDomain(domains[0].grid_type, new_domain_ranges)


def domain_intersection(*domains: SymbolicDomain) -> SymbolicDomain:
"""Return the (set) intersection of a list of domains."""
new_domain_ranges = {}
Expand All @@ -195,22 +196,24 @@ def domain_complement(domain: SymbolicDomain) -> SymbolicDomain:
dims_dict = {}
for dim in domain.ranges.keys():
lb, ub = domain.ranges[dim].start, domain.ranges[dim].stop
if lb == im.ref('neg_inf'):
dims_dict[dim] = SymbolicRange(int(ub.value), "inf")
elif ub == im.ref('inf'):
dims_dict[dim] = SymbolicRange("neg_inf", int(lb.value))
if lb == im.ref("neg_inf"):
dims_dict[dim] = SymbolicRange(start=ub, stop=im.ref("inf"))
elif ub == im.ref("inf"):
dims_dict[dim] = SymbolicRange(start=im.ref("neg_inf"), stop=lb)
else:
raise ValueError("Invalid domain ranges")
return SymbolicDomain(domain.grid_type, dims_dict)
return SymbolicDomain(domain.grid_type, dims_dict)


def promote_to_same_dimensions(domain_small: SymbolicDomain, domain_large: SymbolicDomain) -> SymbolicDomain:
def promote_to_same_dimensions(
domain_small: SymbolicDomain, domain_large: SymbolicDomain
) -> SymbolicDomain:
"""Return an extended domain based on a smaller input domain and a larger domain containing the target dimensions."""
dims_dict = {}
for dim in domain_large.ranges.keys():
if dim in domain_small.ranges.keys():
lb, ub = domain_small.ranges[dim].start, domain_small.ranges[dim].stop
dims_dict[dim] = SymbolicRange(lb, ub)
else:
dims_dict[dim] = SymbolicRange("neg_inf", "inf")
return SymbolicDomain(domain_small.grid_type, dims_dict) # TODO: fix for unstructured
dims_dict[dim] = SymbolicRange(im.ref("neg_inf"), im.ref("inf"))
return SymbolicDomain(domain_small.grid_type, dims_dict) # TODO: fix for unstructured
1 change: 1 addition & 0 deletions src/gt4py/next/iterator/ir_utils/ir_makers.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ def if_(cond, true_val, false_val):
"""Create a if_ FunCall, shorthand for ``call("if_")(expr)``."""
return call("if_")(cond, true_val, false_val)


def concat_where(cond, true_field, false_field):
"""Create a concat_where FunCall, shorthand for ``call("concat_where")(expr)``."""
return call("concat_where")(cond, true_field, false_field)
Expand Down
40 changes: 22 additions & 18 deletions src/gt4py/next/iterator/transforms/constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@

from gt4py.eve import NodeTranslator, PreserveLocationVisitor
from gt4py.next.iterator import embedded, ir
from gt4py.next.iterator.ir_utils import ir_makers as im
from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm
from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im


class ConstantFolding(PreserveLocationVisitor, NodeTranslator):
Expand All @@ -22,37 +21,42 @@ def visit_FunCall(self, node: ir.FunCall):
new_node = self.generic_visit(node)

if (
cpm.is_call_to(new_node, ("minimum","maximum"))
cpm.is_call_to(new_node, ("minimum", "maximum"))
and new_node.args[0] == new_node.args[1]
): # `minimum(a, a)` -> `a`
return new_node.args[0]

if (
cpm.is_call_to(new_node, "minimum")
):
if cpm.is_call_to(new_node, "minimum"): # TODO: add tests
# `minimum(neg_inf, neg_inf)` -> `neg_inf`
if cpm.is_ref_to(new_node.args[0],"neg_inf") or cpm.is_ref_to(new_node.args[1],"neg_inf"):
if cpm.is_ref_to(new_node.args[0], "neg_inf") or cpm.is_ref_to(
new_node.args[1], "neg_inf"
):
return im.ref("neg_inf")
# `minimum(inf, a)` -> `a`
elif cpm.is_ref_to(new_node.args[0],"inf"):
elif cpm.is_ref_to(new_node.args[0], "inf"):
return new_node.args[1]
# `minimum(a, inf)` -> `a`
elif cpm.is_ref_to(new_node.args[1],"inf"):
elif cpm.is_ref_to(new_node.args[1], "inf"):
return new_node.args[0]

if (
cpm.is_call_to(new_node, "maximum")
):
if cpm.is_call_to(new_node, "maximum"): # TODO: add tests
# `minimum(inf, inf)` -> `inf`
if cpm.is_ref_to(new_node.args[0],"inf") or cpm.is_ref_to(new_node.args[1],"inf"):
if cpm.is_ref_to(new_node.args[0], "inf") or cpm.is_ref_to(new_node.args[1], "inf"):
return im.ref("inf")
# `minimum(neg_inf, a)` -> `a`
elif cpm.is_ref_to(new_node.args[0],"neg_inf"):
return new_node.args[1]
elif cpm.is_ref_to(new_node.args[0], "neg_inf"):
return new_node.args[1]
# `minimum(a, neg_inf)` -> `a`
elif cpm.is_ref_to(new_node.args[1],"neg_inf"):
return new_node.args[0]

elif cpm.is_ref_to(new_node.args[1], "neg_inf"):
return new_node.args[0]
if cpm.is_call_to(new_node, ("less", "less_equal")) and cpm.is_ref_to(
new_node.args[0], "neg_inf"
):
return im.literal_from_value(True) # TODO: add tests
if cpm.is_call_to(new_node, ("greater", "greater_equal")) and cpm.is_ref_to(
new_node.args[0], "inf"
):
return im.literal_from_value(True) # TODO: add tests
if (
isinstance(new_node.fun, ir.SymRef)
and new_node.fun.id == "if_"
Expand Down
39 changes: 39 additions & 0 deletions src/gt4py/next/iterator/transforms/expand_library_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# GT4Py - GridTools Framework
#
# Copyright (c) 2014-2024, ETH Zurich
# All rights reserved.
#
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

from functools import reduce

from gt4py.eve import NodeTranslator, PreserveLocationVisitor
from gt4py.next.iterator import ir
from gt4py.next.iterator.ir_utils import (
common_pattern_matcher as cpm,
domain_utils,
ir_makers as im,
)


class ExpandLibraryFunctions(PreserveLocationVisitor, NodeTranslator):
@classmethod
def apply(cls, node: ir.Node):
return cls().visit(node)

def visit_FunCall(self, node: ir.FunCall) -> ir.FunCall:
if cpm.is_call_to(node, "in"):
ret = []
pos, domain = node.args
for i, (k, v) in enumerate(
domain_utils.SymbolicDomain.from_expr(node.args[1]).ranges.items()
):
ret.append(
im.and_(
im.less_equal(v.start, im.tuple_get(i, pos)),
im.less(im.tuple_get(i, pos), v.stop),
)
) # TODO: avoid pos duplication
return reduce(im.and_, ret)
return self.generic_visit(node)
6 changes: 4 additions & 2 deletions src/gt4py/next/iterator/transforms/infer_domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,7 @@ def infer_concat_where(
symbolic_domain_sizes: Optional[dict[str, str]],
) -> tuple[itir.Expr, ACCESSED_DOMAINS]:
assert cpm.is_call_to(expr, "concat_where")
assert isinstance(domain, domain_utils.SymbolicDomain)
infered_args_expr = []
actual_domains: ACCESSED_DOMAINS = {}
cond, true_field, false_field = expr.args
Expand All @@ -332,7 +333,9 @@ def infer_concat_where(
domain_ = domain_utils.domain_intersection(domain, extended_cond)
elif arg == false_field:
cond_complement = domain_utils.domain_complement(symbolic_cond)
extended_cond_complement = domain_utils.promote_to_same_dimensions(cond_complement, domain)
extended_cond_complement = domain_utils.promote_to_same_dimensions(
cond_complement, domain
)
domain_ = domain_utils.domain_intersection(domain, extended_cond_complement)

infered_arg_expr, actual_domains_arg = infer_expr(
Expand All @@ -345,7 +348,6 @@ def infer_concat_where(
return result_expr, actual_domains



def _infer_expr(
expr: itir.Expr,
domain: DOMAIN,
Expand Down
Loading

0 comments on commit 71026ff

Please sign in to comment.