diff --git a/backend/app/api/runs.py b/backend/app/api/runs.py index 70a72968..63b7c5db 100644 --- a/backend/app/api/runs.py +++ b/backend/app/api/runs.py @@ -1,5 +1,5 @@ import json -from typing import Optional, Sequence +from typing import Dict, Optional, Sequence, Union import langsmith.client from fastapi import APIRouter, BackgroundTasks, HTTPException, Request @@ -26,8 +26,8 @@ class CreateRunPayload(BaseModel): assistant_id: str thread_id: str - input: Optional[Sequence[AnyMessage]] = Field(default_factory=list) config: Optional[RunnableConfig] = None + input: Optional[Union[Sequence[AnyMessage], Dict]] = Field(default_factory=list) async def _run_input_and_config(request: Request, opengpts_user_id: OpengptsUserId): @@ -41,7 +41,7 @@ async def _run_input_and_config(request: Request, opengpts_user_id: OpengptsUser config: RunnableConfig = { **assistant["config"], "configurable": { - **assistant["config"]["configurable"], + **assistant["config"].get("configurable", {}), **(body.get("config", {}).get("configurable") or {}), "user_id": opengpts_user_id, "thread_id": body["thread_id"], @@ -81,7 +81,6 @@ async def stream_run( ): """Create a run.""" input_, config = await _run_input_and_config(request, opengpts_user_id) - return EventSourceResponse(to_sse(astream_messages(agent, input_, config))) diff --git a/backend/app/api/threads.py b/backend/app/api/threads.py index 31fe6584..5087f895 100644 --- a/backend/app/api/threads.py +++ b/backend/app/api/threads.py @@ -1,4 +1,4 @@ -from typing import Annotated, List, Sequence +from typing import Annotated, Any, Dict, List, Optional, Sequence, Union from uuid import uuid4 from fastapi import APIRouter, HTTPException, Path @@ -21,10 +21,11 @@ class ThreadPutRequest(BaseModel): assistant_id: str = Field(..., description="The ID of the assistant to use.") -class ThreadMessagesPostRequest(BaseModel): +class ThreadPostRequest(BaseModel): """Payload for adding messages to a thread.""" - messages: Sequence[AnyMessage] + values: Optional[Union[Dict[str, Any], Sequence[AnyMessage]]] + config: Optional[Dict[str, Any]] = None @router.get("/") @@ -33,23 +34,25 @@ async def list_threads(opengpts_user_id: OpengptsUserId) -> List[Thread]: return await storage.list_threads(opengpts_user_id) -@router.get("/{tid}/messages") -async def get_thread_messages( +@router.get("/{tid}/state") +async def get_thread_state( opengpts_user_id: OpengptsUserId, tid: ThreadID, ): """Get all messages for a thread.""" - return await storage.get_thread_messages(opengpts_user_id, tid) + return await storage.get_thread_state(opengpts_user_id, tid) -@router.post("/{tid}/messages") -async def add_thread_messages( +@router.post("/{tid}/state") +async def update_thread_state( + payload: ThreadPostRequest, opengpts_user_id: OpengptsUserId, tid: ThreadID, - payload: ThreadMessagesPostRequest, ): """Add messages to a thread.""" - return await storage.post_thread_messages(opengpts_user_id, tid, payload.messages) + return await storage.update_thread_state( + payload.config or {"configurable": {"thread_id": tid}}, payload.values + ) @router.get("/{tid}/history") diff --git a/backend/app/checkpoint.py b/backend/app/checkpoint.py index 88df9433..9e4681b4 100644 --- a/backend/app/checkpoint.py +++ b/backend/app/checkpoint.py @@ -1,10 +1,10 @@ -from datetime import datetime import pickle +from datetime import datetime from typing import AsyncIterator, Optional from langchain_core.runnables import ConfigurableFieldSpec, RunnableConfig from langgraph.checkpoint import BaseCheckpointSaver -from langgraph.checkpoint.base import Checkpoint, CheckpointTuple, CheckpointThreadTs +from langgraph.checkpoint.base import Checkpoint, CheckpointThreadTs, CheckpointTuple from app.lifespan import get_pg_pool diff --git a/backend/app/storage.py b/backend/app/storage.py index bd87bfaa..d8c401b1 100644 --- a/backend/app/storage.py +++ b/backend/app/storage.py @@ -1,12 +1,12 @@ from datetime import datetime, timezone -from typing import List, Optional, Sequence +from typing import Any, List, Optional, Sequence, Union from langchain_core.messages import AnyMessage +from langchain_core.runnables import RunnableConfig from app.agent import AgentType, get_agent_executor from app.lifespan import get_pg_pool from app.schema import Assistant, Thread -from app.stream import map_chunk_to_msg async def list_assistants(user_id: str) -> List[Assistant]: @@ -99,37 +99,36 @@ async def get_thread(user_id: str, thread_id: str) -> Optional[Thread]: ) -async def get_thread_messages(user_id: str, thread_id: str): +async def get_thread_state(user_id: str, thread_id: str): """Get all messages for a thread.""" app = get_agent_executor([], AgentType.GPT_35_TURBO, "", False) state = await app.aget_state({"configurable": {"thread_id": thread_id}}) return { - "messages": [map_chunk_to_msg(msg) for msg in state.values], - "resumeable": bool(state.next), + "values": state.values, + "next": state.next, } -async def post_thread_messages( - user_id: str, thread_id: str, messages: Sequence[AnyMessage] +async def update_thread_state( + config: RunnableConfig, messages: Union[Sequence[AnyMessage], dict[str, Any]] ): """Add messages to a thread.""" app = get_agent_executor([], AgentType.GPT_35_TURBO, "", False) - await app.aupdate_state({"configurable": {"thread_id": thread_id}}, messages) + return await app.aupdate_state(config, messages) async def get_thread_history(user_id: str, thread_id: str): """Get the history of a thread.""" app = get_agent_executor([], AgentType.GPT_35_TURBO, "", False) + config = {"configurable": {"thread_id": thread_id}} return [ { - "values": [map_chunk_to_msg(msg) for msg in c.values], - "resumeable": bool(c.next), + "values": c.values, + "next": c.next, "config": c.config, "parent": c.parent_config, } - async for c in app.aget_state_history( - {"configurable": {"thread_id": thread_id}} - ) + async for c in app.aget_state_history(config) ] diff --git a/backend/app/stream.py b/backend/app/stream.py index 9629cdb9..a894dc75 100644 --- a/backend/app/stream.py +++ b/backend/app/stream.py @@ -2,20 +2,7 @@ from typing import AsyncIterator, Optional, Sequence, Union import orjson -from langchain_core.messages import ( - AIMessage, - AIMessageChunk, - AnyMessage, - BaseMessage, - BaseMessageChunk, - ChatMessage, - ChatMessageChunk, - FunctionMessage, - FunctionMessageChunk, - HumanMessage, - HumanMessageChunk, - SystemMessage, -) +from langchain_core.messages import AnyMessage, SystemMessage, message_chunk_to_message from langchain_core.runnables import Runnable, RunnableConfig from langserve.serialization import WellKnownLCSerializer @@ -76,22 +63,6 @@ async def astream_messages( yield last_messages_list -def map_chunk_to_msg(chunk: BaseMessageChunk) -> BaseMessage: - if not isinstance(chunk, BaseMessageChunk): - return chunk - args = {k: v for k, v in chunk.__dict__.items() if k != "type"} - if isinstance(chunk, HumanMessageChunk): - return HumanMessage(**args) - elif isinstance(chunk, AIMessageChunk): - return AIMessage(**args) - elif isinstance(chunk, FunctionMessageChunk): - return FunctionMessage(**args) - elif isinstance(chunk, ChatMessageChunk): - return ChatMessage(**args) - else: - raise ValueError(f"Unknown chunk type: {chunk}") - - _serializer = WellKnownLCSerializer() @@ -111,7 +82,7 @@ async def to_sse(messages_stream: MessagesStream) -> AsyncIterator[dict]: yield { "event": "data", "data": _serializer.dumps( - [map_chunk_to_msg(msg) for msg in chunk] + [message_chunk_to_message(msg) for msg in chunk] ).decode(), } except Exception: diff --git a/backend/app/upload.py b/backend/app/upload.py index e4adc814..3d8c3243 100644 --- a/backend/app/upload.py +++ b/backend/app/upload.py @@ -11,7 +11,6 @@ import os from typing import Any, BinaryIO, List, Optional -from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter from langchain_community.document_loaders.blob_loaders.schema import Blob from langchain_community.vectorstores.pgvector import PGVector from langchain_core.runnables import ( @@ -21,6 +20,7 @@ ) from langchain_core.vectorstores import VectorStore from langchain_openai import OpenAIEmbeddings +from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter from app.ingest import ingest_blob from app.parsing import MIMETYPE_BASED_PARSER diff --git a/backend/poetry.lock b/backend/poetry.lock index 5061a20a..a1f57875 100644 --- a/backend/poetry.lock +++ b/backend/poetry.lock @@ -1724,17 +1724,16 @@ extended-testing = ["aiosqlite (>=0.19.0,<0.20.0)", "aleph-alpha-client (>=2.15. [[package]] name = "langchain-core" -version = "0.1.33" +version = "0.1.36" description = "Building applications with LLMs through composability" optional = false python-versions = "<4.0,>=3.8.1" files = [ - {file = "langchain_core-0.1.33-py3-none-any.whl", hash = "sha256:cee7fbab114c74b7279a92c8a376b40344b0fa3d0f0af3143a858e3b7485bf13"}, - {file = "langchain_core-0.1.33.tar.gz", hash = "sha256:545eff3de83cc58231bd2b0c6d672323fc2077b94d326ba1a3219118af1d1a66"}, + {file = "langchain_core-0.1.36-py3-none-any.whl", hash = "sha256:564beeb18ab13deca8daf6e6e74acab52e0b8f6202110262a4c914e4450febd2"}, + {file = "langchain_core-0.1.36.tar.gz", hash = "sha256:aa2432370ca3d2a5d6dd14a810aa6488bf2f622ff7a0a3dc30f6e0ed9d7f5fa8"}, ] [package.dependencies] -anyio = ">=3,<5" jsonpatch = ">=1.33,<2.0" langsmith = ">=0.1.0,<0.2.0" packaging = ">=23.2,<24.0" @@ -1766,17 +1765,17 @@ types-requests = ">=2.31.0,<3.0.0" [[package]] name = "langchain-openai" -version = "0.0.8" +version = "0.1.1" description = "An integration package connecting OpenAI and LangChain" optional = false -python-versions = ">=3.8.1,<4.0" +python-versions = "<4.0,>=3.8.1" files = [ - {file = "langchain_openai-0.0.8-py3-none-any.whl", hash = "sha256:4862fc72cecbee0240aaa6df0234d5893dd30cd33ca23ac5cfdd86c11d2c44df"}, - {file = "langchain_openai-0.0.8.tar.gz", hash = "sha256:b7aba7fcc52305e78b08197ebc54fc45cc06dbc40ba5b913bc48a22b30a4f5c9"}, + {file = "langchain_openai-0.1.1-py3-none-any.whl", hash = "sha256:5cf4df5d2550af673337eafedaeec014ba52f9a25aeb8451206ca254bed01e5c"}, + {file = "langchain_openai-0.1.1.tar.gz", hash = "sha256:d10e9a9fc4c8ea99ca98f23808ce44c7dcdd65354ac07ad10afe874ecf3401ca"}, ] [package.dependencies] -langchain-core = ">=0.1.27,<0.2.0" +langchain-core = ">=0.1.33,<0.2.0" openai = ">=1.10.0,<2.0.0" tiktoken = ">=0.5.2,<1" @@ -1829,13 +1828,13 @@ six = "*" [[package]] name = "langgraph" -version = "0.0.29" +version = "0.0.30" description = "langgraph" optional = false python-versions = "<4.0,>=3.9.0" files = [ - {file = "langgraph-0.0.29-py3-none-any.whl", hash = "sha256:ebb308f0afeffda6cc929c466aacfca73a7b561d4ba3fc190202c7bd3905702d"}, - {file = "langgraph-0.0.29.tar.gz", hash = "sha256:12be459587ca7e753f06e94917028d54c7359664e0a04a8cedb0a7b595c4c996"}, + {file = "langgraph-0.0.30-py3-none-any.whl", hash = "sha256:835d34d66d1ec1cad1825eee682f0c8a85edfd591ab552706a4f4211a3119b88"}, + {file = "langgraph-0.0.30.tar.gz", hash = "sha256:58bb9faf081fdb4490cda8cef4421b01b64b241df92c5556c327b2c83cfd5e0f"}, ] [package.dependencies] @@ -4166,4 +4165,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = "^3.9.0,<3.12" -content-hash = "6aba3d05838348cb038b98803ed0d66e70caaa08a56d5a16fdbea3c5c63ec073" +content-hash = "0c5b10e8d1e0f0d40cf099327c335613c7a640dccdf8931a8a82977856e4f064" diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 148ef410..dd16ad2a 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -20,10 +20,10 @@ orjson = "^3.9.10" python-multipart = "^0.0.6" tiktoken = "^0.5.1" langchain = ">=0.0.338" -langgraph = "^0.0.29" +langgraph = "^0.0.30" pydantic = "<2.0" python-magic = "^0.4.27" -langchain-openai = "^0.0.8" +langchain-openai = "^0.1.1" beautifulsoup4 = "^4.12.3" boto3 = "^1.34.28" duckduckgo-search = "^4.2" @@ -42,7 +42,7 @@ unstructured = {extras = ["doc", "docx"], version = "^0.12.5"} pgvector = "^0.2.5" psycopg2-binary = "^2.9.9" asyncpg = "^0.29.0" -langchain-core = "^0.1.33" +langchain-core = "^0.1.36" [tool.poetry.group.dev.dependencies] uvicorn = "^0.23.2" diff --git a/backend/tests/unit_tests/app/test_app.py b/backend/tests/unit_tests/app/test_app.py index f2bfdc6c..d23ccb67 100644 --- a/backend/tests/unit_tests/app/test_app.py +++ b/backend/tests/unit_tests/app/test_app.py @@ -110,9 +110,9 @@ async def test_threads() -> None: ) assert response.status_code == 200, response.text - response = await client.get(f"/threads/{tid}/messages", headers=headers) + response = await client.get(f"/threads/{tid}/state", headers=headers) assert response.status_code == 200 - assert response.json() == {"messages": [], "resumeable": False} + assert response.json() == {"values": [], "resumeable": False} response = await client.get("/threads/", headers=headers) diff --git a/frontend/package.json b/frontend/package.json index cc40c377..42c0db9e 100644 --- a/frontend/package.json +++ b/frontend/package.json @@ -2,6 +2,7 @@ "name": "frontend", "private": true, "version": "0.0.0", + "packageManager": "yarn@1.22.19", "type": "module", "scripts": { "dev": "vite --host", @@ -11,9 +12,12 @@ "format": "prettier -w src" }, "dependencies": { + "@emotion/react": "^11.11.4", + "@emotion/styled": "^11.11.0", "@headlessui/react": "^1.7.17", "@heroicons/react": "^2.0.18", "@microsoft/fetch-event-source": "^2.0.1", + "@mui/material": "^5.15.14", "@tailwindcss/forms": "^0.5.6", "@tailwindcss/typography": "^0.5.10", "clsx": "^2.0.0", diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index df449b19..c06ba8a1 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -4,7 +4,11 @@ import { Chat } from "./components/Chat"; import { ChatList } from "./components/ChatList"; import { Layout } from "./components/Layout"; import { NewChat } from "./components/NewChat"; -import { Chat as ChatType, useChatList } from "./hooks/useChatList"; +import { + Chat as ChatType, + Message as MessageType, + useChatList, +} from "./hooks/useChatList"; import { useSchemas } from "./hooks/useSchemas"; import { useStreamState } from "./hooks/useStreamState"; import { useConfigList } from "./hooks/useConfigList"; @@ -17,7 +21,13 @@ function App() { const { configSchema, configDefaults } = useSchemas(); const { chats, currentChat, createChat, enterChat } = useChatList(); const { configs, currentConfig, saveConfig, enterConfig } = useConfigList(); - const { startStream, stopStream, stream } = useStreamState(); + const { + startStream, + stopStream, + stream, + streamErrorMessage, + setStreamErrorMessage, + } = useStreamState(); const [isDocumentRetrievalActive, setIsDocumentRetrievalActive] = useState(false); @@ -47,12 +57,24 @@ function App() { }, [currentConfig, currentChat, configs]); const startTurn = useCallback( - async (message?: MessageWithFiles, chat: ChatType | null = currentChat) => { + async (props: { + message?: MessageWithFiles; + previousMessages?: MessageType[]; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + config?: Record; + chat?: ChatType | null; + }) => { + const { + message, + previousMessages = [], + chat = currentChat, + config, + } = props; if (!chat) return; - const config = configs?.find( + const defaultConfig = configs?.find( (c) => c.assistant_id === chat.assistant_id, )?.config; - if (!config) return; + if (!defaultConfig) return; const files = message?.files || []; if (files.length > 0) { const formData = files.reduce((formData, file) => { @@ -68,9 +90,10 @@ function App() { body: formData, }); } - await startStream( - message + await startStream({ + input: message ? [ + ...previousMessages, { content: message.message, additional_kwargs: {}, @@ -78,10 +101,13 @@ function App() { example: false, }, ] - : null, - chat.assistant_id, - chat.thread_id, - ); + : previousMessages.length + ? [...previousMessages] + : null, + assistant_id: chat.assistant_id, + thread_id: chat.thread_id, + config, + }); }, [currentChat, startStream, configs], ); @@ -93,7 +119,7 @@ function App() { message.message, currentConfig.assistant_id, ); - return startTurn(message, chat); + return startTurn({ message, chat }); }, [createChat, startTurn, currentConfig], ); @@ -130,6 +156,8 @@ function App() { stopStream={stopStream} stream={stream} isDocumentRetrievalActive={isDocumentRetrievalActive} + streamErrorMessage={streamErrorMessage} + setStreamErrorMessage={setStreamErrorMessage} /> ) : currentConfig ? ( + + + + + + + + + + + + + + + + + + + + diff --git a/frontend/src/components/AutosizeTextarea.tsx b/frontend/src/components/AutosizeTextarea.tsx new file mode 100644 index 00000000..ee0b0b89 --- /dev/null +++ b/frontend/src/components/AutosizeTextarea.tsx @@ -0,0 +1,60 @@ +import { Ref } from "react"; +import { cn } from "../utils/cn"; + +const COMMON_CLS = cn( + "text-sm col-[1] row-[1] m-0 resize-none overflow-hidden whitespace-pre-wrap break-words bg-transparent px-2 py-1 rounded shadow-none", +); + +export function AutosizeTextarea(props: { + id?: string; + inputRef?: Ref; + value?: string | null | undefined; + placeholder?: string; + className?: string; + onChange?: (e: string) => void; + onFocus?: () => void; + onBlur?: () => void; + onKeyDown?: (e: React.KeyboardEvent) => void; + autoFocus?: boolean; + readOnly?: boolean; + cursorPointer?: boolean; + disabled?: boolean; + fullHeight?: boolean; +}) { + return ( +
+