Skip to content

Commit

Permalink
Make astype set new nodata by default, and data.setter check fo…
Browse files Browse the repository at this point in the history
…r unmasked nodata (GlacioHack#472)
  • Loading branch information
rhugonnet authored Feb 5, 2024
1 parent 4411f1b commit b9b5a0f
Show file tree
Hide file tree
Showing 6 changed files with 214 additions and 91 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/python-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ jobs:
path: ${{ env.CONDA }}/envs
key: conda-${{ matrix.os }}-${{ matrix.python-version }}-${{ env.cache_date }}-${{ hashFiles('dev-environment.yml') }}-${{ env.CACHE_NUMBER }}
env:
CACHE_NUMBER: 1 # Increase this value to reset cache if environment.yml has not changed
CACHE_NUMBER: 0 # Increase this value to reset cache if environment.yml has not changed
id: cache

# The trick below is necessary because the generic environment file does not specify a Python version, and ONLY
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ examples/data/

# Directory where myst_nb executes jupyter code
doc/jupyter_execute/
doc/source/sg_execution_times.rst

# Files that should have been deleted by Sphinx at end of build (but can exist if build fails)
examples/io/open_save/myraster.tif
Expand Down
65 changes: 48 additions & 17 deletions geoutils/raster/raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,6 @@ def __init__(
bands: int | list[int] | None = None,
load_data: bool = False,
downsample: Number = 1,
masked: bool = True,
nodata: int | float | None = None,
) -> None:
"""
Expand All @@ -441,8 +440,6 @@ def __init__(
:param downsample: Downsample the array once loaded by a round factor. Default is no downsampling.
:param masked: Whether to load the array as a NumPy masked-array, with nodata values masked. Default is True.
:param nodata: Nodata value to be used (overwrites the metadata). Default reads from metadata.
"""
self._driver: str | None = None
Expand All @@ -456,7 +453,7 @@ def __init__(
self._nodata: int | float | None = nodata
self._bands = bands
self._bands_loaded: int | tuple[int, ...] | None = None
self._masked = masked
self._masked = True
self._out_count: int | None = None
self._out_shape: tuple[int, int] | None = None
self._disk_hash: int | None = None
Expand Down Expand Up @@ -549,7 +546,7 @@ def __init__(
self.data = _load_rio(
ds,
indexes=bands,
masked=masked,
masked=self._masked,
out_shape=out_shape,
out_count=count,
) # type: ignore
Expand Down Expand Up @@ -1080,8 +1077,12 @@ def _overloading_check(
out_nodata = None
if (nodata2 is not None) and (out_dtype == dtype2):
out_nodata = nodata2
if (nodata1 is not None) and (out_dtype == dtype1):
elif (nodata1 is not None) and (out_dtype == dtype1):
out_nodata = nodata1
# For some cases the promote_types is neither (uint8 and int8 = int16),
# And the minimum dtype of any integer is uint8
elif (nodata1 is not None) or (nodata2 is not None):
out_nodata = nodata1 if not None else nodata2

self_data = self.data

Expand Down Expand Up @@ -1340,18 +1341,25 @@ def __ge__(self: RasterType, other: RasterType | NDArrayNum | Number) -> RasterT
return out_mask

@overload
def astype(self, dtype: DTypeLike, inplace: Literal[False] = False) -> Raster:
def astype(self, dtype: DTypeLike, convert_nodata: bool = True, *, inplace: Literal[False] = False) -> Raster:
...

@overload
def astype(self, dtype: DTypeLike, convert_nodata: bool = True, *, inplace: Literal[True]) -> None:
...

@overload
def astype(self, dtype: DTypeLike, inplace: Literal[True]) -> None:
def astype(self, dtype: DTypeLike, convert_nodata: bool = True, *, inplace: bool = False) -> Raster | None:
...

def astype(self, dtype: DTypeLike, inplace: bool = False) -> Raster | None:
def astype(self, dtype: DTypeLike, convert_nodata: bool = True, inplace: bool = False) -> Raster | None:
"""
Convert data type of the raster.
By default, converts the nodata value to the default of the new data type.
:param dtype: Any numpy dtype or string accepted by numpy.astype.
:param convert_nodata: Whether to convert the nodata value to the default of the new dtype.
:param inplace: Whether to modify the raster in-place.
:returns: Raster with updated dtype (or None if inplace).
Expand All @@ -1368,11 +1376,18 @@ def astype(self, dtype: DTypeLike, inplace: bool = False) -> Raster | None:
)

out_data = self.data.astype(dtype)

if inplace:
self._data = out_data # type: ignore
if convert_nodata:
self.set_nodata(new_nodata=_default_nodata(dtype))
return None
else:
return self.from_array(out_data, self.transform, self.crs, nodata=self.nodata)
if not convert_nodata:
nodata = self.nodata
else:
nodata = _default_nodata(dtype)
return self.from_array(out_data, self.transform, self.crs, nodata=nodata)

@property
def is_modified(self) -> bool:
Expand Down Expand Up @@ -1472,21 +1487,22 @@ def set_nodata(
if np.count_nonzero(index_new_nodatas) > 0:
if update_array and update_mask:
warnings.warn(
message="New nodata value found in the data array. Those will be masked, and the old "
"nodata cells will now take the same value. Use set_nodata() with update_array=False "
"and/or update_mask=False to change this behaviour.",
message="New nodata value cells already exist in the data array. These cells will now be "
"masked, and the old nodata value cells will update to the same new value. "
"Use set_nodata() with update_array=False or update_mask=False to change "
"this behaviour.",
category=UserWarning,
)
elif update_array:
warnings.warn(
"New nodata value found in the data array. The old nodata cells will now take the same "
"value. Use set_nodata() with update_array=False to change this behaviour.",
"New nodata value cells already exist in the data array. The old nodata cells will update to "
"the same new value. Use set_nodata() with update_array=False to change this behaviour.",
category=UserWarning,
)
elif update_mask:
warnings.warn(
"New nodata value found in the data array. Those will be masked. Use set_nodata() "
"with update_mask=False to change this behaviour.",
"New nodata value cells already exist in the data array. These cells will now be masked. "
"Use set_nodata() with update_mask=False to change this behaviour.",
category=UserWarning,
)

Expand Down Expand Up @@ -1623,6 +1639,21 @@ def data(self, new_data: NDArrayNum | MArrayNum) -> None:
else:
self._data = np.ma.masked_array(data=new_data, fill_value=self.nodata)

# Finally, mask values equal to the nodata value in case they weren't masked, but raise a warning
if np.count_nonzero(np.logical_and(~self._data.mask, self._data.data == self.nodata)) > 0:
# This can happen during a numerical operation, especially for integer values that max out with a modulo
# It can also happen with from_array()
warnings.warn(
category=UserWarning,
message="Unmasked values equal to the nodata value found in data array. They are now masked.\n "
"If this happened when creating or updating the array, to silence this warning, "
"convert nodata values in the array to np.nan or mask them with np.ma.masked prior "
"to creating or updating the raster.\n"
"If this happened during a numerical operation, use astype() prior to the operation "
"to convert to a data type that won't derive the nodata values (e.g., a float type).",
)
self._data[self._data.data == self.nodata] = np.ma.masked

@property
def transform(self) -> affine.Affine:
"""
Expand Down
12 changes: 9 additions & 3 deletions tests/test_geoviewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@ def test_geoviewer_valid_1band(capsys, monkeypatch, filename, option): # type:

# The everest example will raise errors when setting a nodata value that exists
if "B4" in os.path.basename(filename) and len(option) > 0 and option[0] == "-nodata":
warnings.filterwarnings("ignore", category=UserWarning, message="New nodata value found in the data array.*")
warnings.filterwarnings(
"ignore", category=UserWarning, message="New nodata value cells already exist in the data array.*"
)

# To not get exception when testing generic functions such as --help
try:
Expand Down Expand Up @@ -89,7 +91,9 @@ def test_geoviewer_invalid_1band(capsys, monkeypatch, filename, args): # type:

# The everest example will raise errors when setting a nodata value that exists
if "B4" in os.path.basename(filename) and len(args) > 0 and args[0] == "-nodata":
warnings.filterwarnings("ignore", category=UserWarning, message="New nodata value found in the data array.*")
warnings.filterwarnings(
"ignore", category=UserWarning, message="New nodata value cells already exist in the data array.*"
)

# To not get exception when testing generic functions such as --help
option, error = args
Expand Down Expand Up @@ -120,7 +124,9 @@ def test_geoviewer_valid_3band(capsys, monkeypatch, filename, option): # type:

# The everest RGB example will raise errors when setting a nodata value that exists
if "RGB" in os.path.basename(filename) and len(option) > 0 and option[0] == "-nodata":
warnings.filterwarnings("ignore", category=UserWarning, message="New nodata value found in the data array.*")
warnings.filterwarnings(
"ignore", category=UserWarning, message="New nodata value cells already exist in the data array.*"
)

# To not get exception when testing generic functions such as --help
try:
Expand Down
8 changes: 6 additions & 2 deletions tests/test_multiraster.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,9 @@ def test_stack_rasters(self, rasters) -> None: # type: ignore
"""Test stack_rasters"""

# Silence the reprojection warning for default nodata value
warnings.filterwarnings("ignore", category=UserWarning, message="New nodata value found in the data array.*")
warnings.filterwarnings(
"ignore", category=UserWarning, message="New nodata value cells already exist in the data array.*"
)
warnings.filterwarnings("ignore", category=UserWarning, message="For reprojection, nodata must be set.*")

# Merge the two overlapping DEMs and check that output bounds and shape is correct
Expand Down Expand Up @@ -171,7 +173,9 @@ def test_merge_rasters(self, rasters) -> None: # type: ignore
# Merge the two overlapping DEMs and check that it closely resembles the initial DEM

# Silence the reprojection warning for default nodata value
warnings.filterwarnings("ignore", category=UserWarning, message="New nodata value found in the data array.*")
warnings.filterwarnings(
"ignore", category=UserWarning, message="New nodata value cells already exist in the data array.*"
)
warnings.filterwarnings("ignore", category=UserWarning, message="For reprojection, nodata must be set.*")

# Ignore warning already checked in test_stack_rasters
Expand Down
Loading

0 comments on commit b9b5a0f

Please sign in to comment.