From 02df71691c188fc113a6f4306a752cf30293795f Mon Sep 17 00:00:00 2001 From: vschaffn Date: Mon, 9 Dec 2024 11:28:22 +0100 Subject: [PATCH] fix: separate nmad in geoutils.stats, add tests, overload raster.get_stats() --- geoutils/raster/raster.py | 68 ++++++++++++++++++++++++++------------- geoutils/stats.py | 26 +++++++++++++++ tests/test_stats.py | 30 +++++++++++++++++ 3 files changed, 101 insertions(+), 23 deletions(-) create mode 100644 geoutils/stats.py create mode 100644 tests/test_stats.py diff --git a/geoutils/raster/raster.py b/geoutils/raster/raster.py index 58b76011..7edc2e85 100644 --- a/geoutils/raster/raster.py +++ b/geoutils/raster/raster.py @@ -70,6 +70,7 @@ decode_sensor_metadata, parse_and_convert_metadata_from_filename, ) +from geoutils.stats import nmad from geoutils.vector.vector import Vector # If python38 or above, Literal is builtin. Otherwise, use typing_extensions @@ -1898,16 +1899,49 @@ def _statistics(self, band: int = 1) -> dict[str, np.floating[Any]]: "Sum": np.nansum(data), "Sum of squares": np.nansum(np.square(data)), "90th percentile": np.nanpercentile(data, 90), - "NMAD": self._nmad(), + "NMAD": nmad(data), "RMSE": np.sqrt(np.nanmean(np.square(data - np.nanmean(data)))), "Standard deviation": np.nanstd(data), } return stats_dict + @overload + def get_stats( + self, + stats_name: ( + Literal["mean", "median", "max", "min", "sum", "sum of squares", "90th percentile", "nmad", "rmse", "std"] + | Callable[[NDArrayNum], np.floating[Any]] + ), + band: int = 1, + ) -> np.floating[Any]: ... + + @overload + def get_stats( + self, + stats_name: ( + list[ + Literal[ + "mean", "median", "max", "min", "sum", "sum of squares", "90th percentile", "nmad", "rmse", "std" + ] + | Callable[[NDArrayNum], np.floating[Any]] + ] + | None + ) = None, + band: int = 1, + ) -> dict[str, np.floating[Any]]: ... + def get_stats( self, stats_name: ( - str | Callable[[NDArrayNum], np.floating[Any]] | list[str | Callable[[NDArrayNum], np.floating[Any]]] | None + Literal["mean", "median", "max", "min", "sum", "sum of squares", "90th percentile", "nmad", "rmse", "std"] + | Callable[[NDArrayNum], np.floating[Any]] + | list[ + Literal[ + "mean", "median", "max", "min", "sum", "sum of squares", "90th percentile", "nmad", "rmse", "std" + ] + | Callable[[NDArrayNum], np.floating[Any]] + ] + | None ) = None, band: int = 1, ) -> np.floating[Any] | dict[str, np.floating[Any]]: @@ -1916,6 +1950,10 @@ def get_stats( to calculate custom stats. :param stats_name: Name or list of names of the statistics to retrieve. If None, all statistics are returned. + Accepted names include: + - "mean", "median", "max", "min", "sum", "sum of squares", "90th percentile", "nmad", "rmse", "std" + You can also use common aliases for these names (e.g., "average", "maximum", "minimum", etc.). + Custom callables can also be provided. :param band: The index of the band for which to compute statistics. Default is 1. :returns: The requested statistic or a dictionary of statistics if multiple or all are requested. @@ -1985,21 +2023,6 @@ def _get_single_stat( logging.warning("Statistic name '%s' is not recognized", stat_name) return np.float32(np.nan) - def _nmad(self, nfact: float = 1.4826, band: int = 0) -> np.floating[Any]: - """ - Calculate the normalized median absolute deviation (NMAD) of an array. - Default scaling factor is 1.4826 to scale the median absolute deviation (MAD) to the dispersion of a normal - distribution (see https://en.wikipedia.org/wiki/Median_absolute_deviation#Relation_to_standard_deviation, and - e.g. Höhle and Höhle (2009), http://dx.doi.org/10.1016/j.isprsjprs.2009.02.003) - """ - if self.count == 1: - data = self.data - else: - data = self.data[band] - if isinstance(data, np.ma.MaskedArray): - data = data.compressed() - return nfact * np.nanmedian(np.abs(data - np.nanmedian(data))) - @overload def info(self, stats: bool = False, *, verbose: Literal[True] = ...) -> None: ... @@ -2040,14 +2063,13 @@ def info(self, stats: bool = False, verbose: bool = True) -> None | str: if self.count == 1: statistics = self.get_stats() - if isinstance(statistics, dict): - # Determine the maximum length of the stat names for alignment - max_len = max(len(name) for name in statistics.keys()) + # Determine the maximum length of the stat names for alignment + max_len = max(len(name) for name in statistics.keys()) - # Format the stats with aligned names - for name, value in statistics.items(): - as_str.append(f"{name.ljust(max_len)}: {value:.2f}\n") + # Format the stats with aligned names + for name, value in statistics.items(): + as_str.append(f"{name.ljust(max_len)}: {value:.2f}\n") else: for b in range(self.count): # try to keep with rasterio convention. diff --git a/geoutils/stats.py b/geoutils/stats.py new file mode 100644 index 00000000..838c701c --- /dev/null +++ b/geoutils/stats.py @@ -0,0 +1,26 @@ +""" Statistical tools""" + +from typing import Any + +import numpy as np + +from geoutils._typing import NDArrayNum + + +def nmad(data: NDArrayNum, nfact: float = 1.4826) -> np.floating[Any]: + """ + Calculate the normalized median absolute deviation (NMAD) of an array. + Default scaling factor is 1.4826 to scale the median absolute deviation (MAD) to the dispersion of a normal + distribution (see https://en.wikipedia.org/wiki/Median_absolute_deviation#Relation_to_standard_deviation, and + e.g. Höhle and Höhle (2009), http://dx.doi.org/10.1016/j.isprsjprs.2009.02.003) + + :param data: Input array or raster + :param nfact: Normalization factor for the data + + :returns nmad: (normalized) median absolute deviation of data. + """ + if isinstance(data, np.ma.masked_array): + data_arr = data.compressed() + else: + data_arr = np.asarray(data) + return nfact * np.nanmedian(np.abs(data_arr - np.nanmedian(data_arr))) diff --git a/tests/test_stats.py b/tests/test_stats.py new file mode 100644 index 00000000..dbfb5118 --- /dev/null +++ b/tests/test_stats.py @@ -0,0 +1,30 @@ +""" +Test functions for stats +""" + +import scipy + +from geoutils import Raster, examples +from geoutils.stats import nmad + + +class TestStats: + landsat_b4_path = examples.get_path("everest_landsat_b4") + landsat_raster = Raster(landsat_b4_path) + + def test_nmad(self) -> None: + """Test NMAD functionality runs on any type of input""" + + # Check that the NMAD is computed the same with a masked array or NaN array, and is equal to scipy nmad + nmad_ma = nmad(self.landsat_raster.data) + nmad_array = nmad(self.landsat_raster.get_nanarray()) + nmad_scipy = scipy.stats.median_abs_deviation(self.landsat_raster.data, axis=None, scale="normal") + + assert nmad_ma == nmad_array + assert nmad_ma.round(2) == nmad_scipy.round(2) + + # Check that the scaling factor works + nmad_1 = nmad(self.landsat_raster.data, nfact=1) + nmad_2 = nmad(self.landsat_raster.data, nfact=2) + + assert nmad_1 * 2 == nmad_2