Skip to content

Commit

Permalink
Fix within_set_at_expr context
Browse files Browse the repository at this point in the history
  • Loading branch information
tehrengruber committed Feb 4, 2025
1 parent ae8c3bc commit 45d017e
Showing 1 changed file with 12 additions and 9 deletions.
21 changes: 12 additions & 9 deletions src/gt4py/next/iterator/transforms/fuse_as_fieldop.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,13 +349,6 @@ def apply(
node, within_set_at_expr=within_set_at_expr
)

def visit_SetAt(self, node: itir.SetAt, **kwargs):
return itir.SetAt(
expr=self.visit(node.expr, **kwargs | {"within_set_at_expr": True}),
domain=node.domain,
target=node.target,
)

def transform_fuse_make_tuple(self, node: itir.Node, **kwargs):
if not cpm.is_call_to(node, "make_tuple"):
return None
Expand Down Expand Up @@ -448,7 +441,8 @@ def transform_inline_let_vars_opcount_preserving(self, node: itir.Node, **kwargs

def generic_visit(self, node, **kwargs):
if cpm.is_applied_as_fieldop(node): # don't descend in stencil
return im.as_fieldop(*node.fun.args)(*self.generic_visit(node.args, **kwargs))
return im.as_fieldop(*node.fun.args)(*self.visit(node.args, **kwargs))

# TODO(tehrengruber): This is a common pattern that should be absorbed in
# `FixedPointTransformation`.
if kwargs.get("recurse", True):
Expand All @@ -457,8 +451,17 @@ def generic_visit(self, node, **kwargs):
return node

def visit(self, node, **kwargs):
if isinstance(node, itir.SetAt):
return itir.SetAt(
expr=self.visit(node.expr, **kwargs | {"within_set_at_expr": True}),
# rest doesn't need to be visited
domain=node.domain,
target=node.target,
)

# don't execute transformations unless inside `SetAt` node
if not kwargs.get("within_set_at_expr"):
return node
return self.generic_visit(node, **kwargs)

# inline all fields with list dtype. This needs to happen before the children are visited
# such that the `as_fieldop` can be fused.
Expand Down

0 comments on commit 45d017e

Please sign in to comment.