Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Add GET endpoints for documents #547

Open
wants to merge 23 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
a9ebd9c
Give option for database to return all documents
smokestacklightnin Jan 27, 2025
ce37b0c
Add `get_documents`
smokestacklightnin Jan 27, 2025
d0197f5
Merge remote-tracking branch 'upstream/main' into ui/enh/document-viewer
smokestacklightnin Jan 28, 2025
b970b96
Add `GET` endpoint for `/documents`
smokestacklightnin Jan 28, 2025
f33e8ea
Add `GET` endpoint for a specific document `/documents/{id}`
smokestacklightnin Jan 28, 2025
c68f874
Fix mypy error
smokestacklightnin Jan 28, 2025
e7bf467
Add `get_document` to engine for convenience
smokestacklightnin Jan 28, 2025
ce6a4f2
Clean up
smokestacklightnin Jan 28, 2025
c4bbcaf
Add support for MIME types in `core.Document`s
smokestacklightnin Jan 28, 2025
1ab6f58
Add `GET` `/documents/{id}/content` endpoint
smokestacklightnin Jan 28, 2025
e57950c
Call correct method
smokestacklightnin Jan 28, 2025
1bc57b0
Use the builtin `mimetypes` library instead of custom logic
smokestacklightnin Jan 28, 2025
54e20a3
Merge remote-tracking branch 'upstream/main' into ui/enh/document-viewer
smokestacklightnin Jan 28, 2025
f67997e
Add mime_type to `Document` schema
smokestacklightnin Jan 28, 2025
23898f0
Add MIME type to `Document` ORM object
smokestacklightnin Jan 28, 2025
e53b180
Add MIME type to ORM <> Schema converters and Core <> Schema converters
smokestacklightnin Jan 28, 2025
2cfa02a
Add `mime_type` to initializer for `LocalDocument`
smokestacklightnin Jan 28, 2025
32dc1ab
Remove unnecessary type conversion
smokestacklightnin Jan 29, 2025
32e11ba
Make code more concise
smokestacklightnin Jan 29, 2025
162a3ff
Enforce keyword arguments
smokestacklightnin Jan 29, 2025
e122713
Help expression scale
smokestacklightnin Jan 29, 2025
06b5ab7
Use `__getitem__` instead of `next(iter(...))`
smokestacklightnin Jan 29, 2025
a9bf0fd
Use traditional `if` statement instead of ternary operator
smokestacklightnin Jan 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion ragna/core/_document.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import abc
import io
import mimetypes
import uuid
from functools import cached_property
from pathlib import Path
Expand All @@ -25,11 +26,15 @@ def __init__(
name: str,
metadata: dict[str, Any],
handler: Optional[DocumentHandler] = None,
mime_type: str | None = None,
):
self.id = id or uuid.uuid4()
self.name = name
self.metadata = metadata
self.handler = handler or self.get_handler(name)
self.mime_type = (
mime_type or mimetypes.guess_type(name)[0] or "application/octet-stream"
)

@staticmethod
def supported_suffixes() -> set[str]:
Expand Down Expand Up @@ -76,8 +81,11 @@ def __init__(
name: str,
metadata: dict[str, Any],
handler: Optional[DocumentHandler] = None,
mime_type: str | None = None,
):
super().__init__(id=id, name=name, metadata=metadata, handler=handler)
super().__init__(
id=id, name=name, metadata=metadata, handler=handler, mime_type=mime_type
)
if "path" not in self.metadata:
metadata["path"] = str(ragna.local_root() / "documents" / str(self.id))

Expand Down
23 changes: 23 additions & 0 deletions ragna/deploy/_api.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import io
import uuid
from typing import Annotated, Any, AsyncIterator

Expand Down Expand Up @@ -40,6 +41,28 @@ async def content_stream() -> AsyncIterator[bytes]:
],
)

@router.get("/documents")
async def get_documents(user: UserDependency) -> list[schemas.Document]:
return engine.get_documents(user=user.name)

@router.get("/documents/{id}")
async def get_document(user: UserDependency, id: uuid.UUID) -> schemas.Document:
return engine.get_document(user=user.name, id=id)

@router.get("/documents/{id}/content")
async def get_document_content(
user: UserDependency, id: uuid.UUID
) -> StreamingResponse:
schema_document = engine.get_document(user=user.name, id=id)
core_document = engine._to_core.document(schema_document)
headers = {"Content-Disposition": f"inline; filename={schema_document.name}"}

