diff --git a/tests/test_plugin.py b/tests/test_plugin.py index 46bea5c..380a22a 100644 --- a/tests/test_plugin.py +++ b/tests/test_plugin.py @@ -65,6 +65,7 @@ def get_setup_cfg(name, version, install_requires=None, extra_requires=None): def get_tox_ini( basepython=None, disable_light_the_torch=None, + pytorch_force_cpu=None, force_cpu=None, deps=None, skip_install=False, @@ -89,6 +90,8 @@ def get_tox_ini( lines.append("extras = extra") if disable_light_the_torch is not None: lines.append(f"disable_light_the_torch = {disable_light_the_torch}") + if pytorch_force_cpu is not None: + lines.append(f"pytorch_force_cpu = {pytorch_force_cpu}") if force_cpu is not None: lines.append(f"force_cpu = {force_cpu}") if deps is not None: @@ -107,6 +110,7 @@ def tox_ltt_initproj_( install_requires=None, extra_requires=None, disable_light_the_torch=None, + pytorch_force_cpu=None, force_cpu=None, deps=None, skip_install=False, @@ -126,6 +130,7 @@ def tox_ltt_initproj_( usedevelop=usedevelop, extra=extra_requires is not None, disable_light_the_torch=disable_light_the_torch, + pytorch_force_cpu=pytorch_force_cpu, force_cpu=force_cpu, deps=deps, pep517=pep517, @@ -146,7 +151,7 @@ def test_help_ini(cmd): result = cmd("--help-ini") result.assert_success(is_run_test_env=False) assert "disable_light_the_torch" in result.out - assert "force_cpu" in result.out + assert "pytorch_force_cpu" in result.out @pytest.mark.slow @@ -161,7 +166,24 @@ def test_tox_ltt_disabled(patch_extract_dists, tox_ltt_initproj, cmd): @pytest.mark.slow -def test_tox_ltt_force_cpu(patch_find_links, tox_ltt_initproj, cmd, install_mock): +def test_tox_ltt_pytorch_force_cpu( + patch_find_links, tox_ltt_initproj, cmd, install_mock +): + mock = patch_find_links() + tox_ltt_initproj(deps=("torch",), pytorch_force_cpu=True) + + result = cmd() + + result.assert_success(is_run_test_env=False) + + _, kwargs = mock.call_args + assert kwargs["computation_backend"] == CPUBackend() + + +@pytest.mark.slow +def test_tox_ltt_force_cpu_legacy( + patch_find_links, tox_ltt_initproj, cmd, install_mock +): mock = patch_find_links() tox_ltt_initproj(deps=("torch",), force_cpu=True) diff --git a/tox_ltt/plugin.py b/tox_ltt/plugin.py index 853b24b..8f3b1fb 100644 --- a/tox_ltt/plugin.py +++ b/tox_ltt/plugin.py @@ -35,9 +35,17 @@ def tox_addoption(parser: Parser) -> None: help="disable installing PyTorch distributions with light-the-torch", default=False, ) - parser.add_testenv_attribute( - name="force_cpu", type="bool", help=extract_force_cpu_help(), default=False, + name="pytorch_force_cpu", + type="bool", + help=extract_force_cpu_help(), + default=False, + ) + parser.add_testenv_attribute( + name="force_cpu", + type="bool", + help="Deprecated alias of 'pytorch_force_cpu'.", + default=False, ) @@ -102,14 +110,23 @@ def remove_extras(dists: List[str]) -> List[str]: return [dist.split(";")[0] for dist in dists] +def _resolve_force_cpu(new: bool, legacy: bool) -> bool: + if legacy: + reporter.warning("The option 'force_cpu' was renamed to 'pytorch_force_cpu'.") + return True + + return new + + def get_computation_backend(envconfig: TestenvConfig) -> Optional[CPUBackend]: - if not envconfig.force_cpu: + force_cpu = _resolve_force_cpu(envconfig.pytorch_force_cpu, envconfig.force_cpu) + if not force_cpu: return None reporter.verbosity1( ( "Using CPU as computation backend instead of auto-detecting since " - "'force_cpu = True' is configured." + "'pytorch_force_cpu = True' is configured." ), ) return CPUBackend()