Skip to content

Commit

Permalink
Allow parametrize to depend on params and marks from previous paramet…
Browse files Browse the repository at this point in the history
…rize
  • Loading branch information
Anton3 committed Feb 20, 2025
1 parent d126389 commit 6910ee1
Show file tree
Hide file tree
Showing 6 changed files with 247 additions and 51 deletions.
4 changes: 2 additions & 2 deletions src/_pytest/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@


if TYPE_CHECKING:
from _pytest.python import CallSpec2
from _pytest.python import CallSpec
from _pytest.python import Function
from _pytest.python import Metafunc

Expand Down Expand Up @@ -184,7 +184,7 @@ def get_parametrized_fixture_argkeys(
assert scope is not Scope.Function

try:
callspec: CallSpec2 = item.callspec # type: ignore[attr-defined]
callspec: CallSpec = item.callspec # type: ignore[attr-defined]
except AttributeError:
return

Expand Down
2 changes: 2 additions & 0 deletions src/_pytest/mark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .structures import MarkDecorator
from .structures import MarkGenerator
from .structures import ParameterSet
from .structures import RawParameterSet
from _pytest.config import Config
from _pytest.config import ExitCode
from _pytest.config import hookimpl
Expand All @@ -38,6 +39,7 @@
"MarkDecorator",
"MarkGenerator",
"ParameterSet",
"RawParameterSet",
"get_empty_parameterset_mark",
]

Expand Down
16 changes: 10 additions & 6 deletions src/_pytest/mark/structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@


if TYPE_CHECKING:
from typing_extensions import TypeAlias

from ..nodes import Node


