diff --git a/setup.cfg b/setup.cfg index eea2257..7317d42 100644 --- a/setup.cfg +++ b/setup.cfg @@ -25,7 +25,7 @@ packages = find: include_package_data = True python_requires = >=3.6 install_requires = - light-the-torch>=0.2 + light-the-torch>=0.3 tox [options.packages.find] diff --git a/tests/test_plugin.py b/tests/test_plugin.py index 380a22a..73b7e02 100644 --- a/tests/test_plugin.py +++ b/tests/test_plugin.py @@ -65,6 +65,7 @@ def get_setup_cfg(name, version, install_requires=None, extra_requires=None): def get_tox_ini( basepython=None, disable_light_the_torch=None, + pytorch_channel=None, pytorch_force_cpu=None, force_cpu=None, deps=None, @@ -90,6 +91,8 @@ def get_tox_ini( lines.append("extras = extra") if disable_light_the_torch is not None: lines.append(f"disable_light_the_torch = {disable_light_the_torch}") + if pytorch_channel is not None: + lines.append(f"pytorch_channel = {pytorch_channel}") if pytorch_force_cpu is not None: lines.append(f"pytorch_force_cpu = {pytorch_force_cpu}") if force_cpu is not None: @@ -110,6 +113,7 @@ def tox_ltt_initproj_( install_requires=None, extra_requires=None, disable_light_the_torch=None, + pytorch_channel=None, pytorch_force_cpu=None, force_cpu=None, deps=None, @@ -130,6 +134,7 @@ def tox_ltt_initproj_( usedevelop=usedevelop, extra=extra_requires is not None, disable_light_the_torch=disable_light_the_torch, + pytorch_channel=pytorch_channel, pytorch_force_cpu=pytorch_force_cpu, force_cpu=force_cpu, deps=deps, @@ -151,6 +156,7 @@ def test_help_ini(cmd): result = cmd("--help-ini") result.assert_success(is_run_test_env=False) assert "disable_light_the_torch" in result.out + assert "pytorch_channel" in result.out assert "pytorch_force_cpu" in result.out @@ -165,6 +171,21 @@ def test_tox_ltt_disabled(patch_extract_dists, tox_ltt_initproj, cmd): mock.assert_not_called() +@pytest.mark.slow +def test_tox_ltt_pytorch_channel(patch_find_links, tox_ltt_initproj, cmd, install_mock): + channel = "channel" + + mock = patch_find_links() + tox_ltt_initproj(deps=("torch",), pytorch_channel=channel) + + result = cmd() + + result.assert_success(is_run_test_env=False) + + _, kwargs = mock.call_args + assert kwargs["channel"] == channel + + @pytest.mark.slow def test_tox_ltt_pytorch_force_cpu( patch_find_links, tox_ltt_initproj, cmd, install_mock diff --git a/tox_ltt/plugin.py b/tox_ltt/plugin.py index 8f3b1fb..c8b6ecd 100644 --- a/tox_ltt/plugin.py +++ b/tox_ltt/plugin.py @@ -12,7 +12,7 @@ from light_the_torch.computation_backend import CPUBackend -def extract_force_cpu_help() -> str: +def extract_ltt_option_help(subcommand: str, option: str) -> str: def extract(seq: Sequence, attr: str, eq_cond: Any) -> Any: reduced_seq = [item for item in seq if getattr(item, attr) == eq_cond] assert len(reduced_seq) == 1 @@ -22,9 +22,8 @@ def extract(seq: Sequence, attr: str, eq_cond: Any) -> Any: argument_group = extract(ltt_parser._action_groups, "title", "subcommands") sub_parsers = extract(argument_group._actions, "dest", "subcommand") - install_parser = sub_parsers.choices["install"] - force_cpu = extract(install_parser._actions, "dest", "force_cpu") - return cast(str, force_cpu.help) + subcommand_parser = sub_parsers.choices[subcommand] + return cast(str, extract(subcommand_parser._actions, "dest", option).help) @tox.hookimpl @@ -35,10 +34,16 @@ def tox_addoption(parser: Parser) -> None: help="disable installing PyTorch distributions with light-the-torch", default=False, ) + parser.add_testenv_attribute( + name="pytorch_channel", + type="string", + help=extract_ltt_option_help("install", "channel"), + default="stable", + ) parser.add_testenv_attribute( name="pytorch_force_cpu", type="bool", - help=extract_force_cpu_help(), + help=extract_ltt_option_help("install", "force_cpu"), default=False, ) parser.add_testenv_attribute( @@ -98,6 +103,7 @@ def tox_testenv_install_deps(venv: VirtualEnv, action: Action) -> None: links = ltt.find_links( dists, computation_backend=get_computation_backend(envconfig), + channel=envconfig.pytorch_channel, python_version=get_python_version(envconfig), )