Skip to content

Commit

Permalink
format. Use future annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
wirthual committed Dec 6, 2024
1 parent 7917c56 commit d8ad010
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 9 deletions.
11 changes: 8 additions & 3 deletions libs/infinity_emb/infinity_emb/engine.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# SPDX-License-Identifier: MIT
# Copyright (c) 2023-now michaelfeilfeil
from __future__ import annotations

from asyncio import Semaphore
from typing import Iterable, Iterator, Optional, Union
Expand Down Expand Up @@ -318,7 +319,7 @@ async def astop(self):
await engine.astop()

async def embed(
self, *, model: str, sentences: list[str], matryoshka_dim: Optional[int]=None
self, *, model: str, sentences: list[str], matryoshka_dim: Optional[int] = None
) -> tuple[list["EmbeddingReturnType"], int]:
"""embed multiple sentences
Expand Down Expand Up @@ -393,7 +394,11 @@ async def classify(
return await self[model].classify(sentences=sentences, raw_scores=raw_scores)

async def image_embed(
self, *, model: str, images: list[Union[str, "ImageClassType"]], matryoshka_dim:Optional[int]=None
self,
*,
model: str,
images: list[Union[str, "ImageClassType"]],
matryoshka_dim: Optional[int] = None,
) -> tuple[list["EmbeddingReturnType"], int]:
"""embed multiple images
Expand Down Expand Up @@ -432,7 +437,7 @@ def __getitem__(self, index_or_name: Union[str, int]) -> "AsyncEmbeddingEngine":
)

async def audio_embed(
self, *, model: str, audios: list[Union[str, bytes]], matryoshka_dim:Optional[int]=None
self, *, model: str, audios: list[Union[str, bytes]], matryoshka_dim: Optional[int] = None
) -> tuple[list["EmbeddingReturnType"], int]:
"""embed multiple audios
Expand Down
9 changes: 6 additions & 3 deletions libs/infinity_emb/infinity_emb/inference/batch_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def __init__(
)

async def embed(
self, sentences: list[str], matryoshka_dim:Optional[int]=None
self, sentences: list[str], matryoshka_dim: Optional[int] = None
) -> tuple[list["EmbeddingReturnType"], int]:
"""Schedule a sentence to be embedded. Awaits until embedded.
Expand Down Expand Up @@ -240,7 +240,10 @@ async def classify(
return classifications, usage

async def image_embed(
self, *, images: list[Union[str, "ImageClassType", bytes]], matryoshka_dim:Optional[int]=None
self,
*,
images: list[Union[str, "ImageClassType", bytes]],
matryoshka_dim: Optional[int] = None,
) -> tuple[list["EmbeddingReturnType"], int]:
"""Schedule a images and sentences to be embedded. Awaits until embedded.
Expand Down Expand Up @@ -269,7 +272,7 @@ async def image_embed(
return embeddings, usage

async def audio_embed(
self, *, audios: list[Union[str, bytes]], matryoshka_dim:Optional[int]=None
self, *, audios: list[Union[str, bytes]], matryoshka_dim: Optional[int] = None
) -> tuple[list["EmbeddingReturnType"], int]:
"""Schedule audios and sentences to be embedded. Awaits until embedded.
Expand Down
10 changes: 7 additions & 3 deletions libs/infinity_emb/infinity_emb/sync_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def stop(self):
self.async_run(self.async_engine_array.astop).result()

@add_start_docstrings(AsyncEngineArray.embed.__doc__)
def embed(self, *, model: str, sentences: list[str], matryoshka_dim:Optional[int]=None):
def embed(self, *, model: str, sentences: list[str], matryoshka_dim: Optional[int] = None):
"""sync interface of AsyncEngineArray"""
return self.async_run(
self.async_engine_array.embed,
Expand Down Expand Up @@ -211,7 +211,9 @@ def classify(self, *, model: str, sentences: list[str], raw_scores: bool = False
)

@add_start_docstrings(AsyncEngineArray.image_embed.__doc__)
def image_embed(self, *, model: str, images: list[Union[str, bytes]], matryoshka_dim:Optional[int]=None):
def image_embed(
self, *, model: str, images: list[Union[str, bytes]], matryoshka_dim: Optional[int] = None
):
"""sync interface of AsyncEngineArray"""
return self.async_run(
self.async_engine_array.image_embed,
Expand All @@ -221,7 +223,9 @@ def image_embed(self, *, model: str, images: list[Union[str, bytes]], matryoshka
)

@add_start_docstrings(AsyncEngineArray.audio_embed.__doc__)
def audio_embed(self, *, model: str, audios: list[Union[str, bytes]], matryoshka_dim:Optional[int]=None):
def audio_embed(
self, *, model: str, audios: list[Union[str, bytes]], matryoshka_dim: Optional[int] = None
):
"""sync interface of AsyncEngineArray"""
return self.async_run(
self.async_engine_array.audio_embed,
Expand Down

0 comments on commit d8ad010

Please sign in to comment.