Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[next]: Embedded field scan #1365

Merged
merged 122 commits into from
Dec 12, 2023
Merged

Conversation

havogt
Copy link
Contributor

@havogt havogt commented Nov 20, 2023

Adds the scalar scan operator for embedded field view.

@havogt havogt requested a review from tehrengruber November 22, 2023 08:31
Copy link
Contributor

@tehrengruber tehrengruber left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Functionally fine, let's address the rest in the beginning of next week.

Copy link
Contributor

@tehrengruber tehrengruber left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed.

3rd point: Tests for all the different code paths.

src/gt4py/next/embedded/operators.py Outdated Show resolved Hide resolved
src/gt4py/next/embedded/operators.py Outdated Show resolved Hide resolved
src/gt4py/next/embedded/operators.py Outdated Show resolved Hide resolved
src/gt4py/next/iterator/embedded.py Outdated Show resolved Hide resolved
src/gt4py/next/utils.py Outdated Show resolved Hide resolved
@havogt
Copy link
Contributor Author

havogt commented Dec 5, 2023

Tests covering all cases are in test_arg_call_interfaces, we concluded that is covering what we need.

@tehrengruber tehrengruber changed the title feat[next] Embedded field scan feat[next]: Embedded field scan Dec 5, 2023
@havogt havogt requested a review from tehrengruber December 6, 2023 05:41
scan_range = embedded_context.closure_column_range.get()
assert self.axis == scan_range[0]

return _scan(self.fun, self.forward, self.init, scan_range, args, kwargs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not place the function here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

axis: common.Dimension

def __call__(self, *args: common.Field | core_defs.Scalar, **kwargs: common.Field | core_defs.Scalar) -> common.Field: # type: ignore[override] # we cannot properly type annotate relative to self.fun
scan_range = embedded_context.closure_column_range.get()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
scan_range = embedded_context.closure_column_range.get()
scan_range: NamedRange = embedded_context.closure_column_range.get()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why?

Comment on lines 130 to 134
for hpos in embedded_common.iterate_domain(non_scan_domain):
scan_loop(hpos)
if len(non_scan_domain) == 0:
# if we don't have any dimension orthogonal to scan_axis, we need to do one scan_loop
scan_loop(())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for hpos in embedded_common.iterate_domain(non_scan_domain):
scan_loop(hpos)
if len(non_scan_domain) == 0:
# if we don't have any dimension orthogonal to scan_axis, we need to do one scan_loop
scan_loop(())
if len(non_scan_domain) == 0:
# if we don't have any dimension orthogonal to scan_axis, we need to do one scan_loop
scan_loop(())
else:
for hpos in embedded_common.iterate_domain(non_scan_domain):
scan_loop(hpos)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@havogt havogt requested a review from tehrengruber December 7, 2023 15:42
Comment on lines 66 to 70
def flatten_nested_tuple(value: tuple[_S | tuple, ...]) -> tuple[_S, ...]:
if isinstance(value, tuple):
return sum((flatten_nested_tuple(v) for v in value), start=()) # type: ignore[arg-type] # cannot properly express nesting
else:
return (value,)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about this?

Suggested change
def flatten_nested_tuple(value: tuple[_S | tuple, ...]) -> tuple[_S, ...]:
if isinstance(value, tuple):
return sum((flatten_nested_tuple(v) for v in value), start=()) # type: ignore[arg-type] # cannot properly express nesting
else:
return (value,)
def tree_leaves(value: S_ | tuple[_S | tuple, ...]) -> Iterable[_S, ...]:
if isinstance(value, tuple):
for el in value:
yield from flatten_nested_tuple(el)
else:
yield value

This also works on values (which would already be useful in one of your use-cases) and has a name that is aligned with tree_map. Only mentioning it here, but meant for the follow up PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe your version is better because it is lazy, but the functionality is the same. Only in my version I am lying about the typing because otherwise you cannot pass a tuple because it immediately gets matched to _S and then you get a typing error. I opted for the version with a typing error inside but now obviously you cannot pass a value without getting a typing error. The best version uses a protocol NestedSequence that Enrique found in some codebase. I'll add an issue about this cleanup.

k_size,
)
) # i_size bigger than in the other argument
inp2_np = np.fromfunction(lambda i, k: k, shape=(i_size, k_size), dtype=float)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A small gibe: Is this an example where local-view is easier to comprehend?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't get it...

src/gt4py/next/utils.py Outdated Show resolved Hide resolved
@havogt havogt merged commit a14ad09 into GridTools:main Dec 12, 2023
33 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants