Skip to content

Commit

Permalink
transform dataset coordinates into specified crs
Browse files Browse the repository at this point in the history
  • Loading branch information
boothmanrylan committed Nov 9, 2023
1 parent 07b8676 commit f1b63cf
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 2 deletions.
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
"absl-py",
"pytest",
"pyink",
"rioxarray",
"rasterio",
]

examples_require = [
Expand Down
12 changes: 10 additions & 2 deletions xee/ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,12 @@ def __init__(
else:
self.mask_value = mask_value

self.coordinate_transformer = pyproj.Transformer.from_crs(
"EPSG:4326",
self.crs,
always_xy=True
)

@functools.cached_property
def get_info(self) -> dict[str, Any]:
"""Make all getInfo() calls to EE at once."""
Expand Down Expand Up @@ -555,11 +561,13 @@ def get_variables(self) -> utils.Frozen[str, xarray.Variable]:
lon_grid = self.project((0, 0, v0.shape[1], 1))
lat_grid = self.project((0, 0, 1, v0.shape[2]))
lon = self.image_to_array(
lnglat_img, grid=lon_grid, dtype=np.float32, bandIds=['longitude']
lnglat_img, grid=lon_grid, dtype=np.float32, bandIds=['longitude', 'latitude']
)
lon = self.coordinate_transformer.transform(lon[0], lon[1])[0]
lat = self.image_to_array(
lnglat_img, grid=lat_grid, dtype=np.float32, bandIds=['latitude']
lnglat_img, grid=lat_grid, dtype=np.float32, bandIds=['longitude', 'latitude']
)
lat = self.coordinate_transformer.transform(lat[0], lat[1])[1]
width_coord = np.squeeze(lon)
height_coord = np.squeeze(lat)

Expand Down
74 changes: 74 additions & 0 deletions xee/ext_integration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@
import numpy as np
import xarray as xr
from xarray.core import indexing
import os
import rioxarray
import rasterio
import tempfile
import xee

import ee
Expand Down Expand Up @@ -397,6 +401,76 @@ def test_validate_band_attrs(self):
for _, value in variable.attrs.items():
self.assertIsInstance(value, valid_types)

def test_write_projected_dataset_to_raster(self):
# ensure that a projected dataset written to a raster intersects with the
# point used to create the initial image collection
with tempfile.TemporaryDirectory() as temp_dir:
temp_file = os.path.join(temp_dir, "test.tif")

crs = "epsg:32610"
proj = ee.Projection(crs)
point = ee.Geometry.Point([-122.44, 37.78])
geom = point.buffer(1024).bounds()

col = ee.ImageCollection("COPERNICUS/S2_SR_HARMONIZED")
col = col.filterBounds(point)
col = col.filter(ee.Filter.lte("CLOUDY_PIXEL_PERCENTAGE", 5))
col = col.limit(10)

ds = xr.open_dataset(
col,
engine=xee.EarthEngineBackendEntrypoint,
scale=10,
crs=crs,
geometry=geom,
)

ds = ds.isel(time=0).transpose("Y", "X")
ds.rio.set_spatial_dims(x_dim="X", y_dim="Y", inplace=True)
ds.rio.write_crs(crs, inplace=True)
ds.rio.reproject(crs, inplace=True)
ds.rio.to_raster(temp_file)

with rasterio.open(temp_file) as raster:
# see https://gis.stackexchange.com/a/407755 for evenOdd explanation
bbox = ee.Geometry.Rectangle(raster.bounds, proj=proj, evenOdd=False)
intersects = bbox.intersects(point, 1, proj=proj)
self.assertTrue(intersects.getInfo())

def test_write_dataset_to_raster(self):
# ensure that a dataset written to a raster intersects with the point used
# to create the initial image collection
with tempfile.TemporaryDirectory() as temp_dir:
temp_file = os.path.join(temp_dir, "test.tif")

crs = "EPSG:4326"
proj = ee.Projection(crs)
point = ee.Geometry.Point([-122.44, 37.78])
geom = point.buffer(1024).bounds()

col = ee.ImageCollection("COPERNICUS/S2_SR_HARMONIZED")
col = col.filterBounds(point)
col = col.filter(ee.Filter.lte("CLOUDY_PIXEL_PERCENTAGE", 5))
col = col.limit(10)

ds = xr.open_dataset(
col,
engine=xee.EarthEngineBackendEntrypoint,
scale=0.0025,
geometry=geom,
)

ds = ds.isel(time=0).transpose("lat", "lon")
ds.rio.set_spatial_dims(x_dim="lon", y_dim="lat", inplace=True)
ds.rio.write_crs(crs, inplace=True)
ds.rio.reproject(crs, inplace=True)
ds.rio.to_raster(temp_file)

with rasterio.open(temp_file) as raster:
# see https://gis.stackexchange.com/a/407755 for evenOdd explanation
bbox = ee.Geometry.Rectangle(raster.bounds, proj=proj, evenOdd=False)
intersects = bbox.intersects(point, 1, proj=proj)
self.assertTrue(intersects.getInfo())

if __name__ == '__main__':
absltest.main()

0 comments on commit f1b63cf

Please sign in to comment.