Skip to content

Commit

Permalink
Merge pull request #490 from michaelfeil/matryoshka_dim
Browse files Browse the repository at this point in the history
Support for  matryoshka embeddings
  • Loading branch information
wirthual authored Dec 10, 2024
2 parents be48378 + 9c811fb commit efe6096
Show file tree
Hide file tree
Showing 8 changed files with 202 additions and 25 deletions.
46 changes: 34 additions & 12 deletions libs/infinity_emb/infinity_emb/engine.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# SPDX-License-Identifier: MIT
# Copyright (c) 2023-now michaelfeilfeil
from __future__ import annotations

from asyncio import Semaphore
from typing import Iterable, Iterator, Optional, Union
Expand Down Expand Up @@ -130,11 +131,14 @@ def capabilities(self) -> set[ModelCapabilites]:
def engine_args(self) -> EngineArgs:
return self._engine_args

async def embed(self, sentences: list[str]) -> tuple[list["EmbeddingReturnType"], int]:
async def embed(
self, sentences: list[str], matryoshka_dim: int | None = None
) -> tuple[list["EmbeddingReturnType"], int]:
"""embed multiple sentences
Kwargs:
sentences (list[str]): sentences to be embedded
matryoshka_dim (int): Length of matryoshka embedding
Raises:
ValueError: raised if engine is not started yet
Expand All @@ -148,7 +152,9 @@ async def embed(self, sentences: list[str]) -> tuple[list["EmbeddingReturnType"]
"""

self._assert_running()
embeddings, usage = await self._batch_handler.embed(sentences=sentences)
embeddings, usage = await self._batch_handler.embed(
sentences=sentences, matryoshka_dim=matryoshka_dim
)
return embeddings, usage

