Skip to content

Commit

Permalink
Fix losing initial payload because of the race.
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrey Zelenchuk committed Mar 10, 2022
1 parent e37937e commit 28632fc
Showing 1 changed file with 19 additions and 26 deletions.
45 changes: 19 additions & 26 deletions channels_graphql_ws/graphql_ws_consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,9 +713,6 @@ async def _register_subscription(
# `_sids_by_group` without any locks.
self._assert_thread()

# The subject we will trigger on the `broadcast` message.
trigger = rx.subjects.Subject()

# The subscription notification queue.
queue_size = notification_queue_limit
if not queue_size or queue_size < 0:
Expand All @@ -728,56 +725,41 @@ async def _register_subscription(

# Start an endless task which listens the `notification_queue`
# and invokes subscription "resolver" on new notifications.
async def notifier():
async def notifier(observer: rx.Observer):
"""Watch the notification queue and notify clients."""

# Assert we run in a proper thread.
self._assert_thread()

# Dirty hack to partially workaround the race between:
# 1) call to `result.subscribe` in `_on_gql_start`; and
# 2) call to `trigger.on_next` below in this function.
# The first call must be earlier. Otherwise, first one or more notifications
# may be lost.
await asyncio.sleep(1)

while True:
serialized_payload = await notification_queue.get()

# Run a subscription's `publish` method (invoked by the
# `trigger.on_next` function) within the threadpool used
# `observer.on_next` function) within the threadpool used
# for processing other GraphQL resolver functions.
# NOTE: it is important to run the deserialization
# in the worker thread as well.
def workload():
try:
payload = Serializer.deserialize(serialized_payload)
except Exception as ex: # pylint: disable=broad-except
trigger.on_error(f"Cannot deserialize payload. {ex}")
observer.on_error(f"Cannot deserialize payload. {ex}")
else:
trigger.on_next(payload)
observer.on_next(payload)

await self._run_in_worker(workload)

# Message processed. This allows `Queue.join` to work.
notification_queue.task_done()

# Enqueue the `publish` method execution. But do not notify
# clients when `publish` returns `SKIP`.
stream = trigger.map(publish_callback).filter( # pylint: disable=no-member
lambda publish_returned: publish_returned is not self.SKIP
)

def push_payloads(observer: rx.Observer):
# Start listening for broadcasts (subscribe to the Channels
# groups), spawn the notification processing task and put
# subscription information into the registry.
# NOTE: Update of `_sids_by_group` & `_subscriptions` must be
# atomic i.e. without `awaits` in between.
waitlist = []
for group in groups:
self._sids_by_group.setdefault(group, []).append(operation_id)
waitlist.append(self._channel_layer.group_add(group, self.channel_name))
notifier_task = self._spawn_background_task(notifier())
notifier_task = self._spawn_background_task(notifier(observer))
self._subscriptions[operation_id] = self._SubInf(
groups=groups,
sid=operation_id,
Expand All @@ -786,9 +768,20 @@ def workload():
notifier_task=notifier_task,
)

await asyncio.wait(waitlist)
await asyncio.wait(
[
self._channel_layer.group_add(group, self.channel_name)
for group in groups
]
)

return stream
# Enqueue the `publish` method execution. But do not notify
# clients when `publish` returns `SKIP`.
return (
rx.Observable.create(push_payloads) # pylint: disable=no-member
.map(publish_callback)
.filter(lambda publish_returned: publish_returned is not self.SKIP)
)

async def _on_gql_stop(self, operation_id):
"""Process the STOP message.
Expand Down

0 comments on commit 28632fc

Please sign in to comment.