Skip to content

Commit

Permalink
add verification check (#374)
Browse files Browse the repository at this point in the history
* add verification check

* format
  • Loading branch information
michaelfeil authored Sep 24, 2024
1 parent 5ab8d51 commit 83d5a3a
Showing 1 changed file with 19 additions and 5 deletions.
24 changes: 19 additions & 5 deletions libs/infinity_emb/infinity_emb/transformer/vision/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,22 @@ def resolve_from_img_url(img_url: str) -> ImageSingle:
try:
downloaded_img = requests.get(img_url, stream=True).raw
except Exception as e:
raise ImageCorruption(f"error downloading image from url: {e}")
raise ImageCorruption(
f"error opening an image in your request image from url: {e}"
)

try:
return ImageSingle(image=Image.open(downloaded_img))
img = Image.open(downloaded_img)
if img.size[0] < 3 or img.size[1] < 3:
# https://upload.wikimedia.org/wikipedia/commons/c/ca/1x1.png
raise ImageCorruption(
f"An image in your request is too small for processing {img.size}"
)
return ImageSingle(image=img)
except Exception as e:
raise ImageCorruption(f"error opening image from url: {e}")
raise ImageCorruption(
f"error opening the payload from an image in your request from url: {e}"
)


def resolve_image(img: Union[str, "ImageClassType"]) -> ImageSingle:
Expand Down Expand Up @@ -70,7 +80,7 @@ def resolve_images(images: List[Union[str, "ImageClassType"]]) -> List[ImageSing
return resolved_imgs


def resolve_audio(audio: Union[str, bytes]) -> AudioSingle:
def resolve_audio(audio: Union[str, bytes], allowed_sampling_rate: int) -> AudioSingle:
if isinstance(audio, bytes):
try:
audio_bytes = io.BytesIO(audio)
Expand All @@ -85,6 +95,10 @@ def resolve_audio(audio: Union[str, bytes]) -> AudioSingle:

try:
data, rate = sf.read(audio_bytes)
if rate != allowed_sampling_rate:
raise AudioCorruption(
f"Audio sample rate is not {allowed_sampling_rate}Mhz, it is {rate}Mhz."
)
return AudioSingle(audio=data, sampling_rate=rate)
except Exception as e:
raise AudioCorruption(f"Error opening audio: {e}.\nError msg: {str(e)}")
Expand All @@ -100,7 +114,7 @@ def resolve_audios(
resolved_audios: list[AudioSingle] = []
for audio in audio_urls:
try:
audio_single = resolve_audio(audio)
audio_single = resolve_audio(audio, allowed_sampling_rate)
resolved_audios.append(audio_single)
except Exception as e:
raise AudioCorruption(f"Failed to resolve audio: {e}")
Expand Down

0 comments on commit 83d5a3a

Please sign in to comment.