From f875a71d433f70deb679d2b145fd773f817d2d29 Mon Sep 17 00:00:00 2001 From: Tat Dat Duong Date: Fri, 6 Dec 2024 01:24:35 +0100 Subject: [PATCH] feat(tasks): add result to tasks --- libs/langgraph/src/pregel/debug.ts | 29 ++++++++++++++++++++++--- libs/langgraph/src/pregel/index.ts | 7 +++++- libs/langgraph/src/pregel/loop.ts | 3 ++- libs/langgraph/src/pregel/types.ts | 1 + libs/langgraph/src/tests/pregel.test.ts | 1 + 5 files changed, 36 insertions(+), 5 deletions(-) diff --git a/libs/langgraph/src/pregel/debug.ts b/libs/langgraph/src/pregel/debug.ts index 75b51ffa2..a35ea0bdc 100644 --- a/libs/langgraph/src/pregel/debug.ts +++ b/libs/langgraph/src/pregel/debug.ts @@ -151,7 +151,8 @@ export function* mapDebugCheckpoint< metadata: CheckpointMetadata, tasks: readonly PregelExecutableTask[], pendingWrites: CheckpointPendingWrite[], - parentConfig: RunnableConfig | undefined + parentConfig: RunnableConfig | undefined, + outputKeys: string | string[] ) { function formatConfig(config: RunnableConfig) { // https://stackoverflow.com/a/78298178 @@ -214,7 +215,7 @@ export function* mapDebugCheckpoint< values: readChannels(channels, streamChannels), metadata, next: tasks.map((task) => task.name), - tasks: tasksWithWrites(tasks, pendingWrites, taskStates), + tasks: tasksWithWrites(tasks, pendingWrites, taskStates, outputKeys), parentConfig: parentConfig ? formatConfig(parentConfig) : undefined, }, }; @@ -223,7 +224,8 @@ export function* mapDebugCheckpoint< export function tasksWithWrites( tasks: PregelTaskDescription[] | readonly PregelExecutableTask[], pendingWrites: CheckpointPendingWrite[], - states?: Record + states: Record | undefined, + outputKeys: string | string[] ): PregelTaskDescription[] { return tasks.map((task): PregelTaskDescription => { const error = pendingWrites.find( @@ -246,12 +248,33 @@ export function tasksWithWrites( interrupts, }; } + + let result: Record | unknown | undefined; + if ( + pendingWrites.some( + ([id, n]) => id === task.id && n !== ERROR && n !== INTERRUPT + ) + ) { + if (typeof outputKeys === "string") { + result = pendingWrites.find( + ([id, chan]) => id === task.id && chan === outputKeys + )?.[2]; + } else { + result = Object.fromEntries( + pendingWrites + .filter(([id, chan]) => id === task.id && outputKeys.includes(chan)) + .map(([chan, , val]) => [chan, val]) + ); + } + } + return { id: task.id, name: task.name as string, path: task.path, interrupts, state: states?.[task.id], + result, }; }); } diff --git a/libs/langgraph/src/pregel/index.ts b/libs/langgraph/src/pregel/index.ts index 053366c1f..35aa57769 100644 --- a/libs/langgraph/src/pregel/index.ts +++ b/libs/langgraph/src/pregel/index.ts @@ -461,7 +461,12 @@ export class Pregel< this.streamChannelsAsIs as string | string[] ), next: nextTasks.map((task) => task.name), - tasks: tasksWithWrites(nextTasks, saved?.pendingWrites ?? [], taskStates), + tasks: tasksWithWrites( + nextTasks, + saved?.pendingWrites ?? [], + taskStates, + this.streamChannelsAsIs as string | string[] + ), metadata: saved.metadata, config: patchCheckpointMap(saved.config, saved.metadata), createdAt: saved.checkpoint.ts, diff --git a/libs/langgraph/src/pregel/loop.ts b/libs/langgraph/src/pregel/loop.ts index 0c86ccc63..7e3c039f5 100644 --- a/libs/langgraph/src/pregel/loop.ts +++ b/libs/langgraph/src/pregel/loop.ts @@ -626,7 +626,8 @@ export class PregelLoop { this.checkpointMetadata, Object.values(this.tasks), this.checkpointPendingWrites, - this.prevCheckpointConfig + this.prevCheckpointConfig, + this.outputKeys ), "debug" ) diff --git a/libs/langgraph/src/pregel/types.ts b/libs/langgraph/src/pregel/types.ts index f2197ca8f..ddb8102a7 100644 --- a/libs/langgraph/src/pregel/types.ts +++ b/libs/langgraph/src/pregel/types.ts @@ -173,6 +173,7 @@ export interface PregelTaskDescription { readonly error?: unknown; readonly interrupts: Interrupt[]; readonly state?: LangGraphRunnableConfig | StateSnapshot; + readonly result?: Record | unknown; readonly path?: [string, ...(string | number)[]]; } diff --git a/libs/langgraph/src/tests/pregel.test.ts b/libs/langgraph/src/tests/pregel.test.ts index 36cbcc81d..1d541dae6 100644 --- a/libs/langgraph/src/tests/pregel.test.ts +++ b/libs/langgraph/src/tests/pregel.test.ts @@ -1904,6 +1904,7 @@ export function runPregelTests( name: "one", interrupts: [], path: [PULL, "one"], + result: { value: 2 }, }, { id: expect.any(String),