Skip to content

Commit

Permalink
api changes sync async (#286)
Browse files Browse the repository at this point in the history
* api: fix minor mismatches

* add infer.py
  • Loading branch information
michaelfeil authored Jun 23, 2024
1 parent dd85082 commit f2b3072
Show file tree
Hide file tree
Showing 9 changed files with 63 additions and 24 deletions.
2 changes: 1 addition & 1 deletion libs/infinity_emb/infinity_emb/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ async def classify(
return await self[model].classify(sentences=sentences, raw_scores=raw_scores)

async def image_embed(
self, model: str, images: list[str]
self, *, model: str, images: list[str]
) -> tuple[list[EmbeddingReturnType], int]:
"""embed multiple images
Expand Down
25 changes: 17 additions & 8 deletions libs/infinity_emb/infinity_emb/sync_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,14 +99,14 @@ def async_run(

@add_start_docstrings(AsyncEngineArray.__doc__)
class SyncEngineArray(AsyncLifeMixin):
def __init__(self, engine_args: list[EngineArgs]):
def __init__(self, _engine_args_array: list[EngineArgs]):
super().__init__()
self.async_engine_array = AsyncEngineArray.from_args(engine_args)
self.async_engine_array = AsyncEngineArray.from_args(_engine_args_array)
self.async_run(self.async_engine_array.astart).result()

@classmethod
def from_args(cls, engine_args: list[EngineArgs]) -> "SyncEngineArray":
return cls(engine_args)
def from_args(cls, engine_args_array: list[EngineArgs]) -> "SyncEngineArray":
return cls(_engine_args_array=engine_args_array)

@property
def is_running(self):
Expand All @@ -128,17 +128,26 @@ def embed(self, *, model: str, sentences: list[str]):
)

@add_start_docstrings(AsyncEngineArray.rerank.__doc__)
def rerank(self, *, model: str, query: str, docs: list[str]):
def rerank(
self, *, model: str, query: str, docs: list[str], raw_scores: bool = False
):
"""sync interface of AsyncEngineArray"""
return self.async_run(
self.async_engine_array.rerank, model=model, query=query, docs=docs
self.async_engine_array.rerank,
model=model,
query=query,
docs=docs,
raw_scores=raw_scores,
)

@add_start_docstrings(AsyncEngineArray.classify.__doc__)
def classify(self, *, model: str, sentences: str):
def classify(self, *, model: str, sentences: list[str], raw_scores: bool = False):
"""sync interface of AsyncEngineArray"""
return self.async_run(
self.async_engine_array.classify, model=model, sentences=sentences
self.async_engine_array.classify,
model=model,
sentences=sentences,
raw_scores=raw_scores,
)

@add_start_docstrings(AsyncEngineArray.image_embed.__doc__)
Expand Down
2 changes: 1 addition & 1 deletion libs/infinity_emb/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "infinity_emb"
version = "0.0.48"
version = "0.0.49"
description = "Infinity is a high-throughput, low-latency REST API for serving vector embeddings, supporting a wide range of sentence-transformer models and frameworks."
authors = ["michaelfeil <[email protected]>"]
license = "MIT"
Expand Down
2 changes: 2 additions & 0 deletions libs/infinity_emb/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
pytest.DEFAULT_RERANKER_MODEL = "mixedbread-ai/mxbai-rerank-xsmall-v1"
pytest.DEFAULT_CLASSIFIER_MODEL = "SamLowe/roberta-base-go_emotions"

pytest.ENGINE_METHODS = ["embed", "image_embed", "classify", "rerank"]


@pytest.fixture
def anyio_backend():
Expand Down
14 changes: 13 additions & 1 deletion libs/infinity_emb/tests/unit_test/test_engine.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import asyncio
import inspect
import sys

import numpy as np
import pytest
import torch
from sentence_transformers import CrossEncoder # type: ignore[import-untyped]

from infinity_emb import AsyncEmbeddingEngine, EngineArgs
from infinity_emb import AsyncEmbeddingEngine, AsyncEngineArray, EngineArgs
from infinity_emb.primitives import (
Device,
EmbeddingDtype,
Expand Down Expand Up @@ -251,3 +252,14 @@ async def test_async_api_failing_revision():
revision="a32952c6d05d45f64f9f709a092c00839bcfe70a",
)
)


@pytest.mark.parametrize("method_name", list(pytest.ENGINE_METHODS)) # type: ignore
def test_args_between_array_and_engine_same(method_name: str):
array_method = inspect.getfullargspec(getattr(AsyncEngineArray, method_name))
engine_method = inspect.getfullargspec(getattr(AsyncEmbeddingEngine, method_name))

assert "model" in array_method.kwonlyargs
assert sorted(array_method.args + array_method.kwonlyargs) == sorted(
engine_method.args + engine_method.kwonlyargs + ["model"]
)
14 changes: 13 additions & 1 deletion libs/infinity_emb/tests/unit_test/test_sync_engine.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import inspect
from uuid import uuid4

import pytest

from infinity_emb import EngineArgs, SyncEngineArray
from infinity_emb import AsyncEngineArray, EngineArgs, SyncEngineArray


def test_sync_engine():
Expand Down Expand Up @@ -75,5 +76,16 @@ def test_sync_engine_on_model(model_id, method: str, payload: dict):
s_eng_array.stop()


@pytest.mark.parametrize("method_name", list(pytest.ENGINE_METHODS) + ["from_args"]) # type: ignore
def test_args_between_sync_and_async_same(method_name: str):
sync_method = inspect.getfullargspec(getattr(SyncEngineArray, method_name))
async_method = inspect.getfullargspec(getattr(AsyncEngineArray, method_name))
if method_name in list(pytest.ENGINE_METHODS): # type: ignore
assert "model" in sync_method.kwonlyargs
assert "model" in async_method.kwonlyargs
assert sync_method.args == async_method.args
assert sync_method.kwonlyargs == async_method.kwonlyargs


if __name__ == "__main__":
test_sync_engine()
22 changes: 12 additions & 10 deletions libs/simpleinference/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion libs/simpleinference/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ priority = "explicit"

[tool.poetry.dependencies]
python = ">=3.9,<4"
infinity_emb = {version = "0.0.47", extras = ["optimum","vision"]}
infinity_emb = {path = "../infinity_emb", extras = ["optimum","vision"]}

[tool.poetry.group.test.dependencies]
pytest = "^7.0.0"
Expand Down
4 changes: 3 additions & 1 deletion libs/simpleinference/simpleinference/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ def __init__(
)
for m, e, d, edt in zip(model_id, engine, device, embedding_dtype)
]
self._engine_array = SyncEngineArray.from_args(engine_args=self._engine_args)
self._engine_array = SyncEngineArray.from_args(
engine_args_array=self._engine_args
)

def stop(self):
self._engine_array.stop()
Expand Down

0 comments on commit f2b3072

Please sign in to comment.