-
Notifications
You must be signed in to change notification settings - Fork 49
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat[next]: Embedded field scan #1365
Merged
Merged
Changes from 121 commits
Commits
Show all changes
122 commits
Select commit
Hold shift + click to select a range
a6269f7
Add definitions and placeholders for basic embedded field view concepts
egparedes 006093f
Fixes
egparedes c4f9443
make it pass tests
havogt bf649f0
save state
havogt c73c296
define DomainLike
havogt 555e9f2
verify
havogt 40a2d37
wip
havogt 0afad47
refactor type translation
havogt 0122863
small fixes
havogt 80c9c12
finish refactoring builtins
havogt 1475f17
cleanup
havogt 5ab9a54
builtin tests
havogt 5df8d9d
more tests
havogt 1993f32
comments
havogt 4f749ca
cleanup
havogt 5555ad7
fix and ignore
havogt 6042750
some more qa fixes
havogt 87f0975
a few more mypy errors
havogt cc32ed2
typing
havogt 0e4ba03
more cleanup, non-dispatch test
havogt 007a8d7
remove debug print
havogt b6e91b6
reenable test
havogt b3fc186
more typing fixes
havogt aa4c94c
Suggestions and fixes
egparedes b8279ae
More fixes
egparedes c20b610
Missing changes from previous commit
egparedes cfcb927
Add docs to FieldABC
egparedes c1444f5
fix some more typing problems
havogt 85b4eca
fix remaining typing problems, move back registry
havogt 8ead6d2
more typing
havogt e414e18
fix test matrix
havogt 0543b85
fix builtin params
havogt 33bdc76
move ndarray_field to embedded
havogt 59c1962
fix check
havogt dadfe18
cartesian connectivity wip
havogt 81f4bae
Merge remote-tracking branch 'origin/main' into embedded-field-view-impl
havogt d45ba0c
slightly cleaner
havogt 45b8151
minor
havogt 34a1391
improve implementation, basic program test
havogt c6fc23e
Merge remote-tracking branch 'origin/main' into embedded-field-view-impl
havogt 27aad55
Merge remote-tracking branch 'origin/embedded-field-view-impl' into c…
havogt 0c8ee0f
improve typing
havogt a8a0324
Introduce Fieldview Domain (#1310)
samkellerhals d3a7e44
Apply suggestions from code review
havogt aab3436
fix test
havogt 202fc3b
fix imports
havogt 60b2534
fix import
havogt 4fbf5f6
fix import
havogt 8bc332a
Merge remote-tracking branch 'origin/main' into embedded-field-view-impl
havogt 9038dd1
Merge remote-tracking branch 'upstream/embedded-field-view-impl' into…
havogt c63b927
fix bug in Domain broadcast
havogt 483dc17
undo unrelated changes
havogt 3868dd3
Merge remote-tracking branch 'upstream/embedded-field-view-impl' into…
havogt de0a534
Merge remote-tracking branch 'upstream/main' into cartesian_connectivity
havogt a5718ef
Merge remote-tracking branch 'upstream/main' into cartesian_connectivity
havogt 6be2db7
Merge branch 'main' into cartesian_connectivity
egparedes 0b39199
Minor fixes after merge
egparedes bf378d0
Fix style
egparedes f8b7bc7
More fixes after the merge
egparedes 3b16173
Fixes
egparedes f7c996b
Fix style
egparedes 182b254
Fix tests
egparedes 08ba6fe
Fix value_type leftovers
egparedes 16a983c
Fix general inverse image (WIP remap and tests, partially broken even…
egparedes 346f6f0
First working version of the general remap (cartesian still WIP)
egparedes a65b58c
First working version of refactored cartesian remap
egparedes 6e007d2
Merge branch 'main' into cartesian_connectivity
egparedes 08802c9
Fix merge issues
egparedes 7a6053f
Add higher level connectivity constructors
egparedes c79b1c1
add sketch of scan
havogt a68ab85
prototyping a bit more
havogt 643cfb1
refactor
havogt 7d3d57c
make indirect call work
havogt ffc75e7
remove levftover
havogt 23bec64
Fixes for most of cartesian shift tests
egparedes 58209f7
Format fixes
egparedes f141408
fix test cases
havogt 28b2620
fix intersect_domains
havogt 01dcf72
fix last test
havogt 68dec86
cleanups and typing
havogt e46e0f6
First embedded unstructured shift test working
egparedes c899c2c
Merge remote-tracking branch 'havogt/cartesian_connectivity' into car…
egparedes 447d8c8
Fix test cases definitions
egparedes 3b4b222
WIP adding neighbor sums
egparedes 257a6b7
add reduction builtins
havogt 22ca726
some typing
havogt 7dbcc36
more typing
havogt bf6299c
more typing
havogt 5183ac6
more
havogt af9bcb6
even more typing
havogt 334bed0
last type problem
havogt f7399dd
fix bug
havogt 2583a4d
Merge remote-tracking branch 'upstream/main' into cartesian_connectivity
havogt 920455d
Apply suggestions from code review
havogt 31ba6be
format and typing
havogt bbfa0d9
Enhance Domain with review suggestions
egparedes f0abba2
Fixes to domain changes
egparedes 581cece
tiny tests and comment cleanup
havogt ac802d7
Merge remote-tracking branch 'origin/cartesian_connectivity' into emb…
havogt c342ad4
Merge remote-tracking branch 'upstream/main' into embedded_field_scan
havogt f7beac9
update column_range to NamedRange
havogt 6cad860
refactor with context vars
havogt 87303be
address review comments
havogt f5f022e
refactor tuple assigns
havogt 1d4b4e1
delete combine_pos
havogt 31c5902
add docstring
havogt c1d6629
fix strange test
havogt fb7f6b5
isolate embedded from decorators module
havogt 62a6ec8
remove reference to foast_node in embedded execution
havogt 13fe506
Merge remote-tracking branch 'upstream/main' into embedded_field_scan
havogt 077a106
support different horizontal domains in scan tuples
havogt 3f45a6b
fix tree_reduce
havogt 4a0b1c9
add doctest for get_common_tuple_value
havogt 650b052
annotation
havogt 41fa12c
refactor dispatching logic
havogt 8142eb4
isolate embedded operators
havogt 440107e
rename back
havogt ed1bbae
fix mixed tuple and cleanup
havogt c8f8cc2
inline scanop
havogt 61f563a
address review comment
havogt b763766
delete get_common_tuple_value
havogt a698189
delete a typevar
havogt File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,168 @@ | ||
# GT4Py - GridTools Framework | ||
# | ||
# Copyright (c) 2014-2023, ETH Zurich | ||
# All rights reserved. | ||
# | ||
# This file is part of the GT4Py project and the GridTools framework. | ||
# GT4Py is free software: you can redistribute it and/or modify it under | ||
# the terms of the GNU General Public License as published by the | ||
# Free Software Foundation, either version 3 of the License, or any later | ||
# version. See the LICENSE.txt file at the top-level directory of this | ||
# distribution for a copy of the license or check <https://www.gnu.org/licenses/>. | ||
# | ||
# SPDX-License-Identifier: GPL-3.0-or-later | ||
|
||
import dataclasses | ||
from typing import Any, Callable, Generic, ParamSpec, Sequence, TypeVar | ||
|
||
from gt4py import eve | ||
from gt4py._core import definitions as core_defs | ||
from gt4py.next import common, constructors, utils | ||
from gt4py.next.embedded import common as embedded_common, context as embedded_context | ||
|
||
|
||
_P = ParamSpec("_P") | ||
_R = TypeVar("_R") | ||
|
||
|
||
@dataclasses.dataclass(frozen=True) | ||
class EmbeddedOperator(Generic[_R, _P]): | ||
fun: Callable[_P, _R] | ||
|
||
def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: | ||
return self.fun(*args, **kwargs) | ||
|
||
|
||
@dataclasses.dataclass(frozen=True) | ||
class ScanOperator(EmbeddedOperator[_R, _P]): | ||
forward: bool | ||
init: core_defs.Scalar | tuple[core_defs.Scalar | tuple, ...] | ||
axis: common.Dimension | ||
|
||
def __call__(self, *args: common.Field | core_defs.Scalar, **kwargs: common.Field | core_defs.Scalar) -> common.Field: # type: ignore[override] # we cannot properly type annotate relative to self.fun | ||
scan_range = embedded_context.closure_column_range.get() | ||
assert self.axis == scan_range[0] | ||
scan_axis = scan_range[0] | ||
domain_intersection = _intersect_scan_args(*args, *kwargs.values()) | ||
non_scan_domain = common.Domain(*[nr for nr in domain_intersection if nr[0] != scan_axis]) | ||
|
||
out_domain = common.Domain( | ||
*[scan_range if nr[0] == scan_axis else nr for nr in domain_intersection] | ||
) | ||
if scan_axis not in out_domain.dims: | ||
# even if the scan dimension is not in the input, we can scan over it | ||
out_domain = common.Domain(*out_domain, (scan_range)) | ||
|
||
res = _construct_scan_array(out_domain)(self.init) | ||
|
||
def scan_loop(hpos): | ||
acc = self.init | ||
for k in scan_range[1] if self.forward else reversed(scan_range[1]): | ||
pos = (*hpos, (scan_axis, k)) | ||
new_args = [_tuple_at(pos, arg) for arg in args] | ||
new_kwargs = {k: _tuple_at(pos, v) for k, v in kwargs.items()} | ||
acc = self.fun(acc, *new_args, **new_kwargs) | ||
_tuple_assign_value(pos, res, acc) | ||
|
||
if len(non_scan_domain) == 0: | ||
# if we don't have any dimension orthogonal to scan_axis, we need to do one scan_loop | ||
scan_loop(()) | ||
else: | ||
for hpos in embedded_common.iterate_domain(non_scan_domain): | ||
scan_loop(hpos) | ||
|
||
return res | ||
|
||
|
||
def field_operator_call(op: EmbeddedOperator, args: Any, kwargs: Any): | ||
if "out" in kwargs: | ||
# called from program or direct field_operator as program | ||
offset_provider = kwargs.pop("offset_provider", None) | ||
|
||
new_context_kwargs = {} | ||
if embedded_context.within_context(): | ||
# called from program | ||
assert offset_provider is None | ||
else: | ||
# field_operator as program | ||
new_context_kwargs["offset_provider"] = offset_provider | ||
|
||
out = kwargs.pop("out") | ||
domain = kwargs.pop("domain", None) | ||
|
||
flattened_out: tuple[common.Field, ...] = utils.flatten_nested_tuple((out,)) | ||
assert all(f.domain == flattened_out[0].domain for f in flattened_out) | ||
|
||
out_domain = common.domain(domain) if domain is not None else flattened_out[0].domain | ||
|
||
new_context_kwargs["closure_column_range"] = _get_vertical_range(out_domain) | ||
|
||
with embedded_context.new_context(**new_context_kwargs) as ctx: | ||
res = ctx.run(op, *args, **kwargs) | ||
_tuple_assign_field( | ||
out, | ||
res, | ||
domain=out_domain, | ||
) | ||
else: | ||
# called from other field_operator | ||
return op(*args, **kwargs) | ||
|
||
|
||
def _get_vertical_range(domain: common.Domain) -> common.NamedRange | eve.NothingType: | ||
vertical_dim_filtered = [nr for nr in domain if nr[0].kind == common.DimensionKind.VERTICAL] | ||
assert len(vertical_dim_filtered) <= 1 | ||
return vertical_dim_filtered[0] if vertical_dim_filtered else eve.NOTHING | ||
|
||
|
||
def _tuple_assign_field( | ||
target: tuple[common.MutableField | tuple, ...] | common.MutableField, | ||
source: tuple[common.Field | tuple, ...] | common.Field, | ||
domain: common.Domain, | ||
): | ||
@utils.tree_map | ||
def impl(target: common.MutableField, source: common.Field): | ||
target[domain] = source[domain] | ||
|
||
impl(target, source) | ||
|
||
|
||
def _intersect_scan_args( | ||
*args: core_defs.Scalar | common.Field | tuple[core_defs.Scalar | common.Field | tuple, ...] | ||
) -> common.Domain: | ||
return embedded_common.intersect_domains( | ||
*[arg.domain for arg in utils.flatten_nested_tuple(args) if common.is_field(arg)] | ||
) | ||
|
||
|
||
def _construct_scan_array(domain: common.Domain): | ||
@utils.tree_map | ||
def impl(init: core_defs.Scalar) -> common.Field: | ||
return constructors.empty(domain, dtype=type(init)) | ||
|
||
return impl | ||
havogt marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
def _tuple_assign_value( | ||
pos: Sequence[common.NamedIndex], | ||
target: common.MutableField | tuple[common.MutableField | tuple, ...], | ||
source: core_defs.Scalar | tuple[core_defs.Scalar | tuple, ...], | ||
) -> None: | ||
@utils.tree_map | ||
def impl(target: common.MutableField, source: core_defs.Scalar): | ||
target[pos] = source | ||
|
||
impl(target, source) | ||
|
||
|
||
def _tuple_at( | ||
havogt marked this conversation as resolved.
Show resolved
Hide resolved
|
||
pos: Sequence[common.NamedIndex], | ||
field: common.Field | core_defs.Scalar | tuple[common.Field | core_defs.Scalar | tuple, ...], | ||
) -> core_defs.Scalar | tuple[core_defs.ScalarT | tuple, ...]: | ||
@utils.tree_map | ||
def impl(field: common.Field | core_defs.Scalar) -> core_defs.Scalar: | ||
res = field[pos] if common.is_field(field) else field | ||
assert core_defs.is_scalar_type(res) | ||
return res | ||
|
||
return impl(field) # type: ignore[return-value] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
# GT4Py - GridTools Framework | ||
# | ||
# Copyright (c) 2014-2023, ETH Zurich | ||
# All rights reserved. | ||
# | ||
# This file is part of the GT4Py project and the GridTools framework. | ||
# GT4Py is free software: you can redistribute it and/or modify it under | ||
# the terms of the GNU General Public License as published by the | ||
# Free Software Foundation, either version 3 of the License, or any later | ||
# version. See the LICENSE.txt file at the top-level directory of this | ||
# distribution for a copy of the license or check <https://www.gnu.org/licenses/>. | ||
# | ||
# SPDX-License-Identifier: GPL-3.0-or-later | ||
|
||
import numpy as np | ||
|
||
from gt4py.next import common, utils | ||
|
||
|
||
@utils.tree_map | ||
def asnumpy(field: common.Field | np.ndarray) -> np.ndarray: | ||
return field.asnumpy() if common.is_field(field) else field # type: ignore[return-value] # mypy doesn't understand the condition |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why?