diff --git a/webui/requirements.txt b/webui/requirements.txt index bc57de6..0f52043 100644 --- a/webui/requirements.txt +++ b/webui/requirements.txt @@ -1,3 +1,3 @@ -reflex>=0.4.0 -openai>=1.12.0 +reflex>=0.4.4 +openai>=1.13.3 diff --git a/webui/webui/components/chat.py b/webui/webui/components/chat.py index 0586c47..7596f0e 100644 --- a/webui/webui/components/chat.py +++ b/webui/webui/components/chat.py @@ -4,7 +4,12 @@ from webui.state import QA, State -message_style = dict(display="inline-block", padding="1em", border_radius="8px", max_width=["30em", "30em", "50em", "50em", "50em", "50em"]) +message_style = dict( + display="inline-block", + padding="1em", + border_radius="8px", + max_width=["30em", "30em", "50em", "50em", "50em", "50em"], +) def message(qa: QA) -> rx.Component: @@ -44,7 +49,10 @@ def message(qa: QA) -> rx.Component: def chat() -> rx.Component: """List all the messages in a single conversation.""" return rx.vstack( - rx.box(rx.foreach(State.chats[State.current_chat], message), width="100%"), + rx.box( + rx.foreach(State.chats[State.current_chat]["messages"], message), + width="100%", + ), py="8", flex="1", width="100%", diff --git a/webui/webui/components/navbar.py b/webui/webui/components/navbar.py index e1fe77d..0868043 100644 --- a/webui/webui/components/navbar.py +++ b/webui/webui/components/navbar.py @@ -1,28 +1,34 @@ import reflex as rx from webui.state import State + def sidebar_chat(chat: str) -> rx.Component: """A sidebar chat item. Args: chat: The chat item. """ - return rx.drawer.close(rx.hstack( - rx.button( - chat, on_click=lambda: State.set_chat(chat), width="80%", variant="surface" - ), - rx.button( - rx.icon( - tag="trash", - on_click=State.delete_chat, - stroke_width=1, + return rx.drawer.close( + rx.hstack( + rx.button( + chat, + on_click=lambda: State.set_chat(chat), + width="80%", + variant="surface", ), - width="20%", - variant="surface", - color_scheme="red", - ), - width="100%", - )) + rx.button( + rx.icon( + tag="trash", + on_click=State.delete_chat, + stroke_width=1, + ), + width="20%", + variant="surface", + color_scheme="red", + ), + width="100%", + ) + ) def sidebar(trigger) -> rx.Component: @@ -81,13 +87,16 @@ def navbar(): return rx.box( rx.hstack( rx.hstack( - rx.avatar(fallback="RC", variant="solid"), - rx.heading("Reflex Chat"), + rx.avatar(fallback="CC", variant="solid"), + rx.heading("Coca Cola"), rx.desktop_only( rx.badge( - State.current_chat, - rx.tooltip(rx.icon("info", size=14), content="The current selected chat."), - variant="soft" + State.current_chat, + rx.tooltip( + rx.icon("info", size=14), + content="The current selected chat.", + ), + variant="soft", ) ), align_items="center", @@ -103,15 +112,15 @@ def navbar(): background_color=rx.color("mauve", 6), ) ), - rx.desktop_only( - rx.button( - rx.icon( - tag="sliders-horizontal", - color=rx.color("mauve", 12), - ), - background_color=rx.color("mauve", 6), - ) - ), + # rx.desktop_only( + # rx.button( + # rx.icon( + # tag="sliders-horizontal", + # color=rx.color("mauve", 12), + # ), + # background_color=rx.color("mauve", 6), + # ) + # ), align_items="center", ), justify_content="space-between", diff --git a/webui/webui/layout.py b/webui/webui/layout.py new file mode 100644 index 0000000..44296f3 --- /dev/null +++ b/webui/webui/layout.py @@ -0,0 +1,49 @@ +import reflex as rx + + +def container(*children, **props): + """A fixed container based on a 960px grid.""" + # Enable override of default props. + props = ( + dict( + width="100%", + max_width="960px", + background=rx.color("mauve", 1), + height="100%", + px="9", + margin="0 auto", + position="relative", + ) + | props + ) + return rx.stack(*children, **props) + + +def auth_layout(*args): + """The shared layout for the login and sign up pages.""" + return rx.box( + container( + rx.vstack( + rx.heading("Welcome to your EY Interactive Chatbot!", size="8"), + rx.heading("Enter your password to get started.", size="5"), + align="center", + spacing="4", + ), + *args, + border_top_radius="10px", + box_shadow="0 4px 60px 0 rgba(0, 0, 0, 0.08), 0 4px 16px 0 rgba(0, 0, 0, 0.08)", + display="flex", + flex_direction="column", + align_items="center", + padding_top="52px", + padding_bottom="24px", + padding_x="24px", + spacing="4", + ), + height="100vh", + padding_x="50px", + padding_y="50px", + background="url(bg.svg)", + background_repeat="no-repeat", + background_size="cover", + ) diff --git a/webui/webui/state.py b/webui/webui/state.py index 3959d58..141b676 100644 --- a/webui/webui/state.py +++ b/webui/webui/state.py @@ -2,12 +2,29 @@ import reflex as rx from openai import OpenAI -client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) +_client = None + + +def get_openai_client(): + global _client + if _client is None: + _client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) + + return _client + + +assistant_id = os.getenv("ASSISTANT_ID") + +PASSWORD = os.getenv("PASSWORD") # Checking if the API key is set properly if not os.getenv("OPENAI_API_KEY"): raise Exception("Please set OPENAI_API_KEY environment variable.") +# Checking if the assistant key is set properly +if not os.getenv("ASSISTANT_ID"): + raise Exception("Please set ASSISTANT_ID environment variable.") + class QA(rx.Base): """A question and answer pair.""" @@ -17,15 +34,19 @@ class QA(rx.Base): DEFAULT_CHATS = { - "Intros": [], + "Intros": {"id": "", "messages": []}, } class State(rx.State): """The app state.""" + password: str = "" + + correct_password: bool = False + # A dict from the chat name to the list of questions and answers. - chats: dict[str, list[QA]] = DEFAULT_CHATS + chats: dict[str, dict[str, list[QA]]] = DEFAULT_CHATS # The current chat name. current_chat = "Intros" @@ -39,11 +60,17 @@ class State(rx.State): # The name of the new chat. new_chat_name: str = "" + def check_password(self): + if PASSWORD == self.password: + self.correct_password = True + else: + return rx.window_alert("Invalid password.") + def create_chat(self): """Create a new chat.""" # Add the new chat to the list of chats. self.current_chat = self.new_chat_name - self.chats[self.new_chat_name] = [] + self.chats[self.new_chat_name] = {"id": "", "messages": []} def delete_chat(self): """Delete the current chat.""" @@ -91,44 +118,61 @@ async def openai_process_question(self, question: str): # Add the question to the list of questions. qa = QA(question=question, answer="") - self.chats[self.current_chat].append(qa) + self.chats[self.current_chat]["messages"].append(qa) # Clear the input and start the processing. self.processing = True yield - # Build the messages. - messages = [ - {"role": "system", "content": "You are a friendly chatbot named Reflex. Respond in markdown."} - ] - for qa in self.chats[self.current_chat]: - messages.append({"role": "user", "content": qa.question}) - messages.append({"role": "assistant", "content": qa.answer}) - - # Remove the last mock answer. - messages = messages[:-1] - - # Start a new session to answer the question. - session = client.chat.completions.create( - model=os.getenv("OPENAI_MODEL", "gpt-3.5-turbo"), - messages=messages, - stream=True, + if self.chats[self.current_chat]["id"] == "": + thread = get_openai_client().beta.threads.create() + self.chats[self.current_chat]["id"] = thread.id + else: + thread = get_openai_client().beta.threads.retrieve( + thread_id=self.chats[self.current_chat]["id"] + ) + + get_openai_client().beta.threads.messages.create( + thread_id=thread.id, + role="user", + content=qa.question, ) - # Stream the results, yielding after every word. - for item in session: - if hasattr(item.choices[0].delta, "content"): - answer_text = item.choices[0].delta.content - # Ensure answer_text is not None before concatenation - if answer_text is not None: - self.chats[self.current_chat][-1].answer += answer_text - else: - # Handle the case where answer_text is None, perhaps log it or assign a default value - # For example, assigning an empty string if answer_text is None - answer_text = "" - self.chats[self.current_chat][-1].answer += answer_text - self.chats = self.chats - yield + run = get_openai_client().beta.threads.runs.create( + thread_id=thread.id, assistant_id=assistant_id + ) + + # Periodically retrieve the Run to check status and see if it has completed + while run.status != "completed": + keep_retrieving_run = get_openai_client().beta.threads.runs.retrieve( + thread_id=thread.id, run_id=run.id + ) + + if keep_retrieving_run.status == "completed": + break + + if keep_retrieving_run.status == "failed": + self.processing = False + yield rx.window_alert("OpenAI Request Failed! Try asking again.") + return + + # Retrieve messages added by the Assistant to the thread + all_messages = get_openai_client().beta.threads.messages.list( + thread_id=thread.id + ) + + answer_text = all_messages.data[0].content[0].text.value + + if answer_text is not None: + self.chats[self.current_chat]["messages"][-1].answer += answer_text + else: + # Handle the case where answer_text is None, perhaps log it or assign a default value + # For example, assigning an empty string if answer_text is None + answer_text = "" + self.chats[self.current_chat]["messages"][-1].answer += answer_text + + self.chats = self.chats + yield # Toggle the processing flag. self.processing = False diff --git a/webui/webui/webui.py b/webui/webui/webui.py index 7a8dcbf..5c8fa8a 100644 --- a/webui/webui/webui.py +++ b/webui/webui/webui.py @@ -1,12 +1,39 @@ """The main Chat app.""" +import os import reflex as rx from webui.components import chat, navbar +from webui.state import State +from webui.layout import auth_layout -def index() -> rx.Component: - """The main app.""" - return rx.chakra.vstack( +def login() -> rx.Component: + return auth_layout( + rx.box( + rx.vstack( + rx.input( + type="password", + placeholder="Password", + on_blur=State.set_password, + size="3", + ), + rx.button( + "Log in", on_click=State.check_password, size="3", + ), + align="center", + spacing="4", + ), + background=rx.color("mauve", 1), + border="1px solid #eaeaea", + padding="16px", + width="400px", + border_radius="8px", + ), + ) + + +def chatapp() -> rx.Component: + return rx.vstack( navbar(), chat.chat(), chat.action_bar(), @@ -18,11 +45,20 @@ def index() -> rx.Component: ) +def index() -> rx.Component: + """The main app.""" + if not os.getenv("PASSWORD"): + return chatapp() + + else: + return rx.cond(State.correct_password, chatapp(), login()) + + # Add state and page to the app. app = rx.App( theme=rx.theme( - appearance="dark", - accent_color="violet", + appearance="light", + accent_color="red", ), ) app.add_page(index)