Skip to content

Commit

Permalink
enable pytorch_channel option (#13)
Browse files Browse the repository at this point in the history
  • Loading branch information
pmeier authored Apr 9, 2021
1 parent 6407fbe commit 9df80db
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 6 deletions.
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ packages = find:
include_package_data = True
python_requires = >=3.6
install_requires =
light-the-torch>=0.2
light-the-torch>=0.3
tox

[options.packages.find]
Expand Down
21 changes: 21 additions & 0 deletions tests/test_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def get_setup_cfg(name, version, install_requires=None, extra_requires=None):
def get_tox_ini(
basepython=None,
disable_light_the_torch=None,
pytorch_channel=None,
pytorch_force_cpu=None,
force_cpu=None,
deps=None,
Expand All @@ -90,6 +91,8 @@ def get_tox_ini(
lines.append("extras = extra")
if disable_light_the_torch is not None:
lines.append(f"disable_light_the_torch = {disable_light_the_torch}")
if pytorch_channel is not None:
lines.append(f"pytorch_channel = {pytorch_channel}")
if pytorch_force_cpu is not None:
lines.append(f"pytorch_force_cpu = {pytorch_force_cpu}")
if force_cpu is not None:
Expand All @@ -110,6 +113,7 @@ def tox_ltt_initproj_(
install_requires=None,
extra_requires=None,
disable_light_the_torch=None,
pytorch_channel=None,
pytorch_force_cpu=None,
force_cpu=None,
deps=None,
Expand All @@ -130,6 +134,7 @@ def tox_ltt_initproj_(
usedevelop=usedevelop,
extra=extra_requires is not None,
disable_light_the_torch=disable_light_the_torch,
pytorch_channel=pytorch_channel,
pytorch_force_cpu=pytorch_force_cpu,
force_cpu=force_cpu,
deps=deps,
Expand All @@ -151,6 +156,7 @@ def test_help_ini(cmd):
result = cmd("--help-ini")
result.assert_success(is_run_test_env=False)
assert "disable_light_the_torch" in result.out
assert "pytorch_channel" in result.out
assert "pytorch_force_cpu" in result.out


Expand All @@ -165,6 +171,21 @@ def test_tox_ltt_disabled(patch_extract_dists, tox_ltt_initproj, cmd):
mock.assert_not_called()


@pytest.mark.slow
def test_tox_ltt_pytorch_channel(patch_find_links, tox_ltt_initproj, cmd, install_mock):
channel = "channel"

mock = patch_find_links()
tox_ltt_initproj(deps=("torch",), pytorch_channel=channel)

result = cmd()

result.assert_success(is_run_test_env=False)

_, kwargs = mock.call_args
assert kwargs["channel"] == channel


@pytest.mark.slow
def test_tox_ltt_pytorch_force_cpu(
patch_find_links, tox_ltt_initproj, cmd, install_mock
Expand Down
16 changes: 11 additions & 5 deletions tox_ltt/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from light_the_torch.computation_backend import CPUBackend


def extract_force_cpu_help() -> str:
def extract_ltt_option_help(subcommand: str, option: str) -> str:
def extract(seq: Sequence, attr: str, eq_cond: Any) -> Any:
reduced_seq = [item for item in seq if getattr(item, attr) == eq_cond]
assert len(reduced_seq) == 1
Expand All @@ -22,9 +22,8 @@ def extract(seq: Sequence, attr: str, eq_cond: Any) -> Any:

argument_group = extract(ltt_parser._action_groups, "title", "subcommands")
sub_parsers = extract(argument_group._actions, "dest", "subcommand")
install_parser = sub_parsers.choices["install"]
force_cpu = extract(install_parser._actions, "dest", "force_cpu")
return cast(str, force_cpu.help)
subcommand_parser = sub_parsers.choices[subcommand]
return cast(str, extract(subcommand_parser._actions, "dest", option).help)


@tox.hookimpl
Expand All @@ -35,10 +34,16 @@ def tox_addoption(parser: Parser) -> None:
help="disable installing PyTorch distributions with light-the-torch",
default=False,
)
parser.add_testenv_attribute(
name="pytorch_channel",
type="string",
help=extract_ltt_option_help("install", "channel"),
default="stable",
)
parser.add_testenv_attribute(
name="pytorch_force_cpu",
type="bool",
help=extract_force_cpu_help(),
help=extract_ltt_option_help("install", "force_cpu"),
default=False,
)
parser.add_testenv_attribute(
Expand Down Expand Up @@ -98,6 +103,7 @@ def tox_testenv_install_deps(venv: VirtualEnv, action: Action) -> None:
links = ltt.find_links(
dists,
computation_backend=get_computation_backend(envconfig),
channel=envconfig.pytorch_channel,
python_version=get_python_version(envconfig),
)

Expand Down

0 comments on commit 9df80db

Please sign in to comment.