From a33a3bb997f1d5949f5c0dc04b3778e9aba958a4 Mon Sep 17 00:00:00 2001 From: Henrik Andersson Date: Wed, 3 Jan 2024 06:08:37 +0100 Subject: [PATCH] Comparer/ComparerCollection isel --- modelskill/comparison/_collection.py | 19 +++++++++++++++++++ modelskill/comparison/_comparison.py | 15 +++++++++++++++ tests/test_comparer.py | 7 +++++++ tests/test_comparercollection.py | 11 +++++++++++ tests/test_multimodelcompare.py | 5 +++++ 5 files changed, 57 insertions(+) diff --git a/modelskill/comparison/_collection.py b/modelskill/comparison/_collection.py index 32c43355b..38360dd06 100644 --- a/modelskill/comparison/_collection.py +++ b/modelskill/comparison/_collection.py @@ -427,6 +427,25 @@ def sel( return cc + def isel( + self, model: int | None = None, observation: int | None = None + ) -> "ComparerCollection": + """Select data based on model and observation index. + + Parameters + ---------- + model : int, optional + Model index. If None, all models are selected. + observation : int, optional + Observation index. If None, all observations are selected. + + Returns + ------- + ComparerCollection + New ComparerCollection with selected data. + """ + return self.sel(model=model, observation=observation) + def filter_by_attrs(self, **kwargs) -> "ComparerCollection": """Filter by comparer attrs similar to xarray.Dataset.filter_by_attrs diff --git a/modelskill/comparison/_comparison.py b/modelskill/comparison/_comparison.py index 7e8b8d2e2..4a3382e07 100644 --- a/modelskill/comparison/_comparison.py +++ b/modelskill/comparison/_comparison.py @@ -931,6 +931,21 @@ def sel( d = d.isel(time=mask) return Comparer.from_matched_data(data=d, raw_mod_data=raw_mod_data) + def isel(self, model: int) -> "Comparer": + """Select data based on model index. + + Parameters + ---------- + model : int + Model index. + + Returns + ------- + Comparer + New Comparer with selected data. + """ + return self.sel(model=model) + def where( self, cond: Union[bool, np.ndarray, xr.DataArray], diff --git a/tests/test_comparer.py b/tests/test_comparer.py index 04ad4fa80..afaadaec5 100644 --- a/tests/test_comparer.py +++ b/tests/test_comparer.py @@ -444,6 +444,13 @@ def test_pc_sel_model_first(pc): assert np.all(pc2.data.m1 == pc.data.m1) +def test_pc_isel_model_first(pc): + pc2 = pc.isel(model=0) + assert pc2.n_points == 5 + assert pc2.n_models == 1 + assert np.all(pc2.data.m1 == pc.data.m1) + + def test_pc_sel_model_last(pc): pc2 = pc.sel(model=-1) assert pc2.n_points == 5 diff --git a/tests/test_comparercollection.py b/tests/test_comparercollection.py index 3538fa1a7..5d3d64ef7 100644 --- a/tests/test_comparercollection.py +++ b/tests/test_comparercollection.py @@ -129,6 +129,17 @@ def test_cc_sel_model_last(cc): assert cc2.mod_names == ["m3"] +def test_cc_isel_model_last(cc): + cc2 = cc.isel(model=-1) + assert len(cc2) == 1 + assert cc2.n_models == 1 + assert cc2.n_points == 5 + assert cc2.start_time == pd.Timestamp("2019-01-03") + assert cc2.end_time == pd.Timestamp("2019-01-07") + assert cc2.obs_names == ["fake track obs"] + assert cc2.mod_names == ["m3"] + + # TODO: FAILS # def test_cc_sel_time_single(cc): # cc1 = cc.sel(time="2019-01-03") diff --git a/tests/test_multimodelcompare.py b/tests/test_multimodelcompare.py index b12e9472a..680c4bcd7 100644 --- a/tests/test_multimodelcompare.py +++ b/tests/test_multimodelcompare.py @@ -125,6 +125,11 @@ def test_mm_skill_obs(cc): assert s.loc["SW_2"].bias == s2.loc["SW_2"].bias +def test_mm_isel(cc): + s2 = cc.isel(observation=-1).skill() + assert s2.loc["SW_2"].bias == pytest.approx(0.081431053) + + def test_mm_mean_skill_obs(cc): df = cc.sel(model=0, observation=[0, "c2"]).mean_skill().to_dataframe() assert pytest.approx(df.iloc[0].si) == 0.11113215