Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

first implementation of local udf #308

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
1 change: 1 addition & 0 deletions openeo_processes_dask/process_implementations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .logic import *
from .math import *
from .text import *
from .udf import *
VincentVerelst marked this conversation as resolved.
Show resolved Hide resolved

try:
from .ml import *
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def apply(
"positional_parameters": positional_parameters,
"named_parameters": named_parameters,
},
keep_attrs=True,
)
return result

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def drop_col(df, keep_var):
if isinstance(predictors, gpd.GeoDataFrame):
predictors = dask_geopandas.from_geopandas(predictors, npartitions=1)

if isinstance(predictors, dask_geopandas.core.GeoDataFrame):
if isinstance(predictors, dask_geopandas.expr.GeoDataFrame):
data_ddf = (
predictors.to_dask_dataframe().reset_index().repartition(npartitions=1)
)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .udf import *
26 changes: 26 additions & 0 deletions openeo_processes_dask/process_implementations/udf/udf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from typing import Optional

import dask.array as da
import xarray as xr
from openeo.udf import UdfData
from openeo.udf.run_code import run_udf_code
from openeo.udf.xarraydatacube import XarrayDataCube

from openeo_processes_dask.process_implementations.data_model import RasterCube

__all__ = ["run_udf"]


def run_udf(
data: da.Array, udf: str, runtime: str, context: Optional[dict] = None
) -> RasterCube:
data = XarrayDataCube(xr.DataArray(data))
VincentVerelst marked this conversation as resolved.
Show resolved Hide resolved
data = UdfData(datacube_list=[data], user_context=context)
result = run_udf_code(code=udf, data=data)
cubes = result.get_datacube_list()
if len(cubes) != 1:
raise ValueError(
f"The provided UDF should return one datacube, but got: {result}"
)
result_array: xr.DataArray = cubes[0].array
return result_array
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ pandas = { version = ">=2.0.0", optional = true }
xarray = { version = ">=2022.11.0,<=2024.3.0", optional = true }
dask = {extras = ["array", "dataframe", "distributed"], version = ">=2023.4.0", optional = true}
rasterio = { version = "^1.3.4", optional = true }
dask-geopandas = { version = ">=0.2.0,<1", optional = true }
dask-geopandas = { version = "0.4.3", optional = true }
xgboost = { version = ">=1.5.1", optional = true }
rioxarray = { version = ">=0.12.0,<1", optional = true }
openeo-pg-parser-networkx = { version = ">=2024.7", optional = true }
Expand All @@ -43,6 +43,7 @@ xvec = { version = "0.2.0", optional = true }
joblib = { version = ">=1.3.2", optional = true }
geoparquet = "^0.0.3"
pyarrow = "^15.0.2"
openeo = ">=0.36.0"
numpy = { version = "<2.0.0", optional = false }

[tool.poetry.group.dev.dependencies]
Expand Down
38 changes: 38 additions & 0 deletions tests/test_udf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import numpy as np
import openeo
import pytest
import xarray as xr

from openeo_processes_dask.process_implementations.udf.udf import run_udf
from tests.general_checks import general_output_checks
from tests.mockdata import create_fake_rastercube


@pytest.mark.parametrize("size", [(6, 5, 4, 4)])
@pytest.mark.parametrize("dtype", [np.float32])
def test_run_udf(temporal_interval, bounding_box, random_raster_data):
input_cube = create_fake_rastercube(
data=random_raster_data,
spatial_extent=bounding_box,
temporal_extent=temporal_interval,
bands=["B02", "B03", "B04", "B08"],
backend="dask",
)

udf = """
import xarray as xr
def apply_datacube(cube: xr.DataArray, context: dict) -> xr.DataArray:
return cube + 1
"""

output_cube = run_udf(data=input_cube, udf=udf, runtime="Python")

general_output_checks(
input_cube=input_cube,
output_cube=output_cube,
verify_attrs=True,
verify_crs=True,
expected_results=input_cube + 1,
)

xr.testing.assert_equal(output_cube, input_cube + 1)
Loading