diff --git a/pyproject.toml b/pyproject.toml index 4927be2..50e3d3c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,140 +1,134 @@ [build-system] -requires = [ - "setuptools>=64", - "setuptools-scm[toml]>=6.2" -] -build-backend = "setuptools.build_meta" + build-backend = "setuptools.build_meta" + requires = ["setuptools-scm[toml]>=6.2", "setuptools>=64"] [project] -name = "xbatcher" -description = "Batch generation from Xarray objects" -readme = "README.rst" -license = {text = "Apache"} -authors = [{name = "xbatcher Developers", email = "rpa@ldeo.columbia.edu"}] -requires-python = ">=3.10" -classifiers = [ - "Development Status :: 4 - Beta", - "License :: OSI Approved :: Apache Software License", - "Operating System :: OS Independent", - "Intended Audience :: Science/Research", - "Programming Language :: Python", - "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.10", - "Programming Language :: Python :: 3.11", - "Programming Language :: Python :: 3.12", - "Topic :: Scientific/Engineering", -] -dynamic = ["version"] -dependencies = [ - "dask", - "numpy", - "xarray", -] + authors = [ + { name = "xbatcher Developers", email = "rpa@ldeo.columbia.edu" }, + ] + classifiers = [ + "Development Status :: 4 - Beta", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: Apache Software License", + "Operating System :: OS Independent", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python", + "Topic :: Scientific/Engineering", + ] + dependencies = ["dask", "numpy", "xarray"] + description = "Batch generation from Xarray objects" + dynamic = ["version"] + license = { text = "Apache" } + name = "xbatcher" + readme = "README.rst" + requires-python = ">=3.10" [project.optional-dependencies] -torch = [ - "torch", -] -tensorflow = [ - "tensorflow", -] -dev = [ - "adlfs", - "asv", - "coverage", - "pytest", - "pytest-cov", - "tensorflow", - "torch", - "zarr", -] + dev = [ + "adlfs", + "asv", + "coverage", + "pytest", + "pytest-cov", + "tensorflow", + "torch", + "zarr<3.0", + ] + tensorflow = ["tensorflow"] + torch = ["torch"] [project.urls] -documentation = "https://xbatcher.readthedocs.io/en/latest/" -repository = "https://github.com/xarray-contrib/xbatcher" + documentation = "https://xbatcher.readthedocs.io/en/latest/" + repository = "https://github.com/xarray-contrib/xbatcher" [tool.setuptools.packages.find] -include = ["xbatcher*"] + include = ["xbatcher*"] [tool.setuptools_scm] -local_scheme = "node-and-date" -fallback_version = "999" - - + fallback_version = "999" + local_scheme = "node-and-date" [tool.ruff] -target-version = "py310" -extend-include = ["*.ipynb"] - - -builtins = ["ellipsis"] -# Exclude a variety of commonly ignored directories. -exclude = [ - ".bzr", - ".direnv", - ".eggs", - ".git", - ".git-rewrite", - ".hg", - ".ipynb_checkpoints", - ".mypy_cache", - ".nox", - ".pants.d", - ".pyenv", - ".pytest_cache", - ".pytype", - ".ruff_cache", - ".svn", - ".tox", - ".venv", - ".vscode", - "__pypackages__", - "_build", - "buck-out", - "build", - "dist", - "node_modules", - "site-packages", - "venv", -] + extend-include = ["*.ipynb"] + target-version = "py310" + + builtins = ["ellipsis"] + # Exclude a variety of commonly ignored directories. + exclude = [ + ".bzr", + ".direnv", + ".eggs", + ".git", + ".git-rewrite", + ".hg", + ".ipynb_checkpoints", + ".mypy_cache", + ".nox", + ".pants.d", + ".pyenv", + ".pytest_cache", + ".pytype", + ".ruff_cache", + ".svn", + ".tox", + ".venv", + ".vscode", + "__pypackages__", + "_build", + "buck-out", + "build", + "dist", + "node_modules", + "site-packages", + "venv", + ] [tool.ruff.lint] -per-file-ignores = {} -ignore = [ - "E721", # Comparing types instead of isinstance - "E741", # Ambiguous variable names - "E501", # Conflicts with ruff format -] -select = [ - # Pyflakes - "F", - # Pycodestyle - "E", - "W", - # isort - "I", - # Pyupgrade - "UP", -] - + ignore = [ + "E501", # Conflicts with ruff format + "E721", # Comparing types instead of isinstance + "E741", # Ambiguous variable names + ] + per-file-ignores = {} + select = [ + # Pyflakes + "F", + # Pycodestyle + "E", + "W", + # isort + "I", + # Pyupgrade + "UP", + ] [tool.ruff.lint.mccabe] -max-complexity = 18 + max-complexity = 18 [tool.ruff.lint.isort] -known-first-party = ["xbatcher"] -known-third-party = ["numpy", "pandas", "pytest", "sphinx_autosummary_accessors", "torch", "xarray"] - -combine-as-imports = true + known-first-party = ["xbatcher"] + known-third-party = [ + "numpy", + "pandas", + "pytest", + "sphinx_autosummary_accessors", + "torch", + "xarray", + ] + + combine-as-imports = true [tool.ruff.format] -quote-style = "single" -docstring-code-format = true + docstring-code-format = true + quote-style = "single" [tool.ruff.lint.pydocstyle] -convention = "numpy" + convention = "numpy" [tool.ruff.lint.pyupgrade] -# Preserve types, even if a file imports `from __future__ import annotations`. -keep-runtime-typing = true + # Preserve types, even if a file imports `from __future__ import annotations`. + keep-runtime-typing = true [tool.pytest.ini_options] -log_cli = true -log_level = "INFO" + log_cli = true + log_level = "INFO" diff --git a/xbatcher/generators.py b/xbatcher/generators.py index 412d09a..2616ab3 100644 --- a/xbatcher/generators.py +++ b/xbatcher/generators.py @@ -268,24 +268,26 @@ def to_json(self): out_json: str The JSON representation of the BatchSchema """ - out_dict = {} - out_dict['input_dims'] = self.input_dims - out_dict['input_overlap'] = self.input_overlap - out_dict['batch_dims'] = self.batch_dims - out_dict['concat_input_dims'] = self.input_dims - out_dict['preload_batch'] = self.preload_batch + out_dict = { + 'input_dims': self.input_dims, + 'input_overlap': self.input_overlap, + 'batch_dims': self.batch_dims, + 'concat_input_dims': self.input_dims, + 'preload_batch': self.preload_batch, + } batch_selector_dict = {} for i in self.selectors.keys(): batch_selector_dict[i] = self.selectors[i] for member in batch_selector_dict[i]: - out_member_dict = {} - member_keys = [x for x in member.keys()] - for member_key in member_keys: - out_member_dict[member_key] = { + member_keys = list(member.keys()) + out_member_dict = { + member_key: { 'start': member[member_key].start, 'stop': member[member_key].stop, 'step': member[member_key].step, } + for member_key in member_keys + } out_dict['selector'] = out_member_dict return json.dumps(out_dict) @@ -323,12 +325,13 @@ def _iterate_through_dimensions( ds: xr.Dataset | xr.DataArray, *, dims: dict[Hashable, int], - overlap: dict[Hashable, int] = {}, + overlap: dict[Hashable, int] | None = None, ) -> Iterator[dict[Hashable, slice]]: + if overlap is None: + overlap = {} dim_slices = [] - for dim in dims: + for dim, slice_size in dims.items(): dim_size = ds.sizes[dim] - slice_size = dims[dim] slice_overlap = overlap.get(dim, 0) if slice_size > dim_size: raise ValueError( @@ -353,7 +356,7 @@ def _drop_input_dims( # remove input_dims coordinates from datasets, rename the dimensions # then put intput_dims back in as coordinates out = ds.copy() - for dim in input_dims.keys(): + for dim in input_dims: newdim = f'{dim}{suffix}' out = out.rename({dim: newdim}) # extra steps needed if there is a coordinate @@ -422,13 +425,17 @@ def __init__( self, ds: xr.Dataset | xr.DataArray, input_dims: dict[Hashable, int], - input_overlap: dict[Hashable, int] = {}, - batch_dims: dict[Hashable, int] = {}, + input_overlap: dict[Hashable, int] | None = None, + batch_dims: dict[Hashable, int] | None = None, concat_input_dims: bool = False, preload_batch: bool = True, cache: dict[str, Any] | None = None, cache_preprocess: Callable | None = None, ): + if input_overlap is None: + input_overlap = {} + if batch_dims is None: + batch_dims = {} self.ds = ds self.cache = cache self.cache_preprocess = cache_preprocess @@ -481,43 +488,40 @@ def __getitem__(self, idx: int) -> xr.Dataset | xr.DataArray: if self.cache and self._batch_in_cache(idx): return self._get_cached_batch(idx) - if idx in self._batch_selectors.selectors: - if self.concat_input_dims: - new_dim_suffix = '_input' - all_dsets: list = [] - batch_selector = {} - for dim in self._batch_selectors.batch_dims.keys(): - starts = [ - x[dim].start for x in self._batch_selectors.selectors[idx] - ] - stops = [x[dim].stop for x in self._batch_selectors.selectors[idx]] - batch_selector[dim] = slice(min(starts), max(stops)) - batch_ds = self.ds.isel(batch_selector) - if self.preload_batch: - batch_ds.load() - for selector in self._batch_selectors.selectors[idx]: - patch_ds = self.ds.isel(selector) - all_dsets.append( - _drop_input_dims( - patch_ds, - self.input_dims, - suffix=new_dim_suffix, - ) + if idx not in self._batch_selectors.selectors: + raise IndexError('list index out of range') + + if self.concat_input_dims: + new_dim_suffix = '_input' + all_dsets: list = [] + batch_selector = {} + for dim in self._batch_selectors.batch_dims.keys(): + starts = [x[dim].start for x in self._batch_selectors.selectors[idx]] + stops = [x[dim].stop for x in self._batch_selectors.selectors[idx]] + batch_selector[dim] = slice(min(starts), max(stops)) + batch_ds = self.ds.isel(batch_selector) + if self.preload_batch: + batch_ds.load() + for selector in self._batch_selectors.selectors[idx]: + patch_ds = self.ds.isel(selector) + all_dsets.append( + _drop_input_dims( + patch_ds, + self.input_dims, + suffix=new_dim_suffix, ) - dsc = xr.concat(all_dsets, dim='input_batch') - new_input_dims = [str(dim) + new_dim_suffix for dim in self.input_dims] - batch = _maybe_stack_batch_dims(dsc, new_input_dims) - else: - batch_ds = self.ds.isel(self._batch_selectors.selectors[idx][0]) - if self.preload_batch: - batch_ds.load() - batch = _maybe_stack_batch_dims( - batch_ds, - list(self.input_dims), ) + dsc = xr.concat(all_dsets, dim='input_batch') + new_input_dims = [str(dim) + new_dim_suffix for dim in self.input_dims] + batch = _maybe_stack_batch_dims(dsc, new_input_dims) else: - raise IndexError('list index out of range') - + batch_ds = self.ds.isel(self._batch_selectors.selectors[idx][0]) + if self.preload_batch: + batch_ds.load() + batch = _maybe_stack_batch_dims( + batch_ds, + list(self.input_dims), + ) if self.cache is not None and self.cache_preprocess is not None: batch = self.cache_preprocess(batch) if self.cache is not None: