Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Don't block the gui thread for tool calls #3435

Merged
merged 2 commits into from
Jan 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions gpt4all-chat/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/).
- Fix `codesign --verify` failure on macOS ([#3413](https://github.com/nomic-ai/gpt4all/pull/3413))
- Code Interpreter: Fix console.log not accepting a single string after v3.7.0 ([#3426](https://github.com/nomic-ai/gpt4all/pull/3426))
- Fix Phi 3.1 Mini 128K Instruct template (by [@ThiloteE](https://github.com/ThiloteE) in [#3412](https://github.com/nomic-ai/gpt4all/pull/3412))
- Don't block the gui thread for reasoning ([#3435](https://github.com/nomic-ai/gpt4all/pull/3435))

## [3.7.0] - 2025-01-21

Expand Down
1 change: 1 addition & 0 deletions gpt4all-chat/qml/ChatItemView.qml
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ GridLayout {
case Chat.PromptProcessing: return qsTr("processing ...")
case Chat.ResponseGeneration: return qsTr("generating response ...");
case Chat.GeneratingQuestions: return qsTr("generating questions ...");
case Chat.ToolCallGeneration: return qsTr("generating toolcall ...");
default: return ""; // handle unexpected values
}
}
Expand Down
102 changes: 61 additions & 41 deletions gpt4all-chat/src/chat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,11 @@ QVariant Chat::popPrompt(int index)

void Chat::stopGenerating()
{
// In future if we have more than one tool we'll have to keep track of which tools are possibly
// running, but for now we only have one
Tool *toolInstance = ToolModel::globalInstance()->get(ToolCallConstants::CodeInterpreterFunction);
Q_ASSERT(toolInstance);
toolInstance->interrupt();
m_llmodel->stopGenerating();
}

Expand Down Expand Up @@ -242,56 +247,71 @@ void Chat::responseStopped(qint64 promptResponseMs)

const QString possibleToolcall = m_chatModel->possibleToolcall();

Network::globalInstance()->trackChatEvent("response_stopped", {
{"first", m_firstResponse},
{"message_count", chatModel()->count()},
{"$duration", promptResponseMs / 1000.},
});

ToolCallParser parser;
parser.update(possibleToolcall);
if (parser.state() == ToolEnums::ParseState::Complete)
processToolCall(parser.toolCall());
else
responseComplete();
}

if (parser.state() == ToolEnums::ParseState::Complete) {
const QString toolCall = parser.toolCall();

// Regex to remove the formatting around the code
static const QRegularExpression regex("^\\s*```javascript\\s*|\\s*```\\s*$");
QString code = toolCall;
code.remove(regex);
code = code.trimmed();

// Right now the code interpreter is the only available tool
Tool *toolInstance = ToolModel::globalInstance()->get(ToolCallConstants::CodeInterpreterFunction);
Q_ASSERT(toolInstance);

// The param is the code
const ToolParam param = { "code", ToolEnums::ParamType::String, code };
const QString result = toolInstance->run({param}, 10000 /*msecs to timeout*/);
const ToolEnums::Error error = toolInstance->error();
const QString errorString = toolInstance->errorString();

// Update the current response with meta information about toolcall and re-parent
m_chatModel->updateToolCall({
ToolCallConstants::CodeInterpreterFunction,
{ param },
result,
error,
errorString
});

++m_consecutiveToolCalls;

// We limit the number of consecutive toolcalls otherwise we get into a potentially endless loop
if (m_consecutiveToolCalls < 3 || error == ToolEnums::Error::NoError) {
resetResponseState();
emit promptRequested(m_collections); // triggers a new response
return;
}
void Chat::processToolCall(const QString &toolCall)
{
m_responseState = Chat::ToolCallGeneration;
emit responseStateChanged();
// Regex to remove the formatting around the code
static const QRegularExpression regex("^\\s*```javascript\\s*|\\s*```\\s*$");
QString code = toolCall;
code.remove(regex);
code = code.trimmed();

// Right now the code interpreter is the only available tool
Tool *toolInstance = ToolModel::globalInstance()->get(ToolCallConstants::CodeInterpreterFunction);
Q_ASSERT(toolInstance);
connect(toolInstance, &Tool::runComplete, this, &Chat::toolCallComplete, Qt::SingleShotConnection);

// The param is the code
const ToolParam param = { "code", ToolEnums::ParamType::String, code };
m_responseInProgress = true;
emit responseInProgressChanged();
toolInstance->run({param});
}

void Chat::toolCallComplete(const ToolCallInfo &info)
{
// Update the current response with meta information about toolcall and re-parent
m_chatModel->updateToolCall(info);

++m_consecutiveToolCalls;

m_responseInProgress = false;
emit responseInProgressChanged();

// We limit the number of consecutive toolcalls otherwise we get into a potentially endless loop
if (m_consecutiveToolCalls < 3 || info.error == ToolEnums::Error::NoError) {
resetResponseState();
emit promptRequested(m_collections); // triggers a new response
return;
}

responseComplete();
}

void Chat::responseComplete()
{
if (m_generatedName.isEmpty())
emit generateNameRequested();

m_responseState = Chat::ResponseStopped;
emit responseStateChanged();

m_consecutiveToolCalls = 0;
Network::globalInstance()->trackChatEvent("response_complete", {
{"first", m_firstResponse},
{"message_count", chatModel()->count()},
{"$duration", promptResponseMs / 1000.},
});
cebtenzzre marked this conversation as resolved.
Show resolved Hide resolved
m_firstResponse = false;
}

Expand Down
6 changes: 5 additions & 1 deletion gpt4all-chat/src/chat.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ class Chat : public QObject
LocalDocsProcessing,
PromptProcessing,
GeneratingQuestions,
ResponseGeneration
ResponseGeneration,
ToolCallGeneration
};
Q_ENUM(ResponseState)

Expand Down Expand Up @@ -166,6 +167,9 @@ private Q_SLOTS:
void promptProcessing();
void generatingQuestions();
void responseStopped(qint64 promptResponseMs);
void processToolCall(const QString &toolCall);
void toolCallComplete(const ToolCallInfo &info);
void responseComplete();
void generatedNameChanged(const QString &name);
void generatedQuestionFinished(const QString &question);
void handleModelLoadingError(const QString &error);
Expand Down
94 changes: 58 additions & 36 deletions gpt4all-chat/src/codeinterpreter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,15 @@

using namespace Qt::Literals::StringLiterals;

CodeInterpreter::CodeInterpreter()
: Tool()
, m_error(ToolEnums::Error::NoError)
{
m_worker = new CodeInterpreterWorker;
connect(this, &CodeInterpreter::request, m_worker, &CodeInterpreterWorker::request, Qt::QueuedConnection);
}

QString CodeInterpreter::run(const QList<ToolParam> &params, qint64 timeout)
void CodeInterpreter::run(const QList<ToolParam> &params)
{
m_error = ToolEnums::Error::NoError;
m_errorString = QString();
Expand All @@ -18,27 +25,24 @@ QString CodeInterpreter::run(const QList<ToolParam> &params, qint64 timeout)
&& params.first().type == ToolEnums::ParamType::String);

const QString code = params.first().value.toString();

QThread workerThread;
CodeInterpreterWorker worker;
worker.moveToThread(&workerThread);
connect(&worker, &CodeInterpreterWorker::finished, &workerThread, &QThread::quit, Qt::DirectConnection);
connect(&workerThread, &QThread::started, [&worker, code]() {
worker.request(code);
connect(m_worker, &CodeInterpreterWorker::finished, [this, params] {
m_error = m_worker->error();
m_errorString = m_worker->errorString();
emit runComplete({
ToolCallConstants::CodeInterpreterFunction,
params,
m_worker->response(),
m_error,
m_errorString
});
});
workerThread.start();
bool timedOut = !workerThread.wait(timeout);
if (timedOut) {
worker.interrupt(timeout); // thread safe
m_error = ToolEnums::Error::TimeoutError;
}
workerThread.quit();
workerThread.wait();
if (!timedOut) {
m_error = worker.error();
m_errorString = worker.errorString();
}
return worker.response();

emit request(code);
}

bool CodeInterpreter::interrupt()
{
return m_worker->interrupt();
}

QList<ToolParamInfo> CodeInterpreter::parameters() const
Expand Down Expand Up @@ -89,17 +93,15 @@ QString CodeInterpreter::exampleReply() const

CodeInterpreterWorker::CodeInterpreterWorker()
: QObject(nullptr)
, m_engine(new QJSEngine(this))
{
}
moveToThread(&m_thread);

void CodeInterpreterWorker::request(const QString &code)
{
JavaScriptConsoleCapture consoleCapture;
QJSValue consoleInternalObject = m_engine.newQObject(&consoleCapture);
m_engine.globalObject().setProperty("console_internal", consoleInternalObject);
QJSValue consoleInternalObject = m_engine->newQObject(&m_consoleCapture);
m_engine->globalObject().setProperty("console_internal", consoleInternalObject);

// preprocess console.log args in JS since Q_INVOKE doesn't support varargs
auto consoleObject = m_engine.evaluate(uR"(
auto consoleObject = m_engine->evaluate(uR"(
class Console {
log(...args) {
if (args.length == 0)
Expand All @@ -116,15 +118,28 @@ void CodeInterpreterWorker::request(const QString &code)

new Console();
)"_s);
m_engine.globalObject().setProperty("console", consoleObject);
m_engine->globalObject().setProperty("console", consoleObject);
m_thread.start();
}

const QJSValue result = m_engine.evaluate(code);
void CodeInterpreterWorker::reset()
{
m_response.clear();
m_error = ToolEnums::Error::NoError;
m_errorString.clear();
m_consoleCapture.output.clear();
m_engine->setInterrupted(false);
}

void CodeInterpreterWorker::request(const QString &code)
{
reset();
const QJSValue result = m_engine->evaluate(code);
QString resultString;

if (m_engine.isInterrupted()) {
resultString = QString("Error: code execution was timed out as it exceeded %1 ms. Code must be written to ensure execution does not timeout.").arg(m_timeout);
} else if (result.isError()) {
if (m_engine->isInterrupted()) {
resultString = QString("Error: code execution was interrupted or timed out.");
} else if (result.isError()) {
// NOTE: We purposely do not set the m_error or m_errorString for the code interpreter since
// we *want* the model to see the response has an error so it can hopefully correct itself. The
// error member variables are intended for tools that have error conditions that cannot be corrected.
Expand All @@ -145,9 +160,16 @@ void CodeInterpreterWorker::request(const QString &code)
}

if (resultString.isEmpty())
resultString = consoleCapture.output;
else if (!consoleCapture.output.isEmpty())
resultString += "\n" + consoleCapture.output;
resultString = m_consoleCapture.output;
else if (!m_consoleCapture.output.isEmpty())
resultString += "\n" + m_consoleCapture.output;
m_response = resultString;
emit finished();
}

bool CodeInterpreterWorker::interrupt()
{
m_error = ToolEnums::Error::TimeoutError;
m_engine->setInterrupted(true);
return true;
}
24 changes: 17 additions & 7 deletions gpt4all-chat/src/codeinterpreter.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <QObject>
#include <QString>
#include <QtGlobal>
#include <QThread>

class JavaScriptConsoleCapture : public QObject
{
Expand Down Expand Up @@ -39,32 +40,37 @@ class CodeInterpreterWorker : public QObject {
CodeInterpreterWorker();
virtual ~CodeInterpreterWorker() {}

void reset();
QString response() const { return m_response; }

void request(const QString &code);
void interrupt(qint64 timeout) { m_timeout = timeout; m_engine.setInterrupted(true); }
ToolEnums::Error error() const { return m_error; }
QString errorString() const { return m_errorString; }
bool interrupt();

public Q_SLOTS:
void request(const QString &code);

Q_SIGNALS:
void finished();

private:
qint64 m_timeout = 0;
QJSEngine m_engine;
QString m_response;
ToolEnums::Error m_error = ToolEnums::Error::NoError;
QString m_errorString;
QThread m_thread;
JavaScriptConsoleCapture m_consoleCapture;
QJSEngine *m_engine = nullptr;
};

class CodeInterpreter : public Tool
{
Q_OBJECT
public:
explicit CodeInterpreter() : Tool(), m_error(ToolEnums::Error::NoError) {}
explicit CodeInterpreter();
virtual ~CodeInterpreter() {}

QString run(const QList<ToolParam> &params, qint64 timeout = 2000) override;
void run(const QList<ToolParam> &params) override;
bool interrupt() override;

ToolEnums::Error error() const override { return m_error; }
QString errorString() const override { return m_errorString; }

Expand All @@ -77,9 +83,13 @@ class CodeInterpreter : public Tool
QString exampleCall() const override;
QString exampleReply() const override;

Q_SIGNALS:
void request(const QString &code);

private:
ToolEnums::Error m_error = ToolEnums::Error::NoError;
QString m_errorString;
CodeInterpreterWorker *m_worker;
};

#endif // CODEINTERPRETER_H
6 changes: 5 additions & 1 deletion gpt4all-chat/src/tool.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ class Tool : public QObject
Tool() : QObject(nullptr) {}
virtual ~Tool() {}

virtual QString run(const QList<ToolParam> &params, qint64 timeout = 2000) = 0;
virtual void run(const QList<ToolParam> &params) = 0;
virtual bool interrupt() = 0;

// Tools should set these if they encounter errors. For instance, a tool depending upon the network
// might set these error variables if the network is not available.
Expand Down Expand Up @@ -122,6 +123,9 @@ class Tool : public QObject
bool operator==(const Tool &other) const { return function() == other.function(); }

jinja2::Value jinjaValue() const;

Q_SIGNALS:
void runComplete(const ToolCallInfo &info);
};

#endif // TOOL_H