Skip to content

Commit

Permalink
Intro callback function for processing result and exception (#42)
Browse files Browse the repository at this point in the history
* Intro callback function for processing result and exception

* Add testcase and docsting for callback function

* Add example for add callback

* Fix readme issue
  • Loading branch information
Wh1isper authored Oct 18, 2024
1 parent 9b0e820 commit 3743b34
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 6 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,13 @@ Redis >= 6.2, tested with latest redis 6/7 docker image. Recommended to use redi

## Feature

> See [examples](%22./examples%22) for running examples.
> See [examples](./examples) for running examples.
- Defer job and automatic retry error job
- Dead queue for unprocessable job, you can process it later
- Multiple consumers in one consumer group
- No scheduler needed, consumer handles itself
- Using callback function to process job result or exception

## Configuration

Expand Down
30 changes: 28 additions & 2 deletions brq/consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@
CONSUMER_IDENTIFIER_ENV = "BRQ_CONSUMER_IDENTIFIER"


class CannotProcessError(RuntimeError):
pass


class RunnableMixin:
def __init__(self):
self._stop_event = asyncio.Event()
Expand Down Expand Up @@ -93,6 +97,7 @@ class Consumer(BrqOperator, RunnableMixin):
max_message_len(int, default=1000): The maximum length of a message. Follow redis stream `maxlen`.
delete_message_after_process(bool, default=False): Whether to delete message after process. If many consumer groups are used, this should be set to False.
run_parallel(bool, default=False): Whether to run in parallel.
awaitable_function_callback(Callable[[Job, Exception | Any, "Consumer"], Awaitable[Any]] | None, default=None): The callback function for awaitable_function.
"""

def __init__(
Expand All @@ -116,12 +121,17 @@ def __init__(
max_message_len: int = 1000,
delete_message_after_process: bool = False,
run_parallel: bool = False,
awaitable_function_callback: (
Callable[[Job, Exception | Any, "Consumer"], Awaitable[Any]] | None
) = None,
):

super().__init__(redis, redis_prefix, redis_seperator)
self._stop_event = asyncio.Event()
self._task: None | asyncio.Task = None

self.awaitable_function = awaitable_function
self.awaitable_function_callback = awaitable_function_callback
self.register_function_name = register_function_name or awaitable_function.__name__
self.group_name = group_name
self.consumer_identifier = consumer_identifier
Expand Down Expand Up @@ -166,6 +176,15 @@ def deferred_key(self) -> str:
def dead_key(self) -> str:
return self.get_dead_message_key(self.register_function_name)

async def callback(self, job: Job, exception_or_result: Exception | Any):
if not self.awaitable_function_callback:
return
try:
await self.awaitable_function_callback(job, exception_or_result, self)
except Exception as e:
logger.warning(f"Callback error: {e}")
logger.exception(e)

async def _is_retry_cooldown(self) -> bool:
return await self.redis.exists(self.retry_cooldown_key)

Expand Down Expand Up @@ -234,6 +253,7 @@ async def _move_expired_jobs(self):
logger.info(f"Put expired job {job} to dead queue")
await self.redis.xdel(self.stream_name, message_id)
logger.debug(f"{job} expired")
await self.callback(job, CannotProcessError("Expired"))

async def _process_unacked_job(self):
if not self.enable_reprocess_timeout_job:
Expand Down Expand Up @@ -263,15 +283,17 @@ async def _process_unacked_job(self):
continue
job = Job.from_message(serialized_job)
try:
await self.awaitable_function(*job.args, **job.kwargs)
r = await self.awaitable_function(*job.args, **job.kwargs)
except Exception as e:
logger.exception(e)
await self.callback(job, e)
else:
logger.info(f"Retry {job} successfully")
await self.redis.xack(self.stream_name, self.group_name, message_id)

if self.delete_message_after_process:
await self.redis.xdel(self.stream_name, message_id)
await self.callback(job, r)

async def _pool_job(self):
poll_result = await self.redis.xreadgroup(
Expand All @@ -288,14 +310,16 @@ async def _pool_job(self):
for message_id, serialized_job in messages:
job = Job.from_message(serialized_job)
try:
await self.awaitable_function(*job.args, **job.kwargs)
r = await self.awaitable_function(*job.args, **job.kwargs)
except Exception as e:
logger.exception(e)
await self.callback(job, e)
else:
await self.redis.xack(self.stream_name, self.group_name, message_id)

if self.delete_message_after_process:
await self.redis.xdel(self.stream_name, message_id)
await self.callback(job, r)

async def _pool_job_prallel(self):
pool_result = await self.redis.xreadgroup(
Expand Down Expand Up @@ -325,11 +349,13 @@ async def _job_wrap(message_id, *args, **kwargs):
for message_id, result in zip(jobs, results):
if isinstance(result, Exception):
logger.exception(result)
await self.callback(jobs[message_id], result)
continue

await self.redis.xack(self.stream_name, self.group_name, message_id)
if self.delete_message_after_process:
await self.redis.xdel(self.stream_name, message_id)
await self.callback(jobs[message_id], result)

async def _acquire_retry_lock(self) -> bool:
return await self.redis.get(
Expand Down
23 changes: 21 additions & 2 deletions brq/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@ def __init__(
func: Callable | Awaitable,
config: BrqConfig,
register_function_name: str | None = None,
callback_func: Callable | Awaitable | None = None,
):
self._func = func
self._callback_func = callback_func
self.config = config
self.register_function_name = register_function_name or func.__name__

Expand All @@ -34,6 +36,14 @@ async def _serve(self):
self._func,
limiter=CapacityLimiter(total_tokens=self.config.daemon_concurrency),
)
callback_func = (
ensure_awaitable(
self._callback_func,
limiter=CapacityLimiter(total_tokens=self.config.daemon_concurrency),
)
if self._callback_func
else None
)
consumer_builder = partial(
Consumer,
redis_prefix=self.config.redis_key_prefix,
Expand All @@ -55,6 +65,7 @@ async def _serve(self):
max_message_len=self.config.consumer_max_message_len,
delete_message_after_process=self.config.consumer_delete_message_after_process,
run_parallel=self.config.consumer_run_parallel,
awaitable_function_callback=callback_func,
)
daemon = Daemon(*[consumer_builder() for _ in range(self.config.daemon_concurrency)])
await daemon.run_forever()
Expand Down Expand Up @@ -91,6 +102,7 @@ def task(
max_message_len: int | None = None,
delete_message_after_process: bool | None = None,
run_parallel: bool | None = None,
callback_func: Callable | Awaitable | None = None,
# Daemon
daemon_concurrency: int | None = None,
):
Expand All @@ -99,9 +111,10 @@ def _wrapper(
func,
config: BrqConfig | None = None,
register_function_name: str | None = None,
callback_func: Callable | Awaitable | None = None,
):

return BrqTaskWrapper(func, config, register_function_name)
return BrqTaskWrapper(func, config, register_function_name, callback_func)

kwargs = {
k: v
Expand Down Expand Up @@ -153,6 +166,12 @@ def _wrapper(
_wrapper,
config=config,
register_function_name=register_function_name,
callback_func=callback_func,
)
else:
return _wrapper(_func, config=config, register_function_name=register_function_name)
return _wrapper(
_func,
config=config,
register_function_name=register_function_name,
callback_func=callback_func,
)
7 changes: 6 additions & 1 deletion examples/decorator/consumer.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
from brq import task


@task
def callback(job, exception_or_result, consumer):
print(f"Callback for {job} with {exception_or_result}")


@task(callback_func=callback)
def echo(message):
print(f"Received message: {message}")
return "processed"


if __name__ == "__main__":
Expand Down
41 changes: 41 additions & 0 deletions tests/test_brq.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,16 @@ async def mock_always_consume_raise_exception(echo_str: str):
raise Exception("raise exception")


callback_called = False


async def mock_callback(job, exception_or_result, consumer):
global callback_called
if callback_called:
raise RuntimeError("Callback already called. Please reset it.")
callback_called = True


@pytest.mark.parametrize("run_parallel", [False, True])
async def test_consume_function(async_redis_client, capfd, run_parallel):
producer = Producer(async_redis_client)
Expand All @@ -53,6 +63,37 @@ async def test_consume_function(async_redis_client, capfd, run_parallel):
await consumer.cleanup()


@pytest.mark.parametrize("run_parallel", [False, True])
async def test_consume_with_callback_function(async_redis_client, capfd, run_parallel):
global callback_called
callback_called = False
assert not callback_called

producer = Producer(async_redis_client)
consumer = Consumer(
async_redis_client,
mock_consume,
run_parallel=run_parallel,
awaitable_function_callback=mock_callback,
)
browser = Browser(async_redis_client)
await browser.status()
await producer.run_job("mock_consume", ["hello"])
jobs = [job async for job in producer.walk_jobs("mock_consume")]
assert len(jobs) == 1
await browser.status()
await consumer.initialize()
await browser.status()
await consumer.run()
await browser.status()

out, err = capfd.readouterr()

assert "hello" in out
assert callback_called
await consumer.cleanup()


async def test_count_jobs(async_redis_client, redis_version):
producer = Producer(async_redis_client)
consumer = Consumer(async_redis_client, delay_job)
Expand Down

0 comments on commit 3743b34

Please sign in to comment.