Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Implement types for Pregel.stream using recursive type #512

Draft
wants to merge 16 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions libs/langgraph/src/channels/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -78,15 +78,15 @@ export abstract class BaseChannel<
}
}

export function emptyChannels<Cc extends Record<string, BaseChannel>>(
channels: Cc,
export function emptyChannels<Channels extends Record<string, BaseChannel>>(
channels: Channels,
checkpoint: ReadonlyCheckpoint
): Cc {
): Channels {
const filteredChannels = Object.fromEntries(
Object.entries(channels).filter(([, value]) => isBaseChannel(value))
) as Cc;
) as Channels;

const newChannels = {} as Cc;
const newChannels = {} as Channels;
for (const k in filteredChannels) {
if (Object.prototype.hasOwnProperty.call(filteredChannels, k)) {
const channelValue = checkpoint.channel_values[k];
Expand Down
115 changes: 66 additions & 49 deletions libs/langgraph/src/pregel/algo.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,18 +38,16 @@ import {
TAG_HIDDEN,
TASKS,
} from "../constants.js";
import { PregelExecutableTask, PregelTaskDescription } from "./types.js";
import {
ChannelsType,
NodesType,
PregelExecutableTask,
PregelTaskDescription,
} from "./types.js";
import { EmptyChannelError, InvalidUpdateError } from "../errors.js";
import { _getIdMetadata, getNullChannelVersion } from "./utils.js";
import { ManagedValueMapping } from "../managed/base.js";

/**
* Construct a type with a set of properties K of type T
*/
export type StrRecord<K extends string, T> = {
[P in K]: T;
};

export type WritesProtocol<C = string> = {
name: string;
writes: PendingWrite<C>[];
Expand Down Expand Up @@ -91,17 +89,17 @@ export function shouldInterrupt<N extends PropertyKey, C extends PropertyKey>(
return anyChannelUpdated && anyTriggeredNodeInInterruptNodes;
}

export function _localRead<Cc extends Record<string, BaseChannel>>(
export function _localRead<LocalChannels extends Record<string, BaseChannel>>(
step: number,
checkpoint: ReadonlyCheckpoint,
channels: Cc,
channels: LocalChannels,
managed: ManagedValueMapping,
task: WritesProtocol<keyof Cc>,
select: Array<keyof Cc> | keyof Cc,
task: WritesProtocol<keyof LocalChannels>,
select: Array<keyof LocalChannels> | keyof LocalChannels,
fresh: boolean = false
): Record<string, unknown> | unknown {
let managedKeys: Array<keyof Cc> = [];
let updated = new Set<keyof Cc>();
let managedKeys: Array<keyof LocalChannels> = [];
let updated = new Set<keyof LocalChannels>();

if (!Array.isArray(select)) {
for (const [c] of task.writes) {
Expand All @@ -113,9 +111,11 @@ export function _localRead<Cc extends Record<string, BaseChannel>>(
updated = updated || new Set();
} else {
managedKeys = select.filter((k) => managed.get(k as string)) as Array<
keyof Cc
keyof LocalChannels
>;
select = select.filter((k) => !managed.get(k as string)) as Array<
keyof LocalChannels
>;
select = select.filter((k) => !managed.get(k as string)) as Array<keyof Cc>;
updated = new Set(
select.filter((c) => task.writes.some(([key, _]) => key === c))
);
Expand All @@ -125,11 +125,20 @@ export function _localRead<Cc extends Record<string, BaseChannel>>(

if (fresh && updated.size > 0) {
const localChannels = Object.fromEntries(
Object.entries(channels).filter(([k, _]) => updated.has(k as keyof Cc))
) as Partial<Cc>;

const newCheckpoint = createCheckpoint(checkpoint, localChannels as Cc, -1);
const newChannels = emptyChannels(localChannels as Cc, newCheckpoint);
Object.entries(channels).filter(([k, _]) =>
updated.has(k as keyof LocalChannels)
)
) as Partial<LocalChannels>;

const newCheckpoint = createCheckpoint(
checkpoint,
localChannels as LocalChannels,
-1
);
const newChannels = emptyChannels(
localChannels as LocalChannels,
newCheckpoint
);

_applyWrites(copyCheckpoint(newCheckpoint), newChannels, [task]);
values = readChannels({ ...channels, ...newChannels }, select);
Expand Down Expand Up @@ -183,17 +192,17 @@ export function _localWrite(
commit(writes);
}

export function _applyWrites<Cc extends Record<string, BaseChannel>>(
export function _applyWrites<LocalChannels extends Record<string, BaseChannel>>(
checkpoint: Checkpoint,
channels: Cc,
tasks: WritesProtocol<keyof Cc>[],
channels: LocalChannels,
tasks: WritesProtocol<keyof LocalChannels>[],
// eslint-disable-next-line @typescript-eslint/no-explicit-any
getNextVersion?: (version: any, channel: BaseChannel) => any
): Record<string, PendingWriteValue[]> {
// Filter out non instances of BaseChannel
const onlyChannels = Object.fromEntries(
Object.entries(channels).filter(([_, value]) => isBaseChannel(value))
) as Cc;
) as LocalChannels;
// Update seen versions
for (const task of tasks) {
if (checkpoint.versions_seen[task.name] === undefined) {
Expand Down Expand Up @@ -240,10 +249,13 @@ export function _applyWrites<Cc extends Record<string, BaseChannel>>(

// Group writes by channel
const pendingWriteValuesByChannel = {} as Record<
keyof Cc,
keyof LocalChannels,
PendingWriteValue[]
>;
const pendingWritesByManaged = {} as Record<
keyof LocalChannels,
PendingWriteValue[]
>;
const pendingWritesByManaged = {} as Record<keyof Cc, PendingWriteValue[]>;
for (const task of tasks) {
for (const [chan, val] of task.writes) {
if (chan === TASKS) {
Expand Down Expand Up @@ -329,45 +341,48 @@ export type NextTaskExtraFields = {
};

export function _prepareNextTasks<
Nn extends StrRecord<string, PregelNode>,
Cc extends StrRecord<string, BaseChannel>
Nodes extends NodesType,
Channels extends ChannelsType
>(
checkpoint: ReadonlyCheckpoint,
processes: Nn,
channels: Cc,
processes: Nodes,
channels: Channels,
managed: ManagedValueMapping,
config: RunnableConfig,
forExecution: false,
extra: NextTaskExtraFields
): PregelTaskDescription[];

export function _prepareNextTasks<
Nn extends StrRecord<string, PregelNode>,
Cc extends StrRecord<string, BaseChannel>
Nodes extends NodesType,
Channels extends ChannelsType
>(
checkpoint: ReadonlyCheckpoint,
processes: Nn,
channels: Cc,
processes: Nodes,
channels: Channels,
managed: ManagedValueMapping,
config: RunnableConfig,
forExecution: true,
extra: NextTaskExtraFields
): PregelExecutableTask<keyof Nn, keyof Cc>[];
): PregelExecutableTask<keyof Nodes, keyof Channels>[];

export function _prepareNextTasks<
Nn extends StrRecord<string, PregelNode>,
Cc extends StrRecord<string, BaseChannel>
Nodes extends NodesType,
LocalChannels extends Record<string, BaseChannel>
>(
checkpoint: ReadonlyCheckpoint,
processes: Nn,
channels: Cc,
processes: Nodes,
channels: LocalChannels,
managed: ManagedValueMapping,
config: RunnableConfig,
forExecution: boolean,
extra: NextTaskExtraFields
): PregelTaskDescription[] | PregelExecutableTask<keyof Nn, keyof Cc>[] {
):
| PregelTaskDescription[]
| PregelExecutableTask<keyof Nodes, keyof LocalChannels>[] {
const parentNamespace = config.configurable?.checkpoint_ns ?? "";
const tasks: Array<PregelExecutableTask<keyof Nn, keyof Cc>> = [];
const tasks: Array<PregelExecutableTask<keyof Nodes, keyof LocalChannels>> =
[];
const taskDescriptions: Array<PregelTaskDescription> = [];
const { step, isResuming = false, checkpointer, manager } = extra;

Expand Down Expand Up @@ -404,7 +419,7 @@ export function _prepareNextTasks<
const proc = processes[packet.node];
const node = proc.getNode();
if (node !== undefined) {
const writes: [keyof Cc, unknown][] = [];
const writes: [keyof LocalChannels, unknown][] = [];
managed.replaceRuntimePlaceholders(step, packet.args);
tasks.push({
name: packet.node,
Expand All @@ -424,14 +439,15 @@ export function _prepareNextTasks<
[CONFIG_KEY_SEND]: (writes_: [string, any][]) =>
_localWrite(
step,
(items: [keyof Cc, unknown][]) => writes.push(...items),
(items: [keyof LocalChannels, unknown][]) =>
writes.push(...items),
processes,
channels,
managed,
writes_
),
[CONFIG_KEY_READ]: (
select_: Array<keyof Cc> | keyof Cc,
select_: Array<keyof LocalChannels> | keyof LocalChannels,
fresh_: boolean = false
) =>
_localRead(
Expand Down Expand Up @@ -510,7 +526,7 @@ export function _prepareNextTasks<
if (forExecution) {
const node = proc.getNode();
if (node !== undefined) {
const writes: [keyof Cc, unknown][] = [];
const writes: [keyof LocalChannels, unknown][] = [];
tasks.push({
name,
input: val,
Expand All @@ -527,14 +543,15 @@ export function _prepareNextTasks<
[CONFIG_KEY_SEND]: (writes_: [string, any][]) =>
_localWrite(
step,
(items: [keyof Cc, unknown][]) => writes.push(...items),
(items: [keyof LocalChannels, unknown][]) =>
writes.push(...items),
processes,
channels,
managed,
writes_
),
[CONFIG_KEY_READ]: (
select_: Array<keyof Cc> | keyof Cc,
select_: Array<keyof LocalChannels> | keyof LocalChannels,
fresh_: boolean = false
) =>
_localRead(
Expand Down Expand Up @@ -573,7 +590,7 @@ function _procInput(
step: number,
proc: PregelNode,
managed: ManagedValueMapping,
channels: StrRecord<string, BaseChannel>,
channels: Record<string, BaseChannel>,
forExecution: boolean
) {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
Expand Down
Loading
Loading