Skip to content

Commit

Permalink
chore: add timeout for api (#1638)
Browse files Browse the repository at this point in the history
  • Loading branch information
beastoin authored Feb 17, 2025
2 parents 2efd100 + 7f166f1 commit 65bcc0b
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 0 deletions.
11 changes: 11 additions & 0 deletions backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from modal import Image, App, asgi_app, Secret
from routers import workflow, chat, firmware, plugins, memories, transcribe_v2, notifications, \
speech_profile, agents, facts, users, processing_memories, trends, sdcard, sync, apps, custom_auth, payment
from utils.other.timeout import TimeoutMiddleware

if os.environ.get('SERVICE_ACCOUNT_JSON'):
service_account_info = json.loads(os.environ["SERVICE_ACCOUNT_JSON"])
Expand Down Expand Up @@ -40,6 +41,16 @@

app.include_router(payment.router)


methods_timeout = {
"GET": os.environ.get('HTTP_GET_TIMEOUT'),
"PUT": os.environ.get('HTTP_PUT_TIMEOUT'),
"PATCH": os.environ.get('HTTP_PATCH_TIMEOUT'),
"DELETE": os.environ.get('HTTP_DELETE_TIMEOUT'),
}

app.add_middleware(TimeoutMiddleware,methods_timeout=methods_timeout)

modal_app = App(
name='backend',
secrets=[Secret.from_name("gcp-credentials"), Secret.from_name('envs')],
Expand Down
40 changes: 40 additions & 0 deletions backend/utils/other/timeout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import Response
from fastapi import Request
import asyncio
import os

class TimeoutMiddleware(BaseHTTPMiddleware):
def __init__(self, app, methods_timeout: dict = None):
super().__init__(app)

self.default_timeout = self._get_timeout_from_env("HTTP_DEFAULT_TIMEOUT", default=2 * 60)

self.methods_timeout = self._parse_methods_timeout(methods_timeout or {})

@staticmethod
def _get_timeout_from_env(env_var: str, default: float) -> float:
timeout = os.environ.get(env_var, default)
try:
return float(timeout)
except ValueError:
raise ValueError(f"Invalid timeout value in env {env_var}: {timeout}")

@staticmethod
def _parse_methods_timeout(methods_timeout: dict) -> dict:
result = {}
for method, timeout in methods_timeout.items():
if timeout is None:
continue
try:
result[method.upper()] = float(timeout)
except ValueError:
raise ValueError(f"Invalid timeout value for method {method}: {timeout}")
return result

async def dispatch(self, request: Request, call_next):
timeout = self.methods_timeout.get(request.method, self.default_timeout)
try:
return await asyncio.wait_for(call_next(request), timeout=timeout)
except asyncio.TimeoutError:
return Response(status_code=504)

0 comments on commit 65bcc0b

Please sign in to comment.