diff --git a/pyproject.toml b/pyproject.toml index f849db91..08e60305 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,7 +22,6 @@ classifiers = [ requires-python = ">=3.10" dependencies = [ "aiofiles", - "emoji", "fastapi", "httpx", "packaging", diff --git a/ragna/assistants/_demo.py b/ragna/assistants/_demo.py index f9bf644e..2f002f15 100644 --- a/ragna/assistants/_demo.py +++ b/ragna/assistants/_demo.py @@ -30,11 +30,11 @@ def answer(self, messages: list[Message]) -> Iterator[str]: def _markdown_answer(self) -> str: return textwrap.dedent( """ - | String | Integer | Float | Emoji | - | :----- | :------: | ----: | ------------------ | - | foo | 0 | 1.0 | :unicorn: | - | `bar` | 1 | -1.23 | :metal: | - | "baz" | -1 | 1e6 | :eye: :lips: :eye: | + | String | Integer | Float | Emoji | + | :----- | :------: | ----: | ------ | + | foo | 0 | 1.0 | 🦄 | + | `bar` | 1 | -1.23 | 🤘 | + | "baz" | -1 | 1e6 | 👁👄👁 | """ ).strip() diff --git a/ragna/deploy/_ui/api_wrapper.py b/ragna/deploy/_ui/api_wrapper.py deleted file mode 100644 index 2f8f37b2..00000000 --- a/ragna/deploy/_ui/api_wrapper.py +++ /dev/null @@ -1,62 +0,0 @@ -import uuid -from datetime import datetime - -import emoji -import panel as pn -import param - -from ragna.deploy import _schemas as schemas -from ragna.deploy._engine import Engine - - -class ApiWrapper(param.Parameterized): - def __init__(self, engine: Engine): - super().__init__() - self._user = pn.state.user - self._engine = engine - - async def get_corpus_names(self): - return await self._engine.get_corpuses() - - async def get_corpus_metadata(self): - return await self._engine.get_corpus_metadata() - - async def get_chats(self): - json_data = [ - chat.model_dump(mode="json") - for chat in self._engine.get_chats(user=self._user) - ] - for chat in json_data: - chat["messages"] = [self.improve_message(msg) for msg in chat["messages"]] - return json_data - - async def answer(self, chat_id, prompt): - async for message in self._engine.answer_stream( - user=self._user, chat_id=uuid.UUID(chat_id), prompt=prompt - ): - yield self.improve_message(message.model_dump(mode="json")) - - def get_components(self): - return self._engine.get_components() - - async def start_and_prepare( - self, name, input, corpus_name, source_storage, assistant, params - ): - chat = self._engine.create_chat( - user=self._user, - chat_creation=schemas.ChatCreation( - name=name, - input=input, - source_storage=source_storage, - assistant=assistant, - corpus_name=corpus_name, - params=params, - ), - ) - await self._engine.prepare_chat(user=self._user, id=chat.id) - return str(chat.id) - - def improve_message(self, msg): - msg["timestamp"] = datetime.strptime(msg["timestamp"], "%Y-%m-%dT%H:%M:%S.%fZ") - msg["content"] = emoji.emojize(msg["content"], language="alias") - return msg diff --git a/ragna/deploy/_ui/app.py b/ragna/deploy/_ui/app.py index 4163c378..aa691311 100644 --- a/ragna/deploy/_ui/app.py +++ b/ragna/deploy/_ui/app.py @@ -5,7 +5,6 @@ from . import js from . import styles as ui -from .api_wrapper import ApiWrapper from .main_page import MainPage pn.extension( @@ -68,10 +67,8 @@ def get_template(self): return template def index_page(self): - api_wrapper = ApiWrapper(self._engine) - template = self.get_template() - main_page = MainPage(api_wrapper=api_wrapper, template=template) + main_page = MainPage(engine=self._engine, template=template) template.main.append(main_page) return template diff --git a/ragna/deploy/_ui/central_view.py b/ragna/deploy/_ui/central_view.py index 147ee686..d72226bd 100644 --- a/ragna/deploy/_ui/central_view.py +++ b/ragna/deploy/_ui/central_view.py @@ -8,6 +8,7 @@ from panel.reactive import ReactiveHTML from ragna.core._metadata_filter import MetadataFilter +from ragna.deploy._schemas import Chat from . import styles as ui @@ -161,18 +162,19 @@ def _build_message(self, *args, **kwargs) -> Optional[RagnaChatMessage]: # We only ever hit this function for user inputs, since we control the # generation of the system and assistant messages manually. Thus, we can # unconditionally create a user message here. - return RagnaChatMessage(message.object, role="user", user=self.user) + return RagnaChatMessage( + message.object, role="user", user=cast(str, pn.state.user) + ) class CentralView(pn.viewable.Viewer): - current_chat = param.ClassSelector(class_=dict, default=None) + current_chat = param.ClassSelector(class_=Chat, default=None) - def __init__(self, api_wrapper, **params): + def __init__(self, engine, **params): super().__init__(**params) # FIXME: make this dynamic from the login - self.user = "" - self.api_wrapper = api_wrapper + self._engine = engine self.chat_info_button = pn.widgets.Button( # The name will be filled at runtime in self.header name="", @@ -187,24 +189,24 @@ def on_click_chat_info_wrapper(self, event): return # see _api/schemas.py for `input` type definitions - if self.current_chat["documents"] is not None: + if self.current_chat.documents is not None: title = "Uploaded Files" pills = "".join( [ f"""
{d['name']}
""" - for d in self.current_chat["documents"] + for d in self.current_chat.documents ] ) details = f"
{pills}

\n\n" - grid_height = len(self.current_chat["documents"]) // 3 + grid_height = len(self.current_chat.documents) // 3 - elif self.current_chat["metadata_filter"] is not None: + elif self.current_chat.metadata_filter is not None: title = "Metadata Filter" metadata_filters_readable = ( - str(MetadataFilter.from_primitive(self.current_chat["metadata_filter"])) + str(MetadataFilter.from_primitive(self.current_chat.metadata_filter)) .replace("\n", "
") .replace(" ", " ") ) @@ -225,14 +227,14 @@ def on_click_chat_info_wrapper(self, event): details, "----", "**Source Storage**", - f"""{self.current_chat['source_storage']}\n""", + f"""{self.current_chat.source_storage}\n""", "----", "**Assistant**", - f"""{self.current_chat['assistant']}\n""", + f"""{self.current_chat.assistant}\n""", "**Advanced configuration**", *[ f"- **{key.replace('_', ' ').title()}**: {value}" - for key, value in self.current_chat["params"].items() + for key, value in self.current_chat.params.items() ], ] ) @@ -300,9 +302,9 @@ def get_user_from_role(self, role: Literal["system", "user", "assistant"]) -> st if role == "system": return "Ragna" elif role == "user": - return cast(str, self.user) + return cast(str, pn.state.user) elif role == "assistant": - return cast(str, self.current_chat["assistant"]) + return cast(str, self.current_chat.assistant) else: raise RuntimeError @@ -310,21 +312,25 @@ async def chat_callback( self, content: str, user: str, instance: pn.chat.ChatInterface ): try: - answer_stream = self.api_wrapper.answer(self.current_chat["id"], content) + answer_stream = self._engine.answer_stream( + user=pn.state.user, + chat_id=self.current_chat.id, + prompt=content, + ) answer = await anext(answer_stream) message = RagnaChatMessage( - answer["content"], + answer.content, role="assistant", user=self.get_user_from_role("assistant"), - sources=answer["sources"], + sources=answer.sources, on_click_source_info_callback=self.on_click_source_info_wrapper, assistant_toolbar_visible=False, ) yield message async for chunk in answer_stream: - message.content_pane.object += chunk["content"] + message.content_pane.object += chunk.content message.clipboard_button.value = message.content_pane.object message.assistant_toolbar.visible = True @@ -354,17 +360,17 @@ def chat_interface(self): return RagnaChatInterface( *[ RagnaChatMessage( - message["content"], - role=message["role"], - user=self.get_user_from_role(message["role"]), - sources=message["sources"], - timestamp=message["timestamp"], + message.content, + role=message.role, + user=self.get_user_from_role(message.role), + sources=message.sources, + timestamp=message.timestamp, on_click_source_info_callback=self.on_click_source_info_wrapper, ) - for message in self.current_chat["messages"] + for message in self.current_chat.messages ], callback=self.chat_callback, - user=self.user, + user=pn.state.user, get_user_from_role=self.get_user_from_role, show_rerun=False, show_undo=False, @@ -393,7 +399,7 @@ def header(self): current_chat_name = "" if self.current_chat is not None: - current_chat_name = self.current_chat["name"] + current_chat_name = self.current_chat.name chat_name_header = pn.pane.HTML( f"

{current_chat_name}

", @@ -402,8 +408,8 @@ def header(self): ) chat_documents_pills = [] - if self.current_chat is not None and self.current_chat["documents"] is not None: - doc_names = [d["name"] for d in self.current_chat["documents"]] + if self.current_chat is not None and self.current_chat.documents is not None: + doc_names = [d.name for d in self.current_chat.documents] # FIXME: Instead of setting a hard limit of 20 documents here, this should # scale automatically with the width of page @@ -417,7 +423,7 @@ def header(self): chat_documents_pills.append(pill) self.chat_info_button.name = ( - f"{self.current_chat['assistant']} | {self.current_chat['source_storage']}" + f"{self.current_chat.assistant} | {self.current_chat.source_storage}" ) return pn.Row( diff --git a/ragna/deploy/_ui/left_sidebar.py b/ragna/deploy/_ui/left_sidebar.py index b55ce1eb..5ab7d3e9 100644 --- a/ragna/deploy/_ui/left_sidebar.py +++ b/ragna/deploy/_ui/left_sidebar.py @@ -11,10 +11,10 @@ class LeftSidebar(pn.viewable.Viewer): current_chat_id = param.String(default=None) refresh_counter = param.Integer(default=0) - def __init__(self, api_wrapper, **params): + def __init__(self, engine, **params): super().__init__(**params) - self.api_wrapper = api_wrapper + self._engine = engine self.on_click_chat = None self.on_click_new_chat = None @@ -63,9 +63,10 @@ def refresh(self): @pn.depends("refresh_counter", "chats", "current_chat_id", on_init=True) def __panel__(self): epoch = datetime(1970, 1, 1) + self.chats.sort( key=lambda chat: ( - epoch if not chat["messages"] else chat["messages"][-1]["timestamp"] + epoch if not chat.messages else chat.messages[-1].timestamp ), reverse=True, ) @@ -73,7 +74,7 @@ def __panel__(self): self.chat_buttons = [] for chat in self.chats: button = pn.widgets.Button( - name=chat["name"], + name=chat.name, css_classes=["chat_button"], ) button.on_click(lambda event, c=chat: self.on_click_chat_wrapper(event, c)) @@ -105,7 +106,7 @@ def __panel__(self): + self.chat_buttons + [ pn.layout.VSpacer(), - pn.pane.HTML(f"user: {self.api_wrapper._user}"), + pn.pane.HTML(f"user: {pn.state.user}"), pn.pane.HTML(f"version: {ragna_version}"), # self.footer() ] diff --git a/ragna/deploy/_ui/main_page.py b/ragna/deploy/_ui/main_page.py index 7e4822ae..6c7034c8 100644 --- a/ragna/deploy/_ui/main_page.py +++ b/ragna/deploy/_ui/main_page.py @@ -13,9 +13,9 @@ class MainPage(pn.viewable.Viewer, param.Parameterized): current_chat_id = param.String(default=None) chats = param.List(default=None) - def __init__(self, api_wrapper, template): + def __init__(self, engine, template): super().__init__() - self.api_wrapper = api_wrapper + self._engine = engine self.template = template self.components = None @@ -23,12 +23,12 @@ def __init__(self, api_wrapper, template): self.corpus_names = None self.modal = None - self.central_view = CentralView(api_wrapper=self.api_wrapper) + self.central_view = CentralView(engine=self._engine) self.central_view.on_click_chat_info = ( lambda event, title, content: self.show_right_sidebar(title, content) ) - self.left_sidebar = LeftSidebar(api_wrapper=self.api_wrapper) + self.left_sidebar = LeftSidebar(engine=self._engine) self.left_sidebar.on_click_chat = self.on_click_chat self.left_sidebar.on_click_new_chat = self.open_modal @@ -41,26 +41,23 @@ def __init__(self, api_wrapper, template): ) async def refresh_data(self): - self.chats = await self.api_wrapper.get_chats() - self.components = self.api_wrapper.get_components() - self.corpus_metadata = await self.api_wrapper.get_corpus_metadata() - self.corpus_names = await self.api_wrapper.get_corpus_names() + self.chats = self._engine.get_chats(user=pn.state.user) + self.components = self._engine.get_components() + self.corpus_metadata = await self._engine.get_corpus_metadata() + self.corpus_names = await self._engine.get_corpuses() @param.depends("chats", watch=True) def after_update_chats(self): self.left_sidebar.chats = self.chats if len(self.chats) > 0: - chat_id_exist = ( - len([c["id"] for c in self.chats if c["id"] == self.current_chat_id]) - > 0 - ) + chat_id_exist = any(c.id == self.current_chat_id for c in self.chats) if self.current_chat_id is None or not chat_id_exist: - self.current_chat_id = self.chats[0]["id"] + self.current_chat_id = str(self.chats[0].id) for c in self.chats: - if c["id"] == self.current_chat_id: + if str(c.id) == self.current_chat_id: self.central_view.set_current_chat(c) break @@ -73,7 +70,7 @@ async def open_modal(self, event): await self.refresh_data() self.modal = ModalConfiguration( - api_wrapper=self.api_wrapper, + engine=self._engine, components=self.components, corpus_metadata=self.corpus_metadata, corpus_names=self.corpus_names, diff --git a/ragna/deploy/_ui/modal_configuration.py b/ragna/deploy/_ui/modal_configuration.py index 6acf3a33..78de6ba3 100644 --- a/ragna/deploy/_ui/modal_configuration.py +++ b/ragna/deploy/_ui/modal_configuration.py @@ -89,12 +89,10 @@ class ModalConfiguration(pn.viewable.Viewer): error = param.Boolean(default=False) - def __init__( - self, api_wrapper, components, corpus_names, corpus_metadata, **params - ): + def __init__(self, engine, components, corpus_names, corpus_metadata, **params): super().__init__(chat_name=get_default_chat_name(), **params) - self.api_wrapper = api_wrapper + self._engine = engine self.corpus_names = corpus_names self.corpus_metadata = corpus_metadata @@ -105,7 +103,7 @@ def __init__( self.document_uploader = pn.widgets.FileInput( multiple=True, css_classes=["file-input"], - accept=",".join(self.api_wrapper.get_components().documents), + accept=",".join(self._engine.get_components().documents), ) # Most widgets (including those that use from_param) should be placed after the super init call @@ -158,15 +156,15 @@ async def did_click_on_start_chat_button(self, event): return self.start_chat_button.disabled = True - documents = self.api_wrapper._engine.register_documents( - user=self.api_wrapper._user, + documents = self._engine.register_documents( + user=pn.state.user, document_registrations=[ schemas.DocumentRegistration(name=name) for name in self.document_uploader.filename ], ) - if self.api_wrapper._engine.supports_store_documents: + if self._engine.supports_store_documents: def make_content_stream(data: bytes) -> AsyncIterator[bytes]: async def content_stream() -> AsyncIterator[bytes]: @@ -174,8 +172,8 @@ async def content_stream() -> AsyncIterator[bytes]: return content_stream() - await self.api_wrapper._engine.store_documents( - user=self.api_wrapper._user, + await self._engine.store_documents( + user=pn.state.user, ids_and_streams=[ (document.id, make_content_stream(data)) for document, data in zip( @@ -207,14 +205,19 @@ async def did_finish_upload(self, input, corpus_name=None): corpus_name = self.corpus_name_input.value try: - new_chat_id = await self.api_wrapper.start_and_prepare( - name=self.chat_name, - input=input, - corpus_name=corpus_name, - source_storage=self.config.source_storage_name, - assistant=self.config.assistant_name, - params=self.config.to_params_dict(), + chat = self._engine.create_chat( + user=pn.state.user, + chat_creation=schemas.ChatCreation( + name=self.chat_name, + input=input, + corpus_name=corpus_name, + source_storage=self.config.source_storage_name, + assistant=self.config.assistant_name, + params=self.config.to_params_dict(), + ), ) + await self._engine.prepare_chat(user=pn.state.user, id=chat.id) + new_chat_id = str(chat.id) self.start_chat_button.disabled = False @@ -249,7 +252,7 @@ def change_upload_files_label(self, mode="normal"): def create_config(self, components): if self.config is None: # Retrieve the components from the API and build a config object - components = self.api_wrapper.get_components() + components = self._engine.get_components() # TODO : use the components to set up the default values for the various params config = ChatConfig()