From 86eab64832da536cd6e5db9be7b753f350945b58 Mon Sep 17 00:00:00 2001 From: bracesproul Date: Mon, 6 Jan 2025 16:43:12 -0800 Subject: [PATCH 1/2] feat: Allow llm model to be configurable --- package.json | 5 +- scripts/generate-demo-post.ts | 11 ---- src/agents/generate-post/constants.ts | 2 + .../generate-post/generate-post-state.ts | 12 ++++- .../generate-post/nodes/condense-post.ts | 8 +-- .../nodes/generate-content-report.ts | 7 ++- .../nodes/geterate-post/index.ts | 5 +- .../generate-post/nodes/human-node/index.ts | 3 +- .../nodes/human-node/route-response.ts | 8 +-- .../generate-post/nodes/rewrite-post.ts | 7 ++- .../nodes/schedule-post/index.ts | 7 ++- .../nodes/update-scheduled-date.ts | 13 +++-- src/agents/ingest-data/ingest-data-graph.ts | 7 ++- src/agents/ingest-data/ingest-data-state.ts | 15 +++++- src/agents/reflection/index.ts | 5 +- src/agents/shared/nodes/verify-general.ts | 20 +++++--- src/agents/shared/nodes/verify-github.ts | 11 ++-- src/agents/shared/nodes/verify-reddit.ts | 13 ++--- src/agents/shared/nodes/verify-youtube.ts | 16 +++--- src/agents/utils.ts | 29 +++++++++++ .../nodes/get-reddit-content.ts | 13 ++--- .../verify-tweet/nodes/validate-tweet.ts | 16 +++--- src/tests/graph.int.test.ts | 6 ++- yarn.lock | 50 ++++++++++++++----- 24 files changed, 193 insertions(+), 96 deletions(-) diff --git a/package.json b/package.json index a37663b..bcbb631 100644 --- a/package.json +++ b/package.json @@ -36,9 +36,9 @@ "dependencies": { "@arcadeai/arcadejs": "^0.1.2", "@googleapis/youtube": "^20.0.0", - "@langchain/anthropic": "^0.3.9", + "@langchain/anthropic": "^0.3.11", "@langchain/community": "^0.3.22", - "@langchain/core": "^0.3.22", + "@langchain/core": "^0.3.27", "@langchain/google-vertexai-web": "^0.1.2", "@langchain/langgraph": "^0.2.31", "@langchain/langgraph-sdk": "^0.0.31", @@ -54,6 +54,7 @@ "express-session": "^1.18.1", "file-type": "^19.6.0", "google-auth-library": "^9.15.0", + "langchain": "^0.3.10", "langsmith": "0.2.15-rc.2", "moment": "^2.30.1", "passport": "^0.7.0", diff --git a/scripts/generate-demo-post.ts b/scripts/generate-demo-post.ts index a4aff8a..665acd3 100644 --- a/scripts/generate-demo-post.ts +++ b/scripts/generate-demo-post.ts @@ -1,9 +1,5 @@ import "dotenv/config"; import { Client } from "@langchain/langgraph-sdk"; -// import { -// LINKEDIN_USER_ID, -// TWITTER_USER_ID, -// } from "../src/agents/generate-post/constants.js"; /** * Generate a post based on the Open Canvas project. @@ -22,13 +18,6 @@ async function invokeGraph() { input: { links: [link], }, - config: { - configurable: { - // By default, the graph will read these values from the environment - // [TWITTER_USER_ID]: process.env.TWITTER_USER_ID, - // [LINKEDIN_USER_ID]: process.env.LINKEDIN_USER_ID, - }, - }, }); } diff --git a/src/agents/generate-post/constants.ts b/src/agents/generate-post/constants.ts index 4889316..018ffb0 100644 --- a/src/agents/generate-post/constants.ts +++ b/src/agents/generate-post/constants.ts @@ -95,3 +95,5 @@ export const TWITTER_USER_ID = "twitterUserId"; export const TWITTER_TOKEN = "twitterToken"; export const TWITTER_TOKEN_SECRET = "twitterTokenSecret"; export const INGEST_TWITTER_USERNAME = "ingestTwitterUsername"; + +export const LLM_MODEL_NAME = "llmModel"; diff --git a/src/agents/generate-post/generate-post-state.ts b/src/agents/generate-post/generate-post-state.ts index f15e1fc..c90b0d0 100644 --- a/src/agents/generate-post/generate-post-state.ts +++ b/src/agents/generate-post/generate-post-state.ts @@ -1,6 +1,6 @@ import { Annotation, END } from "@langchain/langgraph"; import { IngestDataAnnotation } from "../ingest-data/ingest-data-state.js"; -import { POST_TO_LINKEDIN_ORGANIZATION } from "./constants.js"; +import { LLM_MODEL_NAME, POST_TO_LINKEDIN_ORGANIZATION } from "./constants.js"; import { DateType } from "../types.js"; export type LangChainProduct = "langchain" | "langgraph" | "langsmith"; @@ -102,4 +102,14 @@ export const GeneratePostConfigurableAnnotation = Annotation.Root({ * If true, [LINKEDIN_ORGANIZATION_ID] is required. */ [POST_TO_LINKEDIN_ORGANIZATION]: Annotation, + /** + * The name of the LLM to use for generations + * @default "gemini-2.0-flash-exp" + */ + [LLM_MODEL_NAME]: Annotation< + "gemini-2.0-flash-exp" | "claude-3-5-sonnet-latest" | undefined + >({ + reducer: (_state, update) => update, + default: () => "gemini-2.0-flash-exp", + }), }); diff --git a/src/agents/generate-post/nodes/condense-post.ts b/src/agents/generate-post/nodes/condense-post.ts index 3f51281..b0f1ac0 100644 --- a/src/agents/generate-post/nodes/condense-post.ts +++ b/src/agents/generate-post/nodes/condense-post.ts @@ -1,8 +1,8 @@ -import { ChatAnthropic } from "@langchain/anthropic"; import { GeneratePostAnnotation } from "../generate-post-state.js"; import { STRUCTURE_INSTRUCTIONS, RULES } from "./geterate-post/prompts.js"; import { parseGeneration } from "./geterate-post/utils.js"; -import { removeUrls } from "../../utils.js"; +import { getModelFromConfig, removeUrls } from "../../utils.js"; +import { LangGraphRunnableConfig } from "@langchain/langgraph"; const CONDENSE_POST_PROMPT = `You're a highly skilled marketer at LangChain, working on crafting thoughtful and engaging content for LangChain's LinkedIn and Twitter pages. You wrote a post for the LangChain LinkedIn and Twitter pages, however it's a bit too long for Twitter, and thus needs to be condensed. @@ -52,6 +52,7 @@ Follow all rules and instructions outlined above. The user message below will pr */ export async function condensePost( state: typeof GeneratePostAnnotation.State, + config: LangGraphRunnableConfig, ): Promise> { if (!state.post) { throw new Error("No post found"); @@ -72,8 +73,7 @@ export async function condensePost( .replace("{link}", state.relevantLinks[0]) .replace("{originalPostLength}", originalPostLength); - const condensePostModel = new ChatAnthropic({ - model: "claude-3-5-sonnet-20241022", + const condensePostModel = await getModelFromConfig(config, { temperature: 0.5, }); diff --git a/src/agents/generate-post/nodes/generate-content-report.ts b/src/agents/generate-post/nodes/generate-content-report.ts index 4a5edae..4e54862 100644 --- a/src/agents/generate-post/nodes/generate-content-report.ts +++ b/src/agents/generate-post/nodes/generate-content-report.ts @@ -1,7 +1,7 @@ import { LangGraphRunnableConfig } from "@langchain/langgraph"; import { GeneratePostAnnotation } from "../generate-post-state.js"; import { LANGCHAIN_PRODUCTS_CONTEXT } from "../prompts.js"; -import { ChatAnthropic } from "@langchain/anthropic"; +import { getModelFromConfig } from "../../utils.js"; const GENERATE_REPORT_PROMPT = `You are a highly regarded marketing employee at LangChain. You have been tasked with writing a marketing report on content submitted to you from a third party which uses LangChain's products. @@ -85,10 +85,9 @@ ${pageContents.map((content, index) => `\n${conten export async function generateContentReport( state: typeof GeneratePostAnnotation.State, - _config: LangGraphRunnableConfig, + config: LangGraphRunnableConfig, ): Promise> { - const reportModel = new ChatAnthropic({ - model: "claude-3-5-sonnet-20241022", + const reportModel = await getModelFromConfig(config, { temperature: 0, }); diff --git a/src/agents/generate-post/nodes/geterate-post/index.ts b/src/agents/generate-post/nodes/geterate-post/index.ts index b54e9e4..bab2d45 100644 --- a/src/agents/generate-post/nodes/geterate-post/index.ts +++ b/src/agents/generate-post/nodes/geterate-post/index.ts @@ -1,11 +1,11 @@ import { LangGraphRunnableConfig } from "@langchain/langgraph"; import { GeneratePostAnnotation } from "../../generate-post-state.js"; -import { ChatAnthropic } from "@langchain/anthropic"; import { GENERATE_POST_PROMPT, REFLECTIONS_PROMPT } from "./prompts.js"; import { formatPrompt, parseGeneration } from "./utils.js"; import { ALLOWED_TIMES } from "../../constants.js"; import { getReflections, RULESET_KEY } from "../../../../utils/reflections.js"; import { getNextSaturdayDate } from "../../../../utils/date.js"; +import { getModelFromConfig } from "../../../utils.js"; export async function generatePost( state: typeof GeneratePostAnnotation.State, @@ -17,8 +17,7 @@ export async function generatePost( if (state.relevantLinks.length === 0) { throw new Error("No relevant links found"); } - const postModel = new ChatAnthropic({ - model: "claude-3-5-sonnet-20241022", + const postModel = await getModelFromConfig(config, { temperature: 0.5, }); diff --git a/src/agents/generate-post/nodes/human-node/index.ts b/src/agents/generate-post/nodes/human-node/index.ts index 5cd1a27..4ad88d7 100644 --- a/src/agents/generate-post/nodes/human-node/index.ts +++ b/src/agents/generate-post/nodes/human-node/index.ts @@ -78,7 +78,7 @@ Here is the report that was generated for the posts:\n${report} export async function humanNode( state: typeof GeneratePostAnnotation.State, - _config: LangGraphRunnableConfig, + config: LangGraphRunnableConfig, ): Promise> { if (!state.post) { throw new Error("No post found"); @@ -153,6 +153,7 @@ export async function humanNode( post: state.post, dateOrPriority: defaultDateString, userResponse: response.args, + config, }); if (route === "rewrite_post") { diff --git a/src/agents/generate-post/nodes/human-node/route-response.ts b/src/agents/generate-post/nodes/human-node/route-response.ts index dba3a2d..e969eb6 100644 --- a/src/agents/generate-post/nodes/human-node/route-response.ts +++ b/src/agents/generate-post/nodes/human-node/route-response.ts @@ -1,5 +1,6 @@ -import { ChatAnthropic } from "@langchain/anthropic"; import { z } from "zod"; +import { getModelFromConfig } from "../../../utils.js"; +import { LangGraphRunnableConfig } from "@langchain/langgraph"; const ROUTE_RESPONSE_PROMPT = `You are an AI assistant tasked with routing a user's response to one of two possible routes based on their intention. The two possible routes are: @@ -55,15 +56,16 @@ interface RouteResponseArgs { post: string; dateOrPriority: string; userResponse: string; + config: LangGraphRunnableConfig; } export async function routeResponse({ post, dateOrPriority, userResponse, + config, }: RouteResponseArgs) { - const model = new ChatAnthropic({ - model: "claude-3-5-sonnet-20241022", + const model = await getModelFromConfig(config, { temperature: 0, }); diff --git a/src/agents/generate-post/nodes/rewrite-post.ts b/src/agents/generate-post/nodes/rewrite-post.ts index 09da4d7..b4f3aea 100644 --- a/src/agents/generate-post/nodes/rewrite-post.ts +++ b/src/agents/generate-post/nodes/rewrite-post.ts @@ -1,7 +1,7 @@ import { Client } from "@langchain/langgraph-sdk"; import { LangGraphRunnableConfig } from "@langchain/langgraph"; import { GeneratePostAnnotation } from "../generate-post-state.js"; -import { ChatAnthropic } from "@langchain/anthropic"; +import { getModelFromConfig } from "../../utils.js"; const REWRITE_POST_PROMPT = `You're a highly regarded marketing employee at LangChain, working on crafting thoughtful and engaging content for LangChain's LinkedIn and Twitter pages. You wrote a post for the LangChain LinkedIn and Twitter pages, however your boss has asked for some changes to be made before it can be published. @@ -44,7 +44,7 @@ async function runReflections({ export async function rewritePost( state: typeof GeneratePostAnnotation.State, - _config: LangGraphRunnableConfig, + config: LangGraphRunnableConfig, ): Promise> { if (!state.post) { throw new Error("No post found"); @@ -53,8 +53,7 @@ export async function rewritePost( throw new Error("No user response found"); } - const rewritePostModel = new ChatAnthropic({ - model: "claude-3-5-sonnet-20241022", + const rewritePostModel = await getModelFromConfig(config, { temperature: 0.5, }); diff --git a/src/agents/generate-post/nodes/schedule-post/index.ts b/src/agents/generate-post/nodes/schedule-post/index.ts index c9bb3f1..299dd5a 100644 --- a/src/agents/generate-post/nodes/schedule-post/index.ts +++ b/src/agents/generate-post/nodes/schedule-post/index.ts @@ -1,7 +1,10 @@ import { LangGraphRunnableConfig } from "@langchain/langgraph"; import { GeneratePostAnnotation } from "../../generate-post-state.js"; import { Client } from "@langchain/langgraph-sdk"; -import { POST_TO_LINKEDIN_ORGANIZATION } from "../../constants.js"; +import { + LLM_MODEL_NAME, + POST_TO_LINKEDIN_ORGANIZATION, +} from "../../constants.js"; import { getScheduledDateSeconds } from "./find-date.js"; import { SlackClient } from "../../../../clients/slack.js"; import { getFutureDate } from "./get-future-date.js"; @@ -40,6 +43,8 @@ export async function schedulePost( [POST_TO_LINKEDIN_ORGANIZATION]: config.configurable?.[POST_TO_LINKEDIN_ORGANIZATION] || process.env.POST_TO_LINKEDIN_ORGANIZATION, + [LLM_MODEL_NAME]: + config.configurable?.[LLM_MODEL_NAME] || "gemini-2.0-flash-exp", }, }, afterSeconds, diff --git a/src/agents/generate-post/nodes/update-scheduled-date.ts b/src/agents/generate-post/nodes/update-scheduled-date.ts index 31ac360..706cffd 100644 --- a/src/agents/generate-post/nodes/update-scheduled-date.ts +++ b/src/agents/generate-post/nodes/update-scheduled-date.ts @@ -1,9 +1,10 @@ import { z } from "zod"; import { GeneratePostAnnotation } from "../generate-post-state.js"; -import { ChatAnthropic } from "@langchain/anthropic"; import { toZonedTime } from "date-fns-tz"; import { DateType } from "../../types.js"; import { timezoneToUtc } from "../../../utils/date.js"; +import { getModelFromConfig } from "../../utils.js"; +import { LangGraphRunnableConfig } from "@langchain/langgraph"; const SCHEDULE_POST_DATE_PROMPT = `You're an intelligent AI assistant tasked with extracting the date to schedule a social media post from the user's message. @@ -40,14 +41,16 @@ const scheduleDateSchema = z.object({ export async function updateScheduledDate( state: typeof GeneratePostAnnotation.State, + config: LangGraphRunnableConfig, ): Promise> { if (!state.userResponse) { throw new Error("No user response found"); } - const model = new ChatAnthropic({ - model: "claude-3-5-sonnet-20241022", - temperature: 0.5, - }).withStructuredOutput(scheduleDateSchema, { + const model = ( + await getModelFromConfig(config, { + temperature: 0.5, + }) + ).withStructuredOutput(scheduleDateSchema, { name: "scheduleDate", }); const pstDate = toZonedTime(new Date(), "America/Los_Angeles"); diff --git a/src/agents/ingest-data/ingest-data-graph.ts b/src/agents/ingest-data/ingest-data-graph.ts index 2cd2e1e..5746011 100644 --- a/src/agents/ingest-data/ingest-data-graph.ts +++ b/src/agents/ingest-data/ingest-data-graph.ts @@ -10,7 +10,10 @@ import { } from "./ingest-data-state.js"; import { ingestSlackData } from "./nodes/ingest-slack.js"; import { Client } from "@langchain/langgraph-sdk"; -import { POST_TO_LINKEDIN_ORGANIZATION } from "../generate-post/constants.js"; +import { + LLM_MODEL_NAME, + POST_TO_LINKEDIN_ORGANIZATION, +} from "../generate-post/constants.js"; import { getUrlType } from "../utils.js"; /** @@ -69,6 +72,8 @@ async function generatePostFromMessages( config: { configurable: { [POST_TO_LINKEDIN_ORGANIZATION]: postToLinkedInOrg, + [LLM_MODEL_NAME]: + config.configurable?.[LLM_MODEL_NAME] || "gemini-2.0-flash-exp", }, }, afterSeconds, diff --git a/src/agents/ingest-data/ingest-data-state.ts b/src/agents/ingest-data/ingest-data-state.ts index 484c047..19dcbb2 100644 --- a/src/agents/ingest-data/ingest-data-state.ts +++ b/src/agents/ingest-data/ingest-data-state.ts @@ -1,6 +1,9 @@ import { Annotation } from "@langchain/langgraph"; import { SimpleSlackMessage } from "../../clients/slack.js"; -import { POST_TO_LINKEDIN_ORGANIZATION } from "../generate-post/constants.js"; +import { + LLM_MODEL_NAME, + POST_TO_LINKEDIN_ORGANIZATION, +} from "../generate-post/constants.js"; export type LangChainProduct = "langchain" | "langgraph" | "langsmith"; export type SimpleSlackMessageWithLinks = SimpleSlackMessage & { @@ -50,4 +53,14 @@ export const IngestDataConfigurableAnnotation = Annotation.Root({ * If true, [LINKEDIN_ORGANIZATION_ID] is required. */ [POST_TO_LINKEDIN_ORGANIZATION]: Annotation, + /** + * The name of the LLM to use for generations + * @default "gemini-2.0-flash-exp" + */ + [LLM_MODEL_NAME]: Annotation< + "gemini-2.0-flash-exp" | "claude-3-5-sonnet-latest" | undefined + >({ + reducer: (_state, update) => update, + default: () => "gemini-2.0-flash-exp", + }), }); diff --git a/src/agents/reflection/index.ts b/src/agents/reflection/index.ts index 8c7adf0..c8118ba 100644 --- a/src/agents/reflection/index.ts +++ b/src/agents/reflection/index.ts @@ -6,13 +6,13 @@ import { StateGraph, } from "@langchain/langgraph"; import { z } from "zod"; -import { ChatAnthropic } from "@langchain/anthropic"; import { getReflections, putReflections, RULESET_KEY, } from "../../utils/reflections.js"; import { REFLECTION_PROMPT, UPDATE_RULES_PROMPT } from "./prompts.js"; +import { getModelFromConfig } from "../utils.js"; const newRuleSchema = z.object({ newRule: z.string().describe("The new rule to create."), @@ -47,8 +47,7 @@ async function reflection( if (!config.store) { throw new Error("No store provided"); } - const model = new ChatAnthropic({ - model: "claude-3-5-sonnet-latest", + const model = await getModelFromConfig(config, { temperature: 0, }); diff --git a/src/agents/shared/nodes/verify-general.ts b/src/agents/shared/nodes/verify-general.ts index 74750bd..9518307 100644 --- a/src/agents/shared/nodes/verify-general.ts +++ b/src/agents/shared/nodes/verify-general.ts @@ -1,12 +1,11 @@ import { LangGraphRunnableConfig } from "@langchain/langgraph"; import { GeneratePostAnnotation } from "../../generate-post/generate-post-state.js"; import { z } from "zod"; -import { ChatAnthropic } from "@langchain/anthropic"; import { FireCrawlLoader } from "@langchain/community/document_loaders/web/firecrawl"; import { LANGCHAIN_PRODUCTS_CONTEXT } from "../../generate-post/prompts.js"; import { VerifyContentAnnotation } from "../shared-state.js"; import { RunnableLambda } from "@langchain/core/runnables"; -import { getPageText } from "../../utils.js"; +import { getModelFromConfig, getPageText } from "../../utils.js"; type VerifyGeneralContentReturn = { relevantLinks: (typeof GeneratePostAnnotation.State)["relevantLinks"]; @@ -87,11 +86,13 @@ export async function getUrlContents(url: string): Promise { export async function verifyGeneralContentIsRelevant( content: string, + config: LangGraphRunnableConfig, ): Promise { - const relevancyModel = new ChatAnthropic({ - model: "claude-3-5-sonnet-20241022", - temperature: 0, - }).withStructuredOutput(RELEVANCY_SCHEMA, { + const relevancyModel = ( + await getModelFromConfig(config, { + temperature: 0, + }) + ).withStructuredOutput(RELEVANCY_SCHEMA, { name: "relevancy", }); @@ -125,14 +126,17 @@ export async function verifyGeneralContentIsRelevant( */ export async function verifyGeneralContent( state: typeof VerifyContentAnnotation.State, - _config: LangGraphRunnableConfig, + config: LangGraphRunnableConfig, ): Promise { const urlContents = await new RunnableLambda({ func: getUrlContents, }) .withConfig({ runName: "get-url-contents" }) .invoke(state.link); - const relevant = await verifyGeneralContentIsRelevant(urlContents.content); + const relevant = await verifyGeneralContentIsRelevant( + urlContents.content, + config, + ); if (relevant) { return { diff --git a/src/agents/shared/nodes/verify-github.ts b/src/agents/shared/nodes/verify-github.ts index d82f2a7..8a1aadb 100644 --- a/src/agents/shared/nodes/verify-github.ts +++ b/src/agents/shared/nodes/verify-github.ts @@ -1,6 +1,5 @@ import { z } from "zod"; import { LangGraphRunnableConfig } from "@langchain/langgraph"; -import { ChatAnthropic } from "@langchain/anthropic"; import { LANGCHAIN_PRODUCTS_CONTEXT } from "../../generate-post/prompts.js"; import { VerifyContentAnnotation } from "../shared-state.js"; import { GeneratePostAnnotation } from "../../generate-post/generate-post-state.js"; @@ -9,6 +8,7 @@ import { getRepoContents, getFileContents, } from "../../../utils/github-repo-contents.js"; +import { getModelFromConfig } from "../../utils.js"; type VerifyGitHubContentReturn = { relevantLinks: (typeof GeneratePostAnnotation.State)["relevantLinks"]; @@ -132,10 +132,11 @@ export async function verifyGitHubContentIsRelevant({ dependenciesFileName, config, }: VerifyGitHubContentParams): Promise { - const relevancyModel = new ChatAnthropic({ - model: "claude-3-5-sonnet-20241022", - temperature: 0, - }).withStructuredOutput(RELEVANCY_SCHEMA, { + const relevancyModel = ( + await getModelFromConfig(config, { + temperature: 0, + }) + ).withStructuredOutput(RELEVANCY_SCHEMA, { name: "relevancy", }); diff --git a/src/agents/shared/nodes/verify-reddit.ts b/src/agents/shared/nodes/verify-reddit.ts index 75f6121..e5cbf49 100644 --- a/src/agents/shared/nodes/verify-reddit.ts +++ b/src/agents/shared/nodes/verify-reddit.ts @@ -2,8 +2,8 @@ import { LangGraphRunnableConfig } from "@langchain/langgraph"; import { GeneratePostAnnotation } from "../../generate-post/generate-post-state.js"; import { VerifyContentAnnotation } from "../shared-state.js"; import { LANGCHAIN_PRODUCTS_CONTEXT } from "../../generate-post/prompts.js"; -import { ChatAnthropic } from "@langchain/anthropic"; import { z } from "zod"; +import { getModelFromConfig } from "../../utils.js"; /** * TODO: Refactor into a subgraph @@ -81,7 +81,7 @@ You should provide reasoning as to why or why not the content implements LangCha export async function getRedditPostContent( state: typeof VerifyContentAnnotation.State, - _config: LangGraphRunnableConfig, + config: LangGraphRunnableConfig, ): Promise { const url = new URL(state.link); const jsonUrl = `${url.origin}${url.pathname.endsWith("/") ? url.pathname.slice(0, -1) : url.pathname}.json`; @@ -97,10 +97,11 @@ export async function getRedditPostContent( ) .join("\n")}\n`; - const relevancyModel = new ChatAnthropic({ - model: "claude-3-5-sonnet-20241022", - temperature: 0, - }).withStructuredOutput(RELEVANCY_SCHEMA, { + const relevancyModel = ( + await getModelFromConfig(config, { + temperature: 0, + }) + ).withStructuredOutput(RELEVANCY_SCHEMA, { name: "relevancy", }); diff --git a/src/agents/shared/nodes/verify-youtube.ts b/src/agents/shared/nodes/verify-youtube.ts index b7f2577..e56e3df 100644 --- a/src/agents/shared/nodes/verify-youtube.ts +++ b/src/agents/shared/nodes/verify-youtube.ts @@ -2,7 +2,6 @@ import { z } from "zod"; import { LangGraphRunnableConfig } from "@langchain/langgraph"; import { GeneratePostAnnotation } from "../../generate-post/generate-post-state.js"; import { ChatVertexAI } from "@langchain/google-vertexai-web"; -import { ChatAnthropic } from "@langchain/anthropic"; import { HumanMessage } from "@langchain/core/messages"; import { LANGCHAIN_PRODUCTS_CONTEXT } from "../../generate-post/prompts.js"; import { VerifyContentAnnotation } from "../shared-state.js"; @@ -10,6 +9,7 @@ import { getVideoThumbnailUrl, getYouTubeVideoDuration, } from "./youtube.utils.js"; +import { getModelFromConfig } from "../../utils.js"; type VerifyYouTubeContentReturn = { relevantLinks: (typeof GeneratePostAnnotation.State)["relevantLinks"]; @@ -90,11 +90,13 @@ export async function generateVideoSummary(url: string): Promise { export async function verifyYouTubeContentIsRelevant( summary: string, + config: LangGraphRunnableConfig, ): Promise { - const relevancyModel = new ChatAnthropic({ - model: "claude-3-5-sonnet-20241022", - temperature: 0, - }).withStructuredOutput(RELEVANCY_SCHEMA, { + const relevancyModel = ( + await getModelFromConfig(config, { + temperature: 0, + }) + ).withStructuredOutput(RELEVANCY_SCHEMA, { name: "relevancy", }); @@ -120,7 +122,7 @@ export async function verifyYouTubeContentIsRelevant( */ export async function verifyYouTubeContent( state: typeof VerifyContentAnnotation.State, - _config: LangGraphRunnableConfig, + config: LangGraphRunnableConfig, ): Promise { const [videoDurationS, videoThumbnail] = await Promise.all([ getYouTubeVideoDuration(state.link), @@ -140,7 +142,7 @@ export async function verifyYouTubeContent( } const videoSummary = await generateVideoSummary(state.link); - const relevant = await verifyYouTubeContentIsRelevant(videoSummary); + const relevant = await verifyYouTubeContentIsRelevant(videoSummary, config); if (relevant) { return { diff --git a/src/agents/utils.ts b/src/agents/utils.ts index 9ac914e..e21cfbd 100644 --- a/src/agents/utils.ts +++ b/src/agents/utils.ts @@ -1,4 +1,9 @@ +import { initChatModel } from "langchain/chat_models/universal"; +import { LangGraphRunnableConfig } from "@langchain/langgraph"; import * as cheerio from "cheerio"; +import { LLM_MODEL_NAME } from "./generate-post/constants.js"; +import { ChatAnthropic } from "@langchain/anthropic"; +import { ChatVertexAI } from "@langchain/google-vertexai-web"; /** * Extracts URLs from Slack-style message text containing links in the format: @@ -436,3 +441,27 @@ export function removeQueryParams(url: string): string { return url; } } + +export async function getModelFromConfig( + config: LangGraphRunnableConfig, + modelArgs?: Record, +): Promise { + const model = config.configurable?.[LLM_MODEL_NAME] || "gemini-2.0-flash-exp"; + + if (model.startsWith("gemini-")) { + return initChatModel(model, { + modelProvider: "google-vertexai-web", + apiKey: process.env.GOOGLE_VERTEX_AI_WEB_CREDENTIALS, + ...modelArgs, + }) as unknown as ChatVertexAI; + } + if (model.startsWith("claude-")) { + return initChatModel(model, { + modelProvider: "anthropic", + apiKey: process.env.ANTHROPIC_API_KEY, + ...modelArgs, + }) as unknown as ChatAnthropic; + } + + throw new Error(`Unknown model: ${model}`); +} diff --git a/src/agents/verify-reddit-post/nodes/get-reddit-content.ts b/src/agents/verify-reddit-post/nodes/get-reddit-content.ts index e899289..766f2e8 100644 --- a/src/agents/verify-reddit-post/nodes/get-reddit-content.ts +++ b/src/agents/verify-reddit-post/nodes/get-reddit-content.ts @@ -1,9 +1,9 @@ import { LangGraphRunnableConfig } from "@langchain/langgraph"; import { GeneratePostAnnotation } from "../../generate-post/generate-post-state.js"; import { LANGCHAIN_PRODUCTS_CONTEXT } from "../../generate-post/prompts.js"; -import { ChatAnthropic } from "@langchain/anthropic"; import { z } from "zod"; import { VerifyRedditPostAnnotation } from "../verify-reddit-post-state.js"; +import { getModelFromConfig } from "../../utils.js"; /** * TODO: Support handling links in the main content of the reddit post @@ -94,7 +94,7 @@ You should provide reasoning as to why or why not the content implements LangCha export async function getRedditPostContent( state: typeof VerifyRedditPostAnnotation.State, - _config: LangGraphRunnableConfig, + config: LangGraphRunnableConfig, ): Promise { const url = new URL(state.link); const jsonUrl = `${url.origin}${url.pathname.endsWith("/") ? url.pathname.slice(0, -1) : url.pathname}.json`; @@ -110,10 +110,11 @@ export async function getRedditPostContent( ) .join("\n")}\n`; - const relevancyModel = new ChatAnthropic({ - model: "claude-3-5-sonnet-20241022", - temperature: 0, - }).withStructuredOutput(RELEVANCY_SCHEMA, { + const relevancyModel = ( + await getModelFromConfig(config, { + temperature: 0, + }) + ).withStructuredOutput(RELEVANCY_SCHEMA, { name: "relevancy", }); diff --git a/src/agents/verify-tweet/nodes/validate-tweet.ts b/src/agents/verify-tweet/nodes/validate-tweet.ts index ff1797d..e078730 100644 --- a/src/agents/verify-tweet/nodes/validate-tweet.ts +++ b/src/agents/verify-tweet/nodes/validate-tweet.ts @@ -1,7 +1,8 @@ import { z } from "zod"; import { LANGCHAIN_PRODUCTS_CONTEXT } from "../../generate-post/prompts.js"; import { VerifyTweetAnnotation } from "../verify-tweet-state.js"; -import { ChatAnthropic } from "@langchain/anthropic"; +import { getModelFromConfig } from "../../utils.js"; +import { LangGraphRunnableConfig } from "@langchain/langgraph"; const RELEVANCY_SCHEMA = z .object({ @@ -31,11 +32,13 @@ You should provide reasoning as to why or why not the content implements LangCha async function verifyGeneralContentIsRelevant( content: string, + config: LangGraphRunnableConfig, ): Promise { - const relevancyModel = new ChatAnthropic({ - model: "claude-3-5-sonnet-20241022", - temperature: 0, - }).withStructuredOutput(RELEVANCY_SCHEMA, { + const relevancyModel = ( + await getModelFromConfig(config, { + temperature: 0, + }) + ).withStructuredOutput(RELEVANCY_SCHEMA, { name: "relevancy", }); @@ -81,13 +84,14 @@ ${pageContents.map((content, index) => `\n${cont */ export async function validateTweetContent( state: typeof VerifyTweetAnnotation.State, + config: LangGraphRunnableConfig, ): Promise> { const context = constructContext({ tweetContent: state.tweetContent, pageContents: state.pageContents, }); - const relevant = await verifyGeneralContentIsRelevant(context); + const relevant = await verifyGeneralContentIsRelevant(context, config); if (!relevant) { return { diff --git a/src/tests/graph.int.test.ts b/src/tests/graph.int.test.ts index c3c494a..cd4803d 100644 --- a/src/tests/graph.int.test.ts +++ b/src/tests/graph.int.test.ts @@ -14,10 +14,14 @@ import { getGitHubContentsAndTypeFromUrl } from "../agents/shared/nodes/verify-g import { verifyYouTubeContent } from "../agents/shared/nodes/verify-youtube.js"; import { Command, MemorySaver } from "@langchain/langgraph"; import { verifyTweetGraph } from "../agents/verify-tweet/verify-tweet-graph.js"; -import { POST_TO_LINKEDIN_ORGANIZATION } from "../agents/generate-post/constants.js"; +import { + LLM_MODEL_NAME, + POST_TO_LINKEDIN_ORGANIZATION, +} from "../agents/generate-post/constants.js"; const BASE_CONFIG = { [POST_TO_LINKEDIN_ORGANIZATION]: undefined, + [LLM_MODEL_NAME]: undefined, }; describe("GeneratePostGraph", () => { diff --git a/yarn.lock b/yarn.lock index 5122adc..05e5a36 100644 --- a/yarn.lock +++ b/yarn.lock @@ -10,10 +10,10 @@ "@jridgewell/gen-mapping" "^0.3.5" "@jridgewell/trace-mapping" "^0.3.24" -"@anthropic-ai/sdk@^0.27.3": - version "0.27.3" - resolved "https://registry.yarnpkg.com/@anthropic-ai/sdk/-/sdk-0.27.3.tgz#592cdd873c85ffab9589ae6f2e250cbf150e1475" - integrity sha512-IjLt0gd3L4jlOfilxVXTifn42FnVffMgDC04RJK1KDZpmkBWLv0XC92MVVmkxrFZNS/7l3xWgP/I3nqtX1sQHw== +"@anthropic-ai/sdk@^0.32.1": + version "0.32.1" + resolved "https://registry.yarnpkg.com/@anthropic-ai/sdk/-/sdk-0.32.1.tgz#d22c8ebae2adccc59d78fb416e89de337ff09014" + integrity sha512-U9JwTrDvdQ9iWuABVsMLj8nJVwAyQz6QXvgLsVhryhCEPkLsbcP/MXxm+jYcAwLoV8ESbaTTjnD4kuAFa+Hyjg== dependencies: "@types/node" "^18.11.18" "@types/node-fetch" "^2.6.4" @@ -299,6 +299,11 @@ resolved "https://registry.yarnpkg.com/@bcoe/v8-coverage/-/v8-coverage-0.2.3.tgz#75a2e8b51cb758a7553d6804a5932d7aace75c39" integrity sha512-0hYQ8SB4Db5zvZB4axdMHGwEaQjkZzFjQiN9LVYvIFB2nSUHW9tYpxWriPrWDASIxiaXax83REcLxuSdnGPZtw== +"@cfworker/json-schema@^4.0.2": + version "4.1.0" + resolved "https://registry.yarnpkg.com/@cfworker/json-schema/-/json-schema-4.1.0.tgz#cc114da98c23b12f4cd4673ce8a076be24e0233c" + integrity sha512-/vYKi/qMxwNsuIJ9WGWwM2rflY40ZenK3Kh4uR5vB9/Nz12Y7IUN/Xf4wDA7vzPfw0VNh3b/jz4+MjcVgARKJg== + "@esbuild/aix-ppc64@0.23.1": version "0.23.1" resolved "https://registry.yarnpkg.com/@esbuild/aix-ppc64/-/aix-ppc64-0.23.1.tgz#51299374de171dbd80bb7d838e1cfce9af36f353" @@ -737,12 +742,12 @@ "@jridgewell/resolve-uri" "^3.1.0" "@jridgewell/sourcemap-codec" "^1.4.14" -"@langchain/anthropic@^0.3.9": - version "0.3.9" - resolved "https://registry.yarnpkg.com/@langchain/anthropic/-/anthropic-0.3.9.tgz#e95fc6a8d7de2d4d1b4f7dda434ec81b02e8b1d3" - integrity sha512-BZK6EIlYYoGKwmgiZiJZMeL68xv+ooWk/ynDv/QVN8MGGQYW2waBfLaKEFHtPQr/HDznc15FEYZe8C76/EtoYg== +"@langchain/anthropic@^0.3.11": + version "0.3.11" + resolved "https://registry.yarnpkg.com/@langchain/anthropic/-/anthropic-0.3.11.tgz#57277bb4bd7c624eb9039dc10cc36de9ecfeb91c" + integrity sha512-rYjDZjMwVQ+cYeJd9IoSESdkkG8fc0m3siGRYKNy6qgYMnqCz8sUPKBanXwbZAs6wvspPCGgNK9WONfaCeX97A== dependencies: - "@anthropic-ai/sdk" "^0.27.3" + "@anthropic-ai/sdk" "^0.32.1" fast-xml-parser "^4.4.1" zod "^3.22.4" zod-to-json-schema "^3.22.4" @@ -763,11 +768,12 @@ zod "^3.22.3" zod-to-json-schema "^3.22.5" -"@langchain/core@^0.3.22": - version "0.3.22" - resolved "https://registry.yarnpkg.com/@langchain/core/-/core-0.3.22.tgz#22064eca45a1f506e554c30537de62dc742c67f8" - integrity sha512-9rwEbxJi3Fgs8XuealNYxB6s0FCOnvXLnpiV5/oKgmEJtCRS91IqgJCWA8d59s4YkaEply/EsZVc2azNPK6Wjw== +"@langchain/core@^0.3.27": + version "0.3.27" + resolved "https://registry.yarnpkg.com/@langchain/core/-/core-0.3.27.tgz#b8e75bd4f122b18a423f3905df0e9872205a1a44" + integrity sha512-jtJKbJWB1NPU1YvtrExOB2rumvUFgkJwlWGxyjSIV9A6zcLVmUbcZGV8fCSuXgl5bbzOIQLJ1xcLYQmbW9TkTg== dependencies: + "@cfworker/json-schema" "^4.0.2" ansi-styles "^5.0.0" camelcase "6" decamelize "1.2.0" @@ -4012,6 +4018,24 @@ kleur@^3.0.3: zod "^3.22.4" zod-to-json-schema "^3.22.3" +langchain@^0.3.10: + version "0.3.10" + resolved "https://registry.yarnpkg.com/langchain/-/langchain-0.3.10.tgz#1bb43c0bfb50104e6d3b8ae284e56d543527dbff" + integrity sha512-dZXdhs81NU/PS2WfECCLJszx4to3ELK7qTMbumD0rAKx3mQb0sqr8M9MiVCcPgTZ1J1pzoDr5yCSdsmm9UsNXA== + dependencies: + "@langchain/openai" ">=0.1.0 <0.4.0" + "@langchain/textsplitters" ">=0.0.0 <0.2.0" + js-tiktoken "^1.0.12" + js-yaml "^4.1.0" + jsonpointer "^5.0.1" + langsmith "^0.2.8" + openapi-types "^12.1.3" + p-retry "4" + uuid "^10.0.0" + yaml "^2.2.1" + zod "^3.22.4" + zod-to-json-schema "^3.22.3" + langsmith@0.2.15-rc.2: version "0.2.15-rc.2" resolved "https://registry.yarnpkg.com/langsmith/-/langsmith-0.2.15-rc.2.tgz#e884ad0b019a04268d607937f3322027786784ca" From c92d4d6540aa81559e72138bd366fda65c9da02c Mon Sep 17 00:00:00 2001 From: bracesproul Date: Mon, 6 Jan 2025 17:05:16 -0800 Subject: [PATCH 2/2] cr --- src/agents/shared/nodes/verify-github.ts | 3 ++- src/agents/utils.ts | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/agents/shared/nodes/verify-github.ts b/src/agents/shared/nodes/verify-github.ts index 8a1aadb..518c7f1 100644 --- a/src/agents/shared/nodes/verify-github.ts +++ b/src/agents/shared/nodes/verify-github.ts @@ -45,7 +45,8 @@ ${LANGCHAIN_PRODUCTS_CONTEXT} {repoDependenciesPrompt} Given this context, examine the {file_type} closely, and determine if the repository implements LangChain's products. -You should provide reasoning as to why or why not the repository implements LangChain's products, then a simple true or false for whether or not it implements some.`; +You should provide reasoning as to why or why not the repository implements LangChain's products, then a simple true or false for whether or not it implements some. +Always call the 'relevancy' tool to respond.`; const getDependencies = async ( githubUrl: string, diff --git a/src/agents/utils.ts b/src/agents/utils.ts index e21cfbd..c6e40b8 100644 --- a/src/agents/utils.ts +++ b/src/agents/utils.ts @@ -451,7 +451,7 @@ export async function getModelFromConfig( if (model.startsWith("gemini-")) { return initChatModel(model, { modelProvider: "google-vertexai-web", - apiKey: process.env.GOOGLE_VERTEX_AI_WEB_CREDENTIALS, + // apiKey: process.env.GOOGLE_VERTEX_AI_WEB_CREDENTIALS, ...modelArgs, }) as unknown as ChatVertexAI; }