From f6b1d52e202eef96a86ab311ecd53a34ba0113b4 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Mon, 23 Sep 2024 19:58:13 -0700 Subject: [PATCH 01/15] WIP: Implement types for Pregel.stream using recursive type --- libs/langgraph/src/pregel/index.ts | 28 ++++++++++++------------ libs/langgraph/src/pregel/loop.ts | 4 ++-- libs/langgraph/src/pregel/types.ts | 34 ++++++++++++++++++++++++++++-- 3 files changed, 47 insertions(+), 19 deletions(-) diff --git a/libs/langgraph/src/pregel/index.ts b/libs/langgraph/src/pregel/index.ts index b4d4afb3..dcfd1f7f 100644 --- a/libs/langgraph/src/pregel/index.ts +++ b/libs/langgraph/src/pregel/index.ts @@ -49,8 +49,10 @@ import { PregelExecutableTask, PregelInterface, PregelParams, + SingleStreamMode, StateSnapshot, StreamMode, + StreamOutput, } from "./types.js"; import { GraphRecursionError, @@ -177,10 +179,11 @@ export class Channel { */ export interface PregelOptions< Nn extends StrRecord, - Cc extends StrRecord + Cc extends StrRecord, + Sm extends StreamMode = StreamMode > extends RunnableConfig { /** The stream mode for the graph run. Default is ["values"]. */ - streamMode?: StreamMode | StreamMode[]; + streamMode?: Sm; inputKeys?: keyof Cc | Array; /** The output keys to retrieve from the graph run. */ outputKeys?: keyof Cc | Array; @@ -222,7 +225,7 @@ export class Pregel< autoValidate: boolean = true; - streamMode: StreamMode[] = ["values"]; + streamMode: SingleStreamMode = "values"; streamChannels?: keyof Cc | Array; @@ -243,15 +246,10 @@ export class Pregel< constructor(fields: PregelParams) { super(fields); - let { streamMode } = fields; - if (streamMode != null && !Array.isArray(streamMode)) { - streamMode = [streamMode]; - } - this.nodes = fields.nodes; this.channels = fields.channels; this.autoValidate = fields.autoValidate ?? this.autoValidate; - this.streamMode = streamMode ?? this.streamMode; + this.streamMode = fields.streamMode ?? this.streamMode; this.inputChannels = fields.inputChannels; this.outputChannels = fields.outputChannels; this.streamChannels = fields.streamChannels ?? this.streamChannels; @@ -565,7 +563,7 @@ export class Pregel< _defaults(config: PregelOptions): [ boolean, // debug - StreamMode[], // stream mode + SingleStreamMode[], // stream mode string | string[], // input keys string | string[], // output keys RunnableConfig, // config without pregel keys @@ -603,11 +601,11 @@ export class Pregel< const defaultInterruptAfter = interruptAfter ?? this.interruptAfter ?? []; - let defaultStreamMode: StreamMode[]; + let defaultStreamMode: SingleStreamMode[]; if (streamMode !== undefined) { defaultStreamMode = Array.isArray(streamMode) ? streamMode : [streamMode]; } else { - defaultStreamMode = this.streamMode; + defaultStreamMode = [this.streamMode]; } let defaultCheckpointer: BaseCheckpointSaver | undefined; @@ -654,10 +652,10 @@ export class Pregel< * @param options.interruptAfter Nodes to interrupt after. * @param options.debug Whether to print debug information during execution. */ - override async stream( + override async stream( input: PregelInputType, - options?: Partial> - ): Promise> { + options?: Partial> + ): Promise>> { return super.stream(input, options); } diff --git a/libs/langgraph/src/pregel/loop.ts b/libs/langgraph/src/pregel/loop.ts index 54a5953c..e993b54d 100644 --- a/libs/langgraph/src/pregel/loop.ts +++ b/libs/langgraph/src/pregel/loop.ts @@ -17,7 +17,7 @@ import { createCheckpoint, emptyChannels, } from "../channels/base.js"; -import { PregelExecutableTask, StreamMode } from "./types.js"; +import { PregelExecutableTask, SingleStreamMode } from "./types.js"; import { CONFIG_KEY_READ, CONFIG_KEY_RESUMING, @@ -137,7 +137,7 @@ export class PregelLoop { tasks: PregelExecutableTask[] = []; // eslint-disable-next-line @typescript-eslint/no-explicit-any - stream: Deque<[StreamMode, any]> = new Deque(); + stream: Deque<[SingleStreamMode, any]> = new Deque(); checkpointerPromises: Promise[] = []; diff --git a/libs/langgraph/src/pregel/types.ts b/libs/langgraph/src/pregel/types.ts index 7ce7d675..0017eaac 100644 --- a/libs/langgraph/src/pregel/types.ts +++ b/libs/langgraph/src/pregel/types.ts @@ -12,7 +12,37 @@ import { Interrupt } from "../constants.js"; import { BaseStore } from "../store/base.js"; import { type ManagedValueSpec } from "../managed/base.js"; -export type StreamMode = "values" | "updates" | "debug"; +export type DebugOutput = unknown; // TODO + +export type SingleStreamMode = "values" | "updates" | "debug"; + +export type StreamMode = + | SingleStreamMode + | [SingleStreamMode] + | [SingleStreamMode, SingleStreamMode] + | [SingleStreamMode, SingleStreamMode, SingleStreamMode]; + +export type StreamOutput< + S extends StreamMode, + Nn extends StrRecord, + Schema +> = S extends "values" + ? Schema + : S extends "updates" + ? { [K in keyof Nn]: Partial } + : S extends "debug" + ? DebugOutput + : S extends [SingleStreamMode] + ? StreamOutput + : S extends [SingleStreamMode, SingleStreamMode] + ? [StreamOutput, StreamOutput] + : S extends [SingleStreamMode, SingleStreamMode, SingleStreamMode] + ? [ + StreamOutput, + StreamOutput, + StreamOutput + ] + : never; /** * Construct a type with a set of properties K of type T @@ -37,7 +67,7 @@ export interface PregelInterface< /** * @default "values" */ - streamMode?: StreamMode | StreamMode[]; + streamMode?: SingleStreamMode; inputChannels: keyof Cc | Array; From 7b24c925ddae7454cc1a89c2ddc5486362566792 Mon Sep 17 00:00:00 2001 From: Allan Deutsch Date: Tue, 24 Sep 2024 13:43:45 -0700 Subject: [PATCH 02/15] no longer excessively deep nor possibly infinite --- libs/langgraph/src/pregel/types.ts | 29 ++++++++++++++++++++--------- 1 file changed, 20 insertions(+), 9 deletions(-) diff --git a/libs/langgraph/src/pregel/types.ts b/libs/langgraph/src/pregel/types.ts index 0017eaac..544cf6d2 100644 --- a/libs/langgraph/src/pregel/types.ts +++ b/libs/langgraph/src/pregel/types.ts @@ -22,9 +22,9 @@ export type StreamMode = | [SingleStreamMode, SingleStreamMode] | [SingleStreamMode, SingleStreamMode, SingleStreamMode]; -export type StreamOutput< - S extends StreamMode, - Nn extends StrRecord, +type SingleStreamModeOutput< + S extends SingleStreamMode, + Nn extends Record, Schema > = S extends "values" ? Schema @@ -32,16 +32,27 @@ export type StreamOutput< ? { [K in keyof Nn]: Partial } : S extends "debug" ? DebugOutput + : never; + +export type StreamOutput< + S extends StreamMode, + Nn extends Record, + Schema +> = S extends SingleStreamMode + ? SingleStreamModeOutput : S extends [SingleStreamMode] - ? StreamOutput + ? SingleStreamModeOutput : S extends [SingleStreamMode, SingleStreamMode] - ? [StreamOutput, StreamOutput] + ? [ + SingleStreamModeOutput, + SingleStreamModeOutput + ] : S extends [SingleStreamMode, SingleStreamMode, SingleStreamMode] ? [ - StreamOutput, - StreamOutput, - StreamOutput - ] + SingleStreamModeOutput, + SingleStreamModeOutput, + SingleStreamModeOutput + ] : never; /** From 92184d13b37023bc294bfc5ef2ec15e9f75607d0 Mon Sep 17 00:00:00 2001 From: Allan Deutsch Date: Tue, 24 Sep 2024 14:04:02 -0700 Subject: [PATCH 03/15] specified debug output type --- libs/langgraph/src/pregel/types.ts | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/libs/langgraph/src/pregel/types.ts b/libs/langgraph/src/pregel/types.ts index 544cf6d2..57e0b4ed 100644 --- a/libs/langgraph/src/pregel/types.ts +++ b/libs/langgraph/src/pregel/types.ts @@ -12,7 +12,18 @@ import { Interrupt } from "../constants.js"; import { BaseStore } from "../store/base.js"; import { type ManagedValueSpec } from "../managed/base.js"; -export type DebugOutput = unknown; // TODO +export type DebugOutput = { + type: string; + timestamp: string; + step: number; + payload: { + id: string; + name: any; + result: PendingWrite[]; + config: RunnableConfig; + metadata?: CheckpointMetadata; + }; +}; export type SingleStreamMode = "values" | "updates" | "debug"; From 24eb24170f604d3f16bdc1e61cd3bf1ad126b1c2 Mon Sep 17 00:00:00 2001 From: Allan Deutsch Date: Tue, 24 Sep 2024 14:12:58 -0700 Subject: [PATCH 04/15] all tests passing! --- libs/langgraph/src/pregel/types.ts | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/libs/langgraph/src/pregel/types.ts b/libs/langgraph/src/pregel/types.ts index 57e0b4ed..f5ff490a 100644 --- a/libs/langgraph/src/pregel/types.ts +++ b/libs/langgraph/src/pregel/types.ts @@ -12,14 +12,16 @@ import { Interrupt } from "../constants.js"; import { BaseStore } from "../store/base.js"; import { type ManagedValueSpec } from "../managed/base.js"; -export type DebugOutput = { +export type DebugOutput< + Cc extends Record +> = { type: string; timestamp: string; step: number; payload: { id: string; - name: any; - result: PendingWrite[]; + name: string; + result: PendingWrite[]; config: RunnableConfig; metadata?: CheckpointMetadata; }; @@ -36,33 +38,35 @@ export type StreamMode = type SingleStreamModeOutput< S extends SingleStreamMode, Nn extends Record, + Cc extends Record, Schema > = S extends "values" ? Schema : S extends "updates" ? { [K in keyof Nn]: Partial } : S extends "debug" - ? DebugOutput + ? DebugOutput : never; export type StreamOutput< S extends StreamMode, Nn extends Record, + Cc extends Record, Schema > = S extends SingleStreamMode - ? SingleStreamModeOutput + ? SingleStreamModeOutput : S extends [SingleStreamMode] - ? SingleStreamModeOutput + ? SingleStreamModeOutput : S extends [SingleStreamMode, SingleStreamMode] ? [ - SingleStreamModeOutput, - SingleStreamModeOutput + SingleStreamModeOutput, + SingleStreamModeOutput ] : S extends [SingleStreamMode, SingleStreamMode, SingleStreamMode] ? [ - SingleStreamModeOutput, - SingleStreamModeOutput, - SingleStreamModeOutput + SingleStreamModeOutput, + SingleStreamModeOutput, + SingleStreamModeOutput ] : never; From 278517c0a4be8c10b9a9f8d55a468b815c7e9873 Mon Sep 17 00:00:00 2001 From: Allan Deutsch Date: Tue, 24 Sep 2024 14:20:05 -0700 Subject: [PATCH 05/15] all tests passing! --- libs/langgraph/src/pregel/index.ts | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/libs/langgraph/src/pregel/index.ts b/libs/langgraph/src/pregel/index.ts index dcfd1f7f..e5c09f34 100644 --- a/libs/langgraph/src/pregel/index.ts +++ b/libs/langgraph/src/pregel/index.ts @@ -178,8 +178,8 @@ export class Channel { * Config for executing the graph. */ export interface PregelOptions< - Nn extends StrRecord, - Cc extends StrRecord, + Nn extends Record, + Cc extends Record, Sm extends StreamMode = StreamMode > extends RunnableConfig { /** The stream mode for the graph run. Default is ["values"]. */ @@ -202,9 +202,9 @@ export type PregelInputType = any; export type PregelOutputType = any; export class Pregel< - Nn extends StrRecord, - Cc extends StrRecord - > + Nn extends StrRecord, + Cc extends StrRecord +> extends Runnable> implements PregelInterface { @@ -655,7 +655,9 @@ export class Pregel< override async stream( input: PregelInputType, options?: Partial> - ): Promise>> { + ): Promise< + IterableReadableStream> + > { return super.stream(input, options); } From a387043e0f1a25f61b29e3e436805d3c2954d12f Mon Sep 17 00:00:00 2001 From: Allan Deutsch Date: Tue, 24 Sep 2024 14:23:18 -0700 Subject: [PATCH 06/15] formatting --- libs/langgraph/src/pregel/index.ts | 6 +++--- libs/langgraph/src/pregel/types.ts | 14 +++++++------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/libs/langgraph/src/pregel/index.ts b/libs/langgraph/src/pregel/index.ts index e5c09f34..5a3e0764 100644 --- a/libs/langgraph/src/pregel/index.ts +++ b/libs/langgraph/src/pregel/index.ts @@ -202,9 +202,9 @@ export type PregelInputType = any; export type PregelOutputType = any; export class Pregel< - Nn extends StrRecord, - Cc extends StrRecord -> + Nn extends StrRecord, + Cc extends StrRecord + > extends Runnable> implements PregelInterface { diff --git a/libs/langgraph/src/pregel/types.ts b/libs/langgraph/src/pregel/types.ts index f5ff490a..3d5cf38e 100644 --- a/libs/langgraph/src/pregel/types.ts +++ b/libs/langgraph/src/pregel/types.ts @@ -59,15 +59,15 @@ export type StreamOutput< ? SingleStreamModeOutput : S extends [SingleStreamMode, SingleStreamMode] ? [ - SingleStreamModeOutput, - SingleStreamModeOutput - ] + SingleStreamModeOutput, + SingleStreamModeOutput + ] : S extends [SingleStreamMode, SingleStreamMode, SingleStreamMode] ? [ - SingleStreamModeOutput, - SingleStreamModeOutput, - SingleStreamModeOutput - ] + SingleStreamModeOutput, + SingleStreamModeOutput, + SingleStreamModeOutput + ] : never; /** From 6cb87666dab8badf85c5c4f76bb3c6930d0e37a1 Mon Sep 17 00:00:00 2001 From: Allan Deutsch Date: Tue, 24 Sep 2024 14:40:09 -0700 Subject: [PATCH 07/15] cleared out StrRecord type --- libs/langgraph/src/pregel/algo.ts | 21 +++----- libs/langgraph/src/pregel/index.ts | 11 ++-- libs/langgraph/src/pregel/types.ts | 87 ++++++++++++++---------------- 3 files changed, 51 insertions(+), 68 deletions(-) diff --git a/libs/langgraph/src/pregel/algo.ts b/libs/langgraph/src/pregel/algo.ts index 529fe71c..a8b2eee0 100644 --- a/libs/langgraph/src/pregel/algo.ts +++ b/libs/langgraph/src/pregel/algo.ts @@ -43,13 +43,6 @@ import { EmptyChannelError, InvalidUpdateError } from "../errors.js"; import { _getIdMetadata, getNullChannelVersion } from "./utils.js"; import { ManagedValueMapping } from "../managed/base.js"; -/** - * Construct a type with a set of properties K of type T - */ -export type StrRecord = { - [P in K]: T; -}; - export type WritesProtocol = { name: string; writes: PendingWrite[]; @@ -329,8 +322,8 @@ export type NextTaskExtraFields = { }; export function _prepareNextTasks< - Nn extends StrRecord, - Cc extends StrRecord + Nn extends Record, + Cc extends Record >( checkpoint: ReadonlyCheckpoint, processes: Nn, @@ -342,8 +335,8 @@ export function _prepareNextTasks< ): PregelTaskDescription[]; export function _prepareNextTasks< - Nn extends StrRecord, - Cc extends StrRecord + Nn extends Record, + Cc extends Record >( checkpoint: ReadonlyCheckpoint, processes: Nn, @@ -355,8 +348,8 @@ export function _prepareNextTasks< ): PregelExecutableTask[]; export function _prepareNextTasks< - Nn extends StrRecord, - Cc extends StrRecord + Nn extends Record, + Cc extends Record >( checkpoint: ReadonlyCheckpoint, processes: Nn, @@ -573,7 +566,7 @@ function _procInput( step: number, proc: PregelNode, managed: ManagedValueMapping, - channels: StrRecord, + channels: Record, forExecution: boolean ) { // eslint-disable-next-line @typescript-eslint/no-explicit-any diff --git a/libs/langgraph/src/pregel/index.ts b/libs/langgraph/src/pregel/index.ts index 5a3e0764..c0d85e5e 100644 --- a/libs/langgraph/src/pregel/index.ts +++ b/libs/langgraph/src/pregel/index.ts @@ -60,12 +60,7 @@ import { InvalidUpdateError, isGraphInterrupt, } from "../errors.js"; -import { - _prepareNextTasks, - _localRead, - _applyWrites, - StrRecord, -} from "./algo.js"; +import { _prepareNextTasks, _localRead, _applyWrites } from "./algo.js"; import { _coerceToDict, getNewChannelVersions, RetryPolicy } from "./utils.js"; import { PregelLoop } from "./loop.js"; import { executeTasksWithRetry } from "./retry.js"; @@ -202,8 +197,8 @@ export type PregelInputType = any; export type PregelOutputType = any; export class Pregel< - Nn extends StrRecord, - Cc extends StrRecord + Nn extends Record, + Cc extends Record > extends Runnable> implements PregelInterface diff --git a/libs/langgraph/src/pregel/types.ts b/libs/langgraph/src/pregel/types.ts index 3d5cf38e..12c550da 100644 --- a/libs/langgraph/src/pregel/types.ts +++ b/libs/langgraph/src/pregel/types.ts @@ -12,16 +12,18 @@ import { Interrupt } from "../constants.js"; import { BaseStore } from "../store/base.js"; import { type ManagedValueSpec } from "../managed/base.js"; -export type DebugOutput< - Cc extends Record -> = { +export type ChannelsType = Record; + +export type NodesType = Record; + +export type DebugOutput = { type: string; timestamp: string; step: number; payload: { id: string; name: string; - result: PendingWrite[]; + result: PendingWrite[]; config: RunnableConfig; metadata?: CheckpointMetadata; }; @@ -36,54 +38,47 @@ export type StreamMode = | [SingleStreamMode, SingleStreamMode, SingleStreamMode]; type SingleStreamModeOutput< - S extends SingleStreamMode, - Nn extends Record, - Cc extends Record, + Mode extends SingleStreamMode, + Nodes extends NodesType, + Channels extends ChannelsType, Schema -> = S extends "values" +> = Mode extends "values" ? Schema - : S extends "updates" - ? { [K in keyof Nn]: Partial } - : S extends "debug" - ? DebugOutput + : Mode extends "updates" + ? { [K in keyof Nodes]: Partial } + : Mode extends "debug" + ? DebugOutput : never; export type StreamOutput< - S extends StreamMode, - Nn extends Record, - Cc extends Record, + Mode extends StreamMode, + Nodes extends NodesType, + Channels extends ChannelsType, Schema -> = S extends SingleStreamMode - ? SingleStreamModeOutput - : S extends [SingleStreamMode] - ? SingleStreamModeOutput - : S extends [SingleStreamMode, SingleStreamMode] +> = Mode extends SingleStreamMode + ? SingleStreamModeOutput + : Mode extends [SingleStreamMode] + ? SingleStreamModeOutput + : Mode extends [SingleStreamMode, SingleStreamMode] ? [ - SingleStreamModeOutput, - SingleStreamModeOutput + SingleStreamModeOutput, + SingleStreamModeOutput ] - : S extends [SingleStreamMode, SingleStreamMode, SingleStreamMode] + : Mode extends [SingleStreamMode, SingleStreamMode, SingleStreamMode] ? [ - SingleStreamModeOutput, - SingleStreamModeOutput, - SingleStreamModeOutput + SingleStreamModeOutput, + SingleStreamModeOutput, + SingleStreamModeOutput ] : never; -/** - * Construct a type with a set of properties K of type T - */ -type StrRecord = { - [P in K]: T; -}; - export interface PregelInterface< - Nn extends StrRecord, - Cc extends StrRecord + Nodes extends NodesType, + Channels extends ChannelsType > { - nodes: Nn; + nodes: Nodes; - channels: Cc; + channels: Channels; /** * @default true @@ -95,23 +90,23 @@ export interface PregelInterface< */ streamMode?: SingleStreamMode; - inputChannels: keyof Cc | Array; + inputChannels: keyof Channels | Array; - outputChannels: keyof Cc | Array; + outputChannels: keyof Channels | Array; /** * @default [] */ - interruptAfter?: Array | All; + interruptAfter?: Array | All; /** * @default [] */ - interruptBefore?: Array | All; + interruptBefore?: Array | All; - streamChannels?: keyof Cc | Array; + streamChannels?: keyof Channels | Array; - get streamChannelsAsIs(): keyof Cc | Array; + get streamChannelsAsIs(): keyof Channels | Array; /** * @default undefined @@ -134,9 +129,9 @@ export interface PregelInterface< } export type PregelParams< - Nn extends StrRecord, - Cc extends StrRecord -> = Omit, "streamChannelsAsIs">; + Nodes extends NodesType, + Channels extends ChannelsType +> = Omit, "streamChannelsAsIs">; export interface PregelTaskDescription { readonly id: string; From 7773c049f43d1492cf3f77dfea51acc72f1e8548 Mon Sep 17 00:00:00 2001 From: Allan Deutsch Date: Tue, 24 Sep 2024 14:55:51 -0700 Subject: [PATCH 08/15] clearer type parameters for channels and nodes --- libs/langgraph/src/channels/base.ts | 10 +-- libs/langgraph/src/pregel/algo.ts | 106 ++++++++++++++++---------- libs/langgraph/src/pregel/index.ts | 69 +++++++++-------- libs/langgraph/src/pregel/validate.ts | 32 ++++---- 4 files changed, 122 insertions(+), 95 deletions(-) diff --git a/libs/langgraph/src/channels/base.ts b/libs/langgraph/src/channels/base.ts index d583dcfb..e2c3b5c1 100644 --- a/libs/langgraph/src/channels/base.ts +++ b/libs/langgraph/src/channels/base.ts @@ -78,15 +78,15 @@ export abstract class BaseChannel< } } -export function emptyChannels>( - channels: Cc, +export function emptyChannels>( + channels: Channels, checkpoint: ReadonlyCheckpoint -): Cc { +): Channels { const filteredChannels = Object.fromEntries( Object.entries(channels).filter(([, value]) => isBaseChannel(value)) - ) as Cc; + ) as Channels; - const newChannels = {} as Cc; + const newChannels = {} as Channels; for (const k in filteredChannels) { if (Object.prototype.hasOwnProperty.call(filteredChannels, k)) { const channelValue = checkpoint.channel_values[k]; diff --git a/libs/langgraph/src/pregel/algo.ts b/libs/langgraph/src/pregel/algo.ts index a8b2eee0..9b51eaa6 100644 --- a/libs/langgraph/src/pregel/algo.ts +++ b/libs/langgraph/src/pregel/algo.ts @@ -38,7 +38,12 @@ import { TAG_HIDDEN, TASKS, } from "../constants.js"; -import { PregelExecutableTask, PregelTaskDescription } from "./types.js"; +import { + ChannelsType, + NodesType, + PregelExecutableTask, + PregelTaskDescription, +} from "./types.js"; import { EmptyChannelError, InvalidUpdateError } from "../errors.js"; import { _getIdMetadata, getNullChannelVersion } from "./utils.js"; import { ManagedValueMapping } from "../managed/base.js"; @@ -84,17 +89,17 @@ export function shouldInterrupt( return anyChannelUpdated && anyTriggeredNodeInInterruptNodes; } -export function _localRead>( +export function _localRead>( step: number, checkpoint: ReadonlyCheckpoint, - channels: Cc, + channels: LocalChannels, managed: ManagedValueMapping, - task: WritesProtocol, - select: Array | keyof Cc, + task: WritesProtocol, + select: Array | keyof LocalChannels, fresh: boolean = false ): Record | unknown { - let managedKeys: Array = []; - let updated = new Set(); + let managedKeys: Array = []; + let updated = new Set(); if (!Array.isArray(select)) { for (const [c] of task.writes) { @@ -106,9 +111,11 @@ export function _localRead>( updated = updated || new Set(); } else { managedKeys = select.filter((k) => managed.get(k as string)) as Array< - keyof Cc + keyof LocalChannels + >; + select = select.filter((k) => !managed.get(k as string)) as Array< + keyof LocalChannels >; - select = select.filter((k) => !managed.get(k as string)) as Array; updated = new Set( select.filter((c) => task.writes.some(([key, _]) => key === c)) ); @@ -118,11 +125,20 @@ export function _localRead>( if (fresh && updated.size > 0) { const localChannels = Object.fromEntries( - Object.entries(channels).filter(([k, _]) => updated.has(k as keyof Cc)) - ) as Partial; - - const newCheckpoint = createCheckpoint(checkpoint, localChannels as Cc, -1); - const newChannels = emptyChannels(localChannels as Cc, newCheckpoint); + Object.entries(channels).filter(([k, _]) => + updated.has(k as keyof LocalChannels) + ) + ) as Partial; + + const newCheckpoint = createCheckpoint( + checkpoint, + localChannels as LocalChannels, + -1 + ); + const newChannels = emptyChannels( + localChannels as LocalChannels, + newCheckpoint + ); _applyWrites(copyCheckpoint(newCheckpoint), newChannels, [task]); values = readChannels({ ...channels, ...newChannels }, select); @@ -176,17 +192,17 @@ export function _localWrite( commit(writes); } -export function _applyWrites>( +export function _applyWrites>( checkpoint: Checkpoint, - channels: Cc, - tasks: WritesProtocol[], + channels: LocalChannels, + tasks: WritesProtocol[], // eslint-disable-next-line @typescript-eslint/no-explicit-any getNextVersion?: (version: any, channel: BaseChannel) => any ): Record { // Filter out non instances of BaseChannel const onlyChannels = Object.fromEntries( Object.entries(channels).filter(([_, value]) => isBaseChannel(value)) - ) as Cc; + ) as LocalChannels; // Update seen versions for (const task of tasks) { if (checkpoint.versions_seen[task.name] === undefined) { @@ -233,10 +249,13 @@ export function _applyWrites>( // Group writes by channel const pendingWriteValuesByChannel = {} as Record< - keyof Cc, + keyof LocalChannels, + PendingWriteValue[] + >; + const pendingWritesByManaged = {} as Record< + keyof LocalChannels, PendingWriteValue[] >; - const pendingWritesByManaged = {} as Record; for (const task of tasks) { for (const [chan, val] of task.writes) { if (chan === TASKS) { @@ -322,12 +341,12 @@ export type NextTaskExtraFields = { }; export function _prepareNextTasks< - Nn extends Record, - Cc extends Record + Nodes extends NodesType, + Channels extends ChannelsType >( checkpoint: ReadonlyCheckpoint, - processes: Nn, - channels: Cc, + processes: Nodes, + channels: Channels, managed: ManagedValueMapping, config: RunnableConfig, forExecution: false, @@ -335,32 +354,35 @@ export function _prepareNextTasks< ): PregelTaskDescription[]; export function _prepareNextTasks< - Nn extends Record, - Cc extends Record + Nodes extends NodesType, + Channels extends ChannelsType >( checkpoint: ReadonlyCheckpoint, - processes: Nn, - channels: Cc, + processes: Nodes, + channels: Channels, managed: ManagedValueMapping, config: RunnableConfig, forExecution: true, extra: NextTaskExtraFields -): PregelExecutableTask[]; +): PregelExecutableTask[]; export function _prepareNextTasks< - Nn extends Record, - Cc extends Record + Nodes extends NodesType, + LocalChannels extends Record >( checkpoint: ReadonlyCheckpoint, - processes: Nn, - channels: Cc, + processes: Nodes, + channels: LocalChannels, managed: ManagedValueMapping, config: RunnableConfig, forExecution: boolean, extra: NextTaskExtraFields -): PregelTaskDescription[] | PregelExecutableTask[] { +): + | PregelTaskDescription[] + | PregelExecutableTask[] { const parentNamespace = config.configurable?.checkpoint_ns ?? ""; - const tasks: Array> = []; + const tasks: Array> = + []; const taskDescriptions: Array = []; const { step, isResuming = false, checkpointer, manager } = extra; @@ -397,7 +419,7 @@ export function _prepareNextTasks< const proc = processes[packet.node]; const node = proc.getNode(); if (node !== undefined) { - const writes: [keyof Cc, unknown][] = []; + const writes: [keyof LocalChannels, unknown][] = []; managed.replaceRuntimePlaceholders(step, packet.args); tasks.push({ name: packet.node, @@ -417,14 +439,15 @@ export function _prepareNextTasks< [CONFIG_KEY_SEND]: (writes_: [string, any][]) => _localWrite( step, - (items: [keyof Cc, unknown][]) => writes.push(...items), + (items: [keyof LocalChannels, unknown][]) => + writes.push(...items), processes, channels, managed, writes_ ), [CONFIG_KEY_READ]: ( - select_: Array | keyof Cc, + select_: Array | keyof LocalChannels, fresh_: boolean = false ) => _localRead( @@ -503,7 +526,7 @@ export function _prepareNextTasks< if (forExecution) { const node = proc.getNode(); if (node !== undefined) { - const writes: [keyof Cc, unknown][] = []; + const writes: [keyof LocalChannels, unknown][] = []; tasks.push({ name, input: val, @@ -520,14 +543,15 @@ export function _prepareNextTasks< [CONFIG_KEY_SEND]: (writes_: [string, any][]) => _localWrite( step, - (items: [keyof Cc, unknown][]) => writes.push(...items), + (items: [keyof LocalChannels, unknown][]) => + writes.push(...items), processes, channels, managed, writes_ ), [CONFIG_KEY_READ]: ( - select_: Array | keyof Cc, + select_: Array | keyof LocalChannels, fresh_: boolean = false ) => _localRead( diff --git a/libs/langgraph/src/pregel/index.ts b/libs/langgraph/src/pregel/index.ts index c0d85e5e..312315a9 100644 --- a/libs/langgraph/src/pregel/index.ts +++ b/libs/langgraph/src/pregel/index.ts @@ -46,6 +46,8 @@ import { INTERRUPT, } from "../constants.js"; import { + NodesType, + ChannelsType, PregelExecutableTask, PregelInterface, PregelParams, @@ -173,19 +175,19 @@ export class Channel { * Config for executing the graph. */ export interface PregelOptions< - Nn extends Record, - Cc extends Record, - Sm extends StreamMode = StreamMode + Nodes extends NodesType, + Channels extends ChannelsType, + Mode extends StreamMode = StreamMode > extends RunnableConfig { /** The stream mode for the graph run. Default is ["values"]. */ - streamMode?: Sm; - inputKeys?: keyof Cc | Array; + streamMode?: Mode; + inputKeys?: keyof Channels | Array; /** The output keys to retrieve from the graph run. */ - outputKeys?: keyof Cc | Array; + outputKeys?: keyof Channels | Array; /** The nodes to interrupt the graph run before. */ - interruptBefore?: All | Array; + interruptBefore?: All | Array; /** The nodes to interrupt the graph run after. */ - interruptAfter?: All | Array; + interruptAfter?: All | Array; /** Enable debug mode for the graph run. */ debug?: boolean; } @@ -196,12 +198,13 @@ export type PregelInputType = any; // eslint-disable-next-line @typescript-eslint/no-explicit-any export type PregelOutputType = any; -export class Pregel< - Nn extends Record, - Cc extends Record +export class Pregel + extends Runnable< + PregelInputType, + PregelOutputType, + PregelOptions > - extends Runnable> - implements PregelInterface + implements PregelInterface { static lc_name() { return "LangGraph"; @@ -210,23 +213,23 @@ export class Pregel< // Because Pregel extends `Runnable`. lc_namespace = ["langgraph", "pregel"]; - nodes: Nn; + nodes: Nodes; - channels: Cc; + channels: Channels; - inputChannels: keyof Cc | Array; + inputChannels: keyof Channels | Array; - outputChannels: keyof Cc | Array; + outputChannels: keyof Channels | Array; autoValidate: boolean = true; streamMode: SingleStreamMode = "values"; - streamChannels?: keyof Cc | Array; + streamChannels?: keyof Channels | Array; - interruptAfter?: Array | All; + interruptAfter?: Array | All; - interruptBefore?: Array | All; + interruptBefore?: Array | All; stepTimeout?: number; @@ -238,7 +241,7 @@ export class Pregel< store?: BaseStore; - constructor(fields: PregelParams) { + constructor(fields: PregelParams) { super(fields); this.nodes = fields.nodes; @@ -262,7 +265,7 @@ export class Pregel< } validate(): this { - validateGraph({ + validateGraph({ nodes: this.nodes, channels: this.channels, outputChannels: this.outputChannels, @@ -275,7 +278,7 @@ export class Pregel< return this; } - get streamChannelsList(): Array { + get streamChannelsList(): Array { if (Array.isArray(this.streamChannels)) { return this.streamChannels; } else if (this.streamChannels) { @@ -285,7 +288,7 @@ export class Pregel< } } - get streamChannelsAsIs(): keyof Cc | Array { + get streamChannelsAsIs(): keyof Channels | Array { if (this.streamChannels) { return this.streamChannels; } else { @@ -384,7 +387,7 @@ export class Pregel< async updateState( config: RunnableConfig, values: Record | unknown, - asNode?: keyof Nn + asNode?: keyof Nodes ): Promise { if (!this.checkpointer) { throw new GraphValueError("No checkpointer set"); @@ -486,7 +489,7 @@ export class Pregel< `No writers found for node "${asNode.toString()}"` ); } - const task: PregelExecutableTask = { + const task: PregelExecutableTask = { name: asNode, input: values, proc: @@ -503,10 +506,10 @@ export class Pregel< patchConfig(config, { runName: config.runName ?? `${this.getName()}UpdateState`, configurable: { - [CONFIG_KEY_SEND]: (items: [keyof Cc, unknown][]) => + [CONFIG_KEY_SEND]: (items: [keyof Channels, unknown][]) => task.writes.push(...items), [CONFIG_KEY_READ]: ( - select_: Array | keyof Cc, + select_: Array | keyof Channels, fresh_: boolean = false ) => _localRead( @@ -556,7 +559,7 @@ export class Pregel< ); } - _defaults(config: PregelOptions): [ + _defaults(config: PregelOptions): [ boolean, // debug SingleStreamMode[], // stream mode string | string[], // input keys @@ -649,9 +652,9 @@ export class Pregel< */ override async stream( input: PregelInputType, - options?: Partial> + options?: Partial> ): Promise< - IterableReadableStream> + IterableReadableStream> > { return super.stream(input, options); } @@ -720,7 +723,7 @@ export class Pregel< override async *_streamIterator( input: PregelInputType, - options?: Partial> + options?: Partial> ): AsyncGenerator { const inputConfig = ensureConfig(options); if ( @@ -915,7 +918,7 @@ export class Pregel< */ override async invoke( input: PregelInputType, - options?: Partial> + options?: Partial> ): Promise { const streamMode = options?.streamMode ?? "values"; const config = { diff --git a/libs/langgraph/src/pregel/validate.ts b/libs/langgraph/src/pregel/validate.ts index 7f86e161..b1637a7b 100644 --- a/libs/langgraph/src/pregel/validate.ts +++ b/libs/langgraph/src/pregel/validate.ts @@ -1,8 +1,7 @@ import { All } from "@langchain/langgraph-checkpoint"; -import { BaseChannel } from "../channels/index.js"; import { INTERRUPT } from "../constants.js"; import { PregelNode } from "./read.js"; -import { type ManagedValueSpec } from "../managed/base.js"; +import { ChannelsType, NodesType } from "./types.js"; export class GraphValidationError extends Error { constructor(message?: string) { @@ -12,8 +11,8 @@ export class GraphValidationError extends Error { } export function validateGraph< - Nn extends Record, - Cc extends Record + Nodes extends NodesType, + Channels extends ChannelsType >({ nodes, channels, @@ -23,20 +22,20 @@ export function validateGraph< interruptAfterNodes, interruptBeforeNodes, }: { - nodes: Nn; - channels: Cc; - inputChannels: keyof Cc | Array; - outputChannels: keyof Cc | Array; - streamChannels?: keyof Cc | Array; - interruptAfterNodes?: Array | All; - interruptBeforeNodes?: Array | All; + nodes: Nodes; + channels: Channels; + inputChannels: keyof Channels | Array; + outputChannels: keyof Channels | Array; + streamChannels?: keyof Channels | Array; + interruptAfterNodes?: Array | All; + interruptBeforeNodes?: Array | All; }): void { if (!channels) { throw new GraphValidationError("Channels not provided"); } - const subscribedChannels = new Set(); - const allOutputChannels = new Set(); + const subscribedChannels = new Set(); + const allOutputChannels = new Set(); for (const [name, node] of Object.entries(nodes)) { if (name === INTERRUPT) { @@ -114,9 +113,10 @@ export function validateGraph< } } -export function validateKeys< - Cc extends Record ->(keys: keyof Cc | Array, channels: Cc): void { +export function validateKeys( + keys: keyof Channels | Array, + channels: Channels +): void { if (Array.isArray(keys)) { for (const key of keys) { if (!(key in channels)) { From 7423be54cfe7d937bf1a3162e359b0a4a55d35ff Mon Sep 17 00:00:00 2001 From: Allan Deutsch Date: Thu, 26 Sep 2024 10:35:31 -0700 Subject: [PATCH 09/15] need to type the values stored in pregel loop --- libs/langgraph/src/graph/graph.ts | 107 ++++++++++-------- libs/langgraph/src/graph/state.ts | 172 +++++++++++++++++++---------- libs/langgraph/src/pregel/index.ts | 46 ++++---- libs/langgraph/src/pregel/types.ts | 110 +++++++++++++++--- 4 files changed, 284 insertions(+), 151 deletions(-) diff --git a/libs/langgraph/src/graph/graph.ts b/libs/langgraph/src/graph/graph.ts index 55ac5ca7..9da28c82 100644 --- a/libs/langgraph/src/graph/graph.ts +++ b/libs/langgraph/src/graph/graph.ts @@ -53,9 +53,9 @@ export class Branch { this.condition = options.path; this.ends = Array.isArray(options.pathMap) ? options.pathMap.reduce((acc, n) => { - acc[n] = n; - return acc; - }, {} as Record) + acc[n] = n; + return acc; + }, {} as Record) : options.pathMap; } @@ -74,7 +74,7 @@ export class Branch { if (e.name === NodeInterrupt.unminifiable_name) { console.warn( "[WARN]: 'NodeInterrupt' thrown in conditional edge. This is likely a bug in your graph implementation.\n" + - "NodeInterrupt should only be thrown inside a node, not in edge conditions." + "NodeInterrupt should only be thrown inside a node, not in edge conditions." ); } throw e; @@ -119,7 +119,7 @@ export type NodeSpec = { export type AddNodeOptions = { metadata?: Record }; export class Graph< - N extends string = typeof END, + NodeNames extends string = typeof END, // eslint-disable-next-line @typescript-eslint/no-explicit-any RunInput = any, // eslint-disable-next-line @typescript-eslint/no-explicit-any @@ -129,18 +129,18 @@ export class Graph< RunOutput > > { - nodes: Record; + nodes: Record; - edges: Set<[N | typeof START, N | typeof END]>; + edges: Set<[NodeNames | typeof START, NodeNames | typeof END]>; - branches: Record>>; + branches: Record>>; entryPoint?: string; compiled = false; constructor() { - this.nodes = {} as Record; + this.nodes = {} as Record; this.edges = new Set(); this.branches = {}; } @@ -159,7 +159,7 @@ export class Graph< key: K, action: RunnableLike, options?: AddNodeOptions - ): Graph { + ): Graph { if (key.includes(CHECKPOINT_NAMESPACE_SEPARATOR)) { throw new Error( `"${CHECKPOINT_NAMESPACE_SEPARATOR}" is a reserved character and is not allowed in node names.` @@ -176,7 +176,7 @@ export class Graph< throw new Error(`Node \`${key}\` is reserved.`); } - this.nodes[key as unknown as N] = { + this.nodes[key as unknown as NodeNames] = { runnable: _coerceToRunnable( // Account for arbitrary state due to Send API action as RunnableLike @@ -184,10 +184,13 @@ export class Graph< metadata: options?.metadata, } as NodeSpecType; - return this as Graph; + return this as Graph; } - addEdge(startKey: N | typeof START, endKey: N | typeof END): this { + addEdge( + startKey: NodeNames | typeof START, + endKey: NodeNames | typeof END + ): this { this.warnIfCompiled( `Adding an edge to a graph that has already been compiled. This will not be reflected in the compiled graph.` ); @@ -212,20 +215,20 @@ export class Graph< return this; } - addConditionalEdges(source: BranchOptions): this; + addConditionalEdges(source: BranchOptions): this; addConditionalEdges( - source: N, - path: Branch["condition"], - pathMap?: BranchOptions["pathMap"] + source: NodeNames, + path: Branch["condition"], + pathMap?: BranchOptions["pathMap"] ): this; addConditionalEdges( - source: N | BranchOptions, - path?: Branch["condition"], - pathMap?: BranchOptions["pathMap"] + source: NodeNames | BranchOptions, + path?: Branch["condition"], + pathMap?: BranchOptions["pathMap"] ): this { - const options: BranchOptions = + const options: BranchOptions = typeof source === "object" ? source : { source, path: path!, pathMap }; this.warnIfCompiled( "Adding an edge to a graph that has already been compiled. This will not be reflected in the compiled graph." @@ -249,7 +252,7 @@ export class Graph< /** * @deprecated use `addEdge(START, key)` instead */ - setEntryPoint(key: N): this { + setEntryPoint(key: NodeNames): this { this.warnIfCompiled( "Setting the entry point of a graph that has already been compiled. This will not be reflected in the compiled graph." ); @@ -260,7 +263,7 @@ export class Graph< /** * @deprecated use `addEdge(key, END)` instead */ - setFinishPoint(key: N): this { + setFinishPoint(key: NodeNames): this { this.warnIfCompiled( "Setting a finish point of a graph that has already been compiled. This will not be reflected in the compiled graph." ); @@ -274,9 +277,9 @@ export class Graph< interruptAfter, }: { checkpointer?: BaseCheckpointSaver; - interruptBefore?: N[] | All; - interruptAfter?: N[] | All; - } = {}): CompiledGraph { + interruptBefore?: NodeNames[] | All; + interruptAfter?: NodeNames[] | All; + } = {}): CompiledGraph { // validate the graph this.validate([ ...(Array.isArray(interruptBefore) ? interruptBefore : []), @@ -284,20 +287,23 @@ export class Graph< ]); // create empty compiled graph - const compiled = new CompiledGraph({ + const compiled = new CompiledGraph({ builder: this, checkpointer, interruptAfter, interruptBefore, autoValidate: false, - nodes: {} as Record>, + nodes: {} as Record< + NodeNames | typeof START, + PregelNode + >, channels: { [START]: new EphemeralValue(), [END]: new EphemeralValue(), - } as Record, + } as Record, inputChannels: START, outputChannels: END, - streamChannels: [] as N[], + streamChannels: [] as NodeNames[], streamMode: "values", }); @@ -305,14 +311,14 @@ export class Graph< for (const [key, node] of Object.entries>( this.nodes )) { - compiled.attachNode(key as N, node); + compiled.attachNode(key as NodeNames, node); } for (const [start, end] of this.edges) { compiled.attachEdge(start, end); } for (const [start, branches] of Object.entries(this.branches)) { for (const [name, branch] of Object.entries(branches)) { - compiled.attachBranch(start as N, name, branch); + compiled.attachBranch(start as NodeNames, name, branch); } } @@ -375,35 +381,35 @@ export class Graph< } export class CompiledGraph< - N extends string, + NodeNames extends string, // eslint-disable-next-line @typescript-eslint/no-explicit-any RunInput = any, // eslint-disable-next-line @typescript-eslint/no-explicit-any RunOutput = any > extends Pregel< - Record>, - Record + Record>, + Record > { - declare NodeType: N; + declare NodeType: NodeNames; declare RunInput: RunInput; declare RunOutput: RunOutput; - builder: Graph; + builder: Graph; constructor({ builder, ...rest - }: { builder: Graph } & PregelParams< - Record>, - Record + }: { builder: Graph } & PregelParams< + Record>, + Record >) { super(rest); this.builder = builder; } - attachNode(key: N, node: NodeSpec): void { + attachNode(key: NodeNames, node: NodeSpec): void { this.channels[key] = new EphemeralValue(); this.nodes[key] = new PregelNode({ channels: [], @@ -414,10 +420,13 @@ export class CompiledGraph< .pipe( new ChannelWrite([{ channel: key, value: PASSTHROUGH }], [TAG_HIDDEN]) ); - (this.streamChannels as N[]).push(key); + (this.streamChannels as NodeNames[]).push(key); } - attachEdge(start: N | typeof START, end: N | typeof END): void { + attachEdge( + start: NodeNames | typeof START, + end: NodeNames | typeof END + ): void { if (end === END) { if (start === START) { throw new Error("Cannot have an edge from START to END"); @@ -432,9 +441,9 @@ export class CompiledGraph< } attachBranch( - start: N | typeof START, + start: NodeNames | typeof START, name: string, - branch: Branch + branch: Branch ) { // add hidden start node if (start === START && this.nodes[START]) { @@ -460,7 +469,7 @@ export class CompiledGraph< // attach branch readers const ends = branch.ends ? Object.values(branch.ends) - : (Object.keys(this.nodes) as N[]); + : (Object.keys(this.nodes) as NodeNames[]); for (const end of ends) { if (end !== END) { const channelName = `branch:${start}:${name}:${end}`; @@ -502,9 +511,9 @@ export class CompiledGraph< if (config?.xray) { const subgraph = isCompiledGraph(node) ? node.getGraph({ - ...config, - xray: typeof xray === "number" && xray > 0 ? xray - 1 : xray, - }) + ...config, + xray: typeof xray === "number" && xray > 0 ? xray - 1 : xray, + }) : node.runnable.getGraph(config); subgraph.trimFirstNode(); subgraph.trimLastNode(); diff --git a/libs/langgraph/src/graph/state.ts b/libs/langgraph/src/graph/state.ts index c8482190..187d258d 100644 --- a/libs/langgraph/src/graph/state.ts +++ b/libs/langgraph/src/graph/state.ts @@ -53,10 +53,10 @@ export type ChannelReducers = { export interface StateGraphArgs { channels: Channels extends object - ? Channels extends unknown[] - ? ChannelReducers<{ __root__: Channels }> - : ChannelReducers - : ChannelReducers<{ __root__: Channels }>; + ? Channels extends unknown[] + ? ChannelReducers<{ __root__: Channels }> + : ChannelReducers + : ChannelReducers<{ __root__: Channels }>; } export type StateGraphNodeSpec = NodeSpec< @@ -156,26 +156,37 @@ export type StateGraphArgsWithInputOutputSchemas< */ export class StateGraph< SD extends StateDefinition | unknown, - S = SD extends StateDefinition ? StateType : SD, - U = SD extends StateDefinition ? UpdateType : Partial, - N extends string = typeof START, - I extends StateDefinition = SD extends StateDefinition ? SD : StateDefinition, - O extends StateDefinition = SD extends StateDefinition ? SD : StateDefinition, - C extends StateDefinition = StateDefinition -> extends Graph> { + InputStateSchema = SD extends StateDefinition ? StateType : SD, + OutputStateSchema = SD extends StateDefinition + ? UpdateType + : Partial, + NodeNames extends string = typeof START, + InputDefinition extends StateDefinition = SD extends StateDefinition + ? SD + : StateDefinition, + OutputDefinition extends StateDefinition = SD extends StateDefinition + ? SD + : StateDefinition, + ConfigSchema extends StateDefinition = StateDefinition +> extends Graph< + NodeNames, + InputStateSchema, + OutputStateSchema, + StateGraphNodeSpec +> { channels: Record = {}; // TODO: this doesn't dedupe edges as in py, so worth fixing at some point - waitingEdges: Set<[N[], N]> = new Set(); + waitingEdges: Set<[NodeNames[], NodeNames]> = new Set(); /** @internal */ _schemaDefinition: StateDefinition; /** @internal */ - _inputDefinition: I; + _inputDefinition: InputDefinition; /** @internal */ - _outputDefinition: O; + _outputDefinition: OutputDefinition; /** * Map schemas to managed values @@ -184,35 +195,35 @@ export class StateGraph< _schemaDefinitions = new Map(); /** @internal Used only for typing. */ - _configSchema: C | undefined; + _configSchema: ConfigSchema | undefined; constructor( fields: SD extends StateDefinition ? - | SD - | AnnotationRoot - | StateGraphArgs - | StateGraphArgsWithStateSchema - | StateGraphArgsWithInputOutputSchemas - : StateGraphArgs, - configSchema?: AnnotationRoot + | SD + | AnnotationRoot + | StateGraphArgs + | StateGraphArgsWithStateSchema + | StateGraphArgsWithInputOutputSchemas + : StateGraphArgs, + configSchema?: AnnotationRoot ) { super(); if ( isStateGraphArgsWithInputOutputSchemas< SD extends StateDefinition ? SD : never, - O + OutputDefinition >(fields) ) { this._schemaDefinition = fields.input.spec; - this._inputDefinition = fields.input.spec as unknown as I; + this._inputDefinition = fields.input.spec as unknown as InputDefinition; this._outputDefinition = fields.output.spec; } else if (isStateGraphArgsWithStateSchema(fields)) { this._schemaDefinition = fields.stateSchema.spec; this._inputDefinition = (fields.input?.spec ?? - this._schemaDefinition) as I; + this._schemaDefinition) as InputDefinition; this._outputDefinition = (fields.output?.spec ?? - this._schemaDefinition) as O; + this._schemaDefinition) as OutputDefinition; } else if (isStateDefinition(fields) || isAnnotationRoot(fields)) { const spec = isAnnotationRoot(fields) ? fields.spec : fields; this._schemaDefinition = spec; @@ -269,15 +280,25 @@ export class StateGraph< } } - addNode( + addNode( key: K, action: RunnableLike< NodeInput, // eslint-disable-next-line @typescript-eslint/no-explicit-any - U extends object ? U & Record : U + OutputStateSchema extends object + ? OutputStateSchema & Record + : OutputStateSchema >, options?: StateGraphAddNodeOptions - ): StateGraph { + ): StateGraph< + SD, + InputStateSchema, + OutputStateSchema, + NodeNames | K, + InputDefinition, + OutputDefinition, + ConfigSchema + > { if (key in this.channels) { throw new Error( `${key} is already being used as a state attribute (a.k.a. a channel), cannot also be used as a node name.` @@ -303,19 +324,33 @@ export class StateGraph< if (options?.input !== undefined) { this._addSchema(options.input.spec); } - const nodeSpec: StateGraphNodeSpec = { - runnable: _coerceToRunnable(action) as unknown as Runnable, + const nodeSpec: StateGraphNodeSpec = { + runnable: _coerceToRunnable(action) as unknown as Runnable< + InputStateSchema, + OutputStateSchema + >, retryPolicy: options?.retryPolicy, metadata: options?.metadata, input: options?.input?.spec ?? this._schemaDefinition, }; - this.nodes[key as unknown as N] = nodeSpec; - - return this as StateGraph; + this.nodes[key as unknown as NodeNames] = nodeSpec; + + return this as StateGraph< + SD, + InputStateSchema, + OutputStateSchema, + NodeNames | K, + InputDefinition, + OutputDefinition, + ConfigSchema + >; } - addEdge(startKey: typeof START | N | N[], endKey: N | typeof END): this { + addEdge( + startKey: typeof START | NodeNames | NodeNames[], + endKey: NodeNames | typeof END + ): this { if (typeof startKey === "string") { return super.addEdge(startKey, endKey); } @@ -323,7 +358,7 @@ export class StateGraph< if (this.compiled) { console.warn( "Adding an edge to a graph that has already been compiled. This will " + - "not be reflected in the compiled graph." + "not be reflected in the compiled graph." ); } @@ -355,9 +390,16 @@ export class StateGraph< }: { checkpointer?: BaseCheckpointSaver; store?: BaseStore; - interruptBefore?: N[] | All; - interruptAfter?: N[] | All; - } = {}): CompiledStateGraph { + interruptBefore?: NodeNames[] | All; + interruptAfter?: NodeNames[] | All; + } = {}): CompiledStateGraph< + InputStateSchema, + OutputStateSchema, + NodeNames, + InputDefinition, + OutputDefinition, + ConfigSchema + > { // validate the graph this.validate([ ...(Array.isArray(interruptBefore) ? interruptBefore : []), @@ -376,17 +418,27 @@ export class StateGraph< streamKeys.length === 1 && streamKeys[0] === ROOT ? ROOT : streamKeys; // create empty compiled graph - const compiled = new CompiledStateGraph({ + const compiled = new CompiledStateGraph< + InputStateSchema, + OutputStateSchema, + NodeNames, + InputDefinition, + OutputDefinition, + ConfigSchema + >({ builder: this, checkpointer, interruptAfter, interruptBefore, autoValidate: false, - nodes: {} as Record>, + nodes: {} as Record< + NodeNames | typeof START, + PregelNode + >, channels: { ...this.channels, [START]: new EphemeralValue(), - } as Record, + } as Record, inputChannels: START, outputChannels, streamChannels, @@ -396,10 +448,10 @@ export class StateGraph< // attach nodes, edges and branches compiled.attachNode(START); - for (const [key, node] of Object.entries>( - this.nodes - )) { - compiled.attachNode(key as N, node); + for (const [key, node] of Object.entries< + StateGraphNodeSpec + >(this.nodes)) { + compiled.attachNode(key as NodeNames, node); } for (const [start, end] of this.edges) { compiled.attachEdge(start, end); @@ -409,7 +461,7 @@ export class StateGraph< } for (const [start, branches] of Object.entries(this.branches)) { for (const [name, branch] of Object.entries(branches)) { - compiled.attachBranch(start as N, name, branch); + compiled.attachBranch(start as NodeNames, name, branch); } } @@ -467,14 +519,14 @@ export class CompiledStateGraph< key === ROOT ? { channel: key, value: PASSTHROUGH, skipNone: true } : { - channel: key, - value: PASSTHROUGH, - mapper: new RunnableCallable({ - func: getStateKey.bind(null, key as keyof U), - trace: false, - recurse: false, - }), - } + channel: key, + value: PASSTHROUGH, + mapper: new RunnableCallable({ + func: getStateKey.bind(null, key as keyof U), + trace: false, + recurse: false, + }), + } ); // add node and output channel @@ -509,11 +561,11 @@ export class CompiledStateGraph< mapper: isSingleInput ? undefined : // eslint-disable-next-line @typescript-eslint/no-explicit-any - (input: Record) => { - return Object.fromEntries( - Object.entries(input).filter(([k]) => k in inputValues) - ); - }, + (input: Record) => { + return Object.fromEntries( + Object.entries(input).filter(([k]) => k in inputValues) + ); + }, bound: node?.runnable, metadata: node?.metadata, retryPolicy: node?.retryPolicy, diff --git a/libs/langgraph/src/pregel/index.ts b/libs/langgraph/src/pregel/index.ts index 312315a9..39d6499d 100644 --- a/libs/langgraph/src/pregel/index.ts +++ b/libs/langgraph/src/pregel/index.ts @@ -55,6 +55,8 @@ import { StateSnapshot, StreamMode, StreamOutput, + ChannelsStateType, + AllStreamOutputTypes, } from "./types.js"; import { GraphRecursionError, @@ -192,16 +194,10 @@ export interface PregelOptions< debug?: boolean; } -// eslint-disable-next-line @typescript-eslint/no-explicit-any -export type PregelInputType = any; - -// eslint-disable-next-line @typescript-eslint/no-explicit-any -export type PregelOutputType = any; - export class Pregel extends Runnable< - PregelInputType, - PregelOutputType, + any, // Partial>, // input type + any, // AllStreamOutputTypes, // output type PregelOptions > implements PregelInterface @@ -650,13 +646,13 @@ export class Pregel * @param options.interruptAfter Nodes to interrupt after. * @param options.debug Whether to print debug information during execution. */ - override async stream( - input: PregelInputType, - options?: Partial> - ): Promise< - IterableReadableStream> - > { - return super.stream(input, options); + override async stream( + input: Partial>, + options?: Partial> + ): Promise>> { + return super.stream(input, options) as Promise< + IterableReadableStream> + >; } protected async prepareSpecs( @@ -721,10 +717,10 @@ export class Pregel }; } - override async *_streamIterator( - input: PregelInputType, - options?: Partial> - ): AsyncGenerator { + override async *_streamIterator( + input: Partial>, + options?: Partial> + ): AsyncGenerator> { const inputConfig = ensureConfig(options); if ( inputConfig.recursionLimit === undefined || @@ -916,17 +912,17 @@ export class Pregel * @param options.interruptAfter Nodes to interrupt after. * @param options.debug Whether to print debug information during execution. */ - override async invoke( - input: PregelInputType, - options?: Partial> - ): Promise { - const streamMode = options?.streamMode ?? "values"; + override async invoke( + input: ChannelsStateType, + options?: Partial> + ): Promise> { + const streamMode = options?.streamMode ?? ("values" as const); const config = { ...ensureConfig(options), outputKeys: options?.outputKeys ?? this.outputChannels, streamMode, }; - const chunks = []; + const chunks: StreamOutput[] = []; const stream = await this.stream(input, config); for await (const chunk of stream) { chunks.push(chunk); diff --git a/libs/langgraph/src/pregel/types.ts b/libs/langgraph/src/pregel/types.ts index 12c550da..be2081e4 100644 --- a/libs/langgraph/src/pregel/types.ts +++ b/libs/langgraph/src/pregel/types.ts @@ -10,12 +10,43 @@ import type { PregelNode } from "./read.js"; import { RetryPolicy } from "./utils.js"; import { Interrupt } from "../constants.js"; import { BaseStore } from "../store/base.js"; -import { type ManagedValueSpec } from "../managed/base.js"; - +import { + ConfiguredManagedValue, + ManagedValue, + type ManagedValueSpec, +} from "../managed/base.js"; export type ChannelsType = Record; export type NodesType = Record; +type ExtractChannelValueType = Channel extends BaseChannel + ? Channel["ValueType"] + : Channel extends ManagedValueSpec + ? Channel extends ConfiguredManagedValue + ? V + : Channel extends ManagedValue + ? V + : never + : never; + +export type ChannelsStateType = { + [Key in keyof Channels]: ExtractChannelValueType; +}; + +type ExtractChannelUpdateType = Channel extends BaseChannel + ? Channel["UpdateType"] + : Channel extends ManagedValueSpec + ? Channel extends ConfiguredManagedValue + ? V + : Channel extends ManagedValue + ? V + : never + : never; + +export type ChannelsUpdateType = { + [Key in keyof Channels]?: ExtractChannelUpdateType; +}; + export type DebugOutput = { type: string; timestamp: string; @@ -40,12 +71,11 @@ export type StreamMode = type SingleStreamModeOutput< Mode extends SingleStreamMode, Nodes extends NodesType, - Channels extends ChannelsType, - Schema + Channels extends ChannelsType > = Mode extends "values" - ? Schema + ? ChannelsUpdateType : Mode extends "updates" - ? { [K in keyof Nodes]: Partial } + ? { [K in keyof Nodes]: ChannelsUpdateType } : Mode extends "debug" ? DebugOutput : never; @@ -53,25 +83,71 @@ type SingleStreamModeOutput< export type StreamOutput< Mode extends StreamMode, Nodes extends NodesType, - Channels extends ChannelsType, - Schema + Channels extends ChannelsType > = Mode extends SingleStreamMode - ? SingleStreamModeOutput + ? SingleStreamModeOutput : Mode extends [SingleStreamMode] - ? SingleStreamModeOutput + ? SingleStreamModeOutput : Mode extends [SingleStreamMode, SingleStreamMode] ? [ - SingleStreamModeOutput, - SingleStreamModeOutput - ] + SingleStreamModeOutput, + SingleStreamModeOutput + ] : Mode extends [SingleStreamMode, SingleStreamMode, SingleStreamMode] ? [ - SingleStreamModeOutput, - SingleStreamModeOutput, - SingleStreamModeOutput - ] + SingleStreamModeOutput, + SingleStreamModeOutput, + SingleStreamModeOutput + ] : never; +// Gross hack to avoid recursion limits & define the narrowest type that covers all possible stream output types +// This is what gets passed to the lower abstraction layers (such as `Runnable`) and the precise output type is +// narrowed at higher levels such as `StateGraph.stream()` to a specific one once the `StreamMode` is known +export type AllStreamOutputTypes< + Nodes extends NodesType, + Channels extends ChannelsType +> = + | StreamOutput<["values"], Nodes, Channels> + | StreamOutput<["updates"], Nodes, Channels> + | StreamOutput<["debug"], Nodes, Channels> + | StreamOutput<["values", "updates"], Nodes, Channels> + | StreamOutput<["values", "values"], Nodes, Channels> + | StreamOutput<["values", "debug"], Nodes, Channels> + | StreamOutput<["updates", "updates"], Nodes, Channels> + | StreamOutput<["updates", "values"], Nodes, Channels> + | StreamOutput<["updates", "debug"], Nodes, Channels> + | StreamOutput<["debug", "updates"], Nodes, Channels> + | StreamOutput<["debug", "values"], Nodes, Channels> + | StreamOutput<["debug", "debug"], Nodes, Channels> + | StreamOutput<["updates", "updates", "updates"], Nodes, Channels> + | StreamOutput<["updates", "updates", "values"], Nodes, Channels> + | StreamOutput<["updates", "updates", "debug"], Nodes, Channels> + | StreamOutput<["updates", "values", "updates"], Nodes, Channels> + | StreamOutput<["updates", "values", "values"], Nodes, Channels> + | StreamOutput<["updates", "values", "debug"], Nodes, Channels> + | StreamOutput<["updates", "debug", "updates"], Nodes, Channels> + | StreamOutput<["updates", "debug", "values"], Nodes, Channels> + | StreamOutput<["updates", "debug", "debug"], Nodes, Channels> + | StreamOutput<["values", "updates", "updates"], Nodes, Channels> + | StreamOutput<["values", "updates", "values"], Nodes, Channels> + | StreamOutput<["values", "updates", "debug"], Nodes, Channels> + | StreamOutput<["values", "values", "updates"], Nodes, Channels> + | StreamOutput<["values", "values", "values"], Nodes, Channels> + | StreamOutput<["values", "values", "debug"], Nodes, Channels> + | StreamOutput<["values", "debug", "updates"], Nodes, Channels> + | StreamOutput<["values", "debug", "values"], Nodes, Channels> + | StreamOutput<["values", "debug", "debug"], Nodes, Channels> + | StreamOutput<["debug", "updates", "updates"], Nodes, Channels> + | StreamOutput<["debug", "updates", "values"], Nodes, Channels> + | StreamOutput<["debug", "updates", "debug"], Nodes, Channels> + | StreamOutput<["debug", "values", "updates"], Nodes, Channels> + | StreamOutput<["debug", "values", "values"], Nodes, Channels> + | StreamOutput<["debug", "values", "debug"], Nodes, Channels> + | StreamOutput<["debug", "debug", "updates"], Nodes, Channels> + | StreamOutput<["debug", "debug", "values"], Nodes, Channels> + | StreamOutput<["debug", "debug", "debug"], Nodes, Channels>; + export interface PregelInterface< Nodes extends NodesType, Channels extends ChannelsType From 75af108605c6a0bcc4995f0e461a5e72f9567dcf Mon Sep 17 00:00:00 2001 From: Allan Deutsch Date: Thu, 26 Sep 2024 12:59:06 -0700 Subject: [PATCH 10/15] formatting --- libs/langgraph/src/graph/graph.ts | 14 +++---- libs/langgraph/src/graph/state.ts | 62 +++++++++++++++--------------- libs/langgraph/src/pregel/types.ts | 30 +++++++-------- 3 files changed, 53 insertions(+), 53 deletions(-) diff --git a/libs/langgraph/src/graph/graph.ts b/libs/langgraph/src/graph/graph.ts index 9da28c82..b7f80120 100644 --- a/libs/langgraph/src/graph/graph.ts +++ b/libs/langgraph/src/graph/graph.ts @@ -53,9 +53,9 @@ export class Branch { this.condition = options.path; this.ends = Array.isArray(options.pathMap) ? options.pathMap.reduce((acc, n) => { - acc[n] = n; - return acc; - }, {} as Record) + acc[n] = n; + return acc; + }, {} as Record) : options.pathMap; } @@ -74,7 +74,7 @@ export class Branch { if (e.name === NodeInterrupt.unminifiable_name) { console.warn( "[WARN]: 'NodeInterrupt' thrown in conditional edge. This is likely a bug in your graph implementation.\n" + - "NodeInterrupt should only be thrown inside a node, not in edge conditions." + "NodeInterrupt should only be thrown inside a node, not in edge conditions." ); } throw e; @@ -511,9 +511,9 @@ export class CompiledGraph< if (config?.xray) { const subgraph = isCompiledGraph(node) ? node.getGraph({ - ...config, - xray: typeof xray === "number" && xray > 0 ? xray - 1 : xray, - }) + ...config, + xray: typeof xray === "number" && xray > 0 ? xray - 1 : xray, + }) : node.runnable.getGraph(config); subgraph.trimFirstNode(); subgraph.trimLastNode(); diff --git a/libs/langgraph/src/graph/state.ts b/libs/langgraph/src/graph/state.ts index 187d258d..c1d40072 100644 --- a/libs/langgraph/src/graph/state.ts +++ b/libs/langgraph/src/graph/state.ts @@ -53,10 +53,10 @@ export type ChannelReducers = { export interface StateGraphArgs { channels: Channels extends object - ? Channels extends unknown[] - ? ChannelReducers<{ __root__: Channels }> - : ChannelReducers - : ChannelReducers<{ __root__: Channels }>; + ? Channels extends unknown[] + ? ChannelReducers<{ __root__: Channels }> + : ChannelReducers + : ChannelReducers<{ __root__: Channels }>; } export type StateGraphNodeSpec = NodeSpec< @@ -158,15 +158,15 @@ export class StateGraph< SD extends StateDefinition | unknown, InputStateSchema = SD extends StateDefinition ? StateType : SD, OutputStateSchema = SD extends StateDefinition - ? UpdateType - : Partial, + ? UpdateType + : Partial, NodeNames extends string = typeof START, InputDefinition extends StateDefinition = SD extends StateDefinition - ? SD - : StateDefinition, + ? SD + : StateDefinition, OutputDefinition extends StateDefinition = SD extends StateDefinition - ? SD - : StateDefinition, + ? SD + : StateDefinition, ConfigSchema extends StateDefinition = StateDefinition > extends Graph< NodeNames, @@ -200,11 +200,11 @@ export class StateGraph< constructor( fields: SD extends StateDefinition ? - | SD - | AnnotationRoot - | StateGraphArgs - | StateGraphArgsWithStateSchema - | StateGraphArgsWithInputOutputSchemas + | SD + | AnnotationRoot + | StateGraphArgs + | StateGraphArgsWithStateSchema + | StateGraphArgsWithInputOutputSchemas : StateGraphArgs, configSchema?: AnnotationRoot ) { @@ -286,8 +286,8 @@ export class StateGraph< NodeInput, // eslint-disable-next-line @typescript-eslint/no-explicit-any OutputStateSchema extends object - ? OutputStateSchema & Record - : OutputStateSchema + ? OutputStateSchema & Record + : OutputStateSchema >, options?: StateGraphAddNodeOptions ): StateGraph< @@ -358,7 +358,7 @@ export class StateGraph< if (this.compiled) { console.warn( "Adding an edge to a graph that has already been compiled. This will " + - "not be reflected in the compiled graph." + "not be reflected in the compiled graph." ); } @@ -519,14 +519,14 @@ export class CompiledStateGraph< key === ROOT ? { channel: key, value: PASSTHROUGH, skipNone: true } : { - channel: key, - value: PASSTHROUGH, - mapper: new RunnableCallable({ - func: getStateKey.bind(null, key as keyof U), - trace: false, - recurse: false, - }), - } + channel: key, + value: PASSTHROUGH, + mapper: new RunnableCallable({ + func: getStateKey.bind(null, key as keyof U), + trace: false, + recurse: false, + }), + } ); // add node and output channel @@ -561,11 +561,11 @@ export class CompiledStateGraph< mapper: isSingleInput ? undefined : // eslint-disable-next-line @typescript-eslint/no-explicit-any - (input: Record) => { - return Object.fromEntries( - Object.entries(input).filter(([k]) => k in inputValues) - ); - }, + (input: Record) => { + return Object.fromEntries( + Object.entries(input).filter(([k]) => k in inputValues) + ); + }, bound: node?.runnable, metadata: node?.metadata, retryPolicy: node?.retryPolicy, diff --git a/libs/langgraph/src/pregel/types.ts b/libs/langgraph/src/pregel/types.ts index be2081e4..5c73cfbf 100644 --- a/libs/langgraph/src/pregel/types.ts +++ b/libs/langgraph/src/pregel/types.ts @@ -23,10 +23,10 @@ type ExtractChannelValueType = Channel extends BaseChannel ? Channel["ValueType"] : Channel extends ManagedValueSpec ? Channel extends ConfiguredManagedValue - ? V - : Channel extends ManagedValue - ? V - : never + ? V + : Channel extends ManagedValue + ? V + : never : never; export type ChannelsStateType = { @@ -37,10 +37,10 @@ type ExtractChannelUpdateType = Channel extends BaseChannel ? Channel["UpdateType"] : Channel extends ManagedValueSpec ? Channel extends ConfiguredManagedValue - ? V - : Channel extends ManagedValue - ? V - : never + ? V + : Channel extends ManagedValue + ? V + : never : never; export type ChannelsUpdateType = { @@ -90,15 +90,15 @@ export type StreamOutput< ? SingleStreamModeOutput : Mode extends [SingleStreamMode, SingleStreamMode] ? [ - SingleStreamModeOutput, - SingleStreamModeOutput - ] + SingleStreamModeOutput, + SingleStreamModeOutput + ] : Mode extends [SingleStreamMode, SingleStreamMode, SingleStreamMode] ? [ - SingleStreamModeOutput, - SingleStreamModeOutput, - SingleStreamModeOutput - ] + SingleStreamModeOutput, + SingleStreamModeOutput, + SingleStreamModeOutput + ] : never; // Gross hack to avoid recursion limits & define the narrowest type that covers all possible stream output types From ab4e684dd72663e3c2347be5ac65fd1b6b36be8d Mon Sep 17 00:00:00 2001 From: Allan Deutsch Date: Fri, 27 Sep 2024 08:23:37 -0700 Subject: [PATCH 11/15] Only invoke remains! --- libs/langgraph/src/pregel/index.ts | 29 +++++++++------ libs/langgraph/src/pregel/types.ts | 57 ++++++++++++++++++------------ 2 files changed, 53 insertions(+), 33 deletions(-) diff --git a/libs/langgraph/src/pregel/index.ts b/libs/langgraph/src/pregel/index.ts index 39d6499d..9daab9aa 100644 --- a/libs/langgraph/src/pregel/index.ts +++ b/libs/langgraph/src/pregel/index.ts @@ -56,7 +56,7 @@ import { StreamMode, StreamOutput, ChannelsStateType, - AllStreamOutputTypes, + AllStreamInvokeOutputTypes, } from "./types.js"; import { GraphRecursionError, @@ -197,7 +197,7 @@ export interface PregelOptions< export class Pregel extends Runnable< any, // Partial>, // input type - any, // AllStreamOutputTypes, // output type + AllStreamInvokeOutputTypes, // output type PregelOptions > implements PregelInterface @@ -647,7 +647,7 @@ export class Pregel * @param options.debug Whether to print debug information during execution. */ override async stream( - input: Partial>, + input: any, //Partial>, options?: Partial> ): Promise>> { return super.stream(input, options) as Promise< @@ -796,9 +796,9 @@ export class Pregel } if (streamMode.includes(nextItem[0])) { if (streamMode.length === 1) { - yield nextItem[1]; + yield nextItem[1] as StreamOutput; } else { - yield nextItem; + yield nextItem as unknown as StreamOutput; } } } @@ -839,9 +839,13 @@ export class Pregel } if (streamMode.includes(nextItem[0])) { if (streamMode.length === 1) { - yield nextItem[1]; + yield nextItem[1] as StreamOutput; } else { - yield nextItem; + yield nextItem as unknown as StreamOutput< + Mode, + Nodes, + Channels + >; } } } @@ -865,9 +869,9 @@ export class Pregel } if (streamMode.includes(nextItem[0])) { if (streamMode.length === 1) { - yield nextItem[1]; + yield nextItem[1] as StreamOutput; } else { - yield nextItem; + yield nextItem as unknown as StreamOutput; } } } @@ -913,9 +917,12 @@ export class Pregel * @param options.debug Whether to print debug information during execution. */ override async invoke( - input: ChannelsStateType, + input: any, //ChannelsStateType, options?: Partial> - ): Promise> { + ): Promise< + | StreamOutput + | Array> + > { const streamMode = options?.streamMode ?? ("values" as const); const config = { ...ensureConfig(options), diff --git a/libs/langgraph/src/pregel/types.ts b/libs/langgraph/src/pregel/types.ts index 5c73cfbf..6b75c49c 100644 --- a/libs/langgraph/src/pregel/types.ts +++ b/libs/langgraph/src/pregel/types.ts @@ -23,10 +23,10 @@ type ExtractChannelValueType = Channel extends BaseChannel ? Channel["ValueType"] : Channel extends ManagedValueSpec ? Channel extends ConfiguredManagedValue - ? V - : Channel extends ManagedValue - ? V - : never + ? V + : Channel extends ManagedValue + ? V + : never : never; export type ChannelsStateType = { @@ -37,10 +37,10 @@ type ExtractChannelUpdateType = Channel extends BaseChannel ? Channel["UpdateType"] : Channel extends ManagedValueSpec ? Channel extends ConfiguredManagedValue - ? V - : Channel extends ManagedValue - ? V - : never + ? V + : Channel extends ManagedValue + ? V + : never : never; export type ChannelsUpdateType = { @@ -68,7 +68,7 @@ export type StreamMode = | [SingleStreamMode, SingleStreamMode] | [SingleStreamMode, SingleStreamMode, SingleStreamMode]; -type SingleStreamModeOutput< +export type SingleStreamModeOutput< Mode extends SingleStreamMode, Nodes extends NodesType, Channels extends ChannelsType @@ -80,27 +80,29 @@ type SingleStreamModeOutput< ? DebugOutput : never; +export type StreamModeOutputTuple< + Mode extends SingleStreamMode, + Nodes extends NodesType, + Channels extends ChannelsType +> = [Mode, SingleStreamModeOutput]; + export type StreamOutput< Mode extends StreamMode, Nodes extends NodesType, Channels extends ChannelsType > = Mode extends SingleStreamMode ? SingleStreamModeOutput - : Mode extends [SingleStreamMode] - ? SingleStreamModeOutput - : Mode extends [SingleStreamMode, SingleStreamMode] - ? [ - SingleStreamModeOutput, - SingleStreamModeOutput - ] - : Mode extends [SingleStreamMode, SingleStreamMode, SingleStreamMode] - ? [ - SingleStreamModeOutput, - SingleStreamModeOutput, - SingleStreamModeOutput - ] + : Mode[number] extends SingleStreamMode + ? StreamModeOutputTuple : never; +export type InvokeOutput< + Mode extends StreamMode, + Nodes extends NodesType, + Channels extends ChannelsType +> = Mode extends "values" + ? StreamOutput + : Array>; // Gross hack to avoid recursion limits & define the narrowest type that covers all possible stream output types // This is what gets passed to the lower abstraction layers (such as `Runnable`) and the precise output type is // narrowed at higher levels such as `StateGraph.stream()` to a specific one once the `StreamMode` is known @@ -108,6 +110,10 @@ export type AllStreamOutputTypes< Nodes extends NodesType, Channels extends ChannelsType > = + | StreamOutput<"values" | "updates", Nodes, Channels> + | StreamOutput<"values" | "debug", Nodes, Channels> + | StreamOutput<"updates" | "debug", Nodes, Channels> + | StreamOutput<"values" | "updates" | "debug", Nodes, Channels> | StreamOutput<["values"], Nodes, Channels> | StreamOutput<["updates"], Nodes, Channels> | StreamOutput<["debug"], Nodes, Channels> @@ -148,6 +154,13 @@ export type AllStreamOutputTypes< | StreamOutput<["debug", "debug", "values"], Nodes, Channels> | StreamOutput<["debug", "debug", "debug"], Nodes, Channels>; +export type AllStreamInvokeOutputTypes< + Nodes extends NodesType, + Channels extends ChannelsType +> = + | AllStreamOutputTypes + | Array>; + export interface PregelInterface< Nodes extends NodesType, Channels extends ChannelsType From 9dafb44bce46469f2ecd6ab0cb32878ffb6e3600 Mon Sep 17 00:00:00 2001 From: Allan Deutsch Date: Fri, 27 Sep 2024 12:37:31 -0700 Subject: [PATCH 12/15] invoke maybe array return type --- libs/langgraph/src/pregel/index.ts | 10 ++++------ libs/langgraph/src/pregel/types.ts | 7 +++++++ 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/libs/langgraph/src/pregel/index.ts b/libs/langgraph/src/pregel/index.ts index 9daab9aa..ce4b8086 100644 --- a/libs/langgraph/src/pregel/index.ts +++ b/libs/langgraph/src/pregel/index.ts @@ -57,6 +57,7 @@ import { StreamOutput, ChannelsStateType, AllStreamInvokeOutputTypes, + InvokeOutputType, } from "./types.js"; import { GraphRecursionError, @@ -916,13 +917,10 @@ export class Pregel * @param options.interruptAfter Nodes to interrupt after. * @param options.debug Whether to print debug information during execution. */ - override async invoke( + override async invoke( input: any, //ChannelsStateType, options?: Partial> - ): Promise< - | StreamOutput - | Array> - > { + ): Promise> { const streamMode = options?.streamMode ?? ("values" as const); const config = { ...ensureConfig(options), @@ -932,7 +930,7 @@ export class Pregel const chunks: StreamOutput[] = []; const stream = await this.stream(input, config); for await (const chunk of stream) { - chunks.push(chunk); + chunks.push(chunk as StreamOutput); } if (streamMode === "values") { return chunks[chunks.length - 1]; diff --git a/libs/langgraph/src/pregel/types.ts b/libs/langgraph/src/pregel/types.ts index 6b75c49c..8656dd23 100644 --- a/libs/langgraph/src/pregel/types.ts +++ b/libs/langgraph/src/pregel/types.ts @@ -160,6 +160,13 @@ export type AllStreamInvokeOutputTypes< > = | AllStreamOutputTypes | Array>; +export type InvokeOutputType< + Mode extends StreamMode, + Nodes extends NodesType, + Channels extends ChannelsType +> = Mode extends "values" + ? StreamOutput + : Array>; export interface PregelInterface< Nodes extends NodesType, From aed73d951e43c42541b86b6c452bf790355c0c44 Mon Sep 17 00:00:00 2001 From: Allan Deutsch Date: Fri, 27 Sep 2024 12:50:22 -0700 Subject: [PATCH 13/15] invoke return type needs narrowing still --- libs/langgraph/src/pregel/index.ts | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/libs/langgraph/src/pregel/index.ts b/libs/langgraph/src/pregel/index.ts index ce4b8086..6c552339 100644 --- a/libs/langgraph/src/pregel/index.ts +++ b/libs/langgraph/src/pregel/index.ts @@ -57,7 +57,6 @@ import { StreamOutput, ChannelsStateType, AllStreamInvokeOutputTypes, - InvokeOutputType, } from "./types.js"; import { GraphRecursionError, @@ -920,18 +919,23 @@ export class Pregel override async invoke( input: any, //ChannelsStateType, options?: Partial> - ): Promise> { + ): Promise< + Mode extends "values" + ? StreamOutput + : Array> + > { const streamMode = options?.streamMode ?? ("values" as const); const config = { ...ensureConfig(options), outputKeys: options?.outputKeys ?? this.outputChannels, streamMode, }; - const chunks: StreamOutput[] = []; + const chunks: Array> = []; const stream = await this.stream(input, config); for await (const chunk of stream) { chunks.push(chunk as StreamOutput); } + if (streamMode === "values") { return chunks[chunks.length - 1]; } From 9dd5bc5e9cbde51ae317fa18e42e739292dc353a Mon Sep 17 00:00:00 2001 From: Allan Deutsch Date: Fri, 27 Sep 2024 13:42:38 -0700 Subject: [PATCH 14/15] Invoke return types are happy --- libs/langgraph/src/pregel/index.ts | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/libs/langgraph/src/pregel/index.ts b/libs/langgraph/src/pregel/index.ts index 6c552339..4e6d06ea 100644 --- a/libs/langgraph/src/pregel/index.ts +++ b/libs/langgraph/src/pregel/index.ts @@ -937,8 +937,12 @@ export class Pregel } if (streamMode === "values") { - return chunks[chunks.length - 1]; + return chunks[chunks.length - 1] as Mode extends "values" + ? StreamOutput + : Array>; } - return chunks; + return chunks as Mode extends "values" + ? StreamOutput + : Array>; } } From cd5a085ff7ee49cba13144b0f256b752eacdd74d Mon Sep 17 00:00:00 2001 From: Allan Deutsch Date: Fri, 27 Sep 2024 14:34:25 -0700 Subject: [PATCH 15/15] Swapped return type to use ChannelStateType --- libs/langgraph/src/pregel/index.ts | 4 ++-- libs/langgraph/src/pregel/types.ts | 21 +++++++++++---------- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/libs/langgraph/src/pregel/index.ts b/libs/langgraph/src/pregel/index.ts index 4e6d06ea..327ce055 100644 --- a/libs/langgraph/src/pregel/index.ts +++ b/libs/langgraph/src/pregel/index.ts @@ -921,8 +921,8 @@ export class Pregel options?: Partial> ): Promise< Mode extends "values" - ? StreamOutput - : Array> + ? StreamOutput + : Array> > { const streamMode = options?.streamMode ?? ("values" as const); const config = { diff --git a/libs/langgraph/src/pregel/types.ts b/libs/langgraph/src/pregel/types.ts index 8656dd23..6bf6f6fb 100644 --- a/libs/langgraph/src/pregel/types.ts +++ b/libs/langgraph/src/pregel/types.ts @@ -23,10 +23,10 @@ type ExtractChannelValueType = Channel extends BaseChannel ? Channel["ValueType"] : Channel extends ManagedValueSpec ? Channel extends ConfiguredManagedValue - ? V - : Channel extends ManagedValue - ? V - : never + ? V + : Channel extends ManagedValue + ? V + : never : never; export type ChannelsStateType = { @@ -37,10 +37,10 @@ type ExtractChannelUpdateType = Channel extends BaseChannel ? Channel["UpdateType"] : Channel extends ManagedValueSpec ? Channel extends ConfiguredManagedValue - ? V - : Channel extends ManagedValue - ? V - : never + ? V + : Channel extends ManagedValue + ? V + : never : never; export type ChannelsUpdateType = { @@ -73,9 +73,9 @@ export type SingleStreamModeOutput< Nodes extends NodesType, Channels extends ChannelsType > = Mode extends "values" - ? ChannelsUpdateType + ? ChannelsStateType : Mode extends "updates" - ? { [K in keyof Nodes]: ChannelsUpdateType } + ? { [K in keyof Nodes]: ChannelsStateType } : Mode extends "debug" ? DebugOutput : never; @@ -160,6 +160,7 @@ export type AllStreamInvokeOutputTypes< > = | AllStreamOutputTypes | Array>; + export type InvokeOutputType< Mode extends StreamMode, Nodes extends NodesType,