From 3743b34a70c1270c3b630a2e5eb05924d4e9df1d Mon Sep 17 00:00:00 2001 From: Zhongsheng Ji Date: Fri, 18 Oct 2024 16:09:33 +0800 Subject: [PATCH] Intro callback function for processing result and exception (#42) * Intro callback function for processing result and exception * Add testcase and docsting for callback function * Add example for add callback * Fix readme issue --- README.md | 3 ++- brq/consumer.py | 30 +++++++++++++++++++++++-- brq/decorator.py | 23 +++++++++++++++++-- examples/decorator/consumer.py | 7 +++++- tests/test_brq.py | 41 ++++++++++++++++++++++++++++++++++ 5 files changed, 98 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index b9e0000..7c2e290 100644 --- a/README.md +++ b/README.md @@ -22,12 +22,13 @@ Redis >= 6.2, tested with latest redis 6/7 docker image. Recommended to use redi ## Feature -> See [examples](%22./examples%22) for running examples. +> See [examples](./examples) for running examples. - Defer job and automatic retry error job - Dead queue for unprocessable job, you can process it later - Multiple consumers in one consumer group - No scheduler needed, consumer handles itself +- Using callback function to process job result or exception ## Configuration diff --git a/brq/consumer.py b/brq/consumer.py index e5bd4e8..f801cda 100644 --- a/brq/consumer.py +++ b/brq/consumer.py @@ -14,6 +14,10 @@ CONSUMER_IDENTIFIER_ENV = "BRQ_CONSUMER_IDENTIFIER" +class CannotProcessError(RuntimeError): + pass + + class RunnableMixin: def __init__(self): self._stop_event = asyncio.Event() @@ -93,6 +97,7 @@ class Consumer(BrqOperator, RunnableMixin): max_message_len(int, default=1000): The maximum length of a message. Follow redis stream `maxlen`. delete_message_after_process(bool, default=False): Whether to delete message after process. If many consumer groups are used, this should be set to False. run_parallel(bool, default=False): Whether to run in parallel. + awaitable_function_callback(Callable[[Job, Exception | Any, "Consumer"], Awaitable[Any]] | None, default=None): The callback function for awaitable_function. """ def __init__( @@ -116,12 +121,17 @@ def __init__( max_message_len: int = 1000, delete_message_after_process: bool = False, run_parallel: bool = False, + awaitable_function_callback: ( + Callable[[Job, Exception | Any, "Consumer"], Awaitable[Any]] | None + ) = None, ): + super().__init__(redis, redis_prefix, redis_seperator) self._stop_event = asyncio.Event() self._task: None | asyncio.Task = None self.awaitable_function = awaitable_function + self.awaitable_function_callback = awaitable_function_callback self.register_function_name = register_function_name or awaitable_function.__name__ self.group_name = group_name self.consumer_identifier = consumer_identifier @@ -166,6 +176,15 @@ def deferred_key(self) -> str: def dead_key(self) -> str: return self.get_dead_message_key(self.register_function_name) + async def callback(self, job: Job, exception_or_result: Exception | Any): + if not self.awaitable_function_callback: + return + try: + await self.awaitable_function_callback(job, exception_or_result, self) + except Exception as e: + logger.warning(f"Callback error: {e}") + logger.exception(e) + async def _is_retry_cooldown(self) -> bool: return await self.redis.exists(self.retry_cooldown_key) @@ -234,6 +253,7 @@ async def _move_expired_jobs(self): logger.info(f"Put expired job {job} to dead queue") await self.redis.xdel(self.stream_name, message_id) logger.debug(f"{job} expired") + await self.callback(job, CannotProcessError("Expired")) async def _process_unacked_job(self): if not self.enable_reprocess_timeout_job: @@ -263,15 +283,17 @@ async def _process_unacked_job(self): continue job = Job.from_message(serialized_job) try: - await self.awaitable_function(*job.args, **job.kwargs) + r = await self.awaitable_function(*job.args, **job.kwargs) except Exception as e: logger.exception(e) + await self.callback(job, e) else: logger.info(f"Retry {job} successfully") await self.redis.xack(self.stream_name, self.group_name, message_id) if self.delete_message_after_process: await self.redis.xdel(self.stream_name, message_id) + await self.callback(job, r) async def _pool_job(self): poll_result = await self.redis.xreadgroup( @@ -288,14 +310,16 @@ async def _pool_job(self): for message_id, serialized_job in messages: job = Job.from_message(serialized_job) try: - await self.awaitable_function(*job.args, **job.kwargs) + r = await self.awaitable_function(*job.args, **job.kwargs) except Exception as e: logger.exception(e) + await self.callback(job, e) else: await self.redis.xack(self.stream_name, self.group_name, message_id) if self.delete_message_after_process: await self.redis.xdel(self.stream_name, message_id) + await self.callback(job, r) async def _pool_job_prallel(self): pool_result = await self.redis.xreadgroup( @@ -325,11 +349,13 @@ async def _job_wrap(message_id, *args, **kwargs): for message_id, result in zip(jobs, results): if isinstance(result, Exception): logger.exception(result) + await self.callback(jobs[message_id], result) continue await self.redis.xack(self.stream_name, self.group_name, message_id) if self.delete_message_after_process: await self.redis.xdel(self.stream_name, message_id) + await self.callback(jobs[message_id], result) async def _acquire_retry_lock(self) -> bool: return await self.redis.get( diff --git a/brq/decorator.py b/brq/decorator.py index 8a835c6..43d35a2 100644 --- a/brq/decorator.py +++ b/brq/decorator.py @@ -17,8 +17,10 @@ def __init__( func: Callable | Awaitable, config: BrqConfig, register_function_name: str | None = None, + callback_func: Callable | Awaitable | None = None, ): self._func = func + self._callback_func = callback_func self.config = config self.register_function_name = register_function_name or func.__name__ @@ -34,6 +36,14 @@ async def _serve(self): self._func, limiter=CapacityLimiter(total_tokens=self.config.daemon_concurrency), ) + callback_func = ( + ensure_awaitable( + self._callback_func, + limiter=CapacityLimiter(total_tokens=self.config.daemon_concurrency), + ) + if self._callback_func + else None + ) consumer_builder = partial( Consumer, redis_prefix=self.config.redis_key_prefix, @@ -55,6 +65,7 @@ async def _serve(self): max_message_len=self.config.consumer_max_message_len, delete_message_after_process=self.config.consumer_delete_message_after_process, run_parallel=self.config.consumer_run_parallel, + awaitable_function_callback=callback_func, ) daemon = Daemon(*[consumer_builder() for _ in range(self.config.daemon_concurrency)]) await daemon.run_forever() @@ -91,6 +102,7 @@ def task( max_message_len: int | None = None, delete_message_after_process: bool | None = None, run_parallel: bool | None = None, + callback_func: Callable | Awaitable | None = None, # Daemon daemon_concurrency: int | None = None, ): @@ -99,9 +111,10 @@ def _wrapper( func, config: BrqConfig | None = None, register_function_name: str | None = None, + callback_func: Callable | Awaitable | None = None, ): - return BrqTaskWrapper(func, config, register_function_name) + return BrqTaskWrapper(func, config, register_function_name, callback_func) kwargs = { k: v @@ -153,6 +166,12 @@ def _wrapper( _wrapper, config=config, register_function_name=register_function_name, + callback_func=callback_func, ) else: - return _wrapper(_func, config=config, register_function_name=register_function_name) + return _wrapper( + _func, + config=config, + register_function_name=register_function_name, + callback_func=callback_func, + ) diff --git a/examples/decorator/consumer.py b/examples/decorator/consumer.py index 9a6c246..1df9381 100644 --- a/examples/decorator/consumer.py +++ b/examples/decorator/consumer.py @@ -1,9 +1,14 @@ from brq import task -@task +def callback(job, exception_or_result, consumer): + print(f"Callback for {job} with {exception_or_result}") + + +@task(callback_func=callback) def echo(message): print(f"Received message: {message}") + return "processed" if __name__ == "__main__": diff --git a/tests/test_brq.py b/tests/test_brq.py index 0981981..aeca57c 100644 --- a/tests/test_brq.py +++ b/tests/test_brq.py @@ -33,6 +33,16 @@ async def mock_always_consume_raise_exception(echo_str: str): raise Exception("raise exception") +callback_called = False + + +async def mock_callback(job, exception_or_result, consumer): + global callback_called + if callback_called: + raise RuntimeError("Callback already called. Please reset it.") + callback_called = True + + @pytest.mark.parametrize("run_parallel", [False, True]) async def test_consume_function(async_redis_client, capfd, run_parallel): producer = Producer(async_redis_client) @@ -53,6 +63,37 @@ async def test_consume_function(async_redis_client, capfd, run_parallel): await consumer.cleanup() +@pytest.mark.parametrize("run_parallel", [False, True]) +async def test_consume_with_callback_function(async_redis_client, capfd, run_parallel): + global callback_called + callback_called = False + assert not callback_called + + producer = Producer(async_redis_client) + consumer = Consumer( + async_redis_client, + mock_consume, + run_parallel=run_parallel, + awaitable_function_callback=mock_callback, + ) + browser = Browser(async_redis_client) + await browser.status() + await producer.run_job("mock_consume", ["hello"]) + jobs = [job async for job in producer.walk_jobs("mock_consume")] + assert len(jobs) == 1 + await browser.status() + await consumer.initialize() + await browser.status() + await consumer.run() + await browser.status() + + out, err = capfd.readouterr() + + assert "hello" in out + assert callback_called + await consumer.cleanup() + + async def test_count_jobs(async_redis_client, redis_version): producer = Producer(async_redis_client) consumer = Consumer(async_redis_client, delay_job)