diff --git a/tests/handler_async_test.py b/tests/handler_async_test.py index e8206cb..0ab46a7 100644 --- a/tests/handler_async_test.py +++ b/tests/handler_async_test.py @@ -1,11 +1,37 @@ +import json import os from ipinfo.cache.default import DefaultCache from ipinfo.details import Details from ipinfo.handler_async import AsyncHandler from ipinfo import handler_utils +from ipinfo.error import APIError +from ipinfo.exceptions import RequestQuotaExceededError import ipinfo import pytest +import aiohttp + + +class MockResponse: + def __init__(self, text, status, headers): + self._text = text + self.status = status + self.headers = headers + + def text(self): + return self._text + + async def json(self): + return json.loads(self._text) + + async def __aexit__(self, exc_type, exc_val, exc_tb): + pass + + async def __aenter__(self): + return self + + async def release(self): + pass @pytest.mark.asyncio @@ -103,6 +129,39 @@ async def test_get_details(): await handler.deinit() +@pytest.mark.parametrize( + ("mock_resp_status_code", "mock_resp_headers", "mock_resp_error_msg", "expected_error_json"), + [ + pytest.param(503, {"Content-Type": "text/plain"}, "Service Unavailable", {"error": "Service Unavailable"}, id="5xx_not_json"), + pytest.param(403, {"Content-Type": "application/json"}, '{"message": "missing token"}', {"message": "missing token"}, id="4xx_json"), + pytest.param(400, {"Content-Type": "application/json"}, '{"message": "missing field"}', {"message": "missing field"}, id="400"), + ] +) +@pytest.mark.asyncio +async def test_get_details_error(monkeypatch, mock_resp_status_code, mock_resp_headers, mock_resp_error_msg, expected_error_json): + async def mock_get(*args, **kwargs): + response = MockResponse(status=mock_resp_status_code, text=mock_resp_error_msg, headers=mock_resp_headers) + return response + + monkeypatch.setattr(aiohttp.ClientSession, 'get', lambda *args, **kwargs: aiohttp.client._RequestContextManager(mock_get())) + token = os.environ.get("IPINFO_TOKEN", "") + handler = AsyncHandler(token) + with pytest.raises(APIError) as exc_info: + await handler.getDetails("8.8.8.8") + assert exc_info.value.error_code == mock_resp_status_code + assert exc_info.value.error_json == expected_error_json + +@pytest.mark.asyncio +async def test_get_details_quota_error(monkeypatch): + async def mock_get(*args, **kwargs): + response = MockResponse(status=429, text="Quota exceeded", headers={}) + return response + + monkeypatch.setattr(aiohttp.ClientSession, 'get', lambda *args, **kwargs: aiohttp.client._RequestContextManager(mock_get())) + token = os.environ.get("IPINFO_TOKEN", "") + handler = AsyncHandler(token) + with pytest.raises(RequestQuotaExceededError): + await handler.getDetails("8.8.8.8") ############# # BATCH TESTS diff --git a/tests/handler_test.py b/tests/handler_test.py index 8b1100e..3767622 100644 --- a/tests/handler_test.py +++ b/tests/handler_test.py @@ -1,13 +1,14 @@ -from ipaddress import IPv4Address -import json import os from ipinfo.cache.default import DefaultCache from ipinfo.details import Details from ipinfo.handler import Handler from ipinfo import handler_utils +from ipinfo.error import APIError +from ipinfo.exceptions import RequestQuotaExceededError import ipinfo import pytest +import requests def test_init(): @@ -98,6 +99,43 @@ def test_get_details(): assert "total" in domains assert len(domains["domains"]) == 5 +@pytest.mark.parametrize( + ("mock_resp_status_code", "mock_resp_headers", "mock_resp_error_msg", "expected_error_json"), + [ + pytest.param(503, {"Content-Type": "text/plain"}, b"Service Unavailable", {"error": "Service Unavailable"}, id="5xx_not_json"), + pytest.param(403, {"Content-Type": "application/json"}, b'{"message": "missing token"}', {"message": "missing token"}, id="4xx_json"), + pytest.param(400, {"Content-Type": "application/json"}, b'{"message": "missing field"}', {"message": "missing field"}, id="400"), + ] +) +def test_get_details_error(monkeypatch, mock_resp_status_code, mock_resp_headers, mock_resp_error_msg, expected_error_json): + def mock_get(*args, **kwargs): + response = requests.Response() + response.status_code = mock_resp_status_code + response.headers = mock_resp_headers + response._content = mock_resp_error_msg + return response + + monkeypatch.setattr(requests, 'get', mock_get) + token = os.environ.get("IPINFO_TOKEN", "") + handler = Handler(token) + + with pytest.raises(APIError) as exc_info: + handler.getDetails("8.8.8.8") + assert exc_info.value.error_code == mock_resp_status_code + assert exc_info.value.error_json == expected_error_json + +def test_get_details_quota_error(monkeypatch): + def mock_get(*args, **kwargs): + response = requests.Response() + response.status_code = 429 + return response + + monkeypatch.setattr(requests, 'get', mock_get) + token = os.environ.get("IPINFO_TOKEN", "") + handler = Handler(token) + + with pytest.raises(RequestQuotaExceededError): + handler.getDetails("8.8.8.8") ############# # BATCH TESTS