Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
egparedes committed Jan 24, 2025
1 parent 8303463 commit bc25c40
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 61 deletions.
2 changes: 1 addition & 1 deletion tach.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ source_roots = [
"src",
]
exact = true
#forbid_circular_dependencies = true
forbid_circular_dependencies = true

[[modules]]
path = "gt4py._core"
Expand Down
11 changes: 9 additions & 2 deletions tests/next_tests/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import dataclasses
import enum
import importlib
from typing import Final

import pytest

Expand Down Expand Up @@ -53,11 +54,17 @@ class ProgramBackendId(_PythonObjectIdMixin, str, enum.Enum):

@dataclasses.dataclass(frozen=True)
class EmbeddedDummyBackend:
name: str
allocator: next_allocators.FieldBufferAllocatorProtocol
executor: Final = None


numpy_execution = EmbeddedDummyBackend(next_allocators.StandardCPUFieldBufferAllocator())
cupy_execution = EmbeddedDummyBackend(next_allocators.StandardGPUFieldBufferAllocator())
numpy_execution = EmbeddedDummyBackend(
"EmbeddedNumPy", next_allocators.StandardCPUFieldBufferAllocator()
)
cupy_execution = EmbeddedDummyBackend(
"EmbeddedCuPy", next_allocators.StandardGPUFieldBufferAllocator()
)


class EmbeddedIds(_PythonObjectIdMixin, str, enum.Enum):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,69 +6,30 @@
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

from __future__ import annotations

import typing

import pytest

from gt4py import next as gtx
from gt4py.next import common
from gt4py.next.iterator.transforms import extractors

from next_tests.integration_tests import cases
from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import (
IDim,
JDim,
KDim,
)


if typing.TYPE_CHECKING:
from types import ModuleType
from typing import Optional

try:
import dace

from gt4py.next.program_processors.runners.dace import run_dace_cpu
except ImportError:
from gt4py.next import backend as next_backend

dace: Optional[ModuleType] = None
run_dace_cpu: Optional[next_backend.Backend] = None


@pytest.fixture(params=[pytest.param(run_dace_cpu, marks=pytest.mark.requires_dace), gtx.gtfn_cpu])
def gtir_dace_backend(request):
yield request.param


@pytest.fixture
def cartesian(request, gtir_dace_backend):
if gtir_dace_backend is None:
yield None

yield cases.Case(
backend=gtir_dace_backend,
offset_provider={
"Ioff": IDim,
"Joff": JDim,
"Koff": KDim,
},
default_sizes={IDim: 10, JDim: 10, KDim: 10},
grid_type=common.GridType.CARTESIAN,
allocator=gtir_dace_backend.allocator,
)


@pytest.mark.skipif(dace is None, reason="DaCe not found")
def test_input_names_extractor_cartesian(cartesian):
@gtx.field_operator(backend=cartesian.backend)
def test_input_names_extractor_cartesian():
@gtx.field_operator
def testee_op(
a: gtx.Field[[IDim, JDim, KDim], gtx.int],
) -> gtx.Field[[IDim, JDim, KDim], gtx.int]:
return a

@gtx.program(backend=cartesian.backend)
@gtx.program
def testee(
a: gtx.Field[[IDim, JDim, KDim], gtx.int],
b: gtx.Field[[IDim, JDim, KDim], gtx.int],
Expand All @@ -81,15 +42,14 @@ def testee(
assert input_field_names == {"a", "b"}


@pytest.mark.skipif(dace is None, reason="DaCe not found")
def test_output_names_extractor(cartesian):
@gtx.field_operator(backend=cartesian.backend)
def test_output_names_extractor():
@gtx.field_operator
def testee_op(
a: gtx.Field[[IDim, JDim, KDim], gtx.int],
) -> gtx.Field[[IDim, JDim, KDim], gtx.int]:
return a

@gtx.program(backend=cartesian.backend)
@gtx.program
def testee(
a: gtx.Field[[IDim, JDim, KDim], gtx.int],
b: gtx.Field[[IDim, JDim, KDim], gtx.int],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,9 @@ def pnabla(


def test_ffront_compute_zavgS(exec_alloc_descriptor):
_, allocator = exec_alloc_descriptor.executor, exec_alloc_descriptor.allocator
setup = nabla_setup(allocator=exec_alloc_descriptor.allocator)

setup = nabla_setup(allocator=allocator)

zavgS = gtx.zeros({Edge: setup.edges_size}, allocator=allocator)
zavgS = gtx.zeros({Edge: setup.edges_size}, allocator=exec_alloc_descriptor.allocator)

compute_zavgS.with_backend(
None if exec_alloc_descriptor.executor is None else exec_alloc_descriptor
Expand All @@ -83,12 +81,10 @@ def test_ffront_compute_zavgS(exec_alloc_descriptor):


def test_ffront_nabla(exec_alloc_descriptor):
_, allocator = exec_alloc_descriptor.executor, exec_alloc_descriptor.allocator

setup = nabla_setup(allocator=allocator)
setup = nabla_setup(allocator=exec_alloc_descriptor.allocator)

pnabla_MXX = gtx.zeros({Vertex: setup.nodes_size}, allocator=allocator)
pnabla_MYY = gtx.zeros({Vertex: setup.nodes_size}, allocator=allocator)
pnabla_MXX = gtx.zeros({Vertex: setup.nodes_size}, allocator=exec_alloc_descriptor.allocator)
pnabla_MYY = gtx.zeros({Vertex: setup.nodes_size}, allocator=exec_alloc_descriptor.allocator)

pnabla.with_backend(None if exec_alloc_descriptor.executor is None else exec_alloc_descriptor)(
setup.input_field,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from . import util


# dace = pytest.importorskip("dace")
dace = pytest.importorskip("dace")
from dace.sdfg import nodes as dace_nodes
import dace

Expand Down

0 comments on commit bc25c40

Please sign in to comment.