Skip to content

Commit

Permalink
Merge pull request #50 from michaelfeil/add-rerank
Browse files Browse the repository at this point in the history
Add rerank
  • Loading branch information
michaelfeil authored Jan 7, 2024
2 parents 58f01f1 + 6be6b2a commit 7556042
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 23 deletions.
28 changes: 27 additions & 1 deletion libs/infinity_emb/infinity_emb/fastapi_schemas/convert.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, Iterable, Union
from typing import Any, Dict, Iterable, List, Optional, Union

from infinity_emb.primitives import EmbeddingReturnType

Expand All @@ -20,3 +20,29 @@ def list_embeddings_to_response(
],
usage=dict(prompt_tokens=usage, total_tokens=usage),
)


def to_rerank_response(
scores: List[float],
model=str,
usage=int,
documents: Optional[List[str]] = None,
) -> Dict[str, Any]:
if documents is None:
return dict(
model=model,
results=[
dict(relevance_score=score, index=count)
for count, score in enumerate(scores)
],
usage=dict(prompt_tokens=usage, total_tokens=usage),
)
else:
return dict(
model=model,
results=[
dict(relevance_score=score, index=count, document=doc)
for count, (score, doc) in enumerate(zip(scores, documents))
],
usage=dict(prompt_tokens=usage, total_tokens=usage),
)
27 changes: 27 additions & 0 deletions libs/infinity_emb/infinity_emb/fastapi_schemas/pymodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,33 @@ class OpenAIEmbeddingResult(BaseModel):
created: int = Field(default_factory=lambda: int(time.time()))


class RerankInput(BaseModel):
query: Annotated[
str, StringConstraints(max_length=1024 * 10, strip_whitespace=True)
]
documents: conlist( # type: ignore
Annotated[str, StringConstraints(max_length=1024 * 10, strip_whitespace=True)],
min_length=1,
max_length=2048,
)
return_documents: bool = False


class _ReRankObject(BaseModel):
relevance_score: float
index: int
document: Optional[str] = None


class ReRankResult(BaseModel):
object: Literal["rerank"] = "rerank"
data: List[_ReRankObject]
model: str
usage: _Usage
id: str = Field(default_factory=lambda: f"infinity-{uuid4()}")
created: int = Field(default_factory=lambda: int(time.time()))


class ModelInfo(BaseModel):
id: str
stats: Dict[str, Any]
Expand Down
87 changes: 65 additions & 22 deletions libs/infinity_emb/infinity_emb/infinity_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,21 @@

