diff --git a/rio_tiler/io/xarray.py b/rio_tiler/io/xarray.py index 1fb999da..099c6959 100644 --- a/rio_tiler/io/xarray.py +++ b/rio_tiler/io/xarray.py @@ -2,8 +2,9 @@ from __future__ import annotations +import contextlib import warnings -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple import attr import numpy @@ -38,8 +39,10 @@ try: import xarray + from xarray import open_dataset except ImportError: # pragma: nocover xarray = None # type: ignore + open_dataset = None # type: ignore try: import rioxarray @@ -48,8 +51,8 @@ @attr.s -class XarrayReader(BaseReader): - """Xarray Reader. +class DataArrayReader(BaseReader): + """Xarray DataArray Reader. Attributes: dataset (xarray.DataArray): Xarray DataArray dataset. @@ -614,3 +617,237 @@ def feature( img.array.mask = numpy.where(~cutline_mask, img.array.mask, True) return img + + +@attr.s +class DatasetReader(BaseReader): + """Xarray Reader. + + Attributes: + input (str): dataset path. + dataset (xarray.Dataset): Xarray dataset. + tms (morecantile.TileMatrixSet, optional): TileMatrixSet grid definition. Defaults to `WebMercatorQuad`. + opener (Callable): Xarray dataset opener. Defaults to `xarray.open_dataset`. + opener_options (dict): Options to forward to the opener callable. + + Examples: + >>> with DatasetReader( + "s3://mur-sst/zarr-v1", + opener_options={"engine": "zarr"} + ) as src: + print(src) + print(src.variables) + img = src.tile(x, y, z, tmax) + + """ + + input: str = attr.ib() + dataset: xarray.Dataset = attr.ib(default=None) + + tms: TileMatrixSet = attr.ib(default=WEB_MERCATOR_TMS) + + opener: Callable[..., xarray.Dataset] = attr.ib(default=open_dataset) + opener_options: Dict = attr.ib(factory=dict) + + _ctx_stack: contextlib.ExitStack = attr.ib(init=False, factory=contextlib.ExitStack) + + def __attrs_post_init__(self): + """Set bounds and CRS.""" + assert xarray is not None, "xarray must be installed to use XarrayReader" + assert rioxarray is not None, "rioxarray must be installed to use XarrayReader" + + if not self.dataset: + self.dataset = self._ctx_stack.enter_context( + self.opener(self.input, **self.opener_options) + ) + + self.bounds = None + self.crs = None + + def close(self): + """Close xarray dataset.""" + self._ctx_stack.close() + + def __exit__(self, exc_type, exc_value, traceback): + """Support using with Context Managers.""" + self.close() + + @property + def variables(self) -> List[str]: + """Return dataset variable names""" + return list(self.dataset.data_vars) + + def _arrange_dims(self, da: xarray.DataArray) -> xarray.DataArray: + """Arrange coordinates and time dimensions. + + An rioxarray.exceptions.InvalidDimensionOrder error is raised if the coordinates are not in the correct order time, y, and x. + See: https://github.com/corteva/rioxarray/discussions/674 + + We conform to using x and y as the spatial dimension names.. + + """ + if "x" not in da.dims and "y" not in da.dims: + try: + latitude_var_name = next( + name + for name in ["lat", "latitude", "LAT", "LATITUDE", "Lat"] + if name in da.dims + ) + longitude_var_name = next( + name + for name in ["lon", "longitude", "LON", "LONGITUDE", "Lon"] + if name in da.dims + ) + except StopIteration as e: + raise ValueError(f"Couldn't find X/Y dimensions in {da.dims}") from e + + da = da.rename({latitude_var_name: "y", longitude_var_name: "x"}) + + if "TIME" in da.dims: + da = da.rename({"TIME": "time"}) + + if extra_dims := [d for d in da.dims if d not in ["x", "y"]]: + da = da.transpose(*extra_dims, "y", "x") + else: + da = da.transpose("y", "x") + + # If min/max values are stored in `valid_range` we add them in `valid_min/valid_max` + vmin, vmax = da.attrs.get("valid_min"), da.attrs.get("valid_max") + if "valid_range" in da.attrs and not (vmin is not None and vmax is not None): + valid_range = da.attrs.get("valid_range") + da.attrs.update({"valid_min": valid_range[0], "valid_max": valid_range[1]}) + + return da + + def get_variable( + self, variable: str, drop_dim: Optional[str] = None + ) -> xarray.DataArray: + """Get DataArray from xarray Dataset.""" + da = self.dataset[variable] + + if drop_dim: + dim_to_drop, dim_val = drop_dim.split("=") + da = da.sel({dim_to_drop: dim_val}).drop_vars(dim_to_drop) + + da = self._arrange_dims(da) + + # Make sure we have a valid CRS + crs = da.rio.crs or "epsg:4326" + da = da.rio.write_crs(crs) + + if crs == "epsg:4326" and (da.x > 180).any(): + # Adjust the longitude coordinates to the -180 to 180 range + da = da.assign_coords(x=(da.x + 180) % 360 - 180) + + # Sort the dataset by the updated longitude coordinates + da = da.sortby(da.x) + + assert len(da.dims) in [ + 2, + 3, + ], "rio_tiler.io.xarray.DatasetReader can only work with 2D or 3D DataArray" + + return da + + def spatial_info(self, variable: str, drop_dim: Optional[str] = None): + """Return xarray.DataArray info.""" + da = DataArrayReader( + self.get_variable(variable, drop_dim=drop_dim), + ) + return { + "crs": da.crs, + "bounds": da.bounds, + "minzoom": da.minzoom, + "maxzoom": da.maxzoom, + } + + def get_geographic_bounds( # type: ignore + self, crs: CRS, variable: str, drop_dim: Optional[str] = None + ) -> BBox: + """Return Geographic Bounds for a Geographic CRS.""" + return DataArrayReader( + self.get_variable(variable, drop_dim=drop_dim), + ).get_geographic_bounds(crs) + + def info(self, variable: str, drop_dim: Optional[str] = None) -> Info: # type: ignore + """Return xarray.DataArray info.""" + return DataArrayReader( + self.get_variable(variable, drop_dim=drop_dim), + ).info() + + def statistics( # type: ignore + self, + *args: Any, + variable: str, + drop_dim: Optional[str] = None, + **kwargs: Any, + ) -> Dict[str, BandStatistics]: + """Return statistics from a dataset.""" + return DataArrayReader( + self.get_variable(variable, drop_dim=drop_dim), + ).statistics(*args, **kwargs) + + def tile( # type: ignore + self, + *args: Any, + variable: str, + drop_dim: Optional[str] = None, + **kwargs: Any, + ) -> ImageData: + """Read a Web Map tile from a dataset.""" + return DataArrayReader( + self.get_variable(variable, drop_dim=drop_dim), + tms=self.tms, + ).tile(*args, **kwargs) + + def part( # type: ignore + self, + *args: Any, + variable: str, + drop_dim: Optional[str] = None, + **kwargs: Any, + ) -> ImageData: + """Read part of a dataset.""" + return DataArrayReader(self.get_variable(variable, drop_dim=drop_dim)).part( + *args, **kwargs + ) + + def preview( # type: ignore + self, + *args: Any, + variable: str, + drop_dim: Optional[str] = None, + **kwargs: Any, + ) -> ImageData: + """Return a preview of a dataset.""" + return DataArrayReader(self.get_variable(variable, drop_dim=drop_dim)).preview( + *args, **kwargs + ) + + def point( # type: ignore + self, + *args: Any, + variable: str, + drop_dim: Optional[str] = None, + **kwargs: Any, + ) -> PointData: + """Read a pixel value from a dataset.""" + return DataArrayReader(self.get_variable(variable, drop_dim=drop_dim)).point( + *args, **kwargs + ) + + def feature( # type: ignore + self, + *args: Any, + variable: str, + drop_dim: Optional[str] = None, + **kwargs: Any, + ) -> ImageData: + """Read part of a dataset defined by a geojson feature.""" + return DataArrayReader(self.get_variable(variable, drop_dim=drop_dim)).feature( + *args, **kwargs + ) + + +# Compat +XarrayReader = DataArrayReader