Skip to content

Commit

Permalink
Merge pull request #53 from michaelfeil/fix-delayed-warmup
Browse files Browse the repository at this point in the history
fixing delayed warmup and expose capabilities
  • Loading branch information
michaelfeil authored Dec 22, 2023
2 parents abca605 + 8aced6c commit 58f01f1
Show file tree
Hide file tree
Showing 7 changed files with 35 additions and 17 deletions.
10 changes: 7 additions & 3 deletions libs/infinity_emb/infinity_emb/engine.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, List, Tuple, Union
from typing import Dict, List, Set, Tuple, Union

# prometheus
from infinity_emb.inference import (
Expand All @@ -7,7 +7,7 @@
select_model,
)
from infinity_emb.log_handler import logger
from infinity_emb.primitives import EmbeddingReturnType
from infinity_emb.primitives import EmbeddingReturnType, ModelCapabilites
from infinity_emb.transformer.utils import InferenceEngine


Expand Down Expand Up @@ -40,7 +40,7 @@ def __init__(
```python
from infinity_emb import AsyncEmbeddingEngine, transformer
sentences = ["Embedded this via Infinity.", "Paris is in France."]
engine = AsyncEmbeddingEngine(engine=transformer.InferenceEngine.torch)
engine = AsyncEmbeddingEngine(engine="torch")
async with engine: # engine starts with engine.astart()
embeddings = np.array(await engine.embed(sentences))
# engine stops with engine.astop().
Expand Down Expand Up @@ -104,6 +104,10 @@ def is_overloaded(self) -> bool:
self._check_running()
return self._batch_handler.is_overloaded()

@property
def capabilities(self) -> Set[ModelCapabilites]:
return self._model.capabilities

async def embed(
self, sentences: List[str]
) -> Tuple[List[EmbeddingReturnType], int]:
Expand Down
25 changes: 16 additions & 9 deletions libs/infinity_emb/infinity_emb/inference/batch_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,8 @@ async def _schedule(
)
return result, usage

def get_capabilities(self) -> Set[ModelCapabilites]:
@property
def capabilities(self) -> Set[ModelCapabilites]:
return self.model.capabilities

def is_overloaded(self) -> bool:
Expand Down Expand Up @@ -362,14 +363,20 @@ async def _postprocess_batch(self):

async def _delayed_warmup(self):
"""in case there is no warmup -> perform some warmup."""
await asyncio.sleep(10)
logger.debug("Sending a warm up through embedding.")
if "embed" in self.model.capabilities:
await self.embed(sentences=["test"] * self.max_batch_size)
if "rerank" in self.model.capabilities:
await self.rerank(query="query", docs=["test"] * self.max_batch_size)
if "classify" in self.model.capabilities:
await self.classify(sentences=["test"] * self.max_batch_size)
await asyncio.sleep(5)
if not self._shutdown.is_set():
logger.debug("Sending a warm up through embedding.")
try:
if "embed" in self.model.capabilities:
await self.embed(sentences=["test"] * self.max_batch_size)
if "rerank" in self.model.capabilities:
await self.rerank(
query="query", docs=["test"] * self.max_batch_size
)
if "classify" in self.model.capabilities:
await self.classify(sentences=["test"] * self.max_batch_size)
except Exception:
pass

async def spawn(self):
"""set up the resources in batch"""
Expand Down
1 change: 1 addition & 0 deletions libs/infinity_emb/infinity_emb/inference/caching_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def _consume_queue(self) -> None:
k, v = item
self._cache.add(key=self._hash(k), value=v, expire=86400)
self._add_q.task_done()
self._threadpool.shutdown(wait=True)

def _get(self, sentence: str) -> Union[None, EmbeddingReturnType, List[float]]:
"""sets the item.complete() and sets embedding, if in cache."""
Expand Down
8 changes: 4 additions & 4 deletions libs/infinity_emb/infinity_emb/transformer/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def tokenize_lengths(self, sentences: List[str]) -> List[int]:
"""gets the lengths of each sentences according to tokenize/len etc."""

@abstractmethod
def warmup(self, batch_size: int = 64, n_tokens=1) -> Tuple[float, float, str]:
def warmup(self, *, batch_size: int = 64, n_tokens=1) -> Tuple[float, float, str]:
"""warmup the model
returns embeddings per second, inference time, and a log message"""
Expand All @@ -46,7 +46,7 @@ def encode_pre(self, sentences: List[str]) -> INPUT_FEATURE:
def encode_post(self, embedding: OUT_FEATURES) -> EmbeddingReturnType:
"""runs post encoding such as normlization"""

def warmup(self, batch_size: int = 64, n_tokens=1) -> Tuple[float, float, str]:
def warmup(self, *, batch_size: int = 64, n_tokens=1) -> Tuple[float, float, str]:
sample = ["warm" * n_tokens] * batch_size
inp = [
EmbeddingInner(content=EmbeddingSingle(s), future=None) # type: ignore
Expand All @@ -66,7 +66,7 @@ def encode_pre(self, sentences: List[str]) -> INPUT_FEATURE:
def encode_post(self, embedding: OUT_FEATURES) -> Dict[str, float]:
"""runs post encoding such as normlization"""

def warmup(self, batch_size: int = 64, n_tokens=1) -> Tuple[float, float, str]:
def warmup(self, *, batch_size: int = 64, n_tokens=1) -> Tuple[float, float, str]:
sample = ["warm" * n_tokens] * batch_size
inp = [
PredictInner(content=PredictSingle(s), future=None) # type: ignore
Expand All @@ -86,7 +86,7 @@ def encode_pre(self, queries_docs: List[Tuple[str, str]]) -> INPUT_FEATURE:
def encode_post(self, embedding: OUT_FEATURES) -> List[float]:
"""runs post encoding such as normlization"""

def warmup(self, batch_size: int = 64, n_tokens=1) -> Tuple[float, float, str]:
def warmup(self, *, batch_size: int = 64, n_tokens=1) -> Tuple[float, float, str]:
sample = ["warm" * n_tokens] * batch_size
inp = [
ReRankInner(
Expand Down
2 changes: 1 addition & 1 deletion libs/infinity_emb/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "infinity_emb"
version = "0.0.15"
version = "0.0.16"
description = "Infinity is a high-throughput, low-latency REST API for serving vector embeddings, supporting a wide range of sentence-transformer models and frameworks."
authors = ["michaelfeil <[email protected]>"]
license = "MIT"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ async def load_patched_bh() -> Tuple[SentenceTransformerPatched, BatchHandler]:
@pytest.mark.anyio
async def test_batch_performance_raw(get_sts_bechmark_dataset, load_patched_bh):
model, bh = load_patched_bh
assert bh.capabilities == {"embed"}
try:
sentences = []
for d in get_sts_bechmark_dataset:
Expand Down
5 changes: 5 additions & 0 deletions libs/infinity_emb/tests/unit_test/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ async def test_async_api_torch():
engine=transformer.InferenceEngine.torch,
device="auto",
)
assert engine.capabilities == {"embed"}
async with engine:
embeddings, usage = await engine.embed(sentences)
assert isinstance(embeddings, list)
Expand Down Expand Up @@ -58,6 +59,9 @@ async def test_async_api_torch_CROSSENCODER():
device="auto",
model_warmup=True,
)

assert engine.capabilities == {"rerank"}

async with engine:
rankings, usage = await engine.rerank(query=query, docs=documents)

Expand Down Expand Up @@ -100,6 +104,7 @@ async def test_async_api_torch_CLASSIFY():
engine="torch",
model_warmup=True,
)
assert engine.capabilities == {"classify"}

async with engine:
predictions, usage = await engine.classify(sentences=sentences)
Expand Down

0 comments on commit 58f01f1

Please sign in to comment.