Skip to content

Commit

Permalink
do not allow continuing a failed conversation
Browse files Browse the repository at this point in the history
Now that we use the ChatModel as the basis for all new messages, it is a
real problem that we cannot tell the difference between errors and model
output. This change allows us to tell the difference and react
accordingly.

Signed-off-by: Jared Van Bortel <[email protected]>
  • Loading branch information
cebtenzzre committed Nov 6, 2024
1 parent 1e0a6d7 commit edafa9e
Show file tree
Hide file tree
Showing 6 changed files with 93 additions and 16 deletions.
17 changes: 9 additions & 8 deletions gpt4all-chat/qml/ChatView.qml
Original file line number Diff line number Diff line change
Expand Up @@ -1197,14 +1197,14 @@ Rectangle {
Accessible.role: Accessible.EditableText
Accessible.name: placeholderText
Accessible.description: qsTr("Send messages/prompts to the model")
Keys.onReturnPressed: (event)=> {
if (event.modifiers & Qt.ControlModifier || event.modifiers & Qt.ShiftModifier)
event.accepted = false;
else {
editingFinished();
sendMessage()
}
}
Keys.onReturnPressed: event => {
if (event.modifiers & Qt.ControlModifier || event.modifiers & Qt.ShiftModifier) {
event.accepted = false;
} else if (!chatModel.hasError) {
editingFinished();
sendMessage();
}
}
function sendMessage() {
if ((textInput.text === "" && attachmentModel.count === 0) || currentChat.responseInProgress)
return
Expand Down Expand Up @@ -1321,6 +1321,7 @@ Rectangle {
imageWidth: theme.fontSizeLargest
imageHeight: theme.fontSizeLargest
visible: !currentChat.responseInProgress && !currentChat.isServer && ModelList.selectableModels.count !== 0
enabled: !chatModel.hasError
source: "qrc:/gpt4all/icons/send_message.svg"
Accessible.name: qsTr("Send message")
Accessible.description: qsTr("Sends the message/prompt contained in textfield to the model")
Expand Down
10 changes: 10 additions & 0 deletions gpt4all-chat/src/chat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ void Chat::connectLLM()
// Should be in different threads
connect(m_llmodel, &ChatLLM::modelLoadingPercentageChanged, this, &Chat::handleModelLoadingPercentageChanged, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::responseChanged, this, &Chat::handleResponseChanged, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::responseFailed, this, &Chat::handleResponseFailed, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::promptProcessing, this, &Chat::promptProcessing, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::generatingQuestions, this, &Chat::generatingQuestions, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::responseStopped, this, &Chat::responseStopped, Qt::QueuedConnection);
Expand Down Expand Up @@ -187,6 +188,15 @@ void Chat::handleResponseChanged(const QString &response)
emit responseChanged();
}

void Chat::handleResponseFailed(const QString &error)
{
m_response = error;
const int index = m_chatModel->count() - 1;
m_chatModel->updateValue(index, this->response());
m_chatModel->setError();
responseStopped(0);
}

void Chat::handleModelLoadingPercentageChanged(float loadingPercentage)
{
if (m_shouldDeleteLater)
Expand Down
1 change: 1 addition & 0 deletions gpt4all-chat/src/chat.h
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ public Q_SLOTS:

private Q_SLOTS:
void handleResponseChanged(const QString &response);
void handleResponseFailed(const QString &error);
void handleModelLoadingPercentageChanged(float);
void promptProcessing();
void generatingQuestions();
Expand Down
2 changes: 1 addition & 1 deletion gpt4all-chat/src/chatllm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -690,7 +690,7 @@ void ChatLLM::prompt(const QStringList &enabledCollections)
promptInternalChat(enabledCollections, promptContextFromSettings(m_modelInfo));
} catch (const std::exception &e) {
// FIXME(jared): this is neither translated nor serialized
emit responseChanged(u"Error: %1"_s.arg(QString::fromUtf8(e.what())));
emit responseFailed(u"Error: %1"_s.arg(QString::fromUtf8(e.what())));
emit responseStopped(0);
}
}
Expand Down
1 change: 1 addition & 0 deletions gpt4all-chat/src/chatllm.h
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ public Q_SLOTS:
void modelLoadingError(const QString &error);
void modelLoadingWarning(const QString &warning);
void responseChanged(const QString &response);
void responseFailed(const QString &error);
void promptProcessing();
void generatingQuestions();
void responseStopped(qint64 promptResponseMs);
Expand Down
78 changes: 71 additions & 7 deletions gpt4all-chat/src/chatmodel.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ struct ChatItem
Q_PROPERTY(QList<ResultInfo> consolidatedSources MEMBER consolidatedSources)
Q_PROPERTY(QList<PromptAttachment> promptAttachments MEMBER promptAttachments)
Q_PROPERTY(QString promptPlusAttachments READ promptPlusAttachments)
Q_PROPERTY(bool isError MEMBER isError)
// DataLake properties
Q_PROPERTY(QString newResponse MEMBER newResponse)
Q_PROPERTY(bool stopped MEMBER stopped)
Expand Down Expand Up @@ -143,6 +144,7 @@ struct ChatItem
bool stopped = false;
bool thumbsUpState = false;
bool thumbsDownState = false;
bool isError = false; // used by assistant messages
};
Q_DECLARE_METATYPE(ChatItem)

