Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into gtir-dace-transient_m…
Browse files Browse the repository at this point in the history
…apped_nsdfg
  • Loading branch information
edopao committed Jan 29, 2025
2 parents 324f0f6 + d67bd7e commit 6c0f69d
Show file tree
Hide file tree
Showing 52 changed files with 507 additions and 416 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/deploy-release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ jobs:
run: |
python -m build --sdist --wheel --outdir dist/
- name: Upload artifact
uses: actions/upload-artifact@v3
uses: actions/upload-artifact@v4
with:
name: gt4py-dist
path: ./dist/**
Expand All @@ -42,7 +42,7 @@ jobs:
id-token: write
steps:
- name: Download wheel
uses: actions/download-artifact@v3
uses: actions/download-artifact@v4
with:
name: gt4py-dist
path: dist
Expand All @@ -60,7 +60,7 @@ jobs:
id-token: write
steps:
- name: Download wheel
uses: actions/download-artifact@v3
uses: actions/download-artifact@v4
with:
name: gt4py-dist
path: dist
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/test-eve.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ jobs:
tox run -e eve-py${pyversion_no_dot}
# mv coverage.json coverage-py${{ matrix.python-version }}-${{ matrix.os }}.json
# - name: Upload coverage.json artifact
# uses: actions/upload-artifact@v3
# uses: actions/upload-artifact@v4
# with:
# name: coverage-py${{ matrix.python-version }}-${{ matrix.os }}
# path: coverage-py${{ matrix.python-version }}-${{ matrix.os }}.json
Expand All @@ -64,7 +64,7 @@ jobs:
# echo ${{ github.event.pull_request.head.sha }} >> info.txt
# echo ${{ github.run_id }} >> info.txt
# - name: Upload info artifact
# uses: actions/upload-artifact@v3
# uses: actions/upload-artifact@v4
# with:
# name: info-py${{ matrix.python-version }}-${{ matrix.os }}
# path: info.txt
4 changes: 2 additions & 2 deletions .github/workflows/test-next.yml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ jobs:
tox run -e next-py${pyversion_no_dot}-${{ matrix.tox-factor }}-cpu
# mv coverage.json coverage-py${{ matrix.python-version }}-${{ matrix.os }}-${{ matrix.tox-env-factor }}-cpu.json
# - name: Upload coverage.json artifact
# uses: actions/upload-artifact@v3
# uses: actions/upload-artifact@v4
# with:
# name: coverage-py${{ matrix.python-version }}-${{ matrix.os }}-${{ matrix.tox-env-factor }}-cpu
# path: coverage-py${{ matrix.python-version }}-${{ matrix.os }}-${{ matrix.tox-env-factor }}-cpu.json
Expand All @@ -74,7 +74,7 @@ jobs:
# echo ${{ github.event.pull_request.head.sha }} >> info.txt
# echo ${{ github.run_id }} >> info.txt
# - name: Upload info artifact
# uses: actions/upload-artifact@v3
# uses: actions/upload-artifact@v4
# with:
# name: info-py${{ matrix.python-version }}-${{ matrix.os }}-${{ matrix.tox-env-factor }}-cpu
# path: info.txt
4 changes: 2 additions & 2 deletions .github/workflows/test-storage.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ jobs:
tox run -e storage-py${pyversion_no_dot}-${{ matrix.tox-factor }}-cpu
# mv coverage.json coverage-py${{ matrix.python-version }}-${{ matrix.os }}.json
# - name: Upload coverage.json artifact
# uses: actions/upload-artifact@v3
# uses: actions/upload-artifact@v4
# with:
# name: coverage-py${{ matrix.python-version }}-${{ matrix.os }}
# path: coverage-py${{ matrix.python-version }}-${{ matrix.os }}.json
Expand All @@ -66,7 +66,7 @@ jobs:
# echo ${{ github.event.pull_request.head.sha }} >> info.txt
# echo ${{ github.run_id }} >> info.txt
# - name: Upload info artifact
# uses: actions/upload-artifact@v3
# uses: actions/upload-artifact@v4
# with:
# name: info-py${{ matrix.python-version }}-${{ matrix.os }}
# path: info.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ tags: [backend, dace, optimization]
- **Status**: valid
- **Authors**: Philip Müller (@philip-paul-mueller)
- **Created**: 2024-08-27
- **Updated**: 2025-01-15

In the context of the implementation of the new DaCe fieldview we decided about a particular form of the SDFG.
Their main intent is to reduce the complexity of the GT4Py specific transformations.
Expand All @@ -22,6 +23,12 @@ In the pipeline we distinguish between:

The current (GT4Py) pipeline mainly focus on intrastate optimization and relays on DaCe, especially its simplify pass, for interstate optimizations.

## Changelog

#### 2025-01-15:

- Made the rules clearer. Specifically, made a restriction on global memory more explicit.

## Decision

The canonical form is defined by several rules that affect different aspects of an SDFG and what a transformation can assume.
Expand All @@ -38,20 +45,24 @@ The following rules especially affects transformations and how they operate:
- [Note 2]: It is allowed for an _intrastate_ transformation to act in a way that allows state fusion by later intrastate transformations.
- [Note 3]: The DaCe simplification pass violates this rule, for that reason this pass must always be called on its own, see also rule 2.

2. It is invalid to call the simplification pass directly, i.e. the usage of `SDFG.simplify()` is not allowed. The only valid way to call _simplify()_ is to call the `gt_simplify()` function provided by GT4Py.
2. It is invalid to call DaCe's simplification pass directly, i.e. the usage of `SDFG.simplify()` is not allowed. The only valid way to call _simplify()_ is to call the `gt_simplify()` function provided by GT4Py.

- [Rationale]: It was observed that some sub-passes in _simplify()_ have a negative impact and that additional passes might be needed in the future.
By using a single function later modifications to _simplify()_ are easy.
- [Note]: One issue is that the remove redundant array transformation is not able to handle all cases.

#### Global Memory

The only restriction we impose on global memory is:
Global memory has to adhere to the same rules as transient memory.
However, the following rule takes precedence, i.e. if this rule is fulfilled then rules 6 to 10 may be violated.

3. The same global memory is allowed to be used as input and output at the same time, either in the SDFG or in a state, if and only if the output depends _elementwise_ on the input.

3. The same global memory is allowed to be used as input and output at the same time, if and only if the output depends _elementwise_ on the input.
- [Rationale 1]: This allows the removal of double buffering, that DaCe may not remove. See also rule 2.
- [Rationale 2]: This formulation allows writing expressions such as `a += 1`, with only memory for `a`.
Phrased more technically, using global memory for input and output is allowed if and only if the two computations `tmp = computation(global_memory); global_memory = tmp;` and `global_memory = computation(global_memory);` are equivalent.
- [Note]: In the long term this rule will be changed to: Global memory (an array) is either used as input (only read from) or as output (only written to) but never for both.
- [Note 1]: This rule also forbids expressions such as `A[0:10] = A[1:11]`, where `A` refers to a global memory.
- [Note 2]: In the long term this rule will be changed to: Global memory (an array) is either used as input (only read from) or as output (only written to) but never for both.

#### State Machine

Expand All @@ -63,6 +74,7 @@ For the SDFG state machine we assume that:
- [Note]: Running _simplify()_ might actually result in the violation of this rule, see note of rule 9.

5. The state graph does not contain any cycles, i.e. the implementation of a for/while loop using states is not allowed, the new loop construct or serial maps must be used in that case.

- [Rationale]: This is a simplification that makes it much simpler to define what "later in the computation" means, as we will never have a cycle.
- [Note]: Currently the code generator does not support the `LoopRegion` construct and it is transformed to a state machine.

Expand Down Expand Up @@ -93,7 +105,7 @@ It is important to note that these rules only have to be met after _simplify()_
8. No two access nodes in a state can refer to the same array.

- [Rationale]: Together with rule 5 this guarantees SSA style.
- [Note]: An SDFG can still be constructed using different access node for the same underlying data; _simplify()_ will combine them.
- [Note]: An SDFG can still be constructed using different access node for the same underlying data in the same state; _simplify()_ will combine them.

9. Every access node that reads from an array (having an outgoing edge) that was not written to in the same state must be a source node.

Expand All @@ -103,6 +115,7 @@ It is important to note that these rules only have to be met after _simplify()_
Excess interstate transients, that will be kept alive that way, will be removed by later calls to _simplify()_.

10. Every AccessNode within a map scope must refer to a data descriptor whose lifetime must be `dace.dtypes.AllocationLifetime.Scope` and its storage class should either be `dace.dtypes.StorageType.Default` or _preferably_ `dace.dtypes.StorageType.Register`.

- [Rationale 1]: This makes optimizations operating inside maps/kernels simpler, as it guarantees that the AccessNode does not propagate outside.
- [Rationale 2]: The storage type avoids the need to dynamically allocate memory inside a kernel.

Expand All @@ -120,6 +133,7 @@ For maps we assume the following:
- [Rationale]: Without this rule it is very hard to tell which map variable does what, this way we can transmit information from GT4Py to DaCe, see also rule 12.

12. Two map ranges, i.e. the pair map/iteration variable and range, can only be fused if they have the same name _and_ cover the same range.

- [Rationale 1]: Because of rule 11, we will only fuse maps that actually makes sense to fuse.
- [Rationale 2]: This allows fusing maps without renaming the map variables.
- [Note]: This rule might be dropped in the future.
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,6 @@ allow_incomplete_defs = true
allow_untyped_defs = true
follow_imports = 'silent'
module = 'gt4py.cartesian.*'
warn_unused_ignores = false

[[tool.mypy.overrides]]
ignore_errors = true
Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/cartesian/backend/cuda_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ class CudaBackend(BaseGTBackend, CLIBackendMixin):
}
languages = {"computation": "cuda", "bindings": ["python"]}
storage_info = gt_storage.layout.CUDALayout
PYEXT_GENERATOR_CLASS = CudaExtGenerator # type: ignore
PYEXT_GENERATOR_CLASS = CudaExtGenerator
MODULE_GENERATOR_CLASS = CUDAPyExtModuleGenerator
GT_BACKEND_T = "gpu"

Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/cartesian/backend/dace_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -760,7 +760,7 @@ class DaCeCUDAPyExtModuleGenerator(DaCePyExtModuleGenerator, CUDAPyExtModuleGene

class BaseDaceBackend(BaseGTBackend, CLIBackendMixin):
GT_BACKEND_T = "dace"
PYEXT_GENERATOR_CLASS = DaCeExtGenerator # type: ignore
PYEXT_GENERATOR_CLASS = DaCeExtGenerator

def generate(self) -> Type[StencilObject]:
self.check_options(self.builder.options)
Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/cartesian/backend/gtcpp_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def apply(cls, root, *, module_name="stencil", **kwargs) -> str:

class GTBaseBackend(BaseGTBackend, CLIBackendMixin):
options = BaseGTBackend.GT_BACKEND_OPTS
PYEXT_GENERATOR_CLASS = GTExtGenerator # type: ignore
PYEXT_GENERATOR_CLASS = GTExtGenerator

def _generate_extension(self, uses_cuda: bool) -> Tuple[str, str]:
return self.make_extension(stencil_ir=self.builder.gtir, uses_cuda=uses_cuda)
Expand Down
10 changes: 5 additions & 5 deletions src/gt4py/cartesian/gtc/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,14 @@ class GTCPreconditionError(eve.exceptions.EveError, RuntimeError):
message_template = "GTC pass precondition error: [{info}]"

def __init__(self, *, expected: str, **kwargs: Any) -> None:
super().__init__(expected=expected, **kwargs) # type: ignore
super().__init__(expected=expected, **kwargs)


class GTCPostconditionError(eve.exceptions.EveError, RuntimeError):
message_template = "GTC pass postcondition error: [{info}]"

def __init__(self, *, expected: str, **kwargs: Any) -> None:
super().__init__(expected=expected, **kwargs) # type: ignore
super().__init__(expected=expected, **kwargs)


class AssignmentKind(eve.StrEnum):
Expand Down Expand Up @@ -267,7 +267,7 @@ def verify_and_get_common_dtype(
) -> Optional[DataType]:
assert len(exprs) > 0
if all(e.dtype is not DataType.AUTO for e in exprs):
dtypes: List[DataType] = [e.dtype for e in exprs] # type: ignore # guaranteed to be not None
dtypes: List[DataType] = [e.dtype for e in exprs] # guaranteed to be not None
dtype = dtypes[0]
if strict:
if all(dt == dtype for dt in dtypes):
Expand Down Expand Up @@ -908,7 +908,7 @@ def op_to_ufunc(
@functools.lru_cache(maxsize=None)
def typestr_to_data_type(typestr: str) -> DataType:
if not isinstance(typestr, str) or len(typestr) < 3 or not typestr[2:].isnumeric():
return DataType.INVALID # type: ignore
return DataType.INVALID
table = {
("b", 1): DataType.BOOL,
("i", 1): DataType.INT8,
Expand All @@ -919,4 +919,4 @@ def typestr_to_data_type(typestr: str) -> DataType:
("f", 8): DataType.FLOAT64,
}
key = (typestr[1], int(typestr[2:]))
return table.get(key, DataType.INVALID) # type: ignore
return table.get(key, DataType.INVALID)
8 changes: 4 additions & 4 deletions src/gt4py/cartesian/gtc/cuir/cuir.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,19 +32,19 @@ class Stmt(common.Stmt):
pass


class Literal(common.Literal, Expr): # type: ignore
class Literal(common.Literal, Expr):
pass


class ScalarAccess(common.ScalarAccess, Expr): # type: ignore
class ScalarAccess(common.ScalarAccess, Expr):
pass


class VariableKOffset(common.VariableKOffset[Expr]):
pass


class FieldAccess(common.FieldAccess[Expr, VariableKOffset], Expr): # type: ignore
class FieldAccess(common.FieldAccess[Expr, VariableKOffset], Expr):
pass


Expand Down Expand Up @@ -113,7 +113,7 @@ class TernaryOp(common.TernaryOp[Expr], Expr):
_dtype_propagation = common.ternary_op_dtype_propagation(strict=True)


class Cast(common.Cast[Expr], Expr): # type: ignore
class Cast(common.Cast[Expr], Expr):
pass


Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/cartesian/gtc/dace/daceir.py
Original file line number Diff line number Diff line change
Expand Up @@ -771,7 +771,7 @@ class TernaryOp(common.TernaryOp[Expr], Expr):
_dtype_propagation = common.ternary_op_dtype_propagation(strict=True)


class Cast(common.Cast[Expr], Expr): # type: ignore
class Cast(common.Cast[Expr], Expr):
pass


Expand Down
3 changes: 1 addition & 2 deletions src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,8 +420,7 @@ def visit_HorizontalExecution(
k_interval,
**kwargs: Any,
):
# skip type checking due to https://github.com/python/mypy/issues/5485
extent = global_ctx.library_node.get_extents(node) # type: ignore
extent = global_ctx.library_node.get_extents(node)
decls = [self.visit(decl, **kwargs) for decl in node.declarations]
targets: Set[str] = set()
stmts = [
Expand Down
8 changes: 4 additions & 4 deletions src/gt4py/cartesian/gtc/gtcpp/gtcpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,19 +31,19 @@ class Offset(common.CartesianOffset):
pass


class Literal(common.Literal, Expr): # type: ignore
class Literal(common.Literal, Expr):
pass


class LocalAccess(common.ScalarAccess, Expr): # type: ignore
class LocalAccess(common.ScalarAccess, Expr):
pass


class VariableKOffset(common.VariableKOffset[Expr]):
pass


class AccessorRef(common.FieldAccess[Expr, VariableKOffset], Expr): # type: ignore
class AccessorRef(common.FieldAccess[Expr, VariableKOffset], Expr):
pass


Expand Down Expand Up @@ -88,7 +88,7 @@ class NativeFuncCall(common.NativeFuncCall[Expr], Expr):
_dtype_propagation = common.native_func_call_dtype_propagation(strict=True)


class Cast(common.Cast[Expr], Expr): # type: ignore
class Cast(common.Cast[Expr], Expr):
pass


Expand Down
8 changes: 4 additions & 4 deletions src/gt4py/cartesian/gtc/gtir.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,19 +43,19 @@ class BlockStmt(common.BlockStmt[Stmt], Stmt):
pass


class Literal(common.Literal, Expr): # type: ignore
class Literal(common.Literal, Expr):
pass


class VariableKOffset(common.VariableKOffset[Expr]):
pass


class ScalarAccess(common.ScalarAccess, Expr): # type: ignore
class ScalarAccess(common.ScalarAccess, Expr):
pass


class FieldAccess(common.FieldAccess[Expr, VariableKOffset], Expr): # type: ignore
class FieldAccess(common.FieldAccess[Expr, VariableKOffset], Expr):
pass


Expand Down Expand Up @@ -163,7 +163,7 @@ class TernaryOp(common.TernaryOp[Expr], Expr):
_dtype_propagation = common.ternary_op_dtype_propagation(strict=False)


class Cast(common.Cast[Expr], Expr): # type: ignore
class Cast(common.Cast[Expr], Expr):
pass


Expand Down
8 changes: 4 additions & 4 deletions src/gt4py/cartesian/gtc/oir.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,19 +33,19 @@ class Stmt(common.Stmt):
pass


class Literal(common.Literal, Expr): # type: ignore
class Literal(common.Literal, Expr):
pass


class ScalarAccess(common.ScalarAccess, Expr): # type: ignore
class ScalarAccess(common.ScalarAccess, Expr):
pass


class VariableKOffset(common.VariableKOffset[Expr]):
pass


class FieldAccess(common.FieldAccess[Expr, VariableKOffset], Expr): # type: ignore
class FieldAccess(common.FieldAccess[Expr, VariableKOffset], Expr):
pass


Expand Down Expand Up @@ -88,7 +88,7 @@ class TernaryOp(common.TernaryOp[Expr], Expr):
_dtype_propagation = common.ternary_op_dtype_propagation(strict=True)


class Cast(common.Cast[Expr], Expr): # type: ignore
class Cast(common.Cast[Expr], Expr):
pass


Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/cartesian/gtc/passes/gtir_upcaster.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def _upcast_node(target_dtype: DataType, node: Expr) -> Expr:

def _upcast_nodes(*exprs: Expr, upcasting_rule: Callable) -> Iterator[Expr]:
assert all(e.dtype for e in exprs)
dtypes: List[DataType] = [e.dtype for e in exprs] # type: ignore # guaranteed to be not None
dtypes: List[DataType] = [e.dtype for e in exprs] # guaranteed to be not None
target_dtypes = upcasting_rule(*dtypes)
return iter(_upcast_node(target_dtype, arg) for target_dtype, arg in zip(target_dtypes, exprs))

Expand Down
Loading

0 comments on commit 6c0f69d

Please sign in to comment.