async def rerank(
Expand Down Expand Up @@ -213,12 +219,16 @@ async def classify(
return scores, usage

async def image_embed(
self, *, images: list[Union[str, "ImageClassType", bytes]]
self,
*,
images: list[Union[str, "ImageClassType", bytes]],
matryoshka_dim: int | None = None,
) -> tuple[list["EmbeddingReturnType"], int]:
"""embed multiple images
Kwargs:
images (list[Union[str, ImageClassType]]): list of image urls or ImageClassType objects, to be embedded
matryoshka_dim (int): Length of matryoshka embedding
Raises:
ValueError: raised if engine is not started yet
Expand All @@ -232,16 +242,19 @@ async def image_embed(
"""

self._assert_running()
embeddings, usage = await self._batch_handler.image_embed(images=images)
embeddings, usage = await self._batch_handler.image_embed(
images=images, matryoshka_dim=matryoshka_dim
)
return embeddings, usage

async def audio_embed(
self, *, audios: list[Union[str, bytes]]
self, *, audios: list[Union[str, bytes]], matryoshka_dim: int | None = None
) -> tuple[list["EmbeddingReturnType"], int]:
"""embed multiple audios
Kwargs:
audios (list[Union[str, Audiobytes]]): list of audio data, to be embedded
matryoshka_dim (int): Length of matryoshka embedding
Raises:
ValueError: raised if engine is not started yet
Expand All @@ -255,7 +268,9 @@ async def audio_embed(
"""

self._assert_running()
embeddings, usage = await self._batch_handler.audio_embed(audios=audios)
embeddings, usage = await self._batch_handler.audio_embed(
audios=audios, matryoshka_dim=matryoshka_dim
)
return embeddings, usage

def _assert_running(self):
Expand Down Expand Up @@ -304,13 +319,14 @@ async def astop(self):
await engine.astop()

async def embed(
self, *, model: str, sentences: list[str]
self, *, model: str, sentences: list[str], matryoshka_dim: Optional[int] = None
) -> tuple[list["EmbeddingReturnType"], int]:
"""embed multiple sentences
Kwargs:
model (str): model name to be used
sentences (list[str]): sentences to be embedded
matryoshka_dim (int): Length of matryoshka embedding
Raises:
ValueError: raised if engine is not started yet
Expand All @@ -322,7 +338,7 @@ async def embed(
2D list-array of shape( len(sentences),embed_dim )
int: token usage
"""
return await self[model].embed(sentences)
return await self[model].embed(sentences, matryoshka_dim=matryoshka_dim)

def is_running(self) -> bool:
return all(engine.is_running for engine in self.engines_dict.values())
Expand Down Expand Up @@ -378,13 +394,18 @@ async def classify(
return await self[model].classify(sentences=sentences, raw_scores=raw_scores)

async def image_embed(
self, *, model: str, images: list[Union[str, "ImageClassType"]]
self,
*,
model: str,
images: list[Union[str, "ImageClassType"]],
matryoshka_dim: Optional[int] = None,
) -> tuple[list["EmbeddingReturnType"], int]:
"""embed multiple images
Kwargs:
model (str): model name to be used
images (list[Union[str, ImageClassType]]): list of image urls or ImageClassType objects, to be embedded
matryoshka_dim (int): Length of matryoshka embedding
Raises:
ValueError: raised if engine is not started yet
Expand All @@ -396,7 +417,7 @@ async def image_embed(
2D list-array of shape( len(sentences),embed_dim )
int: token usage
"""
return await self[model].image_embed(images=images)
return await self[model].image_embed(images=images, matryoshka_dim=matryoshka_dim)

def __getitem__(self, index_or_name: Union[str, int]) -> "AsyncEmbeddingEngine":
"""resolve engine by model name -> Auto resolve if only one engine is present
Expand All @@ -416,13 +437,14 @@ def __getitem__(self, index_or_name: Union[str, int]) -> "AsyncEmbeddingEngine":
)

async def audio_embed(
self, *, model: str, audios: list[Union[str, bytes]]
self, *, model: str, audios: list[Union[str, bytes]], matryoshka_dim: Optional[int] = None
) -> tuple[list["EmbeddingReturnType"], int]:
"""embed multiple audios
Kwargs:
model (str): model name to be used
audios (list[Union[str, bytes]]): list of audio data, to be embedded
matryoshka_dim (int): Length of matryoshka embedding
Raises:
ValueError: raised if engine is not started yet
Expand All @@ -434,4 +456,4 @@ async def audio_embed(
2D list-array of shape( len(sentences),embed_dim )
int: token usage
"""
return await self[model].audio_embed(audios=audios)
return await self[model].audio_embed(audios=audios, matryoshka_dim=matryoshka_dim)
1 change: 1 addition & 0 deletions libs/infinity_emb/infinity_emb/fastapi_schemas/pymodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class _OpenAIEmbeddingInput(BaseModel):
model: str = "default/not-specified"
encoding_format: EmbeddingEncodingFormat = EmbeddingEncodingFormat.float
user: Optional[str] = None
dimensions: Optional[Annotated[int, Field(strict=True, gt=0, lt=8193)]] = None


class _OpenAIEmbeddingInput_Text(_OpenAIEmbeddingInput):
Expand Down
15 changes: 11 additions & 4 deletions libs/infinity_emb/infinity_emb/inference/batch_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,9 @@ def __init__(
" Consider increasing queue size"
)

async def embed(self, sentences: list[str]) -> tuple[list["EmbeddingReturnType"], int]:
async def embed(
self, sentences: list[str], matryoshka_dim: Optional[int] = None
) -> tuple[list["EmbeddingReturnType"], int]:
"""Schedule a sentence to be embedded. Awaits until embedded.
Args:
Expand All @@ -157,6 +159,8 @@ async def embed(self, sentences: list[str]) -> tuple[list["EmbeddingReturnType"]
input_sentences = [EmbeddingSingle(sentence=s) for s in sentences]

embeddings, usage = await self._schedule(input_sentences)
if matryoshka_dim:
embeddings = [embedding[:matryoshka_dim] for embedding in embeddings]
return embeddings, usage

async def rerank(
Expand Down Expand Up @@ -239,6 +243,7 @@ async def image_embed(
self,
*,
images: list[Union[str, "ImageClassType", bytes]],
matryoshka_dim: Optional[int] = None,
) -> tuple[list["EmbeddingReturnType"], int]:
"""Schedule a images and sentences to be embedded. Awaits until embedded.
Expand All @@ -262,12 +267,12 @@ async def image_embed(

items = await resolve_images(images)
embeddings, usage = await self._schedule(items)
if matryoshka_dim:
embeddings = [embedding[:matryoshka_dim] for embedding in embeddings]
return embeddings, usage

async def audio_embed(
self,
*,
audios: list[Union[str, bytes]],
self, *, audios: list[Union[str, bytes]], matryoshka_dim: Optional[int] = None
) -> tuple[list["EmbeddingReturnType"], int]:
"""Schedule audios and sentences to be embedded. Awaits until embedded.
Expand All @@ -294,6 +299,8 @@ async def audio_embed(
getattr(self.model_worker[0]._model, "sampling_rate", -42),
)
embeddings, usage = await self._schedule(items)
if matryoshka_dim:
embeddings = [embedding[:matryoshka_dim] for embedding in embeddings]
return embeddings, usage

async def _schedule(self, list_queueitem: Sequence[AbstractSingle]) -> tuple[list[Any], int]:
Expand Down
6 changes: 3 additions & 3 deletions libs/infinity_emb/infinity_emb/infinity_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,21 +354,21 @@ def url_to_base64(url, modality = "image"):
"[📝] Received request with %s input texts ",
len(input_), # type: ignore
)
embedding, usage = await engine.embed(sentences=input_)
embedding, usage = await engine.embed(sentences=input_,matryoshka_dim=data_root.dimensions)
elif modality == Modality.audio:
urls_or_bytes = _resolve_mixed_input(data_root.input) # type: ignore
logger.debug(
"[📝] Received request with %s input audios ",
len(urls_or_bytes), # type: ignore
)
embedding, usage = await engine.audio_embed(audios=urls_or_bytes)
embedding, usage = await engine.audio_embed(audios=urls_or_bytes,matryoshka_dim=data_root.dimensions)
elif modality == Modality.image:
urls_or_bytes = _resolve_mixed_input(data_root.input) # type: ignore
logger.debug(
"[📝] Received request with %s input images ",
len(urls_or_bytes), # type: ignore
)
embedding, usage = await engine.image_embed(images=urls_or_bytes)
embedding, usage = await engine.image_embed(images=urls_or_bytes,matryoshka_dim=data_root.dimensions)

duration = (time.perf_counter() - start) * 1000
logger.debug("[✅] Done in %s ms", duration)
Expand Down
31 changes: 25 additions & 6 deletions libs/infinity_emb/infinity_emb/sync_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,9 +171,14 @@ def stop(self):
self.async_run(self.async_engine_array.astop).result()

@add_start_docstrings(AsyncEngineArray.embed.__doc__)
def embed(self, *, model: str, sentences: list[str]):
def embed(self, *, model: str, sentences: list[str], matryoshka_dim: Optional[int] = None):
"""sync interface of AsyncEngineArray"""
return self.async_run(self.async_engine_array.embed, model=model, sentences=sentences)
return self.async_run(
self.async_engine_array.embed,
model=model,
sentences=sentences,
matryoshka_dim=matryoshka_dim,
)

@add_start_docstrings(AsyncEngineArray.rerank.__doc__)
def rerank(
Expand Down Expand Up @@ -206,14 +211,28 @@ def classify(self, *, model: str, sentences: list[str], raw_scores: bool = False
)

@add_start_docstrings(AsyncEngineArray.image_embed.__doc__)
def image_embed(self, *, model: str, images: list[Union[str, bytes]]):
def image_embed(
self, *, model: str, images: list[Union[str, bytes]], matryoshka_dim: Optional[int] = None
):
"""sync interface of AsyncEngineArray"""
return self.async_run(self.async_engine_array.image_embed, model=model, images=images)
return self.async_run(
self.async_engine_array.image_embed,
model=model,
images=images,
matryoshka_dim=matryoshka_dim,
)

@add_start_docstrings(AsyncEngineArray.audio_embed.__doc__)
def audio_embed(self, *, model: str, audios: list[Union[str, bytes]]):
def audio_embed(
self, *, model: str, audios: list[Union[str, bytes]], matryoshka_dim: Optional[int] = None
):
"""sync interface of AsyncEngineArray"""
return self.async_run(self.async_engine_array.audio_embed, model=model, audios=audios)
return self.async_run(
self.async_engine_array.audio_embed,
model=model,
audios=audios,
matryoshka_dim=matryoshka_dim,
)

def __del__(self):
self.stop()
23 changes: 23 additions & 0 deletions libs/infinity_emb/tests/end_to_end/test_api_with_dummymodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,3 +170,26 @@ async def test_openapi_same_as_docs_file(client):
tc.assertDictEqual(openapi_json["info"], openapi_json_expected["info"])
tc.assertDictEqual(openapi_json["paths"], openapi_json_expected["paths"])
# tc.assertDictEqual(openapi_json["components"], openapi_json_expected["components"])


@pytest.mark.anyio
async def test_matryoshka_embedding(client):
matryoshka_dim = 10

possible_inputs = [
["This is a test sentence."],
["This is a test sentence.", "This is another test sentence."],
]
for inp in possible_inputs:
response = await client.post(
f"{PREFIX}/embeddings",
json=dict(input=inp, model=MODEL_NAME, dimensions=matryoshka_dim),
)
assert response.status_code == 200, f"{response.status_code}, {response.text}"
rdata = response.json()
assert "data" in rdata and isinstance(rdata["data"], list)
assert all("embedding" in d for d in rdata["data"])
assert len(rdata["data"]) == len(inp)
for embedding, sentence in zip(rdata["data"], inp):
assert len(sentence) == embedding["embedding"][0]
assert len(embedding["embedding"]) == matryoshka_dim
10 changes: 10 additions & 0 deletions libs/infinity_emb/tests/end_to_end/test_openapi_client_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,16 @@ async def test_openai(client: AsyncClient):
extra_body={"modality": "text"},
)

# test: text matryoshka
emb_1_text_matryoshka_dim = await client_oai.embeddings.create(
model=pytest.DEFAULT_BERT_MODEL,
input=["a cat", "a cat", "a bird"],
encoding_format="float",
dimensions=64,
extra_body={"modality": "text"},
)
assert len(emb_1_text_matryoshka_dim.data[0].embedding) == 64

# test AUDIO: cosine distance of beep to cat and dog
np.testing.assert_allclose(
emb1_audio.data[0].embedding, emb1_1_audio.data[0].embedding, rtol=1e-5
Expand Down
Loading

0 comments on commit efe6096

Please sign in to comment.