From 5768743d222b157373aec62d3a20c2e8020dff19 Mon Sep 17 00:00:00 2001 From: wirthual Date: Fri, 11 Oct 2024 02:51:03 +0200 Subject: [PATCH 1/4] add base64 test for audio endpoint --- .../tests/end_to_end/test_torch_audio.py | 36 +++++++++++++++++-- 1 file changed, 33 insertions(+), 3 deletions(-) diff --git a/libs/infinity_emb/tests/end_to_end/test_torch_audio.py b/libs/infinity_emb/tests/end_to_end/test_torch_audio.py index 8fff57c0..d0985004 100644 --- a/libs/infinity_emb/tests/end_to_end/test_torch_audio.py +++ b/libs/infinity_emb/tests/end_to_end/test_torch_audio.py @@ -1,4 +1,8 @@ +import base64 + +import numpy as np import pytest +import requests import torch from asgi_lifespan import LifespanManager from fastapi import status @@ -119,9 +123,7 @@ async def test_meta(client, helpers): async def test_audio_multiple(client): for route in [f"{PREFIX}/embeddings_audio", f"{PREFIX}/embeddings"]: for no_of_audios in [1, 5, 10]: - audio_urls = [ - "https://github.com/michaelfeil/infinity/raw/3b72eb7c14bae06e68ddd07c1f23fe0bf403f220/libs/infinity_emb/tests/data/audio/beep.wav" - ] * no_of_audios + audio_urls = [pytest.AUDIO_SAMPLE_URL] * no_of_audios response = await client.post( route, @@ -141,6 +143,34 @@ async def test_audio_multiple(client): assert len(rdata_results[0]["embedding"]) > 0 +@pytest.mark.anyio +async def test_audio_base64(client): + bytes_downloaded = requests.get(pytest.AUDIO_SAMPLE_URL).content + base_64_audio = base64.b64encode(bytes_downloaded).decode("utf-8") + + response = await client.post( + f"{PREFIX}/embeddings_audio", + json={ + "model": MODEL, + "input": [ + "data:audio/wav;base64," + base_64_audio, + pytest.AUDIO_SAMPLE_URL, + ], + }, + ) + assert response.status_code == 200 + rdata = response.json() + assert "model" in rdata + assert "usage" in rdata + rdata_results = rdata["data"] + assert rdata_results[0]["object"] == "embedding" + assert len(rdata_results[0]["embedding"]) > 0 + + np.testing.assert_array_equal( + rdata_results[0]["embedding"], rdata_results[1]["embedding"] + ) + + @pytest.mark.anyio async def test_audio_fail(client): for route in [f"{PREFIX}/embeddings_audio", f"{PREFIX}/embeddings"]: From 7407ff95eb20420175163e7bca48858a124b4d0d Mon Sep 17 00:00:00 2001 From: wirthual Date: Fri, 11 Oct 2024 02:53:26 +0200 Subject: [PATCH 2/4] use audio url constant --- libs/infinity_emb/tests/end_to_end/test_torch_audio.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/libs/infinity_emb/tests/end_to_end/test_torch_audio.py b/libs/infinity_emb/tests/end_to_end/test_torch_audio.py index d0985004..1f7ce7cd 100644 --- a/libs/infinity_emb/tests/end_to_end/test_torch_audio.py +++ b/libs/infinity_emb/tests/end_to_end/test_torch_audio.py @@ -50,7 +50,7 @@ async def test_model_route(client): @pytest.mark.anyio async def test_audio_single(client): - audio_url = "https://github.com/michaelfeil/infinity/raw/3b72eb7c14bae06e68ddd07c1f23fe0bf403f220/libs/infinity_emb/tests/data/audio/beep.wav" + audio_url = pytest.AUDIO_SAMPLE_URL response = await client.post( f"{PREFIX}/embeddings_audio", @@ -84,7 +84,7 @@ async def test_audio_single_text_only(client): @pytest.mark.anyio async def test_meta(client, helpers): - audio_url = "https://github.com/michaelfeil/infinity/raw/3b72eb7c14bae06e68ddd07c1f23fe0bf403f220/libs/infinity_emb/tests/data/audio/beep.wav" + audio_url = pytest.AUDIO_SAMPLE_URL text_input = ["a beep", "a horse", "a fish"] audio_input = [audio_url] From c7bc39ce5d606136d255cb5f736f292ce1703f59 Mon Sep 17 00:00:00 2001 From: wirthual Date: Fri, 11 Oct 2024 03:53:36 +0200 Subject: [PATCH 3/4] add failure test for base64 --- .../tests/end_to_end/test_torch_audio.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/libs/infinity_emb/tests/end_to_end/test_torch_audio.py b/libs/infinity_emb/tests/end_to_end/test_torch_audio.py index 1f7ce7cd..ba3380cc 100644 --- a/libs/infinity_emb/tests/end_to_end/test_torch_audio.py +++ b/libs/infinity_emb/tests/end_to_end/test_torch_audio.py @@ -166,11 +166,30 @@ async def test_audio_base64(client): assert rdata_results[0]["object"] == "embedding" assert len(rdata_results[0]["embedding"]) > 0 + print(rdata_results[0]["embedding"]) np.testing.assert_array_equal( rdata_results[0]["embedding"], rdata_results[1]["embedding"] ) +@pytest.mark.anyio +async def test_audio_base64_fail(client): + base_64_audio = "somethingsomething" + + response = await client.post( + f"{PREFIX}/embeddings_audio", + json={ + "model": MODEL, + "input": [ + "data:audio/wav;base64," + base_64_audio, + pytest.AUDIO_SAMPLE_URL, + ], + }, + ) + + assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY + + @pytest.mark.anyio async def test_audio_fail(client): for route in [f"{PREFIX}/embeddings_audio", f"{PREFIX}/embeddings"]: From f403324b2355c30c63c82956fec267557897681f Mon Sep 17 00:00:00 2001 From: wirthual Date: Fri, 11 Oct 2024 04:00:06 +0200 Subject: [PATCH 4/4] remove print statement --- libs/infinity_emb/tests/end_to_end/test_torch_audio.py | 1 - 1 file changed, 1 deletion(-) diff --git a/libs/infinity_emb/tests/end_to_end/test_torch_audio.py b/libs/infinity_emb/tests/end_to_end/test_torch_audio.py index ba3380cc..0983c45d 100644 --- a/libs/infinity_emb/tests/end_to_end/test_torch_audio.py +++ b/libs/infinity_emb/tests/end_to_end/test_torch_audio.py @@ -166,7 +166,6 @@ async def test_audio_base64(client): assert rdata_results[0]["object"] == "embedding" assert len(rdata_results[0]["embedding"]) > 0 - print(rdata_results[0]["embedding"]) np.testing.assert_array_equal( rdata_results[0]["embedding"], rdata_results[1]["embedding"] )