diff --git a/backend/main.py b/backend/main.py index c2d090b87..748967fed 100644 --- a/backend/main.py +++ b/backend/main.py @@ -7,6 +7,7 @@ from modal import Image, App, asgi_app, Secret from routers import workflow, chat, firmware, plugins, memories, transcribe_v2, notifications, \ speech_profile, agents, facts, users, processing_memories, trends, sdcard, sync, apps, custom_auth, payment +from utils.other.timeout import TimeoutMiddleware if os.environ.get('SERVICE_ACCOUNT_JSON'): service_account_info = json.loads(os.environ["SERVICE_ACCOUNT_JSON"]) @@ -40,6 +41,16 @@ app.include_router(payment.router) + +methods_timeout = { + "GET": os.environ.get('HTTP_GET_TIMEOUT'), + "PUT": os.environ.get('HTTP_PUT_TIMEOUT'), + "PATCH": os.environ.get('HTTP_PATCH_TIMEOUT'), + "DELETE": os.environ.get('HTTP_DELETE_TIMEOUT'), + } + +app.add_middleware(TimeoutMiddleware,methods_timeout=methods_timeout) + modal_app = App( name='backend', secrets=[Secret.from_name("gcp-credentials"), Secret.from_name('envs')], diff --git a/backend/utils/other/timeout.py b/backend/utils/other/timeout.py new file mode 100644 index 000000000..ef51c8b4a --- /dev/null +++ b/backend/utils/other/timeout.py @@ -0,0 +1,40 @@ +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.responses import Response +from fastapi import Request +import asyncio +import os + +class TimeoutMiddleware(BaseHTTPMiddleware): + def __init__(self, app, methods_timeout: dict = None): + super().__init__(app) + + self.default_timeout = self._get_timeout_from_env("HTTP_DEFAULT_TIMEOUT", default=2 * 60) + + self.methods_timeout = self._parse_methods_timeout(methods_timeout or {}) + + @staticmethod + def _get_timeout_from_env(env_var: str, default: float) -> float: + timeout = os.environ.get(env_var, default) + try: + return float(timeout) + except ValueError: + raise ValueError(f"Invalid timeout value in env {env_var}: {timeout}") + + @staticmethod + def _parse_methods_timeout(methods_timeout: dict) -> dict: + result = {} + for method, timeout in methods_timeout.items(): + if timeout is None: + continue + try: + result[method.upper()] = float(timeout) + except ValueError: + raise ValueError(f"Invalid timeout value for method {method}: {timeout}") + return result + + async def dispatch(self, request: Request, call_next): + timeout = self.methods_timeout.get(request.method, self.default_timeout) + try: + return await asyncio.wait_for(call_next(request), timeout=timeout) + except asyncio.TimeoutError: + return Response(status_code=504)