Expand Down Expand Up @@ -65,6 +67,9 @@ def get_empty_parameterset_mark(
return mark


RawParameterSet: TypeAlias = "ParameterSet | Sequence[object] | object"


class ParameterSet(NamedTuple):
values: Sequence[object | NotSetType]
marks: Collection[MarkDecorator | Mark]
Expand Down Expand Up @@ -95,7 +100,7 @@ def param(
@classmethod
def extract_from(
cls,
parameterset: ParameterSet | Sequence[object] | object,
parameterset: RawParameterSet,
force_tuple: bool = False,
) -> ParameterSet:
"""Extract from an object or objects.
Expand Down Expand Up @@ -123,7 +128,6 @@ def extract_from(
@staticmethod
def _parse_parametrize_args(
argnames: str | Sequence[str],
argvalues: Iterable[ParameterSet | Sequence[object] | object],
*args,
**kwargs,
) -> tuple[Sequence[str], bool]:
Expand All @@ -136,7 +140,7 @@ def _parse_parametrize_args(

@staticmethod
def _parse_parametrize_parameters(
argvalues: Iterable[ParameterSet | Sequence[object] | object],
argvalues: Iterable[RawParameterSet],
force_tuple: bool,
) -> list[ParameterSet]:
return [
Expand All @@ -147,12 +151,12 @@ def _parse_parametrize_parameters(
def _for_parametrize(
cls,
argnames: str | Sequence[str],
argvalues: Iterable[ParameterSet | Sequence[object] | object],
argvalues: Iterable[RawParameterSet],
func,
config: Config,
nodeid: str,
) -> tuple[Sequence[str], list[ParameterSet]]:
argnames, force_tuple = cls._parse_parametrize_args(argnames, argvalues)
argnames, force_tuple = cls._parse_parametrize_args(argnames)
parameters = cls._parse_parametrize_parameters(argvalues, force_tuple)
del argvalues

Expand Down Expand Up @@ -467,7 +471,7 @@ class _ParametrizeMarkDecorator(MarkDecorator):
def __call__( # type: ignore[override]
self,
argnames: str | Sequence[str],
argvalues: Iterable[ParameterSet | Sequence[object] | object],
argvalues: Iterable[RawParameterSet],
*,
indirect: bool | Sequence[str] = ...,
ids: Iterable[None | str | float | int | bool]
Expand Down
120 changes: 77 additions & 43 deletions src/_pytest/python.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from pathlib import Path
import re
import types
import typing
from typing import Any
from typing import final
from typing import Literal
Expand Down Expand Up @@ -56,6 +57,7 @@
from _pytest.fixtures import get_scope_node
from _pytest.main import Session
from _pytest.mark import ParameterSet
from _pytest.mark import RawParameterSet
from _pytest.mark.structures import get_unpacked_marks
from _pytest.mark.structures import Mark
from _pytest.mark.structures import MarkDecorator
Expand Down Expand Up @@ -105,6 +107,7 @@ def pytest_addoption(parser: Parser) -> None:
)


@hookimpl(tryfirst=True)
def pytest_generate_tests(metafunc: Metafunc) -> None:
for marker in metafunc.definition.iter_markers(name="parametrize"):
metafunc.parametrize(*marker.args, **marker.kwargs, _param_mark=marker)
Expand Down Expand Up @@ -1022,27 +1025,34 @@ def _idval_from_argname(argname: str, idx: int) -> str:

@final
@dataclasses.dataclass(frozen=True)
class CallSpec2:
class CallSpec:
"""A planned parameterized invocation of a test function.
Calculated during collection for a given test function's Metafunc.
Once collection is over, each callspec is turned into a single Item
and stored in item.callspec.
Calculated during collection for a given test function's ``Metafunc``.
Once collection is over, each callspec is turned into a single ``Item``
and stored in ``item.callspec``.
"""

# arg name -> arg value which will be passed to a fixture or pseudo-fixture
# of the same name. (indirect or direct parametrization respectively)
params: dict[str, object] = dataclasses.field(default_factory=dict)
# arg name -> arg index.
indices: dict[str, int] = dataclasses.field(default_factory=dict)
#: arg name -> arg value which will be passed to a fixture or pseudo-fixture
#: of the same name. (indirect or direct parametrization respectively)
params: Mapping[str, object] = dataclasses.field(default_factory=dict)
#: arg name -> arg index.
indices: Mapping[str, int] = dataclasses.field(default_factory=dict)
#: Marks which will be applied to the item.
marks: Sequence[Mark] = dataclasses.field(default_factory=list)

# Used for sorting parametrized resources.
_arg2scope: Mapping[str, Scope] = dataclasses.field(default_factory=dict)
# Parts which will be added to the item's name in `[..]` separated by "-".
_idlist: Sequence[str] = dataclasses.field(default_factory=tuple)
# Marks which will be applied to the item.
marks: list[Mark] = dataclasses.field(default_factory=list)
# Make __init__ internal.
_ispytest: dataclasses.InitVar[bool] = False

def __post_init__(self, _ispytest: bool):
""":meta private:"""
check_ispytest(_ispytest)

def setmulti(
def _setmulti(
self,
*,
argnames: Iterable[str],
Expand All @@ -1051,32 +1061,35 @@ def setmulti(
marks: Iterable[Mark | MarkDecorator],
scope: Scope,
param_index: int,
) -> CallSpec2:
params = self.params.copy()
indices = self.indices.copy()
) -> CallSpec:
params = dict(self.params)
indices = dict(self.indices)
arg2scope = dict(self._arg2scope)
for arg, val in zip(argnames, valset):
if arg in params:
raise ValueError(f"duplicate parametrization of {arg!r}")
params[arg] = val
indices[arg] = param_index
arg2scope[arg] = scope
return CallSpec2(
return CallSpec(
params=params,
indices=indices,
marks=[*self.marks, *normalize_mark_list(marks)],
_arg2scope=arg2scope,
_idlist=[*self._idlist, id],
marks=[*self.marks, *normalize_mark_list(marks)],
_ispytest=True,
)

def getparam(self, name: str) -> object:
""":meta private:"""
try:
return self.params[name]
except KeyError as e:
raise ValueError(name) from e

@property
def id(self) -> str:
"""The combined display name of ``params``."""
return "-".join(self._idlist)


Expand Down Expand Up @@ -1130,14 +1143,15 @@ def __init__(
self._arg2fixturedefs = fixtureinfo.name2fixturedefs

# Result of parametrize().
self._calls: list[CallSpec2] = []
self._calls: list[CallSpec] = []

self._params_directness: dict[str, Literal["indirect", "direct"]] = {}

def parametrize(
self,
argnames: str | Sequence[str],
argvalues: Iterable[ParameterSet | Sequence[object] | object],
argvalues: Iterable[RawParameterSet]
| Callable[[CallSpec], Iterable[RawParameterSet]],
indirect: bool | Sequence[str] = False,
ids: Iterable[object | None] | Callable[[Any], object | None] | None = None,
scope: _ScopeName | None = None,
Expand Down Expand Up @@ -1171,7 +1185,7 @@ def parametrize(
If N argnames were specified, argvalues must be a list of
N-tuples, where each tuple-element specifies a value for its
respective argname.
:type argvalues: Iterable[_pytest.mark.structures.ParameterSet | Sequence[object] | object]
:type argvalues: Iterable[_pytest.mark.structures.ParameterSet | Sequence[object] | object] | Callable
:param indirect:
A list of arguments' names (subset of argnames) or a boolean.
If True the list contains all names from the argnames. Each
Expand Down Expand Up @@ -1206,13 +1220,19 @@ def parametrize(
It will also override any fixture-function defined scope, allowing
to set a dynamic scope using test context or configuration.
"""
argnames, parametersets = ParameterSet._for_parametrize(
argnames,
argvalues,
self.function,
self.config,
nodeid=self.definition.nodeid,
)
if callable(argvalues):
raw_argnames = argnames
param_factory = argvalues
argnames, _ = ParameterSet._parse_parametrize_args(raw_argnames)
else:
param_factory = None
argnames, parametersets = ParameterSet._for_parametrize(
argnames,
argvalues,
self.function,
self.config,
nodeid=self.definition.nodeid,
)
del argvalues

if "request" in argnames:
Expand All @@ -1230,19 +1250,22 @@ def parametrize(

self._validate_if_using_arg_names(argnames, indirect)

# Use any already (possibly) generated ids with parametrize Marks.
if _param_mark and _param_mark._param_ids_from:
generated_ids = _param_mark._param_ids_from._param_ids_generated
if generated_ids is not None:
ids = generated_ids
if param_factory is None:
# Use any already (possibly) generated ids with parametrize Marks.
if _param_mark and _param_mark._param_ids_from:
generated_ids = _param_mark._param_ids_from._param_ids_generated
if generated_ids is not None:
ids = generated_ids

ids = self._resolve_parameter_set_ids(
argnames, ids, parametersets, nodeid=self.definition.nodeid
)
ids_ = self._resolve_parameter_set_ids(
argnames, ids, parametersets, nodeid=self.definition.nodeid
)

# Store used (possibly generated) ids with parametrize Marks.
if _param_mark and _param_mark._param_ids_from and generated_ids is None:
object.__setattr__(_param_mark._param_ids_from, "_param_ids_generated", ids)
# Store used (possibly generated) ids with parametrize Marks.
if _param_mark and _param_mark._param_ids_from and generated_ids is None:
object.__setattr__(
_param_mark._param_ids_from, "_param_ids_generated", ids_
)

# Add funcargs as fixturedefs to fixtureinfo.arg2fixturedefs by registering
# artificial "pseudo" FixtureDef's so that later at test execution time we can
Expand Down Expand Up @@ -1301,11 +1324,22 @@ def parametrize(
# more than once) then we accumulate those calls generating the cartesian product
# of all calls.
newcalls = []
for callspec in self._calls or [CallSpec2()]:
for callspec in self._calls or [CallSpec(_ispytest=True)]:
if param_factory:
_, parametersets = ParameterSet._for_parametrize(
raw_argnames,
param_factory(callspec),
self.function,
self.config,
nodeid=self.definition.nodeid,
)
ids_ = self._resolve_parameter_set_ids(
argnames, ids, parametersets, nodeid=self.definition.nodeid
)
for param_index, (param_id, param_set) in enumerate(
zip(ids, parametersets)
zip(ids_, parametersets)
):
newcallspec = callspec.setmulti(
newcallspec = callspec._setmulti(
argnames=argnames,
valset=param_set.values,
id=param_id,
Expand Down Expand Up @@ -1453,7 +1487,7 @@ def _recompute_direct_params_indices(self) -> None:
for argname, param_type in self._params_directness.items():
if param_type == "direct":
for i, callspec in enumerate(self._calls):
callspec.indices[argname] = i
typing.cast(dict[str, int], callspec.indices)[argname] = i


def _find_parametrized_scope(
Expand Down Expand Up @@ -1538,7 +1572,7 @@ def __init__(
name: str,
parent,
config: Config | None = None,
callspec: CallSpec2 | None = None,
callspec: CallSpec | None = None,
callobj=NOTSET,
keywords: Mapping[str, Any] | None = None,
session: Session | None = None,
Expand Down
2 changes: 2 additions & 0 deletions src/pytest/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
from _pytest.pytester import Pytester
from _pytest.pytester import RecordedHookCall
from _pytest.pytester import RunResult
from _pytest.python import CallSpec
from _pytest.python import Class
from _pytest.python import Function
from _pytest.python import Metafunc
Expand Down Expand Up @@ -91,6 +92,7 @@
__all__ = [
"Cache",
"CallInfo",
"CallSpec",
"CaptureFixture",
"Class",
"CollectReport",
Expand Down
Loading

0 comments on commit 6910ee1

Please sign in to comment.