Skip to content

Commit

Permalink
add a batch_handler ReadOnly notion
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelfeil committed Apr 13, 2024
1 parent d77efd7 commit 7837c85
Showing 1 changed file with 29 additions and 17 deletions.
46 changes: 29 additions & 17 deletions libs/infinity_emb/infinity_emb/inference/batch_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,14 @@ def is_set(self) -> bool:
return self._shutdown.is_set()


class ThreadPoolExecutorReadOnly:
def __init__(self, tp: ThreadPoolExecutor) -> None:
self._tp = tp

def submit(self, *args, **kwargs):
return self._tp.submit(*args, **kwargs)


class BatchHandler:
def __init__(
self,
Expand All @@ -53,16 +61,16 @@ def __init__(
performs batching around the model.
Args:
model: BaseTransformer, implements fn (core|pre|post)_encode
max_batch_size: max batch size of the models
max_queue_wait: max items to queue in the batch, default 32_000 sentences
batch_delay: sleep in seconds, wait time for pre/post methods.
model (BaseTransformer): model to be batched
max_batch_size (int): max batch size
max_queue_wait (int, optional): max items to queue in the batch, default 32_000
batch_delay (float, optional): sleep in seconds, wait time for pre/post methods.
Best result: setting to 1/2 the minimal expected
time for core_encode method / "gpu inference".
Dont set it above 1x minimal expected time of interence.
Should not be 0 to not block Python's GIL.
vector_disk_cache_path: path to cache vectors on disk.
lengths_via_tokenize: if True, use the tokenizer to get the lengths else len()
vector_disk_cache_path (str, optional): path to cache vectors on disk.
lengths_via_tokenize (bool, optional): if True, use the tokenizer to get the lengths else len()
"""

self._max_queue_wait = max_queue_wait
Expand All @@ -87,7 +95,7 @@ def __init__(
max_batch_size=max_batch_size,
shutdown=ShutdownReadOnly(self._shutdown),
model=model,
threadpool=self._threadpool,
threadpool=ThreadPoolExecutorReadOnly(self._threadpool),
input_q=self._queue_prio,
output_q=self._result_queue,
verbose=verbose,
Expand Down Expand Up @@ -251,21 +259,21 @@ async def _get_prios_usage(
get_lengths_with_tokenize,
self._threadpool,
_sentences=[it.str_repr() for it in items],
tokenize=self.model_worker._model.tokenize_lengths,
tokenize=self.model_worker.tokenize_lengths,
)

@staticmethod
async def _collect_from_model(shutdown: ShutdownReadOnly, result_queue: Queue):
async def _collect_from_model(
shutdown: ShutdownReadOnly, result_queue: Queue, tp: ThreadPoolExecutor
):
try:
while not shutdown.is_set():
try:
post_batch = result_queue.get_nowait()
except queue.Empty:
# instead use async await to get
try:
post_batch = await asyncio.to_thread(
result_queue.get, timeout=1
)
post_batch = await to_thread(result_queue.get, tp, timeout=1)
except queue.Empty:
# in case of timeout start again
continue
Expand All @@ -283,9 +291,9 @@ async def spawn(self):
logger.info("creating batching engine")
self.loop = asyncio.get_event_loop()

asyncio.create_task(
self._collect_task = asyncio.create_task(
self._collect_from_model(
ShutdownReadOnly(self._shutdown), self._result_queue
ShutdownReadOnly(self._shutdown), self._result_queue, self._threadpool
)
)
self.model_worker.spawn()
Expand All @@ -296,8 +304,9 @@ async def shutdown(self):
Blocking event, until shutdown complete.
"""
self._shutdown.set()
with ThreadPoolExecutor() as tp_temp:
await to_thread(self._threadpool.shutdown, tp_temp)
await asyncio.to_thread(self._threadpool.shutdown)
# collect task
self._collect_task.cancel()


class ModelWorker:
Expand All @@ -306,7 +315,7 @@ def __init__(
max_batch_size: int,
shutdown: ShutdownReadOnly,
model: BaseTransformer,
threadpool: ThreadPoolExecutor,
threadpool: ThreadPoolExecutorReadOnly,
input_q: CustomFIFOQueue,
output_q: Queue,
batch_delay: float = 5e-3,
Expand Down Expand Up @@ -337,6 +346,9 @@ def spawn(self):
def capabilities(self) -> Set[ModelCapabilites]:
return self._model.capabilities

def tokenize_lengths(self, *args, **kwargs):
return self._model.tokenize_lengths(*args, **kwargs)

def _preprocess_batch(self):
"""loops and checks if the _core_batch has worked on all items"""
logger.info("ready to batch requests.")
Expand Down

0 comments on commit 7837c85

Please sign in to comment.