Skip to content

Commit

Permalink
Don't block the gui thread for tool calls (#3435)
Browse files Browse the repository at this point in the history
Signed-off-by: Adam Treat <[email protected]>
  • Loading branch information
manyoso authored Jan 29, 2025
1 parent adafa17 commit 22b8278
Show file tree
Hide file tree
Showing 7 changed files with 148 additions and 86 deletions.
1 change: 1 addition & 0 deletions gpt4all-chat/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/).
- Fix `codesign --verify` failure on macOS ([#3413](https://github.com/nomic-ai/gpt4all/pull/3413))
- Code Interpreter: Fix console.log not accepting a single string after v3.7.0 ([#3426](https://github.com/nomic-ai/gpt4all/pull/3426))
- Fix Phi 3.1 Mini 128K Instruct template (by [@ThiloteE](https://github.com/ThiloteE) in [#3412](https://github.com/nomic-ai/gpt4all/pull/3412))
- Don't block the gui thread for reasoning ([#3435](https://github.com/nomic-ai/gpt4all/pull/3435))

## [3.7.0] - 2025-01-21

Expand Down
1 change: 1 addition & 0 deletions gpt4all-chat/qml/ChatItemView.qml
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ GridLayout {
case Chat.PromptProcessing: return qsTr("processing ...")
case Chat.ResponseGeneration: return qsTr("generating response ...");
case Chat.GeneratingQuestions: return qsTr("generating questions ...");
case Chat.ToolCallGeneration: return qsTr("generating toolcall ...");
default: return ""; // handle unexpected values
}
}
Expand Down
102 changes: 61 additions & 41 deletions gpt4all-chat/src/chat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,11 @@ QVariant Chat::popPrompt(int index)

void Chat::stopGenerating()
{
// In future if we have more than one tool we'll have to keep track of which tools are possibly
// running, but for now we only have one
Tool *toolInstance = ToolModel::globalInstance()->get(ToolCallConstants::CodeInterpreterFunction);
Q_ASSERT(toolInstance);
toolInstance->interrupt();
m_llmodel->stopGenerating();
}

Expand Down Expand Up @@ -242,56 +247,71 @@ void Chat::responseStopped(qint64 promptResponseMs)

const QString possibleToolcall = m_chatModel->possibleToolcall();

Network::globalInstance()->trackChatEvent("response_stopped", {
{"first", m_firstResponse},
{"message_count", chatModel()->count()},
{"$duration", promptResponseMs / 1000.},
});

ToolCallParser parser;
parser.update(possibleToolcall);
if (parser.state() == ToolEnums::ParseState::Complete)
processToolCall(parser.toolCall());
else
responseComplete();
}

if (parser.state() == ToolEnums::ParseState::Complete) {
const QString toolCall = parser.toolCall();

// Regex to remove the formatting around the code
static const QRegularExpression regex("^\\s*```javascript\\s*|\\s*```\\s*$");
QString code = toolCall;
code.remove(regex);
code = code.trimmed();

// Right now the code interpreter is the only available tool
Tool *toolInstance = ToolModel::globalInstance()->get(ToolCallConstants::CodeInterpreterFunction);
Q_ASSERT(toolInstance);

// The param is the code
const ToolParam param = { "code", ToolEnums::ParamType::String, code };
const QString result = toolInstance->run({param}, 10000 /*msecs to timeout*/);
const ToolEnums::Error error = toolInstance->error();
const QString errorString = toolInstance->errorString();

// Update the current response with meta information about toolcall and re-parent
m_chatModel->updateToolCall({
ToolCallConstants::CodeInterpreterFunction,
{ param },
result,
error,
errorString
});

++m_consecutiveToolCalls;

// We limit the number of consecutive toolcalls otherwise we get into a potentially endless loop
if (m_consecutiveToolCalls < 3 || error == ToolEnums::Error::NoError) {
resetResponseState();
emit promptRequested(m_collections); // triggers a new response
return;
}
void Chat::processToolCall(const QString &toolCall)
{
m_responseState = Chat::ToolCallGeneration;
emit responseStateChanged();
// Regex to remove the formatting around the code
static const QRegularExpression regex("^\\s*```javascript\\s*|\\s*```\\s*$");
QString code = toolCall;
code.remove(regex);
code = code.trimmed();

// Right now the code interpreter is the only available tool
Tool *toolInstance = ToolModel::globalInstance()->get(ToolCallConstants::CodeInterpreterFunction);
Q_ASSERT(toolInstance);
connect(toolInstance, &Tool::runComplete, this, &Chat::toolCallComplete, Qt::SingleShotConnection);

// The param is the code
const ToolParam param = { "code", ToolEnums::ParamType::String, code };
m_responseInProgress = true;
emit responseInProgressChanged();
toolInstance->run({param});
}

void Chat::toolCallComplete(const ToolCallInfo &info)
{
// Update the current response with meta information about toolcall and re-parent
m_chatModel->updateToolCall(info);

++m_consecutiveToolCalls;

m_responseInProgress = false;
emit responseInProgressChanged();

// We limit the number of consecutive toolcalls otherwise we get into a potentially endless loop
if (m_consecutiveToolCalls < 3 || info.error == ToolEnums::Error::NoError) {
resetResponseState();
emit promptRequested(m_collections); // triggers a new response
return;
}

responseComplete();
}

void Chat::responseComplete()
{
if (m_generatedName.isEmpty())
emit generateNameRequested();

m_responseState = Chat::ResponseStopped;
emit responseStateChanged();

m_consecutiveToolCalls = 0;
Network::globalInstance()->trackChatEvent("response_complete", {
{"first", m_firstResponse},
{"message_count", chatModel()->count()},
{"$duration", promptResponseMs / 1000.},
});
m_firstResponse = false;
}

Expand Down
6 changes: 5 additions & 1 deletion gpt4all-chat/src/chat.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ class Chat : public QObject
LocalDocsProcessing,
PromptProcessing,
GeneratingQuestions,
ResponseGeneration
ResponseGeneration,
ToolCallGeneration
};
Q_ENUM(ResponseState)

Expand Down Expand Up @@ -166,6 +167,9 @@ private Q_SLOTS:
void promptProcessing();
void generatingQuestions();
void responseStopped(qint64 promptResponseMs);
void processToolCall(const QString &toolCall);
void toolCallComplete(const ToolCallInfo &info);
void responseComplete();
void generatedNameChanged(const QString &name);
void generatedQuestionFinished(const QString &question);
void handleModelLoadingError(const QString &error);
Expand Down
94 changes: 58 additions & 36 deletions gpt4all-chat/src/codeinterpreter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,15 @@

using namespace Qt::Literals::StringLiterals;

CodeInterpreter::CodeInterpreter()
: Tool()
, m_error(ToolEnums::Error::NoError)
{
m_worker = new CodeInterpreterWorker;
connect(this, &CodeInterpreter::request, m_worker, &CodeInterpreterWorker::request, Qt::QueuedConnection);
}

QString CodeInterpreter::run(const QList<ToolParam> &params, qint64 timeout)
void CodeInterpreter::run(const QList<ToolParam> &params)
{
m_error = ToolEnums::Error::NoError;
m_errorString = QString();
Expand All @@ -18,27 +25,24 @@ QString CodeInterpreter::run(const QList<ToolParam> &params, qint64 timeout)
&& params.first().type == ToolEnums::ParamType::String);

const QString code = params.first().value.toString();

QThread workerThread;
CodeInterpreterWorker worker;
worker.moveToThread(&workerThread);
connect(&worker, &CodeInterpreterWorker::finished, &workerThread, &QThread::quit, Qt::DirectConnection);
connect(&workerThread, &QThread::started, [&worker, code]() {
worker.request(code);
connect(m_worker, &CodeInterpreterWorker::finished, [this, params] {
m_error = m_worker->error();
m_errorString = m_worker->errorString();
emit runComplete({
ToolCallConstants::CodeInterpreterFunction,
params,
m_worker->response(),
m_error,
m_errorString
});
});
workerThread.start();
bool timedOut = !workerThread.wait(timeout);
if (timedOut) {
worker.interrupt(timeout); // thread safe
m_error = ToolEnums::Error::TimeoutError;
}
workerThread.quit();
workerThread.wait();
if (!timedOut) {
m_error = worker.error();
m_errorString = worker.errorString();
}
return worker.response();

emit request(code);
}

bool CodeInterpreter::interrupt()
{
return m_worker->interrupt();
}

QList<ToolParamInfo> CodeInterpreter::parameters() const
Expand Down Expand Up @@ -89,17 +93,15 @@ QString CodeInterpreter::exampleReply() const

CodeInterpreterWorker::CodeInterpreterWorker()
: QObject(nullptr)
, m_engine(new QJSEngine(this))
{
}
moveToThread(&m_thread);

void CodeInterpreterWorker::request(const QString &code)
{
JavaScriptConsoleCapture consoleCapture;
QJSValue consoleInternalObject = m_engine.newQObject(&consoleCapture);
m_engine.globalObject().setProperty("console_internal", consoleInternalObject);
QJSValue consoleInternalObject = m_engine->newQObject(&m_consoleCapture);
m_engine->globalObject().setProperty("console_internal", consoleInternalObject);

// preprocess console.log args in JS since Q_INVOKE doesn't support varargs
auto consoleObject = m_engine.evaluate(uR"(
auto consoleObject = m_engine->evaluate(uR"(
class Console {
log(...args) {
if (args.length == 0)
Expand All @@ -116,15 +118,28 @@ void CodeInterpreterWorker::request(const QString &code)
new Console();
)"_s);
m_engine.globalObject().setProperty("console", consoleObject);
m_engine->globalObject().setProperty("console", consoleObject);
m_thread.start();
}

const QJSValue result = m_engine.evaluate(code);
void CodeInterpreterWorker::reset()
{
m_response.clear();
m_error = ToolEnums::Error::NoError;
m_errorString.clear();
m_consoleCapture.output.clear();
m_engine->setInterrupted(false);
}

void CodeInterpreterWorker::request(const QString &code)
{
reset();
const QJSValue result = m_engine->evaluate(code);
QString resultString;

if (m_engine.isInterrupted()) {
resultString = QString("Error: code execution was timed out as it exceeded %1 ms. Code must be written to ensure execution does not timeout.").arg(m_timeout);
} else if (result.isError()) {
if (m_engine->isInterrupted()) {
resultString = QString("Error: code execution was interrupted or timed out.");
} else if (result.isError()) {
// NOTE: We purposely do not set the m_error or m_errorString for the code interpreter since
// we *want* the model to see the response has an error so it can hopefully correct itself. The
// error member variables are intended for tools that have error conditions that cannot be corrected.
Expand All @@ -145,9 +160,16 @@ void CodeInterpreterWorker::request(const QString &code)
}

if (resultString.isEmpty())
resultString = consoleCapture.output;
else if (!consoleCapture.output.isEmpty())
resultString += "\n" + consoleCapture.output;
resultString = m_consoleCapture.output;
else if (!m_consoleCapture.output.isEmpty())
resultString += "\n" + m_consoleCapture.output;
m_response = resultString;
emit finished();
}

bool CodeInterpreterWorker::interrupt()
{
m_error = ToolEnums::Error::TimeoutError;
m_engine->setInterrupted(true);
return true;
}
24 changes: 17 additions & 7 deletions gpt4all-chat/src/codeinterpreter.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <QObject>
#include <QString>
#include <QtGlobal>
#include <QThread>

class JavaScriptConsoleCapture : public QObject
{
Expand Down Expand Up @@ -39,32 +40,37 @@ class CodeInterpreterWorker : public QObject {
CodeInterpreterWorker();
virtual ~CodeInterpreterWorker() {}

void reset();
QString response() const { return m_response; }

void request(const QString &code);
void interrupt(qint64 timeout) { m_timeout = timeout; m_engine.setInterrupted(true); }
ToolEnums::Error error() const { return m_error; }
QString errorString() const { return m_errorString; }
bool interrupt();

public Q_SLOTS:
void request(const QString &code);

Q_SIGNALS:
void finished();

private:
qint64 m_timeout = 0;
QJSEngine m_engine;
QString m_response;
ToolEnums::Error m_error = ToolEnums::Error::NoError;
QString m_errorString;
QThread m_thread;
JavaScriptConsoleCapture m_consoleCapture;
QJSEngine *m_engine = nullptr;
};

class CodeInterpreter : public Tool
{
Q_OBJECT
public:
explicit CodeInterpreter() : Tool(), m_error(ToolEnums::Error::NoError) {}
explicit CodeInterpreter();
virtual ~CodeInterpreter() {}

QString run(const QList<ToolParam> &params, qint64 timeout = 2000) override;
void run(const QList<ToolParam> &params) override;
bool interrupt() override;

ToolEnums::Error error() const override { return m_error; }
QString errorString() const override { return m_errorString; }

Expand All @@ -77,9 +83,13 @@ class CodeInterpreter : public Tool
QString exampleCall() const override;
QString exampleReply() const override;

Q_SIGNALS:
void request(const QString &code);

private:
ToolEnums::Error m_error = ToolEnums::Error::NoError;
QString m_errorString;
CodeInterpreterWorker *m_worker;
};

#endif // CODEINTERPRETER_H
6 changes: 5 additions & 1 deletion gpt4all-chat/src/tool.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ class Tool : public QObject
Tool() : QObject(nullptr) {}
virtual ~Tool() {}

virtual QString run(const QList<ToolParam> &params, qint64 timeout = 2000) = 0;
virtual void run(const QList<ToolParam> &params) = 0;
virtual bool interrupt() = 0;

// Tools should set these if they encounter errors. For instance, a tool depending upon the network
// might set these error variables if the network is not available.
Expand Down Expand Up @@ -122,6 +123,9 @@ class Tool : public QObject
bool operator==(const Tool &other) const { return function() == other.function(); }

jinja2::Value jinjaValue() const;

Q_SIGNALS:
void runComplete(const ToolCallInfo &info);
};

#endif // TOOL_H

0 comments on commit 22b8278

Please sign in to comment.