diff --git a/injection/main.py b/injection/main.py index 028c8e2..08b9b6e 100644 --- a/injection/main.py +++ b/injection/main.py @@ -2,8 +2,8 @@ from contextlib import suppress from dataclasses import dataclass -from threading import RLock, get_ident -from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar, overload +from threading import Lock, RLock, get_ident +from typing import TYPE_CHECKING, Any, ClassVar, Generic, Literal, TypeVar, overload from injection.compat import get_frame @@ -21,6 +21,8 @@ "Injection", "ObjectState", "inject", + "lenient_recursion_guard", + "strict_recursion_guard", ) @@ -62,7 +64,11 @@ def __hash__(self) -> int: return self.hash -def default_recursion_guard(early: EarlyObject[object]) -> Never: +def lenient_recursion_guard(early: EarlyObject[object]) -> Never: + pass + + +def strict_recursion_guard(early: EarlyObject[object]) -> Never: msg = f"{early} requested itself" raise RecursionError(msg) @@ -73,9 +79,11 @@ class Injection(Generic[Object_co]): pass_scope: bool = False cache: bool = False cache_per_alias: bool = False - recursion_guard: Callable[[EarlyObject[Any]], object] = default_recursion_guard + recursion_guard: Callable[[EarlyObject[Any]], object] = lenient_recursion_guard debug_info: str | None = None + _reassignment_lock: ClassVar[Lock] = Lock() + def _call_factory(self, scope: Locals) -> Object_co: if self.pass_scope: return self.factory(scope) @@ -118,7 +126,10 @@ def assign_to(self, *aliases: str, scope: Locals) -> None: debug_info=debug_info, ) key = InjectionKey(alias, early) - scope[key] = early + + with self._reassignment_lock: + scope.pop(key, None) + scope[key] = early SENTINEL = object() @@ -247,7 +258,7 @@ def inject( # noqa: PLR0913 pass_scope: bool = False, cache: bool = False, cache_per_alias: bool = False, - recursion_guard: Callable[[EarlyObject[Any]], object] = default_recursion_guard, + recursion_guard: Callable[[EarlyObject[Any]], object] = strict_recursion_guard, debug_info: str | None = None, ) -> None: """ diff --git a/tests/unit_tests/test_main.py b/tests/unit_tests/test_main.py index e164e3b..e63d017 100644 --- a/tests/unit_tests/test_main.py +++ b/tests/unit_tests/test_main.py @@ -5,7 +5,7 @@ import pytest -from injection import Injection, inject +from injection import Injection, inject, lenient_recursion_guard def test_injection_basic() -> None: @@ -211,11 +211,18 @@ def test_injection_recursive_guard() -> None: def factory() -> str: return scope.get("my_alias", "default_value") - inject("my_alias", into=scope, factory=factory) + inject( + "my_alias", into=scope, factory=factory, recursion_guard=lenient_recursion_guard + ) obj = scope["my_alias"] assert obj == "default_value" + inject("my_alias", into=scope, factory=factory) # strict + + with pytest.raises(RecursionError, match="requested itself"): + obj = scope["my_alias"] + def test_injection_with_no_aliases() -> None: """Test that injection with no aliases raises an error."""