Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update backend API interfaces to be agnostic to messages list (i.e. MessageGraph) #294

Merged
merged 3 commits into from
Apr 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 10 additions & 11 deletions API.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ We can check the thread, and see that it is currently empty:
```python
import requests
requests.get(
'http://127.0.0.1:8100/threads/231dc7f3-33ee-4040-98fe-27f6e2aa8b2b/messages',
'http://127.0.0.1:8100/threads/231dc7f3-33ee-4040-98fe-27f6e2aa8b2b/state',
cookies= {"opengpts_user_id": "foo"}
).content
```
Expand All @@ -90,9 +90,9 @@ Let's add a message to the thread!
```python
import requests
requests.post(
'http://127.0.0.1:8100/threads/231dc7f3-33ee-4040-98fe-27f6e2aa8b2b/messages',
'http://127.0.0.1:8100/threads/231dc7f3-33ee-4040-98fe-27f6e2aa8b2b/state',
cookies= {"opengpts_user_id": "foo"}, json={
"messages": [{
"values": [{
"content": "hi! my name is bob",
"type": "human",
}]
Expand All @@ -105,12 +105,12 @@ If we now run the command to see the thread, we can see that there is now a mess
```python
import requests
requests.get(
'http://127.0.0.1:8100/threads/231dc7f3-33ee-4040-98fe-27f6e2aa8b2b/messages',
'http://127.0.0.1:8100/threads/231dc7f3-33ee-4040-98fe-27f6e2aa8b2b/state',
cookies= {"opengpts_user_id": "foo"}
).content
```
```shell
b'{"messages":[{"content":"hi! my name is bob","additional_kwargs":{},"type":"human","example":false}]}'
b'{"values":[{"content":"hi! my name is bob","additional_kwargs":{},"type":"human","example":false}],"next":[]}'
```

## Run the assistant on that thread
Expand All @@ -133,10 +133,10 @@ If we now check the thread, we can see (after a bit) that there is a message fro

```python
import requests
requests.get('http://127.0.0.1:8100/threads/231dc7f3-33ee-4040-98fe-27f6e2aa8b2b/messages', cookies= {"opengpts_user_id": "foo"}).content
requests.get('http://127.0.0.1:8100/threads/231dc7f3-33ee-4040-98fe-27f6e2aa8b2b/state', cookies= {"opengpts_user_id": "foo"}).content
```
```shell
b'{"messages":[{"content":"hi! my name is bob","additional_kwargs":{},"type":"human","example":false},{"content":"Hello, Bob! How can I assist you today?","additional_kwargs":{"agent":{"return_values":{"output":"Hello, Bob! How can I assist you today?"},"log":"Hello, Bob! How can I assist you today?","type":"AgentFinish"}},"type":"ai","example":false}]}'
b'{"values":[{"content":"hi! my name is bob","additional_kwargs":{},"type":"human","example":false},{"content":"Hello, Bob! How can I assist you today?","additional_kwargs":{"agent":{"return_values":{"output":"Hello, Bob! How can I assist you today?"},"log":"Hello, Bob! How can I assist you today?","type":"AgentFinish"}},"type":"ai","example":false}],"next":[]}'
```

## Run the assistant on the thread with new messages
Expand All @@ -153,8 +153,7 @@ requests.post('http://127.0.0.1:8100/runs', cookies= {"opengpts_user_id": "foo"}
"messages": [{
"content": "whats my name? respond in spanish",
"type": "human",
}
]
}]
}
}).content
```
Expand All @@ -163,11 +162,11 @@ Then, if we call the threads endpoint after a bit we can see the human message -

```python
import requests
requests.get('http://127.0.0.1:8100/threads/231dc7f3-33ee-4040-98fe-27f6e2aa8b2b/messages', cookies= {"opengpts_user_id": "foo"}).content
requests.get('http://127.0.0.1:8100/threads/231dc7f3-33ee-4040-98fe-27f6e2aa8b2b/state', cookies= {"opengpts_user_id": "foo"}).content
```

