-
Notifications
You must be signed in to change notification settings - Fork 123
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #280 from michaelfeil/easyinference
add easyinference
- Loading branch information
Showing
15 changed files
with
3,229 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
.PHONY: all clean docs_build docs_clean docs_linkcheck api_docs_build api_docs_clean api_docs_linkcheck format lint test tests test_watch integration_tests docker_tests help extended_tests | ||
|
||
# Default target executed when no arguments are given to make. | ||
all: help | ||
|
||
precommit : | format spell_fix spell_check lint poetry_check test | ||
|
||
###################### | ||
# TESTING AND COVERAGE | ||
###################### | ||
|
||
# Define a variable for the test file path. | ||
TEST_FILE ?= | ||
|
||
# Run unit tests and generate a coverage report. | ||
coverage: | ||
poetry run coverage run --source ./easyinference -m pytest --doctest-modules | ||
poetry run coverage report -m | ||
poetry run coverage xml | ||
|
||
test tests: | ||
poetry run pytest --doctest-modules | ||
|
||
|
||
###################### | ||
# LINTING AND FORMATTING | ||
###################### | ||
|
||
# Define a variable for Python and notebook files. | ||
PYTHON_FILES=. | ||
lint format: PYTHON_FILES=. | ||
lint_diff format_diff: PYTHON_FILES=$(shell git diff --relative=libs/infinity_emb --name-only --diff-filter=d master | grep -E '\.py$$|\.ipynb$$') | ||
|
||
lint lint_diff: | ||
poetry run ruff . | ||
[ "$(PYTHON_FILES)" = "" ] || poetry run black $(PYTHON_FILES) --check | ||
[ "$(PYTHON_FILES)" = "" ] || poetry run mypy $(PYTHON_FILES) | ||
|
||
format format_diff: | ||
[ "$(PYTHON_FILES)" = "" ] || poetry run black $(PYTHON_FILES) | ||
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff --select I --fix $(PYTHON_FILES) | ||
|
||
poetry_check: | ||
poetry check | ||
|
||
spell_check: | ||
poetry run codespell --toml pyproject.toml | ||
|
||
spell_fix: | ||
poetry run codespell --toml pyproject.toml -w | ||
|
||
benchmark_embed: tests/data/benchmark/benchmark_embed.json | ||
ab -n 10 -c 10 -l -s 480 \ | ||
-T 'application/json' \ | ||
-p $< \ | ||
http://127.0.0.1:7997/embeddings | ||
# sudo apt-get apache2-utils | ||
|
||
###################### | ||
# HELP | ||
###################### | ||
|
||
help: | ||
@echo '====================' | ||
@echo 'clean - run docs_clean and api_docs_clean' | ||
@echo 'docs_build - build the documentation' | ||
@echo 'docs_clean - clean the documentation build artifacts' | ||
@echo 'docs_linkcheck - run linkchecker on the documentation' | ||
@echo 'api_docs_build - build the API Reference documentation' | ||
@echo 'api_docs_clean - clean the API Reference documentation build artifacts' | ||
@echo 'api_docs_linkcheck - run linkchecker on the API Reference documentation' | ||
@echo '-- LINTING --' | ||
@echo 'format - run code formatters' | ||
@echo 'lint - run linters' | ||
@echo 'spell_check - run codespell on the project' | ||
@echo 'spell_fix - run codespell on the project and fix the errors' | ||
@echo 'poetry_check - run poetry check' | ||
@echo '-- TESTS --' | ||
@echo 'coverage - run unit tests and generate coverage report' | ||
@echo 'test - run unit tests' | ||
@echo 'tests - run unit tests (alias for "make test")' | ||
@echo 'test TEST_FILE=<test_file> - run all tests in file' | ||
@echo 'extended_tests - run only extended unit tests' | ||
@echo 'test_watch - run unit tests in watch mode' | ||
@echo 'integration_tests - run integration tests' | ||
@echo 'docker_tests - run unit tests in docker' | ||
@echo '-- DOCUMENTATION tasks are from the top-level Makefile --' |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from easyinference.infer import EasyInference | ||
|
||
__all__ = ["EasyInference"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
from concurrent.futures import Future | ||
from typing import Iterable, Literal, Union | ||
|
||
from infinity_emb import EngineArgs, SyncEngineArray | ||
|
||
__all__ = ["EasyInference"] | ||
|
||
Device = Literal["cpu", "cuda"] | ||
ModelID = str | ||
Engine = Literal["torch", "optimum"] | ||
EmbeddingDtype = Literal["float32", "int8", "binary"] | ||
ModelIndex = Union[int, str] | ||
|
||
|
||
class EasyInference: | ||
def __init__( | ||
self, | ||
*, | ||
model_id: Union[ModelID, Iterable[ModelID]], | ||
engine: Union[Engine, Iterable[Engine]] = "optimum", | ||
device: Union[Device, Iterable[Device]] = "cpu", | ||
embedding_dtype: Union[EmbeddingDtype, Iterable[EmbeddingDtype]] = "float32", | ||
): | ||
"""An easy interface to infer with multiple models. | ||
>>> ei = EasyInference(model_id="michaelfeil/bge-small-en-v1.5") | ||
>>> ei | ||
EasyInference(['michaelfeil/bge-small-en-v1.5']) | ||
>>> ei.stop() | ||
""" | ||
|
||
if isinstance(model_id, str): | ||
model_id = [model_id] | ||
if isinstance(engine, str): | ||
engine = [engine] | ||
if isinstance(device, str): | ||
device = [device] | ||
if isinstance(embedding_dtype, str): | ||
embedding_dtype = [embedding_dtype] | ||
self._engine_args = [ | ||
EngineArgs( | ||
model_name_or_path=m, | ||
engine=e, # type: ignore | ||
device=d, # type: ignore | ||
served_model_name=m, | ||
embedding_dtype=edt, # type: ignore | ||
lengths_via_tokenize=True, | ||
model_warmup=False, | ||
) | ||
for m, e, d, edt in zip(model_id, engine, device, embedding_dtype) | ||
] | ||
self._engine_array = SyncEngineArray.from_args(engine_args=self._engine_args) | ||
|
||
def stop(self): | ||
self._engine_array.stop() | ||
|
||
def __repr__(self) -> str: | ||
return f"{self.__class__.__name__}({[a.model_name_or_path for a in self._engine_args]})" | ||
|
||
def embed( | ||
self, | ||
*, | ||
sentences: list[str], | ||
model_id: ModelIndex = 0, | ||
) -> Future[tuple[list[list[float]], int]]: | ||
"""Embed sentences with a model. | ||
>>> ei = EasyInference(model_id="michaelfeil/bge-small-en-v1.5") | ||
>>> embed_result = ei.embed(model_id="michaelfeil/bge-small-en-v1.5", sentences=["Hello, world!"]) | ||
>>> type(embed_result) | ||
<class 'concurrent.futures._base.Future'> | ||
>>> embed_result.result()[0][0].shape # embedding | ||
(384,) | ||
>>> embed_result.result()[1] # embedding and usage of 6 tokens | ||
6 | ||
>>> ei.stop() | ||
""" | ||
return self._engine_array.embed(model=model_id, sentences=sentences) | ||
|
||
def image_embed( | ||
self, | ||
*, | ||
images: list[str], | ||
model_id: ModelIndex = 0, | ||
) -> Future[tuple[list[list[float]], int]]: | ||
"""Embed images with a model. | ||
>>> ei = EasyInference(model_id="wkcn/TinyCLIP-ViT-8M-16-Text-3M-YFCC15M", engine="torch") | ||
>>> image_embed_result = ei.image_embed(model_id="wkcn/TinyCLIP-ViT-8M-16-Text-3M-YFCC15M", images=["http://images.cocodataset.org/val2017/000000039769.jpg"]) | ||
>>> type(image_embed_result) | ||
<class 'concurrent.futures._base.Future'> | ||
>>> image_embed_result.result()[0][0].shape | ||
(512,) | ||
>>> ei.stop() | ||
""" | ||
return self._engine_array.image_embed(model=model_id, images=images) | ||
|
||
def classify( | ||
self, | ||
*, | ||
sentences: list[str], | ||
model_id: ModelIndex = 0, | ||
) -> Future[tuple[list[list[dict[str, float]]], int]]: | ||
"""Classify sentences with a model. | ||
>>> ei = EasyInference(model_id="philschmid/tiny-bert-sst2-distilled", engine="torch") | ||
>>> classify_result = ei.classify(model_id="philschmid/tiny-bert-sst2-distilled", sentences=["I love this movie"]) | ||
>>> type(classify_result) | ||
<class 'concurrent.futures._base.Future'> | ||
>>> label_0 = classify_result.result()[0][0][0] | ||
>>> label_0["label"], round(label_0["score"], 4) | ||
('positive', 0.9996) | ||
>>> ei.stop() | ||
""" | ||
return self._engine_array.classify(model=model_id, sentences=sentences) | ||
|
||
def rerank( | ||
self, | ||
*, | ||
query: str, | ||
docs: list[str], | ||
model_id: ModelIndex = 0, | ||
) -> Future[list[str]]: | ||
""" | ||
>>> ei = EasyInference(model_id="mixedbread-ai/mxbai-rerank-xsmall-v1") | ||
>>> docs = ["Paris is nice", "Paris is in France", "In Germany"] | ||
>>> rerank_result = ei.rerank(model_id="mixedbread-ai/mxbai-rerank-xsmall-v1", query="Where is Paris?", docs=docs) | ||
>>> type(rerank_result) | ||
<class 'concurrent.futures._base.Future'> | ||
>>> [round(score, 3) for score in rerank_result.result()[0]] | ||
[0.288, 0.742, 0.022] | ||
>>> ei.stop() | ||
""" | ||
return self._engine_array.rerank(model=model_id, query=query, docs=docs) |
Empty file.
Oops, something went wrong.