Skip to content

Commit

Permalink
use Traversable rather than custom function
Browse files Browse the repository at this point in the history
  • Loading branch information
scivision committed Jan 9, 2025
1 parent 543eebc commit fcc605d
Show file tree
Hide file tree
Showing 7 changed files with 24 additions and 47 deletions.
6 changes: 3 additions & 3 deletions src/gemini3d/compare/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations
import json

from ..utils import get_pkg_file
import importlib.resources as ir


def err_pct(a, b) -> float:
Expand All @@ -26,5 +25,6 @@ def err_pct(a, b) -> float:


def load_tol() -> dict[str, float]:
tol_json = get_pkg_file("gemini3d.compare", "tolerance.json").read_text()
file = ir.files(f"{__package__}.data") / "tolerance.json"
tol_json = file.read_text()
return json.loads(tol_json)
8 changes: 2 additions & 6 deletions src/gemini3d/hpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import binascii
from pathlib import Path
import shutil
import importlib.resources
import importlib.resources as ir


def hpc_submit_job(batcher: str, job_file: Path):
Expand Down Expand Up @@ -41,11 +41,7 @@ def hpc_batch_create(batcher: str, out_dir: Path, cmd: list[str]) -> Path:
Nchar = 6 # arbitrary number of characters

if batcher == "qsub":
template = (
importlib.resources.files("gemini3d.templates")
.joinpath("qsub_template.job")
.read_text()
)
template = (ir.files("gemini3d.templates") / "qsub_template.job").read_text()
job_file = out_dir / f"{binascii.b2a_hex(os.urandom(Nchar)).decode('ascii')}.job"
print("writing job file", job_file)
text = template + "\n" + " ".join(cmd)
Expand Down
14 changes: 9 additions & 5 deletions src/gemini3d/linux_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,16 @@ def parse_os_release(txt: str) -> list[str]:

C = ConfigParser(inline_comment_prefixes=("#", ";"))
C.read_string(txt)
like = C["all"].get("ID_LIKE", fallback="")
if not like:
like = C["all"].get("ID", fallback="")
like = like.strip('"').strip("'").split()

return like
try:
like = C["all"]["ID_LIKE"]
except KeyError:
try:
like = C["all"]["ID"]
except KeyError:
like = ""

return like.strip('"').strip("'").split()


def get_package_manager(like: list[str] | None = None) -> str:
Expand Down
12 changes: 6 additions & 6 deletions src/gemini3d/tests/unit/test_config.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import importlib.resources as ir
import pytest
import os
from datetime import datetime, timedelta
Expand All @@ -6,7 +7,6 @@
import gemini3d.config as config
import gemini3d.read as read
import gemini3d.model as model
from gemini3d.utils import get_pkg_file


def test_model_config(tmp_path):
Expand Down Expand Up @@ -101,7 +101,7 @@ def test_nml_bad(tmp_path):
@pytest.mark.parametrize("group", ["base", "flags", "files", "precip", "efield"])
def test_namelist_exists(group):
assert config.namelist_exists(
get_pkg_file("gemini3d.tests.config", "config_example.nml"), group
ir.files("gemini3d.tests.config") / "config_example.nml", group
)


Expand All @@ -110,7 +110,7 @@ def test_nml_gemini_env_root(monkeypatch, tmp_path):
monkeypatch.setenv("GEMINI_CIROOT", str(tmp_path))

cfg = config.parse_namelist(
get_pkg_file("gemini3d.tests.config", "config_example.nml"), "setup"
ir.files("gemini3d.tests.config") / "config_example.nml", "setup"
)

assert isinstance(cfg["eq_dir"], Path)
Expand All @@ -122,7 +122,7 @@ def test_nml_gemini_env_root(monkeypatch, tmp_path):
@pytest.mark.parametrize("namelist", ["base", "flags", "files", "precip", "efield"])
def test_nml_namelist(namelist):
params = config.parse_namelist(
get_pkg_file("gemini3d.tests.config", "config_example.nml"), namelist
ir.files("gemini3d.tests.config") / "config_example.nml", namelist
)

if "base" in namelist:
Expand All @@ -138,7 +138,7 @@ def test_nml_namelist(namelist):
@pytest.mark.parametrize("namelist", ["neutral_BG"])
def test_msis2_namelist(namelist):
params = config.parse_namelist(
get_pkg_file("gemini3d.tests.config", "config_msis2.nml"), namelist
ir.files("gemini3d.tests.config") / "config_msis2.nml", namelist
)

if "neutral_BG" in namelist:
Expand All @@ -151,7 +151,7 @@ def test_read_config_nml(monkeypatch, tmp_path):
if not os.environ.get("GEMINI_CIROOT"):
monkeypatch.setenv("GEMINI_CIROOT", str(tmp_path))

params = read.config(get_pkg_file("gemini3d.tests.config", "config_example.nml"))
params = read.config(ir.files("gemini3d.tests.config") / "config_example.nml")

assert params["time"][0] == datetime(2013, 2, 20, 5)
assert params["dtprec"] == timedelta(seconds=5)
Expand Down
4 changes: 2 additions & 2 deletions src/gemini3d/tests/unit/test_find.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import importlib.resources as ir
import pytest
from datetime import datetime
from pathlib import Path

import gemini3d
import gemini3d.find as find
import gemini3d.web
from gemini3d.utils import get_pkg_file


def test_config(tmp_path):
Expand All @@ -15,7 +15,7 @@ def test_config(tmp_path):
with pytest.raises(FileNotFoundError):
find.config(tmp_path / "not_exist")

cfn = get_pkg_file("gemini3d.tests.config", "config_example.nml")
cfn = ir.files("gemini3d.tests.config") / "config_example.nml"
fn = find.config(cfn)
assert fn == cfn

Expand Down
25 changes: 1 addition & 24 deletions src/gemini3d/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import os
import shutil
from pathlib import Path
import importlib.resources as pkgr

from datetime import datetime, timedelta
import typing as T
Expand All @@ -17,29 +16,7 @@
import numpy as np


__all__ = ["get_pkg_file", "str2func", "to_datetime", "git_meta", "datetime2stem"]


def get_pkg_file(package: str, filename: str) -> Path:
"""Get a file from a package.
Parameters
----------
package : str
Package name.
filename : str
File name.
Returns
-------
Path
Path to the file.
NOTE: this probably assumes the install is Zip safe
"""

with pkgr.as_file(pkgr.files(package).joinpath(filename)) as f:
return f
__all__ = ["str2func", "to_datetime", "git_meta", "datetime2stem"]


def str2func(name: str, path: Path | None = None) -> T.Callable:
Expand Down
2 changes: 1 addition & 1 deletion src/gemini3d/web.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def download_and_extract(test_name: str, data_dir: Path) -> Path:

if not ref_file.is_file():
jmeta = json.loads(
importlib.resources.files("gemini3d").joinpath("libraries.json").read_text()
(importlib.resources.files("gemini3d") / "libraries.json").read_text()
)
url_retrieve(url=jmeta["ref_data"]["url"], outfile=ref_file)

Expand Down

0 comments on commit fcc605d

Please sign in to comment.