From 0a892dc1bb7f8ade655027f2559ec3e84bcfa148 Mon Sep 17 00:00:00 2001 From: Tat Dat Duong Date: Mon, 27 Nov 2023 23:28:49 +0100 Subject: [PATCH] Handle AIMessageChunk, attempt to gather non-dict values --- .../conversational_retrieval_chain/server.py | 2 +- examples/widgets/server.py | 2 +- .../ChatMessageTuplesControlRenderer.tsx | 20 ++++++++++++++++--- .../ChatMessagesControlRenderer.tsx | 18 ++++++++++++++--- langserve/playground/src/utils/messages.ts | 7 +++++++ langserve/playground/src/utils/path.ts | 9 ++++++++- 6 files changed, 49 insertions(+), 9 deletions(-) create mode 100644 langserve/playground/src/utils/messages.ts diff --git a/examples/conversational_retrieval_chain/server.py b/examples/conversational_retrieval_chain/server.py index d676498d..956f3520 100755 --- a/examples/conversational_retrieval_chain/server.py +++ b/examples/conversational_retrieval_chain/server.py @@ -87,7 +87,7 @@ class ChatHistory(BaseModel): chat_history: List[Tuple[str, str]] = Field( ..., - extra={"widget": {"type": "chat", "input": "question", "output": "output"}}, + extra={"widget": {"type": "chat", "input": "question"}}, ) question: str diff --git a/examples/widgets/server.py b/examples/widgets/server.py index 66f8dc2d..83219474 100755 --- a/examples/widgets/server.py +++ b/examples/widgets/server.py @@ -48,7 +48,7 @@ class ChatHistory(BaseModel): class ChatHistoryMessage(BaseModel): chat_history: List[BaseMessage] = Field( ..., - extra={"widget": {"type": "chat", "input": "location", "output": "output"}}, + extra={"widget": {"type": "chat", "input": "location"}}, ) location: str diff --git a/langserve/playground/src/components/ChatMessageTuplesControlRenderer.tsx b/langserve/playground/src/components/ChatMessageTuplesControlRenderer.tsx index 35627b0e..9ae40253 100644 --- a/langserve/playground/src/components/ChatMessageTuplesControlRenderer.tsx +++ b/langserve/playground/src/components/ChatMessageTuplesControlRenderer.tsx @@ -11,7 +11,8 @@ import { import { AutosizeTextarea } from "./AutosizeTextarea"; import { isJsonSchemaExtra } from "../utils/schema"; import { useStreamCallback } from "../useStreamCallback"; -import { traverseNaiveJsonPath } from "../utils/path"; +import { getNormalizedJsonPath, traverseNaiveJsonPath } from "../utils/path"; +import { getMessageContent } from "../utils/messages"; type MessageTuple = [string, string]; @@ -53,9 +54,22 @@ export const ChatMessageTuplesControlRenderer = withJsonFormsControlProps( const widget = props.schema.extra.widget; if (!("input" in widget) && !("output" in widget)) return; - const human = traverseNaiveJsonPath(ctx.input, widget.input ?? ""); - const ai = traverseNaiveJsonPath(ctx.output, widget.output ?? ""); + const inputPath = getNormalizedJsonPath(widget.input ?? ""); + const outputPath = getNormalizedJsonPath(widget.output ?? ""); + const isSingleOutputKey = + ctx.output != null && + Object.keys(ctx.output).length === 1 && + Object.keys(ctx.output)[0] === "output"; + + const human = traverseNaiveJsonPath(ctx.input, inputPath); + let ai = traverseNaiveJsonPath(ctx.output, outputPath); + + if (isSingleOutputKey) { + ai = traverseNaiveJsonPath(ai, ["output", ...outputPath]) ?? ai; + } + + ai = getMessageContent(ai); if (typeof human === "string" && typeof ai === "string") { props.handleChange(props.path, [...data, [human, ai]]); } diff --git a/langserve/playground/src/components/ChatMessagesControlRenderer.tsx b/langserve/playground/src/components/ChatMessagesControlRenderer.tsx index 34120352..cec9b48c 100644 --- a/langserve/playground/src/components/ChatMessagesControlRenderer.tsx +++ b/langserve/playground/src/components/ChatMessagesControlRenderer.tsx @@ -12,7 +12,7 @@ import { } from "@jsonforms/core"; import { AutosizeTextarea } from "./AutosizeTextarea"; import { useStreamCallback } from "../useStreamCallback"; -import { traverseNaiveJsonPath } from "../utils/path"; +import { getNormalizedJsonPath, traverseNaiveJsonPath } from "../utils/path"; import { isJsonSchemaExtra } from "../utils/schema"; import * as ToggleGroup from "@radix-ui/react-toggle-group"; @@ -122,8 +122,20 @@ export const ChatMessagesControlRenderer = withJsonFormsControlProps( const widget = props.schema.extra.widget; if (!("input" in widget) && !("output" in widget)) return; - const human = traverseNaiveJsonPath(ctx.input, widget.input ?? ""); - const ai = traverseNaiveJsonPath(ctx.output, widget.output ?? ""); + const inputPath = getNormalizedJsonPath(widget.input ?? ""); + const outputPath = getNormalizedJsonPath(widget.output ?? ""); + + const human = traverseNaiveJsonPath(ctx.input, inputPath); + let ai = traverseNaiveJsonPath(ctx.output, outputPath); + + const isSingleOutputKey = + ctx.output != null && + Object.keys(ctx.output).length === 1 && + Object.keys(ctx.output)[0] === "output"; + + if (isSingleOutputKey) { + ai = traverseNaiveJsonPath(ai, ["output", ...outputPath]) ?? ai; + } const humanMsg = constructMessage(human, "human"); const aiMsg = constructMessage(ai, "ai"); diff --git a/langserve/playground/src/utils/messages.ts b/langserve/playground/src/utils/messages.ts new file mode 100644 index 00000000..aaea838a --- /dev/null +++ b/langserve/playground/src/utils/messages.ts @@ -0,0 +1,7 @@ +export function getMessageContent(x: unknown) { + if (typeof x === "string") return x; + if (typeof x === "object" && x != null) { + if ("content" in x && typeof x.content === "string") return x.content; + } + return null; +} diff --git a/langserve/playground/src/utils/path.ts b/langserve/playground/src/utils/path.ts index 1def533b..16a8114d 100644 --- a/langserve/playground/src/utils/path.ts +++ b/langserve/playground/src/utils/path.ts @@ -2,11 +2,18 @@ function isAccessibleObject(x: unknown): x is Record { return typeof x === "object" && x != null; } +export function getNormalizedJsonPath( + path: string | number | Array +) { + return Array.isArray(path) ? path : [path]; +} + export function traverseNaiveJsonPath( x: unknown, path: string | number | Array ) { - const queue = Array.isArray(path) ? path : [path]; + const queue = getNormalizedJsonPath(path); + let tmp: unknown = x; while (queue.length > 0) { const first = queue.shift()!;