Skip to content

Commit

Permalink
Colbert (#456)
Browse files Browse the repository at this point in the history
* inital commit

* fmt

* lint

* fix: typos, more docs

* colbert v1

* fix: usage of uvloop
  • Loading branch information
michaelfeil authored Nov 11, 2024
1 parent 317e809 commit 98dffca
Show file tree
Hide file tree
Showing 5 changed files with 121 additions and 17 deletions.
2 changes: 1 addition & 1 deletion libs/infinity_emb/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ benchmark_embed: tests/data/benchmark/benchmark_embed.json
# sudo apt-get apache2-utils

benchmark_embed_vision: tests/data/benchmark/benchmark_embed_image.json
ab -n 10000 -c 10 -l -s 480 \
ab -n 100 -c 50 -l -s 480 \
-T 'application/json' \
-p $< \
http://127.0.0.1:7997/embeddings
Expand Down
18 changes: 10 additions & 8 deletions libs/infinity_emb/infinity_emb/infinity_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,14 +661,16 @@ def typer_option_resolve(*args):
import typer
import uvicorn

try:
import uvloop

asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
loopname = "uvloop"
except ImportError:
# Windows does not support uvloop
loopname = "auto"
loopname = "auto"
if sys.version_info < (3, 12):
try:
import uvloop

asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
loopname = "uvloop"
except ImportError:
# Windows does not support uvloop
pass

tp = typer.Typer()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,14 @@ def __init__(self, *, engine_args=EngineArgs):
# to be able to could the tokens in another thread
# without corrupting the original.
fm = self._first_module()

self.normalize_embeddings = True

self.mode_colbert = False
if "colbert" in fm.auto_model.config.architectures[0].lower():
self.mode_colbert = True
self.normalize_embeddings = False

self._infinity_tokenizer = copy.deepcopy(fm.tokenizer)
self.eval()
self.engine_args = engine_args
Expand Down Expand Up @@ -102,20 +110,38 @@ def encode_core(self, features: dict[str, "Tensor"]) -> "Tensor":

with torch.no_grad():
features = util.batch_to_device(features, self.device) # type: ignore
out_features: "Tensor" = self.forward(features)["sentence_embedding"]
out: dict[str, "Tensor"] = self.forward(features)
if not self.mode_colbert:
out_features = out["sentence_embedding"].detach().cpu()
else:
out_features = { # type: ignore # noqa
"token_embeddings": out["token_embeddings"].detach().cpu(),
"attention_mask": out["attention_mask"].detach().cpu(),
}

return out_features.detach().cpu()
return out_features

@quant_embedding_decorator()
def encode_post(
self, out_features: "Tensor", normalize_embeddings: bool = True
self,
out_features: "Tensor",
) -> "EmbeddingReturnType":
with torch.inference_mode():
embeddings: "Tensor" = out_features.to(torch.float32)
if normalize_embeddings:
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)

embeddings_np: np.ndarray = embeddings.numpy()
if not self.mode_colbert:
embeddings: "Tensor" = out_features.to(torch.float32)
if self.normalize_embeddings:
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
embeddings_np: np.ndarray = embeddings.numpy()
else:
# remove the attention mask for two inputs with 5 and 3 tokens that's [[1,1,1,1,1],[1,1,1,0,0]]
# and convert to list of numpy arrays
embeddings_np = [ # type: ignore # noqa
z[m].numpy()
for z, m in zip(
out_features["token_embeddings"].to(torch.float32), # type: ignore
out_features["attention_mask"].bool(), # type: ignore
)
]

return embeddings_np

Expand Down
1 change: 1 addition & 0 deletions libs/infinity_emb/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
pytest.DEFAULT_AUDIO_MODEL = "laion/clap-htsat-unfused"
pytest.DEFAULT_IMAGE_MODEL = "wkcn/TinyCLIP-ViT-8M-16-Text-3M-YFCC15M"
pytest.DEFAULT_IMAGE_COLPALI_MODEL = "michaelfeil/colpali-v12-random-testing"
pytest.DEFAULT_COLBERT_MODEL = "michaelfeil/colbert-tiny-random"

pytest.IMAGE_SAMPLE_URL = "https://github.com/michaelfeil/infinity/raw/06fd1f4d8f0a869f4482fc1c78b62a75ccbb66a1/docs/assets/cats_coco_sample.jpg"
pytest.AUDIO_SAMPLE_URL = "https://github.com/michaelfeil/infinity/raw/3b72eb7c14bae06e68ddd07c1f23fe0bf403f220/libs/infinity_emb/tests/data/audio/beep.wav"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import pytest
import torch
from asgi_lifespan import LifespanManager
from httpx import AsyncClient
import numpy as np
from infinity_emb import create_server
from infinity_emb.args import EngineArgs
from infinity_emb.primitives import Device, InferenceEngine

PREFIX = "/v1_sentence_transformers_colbert"
MODEL: str = pytest.DEFAULT_COLBERT_MODEL # type: ignore[assignment]
batch_size = 64 if torch.cuda.is_available() else 8

app = create_server(
url_prefix=PREFIX,
engine_args_list=[
EngineArgs(
model_name_or_path=MODEL,
batch_size=batch_size,
engine=InferenceEngine.torch,
device=Device.auto if not torch.backends.mps.is_available() else Device.cpu,
)
],
)


# @pytest.fixture
# def model_base() -> SentenceTransformer:
# # model = SentenceTransformer(MODEL)
# # if model.device == "cuda":
# # model = model.to(torch.float16)
# # return model
# model


@pytest.fixture()
async def client():
async with AsyncClient(app=app, base_url="http://test", timeout=20) as client, LifespanManager(
app
):
yield client


# def test_load_model(model_base):
# # this makes sure that the error below is not based on a slow download
# # or internal pytorch errors
# model_base.encode(["This is a test sentence."])


@pytest.mark.anyio
async def test_model_route(client):
response = await client.get(f"{PREFIX}/models")
assert response.status_code == 200
rdata = response.json()
assert "data" in rdata
assert rdata["data"][0].get("id", "") == MODEL
assert isinstance(rdata["data"][0].get("stats"), dict)


@pytest.mark.anyio
async def test_embedding(client):
response = await client.post(
f"{PREFIX}/embeddings", json=dict(input=["This is a test", "hi", "hi"], model=MODEL)
)
assert response.status_code == 200
rdata = response.json()
assert "data" in rdata
assert len(rdata["data"]) == 3
# TODO: Check if start and end tokens should be embedded
# TODO: Check if normalization is applied or should be applied?
assert len(rdata["data"][0]["embedding"]) == 6 # This is a test -> 6 tokens
assert len(rdata["data"][1]["embedding"]) == 3 # hi -> 3 tokens
np.testing.assert_allclose(
rdata["data"][1]["embedding"], rdata["data"][2]["embedding"], atol=5e-3
)

0 comments on commit 98dffca

Please sign in to comment.