From 3614d7350141e09d6c512fc0fae2419cfaac50a6 Mon Sep 17 00:00:00 2001 From: jgrss Date: Sat, 4 May 2024 17:14:17 +1000 Subject: [PATCH 01/26] format --- src/geowombat/backends/rasterio_.py | 67 ++++++++++++++--------------- 1 file changed, 32 insertions(+), 35 deletions(-) diff --git a/src/geowombat/backends/rasterio_.py b/src/geowombat/backends/rasterio_.py index 85991b7d..deed4549 100644 --- a/src/geowombat/backends/rasterio_.py +++ b/src/geowombat/backends/rasterio_.py @@ -12,6 +12,7 @@ import xarray as xr from affine import Affine from dask.delayed import Delayed +from dask.utils import SerializableLock from pyproj import CRS from pyproj.exceptions import CRSError from rasterio.coords import BoundingBox @@ -164,20 +165,20 @@ class RasterioStore(object): Code modified from https://github.com/dymaxionlabs/dask-rasterio """ + # https://github.com/dask/distributed/issues/780#issuecomment-270153518 + lock_ = SerializableLock() + def __init__( self, filename: T.Union[str, Path], - mode: str = 'w', tags: dict = None, **kwargs, ): self.filename = Path(filename) - self.mode = mode self.tags = tags self.kwargs = kwargs - self.dst = None - def __setitem__(self, key, item): + def __setitem__(self, key: tuple, item: np.ndarray) -> None: if len(key) == 3: index_range, y, x = key indexes = list( @@ -191,60 +192,56 @@ def __setitem__(self, key, item): indexes = 1 y, x = key - w = Window( + chunk_window = Window( col_off=x.start, row_off=y.start, width=x.stop - x.start, height=y.stop - y.start, ) - self.dst.write(item, window=w, indexes=indexes) + self._write_window(item, indexes=indexes, window=chunk_window) def __enter__(self) -> 'RasterioStore': - return self.open() + self.closed = False + + return self._open() def __exit__(self, exc_type, exc_value, traceback): - self.close() + self.closed = True - def open(self) -> 'RasterioStore': - self.dst = self.rio_open() - self.update_tags() + def _open(self) -> 'RasterioStore': + self._create_image() return self - def rio_open(self): - return rio.open(self.filename, mode=self.mode, **self.kwargs) - - def update_tags(self): - if self.tags is not None: - self.dst.update_tags(**self.tags) - - def write_delayed(self, data: xr.DataArray): - store = da.store( - data.transpose('band', 'y', 'x').squeeze().data, - self, - lock=True, - compute=False, - ) + def _create_image(self) -> None: + mode = 'r+' if self.filename.exists() else 'w' + with rio.open(self.filename, mode=mode, **self.kwargs) as dst: + if self.tags is not None: + dst.update_tags(**self.tags) - return self.close_delayed(store) - - @dask.delayed - def close_delayed(self, store): - return self.close() + def _write_window( + self, + data: np.ndarray, + indexes: T.Union[int, np.ndarray], + window: T.Optional[Window] = None, + ) -> None: + with rio.open(self.filename, mode='r+', **self.kwargs) as dst: + dst.write( + data, + indexes=indexes, + window=window, + ) def write(self, data: xr.DataArray, compute: bool = False) -> Delayed: if isinstance(data.data, da.Array): - return da.store(data.data, self, lock=True, compute=compute) + return da.store(data.data, self, lock=self.lock_, compute=compute) else: - self.dst.write( + self._write_window( data.data, indexes=list(range(1, data.data.shape[0] + 1)), ) - def close(self): - self.dst.close() - def check_res( res: T.Union[ From eea4778c0e3952ea9e3befaa5e39fed952efa265 Mon Sep 17 00:00:00 2001 From: jgrss Date: Sat, 4 May 2024 17:14:51 +1000 Subject: [PATCH 02/26] cleanup type hints --- src/geowombat/core/geoxarray.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/src/geowombat/core/geoxarray.py b/src/geowombat/core/geoxarray.py index 38767019..ec66b029 100644 --- a/src/geowombat/core/geoxarray.py +++ b/src/geowombat/core/geoxarray.py @@ -728,16 +728,15 @@ def to_netcdf( def save( self, filename: T.Union[str, _Path], - mode: T.Optional[str] = 'w', nodata: T.Optional[T.Union[float, int]] = None, overwrite: bool = False, client: T.Optional[_Client] = None, - compute: T.Optional[bool] = True, + compute: bool = True, tags: T.Optional[dict] = None, compress: T.Optional[str] = 'none', compression: T.Optional[str] = None, - num_workers: T.Optional[int] = 1, - log_progress: T.Optional[bool] = True, + num_workers: int = 1, + log_progress: bool = True, tqdm_kwargs: T.Optional[dict] = None, bigtiff: T.Optional[str] = None, ) -> None: @@ -745,7 +744,6 @@ def save( Args: filename (str | Path): The output file name to write to. - mode (Optional[str]): The file storage mode. Choices are ['w', 'r+']. nodata (Optional[float | int]): The 'no data' value. If ``None`` (default), the 'no data' value is taken from the ``DataArray`` metadata. overwrite (Optional[bool]): Whether to overwrite an existing file. Default is False. @@ -785,10 +783,9 @@ def save( ) compress = compression - return save( + save( self._obj, filename=filename, - mode=mode, nodata=nodata, overwrite=overwrite, client=client, From 6f39f6c7816c31168e4fa64a96e2a69c78d2168a Mon Sep 17 00:00:00 2001 From: jgrss Date: Sat, 4 May 2024 17:15:28 +1000 Subject: [PATCH 03/26] support client --- src/geowombat/core/io.py | 71 +++++++++++++++++++--------------------- 1 file changed, 33 insertions(+), 38 deletions(-) diff --git a/src/geowombat/core/io.py b/src/geowombat/core/io.py index 37bf843a..47588875 100644 --- a/src/geowombat/core/io.py +++ b/src/geowombat/core/io.py @@ -683,16 +683,15 @@ def to_netcdf( def save( data: xr.DataArray, filename: T.Union[str, Path], - mode: T.Optional[str] = "w", nodata: T.Optional[T.Union[float, int]] = None, - overwrite: T.Optional[bool] = False, + overwrite: bool = False, client: T.Optional[Client] = None, - compute: T.Optional[bool] = True, + compute: bool = True, tags: T.Optional[dict] = None, compress: T.Optional[str] = "none", compression: T.Optional[str] = None, - num_workers: T.Optional[int] = 1, - log_progress: T.Optional[bool] = True, + num_workers: int = 1, + log_progress: bool = True, tqdm_kwargs: T.Optional[dict] = None, bigtiff: T.Optional[str] = None, ): @@ -701,7 +700,6 @@ def save( Args: filename (str | Path): The output file name to write to. overwrite (Optional[bool]): Whether to overwrite an existing file. Default is False. - mode (Optional[str]): The file storage mode. Choices are ['w', 'r+']. nodata (Optional[float | int]): The 'no data' value. If ``None`` (default), the 'no data' value is taken from the ``DataArray`` metadata. client (Optional[Client object]): A ``dask.distributed.Client`` client object to persist data. @@ -745,15 +743,9 @@ def save( ) compress = compression - if mode not in ["w", "r+"]: - raise AttributeError("The mode must be either 'w' or 'r+'.") - - if Path(filename).is_file(): + if Path(filename).exists(): if overwrite: Path(filename).unlink() - else: - logger.warning(f"The file {str(filename)} already exists.") - return if nodata is None: if hasattr(data, "_FillValue"): @@ -771,6 +763,16 @@ def save( if dtype != "float32": dtype = "float64" + if client is not None: + if compress not in ( + None, + "none", + ): + logger.warning( + " Cannot write to a compressed file with a Dask Client(). Data will be uncompressed." + ) + compress = None + blockxsize = ( data.gw.check_chunksize(512, data.gw.ncols) if not data.gw.array_is_dask @@ -826,33 +828,26 @@ def save( if tqdm_kwargs is None: tqdm_kwargs = {} - if not compute: - return ( - RasterioStore(filename, mode=mode, tags=tags, **kwargs) - .open() - .write_delayed(data) - ) + with RasterioStore(filename, tags=tags, **kwargs) as rio_store: + # Store the data and return a lazy evaluator + res = rio_store.write(data) - else: - with RasterioStore( - filename, mode=mode, tags=tags, **kwargs - ) as rio_store: - # Store the data and return a lazy evaluator - res = rio_store.write(data) - - if client is not None: - results = client.persist(res) - if log_progress: - progress(results) - dask.compute(results) - else: - if log_progress: - with TqdmCallback(**tqdm_kwargs): - dask.compute(res, num_workers=num_workers) - else: - dask.compute(res, num_workers=num_workers) + if not compute: + return res - return None + if client is not None: + results = client.persist(res) + if log_progress: + progress(results) + + dask.compute(results) + + else: + if log_progress: + with TqdmCallback(**tqdm_kwargs): + dask.compute(res, num_workers=num_workers) + else: + dask.compute(res, num_workers=num_workers) def to_raster( From 5ce4f1b726726b05f961ab47c53abedf995aaa7b Mon Sep 17 00:00:00 2001 From: jgrss Date: Sat, 4 May 2024 17:15:57 +1000 Subject: [PATCH 04/26] add tests --- tests/test_write.py | 104 ++++++++++++++++++++++++++++++-------------- 1 file changed, 72 insertions(+), 32 deletions(-) diff --git a/tests/test_write.py b/tests/test_write.py index 2c885e64..b7aa1358 100644 --- a/tests/test_write.py +++ b/tests/test_write.py @@ -2,8 +2,9 @@ import unittest from pathlib import Path -import dask +import numpy as np import rasterio as rio +from dask.distributed import Client, LocalCluster import geowombat as gw from geowombat.data import ( @@ -142,6 +143,76 @@ def test_save(self): with rio.open(out_path) as rio_src: self.assertTrue(rio_src.nodata == NODATA) + def test_write_numpy(self): + with tempfile.TemporaryDirectory() as tmp: + out_path = Path(tmp) / "test.tif" + with gw.open(l8_224078_20200518) as src: + data = src.gw.set_nodata(0, NODATA, dtype="uint16") + # Load data and convert from dask to numpy + data.load() + + self.assertTrue(isinstance(data.data, np.ndarray)) + + ( + data.gw.save( + filename=out_path, + overwrite=True, + tags={"TEST_METADATA": "TEST_VALUE"}, + compress="lzw", + ) + ) + with gw.open(out_path) as tmp_src: + # Compare array values + self.assertTrue(data.equals(tmp_src)) + # Compare attributes + self.assertTrue( + data.gw.nodataval == tmp_src.gw.nodataval == NODATA + ) + self.assertEqual(data.gw.dtype, tmp_src.dtype) + self.assertTrue(hasattr(tmp_src, "TEST_METADATA")) + self.assertEqual(tmp_src.TEST_METADATA, "TEST_VALUE") + + with rio.open(out_path) as rio_src: + self.assertTrue(rio_src.nodata == NODATA) + + def test_client_save(self): + + with LocalCluster( + processes=True, + n_workers=4, + threads_per_worker=1, + memory_limit="2GB", + ) as cluster: + with Client(cluster) as client: + with tempfile.TemporaryDirectory() as tmp: + out_path = Path(tmp) / "test.tif" + with gw.open(l8_224078_20200518) as src: + data = src.gw.set_nodata(0, NODATA, dtype="uint16") + data.gw.save( + filename=out_path, + overwrite=True, + tags={"TEST_METADATA": "TEST_VALUE"}, + compress="lzw", + client=client, + ) + with gw.open(out_path) as tmp_src: + # Compare array values + self.assertTrue(data.equals(tmp_src)) + # Compare attributes + self.assertTrue( + data.gw.nodataval + == tmp_src.gw.nodataval + == NODATA + ) + self.assertEqual(data.gw.dtype, tmp_src.dtype) + self.assertTrue(hasattr(tmp_src, "TEST_METADATA")) + self.assertEqual( + tmp_src.TEST_METADATA, "TEST_VALUE" + ) + + with rio.open(out_path) as rio_src: + self.assertTrue(rio_src.nodata == NODATA) + def test_config_save(self): with tempfile.TemporaryDirectory() as tmp: out_path = Path(tmp) / "test.tif" @@ -191,37 +262,6 @@ def test_save_small(self): except ValueError: self.fail("The small array write test failed.") - def test_delayed_save(self): - with tempfile.TemporaryDirectory() as tmp: - out_path = Path(tmp) / "test.tif" - with gw.open(l8_224078_20200518) as src: - data = src.gw.set_nodata(0, NODATA, dtype="uint16") - tasks = [ - gw.save( - data, - filename=out_path, - tags={"TEST_METADATA": "TEST_VALUE"}, - compress="lzw", - num_workers=2, - compute=False, - overwrite=True, - ) - ] - dask.compute(tasks, num_workers=2) - with gw.open(out_path) as tmp_src: - # Compare array values - self.assertTrue(data.equals(tmp_src)) - # Compare attributes - self.assertTrue( - data.gw.nodataval == tmp_src.gw.nodataval == NODATA - ) - self.assertEqual(data.gw.dtype, tmp_src.dtype) - self.assertTrue(hasattr(tmp_src, "TEST_METADATA")) - self.assertEqual(tmp_src.TEST_METADATA, "TEST_VALUE") - - with rio.open(out_path) as rio_src: - self.assertTrue(rio_src.nodata == NODATA) - def test_mosaic_save_single_band(self): filenames = [l8_224077_20200518_B2, l8_224078_20200518_B2] From 083cb3e799fb71c9ca74192155973022b4bf3689 Mon Sep 17 00:00:00 2001 From: jgrss Date: Sun, 5 May 2024 17:48:12 +1000 Subject: [PATCH 05/26] update stac tests --- src/geowombat/core/stac.py | 191 +++++++++++------------ tests/test_stac.py | 299 ++++++++++++++++++++++--------------- 2 files changed, 279 insertions(+), 211 deletions(-) diff --git a/src/geowombat/core/stac.py b/src/geowombat/core/stac.py index 969950a7..2f182e67 100644 --- a/src/geowombat/core/stac.py +++ b/src/geowombat/core/stac.py @@ -45,40 +45,82 @@ warnings.warn(e) -class STACNames(enum.Enum): +class StrEnum(str, enum.Enum): + """ + Source: + https://github.com/irgeek/StrEnum/blob/master/strenum/__init__.py + """ + + def __new__(cls, value, *args, **kwargs): + return super().__new__(cls, value, *args, **kwargs) + + def __str__(self) -> str: + return self.value + + +class STACNames(StrEnum): """STAC names.""" - element84_v0 = 'element84_v0' - element84_v1 = 'element84_v1' - microsoft_v1 = 'microsoft_v1' + ELEMENT84_V0 = 'element84_v0' + ELEMENT84_V1 = 'element84_v1' + MICROSOFT_V1 = 'microsoft_v1' -class STACCollections(enum.Enum): - cop_dem_glo_30 = 'cop_dem_glo_30' - landsat_c2_l1 = 'landsat_c2_l1' - landsat_c2_l2 = 'landsat_c2_l2' - sentinel_s2_l2a = 'sentinel_s2_l2a' - sentinel_s2_l2a_cogs = 'sentinel_s2_l2a_cogs' - sentinel_s2_l1c = 'sentinel_s2_l1c' - sentinel_s1_l1c = 'sentinel_s1_l1c' - sentinel_3_lst = 'sentinel_3_lst' - landsat_l8_c2_l2 = 'landsat_l8_c2_l2' - usda_cdl = 'usda_cdl' - io_lulc = 'io_lulc' +class STACCollections(StrEnum): + # Copernicus DEM GLO-30 + COP_DEM_GLO_30 = 'cop_dem_glo_30' + # All Landsat, Collection 2, Level 1 + LANDSAT_C2_L1 = 'landsat_c2_l1' + # All Landsat, Collection 2, Level 2 (surface reflectance) + LANDSAT_C2_L2 = 'landsat_c2_l2' + # Sentinel-2, Level 2A (surface reflectance missing cirrus band) + SENTINEL_S2_L2A = 'sentinel_s2_l2a' + SENTINEL_S2_L2A_COGS = 'sentinel_s2_l2a_cogs' + # Sentinel-2, Level 1C (top of atmosphere with all 13 bands available) + SENTINEL_S2_L1C = 'sentinel_s2_l1c' + # Sentinel-1, Level 1C Ground Range Detected (GRD) + SENTINEL_S1_L1C = 'sentinel_s1_l1c' + SENTINEL_3_LST = 'sentinel_3_lst' + LANDSAT_L8_C2_L2 = 'landsat_l8_c2_l2' + USDA_CDL = 'usda_cdl' + IO_LULC = 'io_lulc' + NAIP = 'naip' + + +class STACCollectionURLNames(StrEnum): + # Copernicus DEM GLO-30 + COP_DEM_GLO_30 = STACCollections.COP_DEM_GLO_30.replace('_', '-') + # All Landsat, Collection 2, Level 1 + LANDSAT_C2_L1 = STACCollections.LANDSAT_C2_L1.replace('_', '-') + # All Landsat, Collection 2, Level 2 (surface reflectance) + LANDSAT_C2_L2 = STACCollections.LANDSAT_C2_L2.replace('_', '-') + # Sentinel-2, Level 2A (surface reflectance missing cirrus band) + SENTINEL_S2_L2A = 'sentinel-2-l2a' + SENTINEL_S2_L2A_COGS = STACCollections.SENTINEL_S2_L2A_COGS.replace( + '_', '-' + ) + # Sentinel-2, Level 1C (top of atmosphere with all 13 bands available) + SENTINEL_S2_L1C = 'sentinel-2-l1c' + # Sentinel-1, Level 1C Ground Range Detected (GRD) + SENTINEL_S1_L1C = 'sentinel-1-grd' + SENTINEL_3_LST = 'sentinel-3-slstr-lst-l2-netcdf' + LANDSAT_L8_C2_L2 = 'landsat-8-c2-l2' + USDA_CDL = STACCollections.USDA_CDL.replace('_', '-') + IO_LULC = STACCollections.IO_LULC.replace('_', '-') + NAIP = STACCollections.NAIP STAC_CATALOGS = { - STACNames.element84_v0: 'https://earth-search.aws.element84.com/v0', - STACNames.element84_v1: 'https://earth-search.aws.element84.com/v1', + STACNames.ELEMENT84_V0: 'https://earth-search.aws.element84.com/v0', + STACNames.ELEMENT84_V1: 'https://earth-search.aws.element84.com/v1', # STACNames.google: 'https://earthengine.openeo.org/v1.0', - STACNames.microsoft_v1: 'https://planetarycomputer.microsoft.com/api/stac/v1', + STACNames.MICROSOFT_V1: 'https://planetarycomputer.microsoft.com/api/stac/v1', } - STAC_SCALING = { - STACCollections.landsat_c2_l2: { + STACCollections.LANDSAT_C2_L2: { # https://planetarycomputer.microsoft.com/dataset/landsat-c2-l2 - STACNames.microsoft_v1: { + STACNames.MICROSOFT_V1: { 'gain': 0.0000275, 'offset': -0.2, 'nodata': 0, @@ -87,60 +129,26 @@ class STACCollections(enum.Enum): } STAC_COLLECTIONS = { - # Copernicus DEM GLO-30 - STACCollections.cop_dem_glo_30: { - STACNames.element84_v1: 'cop-dem-glo-30', - STACNames.microsoft_v1: 'cop-dem-glo-30', - }, - # All Landsat, Collection 2, Level 1 - STACCollections.landsat_c2_l1: { - STACNames.microsoft_v1: 'landsat-c2-l1', - }, - # All Landsat, Collection 2, Level 2 (surface reflectance) - STACCollections.landsat_c2_l2: { - STACNames.element84_v1: 'landsat-c2-l2', - # STACNames.google: [ - # 'LC09/C02/T1_L2', - # 'LC08/C02/T1_L2', - # 'LE07/C02/T1_L2', - # 'LT05/C02/T1_L2', - # ], - STACNames.microsoft_v1: 'landsat-c2-l2', - }, - # Sentinel-2, Level 2A (surface reflectance missing cirrus band) - STACCollections.sentinel_s2_l2a_cogs: { - STACNames.element84_v0: 'sentinel-s2-l2a-cogs', - }, - STACCollections.sentinel_s2_l2a: { - STACNames.element84_v1: 'sentinel-2-l2a', - # STACNames.google: 'COPERNICUS/S2_SR', - STACNames.microsoft_v1: 'sentinel-2-l2a', - }, - # Sentinel-2, Level 1C (top of atmosphere with all 13 bands available) - STACCollections.sentinel_s2_l1c: { - STACNames.element84_v1: 'sentinel-2-l1c' - }, - # Sentinel-1, Level 1C Ground Range Detected (GRD) - STACCollections.sentinel_s1_l1c: { - STACNames.element84_v1: 'sentinel-1-grd', - STACNames.microsoft_v1: 'sentinel-1-grd', - }, - STACCollections.sentinel_3_lst: { - STACNames.microsoft_v1: 'sentinel-3-slstr-lst-l2-netcdf', - }, - # Landsat 8, Collection 2, Tier 1 (Level 2 (surface reflectance)) - STACCollections.landsat_l8_c2_l2: { - # STACNames.google: 'LC08_C02_T1_L2', - STACNames.microsoft_v1: 'landsat-8-c2-l2', - }, - # USDA CDL - STACCollections.usda_cdl: { - STACNames.microsoft_v1: 'usda-cdl', - }, - # Esri 10 m land cover - STACCollections.io_lulc: { - STACNames.microsoft_v1: 'io-lulc', - }, + STACNames.ELEMENT84_V0: (STACCollectionURLNames.SENTINEL_S2_L2A_COGS,), + STACNames.ELEMENT84_V1: ( + STACCollectionURLNames.COP_DEM_GLO_30, + STACCollectionURLNames.LANDSAT_C2_L2, + STACCollectionURLNames.SENTINEL_S2_L2A, + STACCollectionURLNames.SENTINEL_S2_L1C, + STACCollectionURLNames.SENTINEL_S1_L1C, + STACCollectionURLNames.NAIP, + ), + STACNames.MICROSOFT_V1: ( + STACCollectionURLNames.COP_DEM_GLO_30, + STACCollectionURLNames.LANDSAT_C2_L1, + STACCollectionURLNames.LANDSAT_C2_L2, + STACCollectionURLNames.SENTINEL_S2_L2A, + STACCollectionURLNames.SENTINEL_S1_L1C, + STACCollectionURLNames.SENTINEL_3_LST, + STACCollectionURLNames.LANDSAT_L8_C2_L2, + STACCollectionURLNames.USDA_CDL, + STACCollectionURLNames.IO_LULC, + ), } @@ -208,7 +216,7 @@ def _download_worker(item, extra: str, out_path: _Path) -> dict: def open_stac( - stac_catalog: str = 'microsoft_v1', + stac_catalog: str = STACNames.ELEMENT84_V1, collection: str = None, bounds: T.Union[T.Sequence[float], str, _Path, gpd.GeoDataFrame] = None, proj_bounds: T.Sequence[float] = None, @@ -244,6 +252,7 @@ def open_stac( sentinel_s2_l2a sentinel_s2_l1c sentinel_s1_l1c + naip microsoft_v1: cop_dem_glo_30 landsat_c2_l1 @@ -341,7 +350,7 @@ def open_stac( bounds = tuple(bounds.total_bounds.flatten().tolist()) try: - stac_catalog_url = STAC_CATALOGS[STACNames(stac_catalog)] + stac_catalog_url = STAC_CATALOGS[stac_catalog] # Open the STAC catalog catalog = _Client.open(stac_catalog_url) except ValueError as e: @@ -349,20 +358,15 @@ def open_stac( f'The STAC catalog {stac_catalog} is not supported ({e}).' ) - try: - collection_dict = STAC_COLLECTIONS[STACCollections(collection)] - except ValueError as e: - raise NameError( - f'The STAC collection {collection} is not supported ({e}).' - ) + if ( + STACCollectionURLNames[STACCollections(collection).name] + not in STAC_COLLECTIONS[stac_catalog] + ): + raise NameError(f'The STAC collection {collection} is not supported.') - try: - catalog_collections = [collection_dict[STACNames(stac_catalog)]] - except KeyError as e: - raise NameError( - f'The STAC catalog {stac_catalog} does not have a collection {collection} ({e}).' - ) - # asset = catalog.get_collection(catalog_collections[0]).assets['geoparquet-items'] + catalog_collections = [ + STACCollectionURLNames[STACCollections(collection).name] + ] query = None if cloud_cover_perc is not None: @@ -382,7 +386,7 @@ def open_stac( raise ValueError('No items found.') if list(search.items()): - if STACNames(stac_catalog) is STACNames.microsoft_v1: + if STACNames(stac_catalog) == STACNames.MICROSOFT_V1: items = pc.sign(search) else: items = pystac.ItemCollection(items=list(search.items())) @@ -394,6 +398,7 @@ def open_stac( ) except pystac_errors.ExtensionNotImplemented: selected_item = items.items[0] + table = _Table("Asset Key", "Description") for asset_key, asset in selected_item.assets.items(): table.add_row(asset_key, asset.title) @@ -492,4 +497,6 @@ def open_stac( return data, df + warnings.warn("No asset items were found.") + return None, None diff --git a/tests/test_stac.py b/tests/test_stac.py index 591dcad7..c4c0263c 100644 --- a/tests/test_stac.py +++ b/tests/test_stac.py @@ -1,21 +1,28 @@ +# flake8: noqa + import tempfile import unittest from pathlib import Path import geopandas as gpd +import numpy as np import validators +from dask.distributed import Client, LocalCluster +from pyproj import CRS from rasterio.enums import Resampling from shapely.geometry import CAP_STYLE, JOIN_STYLE, shape +import geowombat as gw from geowombat.core.stac import ( STAC_CATALOGS, STAC_COLLECTIONS, STACCollections, + STACCollectionURLNames, STACNames, open_stac, ) -geojson = { +search_geojson = { "type": "Polygon", "coordinates": [ [ @@ -28,15 +35,47 @@ ], } -EPSG = 8857 +naip_geojson = { + "type": "Polygon", + "coordinates": [ + [ + [-86.65852429273222, 40.869286853632445], + [-86.65852429273222, 40.85430596003397], + [-86.63824797576922, 40.85430596003397], + [-86.63824797576922, 40.869286853632445], + [-86.65852429273222, 40.869286853632445], + ] + ], +} + + +def geosjon_to_df( + geojson: dict, + epsg: int, +): + df = gpd.GeoDataFrame(geometry=[shape(geojson)], crs=4326) + proj_df = df.to_crs(f'epsg:{epsg}') + df = ( + proj_df.buffer( + 100, cap_style=CAP_STYLE.square, join_style=JOIN_STYLE.mitre + ) + .to_crs('epsg:4326') + .to_frame(name='geometry') + ) + + return df, tuple(proj_df.total_bounds.flatten().tolist()) -DF = gpd.GeoDataFrame(geometry=[shape(geojson)], crs=4326) -SEARCH_DF = ( - DF.to_crs(f'epsg:{EPSG}') - .buffer(100, cap_style=CAP_STYLE.square, join_style=JOIN_STYLE.mitre) - .to_crs('epsg:4326') - .to_frame(name='geometry') +SEARCH_EPSG = 8857 +SEARCH_DF, SEARCH_BOUNDS = geosjon_to_df( + geojson=search_geojson, + epsg=SEARCH_EPSG, +) + +NAIP_EPSG = 8858 +NAIP_DF, NAIP_PROJ_BOUNDS = geosjon_to_df( + geojson=naip_geojson, + epsg=NAIP_EPSG, ) @@ -47,65 +86,78 @@ def url_is_valid(url: str) -> bool: return False -class TestDownloadSingleBand(unittest.TestCase): - # def test_download_sentinel_3_lst(self): - # stack = open_stac( - # stac_catalog='microsoft', - # bounds=SEARCH_DF, - # proj_bounds=tuple( - # DF.to_crs(f'epsg:{EPSG}').total_bounds.flatten().tolist() - # ), - # epsg=EPSG, - # collection='sentinel_3_lst', - # bands=['lst-in'], - # cloud_cover_perc=90, - # chunksize=64, - # start_date='2022-07-01', - # end_date='2022-07-07', - # resolution=300.0, - # nodata_fill=32768, - # resampling=Resampling.nearest, - # max_items=None, - # )[0] - # self.assertTrue(stack.shape == (20, 1, 3, 4)) - # self.assertTrue(stack.crs == 'epsg:8857') - # self.assertTrue(stack.gw.celly == 300.0) - # self.assertTrue(stack.gw.cellx == 300.0) - # self.assertTrue(stack.gw.nodataval == 32768) - - # def test_download_blue_sentinel_s2_l1c(self): - # stack = open_stac( - # stac_catalog='element84_v1', - # bounds=SEARCH_DF, - # proj_bounds=tuple( - # DF.to_crs(f'epsg:{EPSG}').total_bounds.flatten().tolist() - # ), - # epsg=EPSG, - # collection='sentinel_s2_l1c', - # bands=['blue'], - # cloud_cover_perc=90, - # chunksize=64, - # start_date='2022-07-01', - # end_date='2022-07-07', - # resolution=10.0, - # nodata_fill=32768, - # resampling=Resampling.nearest, - # max_items=None, - # )[0] - # self.assertTrue(stack.shape == (2, 1, 48, 64)) - # self.assertTrue(stack.crs == 'epsg:8857') - # self.assertTrue(stack.gw.celly == 10.0) - # self.assertTrue(stack.gw.cellx == 10.0) - # self.assertTrue(stack.gw.nodataval == 32768) - - def test_download_ms_landsat_c2_l2(self): +class TestSearchSingleBand(unittest.TestCase): + def test_search_sentinel_3_lst(self): + stack = open_stac( + stac_catalog='microsoft_v1', + bounds=SEARCH_DF, + proj_bounds=SEARCH_BOUNDS, + epsg=SEARCH_EPSG, + collection='sentinel_3_lst', + bands=['lst-in'], + cloud_cover_perc=90, + chunksize=64, + start_date='2022-07-01', + end_date='2022-07-07', + resolution=300.0, + nodata_fill=32768, + resampling=Resampling.nearest, + max_items=None, + )[0] + + self.assertTrue(stack.shape == (20, 1, 3, 4)) + self.assertTrue(stack.gw.crs_to_pyproj == CRS.from_epsg(EPSG)) + self.assertTrue(stack.gw.celly == 300.0) + self.assertTrue(stack.gw.cellx == 300.0) + self.assertTrue(stack.gw.nodataval == 32768) + + def test_search_blue_sentinel_s2_l1c(self): + stack = open_stac( + stac_catalog='element84_v1', + bounds=SEARCH_DF, + proj_bounds=SEARCH_BOUNDS, + epsg=SEARCH_EPSG, + collection='sentinel_s2_l1c', + bands=['blue'], + cloud_cover_perc=90, + chunksize=64, + start_date='2022-07-01', + end_date='2022-07-07', + resolution=10.0, + nodata_fill=32768, + resampling=Resampling.nearest, + max_items=None, + )[0] + self.assertTrue(stack.shape == (2, 1, 48, 64)) + self.assertTrue(stack.gw.crs_to_pyproj == CRS.from_epsg(EPSG)) + self.assertTrue(stack.gw.celly == 10.0) + self.assertTrue(stack.gw.cellx == 10.0) + self.assertTrue(stack.gw.nodataval == 32768) + + def test_search_sentinel_s1_l1c(self): + stack = open_stac( + stac_catalog='element84_v1', + bounds=SEARCH_DF, + proj_bounds=SEARCH_BOUNDS, + epsg=SEARCH_EPSG, + collection='sentinel_s1_l1c', + bands=['blue'], + cloud_cover_perc=90, + chunksize=64, + start_date='2022-07-01', + end_date='2022-07-07', + resolution=10.0, + nodata_fill=32768, + resampling=Resampling.nearest, + max_items=None, + )[0] + + def test_search_ms_landsat_c2_l2(self): stack, df = open_stac( stac_catalog='microsoft_v1', bounds=SEARCH_DF, - proj_bounds=tuple( - DF.to_crs(f'epsg:{EPSG}').total_bounds.flatten().tolist() - ), - epsg=EPSG, + proj_bounds=SEARCH_BOUNDS, + epsg=SEARCH_EPSG, collection='landsat_c2_l2', bands=['red', 'nir'], cloud_cover_perc=90, @@ -126,14 +178,12 @@ def test_download_ms_landsat_c2_l2(self): stack, df = open_stac( stac_catalog='microsoft_v1', bounds=SEARCH_DF, - proj_bounds=tuple( - DF.to_crs(f'epsg:{EPSG}').total_bounds.flatten().tolist() - ), - epsg=EPSG, + proj_bounds=SEARCH_BOUNDS, + epsg=SEARCH_EPSG, collection='landsat_c2_l2', bands=['red', 'nir08'], cloud_cover_perc=90, - chunksize=64, + chunksize=32, start_date='2022-07-01', end_date='2022-07-07', resolution=10.0, @@ -148,14 +198,14 @@ def test_download_ms_landsat_c2_l2(self): set(stack.band.values).difference(['red', 'nir08']) ) self.assertTrue(stack.shape == (2, 2, 48, 64)) - self.assertTrue(stack.crs == 'epsg:8857') + self.assertTrue(stack.gw.crs_to_pyproj == CRS.from_epsg(EPSG)) self.assertTrue(stack.gw.celly == 10.0) self.assertTrue(stack.gw.cellx == 10.0) self.assertTrue(stack.gw.nodataval == 32768) self.assertTrue(len(df.index) == 2) self.assertFalse(set(df.id.values).difference(df.id.values)) - def test_download_blue_sentinel_s2_l2a(self): + def test_search_blue_sentinel_s2_l2a(self): with tempfile.TemporaryDirectory() as tmp_path: stack, df = open_stac( stac_catalog='element84_v1', @@ -163,11 +213,11 @@ def test_download_blue_sentinel_s2_l2a(self): proj_bounds=tuple( DF.to_crs(f'epsg:{EPSG}').total_bounds.flatten().tolist() ), - epsg=EPSG, + epsg=SEARCH_EPSG, collection='sentinel_s2_l2a', bands=['blue'], cloud_cover_perc=90, - chunksize=64, + chunksize=32, start_date='2022-07-01', end_date='2022-07-07', resolution=10.0, @@ -179,65 +229,76 @@ def test_download_blue_sentinel_s2_l2a(self): ) self.assertTrue(stack.shape == (2, 1, 48, 64)) - self.assertTrue(stack.crs == 'epsg:8857') + self.assertTrue(stack.gw.crs_to_pyproj == CRS.from_epsg(EPSG)) self.assertTrue(stack.gw.celly == 10.0) self.assertTrue(stack.gw.cellx == 10.0) self.assertTrue(stack.gw.nodataval == 32768) self.assertTrue(len(df.index) == 2) self.assertFalse(set(df.id.values).difference(df.id.values)) - # def test_download_blue_sentinel_s2_l2a_cogs(self): - # with tempfile.TemporaryDirectory() as tmp_path: - # stack = open_stac( - # stac_catalog='element84_v0', - # bounds=SEARCH_DF, - # proj_bounds=tuple( - # DF.to_crs(f'epsg:{EPSG}').total_bounds.flatten().tolist() - # ), - # epsg=EPSG, - # collection='sentinel_s2_l2a_cogs', - # bands=['B02'], - # cloud_cover_perc=90, - # chunksize=64, - # start_date='2022-07-01', - # end_date='2022-07-07', - # resolution=10.0, - # nodata_fill=32768, - # resampling=Resampling.nearest, - # max_items=None, - # out_path=Path(tmp_path), - # )[0] - # self.assertTrue(stack.shape == (2, 1, 48, 64)) - # self.assertTrue(stack.crs == 'epsg:8857') - # self.assertTrue(stack.gw.celly == 10.0) - # self.assertTrue(stack.gw.cellx == 10.0) - # self.assertTrue(stack.gw.nodataval == 32768) + out_path = Path(tmp_path) / 'test.tif' + time_mean = stack.mean(dim='time', keep_attrs=True) + time_mean.gw.save( + filename=out_path, + overwrite=True, + ) + with gw.open(out_path) as src: + self.assertTrue( + np.allclose( + time_mean.data.compute(), + src.data.compute(), + ) + ) + + out_path_client = Path(tmp_path) / 'test_client.tif' + with LocalCluster( + processes=True, + n_workers=2, + threads_per_worker=1, + memory_limit="1GB", + ) as cluster: + with Client(cluster) as client: + time_mean.gw.save( + filename=out_path_client, + overwrite=True, + client=client, + ) + + with gw.open(out_path_client) as src: + self.assertTrue( + np.allclose( + time_mean.data.compute(), + src.data.compute(), + ) + ) class TestSTAC(unittest.TestCase): def test_unsupported(self): + # Element84 v1 does have Sentinel2 + self.assertTrue( + STACCollectionURLNames[STACCollections('sentinel_s2_l2a').name] + in STAC_COLLECTIONS[STACNames.ELEMENT84_V1] + ) + # Element84 v1 does not have Sentinel2 COGs - collection = 'element84_v1' - stac_catalog = 'sentinel_s2_l2a_cogs' - collection_dict = STAC_COLLECTIONS[STACCollections(stac_catalog)] - with self.assertRaises(KeyError): - catalog_collections = [collection_dict[STACNames(collection)]] - # Element84 does not have a usda_cdl collection - collection = 'element84_v1' - stac_catalog = 'usda_cdl' - collection_dict = STAC_COLLECTIONS[STACCollections(stac_catalog)] - with self.assertRaises(KeyError): - catalog_collections = [collection_dict[STACNames(collection)]] - - def test_constants(self): - self.assertEqual(STACNames('element84_v0'), STACNames.element84_v0) - self.assertEqual(STACNames('element84_v1'), STACNames.element84_v1) - self.assertEqual(STACNames('microsoft_v1'), STACNames.microsoft_v1) + self.assertFalse( + STACCollectionURLNames[ + STACCollections('sentinel_s2_l2a_cogs').name + ] + in STAC_COLLECTIONS[STACNames.ELEMENT84_V1] + ) + + # Element84 does not have a USDA CDL collection + self.assertFalse( + STACCollectionURLNames[STACCollections('usda_cdl').name] + in STAC_COLLECTIONS[STACNames.ELEMENT84_V1] + ) def test_urls(self): - self.assertTrue(url_is_valid(STAC_CATALOGS[STACNames.element84_v0])) - self.assertTrue(url_is_valid(STAC_CATALOGS[STACNames.element84_v1])) - self.assertTrue(url_is_valid(STAC_CATALOGS[STACNames.microsoft_v1])) + self.assertTrue(url_is_valid(STAC_CATALOGS[STACNames.ELEMENT84_V0])) + self.assertTrue(url_is_valid(STAC_CATALOGS[STACNames.ELEMENT84_V1])) + self.assertTrue(url_is_valid(STAC_CATALOGS[STACNames.MICROSOFT_V1])) if __name__ == '__main__': From 6207071affc574526fb5ba77d28026608103c22f Mon Sep 17 00:00:00 2001 From: jgrss Date: Sun, 5 May 2024 18:38:00 +1000 Subject: [PATCH 06/26] add support to write time with scatter --- src/geowombat/backends/rasterio_.py | 221 ++++++++++++++++++++++++++-- 1 file changed, 205 insertions(+), 16 deletions(-) diff --git a/src/geowombat/backends/rasterio_.py b/src/geowombat/backends/rasterio_.py index deed4549..4f4e074c 100644 --- a/src/geowombat/backends/rasterio_.py +++ b/src/geowombat/backends/rasterio_.py @@ -16,6 +16,7 @@ from pyproj import CRS from pyproj.exceptions import CRSError from rasterio.coords import BoundingBox +from rasterio.drivers import driver_from_extension from rasterio.enums import Resampling from rasterio.transform import array_bounds, from_bounds from rasterio.vrt import WarpedVRT @@ -29,6 +30,8 @@ import geowombat as gw +from ..config import config + try: import numcodecs import zarr @@ -170,16 +173,140 @@ class RasterioStore(object): def __init__( self, + data: xr.DataArray, filename: T.Union[str, Path], + scatter: T.Optional[str] = None, tags: dict = None, - **kwargs, + compress: T.Optional[str] = "none", + bigtiff: T.Optional[str] = None, ): + self.data = data self.filename = Path(filename) + self.scatter = scatter self.tags = tags - self.kwargs = kwargs + self.compress = compress + self.bigtiff = bigtiff + + def _setup(self): + if self.scatter is None: + if len(self.data.shape) > 3: + raise ValueError( + "Only 3-band arrays can be written when scatter=None." + ) + + if hasattr(self.data, "_FillValue"): + nodata = self.data.attrs["_FillValue"] + else: + if hasattr(self.data, "nodatavals"): + nodata = self.data.attrs["nodatavals"][0] + else: + raise AttributeError( + "The DataArray does not have any 'no data' attributes." + ) + + dtype = ( + self.data.dtype.name + if isinstance(self.data.dtype, np.dtype) + else self.data.dtype + ) + if isinstance(nodata, float): + if dtype != "float32": + dtype = "float64" + + if self.scatter is None: + band_count = self.data.gw.nbands + else: + if self.scatter == 'band': + self.data = self.data.chunk( + { + 'time': 1, + 'band': -1, + 'y': self.data.gw.row_chunks, + 'x': self.data.gw.col_chunks, + } + ) + band_count = self.data.gw.ntime + elif self.scatter == 'time': + self.data = self.data.chunk( + { + 'time': -1, + 'band': 1, + 'y': self.data.gw.row_chunks, + 'x': self.data.gw.col_chunks, + } + ) + band_count = self.data.gw.nbands + + blockxsize = ( + self.data.gw.check_chunksize(512, self.data.gw.ncols) + if not self.data.gw.array_is_dask + else self.data.gw.col_chunks + ) + blockysize = ( + self.data.gw.check_chunksize(512, self.data.gw.nrows) + if not self.data.gw.array_is_dask + else self.data.gw.row_chunks + ) + + tiled = True + bigtiff = self.bigtiff + compress = self.compress + if config["with_config"]: + if config["bigtiff"] is not None: + if isinstance(config["bigtiff"], bool): + bigtiff = "YES" if config["bigtiff"] else "NO" + else: + bigtiff = config["bigtiff"].upper() + + if bigtiff not in ( + "YES", + "NO", + "IF_NEEDED", + "IF_SAFER", + ): + raise NameError( + "The GDAL BIGTIFF must be one of 'YES', 'NO', 'IF_NEEDED', or 'IF_SAFER'. See https://gdal.org/drivers/raster/gtiff.html#creation-issues for more information." + ) + + if config["compress"] is not None: + compress = config["compress"] + + if config["tiled"] is not None: + tiled = config["tiled"] + + self.kwargs = dict( + driver=driver_from_extension(self.filename), + width=self.data.gw.ncols, + height=self.data.gw.nrows, + count=band_count, + dtype=dtype, + nodata=nodata, + blockxsize=blockxsize, + blockysize=blockysize, + crs=self.data.gw.crs_to_pyproj, + transform=self.data.gw.transform, + compress=compress, + tiled=tiled if max(blockxsize, blockysize) >= 16 else False, + sharing=False, + BIGTIFF=bigtiff, + ) def __setitem__(self, key: tuple, item: np.ndarray) -> None: - if len(key) == 3: + if len(key) == 4: + if self.scatter == 'band': + index_range, _, y, x = key + else: + _, index_range, y, x = key + + indexes = list( + range( + index_range.start + 1, + index_range.stop + 1, + index_range.step or 1, + ) + ) + + elif len(key) == 3: index_range, y, x = key indexes = list( range( @@ -188,6 +315,7 @@ def __setitem__(self, key: tuple, item: np.ndarray) -> None: index_range.step or 1, ) ) + else: indexes = 1 y, x = key @@ -203,6 +331,7 @@ def __setitem__(self, key: tuple, item: np.ndarray) -> None: def __enter__(self) -> 'RasterioStore': self.closed = False + self._setup() return self._open() @@ -216,9 +345,31 @@ def _open(self) -> 'RasterioStore': def _create_image(self) -> None: mode = 'r+' if self.filename.exists() else 'w' - with rio.open(self.filename, mode=mode, **self.kwargs) as dst: - if self.tags is not None: - dst.update_tags(**self.tags) + if self.scatter == 'band': + for band_name in self.data.band.values: + with rio.open( + self.get_band_filename(band_name), + mode=mode, + **self.kwargs, + ) as dst: + if self.tags is not None: + dst.update_tags(**self.tags) + elif self.scatter == 'time': + for band_name in self.data.time.values: + with rio.open( + self.get_band_filename(band_name), + mode=mode, + **self.kwargs, + ) as dst: + if self.tags is not None: + dst.update_tags(**self.tags) + else: + with rio.open(self.filename, mode=mode, **self.kwargs) as dst: + if self.tags is not None: + dst.update_tags(**self.tags) + + def get_band_filename(self, band_name: str) -> Path: + return self.filename.parent / f"{self.filename.stem}_{band_name}.tif" def _write_window( self, @@ -226,20 +377,58 @@ def _write_window( indexes: T.Union[int, np.ndarray], window: T.Optional[Window] = None, ) -> None: - with rio.open(self.filename, mode='r+', **self.kwargs) as dst: - dst.write( - data, - indexes=indexes, - window=window, + if self.scatter in ( + 'band', + 'time', + ): + band_name_iter = ( + self.data.band.values + if self.scatter == 'band' + else self.data.time.values ) + for i, band_name in enumerate(band_name_iter): + if self.scatter == 'band': + # Take all time and the ith band + band_slice = ( + slice(0, None), + slice(i, i + 1), + ) + data_ = data[band_slice].squeeze(axis=1) + else: + # Take all bands and the ith time + band_slice = ( + slice(i, i + 1), + slice(0, None), + ) + data_ = data[band_slice].squeeze(axis=0) + + with rio.open( + self.get_band_filename(band_name), + mode='r+', + **self.kwargs, + ) as dst: + dst.write( + data_, + indexes=indexes, + window=window, + ) + else: + with rio.open(self.filename, mode='r+', **self.kwargs) as dst: + dst.write( + data, + indexes=indexes, + window=window, + ) - def write(self, data: xr.DataArray, compute: bool = False) -> Delayed: - if isinstance(data.data, da.Array): - return da.store(data.data, self, lock=self.lock_, compute=compute) + def write(self, compute: bool = False) -> Delayed: + if isinstance(self.data.data, da.Array): + return da.store( + self.data.data, self, lock=self.lock_, compute=compute + ) else: self._write_window( - data.data, - indexes=list(range(1, data.data.shape[0] + 1)), + self.data.data, + indexes=list(range(1, self.data.data.shape[0] + 1)), ) From b958e50e99cb333ad9bb0f1fb4c1d709b2ceb383 Mon Sep 17 00:00:00 2001 From: jgrss Date: Sun, 5 May 2024 18:38:47 +1000 Subject: [PATCH 07/26] add support to write time with scatter --- src/geowombat/core/geoxarray.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/geowombat/core/geoxarray.py b/src/geowombat/core/geoxarray.py index ec66b029..1e50181f 100644 --- a/src/geowombat/core/geoxarray.py +++ b/src/geowombat/core/geoxarray.py @@ -728,8 +728,8 @@ def to_netcdf( def save( self, filename: T.Union[str, _Path], - nodata: T.Optional[T.Union[float, int]] = None, overwrite: bool = False, + scatter: T.Optional[str] = None, client: T.Optional[_Client] = None, compute: bool = True, tags: T.Optional[dict] = None, @@ -747,6 +747,7 @@ def save( nodata (Optional[float | int]): The 'no data' value. If ``None`` (default), the 'no data' value is taken from the ``DataArray`` metadata. overwrite (Optional[bool]): Whether to overwrite an existing file. Default is False. + scatter (Optional[str]): Scatter 'band' or 'time' to separate file. Default is None. client (Optional[Client object]): A ``dask.distributed.Client`` client object to persist data. Default is None. compute (Optinoal[bool]): Whether to compute and write to ``filename``. Otherwise, return @@ -786,8 +787,8 @@ def save( save( self._obj, filename=filename, - nodata=nodata, overwrite=overwrite, + scatter=scatter, client=client, compute=compute, tags=tags, From d8962102ce07c2459c2c914300c985c6e1534398 Mon Sep 17 00:00:00 2001 From: jgrss Date: Sun, 5 May 2024 18:39:03 +1000 Subject: [PATCH 08/26] add support to write time with scatter --- src/geowombat/core/io.py | 87 ++++++---------------------------------- 1 file changed, 12 insertions(+), 75 deletions(-) diff --git a/src/geowombat/core/io.py b/src/geowombat/core/io.py index 47588875..316509c9 100644 --- a/src/geowombat/core/io.py +++ b/src/geowombat/core/io.py @@ -21,7 +21,6 @@ from dask.distributed import Client, progress from osgeo import gdal from rasterio import shutil as rio_shutil -from rasterio.drivers import driver_from_extension from rasterio.enums import Resampling from rasterio.vrt import WarpedVRT from rasterio.windows import Window @@ -39,7 +38,6 @@ ZARR_INSTALLED = False from ..backends.rasterio_ import RasterioStore, to_gtiff -from ..config import config from ..handler import add_handler from .windows import get_window_offsets @@ -683,8 +681,8 @@ def to_netcdf( def save( data: xr.DataArray, filename: T.Union[str, Path], - nodata: T.Optional[T.Union[float, int]] = None, overwrite: bool = False, + scatter: T.Optional[str] = None, client: T.Optional[Client] = None, compute: bool = True, tags: T.Optional[dict] = None, @@ -698,10 +696,10 @@ def save( """Saves a DataArray to raster using rasterio/dask. Args: + data (xarray.DataArray): The data to write. filename (str | Path): The output file name to write to. overwrite (Optional[bool]): Whether to overwrite an existing file. Default is False. - nodata (Optional[float | int]): The 'no data' value. If ``None`` (default), the 'no data' - value is taken from the ``DataArray`` metadata. + scatter (Optional[str]): Scatter 'band' or 'time' to separate file. Default is None. client (Optional[Client object]): A ``dask.distributed.Client`` client object to persist data. Default is None. compute (Optinoal[bool]): Whether to compute and write to ``filename``. Otherwise, return @@ -747,22 +745,6 @@ def save( if overwrite: Path(filename).unlink() - if nodata is None: - if hasattr(data, "_FillValue"): - nodata = data.attrs["_FillValue"] - else: - if hasattr(data, "nodatavals"): - nodata = data.attrs["nodatavals"][0] - else: - raise AttributeError( - "The DataArray does not have any 'no data' attributes." - ) - - dtype = data.dtype.name if isinstance(data.dtype, np.dtype) else data.dtype - if isinstance(nodata, float): - if dtype != "float32": - dtype = "float64" - if client is not None: if compress not in ( None, @@ -773,64 +755,19 @@ def save( ) compress = None - blockxsize = ( - data.gw.check_chunksize(512, data.gw.ncols) - if not data.gw.array_is_dask - else data.gw.col_chunks - ) - blockysize = ( - data.gw.check_chunksize(512, data.gw.nrows) - if not data.gw.array_is_dask - else data.gw.row_chunks - ) - - tiled = True - if config["with_config"]: - if config["bigtiff"] is not None: - if isinstance(config["bigtiff"], bool): - bigtiff = "YES" if config["bigtiff"] else "NO" - else: - bigtiff = config["bigtiff"].upper() - - if bigtiff not in ( - "YES", - "NO", - "IF_NEEDED", - "IF_SAFER", - ): - raise NameError( - "The GDAL BIGTIFF must be one of 'YES', 'NO', 'IF_NEEDED', or 'IF_SAFER'. See https://gdal.org/drivers/raster/gtiff.html#creation-issues for more information." - ) - - if config["compress"] is not None: - compress = config["compress"] - - if config["tiled"] is not None: - tiled = config["tiled"] - - kwargs = dict( - driver=driver_from_extension(filename), - width=data.gw.ncols, - height=data.gw.nrows, - count=data.gw.nbands, - dtype=dtype, - nodata=nodata, - blockxsize=blockxsize, - blockysize=blockysize, - crs=data.gw.crs_to_pyproj, - transform=data.gw.transform, - compress=compress, - tiled=tiled if max(blockxsize, blockysize) >= 16 else False, - sharing=False, - BIGTIFF=bigtiff, - ) - if tqdm_kwargs is None: tqdm_kwargs = {} - with RasterioStore(filename, tags=tags, **kwargs) as rio_store: + with RasterioStore( + data=data, + filename=filename, + scatter=scatter, + tags=tags, + compress=compress, + bigtiff=bigtiff, + ) as rio_store: # Store the data and return a lazy evaluator - res = rio_store.write(data) + res = rio_store.write() if not compute: return res From c0efc0998321f0c9820f1e574d0cd78735589c0c Mon Sep 17 00:00:00 2001 From: jgrss Date: Sun, 5 May 2024 18:40:19 +1000 Subject: [PATCH 09/26] update tests --- tests/test_write.py | 150 +++++++++++++++++++++++++++++++++++--------- 1 file changed, 122 insertions(+), 28 deletions(-) diff --git a/tests/test_write.py b/tests/test_write.py index b7aa1358..603db0d8 100644 --- a/tests/test_write.py +++ b/tests/test_write.py @@ -244,41 +244,135 @@ def test_config_save(self): with rio.open(out_path) as rio_src: self.assertTrue(rio_src.nodata == NODATA) - def test_save_small(self): + def test_save_scatter_band(self): with tempfile.TemporaryDirectory() as tmp: out_path = Path(tmp) / "test.tif" - with gw.open(l8_224078_20200518) as src: - data = src.gw.set_nodata(0, NODATA, dtype="uint16") - data = data[:, :1, :2] + with gw.open( + [l8_224078_20200518, l8_224078_20200518], + stack_dim='time', + ) as src: - try: - data.gw.save( - filename=out_path, - overwrite=True, - tags={"TEST_METADATA": "TEST_VALUE"}, - compress="none", - num_workers=1, + self.assertTrue(len(src.shape) == 4) + + src.gw.save( + filename=out_path, + overwrite=True, + scatter='band', + num_workers=1, + ) + + # Each file's band count is equal to the number + # of time dimensions + for band_name in src.band.values: + self.assertTrue( + ( + out_path.parent + / f"{out_path.stem}_{band_name}.tif" + ).exists() ) - except ValueError: - self.fail("The small array write test failed.") - def test_mosaic_save_single_band(self): - filenames = [l8_224077_20200518_B2, l8_224078_20200518_B2] + with gw.open( + out_path.parent / f"{out_path.stem}_1.tif" + ) as src_test: + self.assertTrue(src_test.gw.nbands == src.gw.ntime) + self.assertTrue( + np.allclose( + src_test.sel(band=1), + src.sel(time=1, band=1), + ) + ) + self.assertTrue( + np.allclose( + src_test.sel(band=2), + src.sel(time=2, band=1), + ) + ) + + def test_save_scatter_time(self): + with tempfile.TemporaryDirectory() as tmp: + out_path = Path(tmp) / "test.tif" + with gw.open( + [l8_224078_20200518, l8_224078_20200518], + stack_dim='time', + ) as src: + + self.assertTrue(len(src.shape) == 4) + + src.gw.save( + filename=out_path, + overwrite=True, + scatter='time', + num_workers=1, + ) + + # Each file's band count is equal to the number + # bands, and there are N time files + for band_name in src.time.values: + self.assertTrue( + ( + out_path.parent + / f"{out_path.stem}_{band_name}.tif" + ).exists() + ) - with tempfile.TemporaryDirectory() as temp_dir: - try: with gw.open( - filenames, - band_names=["blue"], - mosaic=True, - bounds_by="union", - nodata=0, - ) as src: - src.gw.save(Path(temp_dir) / "test.tif", overwrite=True) - - except Exception as e: - # If any exception is raised, fail the test with a message - self.fail(f"An error occurred during saving: {e}") + out_path.parent / f"{out_path.stem}_1.tif" + ) as src_test: + self.assertTrue(src_test.gw.nbands == src.gw.nbands) + self.assertTrue( + np.allclose( + src_test.sel(band=1), + src.sel(time=1, band=1), + ) + ) + self.assertTrue( + np.allclose( + src_test.sel(band=2), + src.sel(time=1, band=2), + ) + ) + self.assertTrue( + np.allclose( + src_test.sel(band=3), + src.sel(time=1, band=3), + ) + ) + + # def test_save_small(self): + # with tempfile.TemporaryDirectory() as tmp: + # out_path = Path(tmp) / "test.tif" + # with gw.open(l8_224078_20200518) as src: + # data = src.gw.set_nodata(0, NODATA, dtype="uint16") + # data = data[:, :1, :2] + + # try: + # data.gw.save( + # filename=out_path, + # overwrite=True, + # tags={"TEST_METADATA": "TEST_VALUE"}, + # compress="none", + # num_workers=1, + # ) + # except ValueError: + # self.fail("The small array write test failed.") + + # def test_mosaic_save_single_band(self): + # filenames = [l8_224077_20200518_B2, l8_224078_20200518_B2] + + # with tempfile.TemporaryDirectory() as temp_dir: + # try: + # with gw.open( + # filenames, + # band_names=["blue"], + # mosaic=True, + # bounds_by="union", + # nodata=0, + # ) as src: + # src.gw.save(Path(temp_dir) / "test.tif", overwrite=True) + + # except Exception as e: + # # If any exception is raised, fail the test with a message + # self.fail(f"An error occurred during saving: {e}") if __name__ == "__main__": From 533f6cadd505f61470d81c4b26657dc3bc189f7e Mon Sep 17 00:00:00 2001 From: jgrss Date: Sun, 5 May 2024 18:43:59 +1000 Subject: [PATCH 10/26] remove slice --- src/geowombat/backends/rasterio_.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/src/geowombat/backends/rasterio_.py b/src/geowombat/backends/rasterio_.py index 4f4e074c..044a1121 100644 --- a/src/geowombat/backends/rasterio_.py +++ b/src/geowombat/backends/rasterio_.py @@ -389,18 +389,10 @@ def _write_window( for i, band_name in enumerate(band_name_iter): if self.scatter == 'band': # Take all time and the ith band - band_slice = ( - slice(0, None), - slice(i, i + 1), - ) - data_ = data[band_slice].squeeze(axis=1) + data_ = data[:, i] else: # Take all bands and the ith time - band_slice = ( - slice(i, i + 1), - slice(0, None), - ) - data_ = data[band_slice].squeeze(axis=0) + data_ = data[i] with rio.open( self.get_band_filename(band_name), From 08992ce9b74b5d212034a528d363bc9f18e7d237 Mon Sep 17 00:00:00 2001 From: jgrss Date: Mon, 6 May 2024 18:18:08 +1000 Subject: [PATCH 11/26] fix stac tests --- tests/test_stac.py | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/tests/test_stac.py b/tests/test_stac.py index c4c0263c..e19ec8ac 100644 --- a/tests/test_stac.py +++ b/tests/test_stac.py @@ -52,10 +52,10 @@ def geosjon_to_df( geojson: dict, epsg: int, -): - df = gpd.GeoDataFrame(geometry=[shape(geojson)], crs=4326) - proj_df = df.to_crs(f'epsg:{epsg}') - df = ( +) -> tuple: + latlon_df = gpd.GeoDataFrame(geometry=[shape(geojson)], crs=4326) + proj_df = latlon_df.to_crs(f'epsg:{epsg}') + buffer_df = ( proj_df.buffer( 100, cap_style=CAP_STYLE.square, join_style=JOIN_STYLE.mitre ) @@ -63,7 +63,7 @@ def geosjon_to_df( .to_frame(name='geometry') ) - return df, tuple(proj_df.total_bounds.flatten().tolist()) + return buffer_df, tuple(proj_df.total_bounds.flatten().tolist()) SEARCH_EPSG = 8857 @@ -106,7 +106,7 @@ def test_search_sentinel_3_lst(self): )[0] self.assertTrue(stack.shape == (20, 1, 3, 4)) - self.assertTrue(stack.gw.crs_to_pyproj == CRS.from_epsg(EPSG)) + self.assertTrue(stack.gw.crs_to_pyproj == CRS.from_epsg(SEARCH_EPSG)) self.assertTrue(stack.gw.celly == 300.0) self.assertTrue(stack.gw.cellx == 300.0) self.assertTrue(stack.gw.nodataval == 32768) @@ -129,7 +129,7 @@ def test_search_blue_sentinel_s2_l1c(self): max_items=None, )[0] self.assertTrue(stack.shape == (2, 1, 48, 64)) - self.assertTrue(stack.gw.crs_to_pyproj == CRS.from_epsg(EPSG)) + self.assertTrue(stack.gw.crs_to_pyproj == CRS.from_epsg(SEARCH_EPSG)) self.assertTrue(stack.gw.celly == 10.0) self.assertTrue(stack.gw.cellx == 10.0) self.assertTrue(stack.gw.nodataval == 32768) @@ -198,7 +198,9 @@ def test_search_ms_landsat_c2_l2(self): set(stack.band.values).difference(['red', 'nir08']) ) self.assertTrue(stack.shape == (2, 2, 48, 64)) - self.assertTrue(stack.gw.crs_to_pyproj == CRS.from_epsg(EPSG)) + self.assertTrue( + stack.gw.crs_to_pyproj == CRS.from_epsg(SEARCH_EPSG) + ) self.assertTrue(stack.gw.celly == 10.0) self.assertTrue(stack.gw.cellx == 10.0) self.assertTrue(stack.gw.nodataval == 32768) @@ -210,9 +212,7 @@ def test_search_blue_sentinel_s2_l2a(self): stack, df = open_stac( stac_catalog='element84_v1', bounds=SEARCH_DF, - proj_bounds=tuple( - DF.to_crs(f'epsg:{EPSG}').total_bounds.flatten().tolist() - ), + proj_bounds=SEARCH_BOUNDS, epsg=SEARCH_EPSG, collection='sentinel_s2_l2a', bands=['blue'], @@ -229,7 +229,9 @@ def test_search_blue_sentinel_s2_l2a(self): ) self.assertTrue(stack.shape == (2, 1, 48, 64)) - self.assertTrue(stack.gw.crs_to_pyproj == CRS.from_epsg(EPSG)) + self.assertTrue( + stack.gw.crs_to_pyproj == CRS.from_epsg(SEARCH_EPSG) + ) self.assertTrue(stack.gw.celly == 10.0) self.assertTrue(stack.gw.cellx == 10.0) self.assertTrue(stack.gw.nodataval == 32768) From 33b4ee6dfadf68cf43bfede22881642febb28e97 Mon Sep 17 00:00:00 2001 From: jgrss Date: Mon, 6 May 2024 18:31:26 +1000 Subject: [PATCH 12/26] make method --- src/geowombat/backends/rasterio_.py | 30 ++++++++++++++++------------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/src/geowombat/backends/rasterio_.py b/src/geowombat/backends/rasterio_.py index 044a1121..09a35b4b 100644 --- a/src/geowombat/backends/rasterio_.py +++ b/src/geowombat/backends/rasterio_.py @@ -343,30 +343,34 @@ def _open(self) -> 'RasterioStore': return self + def _write_file(self, filename: Path, mode: str) -> None: + with rio.open( + filename, + mode=mode, + **self.kwargs, + ) as dst: + if self.tags is not None: + dst.update_tags(**self.tags) + def _create_image(self) -> None: mode = 'r+' if self.filename.exists() else 'w' if self.scatter == 'band': for band_name in self.data.band.values: - with rio.open( + self._write_file( self.get_band_filename(band_name), mode=mode, - **self.kwargs, - ) as dst: - if self.tags is not None: - dst.update_tags(**self.tags) + ) elif self.scatter == 'time': for band_name in self.data.time.values: - with rio.open( + self._write_file( self.get_band_filename(band_name), mode=mode, - **self.kwargs, - ) as dst: - if self.tags is not None: - dst.update_tags(**self.tags) + ) else: - with rio.open(self.filename, mode=mode, **self.kwargs) as dst: - if self.tags is not None: - dst.update_tags(**self.tags) + self._write_file( + self.filename, + mode=mode, + ) def get_band_filename(self, band_name: str) -> Path: return self.filename.parent / f"{self.filename.stem}_{band_name}.tif" From 11252eb5d8bbe6f627251d2648eeac7d5001c215 Mon Sep 17 00:00:00 2001 From: jgrss Date: Mon, 6 May 2024 18:56:27 +1000 Subject: [PATCH 13/26] allow compression, but add docstring note --- src/geowombat/core/geoxarray.py | 6 +++ src/geowombat/core/io.py | 16 +++---- tests/test_write.py | 76 ++++++++++++++++----------------- 3 files changed, 50 insertions(+), 48 deletions(-) diff --git a/src/geowombat/core/geoxarray.py b/src/geowombat/core/geoxarray.py index 1e50181f..7d0e7186 100644 --- a/src/geowombat/core/geoxarray.py +++ b/src/geowombat/core/geoxarray.py @@ -755,6 +755,12 @@ def save( return the ``dask`` task graph. Default is ``True``. tags (Optional[dict]): Metadata tags to write to file. Default is None. compress (Optional[str]): The file compression type. Default is 'none', or no compression. + + .. note:: + When using a client, it is advised to use threading. E.g., + ``dask.distributed.LocalCluster(processes=False)``. Process-based concurrency could + result in corrupted file blocks. + compression (Optional[str]): The file compression type. Default is 'none', or no compression. .. deprecated:: 2.1.4 diff --git a/src/geowombat/core/io.py b/src/geowombat/core/io.py index 316509c9..a32b52d6 100644 --- a/src/geowombat/core/io.py +++ b/src/geowombat/core/io.py @@ -707,6 +707,12 @@ def save( return the ``dask`` task graph. Default is ``True``. tags (Optional[dict]): Metadata tags to write to file. Default is None. compress (Optional[str]): The file compression type. Default is 'none', or no compression. + + .. note:: + When using a client, it is advised to use threading. E.g., + ``dask.distributed.LocalCluster(processes=False)``. Process-based concurrency could + result in corrupted file blocks. + compression (Optional[str]): The file compression type. Default is 'none', or no compression. .. deprecated:: 2.1.4 @@ -745,16 +751,6 @@ def save( if overwrite: Path(filename).unlink() - if client is not None: - if compress not in ( - None, - "none", - ): - logger.warning( - " Cannot write to a compressed file with a Dask Client(). Data will be uncompressed." - ) - compress = None - if tqdm_kwargs is None: tqdm_kwargs = {} diff --git a/tests/test_write.py b/tests/test_write.py index 603db0d8..5d339965 100644 --- a/tests/test_write.py +++ b/tests/test_write.py @@ -178,9 +178,9 @@ def test_write_numpy(self): def test_client_save(self): with LocalCluster( - processes=True, - n_workers=4, - threads_per_worker=1, + processes=False, + n_workers=2, + threads_per_worker=2, memory_limit="2GB", ) as cluster: with Client(cluster) as client: @@ -338,41 +338,41 @@ def test_save_scatter_time(self): ) ) - # def test_save_small(self): - # with tempfile.TemporaryDirectory() as tmp: - # out_path = Path(tmp) / "test.tif" - # with gw.open(l8_224078_20200518) as src: - # data = src.gw.set_nodata(0, NODATA, dtype="uint16") - # data = data[:, :1, :2] - - # try: - # data.gw.save( - # filename=out_path, - # overwrite=True, - # tags={"TEST_METADATA": "TEST_VALUE"}, - # compress="none", - # num_workers=1, - # ) - # except ValueError: - # self.fail("The small array write test failed.") - - # def test_mosaic_save_single_band(self): - # filenames = [l8_224077_20200518_B2, l8_224078_20200518_B2] - - # with tempfile.TemporaryDirectory() as temp_dir: - # try: - # with gw.open( - # filenames, - # band_names=["blue"], - # mosaic=True, - # bounds_by="union", - # nodata=0, - # ) as src: - # src.gw.save(Path(temp_dir) / "test.tif", overwrite=True) - - # except Exception as e: - # # If any exception is raised, fail the test with a message - # self.fail(f"An error occurred during saving: {e}") + def test_save_small(self): + with tempfile.TemporaryDirectory() as tmp: + out_path = Path(tmp) / "test.tif" + with gw.open(l8_224078_20200518) as src: + data = src.gw.set_nodata(0, NODATA, dtype="uint16") + data = data[:, :1, :2] + + try: + data.gw.save( + filename=out_path, + overwrite=True, + tags={"TEST_METADATA": "TEST_VALUE"}, + compress="none", + num_workers=1, + ) + except ValueError: + self.fail("The small array write test failed.") + + def test_mosaic_save_single_band(self): + filenames = [l8_224077_20200518_B2, l8_224078_20200518_B2] + + with tempfile.TemporaryDirectory() as temp_dir: + try: + with gw.open( + filenames, + band_names=["blue"], + mosaic=True, + bounds_by="union", + nodata=0, + ) as src: + src.gw.save(Path(temp_dir) / "test.tif", overwrite=True) + + except Exception as e: + # If any exception is raised, fail the test with a message + self.fail(f"An error occurred during saving: {e}") if __name__ == "__main__": From 03f881bd81246ffc56fe2bcaac0531770383d0f1 Mon Sep 17 00:00:00 2001 From: jgrss Date: Thu, 16 May 2024 20:36:50 +1000 Subject: [PATCH 14/26] add band setting --- src/geowombat/backends/xarray_.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/src/geowombat/backends/xarray_.py b/src/geowombat/backends/xarray_.py index ef5b4467..ef59cc5b 100644 --- a/src/geowombat/backends/xarray_.py +++ b/src/geowombat/backends/xarray_.py @@ -476,7 +476,6 @@ def reduce_func( xr.where(left != tmp_nodata, left, right), ) - # Open all the data pointers data_arrays = [ open_rasterio( @@ -548,7 +547,6 @@ def reduce_func( attrs.update(tags) darray = darray.assign_attrs(**attrs) - if dtype is not None: attrs = darray.attrs.copy() return darray.astype(dtype).assign_attrs(**attrs) @@ -615,7 +613,10 @@ def concat( Returns: ``xarray.DataArray`` """ - if stack_dim.lower() not in ['band', 'time']: + if stack_dim.lower() not in ( + 'band', + 'time', + ): logger.exception(" The stack dimension should be 'band' or 'time'.") with rio_open(filenames[0]) as src_: @@ -690,6 +691,7 @@ def concat( if not concat_list[0].gw.config['ignore_warnings']: check_alignment(concat_list) + # Warp all images and concatenate along the 'time' axis into a DataArray src = xr.concat(concat_list, dim=stack_dim.lower()).assign_coords( time=new_time_names @@ -711,6 +713,7 @@ def concat( ] if not warp_list[0].gw.config['ignore_warnings']: check_alignment(warp_list) + src = xr.concat(warp_list, dim=stack_dim.lower()) src = src.assign_attrs(**{'filename': [Path(fn).name for fn in filenames]}) @@ -740,7 +743,7 @@ def concat( src.coords['time'] = parse_filename_dates(filenames) if band_names: - src.coords['band'] = band_names + src = src.assign_coords(band=band_names) else: if src.gw.sensor: if src.gw.sensor not in src.gw.avail_sensors: @@ -768,6 +771,8 @@ def concat( src = src.assign_attrs( **{'sensor': src.gw.sensor_names[src.gw.sensor]} ) + else: + src = src.assign_coords(band=range(1, src.gw.nbands + 1)) if dtype: attrs = src.attrs.copy() From e01fd54c4581acb1c7894fa3e9ba9af86821616d Mon Sep 17 00:00:00 2001 From: jgrss Date: Thu, 16 May 2024 20:37:31 +1000 Subject: [PATCH 15/26] add check for chunks type --- src/geowombat/backends/xarray_rasterio_.py | 26 +++++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/src/geowombat/backends/xarray_rasterio_.py b/src/geowombat/backends/xarray_rasterio_.py index fc7a8855..3ffb3d7c 100644 --- a/src/geowombat/backends/xarray_rasterio_.py +++ b/src/geowombat/backends/xarray_rasterio_.py @@ -10,6 +10,7 @@ import os import typing as T import warnings +from functools import singledispatch from pathlib import Path import numpy as np @@ -174,6 +175,25 @@ def default(s): return parsed_meta +@singledispatch +def check_chunks(chunks: dict) -> dict: + return chunks + + +@check_chunks.register +def _(chunks: tuple) -> dict: + return dict( + zip( + ( + 'band', + 'y', + 'x', + ), + chunks, + ) + ) + + def open_rasterio( filename: T.Union[str, Path, DatasetReader, WarpedVRT], nodata: T.Optional[T.Union[float, int]] = None, @@ -424,7 +444,11 @@ def open_rasterio( mtime = None token = tokenize(filename, mtime, chunks) name_prefix = f"open_rasterio-{token}" - result = result.chunk(chunks, name_prefix=name_prefix, token=token) + result = result.chunk( + check_chunks(chunks), + name_prefix=name_prefix, + token=token, + ) # Make the file closeable result.set_close(manager.close) From 05c9623e05d0c8a1fc790d57048627fc547518fa Mon Sep 17 00:00:00 2001 From: jgrss Date: Thu, 16 May 2024 20:38:05 +1000 Subject: [PATCH 16/26] add tests --- tests/test_open.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/tests/test_open.py b/tests/test_open.py index fce07234..f486a781 100755 --- a/tests/test_open.py +++ b/tests/test_open.py @@ -91,6 +91,38 @@ def test_open_multiple(self): self.assertEqual(src.gw.nbands, 2) self.assertTrue(src.gw.has_band_dim) self.assertTrue(src.gw.has_band_coord) + self.assertEqual(src.band.values.tolist(), [1, 2]) + + with gw.open( + [ + l8_224078_20200518_B2, + l8_224078_20200518_B2, + l8_224078_20200518_B2, + ], + stack_dim='band', + ) as src: + self.assertEqual(src.band.values.tolist(), [1, 2, 3]) + + with gw.open( + [ + l8_224078_20200518_B2, + l8_224078_20200518_B2, + l8_224078_20200518, + ], + stack_dim='band', + ) as src: + self.assertEqual(src.band.values.tolist(), [1, 2, 3, 4, 5]) + + with gw.open( + [ + l8_224078_20200518_B2, + l8_224078_20200518_B2, + l8_224078_20200518_B2, + ], + stack_dim='band', + band_names=['a', 'b', 'c'], + ) as src: + self.assertEqual(src.band.values.tolist(), ['a', 'b', 'c']) def test_open_multiple_same(self): with gw.open( From 844e0507a73aeb53c0e419e9e32cc1a4c124e69a Mon Sep 17 00:00:00 2001 From: jgrss Date: Sat, 18 May 2024 13:07:47 +1000 Subject: [PATCH 17/26] add pip flag --- .github/workflows/ci.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index bbd88264..f973f94a 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -41,8 +41,8 @@ jobs: pip install -U pip setuptools wheel pip install numpy GDAL_VERSION=$(gdal-config --version | awk -F'[.]' '{print $1"."$2}') - pip install GDAL==$GDAL_VERSION --no-cache-dir - pip install arosics + pip install GDAL==$GDAL_VERSION --no-cache-dir --no-binary=gdal + pip install arosics --no-deps - name: Install GeoWombat run: | pip install ".[stac,web,coreg,perf,tests]" From 7dfe030432cf3cb56b03e71e47c7c3fc4a2e17f8 Mon Sep 17 00:00:00 2001 From: jgrss Date: Sat, 18 May 2024 13:45:55 +1000 Subject: [PATCH 18/26] reorganize --- src/geowombat/backends/rasterio_.py | 57 ++++++++++---------- src/geowombat/backends/xarray_.py | 80 ++++++++++------------------- src/geowombat/core/api.py | 54 ++++++++++--------- src/geowombat/core/util.py | 14 ++--- 4 files changed, 92 insertions(+), 113 deletions(-) diff --git a/src/geowombat/backends/rasterio_.py b/src/geowombat/backends/rasterio_.py index 85991b7d..c841b189 100644 --- a/src/geowombat/backends/rasterio_.py +++ b/src/geowombat/backends/rasterio_.py @@ -52,6 +52,8 @@ def get_dims_from_bounds( def get_file_info( src_obj: T.Union[rio.io.DatasetReader, rio.io.DatasetWriter] ) -> namedtuple: + """Gets image file information.""" + src_bounds = src_obj.bounds src_res = src_obj.res src_width = src_obj.width @@ -254,7 +256,7 @@ def check_res( int, ] ) -> T.Tuple[float, float]: - """Checks a resolution. + """Checks an image's resolution. Args: res (int | float | tuple): The resolution. @@ -347,10 +349,10 @@ def check_file_crs(filename: T.Union[str, Path]) -> CRS: # rasterio does not open and read metadata from NetCDF files if str(filename).lower().startswith('netcdf:'): with xr.open_dataset(filename.split(':')[1], chunks=256) as src: - src_crs = src.crs + src_crs = check_src_crs(src) else: with xr.open_dataset(filename, chunks=256) as src: - src_crs = src.crs + src_crs = check_src_crs(src) else: with rio.open(filename) as src: @@ -359,7 +361,7 @@ def check_file_crs(filename: T.Union[str, Path]) -> CRS: return check_crs(src_crs) -def unpack_bounding_box(bounds: str) -> T.Tuple[float, float, float, float]: +def unpack_bounding_box(bounds: str) -> BoundingBox: """Unpacks a BoundBox() string. Args: @@ -372,15 +374,15 @@ def unpack_bounding_box(bounds: str) -> T.Tuple[float, float, float, float]: for str_ in bounds_str: if str_.strip().startswith('left='): - left_coord = float(str_.strip().split('=')[1].replace(')', '')) + left = float(str_.strip().split('=')[1].replace(')', '')) elif str_.strip().startswith('bottom='): - bottom_coord = float(str_.strip().split('=')[1].replace(')', '')) + bottom = float(str_.strip().split('=')[1].replace(')', '')) elif str_.strip().startswith('right='): - right_coord = float(str_.strip().split('=')[1].replace(')', '')) + right = float(str_.strip().split('=')[1].replace(')', '')) elif str_.strip().startswith('top='): - top_coord = float(str_.strip().split('=')[1].replace(')', '')) + top = float(str_.strip().split('=')[1].replace(')', '')) - return left_coord, bottom_coord, right_coord, top_coord + return BoundingBox(left=left, bottom=bottom, right=right, top=top) def unpack_window(bounds: str) -> Window: @@ -408,30 +410,26 @@ def unpack_window(bounds: str) -> Window: def window_to_bounds( - filenames: T.Union[str, Path, T.Sequence[T.Union[str, Path]]], w: Window -) -> T.Tuple[float, float, float, float]: + filename: T.Union[str, Path, T.Sequence[T.Union[str, Path]]], w: Window +) -> BoundingBox: """Transforms a rasterio Window() object to image bounds. Args: - filenames (str or str list) + filename (str or str list) w (object) Returns: ``tuple`` """ - if isinstance(filenames, str): - src = rio.open(filenames) - else: - src = rio.open(filenames[0]) - - left, top = src.transform * (w.col_off, w.row_off) - - right = left + w.width * abs(src.res[0]) - bottom = top - w.height * abs(src.res[1]) + if isinstance(filename, (list, tuple)): + filename = filename[0] - src.close() + with rio.open(filename) as src: + left, top = src.transform * (w.col_off, w.row_off) + right = left + w.width * abs(src.res[0]) + bottom = top - w.height * abs(src.res[1]) - return left, bottom, right, top + return BoundingBox(left=left, bottom=bottom, right=right, top=top) def align_bounds( @@ -441,7 +439,7 @@ def align_bounds( maxy: float, res: T.Union[T.Tuple[float, float], T.Sequence[float], float, int], ) -> T.Tuple[Affine, int, int]: - """Aligns bounds to resolution. + """Aligns bounds to a resolution. Args: minx (float) @@ -478,7 +476,7 @@ def get_file_bounds( """Gets the union of all files. Args: - filenames (list): The file names to mosaic. + filenames (list): The file names from which to get bounds overlap. bounds_by (Optional[str]): How to concatenate the output extent. Choices are ['intersection', 'union']. crs (Optional[crs]): The CRS to warp to. res (Optional[tuple]): The cell resolution to warp to. @@ -504,7 +502,7 @@ def get_file_bounds( with rio.open(filenames[0]) as src: src_info = get_file_info(src) - if res: + if res is not None: dst_res = check_res(res) else: dst_res = src_info.src_res @@ -525,7 +523,10 @@ def get_file_bounds( densify_pts=21, ) - if bounds_by.lower() in ['union', 'intersection']: + if bounds_by.lower() in ( + 'union', + 'intersection', + ): for fn in filenames[1:]: src_crs = check_file_crs(fn) @@ -644,7 +645,7 @@ def warp_images( return [warp(fn, **warp_kwargs) for fn in filenames] -def get_ref_image_meta(filename): +def get_ref_image_meta(filename: T.Union[Path, str]) -> namedtuple: """Gets warping information from a reference image. Args: diff --git a/src/geowombat/backends/xarray_.py b/src/geowombat/backends/xarray_.py index ef59cc5b..3e793f00 100644 --- a/src/geowombat/backends/xarray_.py +++ b/src/geowombat/backends/xarray_.py @@ -76,12 +76,20 @@ def _check_config_globals( bounds_by (str) ref_kwargs (dict) """ + assert bounds_by.lower() in ( + "intersection", + "reference", + "union", + ), "The bounds_by argument must be 'intersection', 'reference', or 'union'." + if config['nodata'] is not None: ref_kwargs = _update_kwarg(config['nodata'], ref_kwargs, 'nodata') + # Check if there is a reference image if config['ref_image']: - if isinstance(config['ref_image'], str) and os.path.isfile( - config['ref_image'] + if ( + isinstance(config['ref_image'], (Path, str)) + and Path(config['ref_image']).is_file() ): # Get the metadata from the reference image ref_meta = get_ref_image_meta(config['ref_image']) @@ -98,32 +106,25 @@ def _check_config_globals( if isinstance(config['ref_bounds'], str) and config[ 'ref_bounds' ].startswith('Window'): - ref_bounds_ = window_to_bounds( + ref_bounds = window_to_bounds( filenames, unpack_window(config['ref_bounds']) ) elif isinstance(config['ref_bounds'], str) and config[ 'ref_bounds' ].startswith('BoundingBox'): - ref_bounds_ = unpack_bounding_box(config['ref_bounds']) + ref_bounds = unpack_bounding_box(config['ref_bounds']) elif isinstance(config['ref_bounds'], Window): - ref_bounds_ = window_to_bounds(filenames, config['ref_bounds']) + ref_bounds = window_to_bounds(filenames, config['ref_bounds']) elif isinstance(config['ref_bounds'], BoundingBox): - - ref_bounds_ = ( - config['ref_bounds'].left, - config['ref_bounds'].bottom, - config['ref_bounds'].right, - config['ref_bounds'].top, - ) - + ref_bounds = config['ref_bounds'] else: - ref_bounds_ = config['ref_bounds'] + ref_bounds = config['ref_bounds'] - ref_kwargs = _update_kwarg(ref_bounds_, ref_kwargs, 'bounds') + ref_kwargs = _update_kwarg(tuple(ref_bounds), ref_kwargs, 'bounds') else: - if isinstance(filenames, str) or isinstance(filenames, Path): - # Use the bounds of the image + if isinstance(filenames, (Path, str)): + # Use the bounds of the input image ref_kwargs['bounds'] = get_file_bounds( [filenames], bounds_by='reference', @@ -133,41 +134,14 @@ def _check_config_globals( ) else: - # Replace the bounds keyword, if needed - if bounds_by.lower() == 'intersection': - # Get the intersecting bounds of all images - ref_kwargs['bounds'] = get_file_bounds( - filenames, - bounds_by='intersection', - crs=ref_kwargs['crs'], - res=ref_kwargs['res'], - return_bounds=True, - ) - - elif bounds_by.lower() == 'union': - # Get the union bounds of all images - ref_kwargs['bounds'] = get_file_bounds( - filenames, - bounds_by='union', - crs=ref_kwargs['crs'], - res=ref_kwargs['res'], - return_bounds=True, - ) - - elif bounds_by.lower() == 'reference': - # Use the bounds of the first image - ref_kwargs['bounds'] = get_file_bounds( - filenames, - bounds_by='reference', - crs=ref_kwargs['crs'], - res=ref_kwargs['res'], - return_bounds=True, - ) - - else: - logger.exception( - " Choose from 'intersection', 'union', or 'reference'." - ) + # Get the union bounds of all images + ref_kwargs['bounds'] = get_file_bounds( + filenames, + bounds_by=bounds_by.lower(), + crs=ref_kwargs['crs'], + res=ref_kwargs['res'], + return_bounds=True, + ) config['ref_bounds'] = ref_kwargs['bounds'] @@ -179,7 +153,7 @@ def _check_config_globals( if config['ref_tar'] is not None: if isinstance(config['ref_tar'], str): - if os.path.isfile(config['ref_tar']): + if Path(config['ref_tar']).is_file(): ref_kwargs = _update_kwarg( _get_raster_coords(config['ref_tar']), ref_kwargs, diff --git a/src/geowombat/core/api.py b/src/geowombat/core/api.py index c332fd4e..a3238e21 100644 --- a/src/geowombat/core/api.py +++ b/src/geowombat/core/api.py @@ -15,6 +15,7 @@ import typing as T import warnings from contextlib import contextmanager +from functools import singledispatch from pathlib import Path import dask @@ -71,6 +72,20 @@ def _tqdm(*args, **kwargs): yield None +@singledispatch +def get_image_chunks(filename: str) -> int: + with rio.open(filename) as src_: + w = src_.block_window(1, 0, 0) + chunks = (-1, w.height, w.width) + + return chunks + + +@get_image_chunks.register +def _(filename: list | tuple) -> int: + return get_image_chunks(filename([0])) + + def _get_attrs(src, **kwargs): cellxh = src.res[0] / 2.0 cellyh = src.res[1] / 2.0 @@ -268,7 +283,7 @@ def read( data_ = None -class open(object): +class open: """Opens one or more raster files. Args: @@ -447,15 +462,17 @@ def __init__( num_workers: T.Optional[int] = 1, **kwargs, ): - if stack_dim not in ["band", "time"]: + if stack_dim not in ( + "band", + "time", + ): logger.exception( f" The 'stack_dim' keyword argument must be either 'band' or 'time', but not {stack_dim}" ) - raise NameError if isinstance(filename, Path): filename = str(filename) - elif isinstance(filename, list) and len(filename) == 1: + elif isinstance(filename, (list, tuple)) and (len(filename) == 1): filename = str(filename[0]) self.data = data_ @@ -466,25 +483,15 @@ def __init__( band_chunks = -1 if "chunks" in kwargs: - if kwargs["chunks"] is not None: - kwargs["chunks"] = ch.check_chunktype( - kwargs["chunks"], output="3d" - ) + kwargs["chunks"] = ch.check_chunktype( + kwargs["chunks"], output="3d" + ) - if bounds or ( - "window" in kwargs and isinstance(kwargs["window"], Window) + if (bounds is not None) or ( + ("window" in kwargs) and isinstance(kwargs["window"], Window) ): if "chunks" not in kwargs: - if isinstance(filename, list): - with rio.open(filename[0]) as src_: - w = src_.block_window(1, 0, 0) - chunks = (band_chunks, w.height, w.width) - - else: - with rio.open(filename) as src_: - w = src_.block_window(1, 0, 0) - chunks = (band_chunks, w.height, w.width) - + chunks = get_image_chunks(filename) else: chunks = kwargs["chunks"] del kwargs["chunks"] @@ -501,17 +508,14 @@ def __init__( self.__filenames = [str(filename)] else: - if (isinstance(filename, str) and "*" in filename) or isinstance( + if (isinstance(filename, str) and ("*" in filename)) or isinstance( filename, list ): - # Build the filename list if isinstance(filename, str): filename = parse_wildcard(filename) if "chunks" not in kwargs: - with rio.open(filename[0]) as src: - w = src.block_window(1, 0, 0) - kwargs["chunks"] = (band_chunks, w.height, w.width) + kwargs["chunks"] = get_image_chunks(filename) if mosaic: # Mosaic images over space diff --git a/src/geowombat/core/util.py b/src/geowombat/core/util.py index 78f34b44..e92db407 100644 --- a/src/geowombat/core/util.py +++ b/src/geowombat/core/util.py @@ -103,7 +103,7 @@ def parse_filename_dates( return date_filenames -def parse_wildcard(string: str) -> T.Sequence: +def parse_wildcard(string: str) -> T.List[Path]: """Parses a search wildcard from a string. @@ -115,16 +115,13 @@ def parse_wildcard(string: str) -> T.Sequence: """ if os.path.dirname(string): - d_name, wildcard = os.path.split(string) + dir_name, wildcard = os.path.split(string) else: - d_name = '.' + dir_name = '.' wildcard = string - matches = sorted(fnmatch.filter(os.listdir(d_name), wildcard)) - - if matches: - matches = [os.path.join(d_name, fn) for fn in matches] + matches = sorted(list(Path(dir_name).glob(wildcard))) if not matches: logger.exception( @@ -341,6 +338,9 @@ def get_chunk_dim(chunksize): return '{:d}d'.format(len(chunksize)) def check_chunktype(self, chunksize: int, output: str = '3d'): + if chunksize is None: + return chunksize + if isinstance(chunksize, int): chunksize = (-1, chunksize, chunksize) From fbc0d84d2728303d825f4840d6d21aa44afc8534 Mon Sep 17 00:00:00 2001 From: jgrss Date: Sat, 18 May 2024 14:00:28 +1000 Subject: [PATCH 19/26] min py version not supporting types --- src/geowombat/core/api.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/geowombat/core/api.py b/src/geowombat/core/api.py index a3238e21..4a19ad0c 100644 --- a/src/geowombat/core/api.py +++ b/src/geowombat/core/api.py @@ -82,7 +82,12 @@ def get_image_chunks(filename: str) -> int: @get_image_chunks.register -def _(filename: list | tuple) -> int: +def _(filename: list) -> int: + return get_image_chunks(filename([0])) + + +@get_image_chunks.register +def _(filename: tuple) -> int: return get_image_chunks(filename([0])) From 9a00a9f5c2dc07f919a94b2ece27c1bb2e1ffb48 Mon Sep 17 00:00:00 2001 From: jgrss Date: Sat, 18 May 2024 14:50:54 +1000 Subject: [PATCH 20/26] use rasterio transform unpacking --- src/geowombat/backends/rasterio_.py | 82 ++++++++++++++--------------- 1 file changed, 39 insertions(+), 43 deletions(-) diff --git a/src/geowombat/backends/rasterio_.py b/src/geowombat/backends/rasterio_.py index 343f39cf..b5889caf 100644 --- a/src/geowombat/backends/rasterio_.py +++ b/src/geowombat/backends/rasterio_.py @@ -18,7 +18,8 @@ from rasterio.coords import BoundingBox from rasterio.drivers import driver_from_extension from rasterio.enums import Resampling -from rasterio.transform import array_bounds, from_bounds +from rasterio.transform import array_bounds +from rasterio.transform import from_bounds as transform_from_bounds from rasterio.vrt import WarpedVRT from rasterio.warp import ( aligned_target, @@ -27,6 +28,7 @@ transform_bounds, ) from rasterio.windows import Window +from rasterio.windows import from_bounds as window_from_bounds import geowombat as gw @@ -44,13 +46,25 @@ logger = logging.getLogger(__name__) +def transform_from_corner( + bounds: BoundingBox, res: T.Sequence[float] +) -> Affine: + return Affine( + res[0], + 0.0, + bounds.left, + 0.0, + -res[1], + bounds.top, + ) + + def get_dims_from_bounds( - bounds: BoundingBox, res: T.Tuple[float, float] -) -> T.Tuple[int, int]: - width = int((bounds.right - bounds.left) / abs(res[0])) - height = int((bounds.top - bounds.bottom) / abs(res[1])) + bounds: BoundingBox, res: T.Sequence[float] +) -> Window: + transform = transform_from_corner(bounds, res) - return height, width + return window_from_bounds(*bounds, transform=transform) def get_file_info( @@ -747,7 +761,7 @@ def get_file_bounds( bounds_width = int((bounds_right - bounds_left) / abs(dst_res[0])) bounds_height = int((bounds_top - bounds_bottom) / abs(dst_res[1])) - bounds_transform = from_bounds( + bounds_transform = transform_from_bounds( bounds_left, bounds_bottom, bounds_right, @@ -894,7 +908,7 @@ def warp( # Check if the data need to be subset if (bounds is None) or (tuple(bounds) == tuple(src_info.src_bounds)): - if crs: + if crs is not None: ( left_coord, bottom_coord, @@ -927,23 +941,11 @@ def warp( elif isinstance(bounds, str): if bounds.startswith('BoundingBox'): - ( - left_coord, - bottom_coord, - right_coord, - top_coord, - ) = unpack_bounding_box(bounds) + dst_bounds = unpack_bounding_box(bounds) else: logger.exception(' The bounds were not accepted.') raise TypeError - dst_bounds = BoundingBox( - left=left_coord, - bottom=bottom_coord, - right=right_coord, - top=top_coord, - ) - elif isinstance(bounds, (list, np.ndarray, tuple)): dst_bounds = BoundingBox( left=bounds[0], @@ -959,15 +961,15 @@ def warp( ) raise TypeError - dst_height, dst_width = get_dims_from_bounds(dst_bounds, dst_res) + dst_window = get_dims_from_bounds(dst_bounds, dst_res) # Do all the key metadata match the reference information? if ( (tuple(src_info.src_bounds) == tuple(bounds)) and (src_info.src_res == dst_res) and (src_crs == dst_crs) - and (src_info.src_width == dst_width) - and (src_info.src_height == dst_height) + and (src_info.src_width == dst_window.width) + and (src_info.src_height == dst_window.height) and ('.nc' not in filename.lower()) ): vrt_options = { @@ -976,8 +978,8 @@ def warp( 'crs': src_crs, 'src_transform': src.transform, 'transform': src.transform, - 'height': dst_height, - 'width': dst_width, + 'height': dst_window.height, + 'width': dst_window.width, 'nodata': None, 'warp_mem_limit': warp_mem_limit, 'warp_extras': { @@ -987,21 +989,12 @@ def warp( } else: - src_transform = Affine( - src_info.src_res[0], - 0.0, - src_info.src_bounds.left, - 0.0, - -src_info.src_res[1], - src_info.src_bounds.top, + src_transform = transform_from_corner( + src_info.src_bounds, src_info.src_res ) - dst_transform = Affine( - dst_res[0], - 0.0, - dst_bounds.left, - 0.0, - -dst_res[1], - dst_bounds.top, + + dst_transform = transform_from_corner( + dst_bounds.src_bounds, dst_res.src_res ) if tac is not None: @@ -1016,7 +1009,10 @@ def warp( if tap: # Align the cells to the resolution dst_transform, dst_width, dst_height = aligned_target( - dst_transform, dst_width, dst_height, dst_res + dst_transform, dst_window.width, dst_window.height, dst_res + ) + dst_window = Window( + row_off=0, col_off=0, width=dst_width, height=dst_height ) vrt_options = { @@ -1025,8 +1021,8 @@ def warp( 'crs': dst_crs, 'src_transform': src_transform, 'transform': dst_transform, - 'height': dst_height, - 'width': dst_width, + 'height': dst_window.height, + 'width': dst_window.width, 'nodata': nodata, 'warp_mem_limit': warp_mem_limit, 'warp_extras': { From 4753adc890c2ca4996e7f3f09c6fe4bf72afd04d Mon Sep 17 00:00:00 2001 From: jgrss Date: Sat, 18 May 2024 15:10:00 +1000 Subject: [PATCH 21/26] update methods --- .github/workflows/ci.yml | 2 +- src/geowombat/backends/rasterio_.py | 10 +++++----- src/geowombat/core/api.py | 4 ++-- src/geowombat/core/sops.py | 11 ++++++----- tests/test_rasterio.py | 8 ++++---- 5 files changed, 18 insertions(+), 17 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index f973f94a..058b863d 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -42,10 +42,10 @@ jobs: pip install numpy GDAL_VERSION=$(gdal-config --version | awk -F'[.]' '{print $1"."$2}') pip install GDAL==$GDAL_VERSION --no-cache-dir --no-binary=gdal - pip install arosics --no-deps - name: Install GeoWombat run: | pip install ".[stac,web,coreg,perf,tests]" + pip install arosics --no-deps - name: Run Unittests run: | pip install testfixtures diff --git a/src/geowombat/backends/rasterio_.py b/src/geowombat/backends/rasterio_.py index b5889caf..346bad09 100644 --- a/src/geowombat/backends/rasterio_.py +++ b/src/geowombat/backends/rasterio_.py @@ -49,6 +49,7 @@ def transform_from_corner( bounds: BoundingBox, res: T.Sequence[float] ) -> Affine: + """Gets an affine transform from an upper left corner.""" return Affine( res[0], 0.0, @@ -59,9 +60,10 @@ def transform_from_corner( ) -def get_dims_from_bounds( +def get_window_from_bounds( bounds: BoundingBox, res: T.Sequence[float] ) -> Window: + """Gets.""" transform = transform_from_corner(bounds, res) return window_from_bounds(*bounds, transform=transform) @@ -961,7 +963,7 @@ def warp( ) raise TypeError - dst_window = get_dims_from_bounds(dst_bounds, dst_res) + dst_window = get_window_from_bounds(dst_bounds, dst_res) # Do all the key metadata match the reference information? if ( @@ -993,9 +995,7 @@ def warp( src_info.src_bounds, src_info.src_res ) - dst_transform = transform_from_corner( - dst_bounds.src_bounds, dst_res.src_res - ) + dst_transform = transform_from_corner(dst_bounds, dst_res) if tac is not None: # Align the cells to target coordinates diff --git a/src/geowombat/core/api.py b/src/geowombat/core/api.py index 4a19ad0c..cf63bfe9 100644 --- a/src/geowombat/core/api.py +++ b/src/geowombat/core/api.py @@ -83,12 +83,12 @@ def get_image_chunks(filename: str) -> int: @get_image_chunks.register def _(filename: list) -> int: - return get_image_chunks(filename([0])) + return get_image_chunks(filename[0]) @get_image_chunks.register def _(filename: tuple) -> int: - return get_image_chunks(filename([0])) + return get_image_chunks(filename[0]) def _get_attrs(src, **kwargs): diff --git a/src/geowombat/core/sops.py b/src/geowombat/core/sops.py index 44d9c73a..84314b47 100644 --- a/src/geowombat/core/sops.py +++ b/src/geowombat/core/sops.py @@ -21,6 +21,7 @@ from pyproj.exceptions import CRSError from rasterio import features from rasterio.coords import BoundingBox +from rasterio.transform import array_bounds from scipy.spatial import cKDTree from scipy.stats import mode as sci_mode from shapely.geometry import Polygon @@ -40,7 +41,7 @@ except ImportError: PYMORPH_INSTALLED = False -from ..backends.rasterio_ import array_bounds, check_crs, get_dims_from_bounds +from ..backends.rasterio_ import check_crs, get_window_from_bounds from ..handler import add_handler from .base import PropertyMixin as _PropertyMixin from .base import _client_dummy, _cluster_dummy @@ -956,7 +957,7 @@ def clip_by_polygon( right=right, top=top, ) - align_height, align_width = get_dims_from_bounds( + align_window = get_window_from_bounds( dst_bounds, (data.gw.cellx, data.gw.celly) ) align_transform = Affine( @@ -964,7 +965,7 @@ def clip_by_polygon( ) # Get the new bounds new_left, new_bottom, new_right, new_top = array_bounds( - align_height, align_width, align_transform + align_window.height, align_window.width, align_transform ) if expand_by > 0: @@ -1057,7 +1058,7 @@ def clip( right=right, top=top, ) - align_height, align_width = get_dims_from_bounds( + align_window = get_window_from_bounds( dst_bounds, (data.gw.cellx, data.gw.celly) ) align_transform = Affine( @@ -1065,7 +1066,7 @@ def clip( ) # Get the new bounds new_left, new_bottom, new_right, new_top = array_bounds( - align_height, align_width, align_transform + align_window.height, align_window.width, align_transform ) if expand_by > 0: diff --git a/tests/test_rasterio.py b/tests/test_rasterio.py index caf7d2f2..635e4027 100644 --- a/tests/test_rasterio.py +++ b/tests/test_rasterio.py @@ -12,8 +12,8 @@ check_crs, check_file_crs, check_res, - get_dims_from_bounds, get_file_info, + get_window_from_bounds, unpack_bounding_box, unpack_window, window_to_bounds, @@ -54,15 +54,15 @@ def test_align_bounds(self): ), ) - def test_get_dims_from_bounds(self): + def test_get_window_from_bounds(self): bounds = BoundingBox( left=-100, bottom=-100, right=100, top=100, ) - height, width = get_dims_from_bounds(bounds=bounds, res=(10, 10)) - self.assertEqual((height, width), (20, 20)) + dst_window = get_window_from_bounds(bounds=bounds, res=(10, 10)) + self.assertEqual((dst_window.height, dst_window.width), (20, 20)) def test_get_file_info(self): with rio.open(l8_224077_20200518_B2) as src: From fa16e2ef3b206d8fa7cc20d4f48d8ea69b232f53 Mon Sep 17 00:00:00 2001 From: jgrss Date: Sat, 18 May 2024 15:13:22 +1000 Subject: [PATCH 22/26] ignore gdal --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 058b863d..6f8f38e9 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -42,10 +42,10 @@ jobs: pip install numpy GDAL_VERSION=$(gdal-config --version | awk -F'[.]' '{print $1"."$2}') pip install GDAL==$GDAL_VERSION --no-cache-dir --no-binary=gdal + pip install arosics --no-deps gdal - name: Install GeoWombat run: | pip install ".[stac,web,coreg,perf,tests]" - pip install arosics --no-deps - name: Run Unittests run: | pip install testfixtures From cc080874c3c7099ff7fb67281cbf0bee73dde1de Mon Sep 17 00:00:00 2001 From: jgrss Date: Sat, 18 May 2024 15:16:25 +1000 Subject: [PATCH 23/26] add geoarray to no-deps list --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 6f8f38e9..f14c0153 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -42,7 +42,7 @@ jobs: pip install numpy GDAL_VERSION=$(gdal-config --version | awk -F'[.]' '{print $1"."$2}') pip install GDAL==$GDAL_VERSION --no-cache-dir --no-binary=gdal - pip install arosics --no-deps gdal + pip install arosics --no-deps gdal geoarray - name: Install GeoWombat run: | pip install ".[stac,web,coreg,perf,tests]" From 47d4b49a6abbdb2487ef90826e5be9640226d80d Mon Sep 17 00:00:00 2001 From: jgrss Date: Sat, 18 May 2024 15:25:36 +1000 Subject: [PATCH 24/26] add geoarray to no-deps list --- .github/workflows/ci.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index f14c0153..1cb1c0be 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -41,8 +41,8 @@ jobs: pip install -U pip setuptools wheel pip install numpy GDAL_VERSION=$(gdal-config --version | awk -F'[.]' '{print $1"."$2}') - pip install GDAL==$GDAL_VERSION --no-cache-dir --no-binary=gdal - pip install arosics --no-deps gdal geoarray + pip install gdal==$GDAL_VERSION --no-cache-dir --no-binary=gdal + pip install arosics --no-deps - name: Install GeoWombat run: | pip install ".[stac,web,coreg,perf,tests]" From 843adaa44bb4659ff486a01b066b9067678e8ad1 Mon Sep 17 00:00:00 2001 From: jgrss Date: Sat, 18 May 2024 15:29:00 +1000 Subject: [PATCH 25/26] remove coreg tests --- .github/workflows/ci.yml | 3 +-- tests/{test_coreg.py => _test_coreg.py} | 8 ++++---- 2 files changed, 5 insertions(+), 6 deletions(-) rename tests/{test_coreg.py => _test_coreg.py} (94%) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 1cb1c0be..cc26ae0d 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -42,10 +42,9 @@ jobs: pip install numpy GDAL_VERSION=$(gdal-config --version | awk -F'[.]' '{print $1"."$2}') pip install gdal==$GDAL_VERSION --no-cache-dir --no-binary=gdal - pip install arosics --no-deps - name: Install GeoWombat run: | - pip install ".[stac,web,coreg,perf,tests]" + pip install ".[stac,web,perf,tests]" - name: Run Unittests run: | pip install testfixtures diff --git a/tests/test_coreg.py b/tests/_test_coreg.py similarity index 94% rename from tests/test_coreg.py rename to tests/_test_coreg.py index 26d7ea28..88e564ad 100644 --- a/tests/test_coreg.py +++ b/tests/_test_coreg.py @@ -1,13 +1,13 @@ -import unittest import tempfile +import unittest from pathlib import Path -import geowombat as gw -from geowombat.data import l8_224077_20200518_B2 -from geowombat.data import l8_224077_20200518_B4 import numpy as np import xarray as xr +import geowombat as gw +from geowombat.data import l8_224077_20200518_B2, l8_224077_20200518_B4 + def shift(data: xr.DataArray, x: int, y: int) -> xr.DataArray: return ( From 15ba05f8352931ddd9a98e1465fba4ee844b6d6d Mon Sep 17 00:00:00 2001 From: jgrss Date: Sun, 19 May 2024 19:44:09 +1000 Subject: [PATCH 26/26] add docstring --- src/geowombat/backends/rasterio_.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/geowombat/backends/rasterio_.py b/src/geowombat/backends/rasterio_.py index 346bad09..eeab6aa9 100644 --- a/src/geowombat/backends/rasterio_.py +++ b/src/geowombat/backends/rasterio_.py @@ -63,7 +63,7 @@ def transform_from_corner( def get_window_from_bounds( bounds: BoundingBox, res: T.Sequence[float] ) -> Window: - """Gets.""" + """Gets a ``rasterio.Window`` a bounding box.""" transform = transform_from_corner(bounds, res) return window_from_bounds(*bounds, transform=transform)