Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add xarray Dataset Reader #779

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
243 changes: 240 additions & 3 deletions rio_tiler/io/xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

from __future__ import annotations

import contextlib
import warnings
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Callable, Dict, List, Optional, Tuple

import attr
import numpy
Expand Down Expand Up @@ -38,8 +39,10 @@

try:
import xarray
from xarray import open_dataset
except ImportError: # pragma: nocover
xarray = None # type: ignore
open_dataset = None # type: ignore

try:
import rioxarray
Expand All @@ -48,8 +51,8 @@


@attr.s
class XarrayReader(BaseReader):
"""Xarray Reader.
class DataArrayReader(BaseReader):
"""Xarray DataArray Reader.

Attributes:
dataset (xarray.DataArray): Xarray DataArray dataset.
Expand Down Expand Up @@ -614,3 +617,237 @@ def feature(
img.array.mask = numpy.where(~cutline_mask, img.array.mask, True)

return img


@attr.s
class DatasetReader(BaseReader):
"""Xarray Reader.

Attributes:
input (str): dataset path.
dataset (xarray.Dataset): Xarray dataset.
tms (morecantile.TileMatrixSet, optional): TileMatrixSet grid definition. Defaults to `WebMercatorQuad`.
opener (Callable): Xarray dataset opener. Defaults to `xarray.open_dataset`.
opener_options (dict): Options to forward to the opener callable.

Examples:
>>> with DatasetReader(
"s3://mur-sst/zarr-v1",
opener_options={"engine": "zarr"}
) as src:
print(src)
print(src.variables)
img = src.tile(x, y, z, tmax)

"""

input: str = attr.ib()
dataset: xarray.Dataset = attr.ib(default=None)

tms: TileMatrixSet = attr.ib(default=WEB_MERCATOR_TMS)

opener: Callable[..., xarray.Dataset] = attr.ib(default=open_dataset)
opener_options: Dict = attr.ib(factory=dict)

_ctx_stack: contextlib.ExitStack = attr.ib(init=False, factory=contextlib.ExitStack)

def __attrs_post_init__(self):
"""Set bounds and CRS."""
assert xarray is not None, "xarray must be installed to use XarrayReader"
assert rioxarray is not None, "rioxarray must be installed to use XarrayReader"

if not self.dataset:
self.dataset = self._ctx_stack.enter_context(
self.opener(self.input, **self.opener_options)
)

self.bounds = None
self.crs = None

def close(self):
"""Close xarray dataset."""
self._ctx_stack.close()

def __exit__(self, exc_type, exc_value, traceback):
"""Support using with Context Managers."""
self.close()

@property
def variables(self) -> List[str]:
"""Return dataset variable names"""
return list(self.dataset.data_vars)

def _arrange_dims(self, da: xarray.DataArray) -> xarray.DataArray:
"""Arrange coordinates and time dimensions.

An rioxarray.exceptions.InvalidDimensionOrder error is raised if the coordinates are not in the correct order time, y, and x.
See: https://github.com/corteva/rioxarray/discussions/674

We conform to using x and y as the spatial dimension names..

"""
if "x" not in da.dims and "y" not in da.dims:
try:
latitude_var_name = next(
name
for name in ["lat", "latitude", "LAT", "LATITUDE", "Lat"]
if name in da.dims
)
longitude_var_name = next(
name
for name in ["lon", "longitude", "LON", "LONGITUDE", "Lon"]
if name in da.dims
)
except StopIteration as e:
raise ValueError(f"Couldn't find X/Y dimensions in {da.dims}") from e

da = da.rename({latitude_var_name: "y", longitude_var_name: "x"})

if "TIME" in da.dims:
da = da.rename({"TIME": "time"})

if extra_dims := [d for d in da.dims if d not in ["x", "y"]]:
da = da.transpose(*extra_dims, "y", "x")
else:
da = da.transpose("y", "x")

# If min/max values are stored in `valid_range` we add them in `valid_min/valid_max`
vmin, vmax = da.attrs.get("valid_min"), da.attrs.get("valid_max")
if "valid_range" in da.attrs and not (vmin is not None and vmax is not None):
valid_range = da.attrs.get("valid_range")
da.attrs.update({"valid_min": valid_range[0], "valid_max": valid_range[1]})

return da

def get_variable(
self, variable: str, drop_dim: Optional[str] = None
) -> xarray.DataArray:
"""Get DataArray from xarray Dataset."""
da = self.dataset[variable]

if drop_dim:
dim_to_drop, dim_val = drop_dim.split("=")
da = da.sel({dim_to_drop: dim_val}).drop_vars(dim_to_drop)

da = self._arrange_dims(da)

# Make sure we have a valid CRS
crs = da.rio.crs or "epsg:4326"
da = da.rio.write_crs(crs)

if crs == "epsg:4326" and (da.x > 180).any():
# Adjust the longitude coordinates to the -180 to 180 range
da = da.assign_coords(x=(da.x + 180) % 360 - 180)

# Sort the dataset by the updated longitude coordinates
da = da.sortby(da.x)

assert len(da.dims) in [
2,
3,
], "rio_tiler.io.xarray.DatasetReader can only work with 2D or 3D DataArray"

return da

def spatial_info(self, variable: str, drop_dim: Optional[str] = None):
"""Return xarray.DataArray info."""
da = DataArrayReader(
self.get_variable(variable, drop_dim=drop_dim),
)
return {
"crs": da.crs,
"bounds": da.bounds,
"minzoom": da.minzoom,
"maxzoom": da.maxzoom,
}

def get_geographic_bounds( # type: ignore
self, crs: CRS, variable: str, drop_dim: Optional[str] = None
) -> BBox:
"""Return Geographic Bounds for a Geographic CRS."""
return DataArrayReader(
self.get_variable(variable, drop_dim=drop_dim),
).get_geographic_bounds(crs)

def info(self, variable: str, drop_dim: Optional[str] = None) -> Info: # type: ignore
"""Return xarray.DataArray info."""
return DataArrayReader(
self.get_variable(variable, drop_dim=drop_dim),
).info()

def statistics( # type: ignore
self,
*args: Any,
variable: str,
drop_dim: Optional[str] = None,
**kwargs: Any,
) -> Dict[str, BandStatistics]:
"""Return statistics from a dataset."""
return DataArrayReader(
self.get_variable(variable, drop_dim=drop_dim),
).statistics(*args, **kwargs)

def tile( # type: ignore
self,
*args: Any,
variable: str,
drop_dim: Optional[str] = None,
**kwargs: Any,
) -> ImageData:
"""Read a Web Map tile from a dataset."""
return DataArrayReader(
self.get_variable(variable, drop_dim=drop_dim),
tms=self.tms,
).tile(*args, **kwargs)

def part( # type: ignore
self,
*args: Any,
variable: str,
drop_dim: Optional[str] = None,
**kwargs: Any,
) -> ImageData:
"""Read part of a dataset."""
return DataArrayReader(self.get_variable(variable, drop_dim=drop_dim)).part(
*args, **kwargs
)

def preview( # type: ignore
self,
*args: Any,
variable: str,
drop_dim: Optional[str] = None,
**kwargs: Any,
) -> ImageData:
"""Return a preview of a dataset."""
return DataArrayReader(self.get_variable(variable, drop_dim=drop_dim)).preview(
*args, **kwargs
)

def point( # type: ignore
self,
*args: Any,
variable: str,
drop_dim: Optional[str] = None,
**kwargs: Any,
) -> PointData:
"""Read a pixel value from a dataset."""
return DataArrayReader(self.get_variable(variable, drop_dim=drop_dim)).point(
*args, **kwargs
)

def feature( # type: ignore
self,
*args: Any,
variable: str,
drop_dim: Optional[str] = None,
**kwargs: Any,
) -> ImageData:
"""Read part of a dataset defined by a geojson feature."""
return DataArrayReader(self.get_variable(variable, drop_dim=drop_dim)).feature(
*args, **kwargs
)


# Compat
XarrayReader = DataArrayReader
Loading