Skip to content

Commit

Permalink
fix: separate nmad in geoutils.stats, add tests, overload raster.get_…
Browse files Browse the repository at this point in the history
…stats()
  • Loading branch information
vschaffn committed Dec 9, 2024
1 parent c99e5ff commit 02df716
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 23 deletions.
68 changes: 45 additions & 23 deletions geoutils/raster/raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
decode_sensor_metadata,
parse_and_convert_metadata_from_filename,
)
from geoutils.stats import nmad
from geoutils.vector.vector import Vector

# If python38 or above, Literal is builtin. Otherwise, use typing_extensions
Expand Down Expand Up @@ -1898,16 +1899,49 @@ def _statistics(self, band: int = 1) -> dict[str, np.floating[Any]]:
"Sum": np.nansum(data),
"Sum of squares": np.nansum(np.square(data)),
"90th percentile": np.nanpercentile(data, 90),
"NMAD": self._nmad(),
"NMAD": nmad(data),
"RMSE": np.sqrt(np.nanmean(np.square(data - np.nanmean(data)))),
"Standard deviation": np.nanstd(data),
}
return stats_dict

@overload
def get_stats(
self,
stats_name: (
Literal["mean", "median", "max", "min", "sum", "sum of squares", "90th percentile", "nmad", "rmse", "std"]
| Callable[[NDArrayNum], np.floating[Any]]
),
band: int = 1,
) -> np.floating[Any]: ...

@overload
def get_stats(
self,
stats_name: (
list[
Literal[
"mean", "median", "max", "min", "sum", "sum of squares", "90th percentile", "nmad", "rmse", "std"
]
| Callable[[NDArrayNum], np.floating[Any]]
]
| None
) = None,
band: int = 1,
) -> dict[str, np.floating[Any]]: ...

def get_stats(
self,
stats_name: (
str | Callable[[NDArrayNum], np.floating[Any]] | list[str | Callable[[NDArrayNum], np.floating[Any]]] | None
Literal["mean", "median", "max", "min", "sum", "sum of squares", "90th percentile", "nmad", "rmse", "std"]
| Callable[[NDArrayNum], np.floating[Any]]
| list[
Literal[
"mean", "median", "max", "min", "sum", "sum of squares", "90th percentile", "nmad", "rmse", "std"
]
| Callable[[NDArrayNum], np.floating[Any]]
]
| None
) = None,
band: int = 1,
) -> np.floating[Any] | dict[str, np.floating[Any]]:
Expand All @@ -1916,6 +1950,10 @@ def get_stats(
to calculate custom stats.
:param stats_name: Name or list of names of the statistics to retrieve. If None, all statistics are returned.
Accepted names include:
- "mean", "median", "max", "min", "sum", "sum of squares", "90th percentile", "nmad", "rmse", "std"
You can also use common aliases for these names (e.g., "average", "maximum", "minimum", etc.).
Custom callables can also be provided.
:param band: The index of the band for which to compute statistics. Default is 1.
:returns: The requested statistic or a dictionary of statistics if multiple or all are requested.
Expand Down Expand Up @@ -1985,21 +2023,6 @@ def _get_single_stat(
logging.warning("Statistic name '%s' is not recognized", stat_name)
return np.float32(np.nan)

def _nmad(self, nfact: float = 1.4826, band: int = 0) -> np.floating[Any]:
"""
Calculate the normalized median absolute deviation (NMAD) of an array.
Default scaling factor is 1.4826 to scale the median absolute deviation (MAD) to the dispersion of a normal
distribution (see https://en.wikipedia.org/wiki/Median_absolute_deviation#Relation_to_standard_deviation, and
e.g. Höhle and Höhle (2009), http://dx.doi.org/10.1016/j.isprsjprs.2009.02.003)
"""
if self.count == 1:
data = self.data
else:
data = self.data[band]
if isinstance(data, np.ma.MaskedArray):
data = data.compressed()
return nfact * np.nanmedian(np.abs(data - np.nanmedian(data)))

@overload
def info(self, stats: bool = False, *, verbose: Literal[True] = ...) -> None: ...

Expand Down Expand Up @@ -2040,14 +2063,13 @@ def info(self, stats: bool = False, verbose: bool = True) -> None | str:

if self.count == 1:
statistics = self.get_stats()
if isinstance(statistics, dict):

# Determine the maximum length of the stat names for alignment
max_len = max(len(name) for name in statistics.keys())
# Determine the maximum length of the stat names for alignment
max_len = max(len(name) for name in statistics.keys())

# Format the stats with aligned names
for name, value in statistics.items():
as_str.append(f"{name.ljust(max_len)}: {value:.2f}\n")
# Format the stats with aligned names
for name, value in statistics.items():
as_str.append(f"{name.ljust(max_len)}: {value:.2f}\n")
else:
for b in range(self.count):
# try to keep with rasterio convention.
Expand Down
26 changes: 26 additions & 0 deletions geoutils/stats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
""" Statistical tools"""

from typing import Any

import numpy as np

from geoutils._typing import NDArrayNum


def nmad(data: NDArrayNum, nfact: float = 1.4826) -> np.floating[Any]:
"""
Calculate the normalized median absolute deviation (NMAD) of an array.
Default scaling factor is 1.4826 to scale the median absolute deviation (MAD) to the dispersion of a normal
distribution (see https://en.wikipedia.org/wiki/Median_absolute_deviation#Relation_to_standard_deviation, and
e.g. Höhle and Höhle (2009), http://dx.doi.org/10.1016/j.isprsjprs.2009.02.003)
:param data: Input array or raster
:param nfact: Normalization factor for the data
:returns nmad: (normalized) median absolute deviation of data.
"""
if isinstance(data, np.ma.masked_array):
data_arr = data.compressed()
else:
data_arr = np.asarray(data)
return nfact * np.nanmedian(np.abs(data_arr - np.nanmedian(data_arr)))
30 changes: 30 additions & 0 deletions tests/test_stats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
"""
Test functions for stats
"""

import scipy

from geoutils import Raster, examples
from geoutils.stats import nmad


class TestStats:
landsat_b4_path = examples.get_path("everest_landsat_b4")
landsat_raster = Raster(landsat_b4_path)

def test_nmad(self) -> None:
"""Test NMAD functionality runs on any type of input"""

# Check that the NMAD is computed the same with a masked array or NaN array, and is equal to scipy nmad
nmad_ma = nmad(self.landsat_raster.data)
nmad_array = nmad(self.landsat_raster.get_nanarray())
nmad_scipy = scipy.stats.median_abs_deviation(self.landsat_raster.data, axis=None, scale="normal")

assert nmad_ma == nmad_array
assert nmad_ma.round(2) == nmad_scipy.round(2)

# Check that the scaling factor works
nmad_1 = nmad(self.landsat_raster.data, nfact=1)
nmad_2 = nmad(self.landsat_raster.data, nfact=2)

assert nmad_1 * 2 == nmad_2

0 comments on commit 02df716

Please sign in to comment.