Skip to content

Commit

Permalink
add rerank images node
Browse files Browse the repository at this point in the history
  • Loading branch information
bracesproul committed Jan 7, 2025
1 parent 1181255 commit 4be3d7a
Show file tree
Hide file tree
Showing 5 changed files with 183 additions and 56 deletions.
2 changes: 1 addition & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
{
"cSpell.words": ["Supabase"]
"cSpell.words": ["reranked", "Supabase"]
}
114 changes: 112 additions & 2 deletions src/agents/find-images/nodes/re-rank-images.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,115 @@
import { ChatVertexAI } from "@langchain/google-vertexai-web";
import { FindImagesAnnotation } from "../find-images-graph.js";
import { chunkArray } from "../../utils.js";
import { getImageMessageContents } from "../utils.js";

export async function reRankImages(_state: typeof FindImagesAnnotation.State) {
throw new Error("Not implemented");
const RE_RANK_IMAGES_PROMPT = `You're a highly regarded marketing employee at LangChain, working on crafting thoughtful and engaging content for LangChain's LinkedIn and Twitter pages.
You're writing a post, and in doing so you've found a series of images that you think will help make the post more engaging.
Your task is to re-rank these images in order of which you think is the most engaging and best for the post.
Here is the marketing report the post was generated based on:
<report>
{REPORT}
</report>
And here's the actual post:
<post>
{POST}
</post>
Now, given this context, re-rank the images in order of most relevant to least relevant.
Provide your response in the following format:
1. <analysis> tag: Briefly explain your thought process for each image, referencing specific elements from the post and report and why each image is or isn't as relevant as others.
2. <reranked-indices> tag: List the indices of the relevant images in order of most relevant to least relevant, separated by commas.
Example: You're given 5 images, and deem that the relevancy order is [2, 0, 1, 4, 3], then you would respond as follows:
<answer>
<analysis>
- Image 2 is (explanation here)
- Image 0 is (explanation here)
- Image 1 is (explanation here)
- Image 4 is (explanation here)
- Image 3 is (explanation here)
</analysis>
<reranked-indices>
2, 0, 1, 4, 3
</reranked-indices>
</answer>
Ensure you ALWAYS WRAP your analysis and relevant indices inside the <analysis> and <reranked-indices> tags, respectively. Do not only prefix, but ensure they are wrapped completely.
Provide your complete response within <answer> tags.`;

export function parseResult(result: string): number[] {
const match = result.match(
/<reranked-indices>\s*([\d,\s]*?)\s*<\/reranked-indices>/s,
);
if (!match) return [];

return match[1]
.split(",")
.map((s) => s.trim())
.filter((s) => s.length > 0)
.map(Number)
.filter((n) => !isNaN(n));
}

export async function reRankImages(state: typeof FindImagesAnnotation.State) {
if (state.imageOptions.length === 0) {
return {
imageOptions: [],
};
}

const model = new ChatVertexAI({
model: "gemini-2.0-flash-exp",
temperature: 0,
});

// Split images into chunks of 5
const imageChunks = chunkArray(state.imageOptions, 5);
let reRankedIndices: number[] = [];
let baseIndex = 0;

const formattedSystemPrompt = RE_RANK_IMAGES_PROMPT.replace(
"{POST}",
state.post,
).replace("{REPORT}", state.report);

// Process each chunk
for (const imageChunk of imageChunks) {
const imageMessages = await getImageMessageContents(imageChunk, baseIndex);

if (!imageMessages.length) {
continue;
}
const response = await model.invoke([
{
role: "system",
content: formattedSystemPrompt,
},
{
role: "user",
content: imageMessages,
},
]);

const chunkAnalysis = parseResult(response.content as string);
// Convert chunk indices to global indices and add to our list of re-ranked indices
const globalIndices = chunkAnalysis.map((index) => index + baseIndex);
reRankedIndices = [...reRankedIndices, ...globalIndices];

baseIndex += imageChunk.length;
}

const imageOptionsInOrder = reRankedIndices.map(
(index) => state.imageOptions[index],
);

return {
imageOptions: imageOptionsInOrder,
};
}
57 changes: 4 additions & 53 deletions src/agents/find-images/nodes/validate-images.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
import { ChatVertexAI } from "@langchain/google-vertexai-web";
import { FindImagesAnnotation } from "../find-images-graph.js";
import {
removeQueryParams,
getMimeTypeFromUrl,
imageUrlToBuffer,
BLACKLISTED_MIME_TYPES,
} from "../../utils.js";
import { chunkArray } from "../../utils.js";
import { getImageMessageContents } from "../utils.js";

const VALIDATE_IMAGES_PROMPT = `You are an advanced AI assistant tasked with validating image options for a social media post.
Your goal is to identify which images from a given set are relevant to the post, based on the content of the post and an associated marketing report.
Expand Down Expand Up @@ -60,12 +56,6 @@ export function parseResult(result: string): number[] {
.filter((n) => !isNaN(n));
}

function chunk<T>(arr: T[], size: number): T[][] {
return Array.from({ length: Math.ceil(arr.length / size) }, (_, i) =>
arr.slice(i * size, i * size + size),
);
}

export async function validateImages(
state: typeof FindImagesAnnotation.State,
): Promise<{
Expand Down Expand Up @@ -93,7 +83,7 @@ export async function validateImages(
}

// Split images into chunks of 10
const imageChunks = chunk(imagesWithoutSupabase, 10);
const imageChunks = chunkArray(imagesWithoutSupabase, 10);
let allIrrelevantIndices: number[] = [];
let baseIndex = 0;

Expand All @@ -104,46 +94,7 @@ export async function validateImages(

// Process each chunk
for (const imageChunk of imageChunks) {
const imageMessagesPromises = imageChunk.flatMap(
async (fileUri, chunkIndex) => {
const cleanedFileUri = removeQueryParams(fileUri);
let mimeType = getMimeTypeFromUrl(fileUri);

if (!mimeType) {
try {
const { contentType } = await imageUrlToBuffer(fileUri);
if (!contentType) {
throw new Error("Failed to fetch content type");
}
mimeType = contentType;
} catch (e) {
console.warn(
"No mime type found, and failed to fetch content type:",
e,
);
}
}
if (
!mimeType ||
BLACKLISTED_MIME_TYPES.find((mt) => mimeType.startsWith(mt))
) {
return [];
}

return [
{
type: "text",
text: `The below image is index ${baseIndex + chunkIndex}`,
},
{
type: "media",
mimeType,
fileUri: cleanedFileUri,
},
];
},
);
const imageMessages = (await Promise.all(imageMessagesPromises)).flat();
const imageMessages = await getImageMessageContents(imageChunk, baseIndex);

if (!imageMessages.length) {
continue;
Expand Down
53 changes: 53 additions & 0 deletions src/agents/find-images/utils.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import {
BLACKLISTED_MIME_TYPES,
getMimeTypeFromUrl,
imageUrlToBuffer,
removeQueryParams,
} from "../utils.js";

export async function getImageMessageContents(
imageChunk: string[],
baseIndex: number,
) {
const imageMessagesPromises = imageChunk.flatMap(
async (fileUri, chunkIndex) => {
const cleanedFileUri = removeQueryParams(fileUri);
let mimeType = getMimeTypeFromUrl(fileUri);

if (!mimeType) {
try {
const { contentType } = await imageUrlToBuffer(fileUri);
if (!contentType) {
throw new Error("Failed to fetch content type");
}
mimeType = contentType;
} catch (e) {
console.warn(
"No mime type found, and failed to fetch content type:",
e,
);
}
}
if (
!mimeType ||
BLACKLISTED_MIME_TYPES.find((mt) => mimeType.startsWith(mt))
) {
return [];
}

return [
{
type: "text",
text: `The below image is index ${baseIndex + chunkIndex}`,
},
{
type: "media",
mimeType,
fileUri: cleanedFileUri,
},
];
},
);
const imageMessages = (await Promise.all(imageMessagesPromises)).flat();
return imageMessages;
}
13 changes: 13 additions & 0 deletions src/agents/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -436,3 +436,16 @@ export function removeQueryParams(url: string): string {
return url;
}
}

/**
* Splits an array into smaller chunks of a specified size
* @param {T[]} arr - The array to be chunked
* @param {number} size - The size of each chunk
* @returns {T[][]} An array of chunks, where each chunk is an array of size elements (except possibly the last chunk which may be smaller)
* @template T - The type of elements in the array
*/
export function chunkArray<T>(arr: T[], size: number): T[][] {
return Array.from({ length: Math.ceil(arr.length / size) }, (_, i) =>
arr.slice(i * size, i * size + size),
);
}

0 comments on commit 4be3d7a

Please sign in to comment.