Skip to content

Commit

Permalink
inital commit
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelfeil committed Dec 26, 2023
1 parent 58f01f1 commit cf223b7
Show file tree
Hide file tree
Showing 7 changed files with 268 additions and 103 deletions.
178 changes: 178 additions & 0 deletions experimental/caching/multiprocessing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
from typing import Iterator
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
import torch.multiprocessing as mp
from transformers import AutoTokenizer, AutoModel
import torch
from queue import Empty
import time
from enum import Enum

class SignalMessages(Enum):
POISON_KILL = 1
WAKE_ON_NOTIFY = 2

class SpecialQueue:
def __init__(self, *args, **kwargs):
self.queue = mp.Queue(*args, **kwargs)

def get(self, block: bool = True, timeout: float | None = None):
item = self.queue.get(block, timeout)
if isinstance(item, SignalMessages):
if item == SignalMessages.POISON_KILL:
self.queue.put(SignalMessages.POISON_KILL)
raise ValueError("Poison Kill")
elif item == SignalMessages.WAKE_ON_NOTIFY:
return None
else:
return item

def put(self, item, block: bool = True, timeout: float | None = None):
self.queue.put(item, block, timeout)

def close(self):
self.queue.close()

def notify(self):
self.queue.put(SignalMessages.WAKE_ON_NOTIFY)
# SpecialQueue = Queue

def queuebatcher(queue_in: SpecialQueue, queue_out: SpecialQueue, batch_size: int, sort_n=4):
queue_in = queue_in
queue_out =queue_out
batch_size = batch_size
sort_n = sort_n
waiting_list = []
while True:
new_items = queue_in.get()

if new_items is not None and len(new_items) > 0:
waiting_list.extend(new_items)