```shell
b'{"messages":[{"content":"hi! my name is bob","additional_kwargs":{},"type":"human","example":false},{"content":"Hello, Bob! How can I assist you today?","additional_kwargs":{"agent":{"return_values":{"output":"Hello, Bob! How can I assist you today?"},"log":"Hello, Bob! How can I assist you today?","type":"AgentFinish"}},"type":"ai","example":false},{"content":"whats my name? respond in spanish","additional_kwargs":{},"type":"human","example":false},{"content":"Tu nombre es Bob.","additional_kwargs":{"agent":{"return_values":{"output":"Tu nombre es Bob."},"log":"Tu nombre es Bob.","type":"AgentFinish"}},"type":"ai","example":false}]}'
b'{"values":[{"content":"hi! my name is bob","additional_kwargs":{},"type":"human","example":false},{"content":"Hello, Bob! How can I assist you today?","additional_kwargs":{"agent":{"return_values":{"output":"Hello, Bob! How can I assist you today?"},"log":"Hello, Bob! How can I assist you today?","type":"AgentFinish"}},"type":"ai","example":false},{"content":"whats my name? respond in spanish","additional_kwargs":{},"type":"human","example":false},{"content":"Tu nombre es Bob.","additional_kwargs":{"agent":{"return_values":{"output":"Tu nombre es Bob."},"log":"Tu nombre es Bob.","type":"AgentFinish"}},"type":"ai","example":false}],"next":[]}'
```

## Stream
Expand Down
6 changes: 4 additions & 2 deletions backend/app/api/runs.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Sequence
from typing import Any, Dict, Optional, Sequence, Union

import langsmith.client
from fastapi import APIRouter, BackgroundTasks, HTTPException
Expand All @@ -24,7 +24,9 @@ class CreateRunPayload(BaseModel):
"""Payload for creating a run."""

thread_id: str
input: Optional[Sequence[AnyMessage]] = Field(default_factory=list)
input: Optional[Union[Sequence[AnyMessage], Dict[str, Any]]] = Field(
default_factory=dict
)
config: Optional[RunnableConfig] = None


Expand Down
26 changes: 13 additions & 13 deletions backend/app/api/threads.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Annotated, List, Sequence
from typing import Annotated, Any, Dict, List, Sequence, Union
from uuid import uuid4

from fastapi import APIRouter, HTTPException, Path
Expand All @@ -21,10 +21,10 @@ class ThreadPutRequest(BaseModel):
assistant_id: str = Field(..., description="The ID of the assistant to use.")


class ThreadMessagesPostRequest(BaseModel):
"""Payload for adding messages to a thread."""
class ThreadPostRequest(BaseModel):
"""Payload for adding state to a thread."""

messages: Sequence[AnyMessage]
values: Union[Sequence[AnyMessage], Dict[str, Any]]


@router.get("/")
Expand All @@ -33,23 +33,23 @@ async def list_threads(opengpts_user_id: OpengptsUserId) -> List[Thread]:
return await storage.list_threads(opengpts_user_id)


