Skip to content

Commit

Permalink
Refactor the brave search and introduce an abstraction for tool calls.
Browse files Browse the repository at this point in the history
Signed-off-by: Adam Treat <[email protected]>
  • Loading branch information
manyoso committed Jul 31, 2024
1 parent 3f8ee0e commit 8de4954
Show file tree
Hide file tree
Showing 10 changed files with 366 additions and 140 deletions.
3 changes: 2 additions & 1 deletion gpt4all-chat/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,10 @@ qt_add_executable(chat
modellist.h modellist.cpp
mysettings.h mysettings.cpp
network.h network.cpp
sourceexcerpt.h
sourceexcerpt.h sourceexcerpt.cpp
server.h server.cpp
logger.h logger.cpp
tool.h tool.cpp
${APP_ICON_RESOURCE}
${CHAT_EXE_RESOURCES}
)
Expand Down
175 changes: 51 additions & 124 deletions gpt4all-chat/bravesearch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,19 @@

using namespace Qt::Literals::StringLiterals;

QPair<QString, QList<SourceExcerpt>> BraveSearch::search(const QString &apiKey, const QString &query, int topK, unsigned long timeout)
QString BraveSearch::run(const QJsonObject &parameters, qint64 timeout)
{
const QString apiKey = parameters["apiKey"].toString();
const QString query = parameters["query"].toString();
const int count = parameters["count"].toInt();
QThread workerThread;
BraveAPIWorker worker;
worker.moveToThread(&workerThread);
connect(&worker, &BraveAPIWorker::finished, &workerThread, &QThread::quit, Qt::DirectConnection);
connect(this, &BraveSearch::request, &worker, &BraveAPIWorker::request, Qt::QueuedConnection);
connect(&workerThread, &QThread::started, [&worker, apiKey, query, count]() {
worker.request(apiKey, query, count);
});
workerThread.start();
emit request(apiKey, query, topK);
workerThread.wait(timeout);
workerThread.quit();
workerThread.wait();
Expand All @@ -34,174 +38,97 @@ QPair<QString, QList<SourceExcerpt>> BraveSearch::search(const QString &apiKey,
void BraveAPIWorker::request(const QString &apiKey, const QString &query, int topK)
{
m_topK = topK;

// Documentation on the brave web search:
// https://api.search.brave.com/app/documentation/web-search/get-started
QUrl jsonUrl("https://api.search.brave.com/res/v1/web/search");

// Documentation on the query options:
//https://api.search.brave.com/app/documentation/web-search/query
QUrlQuery urlQuery;
urlQuery.addQueryItem("q", query);
urlQuery.addQueryItem("count", QString::number(topK));
urlQuery.addQueryItem("result_filter", "web");
urlQuery.addQueryItem("extra_snippets", "true");
jsonUrl.setQuery(urlQuery);
QNetworkRequest request(jsonUrl);
QSslConfiguration conf = request.sslConfiguration();
conf.setPeerVerifyMode(QSslSocket::VerifyNone);
request.setSslConfiguration(conf);

request.setRawHeader("X-Subscription-Token", apiKey.toUtf8());
// request.setRawHeader("Accept-Encoding", "gzip");
request.setRawHeader("Accept", "application/json");

m_networkManager = new QNetworkAccessManager(this);
QNetworkReply *reply = m_networkManager->get(request);
connect(qGuiApp, &QCoreApplication::aboutToQuit, reply, &QNetworkReply::abort);
connect(reply, &QNetworkReply::finished, this, &BraveAPIWorker::handleFinished);
connect(reply, &QNetworkReply::errorOccurred, this, &BraveAPIWorker::handleErrorOccurred);
}

static QPair<QString, QList<SourceExcerpt>> cleanBraveResponse(const QByteArray& jsonResponse, qsizetype topK = 1)
static QString cleanBraveResponse(const QByteArray& jsonResponse, qsizetype topK = 1)
{
// This parses the response from brave and formats it in json that conforms to the de facto
// standard in SourceExcerpts::fromJson(...)
QJsonParseError err;
QJsonDocument document = QJsonDocument::fromJson(jsonResponse, &err);
if (err.error != QJsonParseError::NoError) {
qWarning() << "ERROR: Couldn't parse: " << jsonResponse << err.errorString();
return QPair<QString, QList<SourceExcerpt>>();
qWarning() << "ERROR: Couldn't parse brave response: " << jsonResponse << err.errorString();
return QString();
}

QString query;
QJsonObject searchResponse = document.object();
QJsonObject cleanResponse;
QString query;
QJsonArray cleanArray;

QList<SourceExcerpt> infos;

if (searchResponse.contains("query")) {
QJsonObject queryObj = searchResponse["query"].toObject();
if (queryObj.contains("original")) {
if (queryObj.contains("original"))
query = queryObj["original"].toString();
}
}

if (searchResponse.contains("mixed")) {
QJsonObject mixedResults = searchResponse["mixed"].toObject();
QJsonArray mainResults = mixedResults["main"].toArray();
QJsonObject resultsObject = searchResponse["web"].toObject();
QJsonArray resultsArray = resultsObject["results"].toArray();

for (int i = 0; i < std::min(mainResults.size(), topK); ++i) {
for (int i = 0; i < std::min(mainResults.size(), resultsArray.size()); ++i) {
QJsonObject m = mainResults[i].toObject();
QString r_type = m["type"].toString();
int idx = m["index"].toInt();
QJsonObject resultsObject = searchResponse[r_type].toObject();
QJsonArray resultsArray = resultsObject["results"].toArray();

QJsonValue cleaned;
SourceExcerpt info;
if (r_type == "web") {
// For web data - add a single output from the search
QJsonObject resultObj = resultsArray[idx].toObject();
QStringList selectedKeys = {"type", "title", "url", "description", "date", "extra_snippets"};
QJsonObject cleanedObj;
for (const auto& key : selectedKeys) {
if (resultObj.contains(key)) {
cleanedObj.insert(key, resultObj[key]);
}
}

QStringList textKeys = {"description", "extra_snippets"};
QJsonObject textObj;
for (const auto& key : textKeys) {
if (resultObj.contains(key)) {
textObj.insert(key, resultObj[key]);
}
Q_ASSERT(r_type == "web");
const int idx = m["index"].toInt();

QJsonObject resultObj = resultsArray[idx].toObject();
QStringList selectedKeys = {"type", "title", "url", "description"};
QJsonObject result;
for (const auto& key : selectedKeys)
if (resultObj.contains(key))
result.insert(key, resultObj[key]);

if (resultObj.contains("page_age"))
result.insert("date", resultObj["page_age"]);

QJsonArray excerpts;
if (resultObj.contains("extra_snippets")) {
QJsonArray snippets = resultObj["extra_snippets"].toArray();
for (int i = 0; i < snippets.size(); ++i) {
QString snippet = snippets[i].toString();
QJsonObject excerpt;
excerpt.insert("text", snippet);
excerpts.append(excerpt);
}

QJsonDocument textObjDoc(textObj);
info.date = resultObj["date"].toString();
info.text = textObjDoc.toJson(QJsonDocument::Indented);
info.url = resultObj["url"].toString();
QJsonObject meta_url = resultObj["meta_url"].toObject();
info.favicon = meta_url["favicon"].toString();
info.title = resultObj["title"].toString();

cleaned = cleanedObj;
} else if (r_type == "faq") {
// For faq data - take a list of all the questions & answers
QStringList selectedKeys = {"type", "question", "answer", "title", "url"};
QJsonArray cleanedArray;
for (const auto& q : resultsArray) {
QJsonObject qObj = q.toObject();
QJsonObject cleanedObj;
for (const auto& key : selectedKeys) {
if (qObj.contains(key)) {
cleanedObj.insert(key, qObj[key]);
}
}
cleanedArray.append(cleanedObj);
}
cleaned = cleanedArray;
} else if (r_type == "infobox") {
QJsonObject resultObj = resultsArray[idx].toObject();
QStringList selectedKeys = {"type", "title", "url", "description", "long_desc"};
QJsonObject cleanedObj;
for (const auto& key : selectedKeys) {
if (resultObj.contains(key)) {
cleanedObj.insert(key, resultObj[key]);
}
}
cleaned = cleanedObj;
} else if (r_type == "videos") {
QStringList selectedKeys = {"type", "url", "title", "description", "date"};
QJsonArray cleanedArray;
for (const auto& q : resultsArray) {
QJsonObject qObj = q.toObject();
QJsonObject cleanedObj;
for (const auto& key : selectedKeys) {
if (qObj.contains(key)) {
cleanedObj.insert(key, qObj[key]);
}
}
cleanedArray.append(cleanedObj);
}
cleaned = cleanedArray;
} else if (r_type == "locations") {
QStringList selectedKeys = {"type", "title", "url", "description", "coordinates", "postal_address", "contact", "rating", "distance", "zoom_level"};
QJsonArray cleanedArray;
for (const auto& q : resultsArray) {
QJsonObject qObj = q.toObject();
QJsonObject cleanedObj;
for (const auto& key : selectedKeys) {
if (qObj.contains(key)) {
cleanedObj.insert(key, qObj[key]);
}
}
cleanedArray.append(cleanedObj);
}
cleaned = cleanedArray;
} else if (r_type == "news") {
QStringList selectedKeys = {"type", "title", "url", "description"};
QJsonArray cleanedArray;
for (const auto& q : resultsArray) {
QJsonObject qObj = q.toObject();
QJsonObject cleanedObj;
for (const auto& key : selectedKeys) {
if (qObj.contains(key)) {
cleanedObj.insert(key, qObj[key]);
}
}
cleanedArray.append(cleanedObj);
}
cleaned = cleanedArray;
} else {
cleaned = QJsonValue();
}

infos.append(info);
cleanArray.append(cleaned);
result.insert("excerpts", excerpts);
cleanArray.append(QJsonValue(result));
}
}

cleanResponse.insert("query", query);
cleanResponse.insert("top_k", cleanArray);
cleanResponse.insert("results", cleanArray);
QJsonDocument cleanedDoc(cleanResponse);

// qDebug().noquote() << document.toJson(QJsonDocument::Indented);
// qDebug().noquote() << cleanedDoc.toJson(QJsonDocument::Indented);

return qMakePair(cleanedDoc.toJson(QJsonDocument::Indented), infos);
return cleanedDoc.toJson(QJsonDocument::Compact);
}

void BraveAPIWorker::handleFinished()
Expand Down
15 changes: 6 additions & 9 deletions gpt4all-chat/bravesearch.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define BRAVESEARCH_H

#include "sourceexcerpt.h"
#include "tool.h"

#include <QObject>
#include <QString>
Expand All @@ -17,7 +18,7 @@ class BraveAPIWorker : public QObject {
, m_topK(1) {}
virtual ~BraveAPIWorker() {}

QPair<QString, QList<SourceExcerpt>> response() const { return m_response; }
QString response() const { return m_response; }

public Q_SLOTS:
void request(const QString &apiKey, const QString &query, int topK);
Expand All @@ -31,21 +32,17 @@ private Q_SLOTS:

private:
QNetworkAccessManager *m_networkManager;
QPair<QString, QList<SourceExcerpt>> m_response;
QString m_response;
int m_topK;
};

class BraveSearch : public QObject {
class BraveSearch : public Tool {
Q_OBJECT
public:
BraveSearch()
: QObject(nullptr) {}
BraveSearch() : Tool() {}
virtual ~BraveSearch() {}

QPair<QString, QList<SourceExcerpt>> search(const QString &apiKey, const QString &query, int topK, unsigned long timeout = 2000);

Q_SIGNALS:
void request(const QString &apiKey, const QString &query, int topK);
QString run(const QJsonObject &parameters, qint64 timeout = 2000) override;
};

#endif // BRAVESEARCH_H
22 changes: 17 additions & 5 deletions gpt4all-chat/chatllm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -880,14 +880,26 @@ bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString

const QString query = args["query"].toString();

// FIXME: This has to handle errors of the tool call
emit toolCalled(tr("searching web..."));
const QString apiKey = MySettings::globalInstance()->braveSearchAPIKey();
Q_ASSERT(apiKey != "");
BraveSearch brave;
const QPair<QString, QList<SourceExcerpt>> braveResponse = brave.search(apiKey, query, 2 /*topK*/,
2000 /*msecs to timeout*/);
emit sourceExcerptsChanged(braveResponse.second);

QJsonObject parameters;
parameters.insert("apiKey", apiKey);
parameters.insert("query", query);
parameters.insert("count", 2);

// FIXME: This has to handle errors of the tool call
const QString braveResponse = brave.run(parameters, 2000 /*msecs to timeout*/);

QString parseError;
QList<SourceExcerpt> sourceExcerpts = SourceExcerpt::fromJson(braveResponse, parseError);
if (!parseError.isEmpty()) {
qWarning() << "ERROR: Could not parse source excerpts for brave response" << parseError;
} else if (!sourceExcerpts.isEmpty()) {
emit sourceExcerptsChanged(sourceExcerpts);
}

// Erase the context of the tool call
m_ctx.n_past = std::max(0, m_ctx.n_past);
Expand All @@ -898,7 +910,7 @@ bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString

// This is a recursive call but isToolCallResponse is checked above to arrest infinite recursive
// tool calls
return promptInternal(QList<QString>()/*collectionList*/, braveResponse.first, toolTemplate,
return promptInternal(QList<QString>()/*collectionList*/, braveResponse, toolTemplate,
n_predict, top_k, top_p, min_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens,
true /*isToolCallResponse*/);

Expand Down
9 changes: 8 additions & 1 deletion gpt4all-chat/qml/ChatView.qml
Original file line number Diff line number Diff line change
Expand Up @@ -1133,7 +1133,14 @@ Rectangle {
sourceSize.width: 24
sourceSize.height: 24
mipmap: true
source: consolidatedSources[0].url === "" ? "qrc:/gpt4all/icons/db.svg" : "qrc:/gpt4all/icons/globe.svg"
source: {
if (typeof consolidatedSources === 'undefined'
|| typeof consolidatedSources[0] === 'undefined'
|| consolidatedSources[0].url === "")
return "qrc:/gpt4all/icons/db.svg";
else
return "qrc:/gpt4all/icons/globe.svg";
}
}

ColorOverlay {
Expand Down
Loading

0 comments on commit 8de4954

Please sign in to comment.