if new_items is None and waiting_list:
# flush without waiting for a full batch
queue_out.put(waiting_list[:batch_size])
waiting_list = waiting_list[batch_size:]
elif len(waiting_list) >= batch_size:
# pop up to sort_n batches
max_pop = min(sort_n, len(waiting_list)//batch_size)
to_add, waiting_list = waiting_list[:batch_size*max_pop], waiting_list[batch_size*max_pop:]
# sort and append
to_add = sorted(to_add)
for bs in range(0, len(to_add), batch_size):
queue_out.put(to_add[bs:bs+batch_size])


class BoringPipeline(object):
def __init__(self):
pass

def working_function(self, item):
raise NotImplementedError

def post_init(self, **kwargs):
raise NotImplementedError

def post_init_and_loop(self, queue_in: SpecialQueue, queue_out: SpecialQueue, **kwargs):
self.queue_in = queue_in
self.queue_out = queue_out
self.post_init(**kwargs)
self.loop_forever()

def loop_forever(self):
try:
while True:
item = self.queue_in.get()
processed = self.working_function(item)
self.queue_out.put(processed)
except KeyboardInterrupt:
pass

class TokenizePipeline(BoringPipeline):
def post_init(self, device: str):
self.tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-small-en-v1.5")
self.device = device

def working_function(self, item):
assert isinstance(item, list) and all(isinstance(i, str) for i in item)
try:
with torch.inference_mode():
return self.tokenizer(item, padding="max_length", truncation=True, return_tensors="pt").to(self.device)
except Exception as ex:
print(ex)
return None

class ModelPipeline(BoringPipeline):
def post_init(self, model_device: str):
self.model = AutoModel.from_pretrained("BAAI/bge-small-en-v1.5").to(model_device)
self.model.eval()
self.model.half()

def working_function(self, item):
with torch.inference_mode():
return self.model(**item).last_hidden_state.shape

def main():
mp.set_start_method('spawn')
queues = [SpecialQueue(), SpecialQueue(), SpecialQueue(), SpecialQueue()]

# fill with some data
items = [f"{i}" for i in range(5000)]
# go

processes = []
processes.append(mp.Process(target=queuebatcher, args=(queues[0], queues[1], 64)))
processes[-1].start()

processes.append(mp.Process(target=TokenizePipeline().post_init_and_loop, kwargs=dict(
queue_in=queues[1], queue_out=queues[2], device="cuda")))
processes[-1].start()

processes.append(mp.Process(target=ModelPipeline().post_init_and_loop, kwargs=dict(
queue_in=queues[2], queue_out=queues[3], model_device="cuda")))
processes[-1].start()
queues[0].put(items[1:33])
time.sleep(5)


s = time.perf_counter()
for bs in range(0, len(items), 17):
queues[0].put(items[bs:bs+17])
time.sleep(2)
try:
i = 0
while i < 1:
try:
item = queues[-1].get(timeout=0.5)
except Empty:
queues[0].put(SignalMessages.WAKE_ON_NOTIFY)
i+=1
continue
print(item)
finally:
print(time.perf_counter() -s, "seconds")
print("Shutting down")
for i in range(5):
for q in queues:
q.put(SignalMessages.POISON_KILL)
time.sleep(3)
print("closing queues")
for queue in queues:
queue.close()
# print("joining processes")
# for p in processes:
# p.join()
queues = None
# time.sleep(3)
print("Done")

from sentence_transformers import SentenceTransformer
model = SentenceTransformer("BAAI/bge-small-en-v1.5")
start = time.perf_counter()
model.encode(items, batch_size=64, show_progress_bar=True)
print(time.perf_counter() - start, "sentence transformers")




if __name__ == "__main__":
main()
13 changes: 6 additions & 7 deletions libs/infinity_emb/infinity_emb/inference/batch_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import numpy as np

from infinity_emb.inference.caching_layer import Cache
from infinity_emb.inference.queue import CustomFIFOQueue, ResultKVStoreFuture
from infinity_emb.inference.queue import CustomFIFOQueue, ResultKVStoreFuture, QueueSignal
from infinity_emb.inference.threading_asyncio import to_thread
from infinity_emb.log_handler import logger
from infinity_emb.primitives import (
Expand All @@ -30,7 +30,6 @@
from infinity_emb.transformer.abstract import BaseTransformer
from infinity_emb.transformer.utils import get_lengths_with_tokenize


class BatchHandler:
def __init__(
self,
Expand Down Expand Up @@ -185,7 +184,7 @@ async def _schedule(
item = PrioritizedQueueItem(
priority=p,
item=inner_item(
content=re, future=self.loop.create_future() # type: ignore
content=re # type: ignore
),
)
new_prioqueue.append(item)
Expand Down Expand Up @@ -339,7 +338,7 @@ async def _postprocess_batch(self):
except queue.Empty:
# in case of timeout start again
continue

if (
self._postprocess_queue.empty()
and self._last_inference
Expand All @@ -354,7 +353,8 @@ async def _postprocess_batch(self):
embed, batch = post_batch
results = self.model.encode_post(embed)
for i, item in enumerate(batch):
await item.complete(results[i])
item.set_result(results[i])
await self._result_store.mark_item_ready(item)

self._postprocess_queue.task_done()
except Exception as ex:
Expand Down Expand Up @@ -383,7 +383,6 @@ async def spawn(self):
if self._ready:
raise ValueError("previous threads are still running.")
logger.info("creating batching engine")
self.loop = asyncio.get_event_loop()
self._threadpool.submit(self._preprocess_batch)
self._threadpool.submit(self._core_batch)
asyncio.create_task(self._postprocess_batch())
Expand All @@ -392,7 +391,7 @@ async def spawn(self):
async def shutdown(self):
"""
set the shutdown event and close threadpool.
Blocking event, until shutdown complete.
Blocking event, until shutdown.
"""
self._shutdown.set()
with ThreadPoolExecutor() as tp_temp:
Expand Down
23 changes: 14 additions & 9 deletions libs/infinity_emb/infinity_emb/inference/caching_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,19 +60,24 @@ def _consume_queue(self) -> None:
self._threadpool.shutdown(wait=True)

def _get(self, sentence: str) -> Union[None, EmbeddingReturnType, List[float]]:
"""sets the item.complete() and sets embedding, if in cache."""
return self._cache.get(key=self._hash(sentence))

async def aget_complete(self, item: QueueItemInner) -> None:
"""sets the item.complete() and sets embedding, if in cache."""
async def aget(self, item: QueueItemInner, future: asyncio.Future) -> None:
"""Sets result to item and future, if in cache.
If not in cache, sets future to be done when result is set.
"""
item_as_str = item.content.str_repr()
result = await to_thread(self._get, self._threadpool, item_as_str)
if result is not None:
# update item with cached result
if not item.future.done():
await item.complete(result)
if item.get_result() is None:
item.set_result(result)
try:
future.set_result(None)
except asyncio.InvalidStateError:
pass
else:
# result is not in cache yet, lets wait for it and add it
result_new = await item.get_result()
await asyncio.sleep(1e-3)
self._add_q.put((item_as_str, result_new))
await future
result = item.get_result()
self._add_q.put((item_as_str, result))

35 changes: 30 additions & 5 deletions libs/infinity_emb/infinity_emb/inference/queue.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
import asyncio
import threading
from typing import Dict, List, Optional, Union
import enum

from infinity_emb.inference.caching_layer import Cache
from infinity_emb.primitives import (
EmbeddingReturnType,
PrioritizedQueueItem,
QueueItemInner,
)

class QueueSignal(enum.Enum):
KILL = "kill"

class CustomFIFOQueue:
def __init__(self) -> None:
Expand Down Expand Up @@ -66,7 +68,7 @@ def pop_optimal_batches(
for i in range(n_batches):
mini_batch = new_items_l[size * i : size * (i + 1)]
mini_batch_e: List[QueueItemInner] = [
mi.item for mi in mini_batch if not mi.item.future.done()
mi.item for mi in mini_batch
]
if mini_batch_e:
new_items.append(mini_batch_e)
Expand All @@ -78,14 +80,37 @@ def pop_optimal_batches(

class ResultKVStoreFuture:
def __init__(self, cache: Optional[Cache] = None) -> None:
self._kv: Dict[str, EmbeddingReturnType] = {}
self._kv: Dict[str, asyncio.Future] = {}
self._cache = cache
self._loop = None

async def _loop_running(self):
if self._loop is None:
self._loop = asyncio.get_running_loop()

def __len__(self):
return len(self._kv)

async def wait_for_response(self, item: QueueItemInner) -> EmbeddingReturnType:
"""wait for future to return"""
await self._loop_running()
uuid = item.get_id()
fut = self._loop.create_future()
self._kv[uuid] = fut
if self._cache:
asyncio.create_task(self._cache.aget_complete(item))
return await item.future
asyncio.create_task(self._cache.aget(item))
await fut
return item.get_result()


async def mark_item_ready(self, item: QueueItemInner) -> None:
"""mark item as ready. Item.get_result() must be set before calling this"""
await self._loop_running()
uuid = item.get_id()
fut = self._kv[uuid]
try:
fut.set_result(None)
except asyncio.InvalidStateError:
pass
del self._kv[uuid]

Loading

0 comments on commit cf223b7

Please sign in to comment.