diff --git a/libs/langgraph/src/channels/base.ts b/libs/langgraph/src/channels/base.ts index d583dcfb..e2c3b5c1 100644 --- a/libs/langgraph/src/channels/base.ts +++ b/libs/langgraph/src/channels/base.ts @@ -78,15 +78,15 @@ export abstract class BaseChannel< } } -export function emptyChannels>( - channels: Cc, +export function emptyChannels>( + channels: Channels, checkpoint: ReadonlyCheckpoint -): Cc { +): Channels { const filteredChannels = Object.fromEntries( Object.entries(channels).filter(([, value]) => isBaseChannel(value)) - ) as Cc; + ) as Channels; - const newChannels = {} as Cc; + const newChannels = {} as Channels; for (const k in filteredChannels) { if (Object.prototype.hasOwnProperty.call(filteredChannels, k)) { const channelValue = checkpoint.channel_values[k]; diff --git a/libs/langgraph/src/graph/graph.ts b/libs/langgraph/src/graph/graph.ts index 55ac5ca7..b7f80120 100644 --- a/libs/langgraph/src/graph/graph.ts +++ b/libs/langgraph/src/graph/graph.ts @@ -119,7 +119,7 @@ export type NodeSpec = { export type AddNodeOptions = { metadata?: Record }; export class Graph< - N extends string = typeof END, + NodeNames extends string = typeof END, // eslint-disable-next-line @typescript-eslint/no-explicit-any RunInput = any, // eslint-disable-next-line @typescript-eslint/no-explicit-any @@ -129,18 +129,18 @@ export class Graph< RunOutput > > { - nodes: Record; + nodes: Record; - edges: Set<[N | typeof START, N | typeof END]>; + edges: Set<[NodeNames | typeof START, NodeNames | typeof END]>; - branches: Record>>; + branches: Record>>; entryPoint?: string; compiled = false; constructor() { - this.nodes = {} as Record; + this.nodes = {} as Record; this.edges = new Set(); this.branches = {}; } @@ -159,7 +159,7 @@ export class Graph< key: K, action: RunnableLike, options?: AddNodeOptions - ): Graph { + ): Graph { if (key.includes(CHECKPOINT_NAMESPACE_SEPARATOR)) { throw new Error( `"${CHECKPOINT_NAMESPACE_SEPARATOR}" is a reserved character and is not allowed in node names.` @@ -176,7 +176,7 @@ export class Graph< throw new Error(`Node \`${key}\` is reserved.`); } - this.nodes[key as unknown as N] = { + this.nodes[key as unknown as NodeNames] = { runnable: _coerceToRunnable( // Account for arbitrary state due to Send API action as RunnableLike @@ -184,10 +184,13 @@ export class Graph< metadata: options?.metadata, } as NodeSpecType; - return this as Graph; + return this as Graph; } - addEdge(startKey: N | typeof START, endKey: N | typeof END): this { + addEdge( + startKey: NodeNames | typeof START, + endKey: NodeNames | typeof END + ): this { this.warnIfCompiled( `Adding an edge to a graph that has already been compiled. This will not be reflected in the compiled graph.` ); @@ -212,20 +215,20 @@ export class Graph< return this; } - addConditionalEdges(source: BranchOptions): this; + addConditionalEdges(source: BranchOptions): this; addConditionalEdges( - source: N, - path: Branch["condition"], - pathMap?: BranchOptions["pathMap"] + source: NodeNames, + path: Branch["condition"], + pathMap?: BranchOptions["pathMap"] ): this; addConditionalEdges( - source: N | BranchOptions, - path?: Branch["condition"], - pathMap?: BranchOptions["pathMap"] + source: NodeNames | BranchOptions, + path?: Branch["condition"], + pathMap?: BranchOptions["pathMap"] ): this { - const options: BranchOptions = + const options: BranchOptions = typeof source === "object" ? source : { source, path: path!, pathMap }; this.warnIfCompiled( "Adding an edge to a graph that has already been compiled. This will not be reflected in the compiled graph." @@ -249,7 +252,7 @@ export class Graph< /** * @deprecated use `addEdge(START, key)` instead */ - setEntryPoint(key: N): this { + setEntryPoint(key: NodeNames): this { this.warnIfCompiled( "Setting the entry point of a graph that has already been compiled. This will not be reflected in the compiled graph." ); @@ -260,7 +263,7 @@ export class Graph< /** * @deprecated use `addEdge(key, END)` instead */ - setFinishPoint(key: N): this { + setFinishPoint(key: NodeNames): this { this.warnIfCompiled( "Setting a finish point of a graph that has already been compiled. This will not be reflected in the compiled graph." ); @@ -274,9 +277,9 @@ export class Graph< interruptAfter, }: { checkpointer?: BaseCheckpointSaver; - interruptBefore?: N[] | All; - interruptAfter?: N[] | All; - } = {}): CompiledGraph { + interruptBefore?: NodeNames[] | All; + interruptAfter?: NodeNames[] | All; + } = {}): CompiledGraph { // validate the graph this.validate([ ...(Array.isArray(interruptBefore) ? interruptBefore : []), @@ -284,20 +287,23 @@ export class Graph< ]); // create empty compiled graph - const compiled = new CompiledGraph({ + const compiled = new CompiledGraph({ builder: this, checkpointer, interruptAfter, interruptBefore, autoValidate: false, - nodes: {} as Record>, + nodes: {} as Record< + NodeNames | typeof START, + PregelNode + >, channels: { [START]: new EphemeralValue(), [END]: new EphemeralValue(), - } as Record, + } as Record, inputChannels: START, outputChannels: END, - streamChannels: [] as N[], + streamChannels: [] as NodeNames[], streamMode: "values", }); @@ -305,14 +311,14 @@ export class Graph< for (const [key, node] of Object.entries>( this.nodes )) { - compiled.attachNode(key as N, node); + compiled.attachNode(key as NodeNames, node); } for (const [start, end] of this.edges) { compiled.attachEdge(start, end); } for (const [start, branches] of Object.entries(this.branches)) { for (const [name, branch] of Object.entries(branches)) { - compiled.attachBranch(start as N, name, branch); + compiled.attachBranch(start as NodeNames, name, branch); } } @@ -375,35 +381,35 @@ export class Graph< } export class CompiledGraph< - N extends string, + NodeNames extends string, // eslint-disable-next-line @typescript-eslint/no-explicit-any RunInput = any, // eslint-disable-next-line @typescript-eslint/no-explicit-any RunOutput = any > extends Pregel< - Record>, - Record + Record>, + Record > { - declare NodeType: N; + declare NodeType: NodeNames; declare RunInput: RunInput; declare RunOutput: RunOutput; - builder: Graph; + builder: Graph; constructor({ builder, ...rest - }: { builder: Graph } & PregelParams< - Record>, - Record + }: { builder: Graph } & PregelParams< + Record>, + Record >) { super(rest); this.builder = builder; } - attachNode(key: N, node: NodeSpec): void { + attachNode(key: NodeNames, node: NodeSpec): void { this.channels[key] = new EphemeralValue(); this.nodes[key] = new PregelNode({ channels: [], @@ -414,10 +420,13 @@ export class CompiledGraph< .pipe( new ChannelWrite([{ channel: key, value: PASSTHROUGH }], [TAG_HIDDEN]) ); - (this.streamChannels as N[]).push(key); + (this.streamChannels as NodeNames[]).push(key); } - attachEdge(start: N | typeof START, end: N | typeof END): void { + attachEdge( + start: NodeNames | typeof START, + end: NodeNames | typeof END + ): void { if (end === END) { if (start === START) { throw new Error("Cannot have an edge from START to END"); @@ -432,9 +441,9 @@ export class CompiledGraph< } attachBranch( - start: N | typeof START, + start: NodeNames | typeof START, name: string, - branch: Branch + branch: Branch ) { // add hidden start node if (start === START && this.nodes[START]) { @@ -460,7 +469,7 @@ export class CompiledGraph< // attach branch readers const ends = branch.ends ? Object.values(branch.ends) - : (Object.keys(this.nodes) as N[]); + : (Object.keys(this.nodes) as NodeNames[]); for (const end of ends) { if (end !== END) { const channelName = `branch:${start}:${name}:${end}`; diff --git a/libs/langgraph/src/graph/state.ts b/libs/langgraph/src/graph/state.ts index c8482190..c1d40072 100644 --- a/libs/langgraph/src/graph/state.ts +++ b/libs/langgraph/src/graph/state.ts @@ -156,26 +156,37 @@ export type StateGraphArgsWithInputOutputSchemas< */ export class StateGraph< SD extends StateDefinition | unknown, - S = SD extends StateDefinition ? StateType : SD, - U = SD extends StateDefinition ? UpdateType : Partial, - N extends string = typeof START, - I extends StateDefinition = SD extends StateDefinition ? SD : StateDefinition, - O extends StateDefinition = SD extends StateDefinition ? SD : StateDefinition, - C extends StateDefinition = StateDefinition -> extends Graph> { + InputStateSchema = SD extends StateDefinition ? StateType : SD, + OutputStateSchema = SD extends StateDefinition + ? UpdateType + : Partial, + NodeNames extends string = typeof START, + InputDefinition extends StateDefinition = SD extends StateDefinition + ? SD + : StateDefinition, + OutputDefinition extends StateDefinition = SD extends StateDefinition + ? SD + : StateDefinition, + ConfigSchema extends StateDefinition = StateDefinition +> extends Graph< + NodeNames, + InputStateSchema, + OutputStateSchema, + StateGraphNodeSpec +> { channels: Record = {}; // TODO: this doesn't dedupe edges as in py, so worth fixing at some point - waitingEdges: Set<[N[], N]> = new Set(); + waitingEdges: Set<[NodeNames[], NodeNames]> = new Set(); /** @internal */ _schemaDefinition: StateDefinition; /** @internal */ - _inputDefinition: I; + _inputDefinition: InputDefinition; /** @internal */ - _outputDefinition: O; + _outputDefinition: OutputDefinition; /** * Map schemas to managed values @@ -184,35 +195,35 @@ export class StateGraph< _schemaDefinitions = new Map(); /** @internal Used only for typing. */ - _configSchema: C | undefined; + _configSchema: ConfigSchema | undefined; constructor( fields: SD extends StateDefinition ? | SD | AnnotationRoot - | StateGraphArgs - | StateGraphArgsWithStateSchema - | StateGraphArgsWithInputOutputSchemas - : StateGraphArgs, - configSchema?: AnnotationRoot + | StateGraphArgs + | StateGraphArgsWithStateSchema + | StateGraphArgsWithInputOutputSchemas + : StateGraphArgs, + configSchema?: AnnotationRoot ) { super(); if ( isStateGraphArgsWithInputOutputSchemas< SD extends StateDefinition ? SD : never, - O + OutputDefinition >(fields) ) { this._schemaDefinition = fields.input.spec; - this._inputDefinition = fields.input.spec as unknown as I; + this._inputDefinition = fields.input.spec as unknown as InputDefinition; this._outputDefinition = fields.output.spec; } else if (isStateGraphArgsWithStateSchema(fields)) { this._schemaDefinition = fields.stateSchema.spec; this._inputDefinition = (fields.input?.spec ?? - this._schemaDefinition) as I; + this._schemaDefinition) as InputDefinition; this._outputDefinition = (fields.output?.spec ?? - this._schemaDefinition) as O; + this._schemaDefinition) as OutputDefinition; } else if (isStateDefinition(fields) || isAnnotationRoot(fields)) { const spec = isAnnotationRoot(fields) ? fields.spec : fields; this._schemaDefinition = spec; @@ -269,15 +280,25 @@ export class StateGraph< } } - addNode( + addNode( key: K, action: RunnableLike< NodeInput, // eslint-disable-next-line @typescript-eslint/no-explicit-any - U extends object ? U & Record : U + OutputStateSchema extends object + ? OutputStateSchema & Record + : OutputStateSchema >, options?: StateGraphAddNodeOptions - ): StateGraph { + ): StateGraph< + SD, + InputStateSchema, + OutputStateSchema, + NodeNames | K, + InputDefinition, + OutputDefinition, + ConfigSchema + > { if (key in this.channels) { throw new Error( `${key} is already being used as a state attribute (a.k.a. a channel), cannot also be used as a node name.` @@ -303,19 +324,33 @@ export class StateGraph< if (options?.input !== undefined) { this._addSchema(options.input.spec); } - const nodeSpec: StateGraphNodeSpec = { - runnable: _coerceToRunnable(action) as unknown as Runnable, + const nodeSpec: StateGraphNodeSpec = { + runnable: _coerceToRunnable(action) as unknown as Runnable< + InputStateSchema, + OutputStateSchema + >, retryPolicy: options?.retryPolicy, metadata: options?.metadata, input: options?.input?.spec ?? this._schemaDefinition, }; - this.nodes[key as unknown as N] = nodeSpec; - - return this as StateGraph; + this.nodes[key as unknown as NodeNames] = nodeSpec; + + return this as StateGraph< + SD, + InputStateSchema, + OutputStateSchema, + NodeNames | K, + InputDefinition, + OutputDefinition, + ConfigSchema + >; } - addEdge(startKey: typeof START | N | N[], endKey: N | typeof END): this { + addEdge( + startKey: typeof START | NodeNames | NodeNames[], + endKey: NodeNames | typeof END + ): this { if (typeof startKey === "string") { return super.addEdge(startKey, endKey); } @@ -355,9 +390,16 @@ export class StateGraph< }: { checkpointer?: BaseCheckpointSaver; store?: BaseStore; - interruptBefore?: N[] | All; - interruptAfter?: N[] | All; - } = {}): CompiledStateGraph { + interruptBefore?: NodeNames[] | All; + interruptAfter?: NodeNames[] | All; + } = {}): CompiledStateGraph< + InputStateSchema, + OutputStateSchema, + NodeNames, + InputDefinition, + OutputDefinition, + ConfigSchema + > { // validate the graph this.validate([ ...(Array.isArray(interruptBefore) ? interruptBefore : []), @@ -376,17 +418,27 @@ export class StateGraph< streamKeys.length === 1 && streamKeys[0] === ROOT ? ROOT : streamKeys; // create empty compiled graph - const compiled = new CompiledStateGraph({ + const compiled = new CompiledStateGraph< + InputStateSchema, + OutputStateSchema, + NodeNames, + InputDefinition, + OutputDefinition, + ConfigSchema + >({ builder: this, checkpointer, interruptAfter, interruptBefore, autoValidate: false, - nodes: {} as Record>, + nodes: {} as Record< + NodeNames | typeof START, + PregelNode + >, channels: { ...this.channels, [START]: new EphemeralValue(), - } as Record, + } as Record, inputChannels: START, outputChannels, streamChannels, @@ -396,10 +448,10 @@ export class StateGraph< // attach nodes, edges and branches compiled.attachNode(START); - for (const [key, node] of Object.entries>( - this.nodes - )) { - compiled.attachNode(key as N, node); + for (const [key, node] of Object.entries< + StateGraphNodeSpec + >(this.nodes)) { + compiled.attachNode(key as NodeNames, node); } for (const [start, end] of this.edges) { compiled.attachEdge(start, end); @@ -409,7 +461,7 @@ export class StateGraph< } for (const [start, branches] of Object.entries(this.branches)) { for (const [name, branch] of Object.entries(branches)) { - compiled.attachBranch(start as N, name, branch); + compiled.attachBranch(start as NodeNames, name, branch); } } diff --git a/libs/langgraph/src/pregel/algo.ts b/libs/langgraph/src/pregel/algo.ts index 529fe71c..9b51eaa6 100644 --- a/libs/langgraph/src/pregel/algo.ts +++ b/libs/langgraph/src/pregel/algo.ts @@ -38,18 +38,16 @@ import { TAG_HIDDEN, TASKS, } from "../constants.js"; -import { PregelExecutableTask, PregelTaskDescription } from "./types.js"; +import { + ChannelsType, + NodesType, + PregelExecutableTask, + PregelTaskDescription, +} from "./types.js"; import { EmptyChannelError, InvalidUpdateError } from "../errors.js"; import { _getIdMetadata, getNullChannelVersion } from "./utils.js"; import { ManagedValueMapping } from "../managed/base.js"; -/** - * Construct a type with a set of properties K of type T - */ -export type StrRecord = { - [P in K]: T; -}; - export type WritesProtocol = { name: string; writes: PendingWrite[]; @@ -91,17 +89,17 @@ export function shouldInterrupt( return anyChannelUpdated && anyTriggeredNodeInInterruptNodes; } -export function _localRead>( +export function _localRead>( step: number, checkpoint: ReadonlyCheckpoint, - channels: Cc, + channels: LocalChannels, managed: ManagedValueMapping, - task: WritesProtocol, - select: Array | keyof Cc, + task: WritesProtocol, + select: Array | keyof LocalChannels, fresh: boolean = false ): Record | unknown { - let managedKeys: Array = []; - let updated = new Set(); + let managedKeys: Array = []; + let updated = new Set(); if (!Array.isArray(select)) { for (const [c] of task.writes) { @@ -113,9 +111,11 @@ export function _localRead>( updated = updated || new Set(); } else { managedKeys = select.filter((k) => managed.get(k as string)) as Array< - keyof Cc + keyof LocalChannels + >; + select = select.filter((k) => !managed.get(k as string)) as Array< + keyof LocalChannels >; - select = select.filter((k) => !managed.get(k as string)) as Array; updated = new Set( select.filter((c) => task.writes.some(([key, _]) => key === c)) ); @@ -125,11 +125,20 @@ export function _localRead>( if (fresh && updated.size > 0) { const localChannels = Object.fromEntries( - Object.entries(channels).filter(([k, _]) => updated.has(k as keyof Cc)) - ) as Partial; - - const newCheckpoint = createCheckpoint(checkpoint, localChannels as Cc, -1); - const newChannels = emptyChannels(localChannels as Cc, newCheckpoint); + Object.entries(channels).filter(([k, _]) => + updated.has(k as keyof LocalChannels) + ) + ) as Partial; + + const newCheckpoint = createCheckpoint( + checkpoint, + localChannels as LocalChannels, + -1 + ); + const newChannels = emptyChannels( + localChannels as LocalChannels, + newCheckpoint + ); _applyWrites(copyCheckpoint(newCheckpoint), newChannels, [task]); values = readChannels({ ...channels, ...newChannels }, select); @@ -183,17 +192,17 @@ export function _localWrite( commit(writes); } -export function _applyWrites>( +export function _applyWrites>( checkpoint: Checkpoint, - channels: Cc, - tasks: WritesProtocol[], + channels: LocalChannels, + tasks: WritesProtocol[], // eslint-disable-next-line @typescript-eslint/no-explicit-any getNextVersion?: (version: any, channel: BaseChannel) => any ): Record { // Filter out non instances of BaseChannel const onlyChannels = Object.fromEntries( Object.entries(channels).filter(([_, value]) => isBaseChannel(value)) - ) as Cc; + ) as LocalChannels; // Update seen versions for (const task of tasks) { if (checkpoint.versions_seen[task.name] === undefined) { @@ -240,10 +249,13 @@ export function _applyWrites>( // Group writes by channel const pendingWriteValuesByChannel = {} as Record< - keyof Cc, + keyof LocalChannels, + PendingWriteValue[] + >; + const pendingWritesByManaged = {} as Record< + keyof LocalChannels, PendingWriteValue[] >; - const pendingWritesByManaged = {} as Record; for (const task of tasks) { for (const [chan, val] of task.writes) { if (chan === TASKS) { @@ -329,12 +341,12 @@ export type NextTaskExtraFields = { }; export function _prepareNextTasks< - Nn extends StrRecord, - Cc extends StrRecord + Nodes extends NodesType, + Channels extends ChannelsType >( checkpoint: ReadonlyCheckpoint, - processes: Nn, - channels: Cc, + processes: Nodes, + channels: Channels, managed: ManagedValueMapping, config: RunnableConfig, forExecution: false, @@ -342,32 +354,35 @@ export function _prepareNextTasks< ): PregelTaskDescription[]; export function _prepareNextTasks< - Nn extends StrRecord, - Cc extends StrRecord + Nodes extends NodesType, + Channels extends ChannelsType >( checkpoint: ReadonlyCheckpoint, - processes: Nn, - channels: Cc, + processes: Nodes, + channels: Channels, managed: ManagedValueMapping, config: RunnableConfig, forExecution: true, extra: NextTaskExtraFields -): PregelExecutableTask[]; +): PregelExecutableTask[]; export function _prepareNextTasks< - Nn extends StrRecord, - Cc extends StrRecord + Nodes extends NodesType, + LocalChannels extends Record >( checkpoint: ReadonlyCheckpoint, - processes: Nn, - channels: Cc, + processes: Nodes, + channels: LocalChannels, managed: ManagedValueMapping, config: RunnableConfig, forExecution: boolean, extra: NextTaskExtraFields -): PregelTaskDescription[] | PregelExecutableTask[] { +): + | PregelTaskDescription[] + | PregelExecutableTask[] { const parentNamespace = config.configurable?.checkpoint_ns ?? ""; - const tasks: Array> = []; + const tasks: Array> = + []; const taskDescriptions: Array = []; const { step, isResuming = false, checkpointer, manager } = extra; @@ -404,7 +419,7 @@ export function _prepareNextTasks< const proc = processes[packet.node]; const node = proc.getNode(); if (node !== undefined) { - const writes: [keyof Cc, unknown][] = []; + const writes: [keyof LocalChannels, unknown][] = []; managed.replaceRuntimePlaceholders(step, packet.args); tasks.push({ name: packet.node, @@ -424,14 +439,15 @@ export function _prepareNextTasks< [CONFIG_KEY_SEND]: (writes_: [string, any][]) => _localWrite( step, - (items: [keyof Cc, unknown][]) => writes.push(...items), + (items: [keyof LocalChannels, unknown][]) => + writes.push(...items), processes, channels, managed, writes_ ), [CONFIG_KEY_READ]: ( - select_: Array | keyof Cc, + select_: Array | keyof LocalChannels, fresh_: boolean = false ) => _localRead( @@ -510,7 +526,7 @@ export function _prepareNextTasks< if (forExecution) { const node = proc.getNode(); if (node !== undefined) { - const writes: [keyof Cc, unknown][] = []; + const writes: [keyof LocalChannels, unknown][] = []; tasks.push({ name, input: val, @@ -527,14 +543,15 @@ export function _prepareNextTasks< [CONFIG_KEY_SEND]: (writes_: [string, any][]) => _localWrite( step, - (items: [keyof Cc, unknown][]) => writes.push(...items), + (items: [keyof LocalChannels, unknown][]) => + writes.push(...items), processes, channels, managed, writes_ ), [CONFIG_KEY_READ]: ( - select_: Array | keyof Cc, + select_: Array | keyof LocalChannels, fresh_: boolean = false ) => _localRead( @@ -573,7 +590,7 @@ function _procInput( step: number, proc: PregelNode, managed: ManagedValueMapping, - channels: StrRecord, + channels: Record, forExecution: boolean ) { // eslint-disable-next-line @typescript-eslint/no-explicit-any diff --git a/libs/langgraph/src/pregel/index.ts b/libs/langgraph/src/pregel/index.ts index b4d4afb3..327ce055 100644 --- a/libs/langgraph/src/pregel/index.ts +++ b/libs/langgraph/src/pregel/index.ts @@ -46,11 +46,17 @@ import { INTERRUPT, } from "../constants.js"; import { + NodesType, + ChannelsType, PregelExecutableTask, PregelInterface, PregelParams, + SingleStreamMode, StateSnapshot, StreamMode, + StreamOutput, + ChannelsStateType, + AllStreamInvokeOutputTypes, } from "./types.js"; import { GraphRecursionError, @@ -58,12 +64,7 @@ import { InvalidUpdateError, isGraphInterrupt, } from "../errors.js"; -import { - _prepareNextTasks, - _localRead, - _applyWrites, - StrRecord, -} from "./algo.js"; +import { _prepareNextTasks, _localRead, _applyWrites } from "./algo.js"; import { _coerceToDict, getNewChannelVersions, RetryPolicy } from "./utils.js"; import { PregelLoop } from "./loop.js"; import { executeTasksWithRetry } from "./retry.js"; @@ -176,34 +177,30 @@ export class Channel { * Config for executing the graph. */ export interface PregelOptions< - Nn extends StrRecord, - Cc extends StrRecord + Nodes extends NodesType, + Channels extends ChannelsType, + Mode extends StreamMode = StreamMode > extends RunnableConfig { /** The stream mode for the graph run. Default is ["values"]. */ - streamMode?: StreamMode | StreamMode[]; - inputKeys?: keyof Cc | Array; + streamMode?: Mode; + inputKeys?: keyof Channels | Array; /** The output keys to retrieve from the graph run. */ - outputKeys?: keyof Cc | Array; + outputKeys?: keyof Channels | Array; /** The nodes to interrupt the graph run before. */ - interruptBefore?: All | Array; + interruptBefore?: All | Array; /** The nodes to interrupt the graph run after. */ - interruptAfter?: All | Array; + interruptAfter?: All | Array; /** Enable debug mode for the graph run. */ debug?: boolean; } -// eslint-disable-next-line @typescript-eslint/no-explicit-any -export type PregelInputType = any; - -// eslint-disable-next-line @typescript-eslint/no-explicit-any -export type PregelOutputType = any; - -export class Pregel< - Nn extends StrRecord, - Cc extends StrRecord +export class Pregel + extends Runnable< + any, // Partial>, // input type + AllStreamInvokeOutputTypes, // output type + PregelOptions > - extends Runnable> - implements PregelInterface + implements PregelInterface { static lc_name() { return "LangGraph"; @@ -212,23 +209,23 @@ export class Pregel< // Because Pregel extends `Runnable`. lc_namespace = ["langgraph", "pregel"]; - nodes: Nn; + nodes: Nodes; - channels: Cc; + channels: Channels; - inputChannels: keyof Cc | Array; + inputChannels: keyof Channels | Array; - outputChannels: keyof Cc | Array; + outputChannels: keyof Channels | Array; autoValidate: boolean = true; - streamMode: StreamMode[] = ["values"]; + streamMode: SingleStreamMode = "values"; - streamChannels?: keyof Cc | Array; + streamChannels?: keyof Channels | Array; - interruptAfter?: Array | All; + interruptAfter?: Array | All; - interruptBefore?: Array | All; + interruptBefore?: Array | All; stepTimeout?: number; @@ -240,18 +237,13 @@ export class Pregel< store?: BaseStore; - constructor(fields: PregelParams) { + constructor(fields: PregelParams) { super(fields); - let { streamMode } = fields; - if (streamMode != null && !Array.isArray(streamMode)) { - streamMode = [streamMode]; - } - this.nodes = fields.nodes; this.channels = fields.channels; this.autoValidate = fields.autoValidate ?? this.autoValidate; - this.streamMode = streamMode ?? this.streamMode; + this.streamMode = fields.streamMode ?? this.streamMode; this.inputChannels = fields.inputChannels; this.outputChannels = fields.outputChannels; this.streamChannels = fields.streamChannels ?? this.streamChannels; @@ -269,7 +261,7 @@ export class Pregel< } validate(): this { - validateGraph({ + validateGraph({ nodes: this.nodes, channels: this.channels, outputChannels: this.outputChannels, @@ -282,7 +274,7 @@ export class Pregel< return this; } - get streamChannelsList(): Array { + get streamChannelsList(): Array { if (Array.isArray(this.streamChannels)) { return this.streamChannels; } else if (this.streamChannels) { @@ -292,7 +284,7 @@ export class Pregel< } } - get streamChannelsAsIs(): keyof Cc | Array { + get streamChannelsAsIs(): keyof Channels | Array { if (this.streamChannels) { return this.streamChannels; } else { @@ -391,7 +383,7 @@ export class Pregel< async updateState( config: RunnableConfig, values: Record | unknown, - asNode?: keyof Nn + asNode?: keyof Nodes ): Promise { if (!this.checkpointer) { throw new GraphValueError("No checkpointer set"); @@ -493,7 +485,7 @@ export class Pregel< `No writers found for node "${asNode.toString()}"` ); } - const task: PregelExecutableTask = { + const task: PregelExecutableTask = { name: asNode, input: values, proc: @@ -510,10 +502,10 @@ export class Pregel< patchConfig(config, { runName: config.runName ?? `${this.getName()}UpdateState`, configurable: { - [CONFIG_KEY_SEND]: (items: [keyof Cc, unknown][]) => + [CONFIG_KEY_SEND]: (items: [keyof Channels, unknown][]) => task.writes.push(...items), [CONFIG_KEY_READ]: ( - select_: Array | keyof Cc, + select_: Array | keyof Channels, fresh_: boolean = false ) => _localRead( @@ -563,9 +555,9 @@ export class Pregel< ); } - _defaults(config: PregelOptions): [ + _defaults(config: PregelOptions): [ boolean, // debug - StreamMode[], // stream mode + SingleStreamMode[], // stream mode string | string[], // input keys string | string[], // output keys RunnableConfig, // config without pregel keys @@ -603,11 +595,11 @@ export class Pregel< const defaultInterruptAfter = interruptAfter ?? this.interruptAfter ?? []; - let defaultStreamMode: StreamMode[]; + let defaultStreamMode: SingleStreamMode[]; if (streamMode !== undefined) { defaultStreamMode = Array.isArray(streamMode) ? streamMode : [streamMode]; } else { - defaultStreamMode = this.streamMode; + defaultStreamMode = [this.streamMode]; } let defaultCheckpointer: BaseCheckpointSaver | undefined; @@ -654,11 +646,13 @@ export class Pregel< * @param options.interruptAfter Nodes to interrupt after. * @param options.debug Whether to print debug information during execution. */ - override async stream( - input: PregelInputType, - options?: Partial> - ): Promise> { - return super.stream(input, options); + override async stream( + input: any, //Partial>, + options?: Partial> + ): Promise>> { + return super.stream(input, options) as Promise< + IterableReadableStream> + >; } protected async prepareSpecs( @@ -723,10 +717,10 @@ export class Pregel< }; } - override async *_streamIterator( - input: PregelInputType, - options?: Partial> - ): AsyncGenerator { + override async *_streamIterator( + input: Partial>, + options?: Partial> + ): AsyncGenerator> { const inputConfig = ensureConfig(options); if ( inputConfig.recursionLimit === undefined || @@ -802,9 +796,9 @@ export class Pregel< } if (streamMode.includes(nextItem[0])) { if (streamMode.length === 1) { - yield nextItem[1]; + yield nextItem[1] as StreamOutput; } else { - yield nextItem; + yield nextItem as unknown as StreamOutput; } } } @@ -845,9 +839,13 @@ export class Pregel< } if (streamMode.includes(nextItem[0])) { if (streamMode.length === 1) { - yield nextItem[1]; + yield nextItem[1] as StreamOutput; } else { - yield nextItem; + yield nextItem as unknown as StreamOutput< + Mode, + Nodes, + Channels + >; } } } @@ -871,9 +869,9 @@ export class Pregel< } if (streamMode.includes(nextItem[0])) { if (streamMode.length === 1) { - yield nextItem[1]; + yield nextItem[1] as StreamOutput; } else { - yield nextItem; + yield nextItem as unknown as StreamOutput; } } } @@ -918,24 +916,33 @@ export class Pregel< * @param options.interruptAfter Nodes to interrupt after. * @param options.debug Whether to print debug information during execution. */ - override async invoke( - input: PregelInputType, - options?: Partial> - ): Promise { - const streamMode = options?.streamMode ?? "values"; + override async invoke( + input: any, //ChannelsStateType, + options?: Partial> + ): Promise< + Mode extends "values" + ? StreamOutput + : Array> + > { + const streamMode = options?.streamMode ?? ("values" as const); const config = { ...ensureConfig(options), outputKeys: options?.outputKeys ?? this.outputChannels, streamMode, }; - const chunks = []; + const chunks: Array> = []; const stream = await this.stream(input, config); for await (const chunk of stream) { - chunks.push(chunk); + chunks.push(chunk as StreamOutput); } + if (streamMode === "values") { - return chunks[chunks.length - 1]; + return chunks[chunks.length - 1] as Mode extends "values" + ? StreamOutput + : Array>; } - return chunks; + return chunks as Mode extends "values" + ? StreamOutput + : Array>; } } diff --git a/libs/langgraph/src/pregel/loop.ts b/libs/langgraph/src/pregel/loop.ts index 54a5953c..e993b54d 100644 --- a/libs/langgraph/src/pregel/loop.ts +++ b/libs/langgraph/src/pregel/loop.ts @@ -17,7 +17,7 @@ import { createCheckpoint, emptyChannels, } from "../channels/base.js"; -import { PregelExecutableTask, StreamMode } from "./types.js"; +import { PregelExecutableTask, SingleStreamMode } from "./types.js"; import { CONFIG_KEY_READ, CONFIG_KEY_RESUMING, @@ -137,7 +137,7 @@ export class PregelLoop { tasks: PregelExecutableTask[] = []; // eslint-disable-next-line @typescript-eslint/no-explicit-any - stream: Deque<[StreamMode, any]> = new Deque(); + stream: Deque<[SingleStreamMode, any]> = new Deque(); checkpointerPromises: Promise[] = []; diff --git a/libs/langgraph/src/pregel/types.ts b/libs/langgraph/src/pregel/types.ts index 7ce7d675..6bf6f6fb 100644 --- a/libs/langgraph/src/pregel/types.ts +++ b/libs/langgraph/src/pregel/types.ts @@ -10,24 +10,172 @@ import type { PregelNode } from "./read.js"; import { RetryPolicy } from "./utils.js"; import { Interrupt } from "../constants.js"; import { BaseStore } from "../store/base.js"; -import { type ManagedValueSpec } from "../managed/base.js"; +import { + ConfiguredManagedValue, + ManagedValue, + type ManagedValueSpec, +} from "../managed/base.js"; +export type ChannelsType = Record; -export type StreamMode = "values" | "updates" | "debug"; +export type NodesType = Record; -/** - * Construct a type with a set of properties K of type T - */ -type StrRecord = { - [P in K]: T; +type ExtractChannelValueType = Channel extends BaseChannel + ? Channel["ValueType"] + : Channel extends ManagedValueSpec + ? Channel extends ConfiguredManagedValue + ? V + : Channel extends ManagedValue + ? V + : never + : never; + +export type ChannelsStateType = { + [Key in keyof Channels]: ExtractChannelValueType; +}; + +type ExtractChannelUpdateType = Channel extends BaseChannel + ? Channel["UpdateType"] + : Channel extends ManagedValueSpec + ? Channel extends ConfiguredManagedValue + ? V + : Channel extends ManagedValue + ? V + : never + : never; + +export type ChannelsUpdateType = { + [Key in keyof Channels]?: ExtractChannelUpdateType; }; +export type DebugOutput = { + type: string; + timestamp: string; + step: number; + payload: { + id: string; + name: string; + result: PendingWrite[]; + config: RunnableConfig; + metadata?: CheckpointMetadata; + }; +}; + +export type SingleStreamMode = "values" | "updates" | "debug"; + +export type StreamMode = + | SingleStreamMode + | [SingleStreamMode] + | [SingleStreamMode, SingleStreamMode] + | [SingleStreamMode, SingleStreamMode, SingleStreamMode]; + +export type SingleStreamModeOutput< + Mode extends SingleStreamMode, + Nodes extends NodesType, + Channels extends ChannelsType +> = Mode extends "values" + ? ChannelsStateType + : Mode extends "updates" + ? { [K in keyof Nodes]: ChannelsStateType } + : Mode extends "debug" + ? DebugOutput + : never; + +export type StreamModeOutputTuple< + Mode extends SingleStreamMode, + Nodes extends NodesType, + Channels extends ChannelsType +> = [Mode, SingleStreamModeOutput]; + +export type StreamOutput< + Mode extends StreamMode, + Nodes extends NodesType, + Channels extends ChannelsType +> = Mode extends SingleStreamMode + ? SingleStreamModeOutput + : Mode[number] extends SingleStreamMode + ? StreamModeOutputTuple + : never; + +export type InvokeOutput< + Mode extends StreamMode, + Nodes extends NodesType, + Channels extends ChannelsType +> = Mode extends "values" + ? StreamOutput + : Array>; +// Gross hack to avoid recursion limits & define the narrowest type that covers all possible stream output types +// This is what gets passed to the lower abstraction layers (such as `Runnable`) and the precise output type is +// narrowed at higher levels such as `StateGraph.stream()` to a specific one once the `StreamMode` is known +export type AllStreamOutputTypes< + Nodes extends NodesType, + Channels extends ChannelsType +> = + | StreamOutput<"values" | "updates", Nodes, Channels> + | StreamOutput<"values" | "debug", Nodes, Channels> + | StreamOutput<"updates" | "debug", Nodes, Channels> + | StreamOutput<"values" | "updates" | "debug", Nodes, Channels> + | StreamOutput<["values"], Nodes, Channels> + | StreamOutput<["updates"], Nodes, Channels> + | StreamOutput<["debug"], Nodes, Channels> + | StreamOutput<["values", "updates"], Nodes, Channels> + | StreamOutput<["values", "values"], Nodes, Channels> + | StreamOutput<["values", "debug"], Nodes, Channels> + | StreamOutput<["updates", "updates"], Nodes, Channels> + | StreamOutput<["updates", "values"], Nodes, Channels> + | StreamOutput<["updates", "debug"], Nodes, Channels> + | StreamOutput<["debug", "updates"], Nodes, Channels> + | StreamOutput<["debug", "values"], Nodes, Channels> + | StreamOutput<["debug", "debug"], Nodes, Channels> + | StreamOutput<["updates", "updates", "updates"], Nodes, Channels> + | StreamOutput<["updates", "updates", "values"], Nodes, Channels> + | StreamOutput<["updates", "updates", "debug"], Nodes, Channels> + | StreamOutput<["updates", "values", "updates"], Nodes, Channels> + | StreamOutput<["updates", "values", "values"], Nodes, Channels> + | StreamOutput<["updates", "values", "debug"], Nodes, Channels> + | StreamOutput<["updates", "debug", "updates"], Nodes, Channels> + | StreamOutput<["updates", "debug", "values"], Nodes, Channels> + | StreamOutput<["updates", "debug", "debug"], Nodes, Channels> + | StreamOutput<["values", "updates", "updates"], Nodes, Channels> + | StreamOutput<["values", "updates", "values"], Nodes, Channels> + | StreamOutput<["values", "updates", "debug"], Nodes, Channels> + | StreamOutput<["values", "values", "updates"], Nodes, Channels> + | StreamOutput<["values", "values", "values"], Nodes, Channels> + | StreamOutput<["values", "values", "debug"], Nodes, Channels> + | StreamOutput<["values", "debug", "updates"], Nodes, Channels> + | StreamOutput<["values", "debug", "values"], Nodes, Channels> + | StreamOutput<["values", "debug", "debug"], Nodes, Channels> + | StreamOutput<["debug", "updates", "updates"], Nodes, Channels> + | StreamOutput<["debug", "updates", "values"], Nodes, Channels> + | StreamOutput<["debug", "updates", "debug"], Nodes, Channels> + | StreamOutput<["debug", "values", "updates"], Nodes, Channels> + | StreamOutput<["debug", "values", "values"], Nodes, Channels> + | StreamOutput<["debug", "values", "debug"], Nodes, Channels> + | StreamOutput<["debug", "debug", "updates"], Nodes, Channels> + | StreamOutput<["debug", "debug", "values"], Nodes, Channels> + | StreamOutput<["debug", "debug", "debug"], Nodes, Channels>; + +export type AllStreamInvokeOutputTypes< + Nodes extends NodesType, + Channels extends ChannelsType +> = + | AllStreamOutputTypes + | Array>; + +export type InvokeOutputType< + Mode extends StreamMode, + Nodes extends NodesType, + Channels extends ChannelsType +> = Mode extends "values" + ? StreamOutput + : Array>; + export interface PregelInterface< - Nn extends StrRecord, - Cc extends StrRecord + Nodes extends NodesType, + Channels extends ChannelsType > { - nodes: Nn; + nodes: Nodes; - channels: Cc; + channels: Channels; /** * @default true @@ -37,25 +185,25 @@ export interface PregelInterface< /** * @default "values" */ - streamMode?: StreamMode | StreamMode[]; + streamMode?: SingleStreamMode; - inputChannels: keyof Cc | Array; + inputChannels: keyof Channels | Array; - outputChannels: keyof Cc | Array; + outputChannels: keyof Channels | Array; /** * @default [] */ - interruptAfter?: Array | All; + interruptAfter?: Array | All; /** * @default [] */ - interruptBefore?: Array | All; + interruptBefore?: Array | All; - streamChannels?: keyof Cc | Array; + streamChannels?: keyof Channels | Array; - get streamChannelsAsIs(): keyof Cc | Array; + get streamChannelsAsIs(): keyof Channels | Array; /** * @default undefined @@ -78,9 +226,9 @@ export interface PregelInterface< } export type PregelParams< - Nn extends StrRecord, - Cc extends StrRecord -> = Omit, "streamChannelsAsIs">; + Nodes extends NodesType, + Channels extends ChannelsType +> = Omit, "streamChannelsAsIs">; export interface PregelTaskDescription { readonly id: string; diff --git a/libs/langgraph/src/pregel/validate.ts b/libs/langgraph/src/pregel/validate.ts index 7f86e161..b1637a7b 100644 --- a/libs/langgraph/src/pregel/validate.ts +++ b/libs/langgraph/src/pregel/validate.ts @@ -1,8 +1,7 @@ import { All } from "@langchain/langgraph-checkpoint"; -import { BaseChannel } from "../channels/index.js"; import { INTERRUPT } from "../constants.js"; import { PregelNode } from "./read.js"; -import { type ManagedValueSpec } from "../managed/base.js"; +import { ChannelsType, NodesType } from "./types.js"; export class GraphValidationError extends Error { constructor(message?: string) { @@ -12,8 +11,8 @@ export class GraphValidationError extends Error { } export function validateGraph< - Nn extends Record, - Cc extends Record + Nodes extends NodesType, + Channels extends ChannelsType >({ nodes, channels, @@ -23,20 +22,20 @@ export function validateGraph< interruptAfterNodes, interruptBeforeNodes, }: { - nodes: Nn; - channels: Cc; - inputChannels: keyof Cc | Array; - outputChannels: keyof Cc | Array; - streamChannels?: keyof Cc | Array; - interruptAfterNodes?: Array | All; - interruptBeforeNodes?: Array | All; + nodes: Nodes; + channels: Channels; + inputChannels: keyof Channels | Array; + outputChannels: keyof Channels | Array; + streamChannels?: keyof Channels | Array; + interruptAfterNodes?: Array | All; + interruptBeforeNodes?: Array | All; }): void { if (!channels) { throw new GraphValidationError("Channels not provided"); } - const subscribedChannels = new Set(); - const allOutputChannels = new Set(); + const subscribedChannels = new Set(); + const allOutputChannels = new Set(); for (const [name, node] of Object.entries(nodes)) { if (name === INTERRUPT) { @@ -114,9 +113,10 @@ export function validateGraph< } } -export function validateKeys< - Cc extends Record ->(keys: keyof Cc | Array, channels: Cc): void { +export function validateKeys( + keys: keyof Channels | Array, + channels: Channels +): void { if (Array.isArray(keys)) { for (const key of keys) { if (!(key in channels)) {