diff --git a/gpt4all-chat/chatllm.cpp b/gpt4all-chat/chatllm.cpp index 3500c858669e..7450ab1fc402 100644 --- a/gpt4all-chat/chatllm.cpp +++ b/gpt4all-chat/chatllm.cpp @@ -759,16 +759,13 @@ bool ChatLLM::prompt(const QList &collectionList, const QString &prompt } bool ChatLLM::promptInternal(const QList &collectionList, const QString &prompt, const QString &promptTemplate, - int32_t n_predict, int32_t top_k, float top_p, float min_p, float temp, int32_t n_batch, float repeat_penalty, - int32_t repeat_penalty_tokens, bool isToolCallResponse) + int32_t n_predict, int32_t top_k, float top_p, float min_p, float temp, int32_t n_batch, float repeat_penalty, + int32_t repeat_penalty_tokens) { - if (!isModelLoaded()) - return false; - - // FIXME: This should be made agnostic to localdocs and rely upon the force usage usage mode - // and also we have to honor the ask before running mode. + // FIXME: The only localdocs specific thing here should be the injection of the parameters + // FIXME: Get the list of tools ... if force usage is set, then we *try* and force usage here. QList localDocsExcerpts; - if (!collectionList.isEmpty() && !isToolCallResponse) { + if (!collectionList.isEmpty()) { LocalDocsSearch localdocs; QJsonObject parameters; parameters.insert("text", prompt); @@ -795,6 +792,27 @@ bool ChatLLM::promptInternal(const QList &collectionList, const QString docsContext = u"### Context:\n%1\n\n"_s.arg(json); } + qint64 totalTime = 0; + bool producedSourceExcerpts; + bool success = promptRecursive({ docsContext }, prompt, promptTemplate, n_predict, top_k, top_p, + min_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens, totalTime, producedSourceExcerpts); + + SuggestionMode mode = MySettings::globalInstance()->suggestionMode(); + if (mode == SuggestionMode::On || (mode == SuggestionMode::SourceExcerptsOnly && (!localDocsExcerpts.isEmpty() || producedSourceExcerpts))) + generateQuestions(totalTime); + else + emit responseStopped(totalTime); + + return success; +} + +bool ChatLLM::promptRecursive(const QList &toolContexts, const QString &prompt, + const QString &promptTemplate, int32_t n_predict, int32_t top_k, float top_p, float min_p, float temp, + int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens, qint64 &totalTime, bool &producedSourceExcerpts, bool isRecursiveCall) +{ + if (!isModelLoaded()) + return false; + int n_threads = MySettings::globalInstance()->threadCount(); m_stopGenerating = false; @@ -815,19 +833,22 @@ bool ChatLLM::promptInternal(const QList &collectionList, const QString printf("%s", qPrintable(prompt)); fflush(stdout); #endif - QElapsedTimer totalTime; - totalTime.start(); + + QElapsedTimer elapsedTimer; + elapsedTimer.start(); m_timer->start(); - if (!docsContext.isEmpty()) { - auto old_n_predict = std::exchange(m_ctx.n_predict, 0); // decode localdocs context without a response - m_llModelInfo.model->prompt(docsContext.toStdString(), "%1", promptFunc, responseFunc, + + // The list of possible additional contexts that come from previous usage of tool calls + for (const QString &context : toolContexts) { + auto old_n_predict = std::exchange(m_ctx.n_predict, 0); // decode context without a response + m_llModelInfo.model->prompt(context.toStdString(), "%1", promptFunc, responseFunc, /*allowContextShift*/ true, m_ctx); m_ctx.n_predict = old_n_predict; // now we are ready for a response } - // We can't handle recursive tool calls right now otherwise we always try to check if we have a - // tool call - m_checkToolCall = !isToolCallResponse; + // We can't handle recursive tool calls right now due to the possibility of the model causing + // infinite recursion through repeated tool calls + m_checkToolCall = !isRecursiveCall; m_llModelInfo.model->prompt(prompt.toStdString(), promptTemplate.toStdString(), promptFunc, responseFunc, /*allowContextShift*/ true, m_ctx); @@ -841,7 +862,7 @@ bool ChatLLM::promptInternal(const QList &collectionList, const QString fflush(stdout); #endif m_timer->stop(); - qint64 elapsed = totalTime.elapsed(); + totalTime = elapsedTimer.elapsed(); std::string trimmed = trim_whitespace(m_response); // If we found a tool call, then deal with it @@ -852,7 +873,7 @@ bool ChatLLM::promptInternal(const QList &collectionList, const QString const QString toolTemplate = MySettings::globalInstance()->modelToolTemplate(m_modelInfo); if (toolTemplate.isEmpty()) { qWarning() << "ERROR: No valid tool template for this model" << toolCall; - return handleFailedToolCall(trimmed, elapsed); + return handleFailedToolCall(trimmed, totalTime); } QJsonParseError err; @@ -860,13 +881,13 @@ bool ChatLLM::promptInternal(const QList &collectionList, const QString if (toolCallDoc.isNull() || err.error != QJsonParseError::NoError || !toolCallDoc.isObject()) { qWarning() << "ERROR: The tool call had null or invalid json " << toolCall; - return handleFailedToolCall(trimmed, elapsed); + return handleFailedToolCall(trimmed, totalTime); } QJsonObject rootObject = toolCallDoc.object(); if (!rootObject.contains("name") || !rootObject.contains("parameters")) { qWarning() << "ERROR: The tool call did not have required name and argument objects " << toolCall; - return handleFailedToolCall(trimmed, elapsed); + return handleFailedToolCall(trimmed, totalTime); } const QString tool = toolCallDoc["name"].toString(); @@ -877,7 +898,7 @@ bool ChatLLM::promptInternal(const QList &collectionList, const QString if (tool != "web_search" || !args.contains("query")) { // FIXME: Need to surface errors to the UI qWarning() << "ERROR: Could not find the tool and correct parameters for " << toolCall; - return handleFailedToolCall(trimmed, elapsed); + return handleFailedToolCall(trimmed, totalTime); } const QString query = args["query"].toString(); @@ -900,6 +921,7 @@ bool ChatLLM::promptInternal(const QList &collectionList, const QString if (!parseError.isEmpty()) { qWarning() << "ERROR: Could not parse source excerpts for brave response:" << parseError; } else if (!sourceExcerpts.isEmpty()) { + producedSourceExcerpts = true; emit sourceExcerptsChanged(sourceExcerpts); } @@ -907,23 +929,16 @@ bool ChatLLM::promptInternal(const QList &collectionList, const QString m_promptTokens = 0; m_response = std::string(); - // This is a recursive call but isToolCallResponse is checked above to arrest infinite recursive + // This is a recursive call but isRecursiveCall is checked above to arrest infinite recursive // tool calls - return promptInternal(QList()/*collectionList*/, braveResponse, toolTemplate, - n_predict, top_k, top_p, min_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens, - true /*isToolCallResponse*/); - + return promptRecursive(QList()/*collectionList*/, braveResponse, toolTemplate, + n_predict, top_k, top_p, min_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens, totalTime, + producedSourceExcerpts, true /*isRecursiveCall*/); } else { if (trimmed != m_response) { m_response = trimmed; emit responseChanged(QString::fromStdString(m_response)); } - - SuggestionMode mode = MySettings::globalInstance()->suggestionMode(); - if (mode == SuggestionMode::On || (mode == SuggestionMode::SourceExcerptsOnly && (!localDocsExcerpts.isEmpty() || isToolCallResponse))) - generateQuestions(elapsed); - else - emit responseStopped(elapsed); m_pristineLoadedState = false; return true; } diff --git a/gpt4all-chat/chatllm.h b/gpt4all-chat/chatllm.h index feacd744f228..7622c3661446 100644 --- a/gpt4all-chat/chatllm.h +++ b/gpt4all-chat/chatllm.h @@ -196,9 +196,10 @@ public Q_SLOTS: void modelInfoChanged(const ModelInfo &modelInfo); protected: + // FIXME: This is only available because of server which sucks bool promptInternal(const QList &collectionList, const QString &prompt, const QString &promptTemplate, int32_t n_predict, int32_t top_k, float top_p, float min_p, float temp, int32_t n_batch, float repeat_penalty, - int32_t repeat_penalty_tokens, bool isToolCallResponse = false); + int32_t repeat_penalty_tokens); bool handleFailedToolCall(const std::string &toolCall, qint64 elapsed); bool handlePrompt(int32_t token); bool handleResponse(int32_t token, const std::string &response); @@ -219,6 +220,9 @@ public Q_SLOTS: quint32 m_promptResponseTokens; private: + bool promptRecursive(const QList &toolContexts, const QString &prompt, const QString &promptTemplate, + int32_t n_predict, int32_t top_k, float top_p, float min_p, float temp, int32_t n_batch, float repeat_penalty, + int32_t repeat_penalty_tokens, qint64 &totalTime, bool &producedSourceExcerpts, bool isRecursiveCall = false); bool loadNewModel(const ModelInfo &modelInfo, QVariantMap &modelLoadProps); std::string m_response;