From 6910ee1bfa952386820333da3e501f74fa649c0e Mon Sep 17 00:00:00 2001 From: Anton Zhilin Date: Thu, 20 Feb 2025 14:11:22 +0300 Subject: [PATCH] Allow parametrize to depend on params and marks from previous parametrize --- src/_pytest/fixtures.py | 4 +- src/_pytest/mark/__init__.py | 2 + src/_pytest/mark/structures.py | 16 ++-- src/_pytest/python.py | 120 ++++++++++++++++--------- src/pytest/__init__.py | 2 + testing/python/metafunc.py | 154 +++++++++++++++++++++++++++++++++ 6 files changed, 247 insertions(+), 51 deletions(-) diff --git a/src/_pytest/fixtures.py b/src/_pytest/fixtures.py index dcd06c3b40a..db0a18d2683 100644 --- a/src/_pytest/fixtures.py +++ b/src/_pytest/fixtures.py @@ -77,7 +77,7 @@ if TYPE_CHECKING: - from _pytest.python import CallSpec2 + from _pytest.python import CallSpec from _pytest.python import Function from _pytest.python import Metafunc @@ -184,7 +184,7 @@ def get_parametrized_fixture_argkeys( assert scope is not Scope.Function try: - callspec: CallSpec2 = item.callspec # type: ignore[attr-defined] + callspec: CallSpec = item.callspec # type: ignore[attr-defined] except AttributeError: return diff --git a/src/_pytest/mark/__init__.py b/src/_pytest/mark/__init__.py index efb966c09aa..453f238c1f8 100644 --- a/src/_pytest/mark/__init__.py +++ b/src/_pytest/mark/__init__.py @@ -19,6 +19,7 @@ from .structures import MarkDecorator from .structures import MarkGenerator from .structures import ParameterSet +from .structures import RawParameterSet from _pytest.config import Config from _pytest.config import ExitCode from _pytest.config import hookimpl @@ -38,6 +39,7 @@ "MarkDecorator", "MarkGenerator", "ParameterSet", + "RawParameterSet", "get_empty_parameterset_mark", ] diff --git a/src/_pytest/mark/structures.py b/src/_pytest/mark/structures.py index 1a0b3c5b5b8..b1051a0e547 100644 --- a/src/_pytest/mark/structures.py +++ b/src/_pytest/mark/structures.py @@ -32,6 +32,8 @@ if TYPE_CHECKING: + from typing_extensions import TypeAlias + from ..nodes import Node @@ -65,6 +67,9 @@ def get_empty_parameterset_mark( return mark +RawParameterSet: TypeAlias = "ParameterSet | Sequence[object] | object" + + class ParameterSet(NamedTuple): values: Sequence[object | NotSetType] marks: Collection[MarkDecorator | Mark] @@ -95,7 +100,7 @@ def param( @classmethod def extract_from( cls, - parameterset: ParameterSet | Sequence[object] | object, + parameterset: RawParameterSet, force_tuple: bool = False, ) -> ParameterSet: """Extract from an object or objects. @@ -123,7 +128,6 @@ def extract_from( @staticmethod def _parse_parametrize_args( argnames: str | Sequence[str], - argvalues: Iterable[ParameterSet | Sequence[object] | object], *args, **kwargs, ) -> tuple[Sequence[str], bool]: @@ -136,7 +140,7 @@ def _parse_parametrize_args( @staticmethod def _parse_parametrize_parameters( - argvalues: Iterable[ParameterSet | Sequence[object] | object], + argvalues: Iterable[RawParameterSet], force_tuple: bool, ) -> list[ParameterSet]: return [ @@ -147,12 +151,12 @@ def _parse_parametrize_parameters( def _for_parametrize( cls, argnames: str | Sequence[str], - argvalues: Iterable[ParameterSet | Sequence[object] | object], + argvalues: Iterable[RawParameterSet], func, config: Config, nodeid: str, ) -> tuple[Sequence[str], list[ParameterSet]]: - argnames, force_tuple = cls._parse_parametrize_args(argnames, argvalues) + argnames, force_tuple = cls._parse_parametrize_args(argnames) parameters = cls._parse_parametrize_parameters(argvalues, force_tuple) del argvalues @@ -467,7 +471,7 @@ class _ParametrizeMarkDecorator(MarkDecorator): def __call__( # type: ignore[override] self, argnames: str | Sequence[str], - argvalues: Iterable[ParameterSet | Sequence[object] | object], + argvalues: Iterable[RawParameterSet], *, indirect: bool | Sequence[str] = ..., ids: Iterable[None | str | float | int | bool] diff --git a/src/_pytest/python.py b/src/_pytest/python.py index ef8a5f02b53..9b6a092ac9d 100644 --- a/src/_pytest/python.py +++ b/src/_pytest/python.py @@ -22,6 +22,7 @@ from pathlib import Path import re import types +import typing from typing import Any from typing import final from typing import Literal @@ -56,6 +57,7 @@ from _pytest.fixtures import get_scope_node from _pytest.main import Session from _pytest.mark import ParameterSet +from _pytest.mark import RawParameterSet from _pytest.mark.structures import get_unpacked_marks from _pytest.mark.structures import Mark from _pytest.mark.structures import MarkDecorator @@ -105,6 +107,7 @@ def pytest_addoption(parser: Parser) -> None: ) +@hookimpl(tryfirst=True) def pytest_generate_tests(metafunc: Metafunc) -> None: for marker in metafunc.definition.iter_markers(name="parametrize"): metafunc.parametrize(*marker.args, **marker.kwargs, _param_mark=marker) @@ -1022,27 +1025,34 @@ def _idval_from_argname(argname: str, idx: int) -> str: @final @dataclasses.dataclass(frozen=True) -class CallSpec2: +class CallSpec: """A planned parameterized invocation of a test function. - Calculated during collection for a given test function's Metafunc. - Once collection is over, each callspec is turned into a single Item - and stored in item.callspec. + Calculated during collection for a given test function's ``Metafunc``. + Once collection is over, each callspec is turned into a single ``Item`` + and stored in ``item.callspec``. """ - # arg name -> arg value which will be passed to a fixture or pseudo-fixture - # of the same name. (indirect or direct parametrization respectively) - params: dict[str, object] = dataclasses.field(default_factory=dict) - # arg name -> arg index. - indices: dict[str, int] = dataclasses.field(default_factory=dict) + #: arg name -> arg value which will be passed to a fixture or pseudo-fixture + #: of the same name. (indirect or direct parametrization respectively) + params: Mapping[str, object] = dataclasses.field(default_factory=dict) + #: arg name -> arg index. + indices: Mapping[str, int] = dataclasses.field(default_factory=dict) + #: Marks which will be applied to the item. + marks: Sequence[Mark] = dataclasses.field(default_factory=list) + # Used for sorting parametrized resources. _arg2scope: Mapping[str, Scope] = dataclasses.field(default_factory=dict) # Parts which will be added to the item's name in `[..]` separated by "-". _idlist: Sequence[str] = dataclasses.field(default_factory=tuple) - # Marks which will be applied to the item. - marks: list[Mark] = dataclasses.field(default_factory=list) + # Make __init__ internal. + _ispytest: dataclasses.InitVar[bool] = False + + def __post_init__(self, _ispytest: bool): + """:meta private:""" + check_ispytest(_ispytest) - def setmulti( + def _setmulti( self, *, argnames: Iterable[str], @@ -1051,9 +1061,9 @@ def setmulti( marks: Iterable[Mark | MarkDecorator], scope: Scope, param_index: int, - ) -> CallSpec2: - params = self.params.copy() - indices = self.indices.copy() + ) -> CallSpec: + params = dict(self.params) + indices = dict(self.indices) arg2scope = dict(self._arg2scope) for arg, val in zip(argnames, valset): if arg in params: @@ -1061,15 +1071,17 @@ def setmulti( params[arg] = val indices[arg] = param_index arg2scope[arg] = scope - return CallSpec2( + return CallSpec( params=params, indices=indices, + marks=[*self.marks, *normalize_mark_list(marks)], _arg2scope=arg2scope, _idlist=[*self._idlist, id], - marks=[*self.marks, *normalize_mark_list(marks)], + _ispytest=True, ) def getparam(self, name: str) -> object: + """:meta private:""" try: return self.params[name] except KeyError as e: @@ -1077,6 +1089,7 @@ def getparam(self, name: str) -> object: @property def id(self) -> str: + """The combined display name of ``params``.""" return "-".join(self._idlist) @@ -1130,14 +1143,15 @@ def __init__( self._arg2fixturedefs = fixtureinfo.name2fixturedefs # Result of parametrize(). - self._calls: list[CallSpec2] = [] + self._calls: list[CallSpec] = [] self._params_directness: dict[str, Literal["indirect", "direct"]] = {} def parametrize( self, argnames: str | Sequence[str], - argvalues: Iterable[ParameterSet | Sequence[object] | object], + argvalues: Iterable[RawParameterSet] + | Callable[[CallSpec], Iterable[RawParameterSet]], indirect: bool | Sequence[str] = False, ids: Iterable[object | None] | Callable[[Any], object | None] | None = None, scope: _ScopeName | None = None, @@ -1171,7 +1185,7 @@ def parametrize( If N argnames were specified, argvalues must be a list of N-tuples, where each tuple-element specifies a value for its respective argname. - :type argvalues: Iterable[_pytest.mark.structures.ParameterSet | Sequence[object] | object] + :type argvalues: Iterable[_pytest.mark.structures.ParameterSet | Sequence[object] | object] | Callable :param indirect: A list of arguments' names (subset of argnames) or a boolean. If True the list contains all names from the argnames. Each @@ -1206,13 +1220,19 @@ def parametrize( It will also override any fixture-function defined scope, allowing to set a dynamic scope using test context or configuration. """ - argnames, parametersets = ParameterSet._for_parametrize( - argnames, - argvalues, - self.function, - self.config, - nodeid=self.definition.nodeid, - ) + if callable(argvalues): + raw_argnames = argnames + param_factory = argvalues + argnames, _ = ParameterSet._parse_parametrize_args(raw_argnames) + else: + param_factory = None + argnames, parametersets = ParameterSet._for_parametrize( + argnames, + argvalues, + self.function, + self.config, + nodeid=self.definition.nodeid, + ) del argvalues if "request" in argnames: @@ -1230,19 +1250,22 @@ def parametrize( self._validate_if_using_arg_names(argnames, indirect) - # Use any already (possibly) generated ids with parametrize Marks. - if _param_mark and _param_mark._param_ids_from: - generated_ids = _param_mark._param_ids_from._param_ids_generated - if generated_ids is not None: - ids = generated_ids + if param_factory is None: + # Use any already (possibly) generated ids with parametrize Marks. + if _param_mark and _param_mark._param_ids_from: + generated_ids = _param_mark._param_ids_from._param_ids_generated + if generated_ids is not None: + ids = generated_ids - ids = self._resolve_parameter_set_ids( - argnames, ids, parametersets, nodeid=self.definition.nodeid - ) + ids_ = self._resolve_parameter_set_ids( + argnames, ids, parametersets, nodeid=self.definition.nodeid + ) - # Store used (possibly generated) ids with parametrize Marks. - if _param_mark and _param_mark._param_ids_from and generated_ids is None: - object.__setattr__(_param_mark._param_ids_from, "_param_ids_generated", ids) + # Store used (possibly generated) ids with parametrize Marks. + if _param_mark and _param_mark._param_ids_from and generated_ids is None: + object.__setattr__( + _param_mark._param_ids_from, "_param_ids_generated", ids_ + ) # Add funcargs as fixturedefs to fixtureinfo.arg2fixturedefs by registering # artificial "pseudo" FixtureDef's so that later at test execution time we can @@ -1301,11 +1324,22 @@ def parametrize( # more than once) then we accumulate those calls generating the cartesian product # of all calls. newcalls = [] - for callspec in self._calls or [CallSpec2()]: + for callspec in self._calls or [CallSpec(_ispytest=True)]: + if param_factory: + _, parametersets = ParameterSet._for_parametrize( + raw_argnames, + param_factory(callspec), + self.function, + self.config, + nodeid=self.definition.nodeid, + ) + ids_ = self._resolve_parameter_set_ids( + argnames, ids, parametersets, nodeid=self.definition.nodeid + ) for param_index, (param_id, param_set) in enumerate( - zip(ids, parametersets) + zip(ids_, parametersets) ): - newcallspec = callspec.setmulti( + newcallspec = callspec._setmulti( argnames=argnames, valset=param_set.values, id=param_id, @@ -1453,7 +1487,7 @@ def _recompute_direct_params_indices(self) -> None: for argname, param_type in self._params_directness.items(): if param_type == "direct": for i, callspec in enumerate(self._calls): - callspec.indices[argname] = i + typing.cast(dict[str, int], callspec.indices)[argname] = i def _find_parametrized_scope( @@ -1538,7 +1572,7 @@ def __init__( name: str, parent, config: Config | None = None, - callspec: CallSpec2 | None = None, + callspec: CallSpec | None = None, callobj=NOTSET, keywords: Mapping[str, Any] | None = None, session: Session | None = None, diff --git a/src/pytest/__init__.py b/src/pytest/__init__.py index 70096d6593e..1139aa092a5 100644 --- a/src/pytest/__init__.py +++ b/src/pytest/__init__.py @@ -53,6 +53,7 @@ from _pytest.pytester import Pytester from _pytest.pytester import RecordedHookCall from _pytest.pytester import RunResult +from _pytest.python import CallSpec from _pytest.python import Class from _pytest.python import Function from _pytest.python import Metafunc @@ -91,6 +92,7 @@ __all__ = [ "Cache", "CallInfo", + "CallSpec", "CaptureFixture", "Class", "CollectReport", diff --git a/testing/python/metafunc.py b/testing/python/metafunc.py index 4e7e441768c..bb863b045a1 100644 --- a/testing/python/metafunc.py +++ b/testing/python/metafunc.py @@ -2143,3 +2143,157 @@ def test_converted_to_str(a, b): "*= 6 passed in *", ] ) + + +class TestCovariant: + """Tests related to parametrize with callable argvalues.""" + + def test_basic(self, pytester: Pytester) -> None: + pytester.makepyfile( + """ + import pytest + + def bar_values(callspec: pytest.CallSpec): + return [ + callspec.params["foo"] * 3, + callspec.params["foo"] * 4, + ] + + @pytest.mark.parametrize("bar", bar_values) + @pytest.mark.parametrize("foo", ["a", "b"]) + def test_function(foo, bar): + pass + """ + ) + result = pytester.runpytest("-vv", "-s") + result.stdout.fnmatch_lines( + [ + "test_basic.py::test_function[a-aaa] PASSED", + "test_basic.py::test_function[a-aaaa] PASSED", + "test_basic.py::test_function[b-bbb] PASSED", + "test_basic.py::test_function[b-bbbb] PASSED", + "*= 4 passed in *", + ] + ) + + def test_depend_on_marks(self, pytester: Pytester) -> None: + pytester.makepyfile( + """ + import pytest + + def pytest_generate_tests(metafunc: pytest.Metafunc): + if "bar" in metafunc.fixturenames: + base_bar_marks = list(metafunc.definition.iter_markers("bar_params")) + + def gen_params(callspec: pytest.CallSpec): + bar_marks = base_bar_marks + [ + mark + for mark in callspec.marks + if mark.name == "bar_params" + ] + return [arg for mark in bar_marks for arg in mark.args] + + metafunc.parametrize("bar", gen_params) + + @pytest.mark.bar_params("x") + @pytest.mark.parametrize( + "foo", + [ + "a", + pytest.param("b", marks=[pytest.mark.bar_params("y", "z")]), + pytest.param("c", marks=[pytest.mark.bar_params("w")]), + ], + ) + def test_function(foo, bar): + pass + """ + ) + result = pytester.runpytest("-vv", "-s") + result.stdout.fnmatch_lines( + [ + "test_depend_on_marks.py::test_function[a-x] PASSED", + "test_depend_on_marks.py::test_function[b-x] PASSED", + "test_depend_on_marks.py::test_function[b-y] PASSED", + "test_depend_on_marks.py::test_function[b-z] PASSED", + "test_depend_on_marks.py::test_function[c-x] PASSED", + "test_depend_on_marks.py::test_function[c-w] PASSED", + "*= 6 passed in *", + ] + ) + + def test_id_and_marks(self, pytester: Pytester) -> None: + pytester.makepyfile( + """ + import pytest + + def gen_params(callspec: pytest.CallSpec): + return [ + pytest.param("a", id="aparam", marks=[pytest.mark.foo_value("a")]), + pytest.param("b", id="bparam", marks=[pytest.mark.foo_value("b")]), + ] + + @pytest.mark.parametrize("foo", gen_params) + def test_function(request, foo): + assert request.node.get_closest_marker("foo_value").args[0] == foo + """ + ) + result = pytester.runpytest("-vv", "-s") + result.stdout.fnmatch_lines( + [ + "test_id_and_marks.py::test_function[aparam] PASSED", + "test_id_and_marks.py::test_function[bparam] PASSED", + "*= 2 passed in *", + ] + ) + + def test_invalid_arg_name(self, pytester: Pytester) -> None: + pytester.makepyfile( + """ + import pytest + + def gen_params(callspec: pytest.CallSpec): + assert False, "This function does not need to be called to detect the mistake" + + @pytest.mark.parametrize("foo", gen_params) + def test_function(): + pass + """ + ) + result = pytester.runpytest("--collect-only") + result.stdout.fnmatch_lines( + [ + "collected 0 items / 1 error", + "", + "*= ERRORS =*", + "*_ ERROR collecting test_invalid_arg_name.py _*", + "*In test_function: function uses no argument 'foo'", + "*! Interrupted: 1 error during collection !*", + "*= no tests collected, 1 error in *", + ] + ) + + def test_no_parameter_sets(self, pytester: Pytester) -> None: + pytester.makepyfile( + """ + import pytest + + def gen_params(callspec: pytest.CallSpec): + return range(1, callspec.params["foo"] + 1) + + @pytest.mark.parametrize("bar", gen_params) + @pytest.mark.parametrize("foo", [3, 1, 0]) + def test_function(foo, bar): + pass + """ + ) + result = pytester.runpytest("-vv", "-s") + result.stdout.fnmatch_lines( + [ + "test_no_parameter_sets.py::test_function[[]3-1] PASSED", + "test_no_parameter_sets.py::test_function[[]3-2] PASSED", + "test_no_parameter_sets.py::test_function[[]3-3] PASSED", + "test_no_parameter_sets.py::test_function[[]1-1] PASSED", + "test_no_parameter_sets.py::test_function[[]0-NOTSET] SKIPPED *", + "*= 4 passed, 1 skipped in *", + ] + )