Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP - Shared redis subscribe #4432

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions reflex/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -1477,7 +1477,7 @@ def __init__(self, namespace: str, app: App):
super().__init__(namespace)
self.app = app

def on_connect(self, sid, environ):
async def on_connect(self, sid, environ):
"""Event for when the websocket is connected.

Args:
Expand All @@ -1486,7 +1486,7 @@ def on_connect(self, sid, environ):
"""
pass

def on_disconnect(self, sid):
async def on_disconnect(self, sid):
"""Event for when the websocket disconnects.

Args:
Expand All @@ -1495,6 +1495,7 @@ def on_disconnect(self, sid):
disconnect_token = self.sid_to_token.pop(sid, None)
if disconnect_token:
self.token_to_sid.pop(disconnect_token, None)
await self.app.state_manager.disconnect(sid)

async def emit_update(self, update: StateUpdate, sid: str) -> None:
"""Emit an update to the client.
Expand Down
91 changes: 76 additions & 15 deletions reflex/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
get_type_hints,
)

from redis.asyncio.client import PubSub
from sqlalchemy.orm import DeclarativeBase
from typing_extensions import Self

Expand Down Expand Up @@ -135,7 +136,7 @@


def _no_chain_background_task(
state_cls: Type["BaseState"], name: str, fn: Callable
state_cls: Type[BaseState], name: str, fn: Callable
) -> Callable:
"""Protect against directly chaining a background task from another event handler.

Expand Down Expand Up @@ -172,9 +173,10 @@ async def _no_chain_background_task_gen(*args, **kwargs):
raise TypeError(f"{fn} is marked as a background task, but is not async.")


@functools.lru_cache()
def _substate_key(
token: str,
state_cls_or_name: BaseState | Type[BaseState] | str | Sequence[str],
state_cls_or_name: Type[BaseState] | str | Sequence[str],
) -> str:
"""Get the substate key.

