From a0b5cc41135c1ee5775a2114d5b7b0767df11108 Mon Sep 17 00:00:00 2001 From: wirthual Date: Fri, 6 Dec 2024 05:08:13 +0100 Subject: [PATCH 1/5] initial commits for matryoshka_dim --- libs/infinity_emb/infinity_emb/engine.py | 41 +++++--- .../infinity_emb/fastapi_schemas/pymodels.py | 1 + .../infinity_emb/inference/batch_handler.py | 18 ++-- libs/infinity_emb/infinity_emb/sync_engine.py | 27 ++++-- .../end_to_end/test_api_with_dummymodel.py | 23 +++++ .../end_to_end/test_openapi_client_compat.py | 10 ++ .../tests/unit_test/test_engine.py | 95 +++++++++++++++++++ 7 files changed, 190 insertions(+), 25 deletions(-) diff --git a/libs/infinity_emb/infinity_emb/engine.py b/libs/infinity_emb/infinity_emb/engine.py index 73867a7d..68b2ca8f 100644 --- a/libs/infinity_emb/infinity_emb/engine.py +++ b/libs/infinity_emb/infinity_emb/engine.py @@ -130,11 +130,14 @@ def capabilities(self) -> set[ModelCapabilites]: def engine_args(self) -> EngineArgs: return self._engine_args - async def embed(self, sentences: list[str]) -> tuple[list["EmbeddingReturnType"], int]: + async def embed( + self, sentences: list[str], matryoshka_dim: int | None = None + ) -> tuple[list["EmbeddingReturnType"], int]: """embed multiple sentences Kwargs: sentences (list[str]): sentences to be embedded + matryoshka_dim (int): Length of matryoshka embedding Raises: ValueError: raised if engine is not started yet @@ -148,7 +151,9 @@ async def embed(self, sentences: list[str]) -> tuple[list["EmbeddingReturnType"] """ self._assert_running() - embeddings, usage = await self._batch_handler.embed(sentences=sentences) + embeddings, usage = await self._batch_handler.embed( + sentences=sentences, matryoshka_dim=matryoshka_dim + ) return embeddings, usage async def rerank( @@ -213,12 +218,16 @@ async def classify( return scores, usage async def image_embed( - self, *, images: list[Union[str, "ImageClassType", bytes]] + self, + *, + images: list[Union[str, "ImageClassType", bytes]], + matryoshka_dim: int | None = None, ) -> tuple[list["EmbeddingReturnType"], int]: """embed multiple images Kwargs: images (list[Union[str, ImageClassType]]): list of image urls or ImageClassType objects, to be embedded + matryoshka_dim (int): Length of matryoshka embedding Raises: ValueError: raised if engine is not started yet @@ -232,16 +241,19 @@ async def image_embed( """ self._assert_running() - embeddings, usage = await self._batch_handler.image_embed(images=images) + embeddings, usage = await self._batch_handler.image_embed( + images=images, matryoshka_dim=matryoshka_dim + ) return embeddings, usage async def audio_embed( - self, *, audios: list[Union[str, bytes]] + self, *, audios: list[Union[str, bytes]], matryoshka_dim: int | None = None ) -> tuple[list["EmbeddingReturnType"], int]: """embed multiple audios Kwargs: audios (list[Union[str, Audiobytes]]): list of audio data, to be embedded + matryoshka_dim (int): Length of matryoshka embedding Raises: ValueError: raised if engine is not started yet @@ -255,7 +267,9 @@ async def audio_embed( """ self._assert_running() - embeddings, usage = await self._batch_handler.audio_embed(audios=audios) + embeddings, usage = await self._batch_handler.audio_embed( + audios=audios, matryoshka_dim=matryoshka_dim + ) return embeddings, usage def _assert_running(self): @@ -304,13 +318,14 @@ async def astop(self): await engine.astop() async def embed( - self, *, model: str, sentences: list[str] + self, *, model: str, sentences: list[str], matryoshka_dim=None ) -> tuple[list["EmbeddingReturnType"], int]: """embed multiple sentences Kwargs: model (str): model name to be used sentences (list[str]): sentences to be embedded + matryoshka_dim (int): Length of matryoshka embedding Raises: ValueError: raised if engine is not started yet @@ -322,7 +337,7 @@ async def embed( 2D list-array of shape( len(sentences),embed_dim ) int: token usage """ - return await self[model].embed(sentences) + return await self[model].embed(sentences, matryoshka_dim=matryoshka_dim) def is_running(self) -> bool: return all(engine.is_running for engine in self.engines_dict.values()) @@ -378,13 +393,14 @@ async def classify( return await self[model].classify(sentences=sentences, raw_scores=raw_scores) async def image_embed( - self, *, model: str, images: list[Union[str, "ImageClassType"]] + self, *, model: str, images: list[Union[str, "ImageClassType"]], matryoshka_dim=None ) -> tuple[list["EmbeddingReturnType"], int]: """embed multiple images Kwargs: model (str): model name to be used images (list[Union[str, ImageClassType]]): list of image urls or ImageClassType objects, to be embedded + matryoshka_dim (int): Length of matryoshka embedding Raises: ValueError: raised if engine is not started yet @@ -396,7 +412,7 @@ async def image_embed( 2D list-array of shape( len(sentences),embed_dim ) int: token usage """ - return await self[model].image_embed(images=images) + return await self[model].image_embed(images=images, matryoshka_dim=matryoshka_dim) def __getitem__(self, index_or_name: Union[str, int]) -> "AsyncEmbeddingEngine": """resolve engine by model name -> Auto resolve if only one engine is present @@ -416,13 +432,14 @@ def __getitem__(self, index_or_name: Union[str, int]) -> "AsyncEmbeddingEngine": ) async def audio_embed( - self, *, model: str, audios: list[Union[str, bytes]] + self, *, model: str, audios: list[Union[str, bytes]], matryoshka_dim=None ) -> tuple[list["EmbeddingReturnType"], int]: """embed multiple audios Kwargs: model (str): model name to be used audios (list[Union[str, bytes]]): list of audio data, to be embedded + matryoshka_dim (int): Length of matryoshka embedding Raises: ValueError: raised if engine is not started yet @@ -434,4 +451,4 @@ async def audio_embed( 2D list-array of shape( len(sentences),embed_dim ) int: token usage """ - return await self[model].audio_embed(audios=audios) + return await self[model].audio_embed(audios=audios, matryoshka_dim=matryoshka_dim) diff --git a/libs/infinity_emb/infinity_emb/fastapi_schemas/pymodels.py b/libs/infinity_emb/infinity_emb/fastapi_schemas/pymodels.py index 3a623471..a706f090 100644 --- a/libs/infinity_emb/infinity_emb/fastapi_schemas/pymodels.py +++ b/libs/infinity_emb/infinity_emb/fastapi_schemas/pymodels.py @@ -54,6 +54,7 @@ class _OpenAIEmbeddingInput(BaseModel): model: str = "default/not-specified" encoding_format: EmbeddingEncodingFormat = EmbeddingEncodingFormat.float user: Optional[str] = None + dimensions: Optional[int] = None class _OpenAIEmbeddingInput_Text(_OpenAIEmbeddingInput): diff --git a/libs/infinity_emb/infinity_emb/inference/batch_handler.py b/libs/infinity_emb/infinity_emb/inference/batch_handler.py index 7ed4b83d..49696599 100644 --- a/libs/infinity_emb/infinity_emb/inference/batch_handler.py +++ b/libs/infinity_emb/infinity_emb/inference/batch_handler.py @@ -136,7 +136,9 @@ def __init__( " Consider increasing queue size" ) - async def embed(self, sentences: list[str]) -> tuple[list["EmbeddingReturnType"], int]: + async def embed( + self, sentences: list[str], matryoshka_dim=None + ) -> tuple[list["EmbeddingReturnType"], int]: """Schedule a sentence to be embedded. Awaits until embedded. Args: @@ -157,6 +159,8 @@ async def embed(self, sentences: list[str]) -> tuple[list["EmbeddingReturnType"] input_sentences = [EmbeddingSingle(sentence=s) for s in sentences] embeddings, usage = await self._schedule(input_sentences) + if matryoshka_dim: + embeddings = [embedding[:matryoshka_dim] for embedding in embeddings] return embeddings, usage async def rerank( @@ -236,9 +240,7 @@ async def classify( return classifications, usage async def image_embed( - self, - *, - images: list[Union[str, "ImageClassType", bytes]], + self, *, images: list[Union[str, "ImageClassType", bytes]], matryoshka_dim=None ) -> tuple[list["EmbeddingReturnType"], int]: """Schedule a images and sentences to be embedded. Awaits until embedded. @@ -262,12 +264,12 @@ async def image_embed( items = await resolve_images(images) embeddings, usage = await self._schedule(items) + if matryoshka_dim: + embeddings = [embedding[:matryoshka_dim] for embedding in embeddings] return embeddings, usage async def audio_embed( - self, - *, - audios: list[Union[str, bytes]], + self, *, audios: list[Union[str, bytes]], matryoshka_dim=None ) -> tuple[list["EmbeddingReturnType"], int]: """Schedule audios and sentences to be embedded. Awaits until embedded. @@ -294,6 +296,8 @@ async def audio_embed( getattr(self.model_worker[0]._model, "sampling_rate", -42), ) embeddings, usage = await self._schedule(items) + if matryoshka_dim: + embeddings = [embedding[:matryoshka_dim] for embedding in embeddings] return embeddings, usage async def _schedule(self, list_queueitem: Sequence[AbstractSingle]) -> tuple[list[Any], int]: diff --git a/libs/infinity_emb/infinity_emb/sync_engine.py b/libs/infinity_emb/infinity_emb/sync_engine.py index c6a907d8..a398488a 100644 --- a/libs/infinity_emb/infinity_emb/sync_engine.py +++ b/libs/infinity_emb/infinity_emb/sync_engine.py @@ -171,9 +171,14 @@ def stop(self): self.async_run(self.async_engine_array.astop).result() @add_start_docstrings(AsyncEngineArray.embed.__doc__) - def embed(self, *, model: str, sentences: list[str]): + def embed(self, *, model: str, sentences: list[str], matryoshka_dim=None): """sync interface of AsyncEngineArray""" - return self.async_run(self.async_engine_array.embed, model=model, sentences=sentences) + return self.async_run( + self.async_engine_array.embed, + model=model, + sentences=sentences, + matryoshka_dim=matryoshka_dim, + ) @add_start_docstrings(AsyncEngineArray.rerank.__doc__) def rerank( @@ -206,14 +211,24 @@ def classify(self, *, model: str, sentences: list[str], raw_scores: bool = False ) @add_start_docstrings(AsyncEngineArray.image_embed.__doc__) - def image_embed(self, *, model: str, images: list[Union[str, bytes]]): + def image_embed(self, *, model: str, images: list[Union[str, bytes]], matryoshka_dim=None): """sync interface of AsyncEngineArray""" - return self.async_run(self.async_engine_array.image_embed, model=model, images=images) + return self.async_run( + self.async_engine_array.image_embed, + model=model, + images=images, + matryoshka_dim=matryoshka_dim, + ) @add_start_docstrings(AsyncEngineArray.audio_embed.__doc__) - def audio_embed(self, *, model: str, audios: list[Union[str, bytes]]): + def audio_embed(self, *, model: str, audios: list[Union[str, bytes]], matryoshka_dim=None): """sync interface of AsyncEngineArray""" - return self.async_run(self.async_engine_array.audio_embed, model=model, audios=audios) + return self.async_run( + self.async_engine_array.audio_embed, + model=model, + audios=audios, + matryoshka_dim=matryoshka_dim, + ) def __del__(self): self.stop() diff --git a/libs/infinity_emb/tests/end_to_end/test_api_with_dummymodel.py b/libs/infinity_emb/tests/end_to_end/test_api_with_dummymodel.py index 73e68e2f..1e8d1aa4 100644 --- a/libs/infinity_emb/tests/end_to_end/test_api_with_dummymodel.py +++ b/libs/infinity_emb/tests/end_to_end/test_api_with_dummymodel.py @@ -170,3 +170,26 @@ async def test_openapi_same_as_docs_file(client): tc.assertDictEqual(openapi_json["info"], openapi_json_expected["info"]) tc.assertDictEqual(openapi_json["paths"], openapi_json_expected["paths"]) # tc.assertDictEqual(openapi_json["components"], openapi_json_expected["components"]) + + +@pytest.mark.anyio +async def test_matryoshka_embedding(client): + matryoshka_dim = 10 + + possible_inputs = [ + ["This is a test sentence."], + ["This is a test sentence.", "This is another test sentence."], + ] + for inp in possible_inputs: + response = await client.post( + f"{PREFIX}/embeddings", + json=dict(input=inp, model=MODEL_NAME, dimensions=matryoshka_dim), + ) + assert response.status_code == 200, f"{response.status_code}, {response.text}" + rdata = response.json() + assert "data" in rdata and isinstance(rdata["data"], list) + assert all("embedding" in d for d in rdata["data"]) + assert len(rdata["data"]) == len(inp) + for embedding, sentence in zip(rdata["data"], inp): + assert len(sentence) == embedding["embedding"][0] + assert len(embedding["embedding"]) == matryoshka_dim diff --git a/libs/infinity_emb/tests/end_to_end/test_openapi_client_compat.py b/libs/infinity_emb/tests/end_to_end/test_openapi_client_compat.py index 79b5ebab..19f5a385 100644 --- a/libs/infinity_emb/tests/end_to_end/test_openapi_client_compat.py +++ b/libs/infinity_emb/tests/end_to_end/test_openapi_client_compat.py @@ -120,6 +120,16 @@ async def test_openai(client: AsyncClient): extra_body={"modality": "text"}, ) + # test: text matryoshka + emb_1_text_matryoshka_dim = await client_oai.embeddings.create( + model=pytest.DEFAULT_BERT_MODEL, + input=["a cat", "a cat", "a bird"], + encoding_format="float", + dimensions=64, + extra_body={"modality": "text"}, + ) + assert len(emb_1_text_matryoshka_dim.data[0].embedding) == 64 + # test AUDIO: cosine distance of beep to cat and dog np.testing.assert_allclose( emb1_audio.data[0].embedding, emb1_1_audio.data[0].embedding, rtol=1e-5 diff --git a/libs/infinity_emb/tests/unit_test/test_engine.py b/libs/infinity_emb/tests/unit_test/test_engine.py index c3bf054b..99cecbbb 100644 --- a/libs/infinity_emb/tests/unit_test/test_engine.py +++ b/libs/infinity_emb/tests/unit_test/test_engine.py @@ -380,3 +380,98 @@ def test_args_between_array_and_engine_same(method_name: str): assert sorted(array_method.args + array_method.kwonlyargs) == sorted( engine_method.args + engine_method.kwonlyargs + ["model"] ) + + +@pytest.mark.anyio +async def test_async_api_torch_matryoshka(): + matryoshka_dim = 64 + + sentences = ["Hi", "how"] + engine = AsyncEmbeddingEngine.from_args( + EngineArgs( + model_name_or_path="nomic-ai/nomic-embed-text-v1.5", + engine=InferenceEngine.torch, + revision="main", + device="cpu", + ) + ) + assert engine.capabilities == {"embed"} + async with engine: + embeddings, usage = await engine.embed(sentences=sentences, matryoshka_dim=matryoshka_dim) + assert isinstance(embeddings, list) + assert isinstance(embeddings[0], np.ndarray) + embeddings = np.array(embeddings) + assert usage == sum([len(s) for s in sentences]) + assert embeddings.shape[0] == len(sentences) + assert embeddings.shape[1] >= 10 + + assert len(embeddings[0]) == 64 + + # test if model denies classification and reranking + with pytest.raises(ModelNotDeployedError): + await engine.classify(sentences=sentences) + with pytest.raises(ModelNotDeployedError): + await engine.rerank(query="dummy", docs=sentences) + + +@pytest.mark.anyio +async def test_torch_clip_embed_matryoshka(): + matryoshka_dim = 128 + + image_urls = ["http://images.cocodataset.org/val2017/000000039769.jpg"] # a photo of two cats + sentences = [ + "a photo of two cats", + "a photo of a cat", + "a photo of a dog", + "a photo of a car", + ] + engine = AsyncEmbeddingEngine.from_args( + EngineArgs( + model_name_or_path="jinaai/jina-clip-v2", + engine=InferenceEngine.torch, + model_warmup=True, + ) + ) + async with engine: + t1, t2 = ( + asyncio.create_task(engine.embed(sentences=sentences, matryoshka_dim=matryoshka_dim)), + asyncio.create_task( + engine.image_embed(images=image_urls, matryoshka_dim=matryoshka_dim) + ), + ) + emb_text, usage_text = await t1 + emb_image, usage_image = await t2 + emb_text_np = np.array(emb_text) # type: ignore + emb_image_np = np.array(emb_image) # type: ignore + + assert len(emb_text_np[0]) == matryoshka_dim + assert len(emb_image_np[0]) == matryoshka_dim + + # check if cat image and two cats are most similar + for i in range(1, len(sentences)): + assert np.dot(emb_text_np[0], emb_image_np[0]) > np.dot(emb_text_np[i], emb_image_np[0]) + + +@pytest.mark.anyio +async def test_clap_like_model_matryoshka(audio_sample): + matryoshka_dim = 64 + + model_name = pytest.DEFAULT_AUDIO_MODEL + engine = AsyncEmbeddingEngine.from_args( + EngineArgs(model_name_or_path=model_name, dtype="float32") + ) + url = audio_sample[1] + bytes_url = audio_sample[0].content + + inputs = ["a sound of a cat", "a sound of a cat"] + audios = [url, bytes_url] + async with engine: + embeddings_text, usage_1 = await engine.embed( + sentences=inputs, matryoshka_dim=matryoshka_dim + ) + embeddings_audio, usage_2 = await engine.audio_embed( + audios=audios, matryoshka_dim=matryoshka_dim + ) + + assert len(embeddings_text[0]) == matryoshka_dim + assert len(embeddings_audio[0]) == matryoshka_dim From 7917c56cc450de032e4322006411d0b1f8fe55e8 Mon Sep 17 00:00:00 2001 From: wirthual Date: Fri, 6 Dec 2024 05:20:58 +0100 Subject: [PATCH 2/5] add missing type hints --- libs/infinity_emb/infinity_emb/engine.py | 6 +++--- libs/infinity_emb/infinity_emb/inference/batch_handler.py | 6 +++--- libs/infinity_emb/infinity_emb/sync_engine.py | 6 +++--- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/libs/infinity_emb/infinity_emb/engine.py b/libs/infinity_emb/infinity_emb/engine.py index 68b2ca8f..d98a1877 100644 --- a/libs/infinity_emb/infinity_emb/engine.py +++ b/libs/infinity_emb/infinity_emb/engine.py @@ -318,7 +318,7 @@ async def astop(self): await engine.astop() async def embed( - self, *, model: str, sentences: list[str], matryoshka_dim=None + self, *, model: str, sentences: list[str], matryoshka_dim: Optional[int]=None ) -> tuple[list["EmbeddingReturnType"], int]: """embed multiple sentences @@ -393,7 +393,7 @@ async def classify( return await self[model].classify(sentences=sentences, raw_scores=raw_scores) async def image_embed( - self, *, model: str, images: list[Union[str, "ImageClassType"]], matryoshka_dim=None + self, *, model: str, images: list[Union[str, "ImageClassType"]], matryoshka_dim:Optional[int]=None ) -> tuple[list["EmbeddingReturnType"], int]: """embed multiple images @@ -432,7 +432,7 @@ def __getitem__(self, index_or_name: Union[str, int]) -> "AsyncEmbeddingEngine": ) async def audio_embed( - self, *, model: str, audios: list[Union[str, bytes]], matryoshka_dim=None + self, *, model: str, audios: list[Union[str, bytes]], matryoshka_dim:Optional[int]=None ) -> tuple[list["EmbeddingReturnType"], int]: """embed multiple audios diff --git a/libs/infinity_emb/infinity_emb/inference/batch_handler.py b/libs/infinity_emb/infinity_emb/inference/batch_handler.py index 49696599..94d2ecaa 100644 --- a/libs/infinity_emb/infinity_emb/inference/batch_handler.py +++ b/libs/infinity_emb/infinity_emb/inference/batch_handler.py @@ -137,7 +137,7 @@ def __init__( ) async def embed( - self, sentences: list[str], matryoshka_dim=None + self, sentences: list[str], matryoshka_dim:Optional[int]=None ) -> tuple[list["EmbeddingReturnType"], int]: """Schedule a sentence to be embedded. Awaits until embedded. @@ -240,7 +240,7 @@ async def classify( return classifications, usage async def image_embed( - self, *, images: list[Union[str, "ImageClassType", bytes]], matryoshka_dim=None + self, *, images: list[Union[str, "ImageClassType", bytes]], matryoshka_dim:Optional[int]=None ) -> tuple[list["EmbeddingReturnType"], int]: """Schedule a images and sentences to be embedded. Awaits until embedded. @@ -269,7 +269,7 @@ async def image_embed( return embeddings, usage async def audio_embed( - self, *, audios: list[Union[str, bytes]], matryoshka_dim=None + self, *, audios: list[Union[str, bytes]], matryoshka_dim:Optional[int]=None ) -> tuple[list["EmbeddingReturnType"], int]: """Schedule audios and sentences to be embedded. Awaits until embedded. diff --git a/libs/infinity_emb/infinity_emb/sync_engine.py b/libs/infinity_emb/infinity_emb/sync_engine.py index a398488a..2475269d 100644 --- a/libs/infinity_emb/infinity_emb/sync_engine.py +++ b/libs/infinity_emb/infinity_emb/sync_engine.py @@ -171,7 +171,7 @@ def stop(self): self.async_run(self.async_engine_array.astop).result() @add_start_docstrings(AsyncEngineArray.embed.__doc__) - def embed(self, *, model: str, sentences: list[str], matryoshka_dim=None): + def embed(self, *, model: str, sentences: list[str], matryoshka_dim:Optional[int]=None): """sync interface of AsyncEngineArray""" return self.async_run( self.async_engine_array.embed, @@ -211,7 +211,7 @@ def classify(self, *, model: str, sentences: list[str], raw_scores: bool = False ) @add_start_docstrings(AsyncEngineArray.image_embed.__doc__) - def image_embed(self, *, model: str, images: list[Union[str, bytes]], matryoshka_dim=None): + def image_embed(self, *, model: str, images: list[Union[str, bytes]], matryoshka_dim:Optional[int]=None): """sync interface of AsyncEngineArray""" return self.async_run( self.async_engine_array.image_embed, @@ -221,7 +221,7 @@ def image_embed(self, *, model: str, images: list[Union[str, bytes]], matryoshka ) @add_start_docstrings(AsyncEngineArray.audio_embed.__doc__) - def audio_embed(self, *, model: str, audios: list[Union[str, bytes]], matryoshka_dim=None): + def audio_embed(self, *, model: str, audios: list[Union[str, bytes]], matryoshka_dim:Optional[int]=None): """sync interface of AsyncEngineArray""" return self.async_run( self.async_engine_array.audio_embed, From d8ad01009937b0fcbbc012f24fcf9e2c295c8616 Mon Sep 17 00:00:00 2001 From: wirthual Date: Fri, 6 Dec 2024 05:27:25 +0100 Subject: [PATCH 3/5] format. Use future annotations --- libs/infinity_emb/infinity_emb/engine.py | 11 ++++++++--- .../infinity_emb/inference/batch_handler.py | 9 ++++++--- libs/infinity_emb/infinity_emb/sync_engine.py | 10 +++++++--- 3 files changed, 21 insertions(+), 9 deletions(-) diff --git a/libs/infinity_emb/infinity_emb/engine.py b/libs/infinity_emb/infinity_emb/engine.py index d98a1877..153e15ba 100644 --- a/libs/infinity_emb/infinity_emb/engine.py +++ b/libs/infinity_emb/infinity_emb/engine.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: MIT # Copyright (c) 2023-now michaelfeilfeil +from __future__ import annotations from asyncio import Semaphore from typing import Iterable, Iterator, Optional, Union @@ -318,7 +319,7 @@ async def astop(self): await engine.astop() async def embed( - self, *, model: str, sentences: list[str], matryoshka_dim: Optional[int]=None + self, *, model: str, sentences: list[str], matryoshka_dim: Optional[int] = None ) -> tuple[list["EmbeddingReturnType"], int]: """embed multiple sentences @@ -393,7 +394,11 @@ async def classify( return await self[model].classify(sentences=sentences, raw_scores=raw_scores) async def image_embed( - self, *, model: str, images: list[Union[str, "ImageClassType"]], matryoshka_dim:Optional[int]=None + self, + *, + model: str, + images: list[Union[str, "ImageClassType"]], + matryoshka_dim: Optional[int] = None, ) -> tuple[list["EmbeddingReturnType"], int]: """embed multiple images @@ -432,7 +437,7 @@ def __getitem__(self, index_or_name: Union[str, int]) -> "AsyncEmbeddingEngine": ) async def audio_embed( - self, *, model: str, audios: list[Union[str, bytes]], matryoshka_dim:Optional[int]=None + self, *, model: str, audios: list[Union[str, bytes]], matryoshka_dim: Optional[int] = None ) -> tuple[list["EmbeddingReturnType"], int]: """embed multiple audios diff --git a/libs/infinity_emb/infinity_emb/inference/batch_handler.py b/libs/infinity_emb/infinity_emb/inference/batch_handler.py index 94d2ecaa..1edda315 100644 --- a/libs/infinity_emb/infinity_emb/inference/batch_handler.py +++ b/libs/infinity_emb/infinity_emb/inference/batch_handler.py @@ -137,7 +137,7 @@ def __init__( ) async def embed( - self, sentences: list[str], matryoshka_dim:Optional[int]=None + self, sentences: list[str], matryoshka_dim: Optional[int] = None ) -> tuple[list["EmbeddingReturnType"], int]: """Schedule a sentence to be embedded. Awaits until embedded. @@ -240,7 +240,10 @@ async def classify( return classifications, usage async def image_embed( - self, *, images: list[Union[str, "ImageClassType", bytes]], matryoshka_dim:Optional[int]=None + self, + *, + images: list[Union[str, "ImageClassType", bytes]], + matryoshka_dim: Optional[int] = None, ) -> tuple[list["EmbeddingReturnType"], int]: """Schedule a images and sentences to be embedded. Awaits until embedded. @@ -269,7 +272,7 @@ async def image_embed( return embeddings, usage async def audio_embed( - self, *, audios: list[Union[str, bytes]], matryoshka_dim:Optional[int]=None + self, *, audios: list[Union[str, bytes]], matryoshka_dim: Optional[int] = None ) -> tuple[list["EmbeddingReturnType"], int]: """Schedule audios and sentences to be embedded. Awaits until embedded. diff --git a/libs/infinity_emb/infinity_emb/sync_engine.py b/libs/infinity_emb/infinity_emb/sync_engine.py index 2475269d..deda66d9 100644 --- a/libs/infinity_emb/infinity_emb/sync_engine.py +++ b/libs/infinity_emb/infinity_emb/sync_engine.py @@ -171,7 +171,7 @@ def stop(self): self.async_run(self.async_engine_array.astop).result() @add_start_docstrings(AsyncEngineArray.embed.__doc__) - def embed(self, *, model: str, sentences: list[str], matryoshka_dim:Optional[int]=None): + def embed(self, *, model: str, sentences: list[str], matryoshka_dim: Optional[int] = None): """sync interface of AsyncEngineArray""" return self.async_run( self.async_engine_array.embed, @@ -211,7 +211,9 @@ def classify(self, *, model: str, sentences: list[str], raw_scores: bool = False ) @add_start_docstrings(AsyncEngineArray.image_embed.__doc__) - def image_embed(self, *, model: str, images: list[Union[str, bytes]], matryoshka_dim:Optional[int]=None): + def image_embed( + self, *, model: str, images: list[Union[str, bytes]], matryoshka_dim: Optional[int] = None + ): """sync interface of AsyncEngineArray""" return self.async_run( self.async_engine_array.image_embed, @@ -221,7 +223,9 @@ def image_embed(self, *, model: str, images: list[Union[str, bytes]], matryoshka ) @add_start_docstrings(AsyncEngineArray.audio_embed.__doc__) - def audio_embed(self, *, model: str, audios: list[Union[str, bytes]], matryoshka_dim:Optional[int]=None): + def audio_embed( + self, *, model: str, audios: list[Union[str, bytes]], matryoshka_dim: Optional[int] = None + ): """sync interface of AsyncEngineArray""" return self.async_run( self.async_engine_array.audio_embed, From 1de6c5278fa94b1c7c3cbc03027fdd1e9b1b2bb1 Mon Sep 17 00:00:00 2001 From: wirthual Date: Fri, 6 Dec 2024 17:46:22 +0100 Subject: [PATCH 4/5] add dims to server --- libs/infinity_emb/infinity_emb/infinity_server.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/libs/infinity_emb/infinity_emb/infinity_server.py b/libs/infinity_emb/infinity_emb/infinity_server.py index 0ac1c835..582d7fa3 100644 --- a/libs/infinity_emb/infinity_emb/infinity_server.py +++ b/libs/infinity_emb/infinity_emb/infinity_server.py @@ -354,21 +354,21 @@ def url_to_base64(url, modality = "image"): "[📝] Received request with %s input texts ", len(input_), # type: ignore ) - embedding, usage = await engine.embed(sentences=input_) + embedding, usage = await engine.embed(sentences=input_,matryoshka_dim=data_root.dimensions) elif modality == Modality.audio: urls_or_bytes = _resolve_mixed_input(data_root.input) # type: ignore logger.debug( "[📝] Received request with %s input audios ", len(urls_or_bytes), # type: ignore ) - embedding, usage = await engine.audio_embed(audios=urls_or_bytes) + embedding, usage = await engine.audio_embed(audios=urls_or_bytes,matryoshka_dim=data_root.dimensions) elif modality == Modality.image: urls_or_bytes = _resolve_mixed_input(data_root.input) # type: ignore logger.debug( "[📝] Received request with %s input images ", len(urls_or_bytes), # type: ignore ) - embedding, usage = await engine.image_embed(images=urls_or_bytes) + embedding, usage = await engine.image_embed(images=urls_or_bytes,matryoshka_dim=data_root.dimensions) duration = (time.perf_counter() - start) * 1000 logger.debug("[✅] Done in %s ms", duration) From 9c811fbc7bebb4c09f26fe3e7ca19a27ced52c1a Mon Sep 17 00:00:00 2001 From: wirthual Date: Mon, 9 Dec 2024 17:42:28 +0100 Subject: [PATCH 5/5] add constraints for dimensions --- libs/infinity_emb/infinity_emb/fastapi_schemas/pymodels.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/infinity_emb/infinity_emb/fastapi_schemas/pymodels.py b/libs/infinity_emb/infinity_emb/fastapi_schemas/pymodels.py index a706f090..d7c392e9 100644 --- a/libs/infinity_emb/infinity_emb/fastapi_schemas/pymodels.py +++ b/libs/infinity_emb/infinity_emb/fastapi_schemas/pymodels.py @@ -54,7 +54,7 @@ class _OpenAIEmbeddingInput(BaseModel): model: str = "default/not-specified" encoding_format: EmbeddingEncodingFormat = EmbeddingEncodingFormat.float user: Optional[str] = None - dimensions: Optional[int] = None + dimensions: Optional[Annotated[int, Field(strict=True, gt=0, lt=8193)]] = None class _OpenAIEmbeddingInput_Text(_OpenAIEmbeddingInput):