diff --git a/libs/infinity_emb/tests/end_to_end/test_torch_audio.py b/libs/infinity_emb/tests/end_to_end/test_torch_audio.py index 8fff57c0..0983c45d 100644 --- a/libs/infinity_emb/tests/end_to_end/test_torch_audio.py +++ b/libs/infinity_emb/tests/end_to_end/test_torch_audio.py @@ -1,4 +1,8 @@ +import base64 + +import numpy as np import pytest +import requests import torch from asgi_lifespan import LifespanManager from fastapi import status @@ -46,7 +50,7 @@ async def test_model_route(client): @pytest.mark.anyio async def test_audio_single(client): - audio_url = "https://github.com/michaelfeil/infinity/raw/3b72eb7c14bae06e68ddd07c1f23fe0bf403f220/libs/infinity_emb/tests/data/audio/beep.wav" + audio_url = pytest.AUDIO_SAMPLE_URL response = await client.post( f"{PREFIX}/embeddings_audio", @@ -80,7 +84,7 @@ async def test_audio_single_text_only(client): @pytest.mark.anyio async def test_meta(client, helpers): - audio_url = "https://github.com/michaelfeil/infinity/raw/3b72eb7c14bae06e68ddd07c1f23fe0bf403f220/libs/infinity_emb/tests/data/audio/beep.wav" + audio_url = pytest.AUDIO_SAMPLE_URL text_input = ["a beep", "a horse", "a fish"] audio_input = [audio_url] @@ -119,9 +123,7 @@ async def test_meta(client, helpers): async def test_audio_multiple(client): for route in [f"{PREFIX}/embeddings_audio", f"{PREFIX}/embeddings"]: for no_of_audios in [1, 5, 10]: - audio_urls = [ - "https://github.com/michaelfeil/infinity/raw/3b72eb7c14bae06e68ddd07c1f23fe0bf403f220/libs/infinity_emb/tests/data/audio/beep.wav" - ] * no_of_audios + audio_urls = [pytest.AUDIO_SAMPLE_URL] * no_of_audios response = await client.post( route, @@ -141,6 +143,52 @@ async def test_audio_multiple(client): assert len(rdata_results[0]["embedding"]) > 0 +@pytest.mark.anyio +async def test_audio_base64(client): + bytes_downloaded = requests.get(pytest.AUDIO_SAMPLE_URL).content + base_64_audio = base64.b64encode(bytes_downloaded).decode("utf-8") + + response = await client.post( + f"{PREFIX}/embeddings_audio", + json={ + "model": MODEL, + "input": [ + "data:audio/wav;base64," + base_64_audio, + pytest.AUDIO_SAMPLE_URL, + ], + }, + ) + assert response.status_code == 200 + rdata = response.json() + assert "model" in rdata + assert "usage" in rdata + rdata_results = rdata["data"] + assert rdata_results[0]["object"] == "embedding" + assert len(rdata_results[0]["embedding"]) > 0 + + np.testing.assert_array_equal( + rdata_results[0]["embedding"], rdata_results[1]["embedding"] + ) + + +@pytest.mark.anyio +async def test_audio_base64_fail(client): + base_64_audio = "somethingsomething" + + response = await client.post( + f"{PREFIX}/embeddings_audio", + json={ + "model": MODEL, + "input": [ + "data:audio/wav;base64," + base_64_audio, + pytest.AUDIO_SAMPLE_URL, + ], + }, + ) + + assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY + + @pytest.mark.anyio async def test_audio_fail(client): for route in [f"{PREFIX}/embeddings_audio", f"{PREFIX}/embeddings"]: