Skip to content

Commit

Permalink
Start of frontend changes
Browse files Browse the repository at this point in the history
Time travel, use checkpoint as primary source of truth

Refactor state management for chat window

Add support for state graph

Fixes

Pare down unneeded functionality, frontend updates

Fix repeated history fetches

Add basic state graph support, many other fixes

Revise state graph time travel flow

Use message graph as default

Fix flashing messages in UI on send

Allow adding and deleting tool calls

Hacks!

Only accept module paths

More logs

add env

add built ui files

Build ui files

Update cli

Delete .github/workflows/build_deploy_image.yml

Update path

Update ui files

Move migrations

Move ui files

0.0.5

Allow resume execution for tool messages (#2)

Undo

Undo

Remove cli

Undo

Undo

Update storage/threads

Undo ui

Undo

Lint

Undo

Rm

Undo

Rm

Update api

Undo

WIP
  • Loading branch information
jacoblee93 authored and nfcampos committed Apr 3, 2024
1 parent dd4b9f7 commit d9fbd71
Show file tree
Hide file tree
Showing 23 changed files with 1,375 additions and 257 deletions.
7 changes: 3 additions & 4 deletions backend/app/api/runs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from typing import Optional, Sequence
from typing import Dict, Optional, Sequence, Union

import langsmith.client
from fastapi import APIRouter, BackgroundTasks, HTTPException, Request
Expand All @@ -26,8 +26,8 @@ class CreateRunPayload(BaseModel):

assistant_id: str
thread_id: str
input: Optional[Sequence[AnyMessage]] = Field(default_factory=list)
config: Optional[RunnableConfig] = None
input: Optional[Union[Sequence[AnyMessage], Dict]] = Field(default_factory=list)


async def _run_input_and_config(request: Request, opengpts_user_id: OpengptsUserId):
Expand All @@ -41,7 +41,7 @@ async def _run_input_and_config(request: Request, opengpts_user_id: OpengptsUser
config: RunnableConfig = {
**assistant["config"],
"configurable": {
**assistant["config"]["configurable"],
**assistant["config"].get("configurable", {}),
**(body.get("config", {}).get("configurable") or {}),
"user_id": opengpts_user_id,
"thread_id": body["thread_id"],
Expand Down Expand Up @@ -81,7 +81,6 @@ async def stream_run(
):
"""Create a run."""
input_, config = await _run_input_and_config(request, opengpts_user_id)

return EventSourceResponse(to_sse(astream_messages(agent, input_, config)))


Expand Down
23 changes: 13 additions & 10 deletions backend/app/api/threads.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Annotated, List, Sequence
from typing import Annotated, Any, Dict, List, Optional, Sequence, Union
from uuid import uuid4

from fastapi import APIRouter, HTTPException, Path
Expand All @@ -21,10 +21,11 @@ class ThreadPutRequest(BaseModel):
assistant_id: str = Field(..., description="The ID of the assistant to use.")


class ThreadMessagesPostRequest(BaseModel):
class ThreadPostRequest(BaseModel):
"""Payload for adding messages to a thread."""

messages: Sequence[AnyMessage]
values: Optional[Union[Dict[str, Any], Sequence[AnyMessage]]]
config: Optional[Dict[str, Any]] = None


@router.get("/")
Expand All @@ -33,23 +34,25 @@ async def list_threads(opengpts_user_id: OpengptsUserId) -> List[Thread]:
return await storage.list_threads(opengpts_user_id)


@router.get("/{tid}/messages")
async def get_thread_messages(
@router.get("/{tid}/state")
async def get_thread_state(
opengpts_user_id: OpengptsUserId,
tid: ThreadID,
):
"""Get all messages for a thread."""
return await storage.get_thread_messages(opengpts_user_id, tid)
return await storage.get_thread_state(opengpts_user_id, tid)


@router.post("/{tid}/messages")
async def add_thread_messages(
@router.post("/{tid}/state")
async def update_thread_state(
payload: ThreadPostRequest,
opengpts_user_id: OpengptsUserId,
tid: ThreadID,
payload: ThreadMessagesPostRequest,
):
"""Add messages to a thread."""
return await storage.post_thread_messages(opengpts_user_id, tid, payload.messages)
return await storage.update_thread_state(
payload.config or {"configurable": {"thread_id": tid}}, payload.values
)


@router.get("/{tid}/history")
Expand Down
4 changes: 2 additions & 2 deletions backend/app/checkpoint.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from datetime import datetime
import pickle
from datetime import datetime
from typing import AsyncIterator, Optional

from langchain_core.runnables import ConfigurableFieldSpec, RunnableConfig
from langgraph.checkpoint import BaseCheckpointSaver
from langgraph.checkpoint.base import Checkpoint, CheckpointTuple, CheckpointThreadTs
from langgraph.checkpoint.base import Checkpoint, CheckpointThreadTs, CheckpointTuple

from app.lifespan import get_pg_pool

Expand Down
25 changes: 12 additions & 13 deletions backend/app/storage.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from datetime import datetime, timezone
from typing import List, Optional, Sequence
from typing import Any, List, Optional, Sequence, Union

from langchain_core.messages import AnyMessage
from langchain_core.runnables import RunnableConfig

from app.agent import AgentType, get_agent_executor
from app.lifespan import get_pg_pool
from app.schema import Assistant, Thread
from app.stream import map_chunk_to_msg


async def list_assistants(user_id: str) -> List[Assistant]:
Expand Down Expand Up @@ -99,37 +99,36 @@ async def get_thread(user_id: str, thread_id: str) -> Optional[Thread]:
)


async def get_thread_messages(user_id: str, thread_id: str):
async def get_thread_state(user_id: str, thread_id: str):
"""Get all messages for a thread."""
app = get_agent_executor([], AgentType.GPT_35_TURBO, "", False)
state = await app.aget_state({"configurable": {"thread_id": thread_id}})
return {
"messages": [map_chunk_to_msg(msg) for msg in state.values],
"resumeable": bool(state.next),
"values": state.values,
"next": state.next,
}


async def post_thread_messages(
user_id: str, thread_id: str, messages: Sequence[AnyMessage]
async def update_thread_state(
config: RunnableConfig, messages: Union[Sequence[AnyMessage], dict[str, Any]]
):
"""Add messages to a thread."""
app = get_agent_executor([], AgentType.GPT_35_TURBO, "", False)
await app.aupdate_state({"configurable": {"thread_id": thread_id}}, messages)
return await app.aupdate_state(config, messages)


async def get_thread_history(user_id: str, thread_id: str):
"""Get the history of a thread."""
app = get_agent_executor([], AgentType.GPT_35_TURBO, "", False)
config = {"configurable": {"thread_id": thread_id}}
return [
{
"values": [map_chunk_to_msg(msg) for msg in c.values],
"resumeable": bool(c.next),
"values": c.values,
"next": c.next,
"config": c.config,
"parent": c.parent_config,
}
async for c in app.aget_state_history(
{"configurable": {"thread_id": thread_id}}
)
async for c in app.aget_state_history(config)
]


Expand Down
33 changes: 2 additions & 31 deletions backend/app/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,7 @@
from typing import AsyncIterator, Optional, Sequence, Union

import orjson
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
AnyMessage,
BaseMessage,
BaseMessageChunk,
ChatMessage,
ChatMessageChunk,
FunctionMessage,
FunctionMessageChunk,
HumanMessage,
HumanMessageChunk,
SystemMessage,
)
from langchain_core.messages import AnyMessage, SystemMessage, message_chunk_to_message
from langchain_core.runnables import Runnable, RunnableConfig
from langserve.serialization import WellKnownLCSerializer

Expand Down Expand Up @@ -76,22 +63,6 @@ async def astream_messages(
yield last_messages_list


def map_chunk_to_msg(chunk: BaseMessageChunk) -> BaseMessage:
if not isinstance(chunk, BaseMessageChunk):
return chunk
args = {k: v for k, v in chunk.__dict__.items() if k != "type"}
if isinstance(chunk, HumanMessageChunk):
return HumanMessage(**args)
elif isinstance(chunk, AIMessageChunk):
return AIMessage(**args)
elif isinstance(chunk, FunctionMessageChunk):
return FunctionMessage(**args)
elif isinstance(chunk, ChatMessageChunk):
return ChatMessage(**args)
else:
raise ValueError(f"Unknown chunk type: {chunk}")


_serializer = WellKnownLCSerializer()


Expand All @@ -111,7 +82,7 @@ async def to_sse(messages_stream: MessagesStream) -> AsyncIterator[dict]:
yield {
"event": "data",
"data": _serializer.dumps(
[map_chunk_to_msg(msg) for msg in chunk]
[message_chunk_to_message(msg) for msg in chunk]
).decode(),
}
except Exception:
Expand Down
2 changes: 1 addition & 1 deletion backend/app/upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import os
from typing import Any, BinaryIO, List, Optional

from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter
from langchain_community.document_loaders.blob_loaders.schema import Blob
from langchain_community.vectorstores.pgvector import PGVector
from langchain_core.runnables import (
Expand All @@ -21,6 +20,7 @@
)
from langchain_core.vectorstores import VectorStore
from langchain_openai import OpenAIEmbeddings
from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter

from app.ingest import ingest_blob
from app.parsing import MIMETYPE_BASED_PARSER
Expand Down
25 changes: 12 additions & 13 deletions backend/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions backend/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ orjson = "^3.9.10"
python-multipart = "^0.0.6"
tiktoken = "^0.5.1"
langchain = ">=0.0.338"
langgraph = "^0.0.29"
langgraph = "^0.0.30"
pydantic = "<2.0"
python-magic = "^0.4.27"
langchain-openai = "^0.0.8"
langchain-openai = "^0.1.1"
beautifulsoup4 = "^4.12.3"
boto3 = "^1.34.28"
duckduckgo-search = "^4.2"
Expand All @@ -42,7 +42,7 @@ unstructured = {extras = ["doc", "docx"], version = "^0.12.5"}
pgvector = "^0.2.5"
psycopg2-binary = "^2.9.9"
asyncpg = "^0.29.0"
langchain-core = "^0.1.33"
langchain-core = "^0.1.36"

[tool.poetry.group.dev.dependencies]
uvicorn = "^0.23.2"
Expand Down
4 changes: 2 additions & 2 deletions backend/tests/unit_tests/app/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,9 @@ async def test_threads() -> None:
)
assert response.status_code == 200, response.text

response = await client.get(f"/threads/{tid}/messages", headers=headers)
response = await client.get(f"/threads/{tid}/state", headers=headers)
assert response.status_code == 200
assert response.json() == {"messages": [], "resumeable": False}
assert response.json() == {"values": [], "resumeable": False}

response = await client.get("/threads/", headers=headers)

Expand Down
4 changes: 4 additions & 0 deletions frontend/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
"name": "frontend",
"private": true,
"version": "0.0.0",
"packageManager": "[email protected]",
"type": "module",
"scripts": {
"dev": "vite --host",
Expand All @@ -11,9 +12,12 @@
"format": "prettier -w src"
},
"dependencies": {
"@emotion/react": "^11.11.4",
"@emotion/styled": "^11.11.0",
"@headlessui/react": "^1.7.17",
"@heroicons/react": "^2.0.18",
"@microsoft/fetch-event-source": "^2.0.1",
"@mui/material": "^5.15.14",
"@tailwindcss/forms": "^0.5.6",
"@tailwindcss/typography": "^0.5.10",
"clsx": "^2.0.0",
Expand Down
Loading

0 comments on commit d9fbd71

Please sign in to comment.