Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Persistence using NetCDF groups #461

Draft
wants to merge 14 commits into
base: main
Choose a base branch
from
4 changes: 2 additions & 2 deletions .github/workflows/full_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ jobs:
- name: Test with pytest
run: |
pytest --cov=modelskill --ignore tests/notebooks/
- name: Static type check
run: make typecheck
#- name: Static type check
# run: make typecheck
- name: Test docstrings with pytest
run: |
pytest ./modelskill/metrics.py --doctest-modules
2 changes: 1 addition & 1 deletion .github/workflows/notebooks_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.11'
python-version: '3.13'
- name: Install modelskill
run: |
pip install .[test,notebooks]
Expand Down
5 changes: 2 additions & 3 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,7 @@ dmypy.json
# Pyre type checker
.pyre/

# ignore data except for .msk files
data/*
!data/*.msk
tests/testdata/tmp
/tmp/

Expand All @@ -148,7 +146,8 @@ streamlit_app.py
notes.md

notebooks/Untitled.ipynb
.envrc

docs/_site/
docs/_extensions/
docs/api/*.qmd
docs/api/*.qmd
4 changes: 2 additions & 2 deletions docs/user-guide/workflow.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -97,13 +97,13 @@ This method allow filtering of the data in several ways:
It can be useful to save the comparer collection for later use. This can be done using the `save()` method:

```python
cc.save("my_comparer_collection.msk")
cc.save("my_comparer_collection.nc")
```

The comparer collection can be loaded again from disk, using the `load()` method:

```python
cc = ms.load("my_comparer_collection.msk")
cc = ms.load("my_comparer_collection.nc")
```


Expand Down
8 changes: 4 additions & 4 deletions modelskill/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,12 @@


def load(filename: Union[str, Path]) -> ComparerCollection:
"""Load a ComparerCollection from a zip file.
"""Load a ComparerCollection from a netcdf file.

Parameters
----------
filename : str or Path
Filename of the zip file.
Filename of the nc file.

Returns
-------
Expand All @@ -60,8 +60,8 @@ def load(filename: Union[str, Path]) -> ComparerCollection:
Examples
--------
>>> cc = ms.match(obs, mod)
>>> cc.save("my_comparer_collection.msk")
>>> cc2 = ms.load("my_comparer_collection.msk")"""
>>> cc.save("my_comparer_collection.nc")
>>> cc2 = ms.load("my_comparer_collection.nc")"""

return ComparerCollection.load(filename)

Expand Down
70 changes: 37 additions & 33 deletions modelskill/comparison/_collection.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from __future__ import annotations
from copy import deepcopy
import os
from pathlib import Path
import tempfile
from typing import (
Any,
Callable,
Expand All @@ -18,9 +16,9 @@
Tuple,
)
import warnings
import zipfile
import numpy as np
import pandas as pd
import xarray as xr


from .. import metrics as mtr
Expand Down Expand Up @@ -917,42 +915,35 @@ def score(
return score

def save(self, filename: Union[str, Path]) -> None:
"""Save the ComparerCollection to a zip file.
"""Save the ComparerCollection to a hierarchical NetCDF file.

Each comparer is stored as a netcdf file in the zip file.
Each comparer is stored as a netcdf group.

Parameters
----------
filename : str or Path
Filename of the zip file.
Filename of the nc file.

Examples
--------
>>> cc = ms.match(obs, mod)
>>> cc.save("my_comparer_collection.msk")
>>> cc.save("my_comparer_collection.nc")
"""

files = []
no = 0
dt = xr.DataTree()
for name, cmp in self._comparers.items():
cmp_fn = f"{no}_{name}.nc"
cmp.save(cmp_fn)
files.append(cmp_fn)
no += 1
dtc = cmp._save()
dt[name] = dtc

with zipfile.ZipFile(filename, "w") as zip:
for f in files:
zip.write(f)
os.remove(f)
dt.to_netcdf(filename)

@staticmethod
def load(filename: Union[str, Path]) -> "ComparerCollection":
"""Load a ComparerCollection from a zip file.
def load(filename: Union[str, Path], method: str = "tree") -> "ComparerCollection":
"""Load a ComparerCollection from a NetCDF file.

Parameters
----------
filename : str or Path
Filename of the zip file.
Filename of the nc file.

Returns
-------
Expand All @@ -962,25 +953,38 @@ def load(filename: Union[str, Path]) -> "ComparerCollection":
Examples
--------
>>> cc = ms.match(obs, mod)
>>> cc.save("my_comparer_collection.msk")
>>> cc2 = ms.ComparerCollection.load("my_comparer_collection.msk")
>>> cc.save("my_comparer_collection.nc")
>>> cc2 = ms.ComparerCollection.load("my_comparer_collection.nc")
"""

folder = tempfile.TemporaryDirectory().name
if method == "tree":
dt = xr.open_datatree(filename)
groups = [x for x in dt.children]
comparers = [Comparer._load(dt[group]) for group in groups]

return ComparerCollection(comparers)
else:
import tempfile
import os
import zipfile

with zipfile.ZipFile(filename, "r") as zip:
for f in zip.namelist():
if f.endswith(".nc"):
zip.extract(f, path=folder)
folder = tempfile.TemporaryDirectory().name

comparers = [
ComparerCollection._load_comparer(folder, f)
for f in sorted(os.listdir(folder))
]
return ComparerCollection(comparers)
with zipfile.ZipFile(filename, "r") as zip:
for f in zip.namelist():
if f.endswith(".nc"):
zip.extract(f, path=folder)

comparers = [
ComparerCollection._load_comparer(folder, f)
for f in sorted(os.listdir(folder))
]
return ComparerCollection(comparers)

@staticmethod
def _load_comparer(folder: str, f: str) -> Comparer:
import os

f = os.path.join(folder, f)
cmp = Comparer.load(f)
os.remove(f)
Expand Down
70 changes: 53 additions & 17 deletions modelskill/comparison/_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -1249,6 +1249,29 @@ def to_dataframe(self) -> pd.DataFrame:
else:
raise NotImplementedError(f"Unknown gtype: {self.gtype}")

def _save(self) -> xr.DataTree:
ds = self.data

if self.gtype == "point":
dt = xr.DataTree()
dt["matched"] = ds
dt["raw"] = xr.DataTree()

for key, ts_mod in self.raw_mod_data.items():
ts_mod = ts_mod.copy()
dt["raw"][key] = ts_mod.data

dt.attrs["gtype"] = "point"
return dt
elif self.gtype == "track":
# There is no need to save raw data for track data, since it is identical to the matched data
dt = xr.DataTree()
dt.attrs["gtype"] = "track"
dt["matched"] = ds
return dt

raise NotImplementedError(f"Unknown gtype: {self.gtype}")

def save(self, filename: Union[str, Path]) -> None:
"""Save to netcdf file

Expand All @@ -1257,24 +1280,31 @@ def save(self, filename: Union[str, Path]) -> None:
filename : str or Path
filename
"""
ds = self.data
dt = self._save()

# add self.raw_mod_data to ds with prefix 'raw_' to avoid name conflicts
# an alternative strategy would be to use NetCDF groups
# https://docs.xarray.dev/en/stable/user-guide/io.html#groups
dt.to_netcdf(filename)

# There is no need to save raw data for track data, since it is identical to the matched data
if self.gtype == "point":
ds = self.data.copy() # copy needed to avoid modifying self.data
@staticmethod
def _load(data: xr.DataTree | xr.DataArray) -> "Comparer":
if data.gtype == "track":
return Comparer(matched_data=data["matched"].to_dataset())

for key, ts_mod in self.raw_mod_data.items():
ts_mod = ts_mod.copy()
# rename time to unique name
ts_mod.data = ts_mod.data.rename({"time": "_time_raw_" + key})
# da = ds_mod.to_xarray()[key]
ds["_raw_" + key] = ts_mod.data[key]
if data.gtype == "point":
raw_mod_data: Dict[str, PointModelResult] = {}

names = [x for x in data["raw"].children]
for var in names:
ds = data["raw"][var].to_dataset()
ts = PointModelResult(data=ds, name=var)

ds.to_netcdf(filename)
raw_mod_data[var] = ts

return Comparer(
matched_data=data["matched"].to_dataset(), raw_mod_data=raw_mod_data
)

else:
raise NotImplementedError(f"Unknown gtype: {data.gtype}")

@staticmethod
def load(filename: Union[str, Path]) -> "Comparer":
Expand All @@ -1289,6 +1319,15 @@ def load(filename: Union[str, Path]) -> "Comparer":
-------
Comparer
"""
try:
with xr.open_datatree(filename) as dt:
data = dt.load()
return Comparer._load(data)
except KeyError:
return Comparer._load_legacy(filename)

@staticmethod
def _load_legacy(filename: str | Path):
with xr.open_dataset(filename) as ds:
data = ds.load()

Expand All @@ -1313,6 +1352,3 @@ def load(filename: Union[str, Path]) -> "Comparer":
data = data[[v for v in data.data_vars if "time" in data[v].dims]]

return Comparer(matched_data=data, raw_mod_data=raw_mod_data)

else:
raise NotImplementedError(f"Unknown gtype: {data.gtype}")
4 changes: 2 additions & 2 deletions modelskill/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def vistula() -> ComparerCollection:
-------
ComparerCollection
"""
fn = str(files("modelskill.data") / "vistula.msk")
fn = str(files("modelskill.data") / "vistula.nc")
return ms.load(fn)


Expand All @@ -49,5 +49,5 @@ def oresund() -> ComparerCollection:
-------
ComparerCollection
"""
fn = str(files("modelskill.data") / "oresund.msk")
fn = str(files("modelskill.data") / "oresund.nc")
return ms.load(fn)
Binary file removed modelskill/data/oresund.msk
Binary file not shown.
Binary file added modelskill/data/oresund.nc
Binary file not shown.
Binary file not shown.
2 changes: 1 addition & 1 deletion modelskill/plotting/_scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,7 +640,7 @@ def _plot_summary_table(
## Render, and get width
# plt.draw() # TOOO this causes an error and I have no idea why it is here
dx = (
dx
dx # type: ignore
+ figure_transform.inverted().transform(
[text_col_i.get_window_extent().bounds[2], 0]
)[0]
Expand Down
14,153 changes: 12,622 additions & 1,531 deletions notebooks/Metocean_MIKE21SW_DutchCoast.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ dependencies = [
"pandas >= 1.4",
"mikeio >= 1.2",
"matplotlib",
"xarray",
"xarray>=2024.10.0",
"netCDF4",
"scipy",
"jinja2",
Expand Down
10 changes: 10 additions & 0 deletions tests/test_comparer.py
Original file line number Diff line number Diff line change
Expand Up @@ -889,6 +889,16 @@ def test_from_matched_dfs0():
) == pytest.approx(0.0476569069177831)


def test_save_and_load(pc: Comparer, tmp_path) -> None:
filename = tmp_path / "test.nc"
pc.save(filename)

pc2 = Comparer.load(filename)

assert pc2.name == pc.name
assert pc2.gtype == pc.gtype


def test_from_matched_x_or_x_item_not_both():
with pytest.raises(ValueError, match="x and x_item cannot both be specified"):
ms.from_matched(
Expand Down
8 changes: 4 additions & 4 deletions tests/test_comparercollection.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ def test_save_and_load_preserves_order_of_comparers(tmp_path):
assert cc[1].name == "alpha"
assert cc[2].name == "bravo"

fn = tmp_path / "test_cc.msk"
fn = tmp_path / "test_cc.nc"
cc.save(fn)

cc2 = modelskill.load(fn)
Expand All @@ -422,7 +422,7 @@ def test_save_and_load_preserves_order_of_comparers(tmp_path):


def test_save(cc: modelskill.ComparerCollection, tmp_path):
fn = tmp_path / "test_cc.msk"
fn = tmp_path / "test_cc.nc"
assert cc[0].data.attrs["modelskill_version"] == modelskill.__version__
cc.save(fn)

Expand All @@ -434,15 +434,15 @@ def test_save(cc: modelskill.ComparerCollection, tmp_path):


def test_load_from_root_module(cc, tmp_path):
fn = tmp_path / "test_cc.msk"
fn = tmp_path / "test_cc.nc"
cc.save(fn)

cc2 = modelskill.load(fn)
assert len(cc2) == 2


def test_save_and_load_preserves_raw_model_data(cc, tmp_path):
fn = tmp_path / "test_cc.msk"
fn = tmp_path / "test_cc.nc"
assert len(cc["fake point obs"].raw_mod_data["m1"]) == 6
cc.save(fn)

Expand Down
Loading