Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Follow up PR for Audio End to End testing #390

Merged
merged 16 commits into from
Oct 1, 2024
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions libs/infinity_emb/infinity_emb/fastapi_schemas/pymodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,17 @@ class ImageEmbeddingInput(BaseModel):
user: Optional[str] = None


class AudioEmbeddingInput(ImageEmbeddingInput):
pass
class AudioEmbeddingInput(BaseModel):
input: Union[ # type: ignore
conlist( # type: ignore
Union[Annotated[AnyUrl, HttpUrl], str],
**ITEMS_LIMIT_SMALL,
),
Union[Annotated[AnyUrl, HttpUrl], str],
]
model: str = "default/not-specified"
encoding_format: EmbeddingEncodingFormat = EmbeddingEncodingFormat.float
user: Optional[str] = None


class _EmbeddingObject(BaseModel):
Expand Down
44 changes: 35 additions & 9 deletions libs/infinity_emb/infinity_emb/infinity_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import time
from contextlib import asynccontextmanager
from typing import Any, Optional
from urllib.parse import urlparse

import infinity_emb
from infinity_emb._optional_imports import CHECK_TYPER, CHECK_UVICORN
Expand Down Expand Up @@ -411,29 +412,54 @@ async def _embeddings_audio(data: AudioEmbeddingInput):
json={"model":"laion/larger_clap_general","input":["https://github.com/michaelfeil/infinity/raw/3b72eb7c14bae06e68ddd07c1f23fe0bf403f220/libs/infinity_emb/tests/data/audio/beep.wav"]})
"""
engine = _resolve_engine(data.model)
if hasattr(data.input, "host"):
# if it is a single url
audio_inputs = [str(data.input)]
input_list: list[str] = []
if isinstance(data.input, str):
input_list.append(data.input)
else:
audio_inputs = [str(d) for d in data.input] # type: ignore
input_list = data.input # type: ignore
wirthual marked this conversation as resolved.
Show resolved Hide resolved
audio_urls = []
texts = []
is_audios = []
for input in input_list:
parsed_url = urlparse(input)
# Todo: Improve url check
if parsed_url.netloc and parsed_url.scheme:
# if it is a single url
audio_urls.append(str(input))
is_audios.append(True)
else:
texts.append(input) # type: ignore
is_audios.append(False)
try:
logger.debug("[📝] Received request with %s Urls ", len(audio_inputs))
logger.debug(
f"[📝] Received request with {len(audio_urls)} Urls and {len(texts)} sentences"
)
start = time.perf_counter()

embedding, usage = await engine.audio_embed(audios=audio_inputs) # type: ignore
if audio_urls:
audio_embeddings, usage = await engine.audio_embed(audios=audio_urls) # type: ignore
if texts:
text_embeddings, usage = await engine.embed(sentences=texts)
wirthual marked this conversation as resolved.
Show resolved Hide resolved

embeddings_with_restored_order = []
for is_audio in is_audios:
if is_audio:
embeddings_with_restored_order.append(audio_embeddings.pop(0))
else:
embeddings_with_restored_order.append(text_embeddings.pop(0))

duration = (time.perf_counter() - start) * 1000
logger.debug("[✅] Done in %s ms", duration)
logger.debug(f"[✅] Done in {duration} ms")

return OpenAIEmbeddingResult.to_embeddings_response(
embeddings=embedding,
embeddings=embeddings_with_restored_order,
engine_args=engine.engine_args,
encoding_format=data.encoding_format,
usage=usage,
)
except AudioCorruption as ex:
raise errors.OpenAIException(
f"AudioCorruption, could not open {audio_inputs} -> {ex}",
f"AudioCorruption, could not open {audio_urls} -> {ex}",
code=status.HTTP_400_BAD_REQUEST,
)
except ModelNotDeployedError as ex:
Expand Down
56 changes: 55 additions & 1 deletion libs/infinity_emb/tests/end_to_end/test_torch_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,13 @@
)


def cosine_similarity(a, b):
from numpy import dot
from numpy.linalg import norm

return dot(a, b) / (norm(a) * norm(b))


@pytest.fixture()
async def client():
async with AsyncClient(
Expand Down Expand Up @@ -62,7 +69,6 @@ async def test_audio_single(client):


@pytest.mark.anyio
@pytest.mark.skip("text only")
async def test_audio_single_text_only(client):
text = "a sound of a at"

Expand All @@ -79,6 +85,54 @@ async def test_audio_single_text_only(client):
assert len(rdata_results[0]["embedding"]) > 0


@pytest.mark.anyio
@pytest.mark.parametrize("no_of_input_pairs", [1, 5])
async def test_audio_text_url_mixed(client, no_of_input_pairs):
text = "a sound of a at"
audio_url = "https://github.com/michaelfeil/infinity/raw/3b72eb7c14bae06e68ddd07c1f23fe0bf403f220/libs/infinity_emb/tests/data/audio/beep.wav"

input = [text, audio_url] * no_of_input_pairs

response = await client.post(
f"{PREFIX}/embeddings_audio",
json={"model": MODEL, "input": input},
)
assert response.status_code == 200
rdata = response.json()
assert "model" in rdata
assert "usage" in rdata
rdata_results = rdata["data"]
assert rdata_results[0]["object"] == "embedding"
assert len(rdata_results[0]["embedding"]) > 0
assert len(rdata_results) == len(input)


@pytest.mark.anyio
async def test_meta(client):
audio_url = "https://github.com/michaelfeil/infinity/raw/3b72eb7c14bae06e68ddd07c1f23fe0bf403f220/libs/infinity_emb/tests/data/audio/beep.wav"

input = [audio_url, "a beep", "a horse", "a fish"]

response = await client.post(
f"{PREFIX}/embeddings_audio",
json={"model": MODEL, "input": input},
)
assert response.status_code == 200
rdata = response.json()
rdata_results = rdata["data"]

embeddings_audio_beep = rdata_results[0]["embedding"]
embeddings_text_beep = rdata_results[1]["embedding"]
embeddings_text_horse = rdata_results[2]["embedding"]
embeddings_text_fish = rdata_results[3]["embedding"]
assert cosine_similarity(
embeddings_audio_beep, embeddings_text_beep
) > cosine_similarity(embeddings_audio_beep, embeddings_text_fish)
assert cosine_similarity(
embeddings_audio_beep, embeddings_text_beep
) > cosine_similarity(embeddings_audio_beep, embeddings_text_horse)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: Add a small tolerance to comparisons to account for floating-point precision

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a greater than is fine here I guess



@pytest.mark.anyio
@pytest.mark.parametrize("no_of_audios", [1, 5, 10])
async def test_audio_multiple(client, no_of_audios):
Expand Down
Loading