# prometheus
import infinity_emb
from infinity_emb.engine import AsyncEmbeddingEngine
from infinity_emb.fastapi_schemas import docs, errors
from infinity_emb.fastapi_schemas.convert import list_embeddings_to_response
from infinity_emb.fastapi_schemas.convert import (
list_embeddings_to_response,
to_rerank_response,
)
from infinity_emb.fastapi_schemas.pymodels import (
OpenAIEmbeddingInput,
OpenAIEmbeddingResult,
OpenAIModelInfo,
RerankInput,
)
from infinity_emb.inference import (
BatchHandler,
Device,
DeviceTypeHint,
select_model,
)
from infinity_emb.inference.caching_layer import INFINITY_CACHE_VECTORS
from infinity_emb.log_handler import UVICORN_LOG_LEVELS, logger
Expand Down Expand Up @@ -55,29 +58,23 @@ def create_server(
vector_disk_cache_path = (
f"{engine}_{model_name_or_path.replace('/','_')}" if vector_disk_cache else ""
)
model_name_offical = "".join(model_name_or_path.split("/")[-2:])

@app.on_event("startup")
async def _startup():
instrumentator.expose(app)

model, min_inference_t = select_model(
app.model = AsyncEmbeddingEngine(
model_name_or_path=model_name_or_path,
batch_size=batch_size,
engine=engine,
model_warmup=model_warmup,
device=device,
)

app.batch_handler = BatchHandler(
max_batch_size=batch_size,
model=model,
verbose=verbose,
batch_delay=min_inference_t / 2,
vector_disk_cache_path=vector_disk_cache_path,
device=device,
lengths_via_tokenize=lengths_via_tokenize,
)
# start in a threadpool
await app.batch_handler.spawn()
await app.model.astart()

logger.info(
docs.startup_message(
Expand All @@ -89,7 +86,7 @@ async def _startup():

@app.on_event("shutdown")
async def _shutdown():
await app.batch_handler.shutdown()
await app.model.astop()

@app.get("/ready")
async def _ready() -> float:
Expand All @@ -105,7 +102,7 @@ async def _ready() -> float:
)
async def _models():
"""get models endpoint"""
s = app.batch_handler.overload_status() # type: ignore
s = app.model.overload_status() # type: ignore
return dict(
data=dict(
id=model_name_or_path,
Expand All @@ -132,8 +129,8 @@ async def _embeddings(data: OpenAIEmbeddingInput):
requests.post("http://..:7997/v1/embeddings",
json={"model":"bge-small-en-v1.5","input":["A sentence to encode."]})
"""
bh: BatchHandler = app.batch_handler # type: ignore
if bh.is_overloaded():
model: AsyncEmbeddingEngine = app.model # type: ignore
if model.is_overloaded():
raise errors.OpenAIException(
"model overloaded", code=status.HTTP_429_TOO_MANY_REQUESTS
)
Expand All @@ -142,19 +139,65 @@ async def _embeddings(data: OpenAIEmbeddingInput):
logger.debug("[📝] Received request with %s inputs ", len(data.input))
start = time.perf_counter()

embedding, usage = await bh.embed(data.input)
embedding, usage = await model.embed(data.input)

duration = (time.perf_counter() - start) * 1000
logger.debug("[✅] Done in %s ms", duration)

res = list_embeddings_to_response(
embeddings=embedding, model=data.model, usage=usage
embeddings=embedding, model=model_name_offical, usage=usage
)

return res
except Exception as ex:
raise errors.OpenAIException(
f"internal server error {ex}",
f"InternalServerError: {ex}",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)

@app.post(
f"{url_prefix}/rerank",
# response_model=RerankResult,
response_class=responses.ORJSONResponse,
)
async def _rerank(data: RerankInput):
"""Encode Embeddings
```python
import requests
requests.post("http://..:7997/rerank",
json={"query":"Where is Munich?","texts":["Munich is in Germany."]})
"""
model: AsyncEmbeddingEngine = app.model # type: ignore
if model.is_overloaded():
raise errors.OpenAIException(
"model overloaded", code=status.HTTP_429_TOO_MANY_REQUESTS
)

try:
logger.debug("[📝] Received request with %s docs ", len(data.documents))
start = time.perf_counter()

scores, usage = await model.rerank(
query=data.query, docs=data.documents, raw_scores=False
)

duration = (time.perf_counter() - start) * 1000
logger.debug("[✅] Done in %s ms", duration)

if data.return_documents:
docs = data.documents
else:
docs = None

res = to_rerank_response(
scores=scores, documents=docs, model=model_name_offical, usage=usage
)

return res
except Exception as ex:
raise errors.OpenAIException(
f"InternalServerError: {ex}",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)

Expand All @@ -164,7 +207,7 @@ async def _embeddings(data: OpenAIEmbeddingInput):
def _start_uvicorn(
model_name_or_path: str = "BAAI/bge-small-en-v1.5",
batch_size: int = 64,
url_prefix: str = "/v1",
url_prefix: str = "",
host: str = "0.0.0.0",
port: int = 7997,
log_level: UVICORN_LOG_LEVELS = UVICORN_LOG_LEVELS.info.name, # type: ignore
Expand All @@ -181,7 +224,7 @@ def _start_uvicorn(
model_name_or_path, str: Huggingface model, e.g.
"BAAI/bge-small-en-v1.5".
batch_size, int: batch size for forward pass.
url_prefix, str: prefix for api. typically "/v1".
url_prefix, str: prefix for api. typically "".
host, str: host-url, typically either "0.0.0.0" or "127.0.0.1".
port, int: port that you want to expose.
log_level: logging level.
Expand Down

0 comments on commit 7556042

Please sign in to comment.