Skip to content

Commit

Permalink
updated k-interval computation
Browse files Browse the repository at this point in the history
  • Loading branch information
twicki committed Jan 31, 2025
1 parent b54135f commit a516e3d
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 19 deletions.
37 changes: 28 additions & 9 deletions src/gt4py/cartesian/gtc/passes/gtir_k_boundary.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,17 @@ def visit_FieldAccess(
if interval.start.level == LevelMarker.START and (
include_center_interval or interval.end.level == LevelMarker.START
):
boundary = (max(-interval.start.offset - node.offset.k, boundary[0]), boundary[1])
boundary = (
max(-interval.start.offset - node.offset.k, boundary[0]),
boundary[1],
)
if (
include_center_interval or interval.start.level == LevelMarker.END
) and interval.end.level == LevelMarker.END:
boundary = (boundary[0], max(interval.end.offset + node.offset.k, boundary[1]))
boundary = (
boundary[0],
max(interval.end.offset + node.offset.k, boundary[1]),
)
if node.name in [decl.name for decl in vloop.temporaries] and (
boundary[0] > 0 or boundary[1] > 0
):
Expand All @@ -70,17 +76,30 @@ def compute_k_boundary(
return KBoundaryVisitor().visit(node, include_center_interval=include_center_interval)


def compute_min_k_size(node: gtir.Stencil, include_center_interval=True) -> int:
def compute_min_k_size(node: gtir.Stencil) -> int:
"""Compute the required number of k levels to run a stencil."""

min_size_start = 0
min_size_end = 0
biggest_offset = 0
for vloop in node.vertical_loops:
if vloop.interval.start.level == LevelMarker.START and (
include_center_interval or vloop.interval.end.level == LevelMarker.START
if (
vloop.interval.start.level == LevelMarker.START
and vloop.interval.end.level == LevelMarker.END
):
min_size_start = max(min_size_start, vloop.interval.end.offset)
if not (vloop.interval.start.offset == 0 and vloop.interval.end.offset == 0):
biggest_offset = max(
biggest_offset,
vloop.interval.start.offset - vloop.interval.end.offset + 1,
)
elif (
include_center_interval or vloop.interval.start.level == LevelMarker.END
) and vloop.interval.end.level == LevelMarker.END:
vloop.interval.start.level == LevelMarker.START
and vloop.interval.end.level == LevelMarker.START
):
min_size_start = max(min_size_start, vloop.interval.end.offset)
biggest_offset = max(biggest_offset, vloop.interval.end.offset)
else:
min_size_end = max(min_size_end, -vloop.interval.start.offset)
return min_size_start + min_size_end
biggest_offset = max(biggest_offset, -vloop.interval.start.offset)
minimal_size = max(min_size_start + min_size_end, biggest_offset)
return minimal_size
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@
from gt4py import cartesian as gt4pyc
from gt4py.cartesian import gtscript as gs
from gt4py.cartesian.backend import from_name
from gt4py.cartesian.gtc.passes.gtir_k_boundary import compute_k_boundary, compute_min_k_size
from gt4py.cartesian.gtc.passes.gtir_k_boundary import (
compute_k_boundary,
compute_min_k_size,
)
from gt4py.cartesian.gtc.passes.gtir_pipeline import prune_unused_parameters
from gt4py.cartesian.gtscript import PARALLEL, computation, interval, stencil
from gt4py.cartesian.stencil_builder import StencilBuilder
Expand Down Expand Up @@ -48,21 +51,21 @@ def stencil_no_extent_0(field_a: gs.Field[float], field_b: gs.Field[float]):
field_a = field_b[0, 0, 0]


@register_test_case(k_bounds=(max(0, -2), 0), min_k_size=2)
@register_test_case(k_bounds=(0, 0), min_k_size=2)
@typing.no_type_check
def stencil_no_extent_1(field_a: gs.Field[float], field_b: gs.Field[float]):
with computation(PARALLEL), interval(0, 2):
field_a = field_b[0, 0, 0]


@register_test_case(k_bounds=(max(-1, -2), 0), min_k_size=2)
@register_test_case(k_bounds=(-1, 0), min_k_size=2)
@typing.no_type_check
def stencil_no_extent_2(field_a: gs.Field[float], field_b: gs.Field[float]):
with computation(PARALLEL), interval(1, 2):
field_a = field_b[0, 0, 0]


@register_test_case(k_bounds=(max(max(0, -2), max(-2, -2)), 0), min_k_size=3)
@register_test_case(k_bounds=(0, 0), min_k_size=4)
@typing.no_type_check
def stencil_no_extent_3(field_a: gs.Field[float], field_b: gs.Field[float]):
with computation(PARALLEL), interval(0, 2):
Expand All @@ -73,14 +76,14 @@ def stencil_no_extent_3(field_a: gs.Field[float], field_b: gs.Field[float]):
field_a = field_b[0, 0, 0]


@register_test_case(k_bounds=(0, max(-1, 0)), min_k_size=1)
@register_test_case(k_bounds=(0, 0), min_k_size=1)
@typing.no_type_check
def stencil_no_extent_4(field_a: gs.Field[float], field_b: gs.Field[float]):
with computation(PARALLEL), interval(-1, None):
field_a = field_b[0, 0, 0]


@register_test_case(k_bounds=(max(0, -1), max(-2, 0)), min_k_size=3)
@register_test_case(k_bounds=(0, 0), min_k_size=3)
@typing.no_type_check
def stencil_no_extent_5(field_a: gs.Field[float], field_b: gs.Field[float]):
with computation(PARALLEL), interval(0, 1):
Expand All @@ -89,6 +92,13 @@ def stencil_no_extent_5(field_a: gs.Field[float], field_b: gs.Field[float]):
field_a = field_b[0, 0, 0]


@register_test_case(k_bounds=(-1, -2), min_k_size=4)
@typing.no_type_check
def stencil_no_extent_6(field_a: gs.Field[float], field_b: gs.Field[float]):
with computation(PARALLEL), interval(1, -2):
field_a[0, 0, 0] = field_b[0, 0, 0]


# stencils with extent
@register_test_case(k_bounds=(5, -5), min_k_size=0)
@typing.no_type_check
Expand All @@ -111,7 +121,7 @@ def stencil_with_extent_2(field_a: gs.Field[float], field_b: gs.Field[float]):
field_a = field_b[0, 0, 5]


@register_test_case(k_bounds=(3, -3), min_k_size=3)
@register_test_case(k_bounds=(3, -3), min_k_size=4)
@typing.no_type_check
def stencil_with_extent_3(field_a: gs.Field[float], field_b: gs.Field[float]):
with computation(PARALLEL), interval(0, 2):
Expand All @@ -122,7 +132,7 @@ def stencil_with_extent_3(field_a: gs.Field[float], field_b: gs.Field[float]):
field_a = field_b[0, 0, -3]


@register_test_case(k_bounds=(-5, 5), min_k_size=1)
@register_test_case(k_bounds=(-5, 5), min_k_size=2)
@typing.no_type_check
def stencil_with_extent_4(field_a: gs.Field[float], field_b: gs.Field[float]):
with computation(PARALLEL), interval(0, -1):
Expand Down Expand Up @@ -171,7 +181,10 @@ def test_min_k_size(definition, expected_min_k_size):

@pytest.mark.parametrize("definition,expected", test_data)
def test_k_bounds_exec(definition, expected):
expected_k_bounds, expected_min_k_size = expected["k_bounds"], expected["min_k_size"]
expected_k_bounds, expected_min_k_size = (
expected["k_bounds"],
expected["min_k_size"],
)

required_field_size = expected_min_k_size + expected_k_bounds[0] + expected_k_bounds[1]

Expand Down Expand Up @@ -234,7 +247,10 @@ def stencil_with_invalid_temporary_access_end(field_a: gs.Field[float], field_b:

@pytest.mark.parametrize(
"definition",
[stencil_with_invalid_temporary_access_start, stencil_with_invalid_temporary_access_end],
[
stencil_with_invalid_temporary_access_start,
stencil_with_invalid_temporary_access_end,
],
)
def test_invalid_temporary_access(definition):
builder = StencilBuilder(definition, backend=from_name("numpy"))
Expand Down

0 comments on commit a516e3d

Please sign in to comment.