diff --git a/openeo_processes_dask/process_implementations/__init__.py b/openeo_processes_dask/process_implementations/__init__.py index 5dd1333..5f8c7dd 100644 --- a/openeo_processes_dask/process_implementations/__init__.py +++ b/openeo_processes_dask/process_implementations/__init__.py @@ -9,6 +9,7 @@ from .logic import * from .math import * from .text import * +from .udf import * try: from .ml import * diff --git a/openeo_processes_dask/process_implementations/cubes/apply.py b/openeo_processes_dask/process_implementations/cubes/apply.py index d39b36d..aba8715 100644 --- a/openeo_processes_dask/process_implementations/cubes/apply.py +++ b/openeo_processes_dask/process_implementations/cubes/apply.py @@ -34,6 +34,7 @@ def apply( "positional_parameters": positional_parameters, "named_parameters": named_parameters, }, + keep_attrs=True, ) return result diff --git a/openeo_processes_dask/process_implementations/udf/__init__.py b/openeo_processes_dask/process_implementations/udf/__init__.py new file mode 100644 index 0000000..e926694 --- /dev/null +++ b/openeo_processes_dask/process_implementations/udf/__init__.py @@ -0,0 +1 @@ +from .udf import * diff --git a/openeo_processes_dask/process_implementations/udf/udf.py b/openeo_processes_dask/process_implementations/udf/udf.py new file mode 100644 index 0000000..a471e3c --- /dev/null +++ b/openeo_processes_dask/process_implementations/udf/udf.py @@ -0,0 +1,26 @@ +from typing import Optional + +import dask.array as da +import xarray as xr +from openeo.udf import UdfData +from openeo.udf.run_code import run_udf_code +from openeo.udf.xarraydatacube import XarrayDataCube + +from openeo_processes_dask.process_implementations.data_model import RasterCube + +__all__ = ["run_udf"] + + +def run_udf( + data: da.Array, udf: str, runtime: str, context: Optional[dict] = None +) -> RasterCube: + data = XarrayDataCube(xr.DataArray(data)) + data = UdfData(datacube_list=[data], user_context=context) + result = run_udf_code(code=udf, data=data) + cubes = result.get_datacube_list() + if len(cubes) != 1: + raise ValueError( + f"The provided UDF should return one datacube, but got: {result}" + ) + result_array: xr.DataArray = cubes[0].array + return result_array diff --git a/pyproject.toml b/pyproject.toml index f90ffc2..f486a9f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,6 +43,7 @@ xvec = { version = "0.2.0", optional = true } joblib = { version = ">=1.3.2", optional = true } geoparquet = "^0.0.3" pyarrow = "^15.0.2" +openeo = ">=0.36.0" numpy = { version = "<2.0.0", optional = false } pystac = { version = "<1.12.0", optional = false } diff --git a/tests/test_udf.py b/tests/test_udf.py new file mode 100644 index 0000000..2ee40ae --- /dev/null +++ b/tests/test_udf.py @@ -0,0 +1,38 @@ +import numpy as np +import openeo +import pytest +import xarray as xr + +from openeo_processes_dask.process_implementations.udf.udf import run_udf +from tests.general_checks import general_output_checks +from tests.mockdata import create_fake_rastercube + + +@pytest.mark.parametrize("size", [(6, 5, 4, 4)]) +@pytest.mark.parametrize("dtype", [np.float32]) +def test_run_udf(temporal_interval, bounding_box, random_raster_data): + input_cube = create_fake_rastercube( + data=random_raster_data, + spatial_extent=bounding_box, + temporal_extent=temporal_interval, + bands=["B02", "B03", "B04", "B08"], + backend="dask", + ) + + udf = """ +import xarray as xr +def apply_datacube(cube: xr.DataArray, context: dict) -> xr.DataArray: + return cube + 1 +""" + + output_cube = run_udf(data=input_cube, udf=udf, runtime="Python") + + general_output_checks( + input_cube=input_cube, + output_cube=output_cube, + verify_attrs=True, + verify_crs=True, + expected_results=input_cube + 1, + ) + + xr.testing.assert_equal(output_cube, input_cube + 1)