Skip to content

Commit

Permalink
Schema to json (#159)
Browse files Browse the repository at this point in the history
Co-authored-by: Robert Jackson <[email protected]>
Co-authored-by: Anderson Banihirwe <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Anderson Banihirwe <[email protected]>
  • Loading branch information
5 people authored Aug 13, 2024
1 parent 8e226e8 commit fd0244b
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 1 deletion.
44 changes: 44 additions & 0 deletions xbatcher/generators.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Classes for iterating through xarray datarrays / datasets in batches."""

import itertools
import json
import warnings
from collections.abc import Hashable, Iterator, Sequence
from operator import itemgetter
Expand Down Expand Up @@ -262,6 +263,49 @@ def _get_batch_in_range_per_batch(self, batch_multi_index):
batch_in_range_per_patch = np.all(batch_multi_index < batch_id_maximum, axis=0)
return batch_in_range_per_patch

def to_json(self):
"""
Dump the BatchSchema properties to a JSON file.
Returns
----------
out_json: str
The JSON representation of the BatchSchema
"""
out_dict = {}
out_dict['input_dims'] = self.input_dims
out_dict['input_overlap'] = self.input_overlap
out_dict['batch_dims'] = self.batch_dims
out_dict['concat_input_dims'] = self.input_dims
out_dict['preload_batch'] = self.preload_batch
batch_selector_dict = {}
for i in self.selectors.keys():
batch_selector_dict[i] = self.selectors[i]
for member in batch_selector_dict[i]:
out_member_dict = {}
member_keys = [x for x in member.keys()]
for member_key in member_keys:
out_member_dict[member_key] = {
'start': member[member_key].start,
'stop': member[member_key].stop,
'step': member[member_key].step,
}
out_dict['selector'] = out_member_dict
return json.dumps(out_dict)

def to_file(self, out_file_name: str):
"""
Dumps the JSON representation of the BatchSchema object to a file.
Parameters
----------
out_file_name: str
The path to the json file to write to.
"""
out_json = self.to_json()
with open(out_file_name, mode='w') as out_file:
out_file.write(out_json)


def _gen_slices(*, dim_size: int, slice_size: int, overlap: int = 0) -> list[slice]:
# return a list of slices to chop up a single dimension
Expand Down
19 changes: 18 additions & 1 deletion xbatcher/tests/test_generators.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import json
import tempfile
from typing import Any

import numpy as np
import pytest
import xarray as xr

from xbatcher import BatchGenerator
from xbatcher import BatchGenerator, BatchSchema
from xbatcher.testing import (
get_batch_dimensions,
validate_batch_dimensions,
Expand Down Expand Up @@ -360,6 +362,21 @@ def test_input_overlap_exceptions(sample_ds_1d):
assert len(e) == 1


@pytest.mark.parametrize('input_size', [5, 10])
def test_to_json(sample_ds_3d, input_size):
x_input_size = 20
bg = BatchSchema(
sample_ds_3d,
input_dims={'time': input_size, 'x': x_input_size},
)
out_file = tempfile.NamedTemporaryFile(mode='w+b')
bg.to_file(out_file.name)
in_dict = json.load(out_file)
assert in_dict['input_dims']['time'] == input_size
assert in_dict['input_dims']['x'] == x_input_size
out_file.close()


@pytest.mark.parametrize('preload', [True, False])
def test_batcher_cached_getitem(sample_ds_1d, preload) -> None:
pytest.importorskip('zarr')
Expand Down

0 comments on commit fd0244b

Please sign in to comment.