diff --git a/libs/infinity_emb/infinity_emb/inference/batch_handler.py b/libs/infinity_emb/infinity_emb/inference/batch_handler.py index b9870403..50fb5b92 100644 --- a/libs/infinity_emb/infinity_emb/inference/batch_handler.py +++ b/libs/infinity_emb/infinity_emb/inference/batch_handler.py @@ -38,6 +38,14 @@ def is_set(self) -> bool: return self._shutdown.is_set() +class ThreadPoolExecutorReadOnly: + def __init__(self, tp: ThreadPoolExecutor) -> None: + self._tp = tp + + def submit(self, *args, **kwargs): + return self._tp.submit(*args, **kwargs) + + class BatchHandler: def __init__( self, @@ -53,16 +61,16 @@ def __init__( performs batching around the model. Args: - model: BaseTransformer, implements fn (core|pre|post)_encode - max_batch_size: max batch size of the models - max_queue_wait: max items to queue in the batch, default 32_000 sentences - batch_delay: sleep in seconds, wait time for pre/post methods. + model (BaseTransformer): model to be batched + max_batch_size (int): max batch size + max_queue_wait (int, optional): max items to queue in the batch, default 32_000 + batch_delay (float, optional): sleep in seconds, wait time for pre/post methods. Best result: setting to 1/2 the minimal expected time for core_encode method / "gpu inference". Dont set it above 1x minimal expected time of interence. Should not be 0 to not block Python's GIL. - vector_disk_cache_path: path to cache vectors on disk. - lengths_via_tokenize: if True, use the tokenizer to get the lengths else len() + vector_disk_cache_path (str, optional): path to cache vectors on disk. + lengths_via_tokenize (bool, optional): if True, use the tokenizer to get the lengths else len() """ self._max_queue_wait = max_queue_wait @@ -87,7 +95,7 @@ def __init__( max_batch_size=max_batch_size, shutdown=ShutdownReadOnly(self._shutdown), model=model, - threadpool=self._threadpool, + threadpool=ThreadPoolExecutorReadOnly(self._threadpool), input_q=self._queue_prio, output_q=self._result_queue, verbose=verbose, @@ -251,11 +259,13 @@ async def _get_prios_usage( get_lengths_with_tokenize, self._threadpool, _sentences=[it.str_repr() for it in items], - tokenize=self.model_worker._model.tokenize_lengths, + tokenize=self.model_worker.tokenize_lengths, ) @staticmethod - async def _collect_from_model(shutdown: ShutdownReadOnly, result_queue: Queue): + async def _collect_from_model( + shutdown: ShutdownReadOnly, result_queue: Queue, tp: ThreadPoolExecutor + ): try: while not shutdown.is_set(): try: @@ -263,9 +273,7 @@ async def _collect_from_model(shutdown: ShutdownReadOnly, result_queue: Queue): except queue.Empty: # instead use async await to get try: - post_batch = await asyncio.to_thread( - result_queue.get, timeout=1 - ) + post_batch = await to_thread(result_queue.get, tp, timeout=1) except queue.Empty: # in case of timeout start again continue @@ -283,9 +291,9 @@ async def spawn(self): logger.info("creating batching engine") self.loop = asyncio.get_event_loop() - asyncio.create_task( + self._collect_task = asyncio.create_task( self._collect_from_model( - ShutdownReadOnly(self._shutdown), self._result_queue + ShutdownReadOnly(self._shutdown), self._result_queue, self._threadpool ) ) self.model_worker.spawn() @@ -296,8 +304,9 @@ async def shutdown(self): Blocking event, until shutdown complete. """ self._shutdown.set() - with ThreadPoolExecutor() as tp_temp: - await to_thread(self._threadpool.shutdown, tp_temp) + await asyncio.to_thread(self._threadpool.shutdown) + # collect task + self._collect_task.cancel() class ModelWorker: @@ -306,7 +315,7 @@ def __init__( max_batch_size: int, shutdown: ShutdownReadOnly, model: BaseTransformer, - threadpool: ThreadPoolExecutor, + threadpool: ThreadPoolExecutorReadOnly, input_q: CustomFIFOQueue, output_q: Queue, batch_delay: float = 5e-3, @@ -337,6 +346,9 @@ def spawn(self): def capabilities(self) -> Set[ModelCapabilites]: return self._model.capabilities + def tokenize_lengths(self, *args, **kwargs): + return self._model.tokenize_lengths(*args, **kwargs) + def _preprocess_batch(self): """loops and checks if the _core_batch has worked on all items""" logger.info("ready to batch requests.")