From 55f65c9a1d15881f40c5f32cc7fc9631d5ca2174 Mon Sep 17 00:00:00 2001 From: Nick Byrne Date: Sun, 3 Mar 2024 21:50:50 -0300 Subject: [PATCH] Add chat-identifiers endpoint --- ragna/deploy/_api/core.py | 7 +++++++ ragna/deploy/_api/database.py | 17 +++++++++++++++++ ragna/deploy/_api/schemas.py | 5 +++++ 3 files changed, 29 insertions(+) diff --git a/ragna/deploy/_api/core.py b/ragna/deploy/_api/core.py index 3aab3640..bb515e50 100644 --- a/ragna/deploy/_api/core.py +++ b/ragna/deploy/_api/core.py @@ -200,6 +200,13 @@ async def create_chat( database.add_chat(session, user=user, chat=chat) return chat + @app.get("/chat-identifiers") + async def get_chat_identifiers( + user: UserDependency + ) -> list[schemas.ChatIdentifier]: + with get_session() as session: + return database.get_chat_identifiers(session, user=user) + @app.get("/chats") async def get_chats(user: UserDependency) -> list[schemas.Chat]: with get_session() as session: diff --git a/ragna/deploy/_api/database.py b/ragna/deploy/_api/database.py index b4433a33..b9ca7eff 100644 --- a/ragna/deploy/_api/database.py +++ b/ragna/deploy/_api/database.py @@ -98,6 +98,23 @@ def add_chat(session: Session, *, user: str, chat: schemas.Chat) -> None: session.commit() +def _orm_to_schema_identifier(chat: orm.Chat) -> schemas.ChatIdentifier: + return schemas.ChatIdentifier(id=chat.id, name=chat.name) + + +def get_chat_identifiers( + session: Session, *, user: str +) -> list[schemas.ChatIdentifier]: + return [ + _orm_to_schema_identifier(chat) + for chat in session.execute( + select(orm.Chat).where(orm.Chat.user_id == _get_user_id(session, user)) + ) + .scalars() + .all() + ] + + def _orm_to_schema_chat(chat: orm.Chat) -> schemas.Chat: documents = [ schemas.Document(id=document.id, name=document.name) diff --git a/ragna/deploy/_api/schemas.py b/ragna/deploy/_api/schemas.py index 7439f0d5..88784e72 100644 --- a/ragna/deploy/_api/schemas.py +++ b/ragna/deploy/_api/schemas.py @@ -82,3 +82,8 @@ class Chat(BaseModel): metadata: ChatMetadata messages: list[Message] = Field(default_factory=list) prepared: bool = False + + +class ChatIdentifier(BaseModel): + id: uuid.UUID + name: str