Skip to content

Commit

Permalink
Handle the forced usage of tool calls outside of the recursive prompt…
Browse files Browse the repository at this point in the history
… method.

Signed-off-by: Adam Treat <[email protected]>
  • Loading branch information
manyoso committed Aug 14, 2024
1 parent f118720 commit 75dbf9d
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 33 deletions.
79 changes: 47 additions & 32 deletions gpt4all-chat/chatllm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -759,16 +759,13 @@ bool ChatLLM::prompt(const QList<QString> &collectionList, const QString &prompt
}

bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString &prompt, const QString &promptTemplate,
int32_t n_predict, int32_t top_k, float top_p, float min_p, float temp, int32_t n_batch, float repeat_penalty,
int32_t repeat_penalty_tokens, bool isToolCallResponse)
int32_t n_predict, int32_t top_k, float top_p, float min_p, float temp, int32_t n_batch, float repeat_penalty,
int32_t repeat_penalty_tokens)
{
if (!isModelLoaded())
return false;

// FIXME: This should be made agnostic to localdocs and rely upon the force usage usage mode
// and also we have to honor the ask before running mode.
// FIXME: The only localdocs specific thing here should be the injection of the parameters
// FIXME: Get the list of tools ... if force usage is set, then we *try* and force usage here.
QList<SourceExcerpt> localDocsExcerpts;
if (!collectionList.isEmpty() && !isToolCallResponse) {
if (!collectionList.isEmpty()) {
LocalDocsSearch localdocs;
QJsonObject parameters;
parameters.insert("text", prompt);
Expand All @@ -795,6 +792,27 @@ bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString
docsContext = u"### Context:\n%1\n\n"_s.arg(json);
}

qint64 totalTime = 0;
bool producedSourceExcerpts;
bool success = promptRecursive({ docsContext }, prompt, promptTemplate, n_predict, top_k, top_p,
min_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens, totalTime, producedSourceExcerpts);

SuggestionMode mode = MySettings::globalInstance()->suggestionMode();
if (mode == SuggestionMode::On || (mode == SuggestionMode::SourceExcerptsOnly && (!localDocsExcerpts.isEmpty() || producedSourceExcerpts)))
generateQuestions(totalTime);
else
emit responseStopped(totalTime);

return success;
}

bool ChatLLM::promptRecursive(const QList<QString> &toolContexts, const QString &prompt,
const QString &promptTemplate, int32_t n_predict, int32_t top_k, float top_p, float min_p, float temp,
int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens, qint64 &totalTime, bool &producedSourceExcerpts, bool isRecursiveCall)
{
if (!isModelLoaded())
return false;

int n_threads = MySettings::globalInstance()->threadCount();

m_stopGenerating = false;
Expand All @@ -815,19 +833,22 @@ bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString
printf("%s", qPrintable(prompt));
fflush(stdout);
#endif
QElapsedTimer totalTime;
totalTime.start();

QElapsedTimer elapsedTimer;
elapsedTimer.start();
m_timer->start();
if (!docsContext.isEmpty()) {
auto old_n_predict = std::exchange(m_ctx.n_predict, 0); // decode localdocs context without a response
m_llModelInfo.model->prompt(docsContext.toStdString(), "%1", promptFunc, responseFunc,

// The list of possible additional contexts that come from previous usage of tool calls
for (const QString &context : toolContexts) {
auto old_n_predict = std::exchange(m_ctx.n_predict, 0); // decode context without a response
m_llModelInfo.model->prompt(context.toStdString(), "%1", promptFunc, responseFunc,
/*allowContextShift*/ true, m_ctx);
m_ctx.n_predict = old_n_predict; // now we are ready for a response
}

// We can't handle recursive tool calls right now otherwise we always try to check if we have a
// tool call
m_checkToolCall = !isToolCallResponse;
// We can't handle recursive tool calls right now due to the possibility of the model causing
// infinite recursion through repeated tool calls
m_checkToolCall = !isRecursiveCall;

m_llModelInfo.model->prompt(prompt.toStdString(), promptTemplate.toStdString(), promptFunc, responseFunc,
/*allowContextShift*/ true, m_ctx);
Expand All @@ -841,7 +862,7 @@ bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString
fflush(stdout);
#endif
m_timer->stop();
qint64 elapsed = totalTime.elapsed();
totalTime = elapsedTimer.elapsed();
std::string trimmed = trim_whitespace(m_response);

// If we found a tool call, then deal with it
Expand All @@ -852,21 +873,21 @@ bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString
const QString toolTemplate = MySettings::globalInstance()->modelToolTemplate(m_modelInfo);
if (toolTemplate.isEmpty()) {
qWarning() << "ERROR: No valid tool template for this model" << toolCall;
return handleFailedToolCall(trimmed, elapsed);
return handleFailedToolCall(trimmed, totalTime);
}

QJsonParseError err;
const QJsonDocument toolCallDoc = QJsonDocument::fromJson(toolCall.toUtf8(), &err);

if (toolCallDoc.isNull() || err.error != QJsonParseError::NoError || !toolCallDoc.isObject()) {
qWarning() << "ERROR: The tool call had null or invalid json " << toolCall;
return handleFailedToolCall(trimmed, elapsed);
return handleFailedToolCall(trimmed, totalTime);
}

QJsonObject rootObject = toolCallDoc.object();
if (!rootObject.contains("name") || !rootObject.contains("parameters")) {
qWarning() << "ERROR: The tool call did not have required name and argument objects " << toolCall;
return handleFailedToolCall(trimmed, elapsed);
return handleFailedToolCall(trimmed, totalTime);
}

const QString tool = toolCallDoc["name"].toString();
Expand All @@ -877,7 +898,7 @@ bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString
if (tool != "web_search" || !args.contains("query")) {
// FIXME: Need to surface errors to the UI
qWarning() << "ERROR: Could not find the tool and correct parameters for " << toolCall;
return handleFailedToolCall(trimmed, elapsed);
return handleFailedToolCall(trimmed, totalTime);
}

const QString query = args["query"].toString();
Expand All @@ -900,30 +921,24 @@ bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString
if (!parseError.isEmpty()) {
qWarning() << "ERROR: Could not parse source excerpts for brave response:" << parseError;
} else if (!sourceExcerpts.isEmpty()) {
producedSourceExcerpts = true;
emit sourceExcerptsChanged(sourceExcerpts);
}

m_promptResponseTokens = 0;
m_promptTokens = 0;
m_response = std::string();

// This is a recursive call but isToolCallResponse is checked above to arrest infinite recursive
// This is a recursive call but isRecursiveCall is checked above to arrest infinite recursive
// tool calls
return promptInternal(QList<QString>()/*collectionList*/, braveResponse, toolTemplate,
n_predict, top_k, top_p, min_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens,
true /*isToolCallResponse*/);

return promptRecursive(QList<QString>()/*collectionList*/, braveResponse, toolTemplate,
n_predict, top_k, top_p, min_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens, totalTime,
producedSourceExcerpts, true /*isRecursiveCall*/);
} else {
if (trimmed != m_response) {
m_response = trimmed;
emit responseChanged(QString::fromStdString(m_response));
}

SuggestionMode mode = MySettings::globalInstance()->suggestionMode();
if (mode == SuggestionMode::On || (mode == SuggestionMode::SourceExcerptsOnly && (!localDocsExcerpts.isEmpty() || isToolCallResponse)))
generateQuestions(elapsed);
else
emit responseStopped(elapsed);
m_pristineLoadedState = false;
return true;
}
Expand Down
6 changes: 5 additions & 1 deletion gpt4all-chat/chatllm.h
Original file line number Diff line number Diff line change
Expand Up @@ -196,9 +196,10 @@ public Q_SLOTS:
void modelInfoChanged(const ModelInfo &modelInfo);

protected:
// FIXME: This is only available because of server which sucks
bool promptInternal(const QList<QString> &collectionList, const QString &prompt, const QString &promptTemplate,
int32_t n_predict, int32_t top_k, float top_p, float min_p, float temp, int32_t n_batch, float repeat_penalty,
int32_t repeat_penalty_tokens, bool isToolCallResponse = false);
int32_t repeat_penalty_tokens);
bool handleFailedToolCall(const std::string &toolCall, qint64 elapsed);
bool handlePrompt(int32_t token);
bool handleResponse(int32_t token, const std::string &response);
Expand All @@ -219,6 +220,9 @@ public Q_SLOTS:
quint32 m_promptResponseTokens;

private:
bool promptRecursive(const QList<QString> &toolContexts, const QString &prompt, const QString &promptTemplate,
int32_t n_predict, int32_t top_k, float top_p, float min_p, float temp, int32_t n_batch, float repeat_penalty,
int32_t repeat_penalty_tokens, qint64 &totalTime, bool &producedSourceExcerpts, bool isRecursiveCall = false);
bool loadNewModel(const ModelInfo &modelInfo, QVariantMap &modelLoadProps);

std::string m_response;
Expand Down

0 comments on commit 75dbf9d

Please sign in to comment.