Skip to content

Commit

Permalink
Merge pull request #183 from michaelfeil/pydatnic-cli-validation
Browse files Browse the repository at this point in the history
pydantic cli / args validation
  • Loading branch information
michaelfeil authored Mar 30, 2024
2 parents 7b10965 + 8bbea84 commit 1f0258c
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 11 deletions.
3 changes: 2 additions & 1 deletion libs/infinity_emb/infinity_emb/_optional_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def mark_dirty(self, exception: Exception) -> None:
self._marked_as_dirty = exception

def mark_required(self) -> bool:
if self.is_available or self._marked_as_dirty:
if not self.is_available or self._marked_as_dirty:
self._raise_error()
return True

Expand All @@ -49,6 +49,7 @@ def _raise_error(self) -> None:
)
if self._marked_as_dirty:
raise ImportError(msg) from self._marked_as_dirty
raise ImportError(msg)


CHECK_DISKCACHE = OptionalImports("diskcache", "cache")
Expand Down
27 changes: 25 additions & 2 deletions libs/infinity_emb/infinity_emb/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from dataclasses import dataclass
from typing import Optional

from infinity_emb._optional_imports import CHECK_PYDANTIC
from infinity_emb.primitives import (
Device,
Dtype,
Expand All @@ -11,6 +12,8 @@
PoolingMethod,
)

if CHECK_PYDANTIC.is_available:
from pydantic.dataclasses import dataclass as dataclass_pydantic
# if python>=3.10 use kw_only
dataclass_args = {"kw_only": True} if sys.version_info >= (3, 10) else {}

Expand All @@ -21,7 +24,6 @@ class EngineArgs:
Args:
model_name_or_path, str: Defaults to "michaelfeil/bge-small-en-v1.5".
served_model_name, str: Defaults to bge-small-en-v1.5
batch_size, int: Defaults to 32.
revision, str: Defaults to None.
trust_remote_code, bool: Defaults to True.
Expand All @@ -36,10 +38,10 @@ class EngineArgs:
dtype, Dtype or str: data type to use for inference. Defaults to Dtype.auto.
pooling_method, PoolingMethod or str: pooling method to use. Defaults to PoolingMethod.auto.
lengths_via_tokenize, bool: schedule by token usage. Defaults to False.
served_model_name, str: Defaults to readable name of model_name_or_path.
"""

model_name_or_path: str = "michaelfeil/bge-small-en-v1.5"
served_model_name: Optional[str] = None
batch_size: int = 32
revision: Optional[str] = None
trust_remote_code: bool = True
Expand All @@ -53,6 +55,7 @@ class EngineArgs:
pooling_method: PoolingMethod = PoolingMethod.auto
lengths_via_tokenize: bool = False
embedding_dtype: EmbeddingDtype = EmbeddingDtype.float32
served_model_name: str = None # type: ignore

def __post_init__(self):
# convert the following strings to enums
Expand All @@ -71,3 +74,23 @@ def __post_init__(self):
object.__setattr__(
self, "embedding_dtype", EmbeddingDtype[self.embedding_dtype]
)
if self.served_model_name is None:
object.__setattr__(
self,
"served_model_name",
"/".join(self.model_name_or_path.split("/")[-2:]),
)

# after all done -> check if the dataclass is valid
if CHECK_PYDANTIC.is_available:
# convert to pydantic dataclass
# and check if the dataclass is valid
@dataclass_pydantic(frozen=True, **dataclass_args)
class EngineArgsPydantic(EngineArgs):
def __post_init__(self):
# overwrite the __post_init__ method
# to avoid infinite recursion
pass

# validate
EngineArgsPydantic(**self.__dict__)
12 changes: 4 additions & 8 deletions libs/infinity_emb/infinity_emb/infinity_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,6 @@ def create_server(
instrumentator = Instrumentator().instrument(app)
app.add_exception_handler(errors.OpenAIException, errors.openai_exception_handler)

MODEL_RESPONSE_NAME = engine_args.served_model_name or "/".join(
engine_args.model_name_or_path.split("/")[-2:]
)

@app.on_event("startup")
async def _startup():
instrumentator.expose(app)
Expand Down Expand Up @@ -109,7 +105,7 @@ async def _models():
return dict(
data=[
dict(
id=MODEL_RESPONSE_NAME,
id=engine_args.served_model_name,
stats=dict(
queue_fraction=s.queue_fraction,
queue_absolute=s.queue_absolute,
Expand Down Expand Up @@ -154,7 +150,7 @@ async def _embeddings(data: OpenAIEmbeddingInput):
logger.debug("[✅] Done in %s ms", duration)

res = list_embeddings_to_response(
embeddings=embedding, model=MODEL_RESPONSE_NAME, usage=usage
embeddings=embedding, model=engine_args.served_model_name, usage=usage
)

return res
Expand Down Expand Up @@ -202,7 +198,7 @@ async def _rerank(data: RerankInput):
res = to_rerank_response(
scores=scores,
documents=docs,
model=MODEL_RESPONSE_NAME,
model=engine_args.served_model_name,
usage=usage,
)

Expand Down Expand Up @@ -277,7 +273,6 @@ def _start_uvicorn(

engine_args = EngineArgs(
model_name_or_path=model_name_or_path,
served_model_name=served_model_name,
batch_size=batch_size,
revision=revision,
trust_remote_code=trust_remote_code,
Expand All @@ -290,6 +285,7 @@ def _start_uvicorn(
pooling_method=PoolingMethod[pooling_method.value], # type: ignore
compile=compile,
bettertransformer=bettertransformer,
served_model_name=served_model_name, # type: ignore
)

app = create_server(
Expand Down

0 comments on commit 1f0258c

Please sign in to comment.