From 9ec9ab7929cc79e1c92b539024ce97924079e89b Mon Sep 17 00:00:00 2001 From: Benedikt Bartscher Date: Fri, 22 Nov 2024 09:57:18 +0100 Subject: [PATCH 1/5] test fast yielding in background task --- tests/integration/test_background_task.py | 30 +++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/tests/integration/test_background_task.py b/tests/integration/test_background_task.py index d7fe208247b..2cf4d6a30ec 100644 --- a/tests/integration/test_background_task.py +++ b/tests/integration/test_background_task.py @@ -42,6 +42,11 @@ async def handle_event_yield_only(self): yield State.increment() # type: ignore await asyncio.sleep(0.005) + @rx.event(background=True) + async def fast_yielding(self): + for _ in range(100): + yield State.increment() + @rx.event def increment(self): self.counter += 1 @@ -375,3 +380,28 @@ def test_yield_in_async_with_self( yield_in_async_with_self_button.click() assert background_task._poll_for(lambda: counter.text == "2", timeout=5) + + +def test_fast_yielding( + background_task: AppHarness, + driver: WebDriver, + token: str, +) -> None: + """Test that fast yielding works as expected. + + Args: + background_task: harness for BackgroundTask app. + driver: WebDriver instance. + token: The token for the connected client. + """ + assert background_task.app_instance is not None + + # get a reference to all buttons + fast_yielding_button = driver.find_element(By.ID, "yield-increment") + + # get a reference to the counter + counter = driver.find_element(By.ID, "counter") + assert background_task._poll_for(lambda: counter.text == "0", timeout=5) + + fast_yielding_button.click() + assert background_task._poll_for(lambda: counter.text == "100", timeout=5) From 03eda2e90ecb3d7ae5ae7a6948b0da3059757f53 Mon Sep 17 00:00:00 2001 From: Benedikt Bartscher Date: Sat, 23 Nov 2024 11:44:58 +0100 Subject: [PATCH 2/5] accidentally pushed unfinished changes --- tests/integration/test_background_task.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/tests/integration/test_background_task.py b/tests/integration/test_background_task.py index 2cf4d6a30ec..70e2202a68d 100644 --- a/tests/integration/test_background_task.py +++ b/tests/integration/test_background_task.py @@ -44,7 +44,7 @@ async def handle_event_yield_only(self): @rx.event(background=True) async def fast_yielding(self): - for _ in range(100): + for _ in range(1000): yield State.increment() @rx.event @@ -174,6 +174,11 @@ def index() -> rx.Component: on_click=State.yield_in_async_with_self, id="yield-in-async-with-self", ), + rx.button( + "Fast Yielding", + on_click=State.fast_yielding, + id="fast-yielding", + ), rx.button("Reset", on_click=State.reset_counter, id="reset"), ) @@ -397,11 +402,11 @@ def test_fast_yielding( assert background_task.app_instance is not None # get a reference to all buttons - fast_yielding_button = driver.find_element(By.ID, "yield-increment") + fast_yielding_button = driver.find_element(By.ID, "fast-yielding") # get a reference to the counter counter = driver.find_element(By.ID, "counter") assert background_task._poll_for(lambda: counter.text == "0", timeout=5) fast_yielding_button.click() - assert background_task._poll_for(lambda: counter.text == "100", timeout=5) + assert background_task._poll_for(lambda: counter.text == "1000", timeout=50) From 389f4c7196784bcbdf38260c05a32a79e3114a42 Mon Sep 17 00:00:00 2001 From: Benedikt Bartscher Date: Sat, 23 Nov 2024 20:37:39 +0100 Subject: [PATCH 3/5] fix: only open one connection/sub for each token per worker bonus: properly cleanup StateManager connections on disconnect --- reflex/app.py | 5 +++-- reflex/state.py | 42 +++++++++++++++++++++++++++++++++++++++++- 2 files changed, 44 insertions(+), 3 deletions(-) diff --git a/reflex/app.py b/reflex/app.py index afc40e3b880..9c89868f451 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -1477,7 +1477,7 @@ def __init__(self, namespace: str, app: App): super().__init__(namespace) self.app = app - def on_connect(self, sid, environ): + async def on_connect(self, sid, environ): """Event for when the websocket is connected. Args: @@ -1486,7 +1486,7 @@ def on_connect(self, sid, environ): """ pass - def on_disconnect(self, sid): + async def on_disconnect(self, sid): """Event for when the websocket disconnects. Args: @@ -1495,6 +1495,7 @@ def on_disconnect(self, sid): disconnect_token = self.sid_to_token.pop(sid, None) if disconnect_token: self.token_to_sid.pop(disconnect_token, None) + await self.app.state_manager.disconnect(sid) async def emit_update(self, update: StateUpdate, sid: str) -> None: """Emit an update to the client. diff --git a/reflex/state.py b/reflex/state.py index 349dc59e973..7a7d7f43eb8 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -2826,6 +2826,14 @@ async def modify_state(self, token: str) -> AsyncIterator[BaseState]: """ yield self.state() + async def disconnect(self, token: str) -> None: + """Disconnect the client with the given token. + + Args: + token: The token to disconnect. + """ + pass + class StateManagerMemory(StateManager): """A state manager that stores states in memory.""" @@ -2895,6 +2903,20 @@ async def modify_state(self, token: str) -> AsyncIterator[BaseState]: yield state await self.set_state(token, state) + @override + async def disconnect(self, token: str) -> None: + """Disconnect the client with the given token. + + Args: + token: The token to disconnect. + """ + if token in self.states: + del self.states[token] + if lock := self._states_locks.get(token): + if lock.locked(): + lock.release() + del self._states_locks[token] + def _default_token_expiration() -> int: """Get the default token expiration time. @@ -3183,6 +3205,9 @@ class StateManagerRedis(StateManager): b"evicted", } + # This lock is used to ensure we only subscribe to keyspace events once per token and worker + _pubsub_locks: Dict[bytes, asyncio.Lock] = pydantic.PrivateAttr({}) + async def _get_parent_state( self, token: str, state: BaseState | None = None ) -> BaseState | None: @@ -3458,7 +3483,9 @@ async def _wait_lock(self, lock_key: bytes, lock_id: bytes) -> None: # Some redis servers only allow out-of-band configuration, so ignore errors here. if not environment.REFLEX_IGNORE_REDIS_CONFIG_ERROR.get(): raise - async with self.redis.pubsub() as pubsub: + if lock_key not in self._pubsub_locks: + self._pubsub_locks[lock_key] = asyncio.Lock() + async with self._pubsub_locks[lock_key], self.redis.pubsub() as pubsub: await pubsub.psubscribe(lock_key_channel) while not state_is_locked: # wait for the lock to be released @@ -3475,6 +3502,19 @@ async def _wait_lock(self, lock_key: bytes, lock_id: bytes) -> None: break state_is_locked = await self._try_get_lock(lock_key, lock_id) + @override + async def disconnect(self, token: str): + """Disconnect the token from the redis client. + + Args: + token: The token to disconnect. + """ + lock_key = self._lock_key(token) + if lock := self._pubsub_locks.get(lock_key): + if lock.locked(): + lock.release() + del self._pubsub_locks[lock_key] + @contextlib.asynccontextmanager async def _lock(self, token: str): """Obtain a redis lock for a token. From 7fe33c9bf1917cb59f0384a1220025348deede6c Mon Sep 17 00:00:00 2001 From: Benedikt Bartscher Date: Mon, 25 Nov 2024 00:28:10 +0100 Subject: [PATCH 4/5] run ci From 198d02cb9ba6c22cf62dcccc69b77dfd8eb259e1 Mon Sep 17 00:00:00 2001 From: Benedikt Bartscher Date: Mon, 25 Nov 2024 13:54:53 +0100 Subject: [PATCH 5/5] wip shared pubsub and cached _substate_key, garbage collect pubsubs tbd --- reflex/state.py | 51 ++++++++++++++++------- tests/integration/test_background_task.py | 4 +- tests/units/test_state.py | 6 ++- 3 files changed, 42 insertions(+), 19 deletions(-) diff --git a/reflex/state.py b/reflex/state.py index 7a7d7f43eb8..21acee67b6c 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -39,6 +39,7 @@ get_type_hints, ) +from redis.asyncio.client import PubSub from sqlalchemy.orm import DeclarativeBase from typing_extensions import Self @@ -135,7 +136,7 @@ def _no_chain_background_task( - state_cls: Type["BaseState"], name: str, fn: Callable + state_cls: Type[BaseState], name: str, fn: Callable ) -> Callable: """Protect against directly chaining a background task from another event handler. @@ -172,9 +173,10 @@ async def _no_chain_background_task_gen(*args, **kwargs): raise TypeError(f"{fn} is marked as a background task, but is not async.") +@functools.lru_cache() def _substate_key( token: str, - state_cls_or_name: BaseState | Type[BaseState] | str | Sequence[str], + state_cls_or_name: Type[BaseState] | str | Sequence[str], ) -> str: """Get the substate key. @@ -185,9 +187,7 @@ def _substate_key( Returns: The substate key. """ - if isinstance(state_cls_or_name, BaseState) or ( - isinstance(state_cls_or_name, type) and issubclass(state_cls_or_name, BaseState) - ): + if isinstance(state_cls_or_name, type) and issubclass(state_cls_or_name, BaseState): state_cls_or_name = state_cls_or_name.get_full_name() elif isinstance(state_cls_or_name, (list, tuple)): state_cls_or_name = ".".join(state_cls_or_name) @@ -301,7 +301,16 @@ def get_var_for_field(cls: Type[BaseState], f: ModelField): ) -class BaseState(Base, ABC, extra=pydantic.Extra.allow): +class HashableModelMetaclass(type(Base)): + def __hash__(self): + return id(self) + # return hash(f"{self.__module__}.{self.__name__}") + # return hash(self.get_full_name()) + + +class BaseState( + Base, ABC, extra=pydantic.Extra.allow, metaclass=HashableModelMetaclass +): """The state of the app.""" # A map from the var name to the var. @@ -3066,17 +3075,17 @@ async def populate_substates( state: The state object to populate. root_state: The root state object. """ - for substate in state.get_substates(): - substate_token = _substate_key(client_token, substate) + for substate_cls in state.get_substates(): + substate_token = _substate_key(client_token, substate_cls) - fresh_instance = await root_state.get_state(substate) + fresh_instance = await root_state.get_state(substate_cls) instance = await self.load_state(substate_token) if instance is not None: # Ensure all substates exist, even if they weren't serialized previously. instance.substates = fresh_instance.substates else: instance = fresh_instance - state.substates[substate.get_name()] = instance + state.substates[substate_cls.get_name()] = instance instance.parent_state = state await self.populate_substates(client_token, instance, root_state) @@ -3120,7 +3129,7 @@ async def set_state_for_substate(self, client_token: str, substate: BaseState): client_token: The client token. substate: The substate to set. """ - substate_token = _substate_key(client_token, substate) + substate_token = _substate_key(client_token, type(substate)) if substate._get_was_touched(): substate._was_touched = False # Reset the touched flag after serializing. @@ -3177,6 +3186,18 @@ def _default_lock_expiration() -> int: return get_config().redis_lock_expiration +PUBSUB_CLIENTS: Dict[str, PubSub] = {} + + +async def cached_pubsub(redis: Redis, lock_key_channel: str) -> PubSub: + if lock_key_channel in PUBSUB_CLIENTS: + return PUBSUB_CLIENTS[lock_key_channel] + pubsub = redis.pubsub() + await pubsub.psubscribe(lock_key_channel) + PUBSUB_CLIENTS[lock_key_channel] = pubsub + return pubsub + + class StateManagerRedis(StateManager): """A state manager that stores states in redis.""" @@ -3392,7 +3413,7 @@ async def set_state( tasks.append( asyncio.create_task( self.set_state( - token=_substate_key(client_token, substate), + token=_substate_key(client_token, type(substate)), state=substate, lock_id=lock_id, ) @@ -3403,7 +3424,7 @@ async def set_state( pickle_state = state._serialize() if pickle_state: await self.redis.set( - _substate_key(client_token, state), + _substate_key(client_token, type(state)), pickle_state, ex=self.token_expiration, ) @@ -3485,8 +3506,8 @@ async def _wait_lock(self, lock_key: bytes, lock_id: bytes) -> None: raise if lock_key not in self._pubsub_locks: self._pubsub_locks[lock_key] = asyncio.Lock() - async with self._pubsub_locks[lock_key], self.redis.pubsub() as pubsub: - await pubsub.psubscribe(lock_key_channel) + async with self._pubsub_locks[lock_key]: + pubsub = await cached_pubsub(self.redis, lock_key_channel) while not state_is_locked: # wait for the lock to be released while True: diff --git a/tests/integration/test_background_task.py b/tests/integration/test_background_task.py index 70e2202a68d..fa8b55d03f5 100644 --- a/tests/integration/test_background_task.py +++ b/tests/integration/test_background_task.py @@ -44,7 +44,7 @@ async def handle_event_yield_only(self): @rx.event(background=True) async def fast_yielding(self): - for _ in range(1000): + for _ in range(10000): yield State.increment() @rx.event @@ -409,4 +409,4 @@ def test_fast_yielding( assert background_task._poll_for(lambda: counter.text == "0", timeout=5) fast_yielding_button.click() - assert background_task._poll_for(lambda: counter.text == "1000", timeout=50) + assert background_task._poll_for(lambda: counter.text == "10000", timeout=50) diff --git a/tests/units/test_state.py b/tests/units/test_state.py index c8a52e6c0b7..2c18699bac7 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -1759,7 +1759,7 @@ def substate_token_redis(state_manager_redis, token): Returns: Token concatenated with the state_manager's state full_name. """ - return _substate_key(token, state_manager_redis.state) + return _substate_key(token, type(state_manager_redis.state)) @pytest.mark.asyncio @@ -1918,7 +1918,9 @@ async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App): # Get the state from the state manager directly and check that the value is updated gotten_state = await mock_app.state_manager.get_state( - _substate_key(grandchild_state.router.session.client_token, grandchild_state) + _substate_key( + grandchild_state.router.session.client_token, type(grandchild_state) + ) ) if isinstance(mock_app.state_manager, (StateManagerMemory, StateManagerDisk)): # For in-process store, only one instance of the state exists