Skip to content

Commit

Permalink
Send MQTT packet back to client on more errors
Browse files Browse the repository at this point in the history
ProtocolError is now meant for MQTT. We can remove the case for
unauthenticated clients. This means MQTT protocol errors are
communicated back to the client better with MQTT reason codes / connack.

For non-MQTT errors a new exception was introduced, that just removes
the client.
  • Loading branch information
halfgaar committed Nov 18, 2023
1 parent 279bb96 commit 546fceb
Show file tree
Hide file tree
Showing 7 changed files with 26 additions and 16 deletions.
8 changes: 7 additions & 1 deletion exceptions.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ See LICENSE for license details.
* @brief The ProtocolError class is handled by the error handler in the worker threads and is used to make decisions about if and how
* to inform a client and log the message.
*
* If you don't specify a reason code it becomes UnspecifiedError, and no MQTT packet will be sent back to the client.
* It's mainly meant for errors that can be communicated with MQTT packets.
*/
class ProtocolError : public std::runtime_error
{
Expand All @@ -35,6 +35,12 @@ class ProtocolError : public std::runtime_error
}
};

class BadClientException : public std::runtime_error
{
public:
BadClientException(const std::string &msg) : std::runtime_error(msg) {}
};

class NotImplementedException : public std::runtime_error
{
public:
Expand Down
12 changes: 6 additions & 6 deletions iowrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -586,7 +586,7 @@ ssize_t IoWrapper::readWebsocketAndOrSsl(int fd, void *buf, size_t nbytes, IoWra
if (websocketPendingBytes.getSize() * 2 <= 8192)
websocketPendingBytes.doubleSize();
else
throw ProtocolError("Trying to exceed websocket buffer. Probably not valid websocket traffic.");
throw BadClientException("Trying to exceed websocket buffer. Probably not valid websocket traffic.");
}
else
{
Expand Down Expand Up @@ -729,10 +729,10 @@ ssize_t IoWrapper::websocketBytesToReadBuffer(void *buf, const size_t nbytes, Io
headerLength += extendedPayloadLengthLength;

//if (!masked)
// throw ProtocolError("Client must send masked websocket bytes.");
// throw BadClientException("Client must send masked websocket bytes.");

if (reserved != 0)
throw ProtocolError("Reserved bytes in header must be 0.");
throw BadClientException("Reserved bytes in header must be 0.");

if (headerLength > websocketPendingBytes.usedBytes())
return nbytesRead;
Expand Down Expand Up @@ -803,7 +803,7 @@ ssize_t IoWrapper::websocketBytesToReadBuffer(void *buf, const size_t nbytes, Io
// Because these internal websocket frames don't contain bytes for the client, we need to allow them to fit
// fully in websocketPendingBytes, otherwise you can get stuck.
if (incompleteWebsocketRead.frame_bytes_left > (settings->clientMaxWriteBufferSize / 2))
throw ProtocolError("The option 'client_max_write_buffer_size / 2' is lower than the ping frame we're are supposed to pong back. Abusing client?");
throw BadClientException("The option 'client_max_write_buffer_size / 2' is lower than the ping frame we're are supposed to pong back. Abusing client?");

if (incompleteWebsocketRead.frame_bytes_left > websocketPendingBytes.usedBytes())
break;
Expand Down Expand Up @@ -835,7 +835,7 @@ ssize_t IoWrapper::websocketBytesToReadBuffer(void *buf, const size_t nbytes, Io
// Because these internal websocket frames don't contain bytes for the client, we need to allow them to fit
// fully in websocketPendingBytes, otherwise you can get stuck.
if (incompleteWebsocketRead.frame_bytes_left > (settings->clientMaxWriteBufferSize / 2))
throw ProtocolError("Websocket close frame is too big.");
throw BadClientException("Websocket close frame is too big.");

if (incompleteWebsocketRead.frame_bytes_left > websocketPendingBytes.usedBytes())
break;
Expand Down Expand Up @@ -870,7 +870,7 @@ ssize_t IoWrapper::websocketBytesToReadBuffer(void *buf, const size_t nbytes, Io
{
// Specs: "MQTT Control Packets MUST be sent in WebSocket binary data frames. If any other type of data frame is
// received the recipient MUST close the Network Connection [MQTT-6.0.0-1]".
throw ProtocolError(formatString("Websocket frames must be 'binary' or 'ping'. Received: %d", incompleteWebsocketRead.opcode));
throw BadClientException(formatString("Websocket frames must be 'binary' or 'ping'. Received: %d", incompleteWebsocketRead.opcode));
}

if (!incompleteWebsocketRead.sillWorkingOnFrame())
Expand Down
3 changes: 3 additions & 0 deletions mqttpacket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,9 @@ ConnectData MqttPacket::parseConnectData()
throw ProtocolError("Packet contains invalid MQTT marker.", ReasonCodes::MalformedPacket);
}

// Even though we're still parsing, setting this helps the exception handler to make decisions.
sender->setProtocolVersion(this->protocolVersion);

char flagByte = readByte();
bool reserved = !!(flagByte & 0b00000001);

Expand Down
12 changes: 6 additions & 6 deletions threadloop.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -194,12 +194,7 @@ void do_thread_work(ThreadData *threadData)

try
{
if (!client->hasConnectPacketSeen() || ex.reasonCode == ReasonCodes::UnspecifiedError)
{
logger->log(LOG_ERR) << "Unspecified or non-MQTT protocol error: " << ex.what() << ". Removing client.";
threadData->removeClient(client);
}
else if (!client->getAuthenticated())
if (!client->getAuthenticated())
{
ConnAck connAck(client->getProtocolVersion(), ex.reasonCode);

Expand Down Expand Up @@ -242,6 +237,11 @@ void do_thread_work(ThreadData *threadData)
threadData->removeClient(client);
}
}
catch(BadClientException &ex)
{
client->setDisconnectReason(ex.what());
threadData->removeClient(client);
}
catch(std::exception &ex)
{
client->setDisconnectReason(ex.what());
Expand Down
3 changes: 2 additions & 1 deletion types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,9 @@ ConnAck::ConnAck(const ProtocolVersion protVersion, ReasonCodes return_code, boo
protocol_version(protVersion),
session_present(session_present)
{

if (this->protocol_version <= ProtocolVersion::Mqtt311)
{
this->supported_reason_code = true;
ConnAckReturnCodes mqtt3_return = ConnAckReturnCodes::Accepted;

switch (return_code)
Expand Down Expand Up @@ -96,6 +96,7 @@ ConnAck::ConnAck(const ProtocolVersion protVersion, ReasonCodes return_code, boo
}
else
{
this->supported_reason_code = true;
this->return_code = static_cast<uint8_t>(return_code);

// MQTT-3.2.2-6
Expand Down
2 changes: 1 addition & 1 deletion types.h
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ class ConnAck
const ProtocolVersion protocol_version;
uint8_t return_code;
bool session_present = false;
bool supported_reason_code = true;
bool supported_reason_code = false;
std::shared_ptr<Mqtt5PropertyBuilder> propertyBuilder;

size_t getLengthWithoutFixedHeader() const;
Expand Down
2 changes: 1 addition & 1 deletion utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -817,7 +817,7 @@ void exceptionOnNonMqtt(const std::vector<char> &data)

if (strContains(line, "HTTP"))
{
throw ProtocolError("This looks like HTTP traffic.", ReasonCodes::MalformedPacket);
throw BadClientException("This looks like HTTP traffic.");
}
}
}
Expand Down

0 comments on commit 546fceb

Please sign in to comment.