Skip to content

Commit

Permalink
Follow up PR for Audio End to End testing (#390)
Browse files Browse the repository at this point in the history
* update readme

* extract audio related code into audio utils

* add test cases for audio and vision

* revert docs v2

* revert docs v2

* fix test cases

* add test for text only vision case

* add text only case for audio

* format code

* skip text test for not to see updated coverage

* revert cli doc from main branch

* add changes to support text and urls

* address comments. Report correct usage

* update endpoint usage for mixed case. Extend vision cases
  • Loading branch information
wirthual authored Oct 1, 2024
1 parent 638205c commit 5881a74
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 4 deletions.
6 changes: 6 additions & 0 deletions libs/infinity_emb/tests/end_to_end/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

import numpy as np
import pytest
from numpy import dot
from numpy.linalg import norm


class Helpers:
Expand Down Expand Up @@ -98,6 +100,10 @@ async def embedding_verify(client, model_base, prefix, model_name, decimal=3):
embedding["embedding"], st_embedding, decimal=decimal
)

@staticmethod
def cosine_similarity(a, b):
return dot(a, b) / (norm(a) * norm(b))


@pytest.fixture
def helpers():
Expand Down
49 changes: 47 additions & 2 deletions libs/infinity_emb/tests/end_to_end/test_torch_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,11 @@ async def test_audio_single(client):


@pytest.mark.anyio
@pytest.mark.skip("text only")
async def test_audio_single_text_only(client):
text = "a sound of a at"

response = await client.post(
f"{PREFIX}/embeddings_audio",
f"{PREFIX}/embeddings",
json={"model": MODEL, "input": text},
)
assert response.status_code == 200
Expand All @@ -79,6 +78,43 @@ async def test_audio_single_text_only(client):
assert len(rdata_results[0]["embedding"]) > 0


@pytest.mark.anyio
async def test_meta(client, helpers):
audio_url = "https://github.com/michaelfeil/infinity/raw/3b72eb7c14bae06e68ddd07c1f23fe0bf403f220/libs/infinity_emb/tests/data/audio/beep.wav"

text_input = ["a beep", "a horse", "a fish"]
audio_input = [audio_url]
response_text = await client.post(
f"{PREFIX}/embeddings",
json={"model": MODEL, "input": text_input},
)
response_audio = await client.post(
f"{PREFIX}/embeddings_audio",
json={"model": MODEL, "input": audio_input},
)

assert response_text.status_code == 200
assert response_audio.status_code == 200

rdata_text = response_text.json()
rdata_results_text = rdata_text["data"]

rdata_audio = response_audio.json()
rdata_results_audio = rdata_audio["data"]

embeddings_audio_beep = rdata_results_audio[0]["embedding"]
embeddings_text_beep = rdata_results_text[0]["embedding"]
embeddings_text_horse = rdata_results_text[1]["embedding"]
embeddings_text_fish = rdata_results_text[2]["embedding"]

assert helpers.cosine_similarity(
embeddings_audio_beep, embeddings_text_beep
) > helpers.cosine_similarity(embeddings_audio_beep, embeddings_text_fish)
assert helpers.cosine_similarity(
embeddings_audio_beep, embeddings_text_beep
) > helpers.cosine_similarity(embeddings_audio_beep, embeddings_text_horse)


@pytest.mark.anyio
@pytest.mark.parametrize("no_of_audios", [1, 5, 10])
async def test_audio_multiple(client, no_of_audios):
Expand Down Expand Up @@ -120,3 +156,12 @@ async def test_audio_empty(client):
json={"model": MODEL, "input": audio_url_empty},
)
assert response_empty.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY


@pytest.mark.anyio
async def test_unsupported_endpoints(client):
response_unsupported = await client.post(
f"{PREFIX}/classify",
json={"model": MODEL, "input": ["test"]},
)
assert response_unsupported.status_code == status.HTTP_400_BAD_REQUEST
49 changes: 47 additions & 2 deletions libs/infinity_emb/tests/end_to_end/test_torch_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,11 @@ async def test_vision_single(client):


@pytest.mark.anyio
@pytest.mark.skip("text only")
async def test_vision_single_text_only(client):
text = "a image of a cat"

response = await client.post(
f"{PREFIX}/embeddings_image",
f"{PREFIX}/embeddings",
json={"model": MODEL, "input": text},
)
assert response.status_code == 200
Expand All @@ -79,6 +78,43 @@ async def test_vision_single_text_only(client):
assert len(rdata_results[0]["embedding"]) > 0


@pytest.mark.anyio
async def test_meta(client, helpers):
image_url = "http://images.cocodataset.org/val2017/000000039769.jpg"

text_input = ["a cat", "a car", "a fridge"]
image_input = [image_url]
response_text = await client.post(
f"{PREFIX}/embeddings",
json={"model": MODEL, "input": text_input},
)
response_image = await client.post(
f"{PREFIX}/embeddings_image",
json={"model": MODEL, "input": image_input},
)

assert response_text.status_code == 200
assert response_image.status_code == 200

rdata_text = response_text.json()
rdata_results_text = rdata_text["data"]

rdata_image = response_image.json()
rdata_results_image = rdata_image["data"]

embeddings_image_cat = rdata_results_image[0]["embedding"]
embeddings_text_cat = rdata_results_text[0]["embedding"]
embeddings_text_car = rdata_results_text[1]["embedding"]
embeddings_text_fridge = rdata_results_text[2]["embedding"]

assert helpers.cosine_similarity(
embeddings_image_cat, embeddings_text_cat
) > helpers.cosine_similarity(embeddings_image_cat, embeddings_text_car)
assert helpers.cosine_similarity(
embeddings_image_cat, embeddings_text_cat
) > helpers.cosine_similarity(embeddings_image_cat, embeddings_text_fridge)


@pytest.mark.anyio
@pytest.mark.parametrize("no_of_images", [1, 5, 10])
async def test_vision_multiple(client, no_of_images):
Expand Down Expand Up @@ -119,3 +155,12 @@ async def test_vision_empty(client):
json={"model": MODEL, "input": image_url_empty},
)
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY


@pytest.mark.anyio
async def test_unsupported_endpoints(client):
response_unsupported = await client.post(
f"{PREFIX}/classify",
json={"model": MODEL, "input": ["test"]},
)
assert response_unsupported.status_code == status.HTTP_400_BAD_REQUEST

0 comments on commit 5881a74

Please sign in to comment.