Skip to content

Commit

Permalink
chore: add timeout for api
Browse files Browse the repository at this point in the history
  • Loading branch information
nquang29 committed Jan 17, 2025
1 parent 2d977da commit ac328e5
Show file tree
Hide file tree
Showing 5 changed files with 124 additions and 5 deletions.
56 changes: 54 additions & 2 deletions app/lib/backend/http/api/conversations.dart
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,12 @@ Future<List<ServerConversation>> getConversations(
int offset = 0,
List<ConversationStatus> statuses = const [],
bool includeDiscarded = true}) async {
var response = await makeApiCall(

int segmentLimit = 50;

var response = await makeApiCall(
url:
'${Env.apiBaseUrl}v1/memories?include_discarded=$includeDiscarded&limit=$limit&offset=$offset&statuses=${statuses.map((val) => val.toString().split(".").last).join(",")}',
'${Env.apiBaseUrl}v1/memories?include_discarded=$includeDiscarded&limit=$limit&offset=$offset&statuses=${statuses.map((val) => val.toString().split(".").last).join(",")}&segment_limit=$segmentLimit',
headers: {},
method: 'GET',
body: '');
Expand All @@ -54,13 +57,62 @@ Future<List<ServerConversation>> getConversations(
var memories =
(jsonDecode(body) as List<dynamic>).map((conversation) => ServerConversation.fromJson(conversation)).toList();
debugPrint('getMemories length: ${memories.length}');

for (var memory in memories) {
if (memory.transcriptSegments.length < segmentLimit) {
continue;
}
// Get all transcript segments for this memory, with paging
List<TranscriptSegment> allSegments = [];
int segmentOffset = memory.transcriptSegments.length;

while (true) {
var segments = await getTranscriptSegmentsForConversation(memory.id, segmentLimit, offset:segmentOffset);
if (segments.isEmpty) break;

allSegments.addAll(segments);
segmentOffset += segmentLimit;
}
memory.addTranscriptSegments(allSegments);
}

return memories;
} else {
debugPrint('getMemories error ${response.statusCode}');
}
return [];
}


Future<List<TranscriptSegment>> getTranscriptSegmentsForConversation(String conversationId, int? limit, {int offset = 0}) async {
var url = '${Env.apiBaseUrl}v1/memories/$conversationId/transcript_segments';
if (limit != null) {
url += '?limit=$limit&offset=$offset';
}

var response = await makeApiCall(
url: url,
headers: {},
method: 'GET',
body: ''
);

if (response == null) return [];

if (response.statusCode == 200) {
var body = utf8.decode(response.bodyBytes);
var segments = (jsonDecode(body) as List<dynamic>).map((segment) {
return TranscriptSegment.fromJson(segment);
}).toList();

return segments;
} else {
debugPrint('getTranscriptSegmentsForConversation error ${response.statusCode}');
}

return [];
}

Future<ServerConversation?> reProcessConversationServer(String conversationId) async {
var response = await makeApiCall(
url: '${Env.apiBaseUrl}v1/memories/$conversationId/reprocess',
Expand Down
4 changes: 4 additions & 0 deletions app/lib/backend/schema/conversation.dart
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,10 @@ class ServerConversation {
return transcript;
}
}

void addTranscriptSegments(List<TranscriptSegment> newSegments) {
transcriptSegments.addAll(newSegments);
}
}

class SyncLocalFilesResponse {
Expand Down
11 changes: 11 additions & 0 deletions backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from modal import Image, App, asgi_app, Secret
from routers import workflow, chat, firmware, plugins, memories, transcribe_v2, notifications, \
speech_profile, agents, facts, users, processing_memories, trends, sdcard, sync, apps, custom_auth, payment
from utils.other.timeout import TimeoutMiddleware

if os.environ.get('SERVICE_ACCOUNT_JSON'):
service_account_info = json.loads(os.environ["SERVICE_ACCOUNT_JSON"])
Expand Down Expand Up @@ -40,6 +41,16 @@

app.include_router(payment.router)


methods_timeout = {
"GET": os.environ.get('HTTP_GET_TIMEOUT'),
"PUT": os.environ.get('HTTP_PUT_TIMEOUT'),
"PATCH": os.environ.get('HTTP_PATCH_TIMEOUT'),
"DELETE": os.environ.get('HTTP_DELETE_TIMEOUT'),
}

app.add_middleware(TimeoutMiddleware,methods_timeout=methods_timeout)

modal_app = App(
name='backend',
secrets=[Secret.from_name("gcp-credentials"), Secret.from_name('envs')],
Expand Down
18 changes: 15 additions & 3 deletions backend/routers/memories.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,16 +69,28 @@ def reprocess_memory(


@router.get('/v1/memories', response_model=List[Memory], tags=['memories'])
def get_memories(limit: int = 100, offset: int = 0, statuses: str = "", include_discarded: bool = True, uid: str = Depends(auth.get_current_user_uid)):
def get_memories(limit: int = 100, offset: int = 0, statuses: str = "", include_discarded: bool = True, segment_limit: int = None, uid: str = Depends(auth.get_current_user_uid)):
print('get_memories', uid, limit, offset, statuses)
return memories_db.get_memories(uid, limit, offset, include_discarded=include_discarded,
memories = memories_db.get_memories(uid, limit, offset, include_discarded=include_discarded,
statuses=statuses.split(",") if len(statuses) > 0 else [])

if segment_limit is not None:
for memory in memories:
memory["transcript_segments"] = memory.get("transcript_segments", [])[:segment_limit]
return memories

@router.get("/v1/memories/{memory_id}", response_model=Memory, tags=['memories'])
def get_memory_by_id(memory_id: str, uid: str = Depends(auth.get_current_user_uid)):
return _get_memory_by_id(uid, memory_id)

@router.get("/v1/memories/{memory_id}/transcript_segments", response_model=List[TranscriptSegment], tags=['memories'])
def get_transcript_segments_memory(memory_id: str, uid: str = Depends(auth.get_current_user_uid), limit: int = None, offset: int = 0):
memory = _get_memory_by_id(uid, memory_id)
transcript_segments = memory["transcript_segments"]
if limit is not None:
transcript_segments = transcript_segments[offset:limit+offset]
return transcript_segments



@router.patch("/v1/memories/{memory_id}/title", tags=['memories'])
def patch_memory_title(memory_id: str, title: str, uid: str = Depends(auth.get_current_user_uid)):
Expand Down
40 changes: 40 additions & 0 deletions backend/utils/other/timeout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import Response
from fastapi import Request
import asyncio
import os

class TimeoutMiddleware(BaseHTTPMiddleware):
def __init__(self, app, methods_timeout: dict = None):
super().__init__(app)

self.default_timeout = self._get_timeout_from_env("HTTP_DEFAULT_TIMEOUT", default=2 * 60)

self.methods_timeout = self._parse_methods_timeout(methods_timeout or {})

@staticmethod
def _get_timeout_from_env(env_var: str, default: float) -> float:
timeout = os.environ.get(env_var, default)
try:
return float(timeout)
except ValueError:
raise ValueError(f"Invalid timeout value in env {env_var}: {timeout}")

@staticmethod
def _parse_methods_timeout(methods_timeout: dict) -> dict:
result = {}
for method, timeout in methods_timeout.items():
if timeout is None:
continue
try:
result[method.upper()] = float(timeout)
except ValueError:
raise ValueError(f"Invalid timeout value for method {method}: {timeout}")
return result

async def dispatch(self, request: Request, call_next):
timeout = self.methods_timeout.get(request.method, self.default_timeout)
try:
return await asyncio.wait_for(call_next(request), timeout=timeout)
except asyncio.TimeoutError:
return Response(status_code=504)

0 comments on commit ac328e5

Please sign in to comment.