Expand All @@ -185,9 +187,7 @@ def _substate_key(
Returns:
The substate key.
"""
if isinstance(state_cls_or_name, BaseState) or (
isinstance(state_cls_or_name, type) and issubclass(state_cls_or_name, BaseState)
):
if isinstance(state_cls_or_name, type) and issubclass(state_cls_or_name, BaseState):
state_cls_or_name = state_cls_or_name.get_full_name()
elif isinstance(state_cls_or_name, (list, tuple)):
state_cls_or_name = ".".join(state_cls_or_name)
Expand Down Expand Up @@ -301,7 +301,16 @@ def get_var_for_field(cls: Type[BaseState], f: ModelField):
)


class BaseState(Base, ABC, extra=pydantic.Extra.allow):
class HashableModelMetaclass(type(Base)):
def __hash__(self):
return id(self)
# return hash(f"{self.__module__}.{self.__name__}")
# return hash(self.get_full_name())


class BaseState(
Base, ABC, extra=pydantic.Extra.allow, metaclass=HashableModelMetaclass
):
"""The state of the app."""

# A map from the var name to the var.
Expand Down Expand Up @@ -2826,6 +2835,14 @@ async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
"""
yield self.state()

async def disconnect(self, token: str) -> None:
"""Disconnect the client with the given token.

Args:
token: The token to disconnect.
"""
pass


class StateManagerMemory(StateManager):
"""A state manager that stores states in memory."""
Expand Down Expand Up @@ -2895,6 +2912,20 @@ async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
yield state
await self.set_state(token, state)

@override
async def disconnect(self, token: str) -> None:
"""Disconnect the client with the given token.

Args:
token: The token to disconnect.
"""
if token in self.states:
del self.states[token]
if lock := self._states_locks.get(token):
if lock.locked():
lock.release()
del self._states_locks[token]


def _default_token_expiration() -> int:
"""Get the default token expiration time.
Expand Down Expand Up @@ -3044,17 +3075,17 @@ async def populate_substates(
state: The state object to populate.
root_state: The root state object.
"""
for substate in state.get_substates():
substate_token = _substate_key(client_token, substate)
for substate_cls in state.get_substates():
substate_token = _substate_key(client_token, substate_cls)

fresh_instance = await root_state.get_state(substate)
fresh_instance = await root_state.get_state(substate_cls)
instance = await self.load_state(substate_token)
if instance is not None:
# Ensure all substates exist, even if they weren't serialized previously.
instance.substates = fresh_instance.substates
else:
instance = fresh_instance
state.substates[substate.get_name()] = instance
state.substates[substate_cls.get_name()] = instance
instance.parent_state = state

await self.populate_substates(client_token, instance, root_state)
Expand Down Expand Up @@ -3098,7 +3129,7 @@ async def set_state_for_substate(self, client_token: str, substate: BaseState):
client_token: The client token.
substate: The substate to set.
"""
substate_token = _substate_key(client_token, substate)
substate_token = _substate_key(client_token, type(substate))

if substate._get_was_touched():
substate._was_touched = False # Reset the touched flag after serializing.
Expand Down Expand Up @@ -3155,6 +3186,18 @@ def _default_lock_expiration() -> int:
return get_config().redis_lock_expiration


PUBSUB_CLIENTS: Dict[str, PubSub] = {}


async def cached_pubsub(redis: Redis, lock_key_channel: str) -> PubSub:
if lock_key_channel in PUBSUB_CLIENTS:
return PUBSUB_CLIENTS[lock_key_channel]
pubsub = redis.pubsub()
await pubsub.psubscribe(lock_key_channel)
PUBSUB_CLIENTS[lock_key_channel] = pubsub
return pubsub


class StateManagerRedis(StateManager):
"""A state manager that stores states in redis."""

Expand Down Expand Up @@ -3183,6 +3226,9 @@ class StateManagerRedis(StateManager):
b"evicted",
}

# This lock is used to ensure we only subscribe to keyspace events once per token and worker
_pubsub_locks: Dict[bytes, asyncio.Lock] = pydantic.PrivateAttr({})

async def _get_parent_state(
self, token: str, state: BaseState | None = None
) -> BaseState | None:
Expand Down Expand Up @@ -3367,7 +3413,7 @@ async def set_state(
tasks.append(
asyncio.create_task(
self.set_state(
token=_substate_key(client_token, substate),
token=_substate_key(client_token, type(substate)),
state=substate,
lock_id=lock_id,
)
Expand All @@ -3378,7 +3424,7 @@ async def set_state(
pickle_state = state._serialize()
if pickle_state:
await self.redis.set(
_substate_key(client_token, state),
_substate_key(client_token, type(state)),
pickle_state,
ex=self.token_expiration,
)
Expand Down Expand Up @@ -3458,8 +3504,10 @@ async def _wait_lock(self, lock_key: bytes, lock_id: bytes) -> None:
# Some redis servers only allow out-of-band configuration, so ignore errors here.
if not environment.REFLEX_IGNORE_REDIS_CONFIG_ERROR.get():
raise
async with self.redis.pubsub() as pubsub:
await pubsub.psubscribe(lock_key_channel)
if lock_key not in self._pubsub_locks:
self._pubsub_locks[lock_key] = asyncio.Lock()
async with self._pubsub_locks[lock_key]:
pubsub = await cached_pubsub(self.redis, lock_key_channel)
while not state_is_locked:
# wait for the lock to be released
while True:
Expand All @@ -3475,6 +3523,19 @@ async def _wait_lock(self, lock_key: bytes, lock_id: bytes) -> None:
break
state_is_locked = await self._try_get_lock(lock_key, lock_id)

@override
async def disconnect(self, token: str):
"""Disconnect the token from the redis client.

Args:
token: The token to disconnect.
"""
lock_key = self._lock_key(token)
if lock := self._pubsub_locks.get(lock_key):
if lock.locked():
lock.release()
del self._pubsub_locks[lock_key]

@contextlib.asynccontextmanager
async def _lock(self, token: str):
"""Obtain a redis lock for a token.
Expand Down
35 changes: 35 additions & 0 deletions tests/integration/test_background_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ async def handle_event_yield_only(self):
yield State.increment() # type: ignore
await asyncio.sleep(0.005)

@rx.event(background=True)
async def fast_yielding(self):
for _ in range(10000):
yield State.increment()

@rx.event
def increment(self):
self.counter += 1
Expand Down Expand Up @@ -169,6 +174,11 @@ def index() -> rx.Component:
on_click=State.yield_in_async_with_self,
id="yield-in-async-with-self",
),
rx.button(
"Fast Yielding",
on_click=State.fast_yielding,
id="fast-yielding",
),
rx.button("Reset", on_click=State.reset_counter, id="reset"),
)

Expand Down Expand Up @@ -375,3 +385,28 @@ def test_yield_in_async_with_self(

yield_in_async_with_self_button.click()
assert background_task._poll_for(lambda: counter.text == "2", timeout=5)


def test_fast_yielding(
background_task: AppHarness,
driver: WebDriver,
token: str,
) -> None:
"""Test that fast yielding works as expected.

Args:
background_task: harness for BackgroundTask app.
driver: WebDriver instance.
token: The token for the connected client.
"""
assert background_task.app_instance is not None

# get a reference to all buttons
fast_yielding_button = driver.find_element(By.ID, "fast-yielding")

# get a reference to the counter
counter = driver.find_element(By.ID, "counter")
assert background_task._poll_for(lambda: counter.text == "0", timeout=5)

fast_yielding_button.click()
assert background_task._poll_for(lambda: counter.text == "10000", timeout=50)
6 changes: 4 additions & 2 deletions tests/units/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -1759,7 +1759,7 @@ def substate_token_redis(state_manager_redis, token):
Returns:
Token concatenated with the state_manager's state full_name.
"""
return _substate_key(token, state_manager_redis.state)
return _substate_key(token, type(state_manager_redis.state))


@pytest.mark.asyncio
Expand Down Expand Up @@ -1918,7 +1918,9 @@ async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App):

# Get the state from the state manager directly and check that the value is updated
gotten_state = await mock_app.state_manager.get_state(
_substate_key(grandchild_state.router.session.client_token, grandchild_state)
_substate_key(
grandchild_state.router.session.client_token, type(grandchild_state)
)
)
if isinstance(mock_app.state_manager, (StateManagerMemory, StateManagerDisk)):
# For in-process store, only one instance of the state exists
Expand Down
Loading