Skip to content

Commit

Permalink
Send connack on exception when we still can
Browse files Browse the repository at this point in the history
This allows us to be more clear especially in the case of MQTT3, where
certain errors can only be communicated in the connack.

As an example, we can throw exceptions about X509 client verification
and the client will get a connack about it.
  • Loading branch information
halfgaar committed Nov 12, 2023
1 parent 99dfd87 commit 311f796
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 22 deletions.
10 changes: 10 additions & 0 deletions client.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,11 @@ ProtocolVersion Client::getProtocolVersion() const
return protocolVersion;
}

void Client::setProtocolVersion(ProtocolVersion version)
{
this->protocolVersion = version;
}

void Client::connectToBridgeTarget(FMQSockaddr_in6 addr)
{
this->lastActivity = std::chrono::steady_clock::now();
Expand Down Expand Up @@ -958,6 +963,11 @@ void Client::clearWill()
session->clearWill();
}

void Client::setClientId(const std::string &id)
{
this->clientid = id;
}

std::string &Client::getMutableUsername()
{
return this->username;
Expand Down
3 changes: 3 additions & 0 deletions client.h
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ class Client
bool getSslReadWantsWrite() const;
bool getSslWriteWantsRead() const;
ProtocolVersion getProtocolVersion() const;
void setProtocolVersion(ProtocolVersion version);
void connectToBridgeTarget(FMQSockaddr_in6 addr);

void startOrContinueSslHandshake();
Expand All @@ -149,7 +150,9 @@ class Client
void setAuthenticated(bool value) { authenticated = value;}
bool getAuthenticated() { return authenticated; }
bool hasConnectPacketSeen() { return connectPacketSeen; }
void setHasConnectPacketSeen() { connectPacketSeen = true; }
std::string &getClientId() { return this->clientid; }
void setClientId(const std::string &id);
const std::string &getUsername() const { return this->username; }
std::string &getMutableUsername();
std::shared_ptr<WillPublish> &getWill() { return this->willPublish; }
Expand Down
28 changes: 16 additions & 12 deletions mqttpacket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -794,6 +794,7 @@ void MqttPacket::handleConnect()
{
if (sender->hasConnectPacketSeen())
throw ProtocolError("Client already sent a CONNECT.", ReasonCodes::ProtocolError);
sender->setHasConnectPacketSeen();

std::shared_ptr<SubscriptionStore> subscriptionStore = MainApp::getMainApp()->getSubscriptionStore();

Expand All @@ -803,18 +804,7 @@ void MqttPacket::handleConnect()
threadData->mqttConnectCounter.inc();

ConnectData connectData = parseConnectData();

std::string username = connectData.username ? connectData.username.value() : "";

if (sender->getX509ClientVerification() > X509ClientVerification::None)
{
std::optional<std::string> certificateUsername = sender->getUsernameFromPeerCertificate();

if (!certificateUsername || certificateUsername.value().empty())
throw ProtocolError("Client certificate did not provider username", ReasonCodes::BadUserNameOrPassword);

username = certificateUsername.value();
}
sender->setProtocolVersion(this->protocolVersion);

sender->setBridge(connectData.bridge);

Expand Down Expand Up @@ -885,6 +875,20 @@ void MqttPacket::handleConnect()
clientIdGenerated = true;
}

sender->setClientId(connectData.client_id);

std::string username = connectData.username ? connectData.username.value() : "";

if (sender->getX509ClientVerification() > X509ClientVerification::None)
{
std::optional<std::string> certificateUsername = sender->getUsernameFromPeerCertificate();

if (!certificateUsername || certificateUsername.value().empty())
throw ProtocolError("Client certificate did not provider username", ReasonCodes::BadUserNameOrPassword);

username = certificateUsername.value();
}

sender->setClientProperties(protocolVersion, connectData.client_id, username, true, connectData.keep_alive,
connectData.max_outgoing_packet_size, connectData.max_outgoing_topic_aliases);

Expand Down
43 changes: 33 additions & 10 deletions threadloop.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,18 +190,41 @@ void do_thread_work(ThreadData *threadData)
catch (ProtocolError &ex)
{
client->setDisconnectReason(ex.what());
if (client->getProtocolVersion() >= ProtocolVersion::Mqtt5 && client->hasConnectPacketSeen())
bool clientRemoved = true;

try
{
Disconnect d(client->getProtocolVersion(), ex.reasonCode);
MqttPacket p(d);
client->writeMqttPacket(p);
client->setReadyForDisconnect();

// When a client's TCP buffers are full (when the client is gone, for instance), EPOLLOUT will never be
// reported. In those cases, the client is not removed; not until the keep-alive mechanism anyway. Is
// that a problem?
if (!client->hasConnectPacketSeen())
{
logger->logf(LOG_ERR, "Protocol error before MQTT traffic: %s. Removing client.", ex.what());
threadData->removeClient(client);
}
else if (!client->getAuthenticated())
{
ConnAck connAck(client->getProtocolVersion(), ex.reasonCode);
MqttPacket p(connAck);
client->writeMqttPacket(p);
client->setReadyForDisconnect();
}
else if (client->getProtocolVersion() >= ProtocolVersion::Mqtt5)
{
Disconnect d(client->getProtocolVersion(), ex.reasonCode);
MqttPacket p(d);
client->writeMqttPacket(p);
client->setReadyForDisconnect();

// When a client's TCP buffers are full (when the client is gone, for instance), EPOLLOUT will never be
// reported. In those cases, the client is not removed; not until the keep-alive mechanism anyway. Is
// that a problem?
}
}
else
catch (std::exception &inner_ex)
{
clientRemoved = false;
logger->log(LOG_ERR) << "Exception when notyfing client about ProtocolError: " << inner_ex.what();
}

if (!clientRemoved)
{
logger->logf(LOG_ERR, "Protocol error: %s. Removing client.", ex.what());
threadData->removeClient(client);
Expand Down

0 comments on commit 311f796

Please sign in to comment.