@router.get("/{tid}/messages")
async def get_thread_messages(
@router.get("/{tid}/state")
async def get_thread_state(
opengpts_user_id: OpengptsUserId,
tid: ThreadID,
):
"""Get all messages for a thread."""
return await storage.get_thread_messages(opengpts_user_id, tid)
"""Get state for a thread."""
return await storage.get_thread_state(opengpts_user_id, tid)


@router.post("/{tid}/messages")
async def add_thread_messages(
@router.post("/{tid}/state")
async def add_thread_state(
opengpts_user_id: OpengptsUserId,
tid: ThreadID,
payload: ThreadMessagesPostRequest,
payload: ThreadPostRequest,
):
"""Add messages to a thread."""
return await storage.post_thread_messages(opengpts_user_id, tid, payload.messages)
"""Add state to a thread."""
return await storage.update_thread_state(opengpts_user_id, tid, payload.values)


@router.get("/{tid}/history")
Expand Down
20 changes: 10 additions & 10 deletions backend/app/storage.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from datetime import datetime, timezone
from typing import List, Optional, Sequence
from typing import Any, Dict, List, Optional, Sequence, Union

from langchain_core.messages import AnyMessage

Expand Down Expand Up @@ -98,22 +98,22 @@ async def get_thread(user_id: str, thread_id: str) -> Optional[Thread]:
)


async def get_thread_messages(user_id: str, thread_id: str):
"""Get all messages for a thread."""
async def get_thread_state(user_id: str, thread_id: str):
"""Get state for a thread."""
app = get_agent_executor([], AgentType.GPT_35_TURBO, "", False)
state = await app.aget_state({"configurable": {"thread_id": thread_id}})
return {
"messages": state.values,
"resumeable": bool(state.next),
"values": state.values,
"next": state.next,
}


async def post_thread_messages(
user_id: str, thread_id: str, messages: Sequence[AnyMessage]
async def update_thread_state(
user_id: str, thread_id: str, values: Union[Sequence[AnyMessage], Dict[str, Any]]
):
"""Add messages to a thread."""
"""Add state to a thread."""
app = get_agent_executor([], AgentType.GPT_35_TURBO, "", False)
await app.aupdate_state({"configurable": {"thread_id": thread_id}}, messages)
await app.aupdate_state({"configurable": {"thread_id": thread_id}}, values)


async def get_thread_history(user_id: str, thread_id: str):
Expand All @@ -122,7 +122,7 @@ async def get_thread_history(user_id: str, thread_id: str):
return [
{
"values": c.values,
"resumeable": bool(c.next),
"next": c.next,
"config": c.config,
"parent": c.parent_config,
}
Expand Down
4 changes: 2 additions & 2 deletions backend/tests/unit_tests/app/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,9 @@ async def test_threads() -> None:
)
assert response.status_code == 200, response.text

response = await client.get(f"/threads/{tid}/messages", headers=headers)
response = await client.get(f"/threads/{tid}/state", headers=headers)
assert response.status_code == 200
assert response.json() == {"messages": [], "resumeable": False}
assert response.json() == {"values": [], "next": []}

response = await client.get("/threads/", headers=headers)

Expand Down
4 changes: 2 additions & 2 deletions frontend/src/components/Chat.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ function usePrevious<T>(value: T): T | undefined {

export function Chat(props: ChatProps) {
const { chatId } = useParams();
const { messages, resumeable } = useChatMessages(
const { messages, next } = useChatMessages(
chatId ?? null,
props.stream,
props.stopStream,
Expand Down Expand Up @@ -71,7 +71,7 @@ export function Chat(props: ChatProps) {
An error has occurred. Please try again.
</div>
)}
{resumeable && props.stream?.status !== "inflight" && (
{next.length > 0 && props.stream?.status !== "inflight" && (
<div
className="flex items-center rounded-md bg-blue-50 px-2 py-1 text-xs font-medium text-blue-800 ring-1 ring-inset ring-yellow-600/20 cursor-pointer"
onClick={() => props.startStream(null, currentChat.thread_id)}
Expand Down
37 changes: 17 additions & 20 deletions frontend/src/hooks/useChatMessages.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,13 @@ import { useEffect, useMemo, useRef, useState } from "react";
import { Message } from "./useChatList";
import { StreamState, mergeMessagesById } from "./useStreamState";

async function getMessages(threadId: string) {
const { messages, resumeable } = await fetch(
`/threads/${threadId}/messages`,
{
headers: {
Accept: "application/json",
},
async function getState(threadId: string) {
const { values, next } = await fetch(`/threads/${threadId}/state`, {
headers: {
Accept: "application/json",
},
).then((r) => r.json());
return { messages, resumeable };
}).then((r) => r.json());
return { values, next };
}

function usePrevious<T>(value: T): T | undefined {
Expand All @@ -26,17 +23,17 @@ export function useChatMessages(
threadId: string | null,
stream: StreamState | null,
stopStream?: (clear?: boolean) => void,
): { messages: Message[] | null; resumeable: boolean } {
): { messages: Message[] | null; next: string[] } {
const [messages, setMessages] = useState<Message[] | null>(null);
const [resumeable, setResumeable] = useState(false);
const [next, setNext] = useState<string[]>([]);
const prevStreamStatus = usePrevious(stream?.status);

useEffect(() => {
async function fetchMessages() {
if (threadId) {
const { messages, resumeable } = await getMessages(threadId);
setMessages(messages);
setResumeable(resumeable);
const { values, next } = await getState(threadId);
setMessages(values);
setNext(next);
}
}

Expand All @@ -50,15 +47,15 @@ export function useChatMessages(
useEffect(() => {
async function fetchMessages() {
if (threadId) {
const { messages, resumeable } = await getMessages(threadId);
setMessages(messages);
setResumeable(resumeable);
const { values, next } = await getState(threadId);
setMessages(values);
setNext(next);
stopStream?.(true);
}
}

if (prevStreamStatus === "inflight" && stream?.status !== "inflight") {
setResumeable(false);
setNext([]);
fetchMessages();
}

Expand All @@ -68,8 +65,8 @@ export function useChatMessages(
return useMemo(
() => ({
messages: mergeMessagesById(messages, stream?.messages),
resumeable,
next,
}),
[messages, stream?.messages, resumeable],
[messages, stream?.messages, next],
);
}
Loading