Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: add timeout for api #1638

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 54 additions & 2 deletions app/lib/backend/http/api/conversations.dart
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,12 @@ Future<List<ServerConversation>> getConversations(
int offset = 0,
List<ConversationStatus> statuses = const [],
bool includeDiscarded = true}) async {
var response = await makeApiCall(

int segmentLimit = 50;

var response = await makeApiCall(
url:
'${Env.apiBaseUrl}v1/memories?include_discarded=$includeDiscarded&limit=$limit&offset=$offset&statuses=${statuses.map((val) => val.toString().split(".").last).join(",")}',
'${Env.apiBaseUrl}v1/memories?include_discarded=$includeDiscarded&limit=$limit&offset=$offset&statuses=${statuses.map((val) => val.toString().split(".").last).join(",")}&segment_limit=$segmentLimit',
headers: {},
method: 'GET',
body: '');
Expand All @@ -54,13 +57,62 @@ Future<List<ServerConversation>> getConversations(
var memories =
(jsonDecode(body) as List<dynamic>).map((conversation) => ServerConversation.fromJson(conversation)).toList();
debugPrint('getMemories length: ${memories.length}');

for (var memory in memories) {
if (memory.transcriptSegments.length < segmentLimit) {
continue;
}
// Get all transcript segments for this memory, with paging
List<TranscriptSegment> allSegments = [];
int segmentOffset = memory.transcriptSegments.length;

while (true) {
var segments = await getTranscriptSegmentsForConversation(memory.id, segmentLimit, offset:segmentOffset);
if (segments.isEmpty) break;

allSegments.addAll(segments);
segmentOffset += segmentLimit;
}
memory.addTranscriptSegments(allSegments);
}

return memories;
} else {
debugPrint('getMemories error ${response.statusCode}');
}
return [];
}


Future<List<TranscriptSegment>> getTranscriptSegmentsForConversation(String conversationId, int? limit, {int offset = 0}) async {
var url = '${Env.apiBaseUrl}v1/memories/$conversationId/transcript_segments';
if (limit != null) {
url += '?limit=$limit&offset=$offset';
}

var response = await makeApiCall(
url: url,
headers: {},
method: 'GET',
body: ''
);

if (response == null) return [];

if (response.statusCode == 200) {
var body = utf8.decode(response.bodyBytes);
var segments = (jsonDecode(body) as List<dynamic>).map((segment) {
return TranscriptSegment.fromJson(segment);
}).toList();

return segments;
} else {
debugPrint('getTranscriptSegmentsForConversation error ${response.statusCode}');
}

return [];
}

Future<ServerConversation?> reProcessConversationServer(String conversationId) async {
var response = await makeApiCall(
url: '${Env.apiBaseUrl}v1/memories/$conversationId/reprocess',
Expand Down
4 changes: 4 additions & 0 deletions app/lib/backend/schema/conversation.dart
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,10 @@ class ServerConversation {
return transcript;
}
}

void addTranscriptSegments(List<TranscriptSegment> newSegments) {
transcriptSegments.addAll(newSegments);
}
}

class SyncLocalFilesResponse {
Expand Down
11 changes: 11 additions & 0 deletions backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from modal import Image, App, asgi_app, Secret
from routers import workflow, chat, firmware, plugins, memories, transcribe_v2, notifications, \
speech_profile, agents, facts, users, processing_memories, trends, sdcard, sync, apps, custom_auth, payment
from utils.other.timeout import TimeoutMiddleware

if os.environ.get('SERVICE_ACCOUNT_JSON'):
service_account_info = json.loads(os.environ["SERVICE_ACCOUNT_JSON"])
Expand Down Expand Up @@ -40,6 +41,16 @@

app.include_router(payment.router)


methods_timeout = {
"GET": os.environ.get('HTTP_GET_TIMEOUT'),
"PUT": os.environ.get('HTTP_PUT_TIMEOUT'),
"PATCH": os.environ.get('HTTP_PATCH_TIMEOUT'),
"DELETE": os.environ.get('HTTP_DELETE_TIMEOUT'),
}

app.add_middleware(TimeoutMiddleware,methods_timeout=methods_timeout)

modal_app = App(
name='backend',
secrets=[Secret.from_name("gcp-credentials"), Secret.from_name('envs')],
Expand Down
18 changes: 15 additions & 3 deletions backend/routers/memories.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,16 +69,28 @@ def reprocess_memory(


@router.get('/v1/memories', response_model=List[Memory], tags=['memories'])
def get_memories(limit: int = 100, offset: int = 0, statuses: str = "", include_discarded: bool = True, uid: str = Depends(auth.get_current_user_uid)):
def get_memories(limit: int = 100, offset: int = 0, statuses: str = "", include_discarded: bool = True, segment_limit: int = None, uid: str = Depends(auth.get_current_user_uid)):
print('get_memories', uid, limit, offset, statuses)
return memories_db.get_memories(uid, limit, offset, include_discarded=include_discarded,
memories = memories_db.get_memories(uid, limit, offset, include_discarded=include_discarded,
statuses=statuses.split(",") if len(statuses) > 0 else [])

if segment_limit is not None:
for memory in memories:
memory["transcript_segments"] = memory.get("transcript_segments", [])[:segment_limit]
return memories

@router.get("/v1/memories/{memory_id}", response_model=Memory, tags=['memories'])
def get_memory_by_id(memory_id: str, uid: str = Depends(auth.get_current_user_uid)):
return _get_memory_by_id(uid, memory_id)

@router.get("/v1/memories/{memory_id}/transcript_segments", response_model=List[TranscriptSegment], tags=['memories'])
def get_transcript_segments_memory(memory_id: str, uid: str = Depends(auth.get_current_user_uid), limit: int = None, offset: int = 0):
memory = _get_memory_by_id(uid, memory_id)
transcript_segments = memory["transcript_segments"]
if limit is not None:
transcript_segments = transcript_segments[offset:limit+offset]
return transcript_segments



@router.patch("/v1/memories/{memory_id}/title", tags=['memories'])
def patch_memory_title(memory_id: str, title: str, uid: str = Depends(auth.get_current_user_uid)):
Expand Down
40 changes: 40 additions & 0 deletions backend/utils/other/timeout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import Response
from fastapi import Request
import asyncio
import os

class TimeoutMiddleware(BaseHTTPMiddleware):
def __init__(self, app, methods_timeout: dict = None):
super().__init__(app)

self.default_timeout = self._get_timeout_from_env("HTTP_DEFAULT_TIMEOUT", default=2 * 60)

self.methods_timeout = self._parse_methods_timeout(methods_timeout or {})

@staticmethod
def _get_timeout_from_env(env_var: str, default: float) -> float:
timeout = os.environ.get(env_var, default)
try:
return float(timeout)
except ValueError:
raise ValueError(f"Invalid timeout value in env {env_var}: {timeout}")

@staticmethod
def _parse_methods_timeout(methods_timeout: dict) -> dict:
result = {}
for method, timeout in methods_timeout.items():
if timeout is None:
continue
try:
result[method.upper()] = float(timeout)
except ValueError:
raise ValueError(f"Invalid timeout value for method {method}: {timeout}")
return result

async def dispatch(self, request: Request, call_next):
timeout = self.methods_timeout.get(request.method, self.default_timeout)
try:
return await asyncio.wait_for(call_next(request), timeout=timeout)
except asyncio.TimeoutError:
return Response(status_code=504)