diff --git a/modelskill/comparison/_collection_plotter.py b/modelskill/comparison/_collection_plotter.py index 0138b9ef3..87958e77f 100644 --- a/modelskill/comparison/_collection_plotter.py +++ b/modelskill/comparison/_collection_plotter.py @@ -1,6 +1,7 @@ from __future__ import annotations from typing import Any, List, Union, Optional, Tuple, Sequence, TYPE_CHECKING from matplotlib.axes import Axes # type: ignore +import warnings if TYPE_CHECKING: from ._collection import ComparerCollection @@ -44,7 +45,7 @@ def scatter( xlabel: Optional[str] = None, ylabel: Optional[str] = None, skill_table: Optional[Union[str, List[str], bool]] = None, - ax: Optional[Axes] = None, + ax=None, **kwargs, ): """Scatter plot showing compared data: observation vs modelled @@ -113,11 +114,72 @@ def scatter( >>> cc.plot.scatter(observations=['c2','HKNA']) """ - # select model - mod_id = _get_idx(model, self.cc.mod_names) - mod_name = self.cc.mod_names[mod_id] + cc = self.cc + if model is None: + mod_names = cc.mod_names + else: + warnings.warn( + "The 'model' keyword is deprecated! Instead, filter comparer before plotting cmp.sel(model=...).plot.scatter()", + FutureWarning, + ) + + model_list = [model] if isinstance(model, (str, int)) else model + mod_names = [ + self.cc.mod_names[_get_idx(m, self.cc.mod_names)] for m in model_list + ] + + axes = [] + for mod_name in mod_names: + ax_mod = self._scatter_one_model( + mod_name=mod_name, + bins=bins, + quantiles=quantiles, + fit_to_quantiles=fit_to_quantiles, + show_points=show_points, + show_hist=show_hist, + show_density=show_density, + backend=backend, + figsize=figsize, + xlim=xlim, + ylim=ylim, + reg_method=reg_method, + title=title, + xlabel=xlabel, + ylabel=ylabel, + skill_table=skill_table, + ax=ax, + **kwargs, + ) + axes.append(ax_mod) + return axes[0] if len(axes) == 1 else axes - cmp = self.cc + def _scatter_one_model( + self, + *, + mod_name: str, + bins: int | float, + quantiles: int | Sequence[float] | None, + fit_to_quantiles: bool, + show_points: bool | int | float | None, + show_hist: Optional[bool], + show_density: Optional[bool], + backend: str, + figsize: Tuple[float, float], + xlim: Optional[Tuple[float, float]], + ylim: Optional[Tuple[float, float]], + reg_method: str | bool, + title: Optional[str], + xlabel: Optional[str], + ylabel: Optional[str], + skill_table: Optional[Union[str, List[str], bool]], + ax, + **kwargs, + ): + assert ( + mod_name in self.cc.mod_names + ), f"Model {mod_name} not found in collection {self.cc.mod_names}" + + cmp = self.cc.sel(model=mod_name) if cmp.n_points == 0: raise ValueError("No data found in selection") @@ -183,7 +245,7 @@ def scatter( return ax - def kde(self, ax=None, figsize=None, title=None, **kwargs) -> Axes: + def kde(self, *, ax=None, figsize=None, title=None, **kwargs) -> Axes: """Plot kernel density estimate of observation and model data. Parameters @@ -247,10 +309,11 @@ def kde(self, ax=None, figsize=None, title=None, **kwargs) -> Axes: def hist( self, - model=None, - bins=100, + bins: int | Sequence = 100, + *, + model: str | int | None = None, title: Optional[str] = None, - density=True, + density: bool = True, alpha: float = 0.5, ax=None, figsize: Optional[Tuple[float, float]] = None, @@ -262,8 +325,6 @@ def hist( Parameters ---------- - model : str, optional - model name, by default None, i.e. the first model bins : int, optional number of bins, by default 100 title : str, optional @@ -292,12 +353,53 @@ def hist( pandas.Series.hist matplotlib.axes.Axes.hist """ + if model is None: + mod_names = self.cc.mod_names + else: + warnings.warn( + "The 'model' keyword is deprecated! Instead, filter comparer before plotting cmp.sel(model=...).plot.hist()", + FutureWarning, + ) + model_list = [model] if isinstance(model, (str, int)) else model + mod_names = [ + self.cc.mod_names[_get_idx(m, self.cc.mod_names)] for m in model_list + ] + + axes = [] + for mod_name in mod_names: + ax_mod = self._hist_one_model( + mod_name=mod_name, + bins=bins, + title=title, + density=density, + alpha=alpha, + ax=ax, + figsize=figsize, + **kwargs, + ) + axes.append(ax_mod) + return axes[0] if len(axes) == 1 else axes + + def _hist_one_model( + self, + *, + mod_name: str, + bins: int | Sequence, + title: Optional[str], + density: bool, + alpha: float, + ax, + figsize: Optional[Tuple[float, float]], + **kwargs, + ): from ._comparison import MOD_COLORS _, ax = _get_fig_ax(ax, figsize) - mod_id = _get_idx(model, self.cc.mod_names) - mod_name = self.cc.mod_names[mod_id] + assert ( + mod_name in self.cc.mod_names + ), f"Model {mod_name} not found in collection" + mod_id = _get_idx(mod_name, self.cc.mod_names) title = ( _default_univarate_title("Histogram", self.cc) if title is None else title @@ -331,6 +433,7 @@ def hist( def taylor( self, + *, normalize_std: bool = False, aggregate_observations: bool = True, figsize: Tuple[float, float] = (7, 7), diff --git a/modelskill/comparison/_comparer_plotter.py b/modelskill/comparison/_comparer_plotter.py index 8889fb2c3..305ea526c 100644 --- a/modelskill/comparison/_comparer_plotter.py +++ b/modelskill/comparison/_comparer_plotter.py @@ -1,5 +1,6 @@ from __future__ import annotations from typing import Union, List, Optional, Tuple, Sequence, TYPE_CHECKING +import warnings if TYPE_CHECKING: import matplotlib.figure @@ -34,12 +35,12 @@ def __call__(self, *args, **kwargs): def timeseries( self, - title=None, *, - ylim=None, + title: str | None = None, + ylim: Tuple[float, float] | None = None, ax=None, - figsize=None, - backend="matplotlib", + figsize: Tuple[float, float] | None = None, + backend: str = "matplotlib", **kwargs, ): """Timeseries plot showing compared data: observation vs modelled @@ -48,7 +49,7 @@ def timeseries( ---------- title : str, optional plot title, by default None - ylim : tuple, optional + ylim : (float, float), optional plot range for the model (ymin, ymax), by default None ax : matplotlib.axes.Axes, optional axes to plot on, by default None @@ -127,14 +128,14 @@ def timeseries( def hist( self, + bins: int | Sequence = 100, *, - model=None, - bins=100, - title=None, + model: str | int | None = None, + title: str | None = None, ax=None, - figsize=None, - density=True, - alpha=0.5, + figsize: Tuple[float, float] | None = None, + density: bool = True, + alpha: float = 0.5, **kwargs, ): """Plot histogram of model data and observations. @@ -143,8 +144,6 @@ def hist( Parameters ---------- - model : (str, int), optional - name or id of model to be plotted, by default 0 bins : int, optional number of bins, by default 100 title : str, optional @@ -168,11 +167,51 @@ def hist( pandas.Series.plot.hist matplotlib.axes.Axes.hist """ + cmp = self.comparer + + if model is None: + mod_names = cmp.mod_names + else: + warnings.warn( + "The 'model' keyword is deprecated! Instead, filter comparer before plotting cmp.sel(model=...).plot.hist()", + FutureWarning, + ) + model_list = [model] if isinstance(model, (str, int)) else model + mod_names = [cmp.mod_names[_get_idx(m, cmp.mod_names)] for m in model_list] + + axes = [] + for mod_name in mod_names: + ax_mod = self._hist_one_model( + mod_name=mod_name, + bins=bins, + title=title, + ax=ax, + figsize=figsize, + density=density, + alpha=alpha, + **kwargs, + ) + axes.append(ax_mod) + + return axes[0] if len(axes) == 1 else axes + + def _hist_one_model( + self, + *, + mod_name: str, + bins: int | Sequence | None, + title: str | None, + ax, + figsize: Tuple[float, float] | None, + density: bool | None, + alpha: float | None, + **kwargs, + ): from ._comparison import MOD_COLORS # TODO move to here cmp = self.comparer - mod_id = _get_idx(model, cmp.mod_names) - mod_name = cmp.mod_names[mod_id] + assert mod_name in cmp.mod_names, f"Model {mod_name} not found in comparer" + mod_id = _get_idx(mod_name, cmp.mod_names) title = f"{mod_name} vs {cmp.name}" if title is None else title @@ -269,6 +308,7 @@ def kde(self, ax=None, title=None, figsize=None, **kwargs) -> matplotlib.axes.Ax def qq( self, quantiles: int | Sequence[float] | None = None, + *, title=None, ax=None, figsize=None, @@ -352,7 +392,7 @@ def qq( return ax - def box(self, ax=None, title=None, figsize=None, **kwargs): + def box(self, *, ax=None, title=None, figsize=None, **kwargs): """Make a box plot of model data and observations. Wraps pandas.DataFrame boxplot() method. @@ -424,7 +464,7 @@ def scatter( Parameters ---------- - model : (str, int), optional + model : (str, int), optional, DEPRECATED name or id of model to be plotted, by default 0 bins: (int, float, sequence), optional bins for the 2D histogram on the background. By default 20 bins. @@ -488,8 +528,67 @@ def scatter( """ cmp = self.comparer - mod_id = _get_idx(model, cmp.mod_names) - mod_name = cmp.mod_names[mod_id] + if model is None: + mod_names = cmp.mod_names + else: + warnings.warn( + "The 'model' keyword is deprecated! Instead, filter comparer before plotting cmp.sel(model=...).plot.scatter()", + FutureWarning, + ) + model_list = [model] if isinstance(model, (str, int)) else model + mod_names = [cmp.mod_names[_get_idx(m, cmp.mod_names)] for m in model_list] + + axes = [] + for mod_name in mod_names: + ax_mod = self._scatter_one_model( + mod_name=mod_name, + bins=bins, + quantiles=quantiles, + fit_to_quantiles=fit_to_quantiles, + show_points=show_points, + show_hist=show_hist, + show_density=show_density, + norm=norm, + backend=backend, + figsize=figsize, + xlim=xlim, + ylim=ylim, + reg_method=reg_method, + title=title, + xlabel=xlabel, + ylabel=ylabel, + skill_table=skill_table, + **kwargs, + ) + axes.append(ax_mod) + return axes[0] if len(axes) == 1 else axes + + def _scatter_one_model( + self, + *, + mod_name: str, + bins: int | float, + quantiles: int | Sequence[float] | None, + fit_to_quantiles: bool, + show_points: bool | int | float | None, + show_hist: Optional[bool], + show_density: Optional[bool], + norm: Optional[colors.Normalize], + backend: str, + figsize: Tuple[float, float], + xlim: Optional[Tuple[float, float]], + ylim: Optional[Tuple[float, float]], + reg_method: str | bool, + title: Optional[str], + xlabel: Optional[str], + ylabel: Optional[str], + skill_table: Optional[Union[str, List[str], bool]], + **kwargs, + ): + """Scatter plot for one model only""" + + cmp = self.comparer + assert mod_name in cmp.mod_names, f"Model {mod_name} not found in comparer" if cmp.n_points == 0: raise ValueError("No data found in selection") @@ -552,6 +651,7 @@ def scatter( def taylor( self, + *, normalize_std: bool = False, figsize: Tuple[float, float] = (7, 7), marker: str = "o", diff --git a/modelskill/plotting/_misc.py b/modelskill/plotting/_misc.py index 00374475c..9d6ec0491 100644 --- a/modelskill/plotting/_misc.py +++ b/modelskill/plotting/_misc.py @@ -6,7 +6,7 @@ import numpy as np import pandas as pd -from ..metrics import metric_has_units +from ..metrics import metric_has_units, defined_metrics from ..observation import unit_display_name @@ -16,7 +16,7 @@ def _get_ax(ax=None, figsize=None): return ax -def _get_fig_ax(ax=None, figsize=None): +def _get_fig_ax(ax: plt.Axes | None = None, figsize=None): if ax is None: fig, ax = plt.subplots(figsize=figsize) else: @@ -149,8 +149,10 @@ def quantiles_xy( def format_skill_df(df: pd.DataFrame, units: str, precision: int = 2): - # remove model and variable columns if present, i.e. keep all other columns - df.drop(["model", "variable"], axis=1, errors="ignore", inplace=True) + # select metrics columns + accepted_columns = defined_metrics | {"n"} + + df = df.loc[:, df.columns.isin(accepted_columns)] # loop over series in dataframe, (columns) lines = [_format_skill_line(df[col], units, precision) for col in list(df.columns)]