diff --git a/.github/workflows/test-gemini.yml b/.github/workflows/test-gemini.yml index e859601..d151e05 100644 --- a/.github/workflows/test-gemini.yml +++ b/.github/workflows/test-gemini.yml @@ -30,6 +30,7 @@ jobs: LLM_API_BASE_URL: 'https://generativelanguage.googleapis.com/v1beta' LLM_API_KEY: ${{ secrets.GEMINI_API_KEY }} LLM_CHAT_MODEL: 'gemini-1.5-flash-8b' + LLM_JSON_SCHEMA: 1 - run: cat output.txt - run: grep -i jupiter output.txt @@ -46,6 +47,7 @@ jobs: LLM_API_BASE_URL: 'https://generativelanguage.googleapis.com/v1beta' LLM_API_KEY: ${{ secrets.GEMINI_API_KEY }} LLM_CHAT_MODEL: 'gemini-1.5-flash-8b' + LLM_JSON_SCHEMA: 1 high-school-stem: needs: chain-of-thought @@ -59,6 +61,7 @@ jobs: LLM_API_BASE_URL: 'https://generativelanguage.googleapis.com/v1beta' LLM_API_KEY: ${{ secrets.GEMINI_API_KEY }} LLM_CHAT_MODEL: 'gemini-1.5-flash-8b' + LLM_JSON_SCHEMA: 1 general-knowledge: needs: chain-of-thought @@ -72,3 +75,4 @@ jobs: LLM_API_BASE_URL: 'https://generativelanguage.googleapis.com/v1beta' LLM_API_KEY: ${{ secrets.GEMINI_API_KEY }} LLM_CHAT_MODEL: 'gemini-1.5-flash-8b' + LLM_JSON_SCHEMA: 1 diff --git a/.github/workflows/test-gpt.yml b/.github/workflows/test-gpt.yml index 939d723..e4cfce3 100644 --- a/.github/workflows/test-gpt.yml +++ b/.github/workflows/test-gpt.yml @@ -30,6 +30,8 @@ jobs: LLM_API_BASE_URL: 'https://api.openai.com/v1' LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }} LLM_CHAT_MODEL: 'gpt-4o-mini' + LLM_JSON_SCHEMA: 1 + LLM_DEBUG_CHAT: 1 - run: cat output.txt - run: grep -i jupiter output.txt @@ -46,6 +48,8 @@ jobs: LLM_API_BASE_URL: 'https://api.openai.com/v1' LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }} LLM_CHAT_MODEL: 'gpt-4o-mini' + LLM_JSON_SCHEMA: 1 + LLM_DEBUG_CHAT: 1 high-school-stem: needs: chain-of-thought @@ -59,6 +63,8 @@ jobs: LLM_API_BASE_URL: 'https://api.openai.com/v1' LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }} LLM_CHAT_MODEL: 'gpt-4o-mini' + LLM_JSON_SCHEMA: 1 + LLM_DEBUG_CHAT: 1 general-knowledge: needs: chain-of-thought @@ -72,3 +78,4 @@ jobs: LLM_API_BASE_URL: 'https://api.openai.com/v1' LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }} LLM_CHAT_MODEL: 'gpt-4o-mini' + LLM_JSON_SCHEMA: 1 diff --git a/.github/workflows/test-small-llm.yml b/.github/workflows/test-small-llm.yml index f3654e1..2bc4ea6 100644 --- a/.github/workflows/test-small-llm.yml +++ b/.github/workflows/test-small-llm.yml @@ -42,6 +42,7 @@ jobs: - run: echo 'Which planet in our solar system is the largest?' | ./query-llm.js | tee output.txt env: LLM_API_BASE_URL: 'http://127.0.0.1:8080/v1' + LLM_JSON_SCHEMA: 1 - run: cat output.txt - run: grep -i jupiter output.txt @@ -64,6 +65,7 @@ jobs: - run: ./query-llm.js tests/canary-multi-turn.txt env: LLM_API_BASE_URL: 'http://127.0.0.1:8080/v1' + LLM_JSON_SCHEMA: 1 high-school-stem: needs: chain-of-thought @@ -83,3 +85,4 @@ jobs: - run: ./query-llm.js tests/high-school-stem.txt env: LLM_API_BASE_URL: 'http://127.0.0.1:8080/v1' + LLM_JSON_SCHEMA: 1 diff --git a/query-llm.js b/query-llm.js index 176c39e..01fbac7 100755 --- a/query-llm.js +++ b/query-llm.js @@ -7,6 +7,7 @@ const LLM_API_BASE_URL = process.env.LLM_API_BASE_URL || 'https://api.openai.com const LLM_API_KEY = process.env.LLM_API_KEY || process.env.OPENAI_API_KEY; const LLM_CHAT_MODEL = process.env.LLM_CHAT_MODEL; const LLM_STREAMING = process.env.LLM_STREAMING !== 'no'; +const LLM_JSON_SCHEMA = process.env.LLM_JSON_SCHEMA; const LLM_ZERO_SHOT = process.env.LLM_ZERO_SHOT; const LLM_DEBUG_CHAT = process.env.LLM_DEBUG_CHAT; @@ -33,6 +34,29 @@ const CROSS = '✘'; */ const pipe = (...fns) => arg => fns.reduce((d, fn) => d.then(fn), Promise.resolve(arg)); +/** + * Tries to parse a string as JSON, but if that fails, tries adding a + * closing curly brace or double quote to fix the JSON. + * + * @param {string} text + * @returns {Object} + */ +const unJSON = (text) => { + try { + return JSON.parse(text); + } catch (e) { + try { + return JSON.parse(text + '}'); + } catch (e) { + try { + return JSON.parse(text + '"}'); + } catch (e) { + return {}; + } + } + } +}; + /** * Represents a chat message. @@ -58,7 +82,7 @@ const pipe = (...fns) => arg => fns.reduce((d, fn) => d.then(fn), Promise.resolv * @returns {Promise} The completion generated by the LLM. */ -const chat = async (messages, handler) => { +const chat = async (messages, schema, handler) => { const gemini = LLM_API_BASE_URL.indexOf('generativelanguage.google') > 0; const stream = LLM_STREAMING && typeof handler === 'function'; const model = LLM_CHAT_MODEL || 'gpt-4o-mini'; @@ -69,16 +93,34 @@ const chat = async (messages, handler) => { const max_tokens = 200; const temperature = 0; + const response_format = schema ? { + type: 'json_schema', + json_schema: { + schema, + name: 'response', + strict: true + } + } : undefined; + + const geminify = schema => ({ ...schema, additionalProperties: undefined }); + const response_schema = response_format ? geminify(schema) : undefined; + const response_mime_type = response_schema ? 'application/json' : 'text/plain'; + const bundles = messages.map(({ role, content }) => { return { role, parts: [{ text: content }] }; }); const contents = bundles.filter(({ role }) => role === 'user'); const system_instruction = bundles.filter(({ role }) => role === 'system').shift(); - const generationConfig = { temperature, maxOutputTokens: max_tokens, responseMimeType: 'text/plain' }; + const generationConfig = { temperature, response_mime_type, response_schema, maxOutputTokens: max_tokens }; const body = gemini ? { system_instruction, contents, generationConfig } : - { messages, model, stop, max_tokens, temperature, stream } + { messages, response_format, model, stop, max_tokens, temperature, stream } + + LLM_DEBUG_CHAT && + messages.forEach(({ role, content }) => { + console.log(`${MAGENTA}${role}:${NORMAL} ${content}`); + }); const response = await fetch(url, { method: 'POST', @@ -101,15 +143,19 @@ const chat = async (messages, handler) => { return ''; } - LLM_DEBUG_CHAT && - messages.forEach(({ role, content }) => { - console.log(`${MAGENTA}${role}:${NORMAL} ${content}`); - }); - if (!stream) { const data = await response.json(); const answer = extract(data).trim(); - LLM_DEBUG_CHAT && console.log(`${YELLOW}${answer}${NORMAL}`); + if (LLM_DEBUG_CHAT) { + if (LLM_JSON_SCHEMA) { + const parsed = unJSON(answer); + const empty = Object.keys(parsed).length === 0; + const formatted = empty ? answer : JSON.stringify(parsed, null, 2); + console.log(`${YELLOW}${formatted}${NORMAL}`); + } else { + console.log(`${YELLOW}${answer}${NORMAL}`); + } + } (answer.length > 0) && handler && handler(answer); return answer; } @@ -196,6 +242,7 @@ const reply = async (context) => { const { enter, leave, stream } = delegates; enter && enter('Reply'); + const schema = null; const messages = []; messages.push({ role: 'system', content: REPLY_PROMPT }); const relevant = history.slice(-5); @@ -206,13 +253,13 @@ const reply = async (context) => { }); messages.push({ role: 'user', content: inquiry }); - const answer = await chat(messages, stream); + const answer = await chat(messages, schema, stream); leave && leave('Reply', { inquiry, answer }); return { answer, ...context }; } -const PREDEFINED_KEYS = ['INQUIRY', 'TOOL', 'THOUGHT', 'KEYPHRASES', 'OBSERVATION', 'TOPIC']; +const PREDEFINED_KEYS = ['inquiry', 'tool', 'thought', 'keyphrases', 'observation', 'answer', 'topic']; /** * Break downs a multi-line text based on a number of predefined keys. @@ -227,7 +274,7 @@ const deconstruct = (text, markers = PREDEFINED_KEYS) => { const anchor = markers.slice().pop(); const start = text.lastIndexOf(anchor + ':'); if (start >= 0) { - parts[anchor.toLowerCase()] = text.substring(start).replace(anchor + ':', '').trim(); + parts[anchor] = text.substring(start).replace(anchor + ':', '').trim(); let str = text.substring(0, start); for (let i = 0; i < keys.length; ++i) { const marker = keys[i]; @@ -236,8 +283,7 @@ const deconstruct = (text, markers = PREDEFINED_KEYS) => { const substr = str.substr(pos + marker.length + 1).trim(); const value = substr.split('\n').shift(); str = str.slice(0, pos); - const key = marker.toLowerCase(); - parts[key] = value; + parts[marker] = value; } } } @@ -251,15 +297,60 @@ const deconstruct = (text, markers = PREDEFINED_KEYS) => { * @return {text} */ const construct = (kv) => { - return PREDEFINED_KEYS.filter(key => kv[key.toLowerCase()]).map(key => { - const value = kv[key.toLowerCase()]; + if (LLM_JSON_SCHEMA) { + return JSON.stringify(kv, null, 2); + } + return PREDEFINED_KEYS.filter(key => kv[key]).map(key => { + const value = kv[key]; if (value && value.length > 0) { - return `${key.toUpperCase()}: ${value}`; + return `${key}: ${value}`; } return null; }).join('\n'); } + +/** + * Breaks down the completion into a dictionary containing the thought process, + * important keyphrases, observation, and topic. + * + * @param {string} hint - The hint or example given to the LLM. + * @param {string} completion - The completion generated by the LLM. + * @returns {Object} Breakdown of the thought process into a dictionary. + */ +const breakdown = (hint, completion) => { + const text = hint + completion; + if (text.startsWith('{') && text.endsWith('}')) { + try { + return unJSON(text); + } catch (error) { + LLM_DEBUG_CHAT && console.error(`Failed to parse JSON: ${text.replaceAll('\n', '')}`); + } + } + let result = deconstruct(text); + const { topic } = result; + if (!topic || topic.length === 0) { + result = deconstruct(text + '\n' + 'TOPIC: general knowledge.'); + } + return result; +} + +/** + * Returns a formatted string based on the given object. + * + * @param {string} [prefix] - An optional prefix string + * @param {Object} object - The object to format + * @returns {string} The formatted string + */ +const structure = (prefix, object) => { + if (LLM_JSON_SCHEMA) { + const format = prefix ? prefix + ' (JSON with this schema)' : ''; + return format + '\n' + JSON.stringify(object, null, 2) + '\n'; + } + + return (prefix || '') + '\n\n' + construct(object) + '\n'; +} + /** * Represents the record of an atomic processing. * @@ -290,46 +381,77 @@ const construct = (kv) => { * @returns {Context} Updated pipeline context. */ -const REASON_PROMPT = `Use Google to search for the answer. -Think step by step. Always output your thought in the following format: - -TOOL: the search engine to use (must be Google). -THOUGHT: describe your thoughts about the inquiry. -KEYPHRASES: the important key phrases to search for. -OBSERVATION: the concise result of the search tool. -TOPIC: the specific topic covering the inquiry.`; - -const REASON_EXAMPLE = ` - -# Example - -Given an inquiry "What is Pitch Lake in Trinidad famous for?", you will output: - -TOOL: Google. -THOUGHT: This is about geography, I will use Google search. -KEYPHRASES: Pitch Lake in Trinidad fame. -OBSERVATION: Pitch Lake in Trinidad is the largest natural deposit of asphalt. -TOPIC: geography.`; - -const breakdown = (hint, completion) => { - const text = hint + completion; - let result = deconstruct(text); - const { topic } = result; - if (!topic || topic.length === 0) { - result = deconstruct(text + '\n' + 'TOPIC: general knowledge.'); - } - return result; -} +const REASON_PROMPT = `Use Google to search for the answer. Think step by step. +Always output your thought in following format` + +const REASON_GUIDELINE = { + tool: 'the search engine to use (must be Google)', + thought: 'describe your thoughts about the inquiry', + keyphrases: 'the important key phrases to search for', + observation: 'the concise result of the search tool', + topic: 'the specific topic covering the inquiry' +}; + +const REASON_EXAMPLE_INQUIRY = ` +Example: + +Given an inquiry "What is Pitch Lake in Trinidad famous for?", you will output:`; + +const REASON_EXAMPLE_OUTPUT = { + tool: 'Google', + thought: 'This is about geography, I will use Google search', + keyphrases: 'Pitch Lake in Trinidad fame', + observation: 'Pitch Lake in Trinidad is the largest natural deposit of asphalt', + topic: 'geography' +}; + +const REASON_SCHEMA = { + type: 'object', + additionalProperties: false, + properties: { + tool: { + type: 'string' + }, + thought: { + type: 'string' + }, + keyphrases: { + type: 'string' + }, + observation: { + type: 'string' + }, + topic: { + type: 'string' + } + }, + required: [ + 'tool', + 'thought', + 'keyphrases', + 'observation', + 'topic' + ] +}; +/** + * Performs a basic step-by-step reasoning, in the style of Chain of Thought. + * The updated context will contains new information such as `keyphrases` and `observation`. + * If the generated keyphrases is empty, the pipeline will retry the reasoning. + * + * @param {Context} context - Current pipeline context. + * @returns {Context} Updated pipeline context. + */ const reason = async (context) => { const { history, delegates } = context; const { enter, leave } = delegates; enter && enter('Reason'); + const schema = LLM_JSON_SCHEMA ? REASON_SCHEMA : null; + let prompt = structure(REASON_PROMPT, REASON_GUIDELINE); const relevant = history.slice(-3); - let prompt = REASON_PROMPT; if (relevant.length === 0) { - prompt += REASON_EXAMPLE; + prompt += structure(REASON_EXAMPLE_INQUIRY, REASON_EXAMPLE_OUTPUT); } const messages = []; @@ -338,22 +460,23 @@ const reason = async (context) => { const { inquiry, topic, thought, keyphrases, answer } = msg; const observation = answer; messages.push({ role: 'user', content: inquiry }); - const assistant = construct({ tool: 'Google.', thought, keyphrases, observation, topic }); + const assistant = construct({ tool: 'Google', thought, keyphrases, observation, topic }); messages.push({ role: 'assistant', content: assistant }); }); const { inquiry } = context; + messages.push({ role: 'user', content: inquiry }); - const hint = ['TOOL: Google.', 'THOUGHT: '].join('\n'); - messages.push({ role: 'assistant', content: hint }); - const completion = await chat(messages); + const hint = schema ? '' : ['tool: Google', 'thought: '].join('\n'); + (!schema) && messages.push({ role: 'assistant', content: hint }); + const completion = await chat(messages, schema); let result = breakdown(hint, completion); - if (!result.keyphrases || result.keyphrases.length === 0) { + if (!schema && (!result.keyphrases || result.keyphrases.length === 0)) { LLM_DEBUG_CHAT && console.log(`-->${RED}Invalid keyphrases. Trying again...`); - const hint = ['TOOL: Google.', 'THOUGHT: ' + result.thought, 'KEYPHRASES: '].join('\n'); + const hint = ['tool: Google', 'thought: ' + result.thought, 'keyphrases: '].join('\n'); messages.pop(); messages.push({ role: 'assistant', content: hint }); - const completion = await chat(messages); + const completion = await chat(messages, schema); result = breakdown(hint, completion); } const { topic, thought, keyphrases, observation } = result; @@ -379,12 +502,31 @@ Do not make any apology or other commentary. Do not use other sources of information, including your memory. Do not make up new names or come up with new facts.`; +const RESPOND_GUIDELINE = ` +Always answer in JSON with the following format: + +{ + "answer": // accurate and polite answer +}`; + +const RESPOND_SCHEMA = { + type: 'object', + additionalProperties: false, + properties: { + answer: { + type: 'string' + } + }, + required: ['answer'] +} + const respond = async (context) => { const { history, delegates } = context; const { enter, leave, stream } = delegates; enter && enter('Respond'); - let prompt = RESPOND_PROMPT; + const schema = LLM_JSON_SCHEMA ? RESPOND_SCHEMA : null; + let prompt = schema ? RESPOND_PROMPT + RESPOND_GUIDELINE : RESPOND_PROMPT; const relevant = history.slice(-2); if (relevant.length > 0) { prompt += '\n'; @@ -398,10 +540,11 @@ const respond = async (context) => { const messages = []; messages.push({ role: 'system', content: prompt }); - const { inquiry, thought, observation } = context; + const { inquiry, observation } = context; messages.push({ role: 'user', content: construct({ inquiry, observation }) }); - messages.push({ role: 'assistant', content: 'ANSWER: ' }); - const answer = await chat(messages, stream); + (!schema) && messages.push({ role: 'assistant', content: 'Answer: ' }); + const completion = await chat(messages, schema, stream); + const answer = schema ? breakdown('', completion).answer : completion; leave && leave('Respond', { inquiry, observation, answer }); return { answer, ...context }; @@ -658,7 +801,6 @@ const evaluate = async (filename) => { const interact = async () => { const history = []; - const stream = (text) => process.stdout.write(text); let loop = true; const io = readline.createInterface({ input: process.stdin, output: process.stdout }); @@ -678,6 +820,21 @@ const interact = async () => { } } else { + let input = ''; + let output = ''; + const stream = (text) => { + if (LLM_JSON_SCHEMA) { + input += text; + const { answer } = unJSON(input); + if (answer && answer.length > 0) { + process.stdout.write(answer.substring(output.length)); + output = answer; + } + } else { + process.stdout.write(text); + } + } + const stages = []; const update = (stage, fields) => { if (stage === 'Reason') {