diff --git a/benchmarks/benchmark_state_manager_redis.py b/benchmarks/benchmark_state_manager_redis.py new file mode 100644 index 00000000000..7690b7001a7 --- /dev/null +++ b/benchmarks/benchmark_state_manager_redis.py @@ -0,0 +1,309 @@ +"""Benchmark for the state manager redis.""" + +import asyncio +from uuid import uuid4 + +import pytest +from pytest_benchmark.fixture import BenchmarkFixture + +from reflex.state import State, StateManagerRedis +from reflex.utils.prerequisites import get_redis +from reflex.vars.base import computed_var + + +class RootState(State): + """Root state class for testing.""" + + counter: int = 0 + int_dict: dict[str, int] = {} + + +class ChildState(RootState): + """Child state class for testing.""" + + child_counter: int = 0 + + @computed_var + def str_dict(self): + """Convert the int dict to a string dict. + + Returns: + A dictionary with string keys and integer values. + """ + return {str(k): v for k, v in self.int_dict.items()} + + +class ChildState2(RootState): + """Child state 2 class for testing.""" + + child2_counter: int = 0 + + +class GrandChildState(ChildState): + """Grandchild state class for testing.""" + + grand_child_counter: int = 0 + float_dict: dict[str, float] = {} + + @computed_var + def double_counter(self): + """Double the counter. + + Returns: + The counter value multiplied by 2. + """ + return self.counter * 2 + + +@pytest.fixture +def state_manager() -> StateManagerRedis: + """Fixture for the redis state manager. + + Returns: + An instance of StateManagerRedis. + """ + redis = get_redis() + if redis is None: + pytest.skip("Redis is not available") + return StateManagerRedis(redis=redis, state=State) + + +@pytest.fixture +def token() -> str: + """Fixture for the token. + + Returns: + A unique token string. + """ + return str(uuid4()) + + +@pytest.fixture +def grand_child_state_token(token: str) -> str: + """Fixture for the grand child state token. + + Args: + token: The token fixture. + + Returns: + A string combining the token and the grandchild state name. + """ + return f"{token}_{GrandChildState.get_full_name()}" + + +@pytest.fixture +def state_token(token: str) -> str: + """Fixture for the base state token. + + Args: + token: The token fixture. + + Returns: + A string combining the token and the base state name. + """ + return f"{token}_{State.get_full_name()}" + + +@pytest.fixture +def grand_child_state() -> GrandChildState: + """Fixture for the grand child state. + + Returns: + An instance of GrandChildState. + """ + state = State() + + root = RootState() + root.parent_state = state + state.substates[root.get_name()] = root + + child = ChildState() + child.parent_state = root + root.substates[child.get_name()] = child + + child2 = ChildState2() + child2.parent_state = root + root.substates[child2.get_name()] = child2 + + gcs = GrandChildState() + gcs.parent_state = child + child.substates[gcs.get_name()] = gcs + + return gcs + + +@pytest.fixture +def grand_child_state_big(grand_child_state: GrandChildState) -> GrandChildState: + """Fixture for the grand child state with big data. + + Args: + grand_child_state: The grand child state fixture. + + Returns: + An instance of GrandChildState with large data. + """ + grand_child_state.counter = 100 + grand_child_state.child_counter = 200 + grand_child_state.grand_child_counter = 300 + grand_child_state.int_dict = {str(i): i for i in range(10000)} + grand_child_state.float_dict = {str(i): i + 0.5 for i in range(10000)} + return grand_child_state + + +def test_set_state( + benchmark: BenchmarkFixture, + state_manager: StateManagerRedis, + event_loop: asyncio.AbstractEventLoop, + token: str, +) -> None: + """Benchmark setting state with minimal data. + + Args: + benchmark: The benchmark fixture. + state_manager: The state manager fixture. + event_loop: The event loop fixture. + token: The token fixture. + """ + state = State() + + def func(): + event_loop.run_until_complete(state_manager.set_state(token=token, state=state)) + + benchmark(func) + + +def test_get_state( + benchmark: BenchmarkFixture, + state_manager: StateManagerRedis, + event_loop: asyncio.AbstractEventLoop, + state_token: str, +) -> None: + """Benchmark getting state with minimal data. + + Args: + benchmark: The benchmark fixture. + state_manager: The state manager fixture. + event_loop: The event loop fixture. + state_token: The base state token fixture. + """ + state = State() + event_loop.run_until_complete( + state_manager.set_state(token=state_token, state=state) + ) + + def func(): + _ = event_loop.run_until_complete(state_manager.get_state(token=state_token)) + + benchmark(func) + + +def test_set_state_tree_minimal( + benchmark: BenchmarkFixture, + state_manager: StateManagerRedis, + event_loop: asyncio.AbstractEventLoop, + grand_child_state_token: str, + grand_child_state: GrandChildState, +) -> None: + """Benchmark setting state with minimal data. + + Args: + benchmark: The benchmark fixture. + state_manager: The state manager fixture. + event_loop: The event loop fixture. + grand_child_state_token: The grand child state token fixture. + grand_child_state: The grand child state fixture. + """ + + def func(): + event_loop.run_until_complete( + state_manager.set_state( + token=grand_child_state_token, state=grand_child_state + ) + ) + + benchmark(func) + + +def test_get_state_tree_minimal( + benchmark: BenchmarkFixture, + state_manager: StateManagerRedis, + event_loop: asyncio.AbstractEventLoop, + grand_child_state_token: str, + grand_child_state: GrandChildState, +) -> None: + """Benchmark getting state with minimal data. + + Args: + benchmark: The benchmark fixture. + state_manager: The state manager fixture. + event_loop: The event loop fixture. + grand_child_state_token: The grand child state token fixture. + grand_child_state: The grand child state fixture. + """ + event_loop.run_until_complete( + state_manager.set_state(token=grand_child_state_token, state=grand_child_state) + ) + + def func(): + _ = event_loop.run_until_complete( + state_manager.get_state(token=grand_child_state_token) + ) + + benchmark(func) + + +def test_set_state_tree_big( + benchmark: BenchmarkFixture, + state_manager: StateManagerRedis, + event_loop: asyncio.AbstractEventLoop, + grand_child_state_token: str, + grand_child_state_big: GrandChildState, +) -> None: + """Benchmark setting state with minimal data. + + Args: + benchmark: The benchmark fixture. + state_manager: The state manager fixture. + event_loop: The event loop fixture. + grand_child_state_token: The grand child state token fixture. + grand_child_state_big: The grand child state fixture. + """ + + def func(): + event_loop.run_until_complete( + state_manager.set_state( + token=grand_child_state_token, state=grand_child_state_big + ) + ) + + benchmark(func) + + +def test_get_state_tree_big( + benchmark: BenchmarkFixture, + state_manager: StateManagerRedis, + event_loop: asyncio.AbstractEventLoop, + grand_child_state_token: str, + grand_child_state_big: GrandChildState, +) -> None: + """Benchmark getting state with minimal data. + + Args: + benchmark: The benchmark fixture. + state_manager: The state manager fixture. + event_loop: The event loop fixture. + grand_child_state_token: The grand child state token fixture. + grand_child_state_big: The grand child state fixture. + """ + event_loop.run_until_complete( + state_manager.set_state( + token=grand_child_state_token, state=grand_child_state_big + ) + ) + + def func(): + _ = event_loop.run_until_complete( + state_manager.get_state(token=grand_child_state_token) + ) + + benchmark(func) diff --git a/poetry.lock b/poetry.lock index cc778d19b3b..b45ef0b910d 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1147,6 +1147,7 @@ files = [ {file = "nh3-0.2.19-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:00810cd5275f5c3f44b9eb0e521d1a841ee2f8023622de39ffc7d88bd533d8e0"}, {file = "nh3-0.2.19-cp38-abi3-win32.whl", hash = "sha256:7e98621856b0a911c21faa5eef8f8ea3e691526c2433f9afc2be713cb6fbdb48"}, {file = "nh3-0.2.19-cp38-abi3-win_amd64.whl", hash = "sha256:75c7cafb840f24430b009f7368945cb5ca88b2b54bb384ebfba495f16bc9c121"}, + {file = "nh3-0.2.19.tar.gz", hash = "sha256:790056b54c068ff8dceb443eaefb696b84beff58cca6c07afd754d17692a4804"}, ] [[package]] @@ -3076,4 +3077,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "d62cd1897d8f73e9aad9e907beb82be509dc5e33d8f37b36ebf26ad1f3075a9f" +content-hash = "4ef559dcc4b3fd0d88c908cb4df4d7a14e3d021498d3034ad1b9481131abe686" diff --git a/pyproject.toml b/pyproject.toml index 57d49e3f073..1500ab6f850 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,7 @@ psutil = ">=5.9.4,<7.0" pydantic = ">=1.10.2,<3.0" python-multipart = ">=0.0.5,<0.1" python-socketio = ">=5.7.0,<6.0" -redis = ">=4.3.5,<6.0" +redis = ">=5.1.0,<6.0" rich = ">=13.0.0,<14.0" sqlmodel = ">=0.0.14,<0.1" typer = ">=0.4.2,<1.0" diff --git a/reflex/state.py b/reflex/state.py index b181090da17..800b13b79a8 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -956,7 +956,19 @@ def get_class_substate(cls, path: Sequence[str] | str) -> Type[BaseState]: for substate in cls.get_substates(): if path[0] == substate.get_name(): return substate.get_class_substate(path[1:]) - raise ValueError(f"Invalid path: {path}") + raise ValueError(f"Invalid path: {cls.get_full_name()=} {path=}") + + @classmethod + def get_all_substate_classes(cls) -> set[Type[BaseState]]: + """Get all substate classes of the state. + + Returns: + The set of all substate classes. + """ + substates = set(cls.get_substates()) + for substate in cls.get_substates(): + substates.update(substate.get_all_substate_classes()) + return substates @classmethod def get_class_var(cls, path: Sequence[str]) -> Any: @@ -1415,7 +1427,9 @@ def get_substate(self, path: Sequence[str]) -> BaseState: return self path = path[1:] if path[0] not in self.substates: - raise ValueError(f"Invalid path: {path}") + raise ValueError( + f"Invalid path: {path=} {self.get_full_name()=} {self.substates.keys()=}" + ) return self.substates[path[0]].get_substate(path[1:]) @classmethod @@ -1477,6 +1491,63 @@ def _get_parent_states(self) -> list[tuple[str, BaseState]]: parent_states_with_name.append((parent_state.get_full_name(), parent_state)) return parent_states_with_name + def _get_loaded_states(self) -> dict[str, BaseState]: + """Get all loaded states in the state tree. + + Returns: + A list of all loaded states in the state tree. + """ + root_state = self._get_root_state() + d = {root_state.get_full_name(): root_state} + root_state._get_loaded_substates(d) + return d + + def _get_loaded_substates( + self, + loaded_substates: dict[str, BaseState], + ) -> None: + """Get all loaded substates of this state. + + Args: + loaded_substates: A dictionary of loaded substates which will be updated with the substates of this state. + """ + for substate in self.substates.values(): + loaded_substates[substate.get_full_name()] = substate + substate._get_loaded_substates(loaded_substates) + + def _serialize_touched_states(self) -> dict[str, bytes]: + """Serialize all touched states in the state tree. + + Returns: + The serialized states. + """ + root_state = self._get_root_state() + d = {} + if root_state._get_was_touched(): + serialized = root_state._serialize() + if serialized: + d[root_state.get_full_name()] = serialized + root_state._serialize_touched_substates(d) + return d + + def _serialize_touched_substates( + self, + touched_substates: dict[str, bytes], + ) -> None: + """Serialize all touched substates of this state. + + Args: + touched_substates: A dictionary of touched substates which will be updated with the substates of this state. + """ + for substate in self.substates.values(): + substate._serialize_touched_substates(touched_substates) + if not substate._get_was_touched(): + continue + serialized = substate._serialize() + if not serialized: + continue + touched_substates[substate.get_full_name()] = serialized + def _get_root_state(self) -> BaseState: """Get the root state of the state tree. @@ -1884,26 +1955,48 @@ def _dirty_computed_vars( } @classmethod - def _potentially_dirty_substates(cls) -> set[Type[BaseState]]: + def _potentially_dirty_substates(cls) -> set[str]: """Determine substates which could be affected by dirty vars in this state. Returns: - Set of State classes that may need to be fetched to recalc computed vars. + Set of State full names that may need to be fetched to recalc computed vars. """ # _always_dirty_substates need to be fetched to recalc computed vars. fetch_substates = { - cls.get_class_substate((cls.get_name(), *substate_name.split("."))) + f"{cls.get_full_name()}.{substate_name}" for substate_name in cls._always_dirty_substates } for dependent_substates in cls._substate_var_dependencies.values(): fetch_substates.update( { - cls.get_class_substate((cls.get_name(), *substate_name.split("."))) + f"{cls.get_full_name()}.{substate_name}" for substate_name in dependent_substates } ) return fetch_substates + @classmethod + def _recursive_potentially_dirty_substates( + cls, + already_selected: Type[BaseState] | None = None, + ) -> set[str]: + """Recursively determine substates which could be affected by dirty vars in this state. + + Args: + already_selected: The class of the state that has already been selected and needs no further processing. + + Returns: + Set of full state names that may need to be fetched to recalc computed vars. + """ + if already_selected is not None and already_selected == cls: + return set() + fetch_substates = cls._potentially_dirty_substates() + for substate_cls in cls.get_substates(): + fetch_substates.update( + substate_cls._recursive_potentially_dirty_substates(already_selected) + ) + return fetch_substates + def get_delta(self) -> Delta: """Get the delta for the state. @@ -3232,6 +3325,9 @@ class StateManagerRedis(StateManager): default_factory=_default_lock_warning_threshold ) + # If HEXPIRE is not supported, use EXPIRE instead. + _hexpire_not_supported: Optional[bool] = pydantic.PrivateAttr(None) + # The keyspace subscription string when redis is waiting for lock to be released _redis_notify_keyspace_events: str = ( "K" # Enable keyspace notifications (target a particular key) @@ -3248,77 +3344,6 @@ class StateManagerRedis(StateManager): b"evicted", } - async def _get_parent_state( - self, token: str, state: BaseState | None = None - ) -> BaseState | None: - """Get the parent state for the state requested in the token. - - Args: - token: The token to get the state for (_substate_key). - state: The state instance to get parent state for. - - Returns: - The parent state for the state requested by the token or None if there is no such parent. - """ - parent_state = None - client_token, state_path = _split_substate_key(token) - parent_state_name = state_path.rpartition(".")[0] - if parent_state_name: - cached_substates = None - if state is not None: - cached_substates = [state] - # Retrieve the parent state to populate event handlers onto this substate. - parent_state = await self.get_state( - token=_substate_key(client_token, parent_state_name), - top_level=False, - get_substates=False, - cached_substates=cached_substates, - ) - return parent_state - - async def _populate_substates( - self, - token: str, - state: BaseState, - all_substates: bool = False, - ): - """Fetch and link substates for the given state instance. - - There is no return value; the side-effect is that `state` will have `substates` populated, - and each substate will have its `parent_state` set to `state`. - - Args: - token: The token to get the state for. - state: The state instance to populate substates for. - all_substates: Whether to fetch all substates or just required substates. - """ - client_token, _ = _split_substate_key(token) - - if all_substates: - # All substates are requested. - fetch_substates = state.get_substates() - else: - # Only _potentially_dirty_substates need to be fetched to recalc computed vars. - fetch_substates = state._potentially_dirty_substates() - - tasks = {} - # Retrieve the necessary substates from redis. - for substate_cls in fetch_substates: - if substate_cls.get_name() in state.substates: - continue - substate_name = substate_cls.get_name() - tasks[substate_name] = asyncio.create_task( - self.get_state( - token=_substate_key(client_token, substate_cls), - top_level=False, - get_substates=all_substates, - parent_state=state, - ) - ) - - for substate_name, substate_task in tasks.items(): - state.substates[substate_name] = await substate_task - @override async def get_state( self, @@ -3326,7 +3351,6 @@ async def get_state( top_level: bool = True, get_substates: bool = True, parent_state: BaseState | None = None, - cached_substates: list[BaseState] | None = None, ) -> BaseState: """Get the state for a token. @@ -3335,7 +3359,6 @@ async def get_state( top_level: If true, return an instance of the top-level state (self.state). get_substates: If true, also retrieve substates. parent_state: If provided, use this parent_state instead of getting it from redis. - cached_substates: If provided, attach these substates to the state. Returns: The state for the token. @@ -3343,8 +3366,8 @@ async def get_state( Raises: RuntimeError: when the state_cls is not specified in the token """ - # Split the actual token from the fully qualified substate name. - _, state_path = _split_substate_key(token) + # new impl from top to bottomA + client_token, state_path = _split_substate_key(token) if state_path: # Get the State class associated with the given path. state_cls = self.state.get_class_substate(state_path) @@ -3353,44 +3376,94 @@ async def get_state( f"StateManagerRedis requires token to be specified in the form of {{token}}_{{state_full_name}}, but got {token}" ) - # The deserialized or newly created (sub)state instance. - state = None - - # Fetch the serialized substate from redis. - redis_state = await self.redis.get(token) - - if redis_state is not None: - # Deserialize the substate. - with contextlib.suppress(StateSchemaMismatchError): - state = BaseState._deserialize(data=redis_state) - if state is None: - # Key didn't exist or schema mismatch so create a new instance for this token. - state = state_cls( - init_substates=False, - _reflex_internal_init=True, + state_tokens = {state_path} + + # walk up the state path + walk_state_path = state_path + while "." in walk_state_path: + walk_state_path = walk_state_path.rpartition(".")[0] + state_tokens.add(walk_state_path) + + if get_substates: + state_tokens.update( + { + substate.get_full_name() + for substate in state_cls.get_all_substate_classes() + } + ) + state_tokens.update( + self.state._recursive_potentially_dirty_substates( + already_selected=state_cls, + ) ) - # Populate parent state if missing and requested. - if parent_state is None: - parent_state = await self._get_parent_state(token, state) - # Set up Bidirectional linkage between this state and its parent. + else: + state_tokens.update(self.state._recursive_potentially_dirty_substates()) + + loaded_states = {} if parent_state is not None: - parent_state.substates[state.get_name()] = state - state.parent_state = parent_state - # Avoid fetching substates multiple times. - if cached_substates: - for substate in cached_substates: - state.substates[substate.get_name()] = substate - if substate.parent_state is None: - substate.parent_state = state - # Populate substates if requested. - await self._populate_substates(token, state, all_substates=get_substates) - - # To retain compatibility with previous implementation, by default, we return - # the top-level state by chasing `parent_state` pointers up the tree. + loaded_states = parent_state._get_loaded_states() + # remove all states that are already loaded + state_tokens = state_tokens.difference(loaded_states.keys()) + + redis_states = await self.hmget(name=client_token, keys=list(state_tokens)) + redis_states.update(loaded_states) + root_state = redis_states[self.state.get_full_name()] + self.recursive_link_substates(state=root_state, substates=redis_states) + if top_level: - return state._get_root_state() + return root_state + + state = redis_states[state_path] return state + def recursive_link_substates( + self, + state: BaseState, + substates: dict[str, BaseState], + ): + """Recursively link substates to a state. + + Args: + state: The state to link substates to. + substates: The substates to link. + """ + for substate_cls in state.get_substates(): + if substate_cls.get_full_name() not in substates: + continue + substate = substates[substate_cls.get_full_name()] + state.substates[substate.get_name()] = substate + substate.parent_state = state + self.recursive_link_substates( + state=substate, + substates=substates, + ) + + async def hmget(self, name: str, keys: List[str]) -> dict[str, BaseState]: + """Get multiple values from a hash. + + Args: + name: The name of the hash. + keys: The keys to get. + + Returns: + The values. + """ + d = {} + for redis_state in await self.redis.hmget(name=name, keys=keys): # type: ignore + key = keys.pop(0) + state = None + if redis_state is not None: + with contextlib.suppress(StateSchemaMismatchError): + state = BaseState._deserialize(data=redis_state) + if state is None: + state_cls = self.state.get_class_substate(key) + state = state_cls( + init_substates=False, + _reflex_internal_init=True, + ) + d[state.get_full_name()] = state + return d + @override async def set_state( self, @@ -3408,6 +3481,7 @@ async def set_state( Raises: LockExpiredError: If lock_id is provided and the lock for the token is not held by that ID. RuntimeError: If the state instance doesn't match the state name in the token. + ResponseError: If the redis command fails. """ # Check that we're holding the lock. if ( @@ -3437,30 +3511,38 @@ async def set_state( f"Cannot `set_state` with mismatching token {token} and substate {state.get_full_name()}." ) - # Recursively set_state on all known substates. - tasks = [ - asyncio.create_task( - self.set_state( - _substate_key(client_token, substate), - substate, - lock_id, - ) - ) - for substate in state.substates.values() - ] - # Persist only the given state (parents or substates are excluded by BaseState.__getstate__). - if state._get_was_touched(): - pickle_state = state._serialize() - if pickle_state: - await self.redis.set( - _substate_key(client_token, state), - pickle_state, - ex=self.token_expiration, - ) + redis_hashset = state._serialize_touched_states() + + if not redis_hashset: + return - # Wait for substates to be persisted. - for t in tasks: - await t + try: + await self._hset_pipeline(client_token, redis_hashset) + except ResponseError as re: + if "unknown command 'HEXPIRE'" not in str(re): + raise + # HEXPIRE not supported, try again with fallback expire. + self._hexpire_not_supported = True + await self._hset_pipeline(client_token, redis_hashset) + + async def _hset_pipeline(self, client_token: str, redis_hashset: dict[str, bytes]): + """Set multiple fields in a hash with expiration. + + Args: + client_token: The name of the hash. + redis_hashset: The keys and values to set. + """ + pipe = self.redis.pipeline(transaction=False) + pipe.hset(name=client_token, mapping=redis_hashset) + if self._hexpire_not_supported: + pipe.expire(client_token, self.token_expiration) + else: + pipe.hexpire( + client_token, + self.token_expiration, + *redis_hashset.keys(), + ) + await pipe.execute() @override @contextlib.asynccontextmanager diff --git a/reflex/utils/prerequisites.py b/reflex/utils/prerequisites.py index 25e753d093c..42ec94da25e 100644 --- a/reflex/utils/prerequisites.py +++ b/reflex/utils/prerequisites.py @@ -333,10 +333,9 @@ def get_redis() -> Redis | None: Returns: The asynchronous redis client. """ - if isinstance((redis_url_or_options := parse_redis_url()), str): - return Redis.from_url(redis_url_or_options) - elif isinstance(redis_url_or_options, dict): - return Redis(**redis_url_or_options) + redis_url = parse_redis_url() + if redis_url is not None: + return Redis.from_url(redis_url) return None @@ -346,14 +345,13 @@ def get_redis_sync() -> RedisSync | None: Returns: The synchronous redis client. """ - if isinstance((redis_url_or_options := parse_redis_url()), str): - return RedisSync.from_url(redis_url_or_options) - elif isinstance(redis_url_or_options, dict): - return RedisSync(**redis_url_or_options) + redis_url = parse_redis_url() + if redis_url is not None: + return RedisSync.from_url(redis_url) return None -def parse_redis_url() -> str | dict | None: +def parse_redis_url() -> str | None: """Parse the REDIS_URL in config if applicable. Returns: diff --git a/tests/integration/test_client_storage.py b/tests/integration/test_client_storage.py index 2652d6ccb3a..0fd52395745 100644 --- a/tests/integration/test_client_storage.py +++ b/tests/integration/test_client_storage.py @@ -11,7 +11,6 @@ from selenium.webdriver.remote.webdriver import WebDriver from reflex.state import ( - State, StateManagerDisk, StateManagerMemory, StateManagerRedis, @@ -278,6 +277,7 @@ def set_sub_sub(var: str, value: str): set_sub_sub_state_button.click() token = poll_for_token() + assert token is not None # get a reference to all cookie and local storage elements c1 = driver.find_element(By.ID, "c1") @@ -613,16 +613,7 @@ def set_sub_sub(var: str, value: str): # Simulate state expiration if isinstance(client_side.state_manager, StateManagerRedis): - await client_side.state_manager.redis.delete( - _substate_key(token, State.get_full_name()) - ) - await client_side.state_manager.redis.delete(_substate_key(token, state_name)) - await client_side.state_manager.redis.delete( - _substate_key(token, sub_state_name) - ) - await client_side.state_manager.redis.delete( - _substate_key(token, sub_sub_state_name) - ) + await client_side.state_manager.redis.delete(token) elif isinstance(client_side.state_manager, (StateManagerMemory, StateManagerDisk)): del client_side.state_manager.states[token] if isinstance(client_side.state_manager, StateManagerDisk): @@ -678,9 +669,8 @@ async def poll_for_not_hydrated(): # Get the backend state and ensure the values are still set async def get_sub_state(): - root_state = await client_side.get_state( - _substate_key(token or "", sub_state_name) - ) + assert token is not None + root_state = await client_side.get_state(_substate_key(token, sub_state_name)) state = root_state.substates[client_side.get_state_name("_client_side_state")] sub_state = state.substates[ client_side.get_state_name("_client_side_sub_state") diff --git a/tests/units/test_state.py b/tests/units/test_state.py index 912d72f4f15..84a97efb69c 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -31,6 +31,7 @@ import reflex as rx import reflex.config +import reflex.utils.console from reflex import constants from reflex.app import App from reflex.base import Base @@ -1857,7 +1858,7 @@ async def test_state_manager_lock_warning_threshold_contend( substate_token_redis: A token + substate name for looking up in state manager. mocker: Pytest mocker object. """ - console_warn = mocker.patch("reflex.utils.console.warn") + console_warn = mocker.spy(reflex.utils.console, "warn") state_manager_redis.lock_expiration = LOCK_EXPIRATION state_manager_redis.lock_warning_threshold = LOCK_WARNING_THRESHOLD @@ -1875,7 +1876,7 @@ async def _coro_blocker(): await tasks[0] console_warn.assert_called() - assert console_warn.call_count == 7 + assert console_warn.call_count == 1 class CopyingAsyncMock(AsyncMock): @@ -3192,10 +3193,17 @@ class C1(State): def bar(self) -> str: return "" - assert RxState._potentially_dirty_substates() == {State} - assert State._potentially_dirty_substates() == {C1} + assert RxState._potentially_dirty_substates() == {State.get_full_name()} + assert State._potentially_dirty_substates() == {C1.get_full_name()} assert C1._potentially_dirty_substates() == set() + assert RxState._recursive_potentially_dirty_substates() == { + State.get_full_name(), + C1.get_full_name(), + } + assert State._recursive_potentially_dirty_substates() == {C1.get_full_name()} + assert C1._recursive_potentially_dirty_substates() == set() + def test_router_var_dep() -> None: """Test that router var dependencies are correctly tracked.""" @@ -3216,7 +3224,9 @@ def foo(self) -> str: State._init_var_dependency_dicts() assert foo._deps(objclass=RouterVarDepState) == {"router"} - assert RouterVarParentState._potentially_dirty_substates() == {RouterVarDepState} + assert RouterVarParentState._potentially_dirty_substates() == { + RouterVarDepState.get_full_name() + } assert RouterVarParentState._substate_var_dependencies == { "router": {RouterVarDepState.get_name()} } diff --git a/tests/units/test_state_tree.py b/tests/units/test_state_tree.py index ebdd877de2d..f1d100ff284 100644 --- a/tests/units/test_state_tree.py +++ b/tests/units/test_state_tree.py @@ -354,11 +354,11 @@ async def state_manager_redis( ], ) async def test_get_state_tree( - state_manager_redis, - token, - substate_cls, - exp_root_substates, - exp_root_dict_keys, + state_manager_redis: StateManagerRedis, + token: str, + substate_cls: type[BaseState], + exp_root_substates: list[str], + exp_root_dict_keys: list[str], ): """Test getting state trees and assert on which branches are retrieved.