Expand All @@ -163,6 +165,7 @@ class ChatModel : public QAbstractListModel
{
Q_OBJECT
Q_PROPERTY(int count READ count NOTIFY countChanged)
Q_PROPERTY(bool hasError READ hasError NOTIFY hasErrorChanged)

public:
explicit ChatModel(QObject *parent = nullptr)
Expand All @@ -180,7 +183,8 @@ class ChatModel : public QAbstractListModel
ThumbsDownStateRole,
SourcesRole,
ConsolidatedSourcesRole,
PromptAttachmentsRole
PromptAttachmentsRole,
IsErrorRole
};

int rowCount(const QModelIndex &parent = QModelIndex()) const override
Expand Down Expand Up @@ -220,6 +224,8 @@ class ChatModel : public QAbstractListModel
return QVariant::fromValue(item.consolidatedSources);
case PromptAttachmentsRole:
return QVariant::fromValue(item.promptAttachments);
case IsErrorRole:
return item.type() == ChatItem::Type::Response && item.isError;
}

return QVariant();
Expand All @@ -239,16 +245,22 @@ class ChatModel : public QAbstractListModel
roles[SourcesRole] = "sources";
roles[ConsolidatedSourcesRole] = "consolidatedSources";
roles[PromptAttachmentsRole] = "promptAttachments";
roles[IsErrorRole] = "isError";
return roles;
}