return StreamingResponse(
io.BytesIO(core_document.read()),
media_type=core_document.mime_type,
headers=headers,
)

@router.get("/components")
def get_components() -> schemas.Components:
return engine.get_components()
Expand Down
22 changes: 13 additions & 9 deletions ragna/deploy/_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,24 +132,24 @@ def add_documents(
session.commit()

def _get_orm_documents(
self, session: Session, *, user: str, ids: Collection[uuid.UUID]
self, session: Session, *, user: str, ids: Collection[uuid.UUID] | None = None
) -> list[orm.Document]:
# FIXME also check if the user is allowed to access the documents
# FIXME: maybe just take the user id to avoid getting it twice in add_chat?
documents = (
session.execute(select(orm.Document).where(orm.Document.id.in_(ids)))
.scalars()
.all()
)
if len(documents) != len(ids):
expr = select(orm.Document)
if ids is not None:
expr = expr.where(orm.Document.id.in_(ids))
documents = session.execute(expr).scalars().all()

if (ids is not None) and (len(documents) != len(ids)):
raise RagnaException(
str(set(ids) - {document.id for document in documents})
)

return documents # type: ignore[no-any-return]

def get_documents(
self, session: Session, *, user: str, ids: Collection[uuid.UUID]
self, session: Session, *, user: str, ids: Collection[uuid.UUID] | None = None
) -> list[schemas.Document]:
return [
self._to_schema.document(document)
Expand Down Expand Up @@ -288,6 +288,7 @@ def document(
user_id=user_id,
name=document.name,
metadata_=document.metadata,
mime_type=document.mime_type,
)

def source(self, source: schemas.Source) -> orm.Source:
Expand Down Expand Up @@ -354,7 +355,10 @@ def api_key(self, api_key: orm.ApiKey) -> schemas.ApiKey:

def document(self, document: orm.Document) -> schemas.Document:
return schemas.Document(
id=document.id, name=document.name, metadata=document.metadata_
id=document.id,
name=document.name,
metadata=document.metadata_,
mime_type=document.mime_type,
)

def source(self, source: orm.Source) -> schemas.Source:
Expand Down
18 changes: 13 additions & 5 deletions ragna/deploy/_engine.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import secrets
import uuid
from typing import Any, AsyncIterator, Optional, cast
from typing import Any, AsyncIterator, Collection, Optional, cast

from fastapi import status as http_status_code

Expand Down Expand Up @@ -182,17 +182,23 @@ async def store_documents(

streams = dict(ids_and_streams)

with self._database.get_session() as session:
documents = self._database.get_documents(
session, user=user, ids=streams.keys()
)
documents = self.get_documents(user=user, ids=streams.keys())

for document in documents:
core_document = cast(
ragna.core.LocalDocument, self._to_core.document(document)
)
await core_document._write(streams[document.id])

def get_documents(
self, *, user: str, ids: Collection[uuid.UUID] | None = None
) -> list[schemas.Document]:
with self._database.get_session() as session:
return self._database.get_documents(session, user=user, ids=ids)

def get_document(self, *, user: str, id: uuid.UUID) -> schemas.Document:
return next(iter(self.get_documents(user=user, ids=[id])))

def create_chat(
self, *, user: str, chat_creation: schemas.ChatCreation
) -> schemas.Chat:
Expand Down Expand Up @@ -280,6 +286,7 @@ def document(self, document: schemas.Document) -> core.Document:
id=document.id,
name=document.name,
metadata=document.metadata,
mime_type=document.mime_type,
)

def source(self, source: schemas.Source) -> core.Source:
Expand Down Expand Up @@ -328,6 +335,7 @@ def document(self, document: core.Document) -> schemas.Document:
id=document.id,
name=document.name,
metadata=document.metadata,
mime_type=document.mime_type,
)

def source(self, source: core.Source) -> schemas.Source:
Expand Down
1 change: 1 addition & 0 deletions ragna/deploy/_orm.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ class Document(Base):
# Mind the trailing underscore here. Unfortunately, this is necessary, because
# metadata without the underscore is reserved by SQLAlchemy
metadata_ = Column(Json, nullable=False)
mime_type = Column(types.String, nullable=False)
chats = relationship(
"Chat",
secondary=document_chat_association_table,
Expand Down
1 change: 1 addition & 0 deletions ragna/deploy/_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ class Document(BaseModel):
id: uuid.UUID = Field(default_factory=uuid.uuid4)
name: str
metadata: dict[str, Any]
mime_type: str


class Source(BaseModel):
Expand Down