-
Notifications
You must be signed in to change notification settings - Fork 49
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat[next]: Embedded field scan #1365
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Functionally fine, let's address the rest in the beginning of next week.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As discussed.
3rd point: Tests for all the different code paths.
Tests covering all cases are in |
src/gt4py/next/embedded/operators.py
Outdated
scan_range = embedded_context.closure_column_range.get() | ||
assert self.axis == scan_range[0] | ||
|
||
return _scan(self.fun, self.forward, self.init, scan_range, args, kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not place the function here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
axis: common.Dimension | ||
|
||
def __call__(self, *args: common.Field | core_defs.Scalar, **kwargs: common.Field | core_defs.Scalar) -> common.Field: # type: ignore[override] # we cannot properly type annotate relative to self.fun | ||
scan_range = embedded_context.closure_column_range.get() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
scan_range = embedded_context.closure_column_range.get() | |
scan_range: NamedRange = embedded_context.closure_column_range.get() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why?
src/gt4py/next/embedded/operators.py
Outdated
for hpos in embedded_common.iterate_domain(non_scan_domain): | ||
scan_loop(hpos) | ||
if len(non_scan_domain) == 0: | ||
# if we don't have any dimension orthogonal to scan_axis, we need to do one scan_loop | ||
scan_loop(()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for hpos in embedded_common.iterate_domain(non_scan_domain): | |
scan_loop(hpos) | |
if len(non_scan_domain) == 0: | |
# if we don't have any dimension orthogonal to scan_axis, we need to do one scan_loop | |
scan_loop(()) | |
if len(non_scan_domain) == 0: | |
# if we don't have any dimension orthogonal to scan_axis, we need to do one scan_loop | |
scan_loop(()) | |
else: | |
for hpos in embedded_common.iterate_domain(non_scan_domain): | |
scan_loop(hpos) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
src/gt4py/next/utils.py
Outdated
def flatten_nested_tuple(value: tuple[_S | tuple, ...]) -> tuple[_S, ...]: | ||
if isinstance(value, tuple): | ||
return sum((flatten_nested_tuple(v) for v in value), start=()) # type: ignore[arg-type] # cannot properly express nesting | ||
else: | ||
return (value,) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about this?
def flatten_nested_tuple(value: tuple[_S | tuple, ...]) -> tuple[_S, ...]: | |
if isinstance(value, tuple): | |
return sum((flatten_nested_tuple(v) for v in value), start=()) # type: ignore[arg-type] # cannot properly express nesting | |
else: | |
return (value,) | |
def tree_leaves(value: S_ | tuple[_S | tuple, ...]) -> Iterable[_S, ...]: | |
if isinstance(value, tuple): | |
for el in value: | |
yield from flatten_nested_tuple(el) | |
else: | |
yield value |
This also works on values (which would already be useful in one of your use-cases) and has a name that is aligned with tree_map
. Only mentioning it here, but meant for the follow up PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe your version is better because it is lazy, but the functionality is the same. Only in my version I am lying about the typing because otherwise you cannot pass a tuple because it immediately gets matched to _S
and then you get a typing error. I opted for the version with a typing error inside but now obviously you cannot pass a value without getting a typing error. The best version uses a protocol NestedSequence
that Enrique found in some codebase. I'll add an issue about this cleanup.
k_size, | ||
) | ||
) # i_size bigger than in the other argument | ||
inp2_np = np.fromfunction(lambda i, k: k, shape=(i_size, k_size), dtype=float) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A small gibe: Is this an example where local-view is easier to comprehend?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't get it...
Adds the scalar scan operator for embedded field view.