void appendPrompt(const QString &value, const QList<PromptAttachment> &attachments = {})
{
ChatItem item(ChatItem::prompt_tag, value, attachments);

m_mutex.lock();
const qsizetype count = m_chatItems.count();
m_mutex.unlock();
qsizetype count;
{
QMutexLocker locker(&m_mutex);
if (hasErrorUnlocked())
throw std::logic_error("cannot append to a failed chat");
count = m_chatItems.count();
}

beginInsertRows(QModelIndex(), count, count);
{
QMutexLocker locker(&m_mutex);
Expand All @@ -260,9 +272,14 @@ class ChatModel : public QAbstractListModel

void appendResponse()
{
m_mutex.lock();
const qsizetype count = m_chatItems.count();
m_mutex.unlock();
qsizetype count;
{
QMutexLocker locker(&m_mutex);
if (hasErrorUnlocked())
throw std::logic_error("cannot append to a failed chat");
count = m_chatItems.count();
}

ChatItem item(ChatItem::response_tag, count);
beginInsertRows(QModelIndex(), count, count);
{
Expand All @@ -286,15 +303,20 @@ class ChatModel : public QAbstractListModel
qsizetype nNewItems = history.size() + 1;
qsizetype endIndex = startIndex + nNewItems;
beginInsertRows(QModelIndex(), startIndex, endIndex - 1 /*inclusive*/);
bool hadError;
{
QMutexLocker locker(&m_mutex);
hadError = hasErrorUnlocked();
m_chatItems.reserve(m_chatItems.size() + nNewItems);
for (auto &item : history)
m_chatItems << item;
m_chatItems.emplace_back(ChatItem::response_tag, /*id*/ 0);
}
endInsertRows();
emit countChanged();
// Server can add messages when there is an error because each call is a new conversation
if (hadError)
emit hasErrorChanged(false);
}

Q_INVOKABLE void clear()
Expand Down Expand Up @@ -453,10 +475,31 @@ class ChatModel : public QAbstractListModel
if (changed) emit dataChanged(createIndex(index, 0), createIndex(index, 0), {NewResponseRole});
}

Q_INVOKABLE void setError(bool value = true)
{
qsizetype index;
{
QMutexLocker locker(&m_mutex);

if (m_chatItems.isEmpty() || m_chatItems.cend()[-1].type() != ChatItem::Type::Response)
throw std::logic_error("can only set error on a chat that ends with a response");

index = m_chatItems.count() - 1;
auto &last = m_chatItems.back();
if (last.isError == value)
return; // already set
last.isError = value;
}
emit dataChanged(createIndex(index, 0), createIndex(index, 0), {IsErrorRole});
emit hasErrorChanged(value);
}

qsizetype count() const { QMutexLocker locker(&m_mutex); return m_chatItems.size(); }

ChatModelAccessor chatItems() const { return {m_mutex, std::as_const(m_chatItems)}; }

bool hasError() const { QMutexLocker locker(&m_mutex); return hasErrorUnlocked(); }

bool serialize(QDataStream &stream, int version) const
{
QMutexLocker locker(&m_mutex);
Expand All @@ -470,6 +513,9 @@ class ChatModel : public QAbstractListModel
stream << c.stopped;
stream << c.thumbsUpState;
stream << c.thumbsDownState;
if (version >= 11 && c.type() == ChatItem::Type::Response) {
stream << c.isError;
}
if (version >= 8) {
stream << c.sources.size();
for (const ResultInfo &info : c.sources) {
Expand Down Expand Up @@ -537,6 +583,7 @@ class ChatModel : public QAbstractListModel
{
int size;
stream >> size;
bool hasError = false;
for (int i = 0; i < size; ++i) {
ChatItem c;
stream >> c.id;
Expand All @@ -552,6 +599,9 @@ class ChatModel : public QAbstractListModel
stream >> c.stopped;
stream >> c.thumbsUpState;
stream >> c.thumbsDownState;
if (version >= 11 && c.type() == ChatItem::Type::Response) {
stream >> c.isError;
}
if (version >= 8) {
qsizetype count;
stream >> count;
Expand Down Expand Up @@ -675,16 +725,30 @@ class ChatModel : public QAbstractListModel
{
QMutexLocker locker(&m_mutex);
m_chatItems.append(c);
if (i == size - 1)
hasError = hasErrorUnlocked();
}
endInsertRows();
}
emit countChanged();
if (hasError)
emit hasErrorChanged(true);
return stream.status() == QDataStream::Ok;
}

Q_SIGNALS:
void countChanged();
void valueChanged(int index, const QString &value);
void hasErrorChanged(bool value);

private:
bool hasErrorUnlocked() const
{
if (m_chatItems.isEmpty())
return false;
auto &last = m_chatItems.back();
return last.type() == ChatItem::Type::Response && last.isError;
}

private:
mutable QMutex m_mutex;
Expand Down

0 comments on commit edafa9e

Please sign in to comment.