From 98dffcad0680cf19087a7ca94001c19620e5f026 Mon Sep 17 00:00:00 2001 From: Michael Feil <63565275+michaelfeil@users.noreply.github.com> Date: Sun, 10 Nov 2024 23:32:21 -0800 Subject: [PATCH] Colbert (#456) * inital commit * fmt * lint * fix: typos, more docs * colbert v1 * fix: usage of uvloop --- libs/infinity_emb/Makefile | 2 +- .../infinity_emb/infinity_server.py | 18 +++-- .../embedder/sentence_transformer.py | 42 +++++++++-- libs/infinity_emb/tests/conftest.py | 1 + .../test_sentence_transformers_colbert.py | 75 +++++++++++++++++++ 5 files changed, 121 insertions(+), 17 deletions(-) create mode 100644 libs/infinity_emb/tests/end_to_end/test_sentence_transformers_colbert.py diff --git a/libs/infinity_emb/Makefile b/libs/infinity_emb/Makefile index 06d39848..971afc7b 100644 --- a/libs/infinity_emb/Makefile +++ b/libs/infinity_emb/Makefile @@ -63,7 +63,7 @@ benchmark_embed: tests/data/benchmark/benchmark_embed.json # sudo apt-get apache2-utils benchmark_embed_vision: tests/data/benchmark/benchmark_embed_image.json - ab -n 10000 -c 10 -l -s 480 \ + ab -n 100 -c 50 -l -s 480 \ -T 'application/json' \ -p $< \ http://127.0.0.1:7997/embeddings diff --git a/libs/infinity_emb/infinity_emb/infinity_server.py b/libs/infinity_emb/infinity_emb/infinity_server.py index fb386710..0f701888 100644 --- a/libs/infinity_emb/infinity_emb/infinity_server.py +++ b/libs/infinity_emb/infinity_emb/infinity_server.py @@ -661,14 +661,16 @@ def typer_option_resolve(*args): import typer import uvicorn - try: - import uvloop - - asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) - loopname = "uvloop" - except ImportError: - # Windows does not support uvloop - loopname = "auto" + loopname = "auto" + if sys.version_info < (3, 12): + try: + import uvloop + + asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) + loopname = "uvloop" + except ImportError: + # Windows does not support uvloop + pass tp = typer.Typer() diff --git a/libs/infinity_emb/infinity_emb/transformer/embedder/sentence_transformer.py b/libs/infinity_emb/infinity_emb/transformer/embedder/sentence_transformer.py index b1e304d9..0439cf5e 100644 --- a/libs/infinity_emb/infinity_emb/transformer/embedder/sentence_transformer.py +++ b/libs/infinity_emb/infinity_emb/transformer/embedder/sentence_transformer.py @@ -70,6 +70,14 @@ def __init__(self, *, engine_args=EngineArgs): # to be able to could the tokens in another thread # without corrupting the original. fm = self._first_module() + + self.normalize_embeddings = True + + self.mode_colbert = False + if "colbert" in fm.auto_model.config.architectures[0].lower(): + self.mode_colbert = True + self.normalize_embeddings = False + self._infinity_tokenizer = copy.deepcopy(fm.tokenizer) self.eval() self.engine_args = engine_args @@ -102,20 +110,38 @@ def encode_core(self, features: dict[str, "Tensor"]) -> "Tensor": with torch.no_grad(): features = util.batch_to_device(features, self.device) # type: ignore - out_features: "Tensor" = self.forward(features)["sentence_embedding"] + out: dict[str, "Tensor"] = self.forward(features) + if not self.mode_colbert: + out_features = out["sentence_embedding"].detach().cpu() + else: + out_features = { # type: ignore # noqa + "token_embeddings": out["token_embeddings"].detach().cpu(), + "attention_mask": out["attention_mask"].detach().cpu(), + } - return out_features.detach().cpu() + return out_features @quant_embedding_decorator() def encode_post( - self, out_features: "Tensor", normalize_embeddings: bool = True + self, + out_features: "Tensor", ) -> "EmbeddingReturnType": with torch.inference_mode(): - embeddings: "Tensor" = out_features.to(torch.float32) - if normalize_embeddings: - embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1) - - embeddings_np: np.ndarray = embeddings.numpy() + if not self.mode_colbert: + embeddings: "Tensor" = out_features.to(torch.float32) + if self.normalize_embeddings: + embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1) + embeddings_np: np.ndarray = embeddings.numpy() + else: + # remove the attention mask for two inputs with 5 and 3 tokens that's [[1,1,1,1,1],[1,1,1,0,0]] + # and convert to list of numpy arrays + embeddings_np = [ # type: ignore # noqa + z[m].numpy() + for z, m in zip( + out_features["token_embeddings"].to(torch.float32), # type: ignore + out_features["attention_mask"].bool(), # type: ignore + ) + ] return embeddings_np diff --git a/libs/infinity_emb/tests/conftest.py b/libs/infinity_emb/tests/conftest.py index b030d2c0..6bd03f35 100644 --- a/libs/infinity_emb/tests/conftest.py +++ b/libs/infinity_emb/tests/conftest.py @@ -13,6 +13,7 @@ pytest.DEFAULT_AUDIO_MODEL = "laion/clap-htsat-unfused" pytest.DEFAULT_IMAGE_MODEL = "wkcn/TinyCLIP-ViT-8M-16-Text-3M-YFCC15M" pytest.DEFAULT_IMAGE_COLPALI_MODEL = "michaelfeil/colpali-v12-random-testing" +pytest.DEFAULT_COLBERT_MODEL = "michaelfeil/colbert-tiny-random" pytest.IMAGE_SAMPLE_URL = "https://github.com/michaelfeil/infinity/raw/06fd1f4d8f0a869f4482fc1c78b62a75ccbb66a1/docs/assets/cats_coco_sample.jpg" pytest.AUDIO_SAMPLE_URL = "https://github.com/michaelfeil/infinity/raw/3b72eb7c14bae06e68ddd07c1f23fe0bf403f220/libs/infinity_emb/tests/data/audio/beep.wav" diff --git a/libs/infinity_emb/tests/end_to_end/test_sentence_transformers_colbert.py b/libs/infinity_emb/tests/end_to_end/test_sentence_transformers_colbert.py new file mode 100644 index 00000000..503db7a0 --- /dev/null +++ b/libs/infinity_emb/tests/end_to_end/test_sentence_transformers_colbert.py @@ -0,0 +1,75 @@ +import pytest +import torch +from asgi_lifespan import LifespanManager +from httpx import AsyncClient +import numpy as np +from infinity_emb import create_server +from infinity_emb.args import EngineArgs +from infinity_emb.primitives import Device, InferenceEngine + +PREFIX = "/v1_sentence_transformers_colbert" +MODEL: str = pytest.DEFAULT_COLBERT_MODEL # type: ignore[assignment] +batch_size = 64 if torch.cuda.is_available() else 8 + +app = create_server( + url_prefix=PREFIX, + engine_args_list=[ + EngineArgs( + model_name_or_path=MODEL, + batch_size=batch_size, + engine=InferenceEngine.torch, + device=Device.auto if not torch.backends.mps.is_available() else Device.cpu, + ) + ], +) + + +# @pytest.fixture +# def model_base() -> SentenceTransformer: +# # model = SentenceTransformer(MODEL) +# # if model.device == "cuda": +# # model = model.to(torch.float16) +# # return model +# model + + +@pytest.fixture() +async def client(): + async with AsyncClient(app=app, base_url="http://test", timeout=20) as client, LifespanManager( + app + ): + yield client + + +# def test_load_model(model_base): +# # this makes sure that the error below is not based on a slow download +# # or internal pytorch errors +# model_base.encode(["This is a test sentence."]) + + +@pytest.mark.anyio +async def test_model_route(client): + response = await client.get(f"{PREFIX}/models") + assert response.status_code == 200 + rdata = response.json() + assert "data" in rdata + assert rdata["data"][0].get("id", "") == MODEL + assert isinstance(rdata["data"][0].get("stats"), dict) + + +@pytest.mark.anyio +async def test_embedding(client): + response = await client.post( + f"{PREFIX}/embeddings", json=dict(input=["This is a test", "hi", "hi"], model=MODEL) + ) + assert response.status_code == 200 + rdata = response.json() + assert "data" in rdata + assert len(rdata["data"]) == 3 + # TODO: Check if start and end tokens should be embedded + # TODO: Check if normalization is applied or should be applied? + assert len(rdata["data"][0]["embedding"]) == 6 # This is a test -> 6 tokens + assert len(rdata["data"][1]["embedding"]) == 3 # hi -> 3 tokens + np.testing.assert_allclose( + rdata["data"][1]["embedding"], rdata["data"][2]["embedding"], atol=5e-3 + )