From 83d5a3a37b35b2e4006475af58c20c97b0134f3c Mon Sep 17 00:00:00 2001 From: Michael Feil <63565275+michaelfeil@users.noreply.github.com> Date: Mon, 23 Sep 2024 20:13:17 -0700 Subject: [PATCH] add verification check (#374) * add verification check * format --- .../infinity_emb/transformer/vision/utils.py | 24 +++++++++++++++---- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/libs/infinity_emb/infinity_emb/transformer/vision/utils.py b/libs/infinity_emb/infinity_emb/transformer/vision/utils.py index adb0a2b2..e529a54a 100644 --- a/libs/infinity_emb/infinity_emb/transformer/vision/utils.py +++ b/libs/infinity_emb/infinity_emb/transformer/vision/utils.py @@ -32,12 +32,22 @@ def resolve_from_img_url(img_url: str) -> ImageSingle: try: downloaded_img = requests.get(img_url, stream=True).raw except Exception as e: - raise ImageCorruption(f"error downloading image from url: {e}") + raise ImageCorruption( + f"error opening an image in your request image from url: {e}" + ) try: - return ImageSingle(image=Image.open(downloaded_img)) + img = Image.open(downloaded_img) + if img.size[0] < 3 or img.size[1] < 3: + # https://upload.wikimedia.org/wikipedia/commons/c/ca/1x1.png + raise ImageCorruption( + f"An image in your request is too small for processing {img.size}" + ) + return ImageSingle(image=img) except Exception as e: - raise ImageCorruption(f"error opening image from url: {e}") + raise ImageCorruption( + f"error opening the payload from an image in your request from url: {e}" + ) def resolve_image(img: Union[str, "ImageClassType"]) -> ImageSingle: @@ -70,7 +80,7 @@ def resolve_images(images: List[Union[str, "ImageClassType"]]) -> List[ImageSing return resolved_imgs -def resolve_audio(audio: Union[str, bytes]) -> AudioSingle: +def resolve_audio(audio: Union[str, bytes], allowed_sampling_rate: int) -> AudioSingle: if isinstance(audio, bytes): try: audio_bytes = io.BytesIO(audio) @@ -85,6 +95,10 @@ def resolve_audio(audio: Union[str, bytes]) -> AudioSingle: try: data, rate = sf.read(audio_bytes) + if rate != allowed_sampling_rate: + raise AudioCorruption( + f"Audio sample rate is not {allowed_sampling_rate}Mhz, it is {rate}Mhz." + ) return AudioSingle(audio=data, sampling_rate=rate) except Exception as e: raise AudioCorruption(f"Error opening audio: {e}.\nError msg: {str(e)}") @@ -100,7 +114,7 @@ def resolve_audios( resolved_audios: list[AudioSingle] = [] for audio in audio_urls: try: - audio_single = resolve_audio(audio) + audio_single = resolve_audio(audio, allowed_sampling_rate) resolved_audios.append(audio_single) except Exception as e: raise AudioCorruption(f"Failed to resolve audio: {e}")