diff --git a/setup.py b/setup.py index ab6ce73..8af797d 100644 --- a/setup.py +++ b/setup.py @@ -28,6 +28,8 @@ "absl-py", "pytest", "pyink", + "rioxarray", + "rasterio", ] examples_require = [ diff --git a/xee/ext.py b/xee/ext.py index 99a826f..db5ff66 100644 --- a/xee/ext.py +++ b/xee/ext.py @@ -253,6 +253,12 @@ def __init__( else: self.mask_value = mask_value + self.coordinate_transformer = pyproj.Transformer.from_crs( + "EPSG:4326", + self.crs, + always_xy=True + ) + @functools.cached_property def get_info(self) -> dict[str, Any]: """Make all getInfo() calls to EE at once.""" @@ -555,11 +561,13 @@ def get_variables(self) -> utils.Frozen[str, xarray.Variable]: lon_grid = self.project((0, 0, v0.shape[1], 1)) lat_grid = self.project((0, 0, 1, v0.shape[2])) lon = self.image_to_array( - lnglat_img, grid=lon_grid, dtype=np.float32, bandIds=['longitude'] + lnglat_img, grid=lon_grid, dtype=np.float32, bandIds=['longitude', 'latitude'] ) + lon = self.coordinate_transformer.transform(lon[0], lon[1])[0] lat = self.image_to_array( - lnglat_img, grid=lat_grid, dtype=np.float32, bandIds=['latitude'] + lnglat_img, grid=lat_grid, dtype=np.float32, bandIds=['longitude', 'latitude'] ) + lat = self.coordinate_transformer.transform(lat[0], lat[1])[1] width_coord = np.squeeze(lon) height_coord = np.squeeze(lat) diff --git a/xee/ext_integration_test.py b/xee/ext_integration_test.py index e340f52..7b4b95c 100644 --- a/xee/ext_integration_test.py +++ b/xee/ext_integration_test.py @@ -22,6 +22,10 @@ import numpy as np import xarray as xr from xarray.core import indexing +import os +import rioxarray +import rasterio +import tempfile import xee import ee @@ -397,6 +401,76 @@ def test_validate_band_attrs(self): for _, value in variable.attrs.items(): self.assertIsInstance(value, valid_types) + def test_write_projected_dataset_to_raster(self): + # ensure that a projected dataset written to a raster intersects with the + # point used to create the initial image collection + with tempfile.TemporaryDirectory() as temp_dir: + temp_file = os.path.join(temp_dir, "test.tif") + + crs = "epsg:32610" + proj = ee.Projection(crs) + point = ee.Geometry.Point([-122.44, 37.78]) + geom = point.buffer(1024).bounds() + + col = ee.ImageCollection("COPERNICUS/S2_SR_HARMONIZED") + col = col.filterBounds(point) + col = col.filter(ee.Filter.lte("CLOUDY_PIXEL_PERCENTAGE", 5)) + col = col.limit(10) + + ds = xr.open_dataset( + col, + engine=xee.EarthEngineBackendEntrypoint, + scale=10, + crs=crs, + geometry=geom, + ) + + ds = ds.isel(time=0).transpose("Y", "X") + ds.rio.set_spatial_dims(x_dim="X", y_dim="Y", inplace=True) + ds.rio.write_crs(crs, inplace=True) + ds.rio.reproject(crs, inplace=True) + ds.rio.to_raster(temp_file) + + with rasterio.open(temp_file) as raster: + # see https://gis.stackexchange.com/a/407755 for evenOdd explanation + bbox = ee.Geometry.Rectangle(raster.bounds, proj=proj, evenOdd=False) + intersects = bbox.intersects(point, 1, proj=proj) + self.assertTrue(intersects.getInfo()) + + def test_write_dataset_to_raster(self): + # ensure that a dataset written to a raster intersects with the point used + # to create the initial image collection + with tempfile.TemporaryDirectory() as temp_dir: + temp_file = os.path.join(temp_dir, "test.tif") + + crs = "EPSG:4326" + proj = ee.Projection(crs) + point = ee.Geometry.Point([-122.44, 37.78]) + geom = point.buffer(1024).bounds() + + col = ee.ImageCollection("COPERNICUS/S2_SR_HARMONIZED") + col = col.filterBounds(point) + col = col.filter(ee.Filter.lte("CLOUDY_PIXEL_PERCENTAGE", 5)) + col = col.limit(10) + + ds = xr.open_dataset( + col, + engine=xee.EarthEngineBackendEntrypoint, + scale=0.0025, + geometry=geom, + ) + + ds = ds.isel(time=0).transpose("lat", "lon") + ds.rio.set_spatial_dims(x_dim="lon", y_dim="lat", inplace=True) + ds.rio.write_crs(crs, inplace=True) + ds.rio.reproject(crs, inplace=True) + ds.rio.to_raster(temp_file) + + with rasterio.open(temp_file) as raster: + # see https://gis.stackexchange.com/a/407755 for evenOdd explanation + bbox = ee.Geometry.Rectangle(raster.bounds, proj=proj, evenOdd=False) + intersects = bbox.intersects(point, 1, proj=proj) + self.assertTrue(intersects.getInfo()) if __name__ == '__main__': absltest.main()