From ad39b942305926adf847eb6386949188542b9055 Mon Sep 17 00:00:00 2001 From: Guogang Li Date: Tue, 4 Feb 2025 15:12:33 -0800 Subject: [PATCH] internal fix PiperOrigin-RevId: 723242250 --- connections/implementation/BUILD | 1 + .../implementation/ble_advertisement.cc | 83 +++++++++--- .../implementation/ble_advertisement_test.cc | 5 + .../implementation/bluetooth_device_name.cc | 117 ++++++++++------- .../implementation/bluetooth_device_name.h | 2 +- .../bluetooth_device_name_test.cc | 2 + .../mediums/ble_v2/ble_advertisement.cc | 84 ++++++++---- .../ble_v2/ble_advertisement_header.cc | 38 +++--- .../mediums/ble_v2/ble_packet.cc | 30 ++++- .../implementation/wifi_lan_service_info.cc | 75 ++++++----- internal/platform/base_input_stream.cc | 121 ++++++++++++++---- internal/platform/base_input_stream.h | 18 ++- internal/platform/byte_utils.cc | 2 +- 13 files changed, 395 insertions(+), 183 deletions(-) diff --git a/connections/implementation/BUILD b/connections/implementation/BUILD index a7edecdc7c..ee3eb054d4 100644 --- a/connections/implementation/BUILD +++ b/connections/implementation/BUILD @@ -289,6 +289,7 @@ cc_test( "//internal/platform/implementation/g3", # build_cleaner: keep "@com_github_protobuf_matchers//protobuf-matchers", "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest_main", ], ) diff --git a/connections/implementation/ble_advertisement.cc b/connections/implementation/ble_advertisement.cc index 83f7fea691..972440d23b 100644 --- a/connections/implementation/ble_advertisement.cc +++ b/connections/implementation/ble_advertisement.cc @@ -14,11 +14,12 @@ #include "connections/implementation/ble_advertisement.h" -#include +#include +#include +#include #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "absl/strings/escaping.h" #include "absl/strings/str_cat.h" #include "connections/implementation/base_pcp_handler.h" #include "connections/implementation/pcp.h" @@ -114,17 +115,22 @@ absl::StatusOr BleAdvertisement::CreateBleAdvertisement( ByteArray advertisement_bytes{ble_advertisement_bytes}; BaseInputStream base_input_stream{advertisement_bytes}; // The first 1 byte is supposed to be the version and pcp. - auto version_and_pcp_byte = static_cast(base_input_stream.ReadUint8()); + auto version_and_pcp_byte = base_input_stream.ReadUint8(); + if (!version_and_pcp_byte.has_value()) { + return absl::InvalidArgumentError( + "Cannot deserialize BleAdvertisement: version_and_pcp."); + } + // The upper 3 bits are supposed to be the version. Version version = - static_cast((version_and_pcp_byte & kVersionBitmask) >> 5); + static_cast((*version_and_pcp_byte & kVersionBitmask) >> 5); if (version != Version::kV1) { return absl::InvalidArgumentError(absl::StrCat( "Cannot deserialize BleAdvertisement: unsupported Version: ", version)); } // The lower 5 bits are supposed to be the Pcp. - Pcp pcp = static_cast(version_and_pcp_byte & kPcpBitmask); + Pcp pcp = static_cast(*version_and_pcp_byte & kPcpBitmask); switch (pcp) { case Pcp::kP2pCluster: // Fall through case Pcp::kP2pStar: // Fall through @@ -139,20 +145,43 @@ absl::StatusOr BleAdvertisement::CreateBleAdvertisement( // advertisement. ByteArray service_id_hash; if (!fast_advertisement) { - service_id_hash = base_input_stream.ReadBytes(kServiceIdHashLength); + auto service_id_hash_bytes = + base_input_stream.ReadBytes(kServiceIdHashLength); + if (!service_id_hash_bytes.has_value()) { + return absl::InvalidArgumentError( + "Cannot deserialize BleAdvertisement: service_id_hash."); + } + + service_id_hash = *service_id_hash_bytes; } // The next 4 bytes are supposed to be the endpoint_id. - std::string endpoint_id = - std::string{base_input_stream.ReadBytes(kEndpointIdLength)}; + auto endpoint_id_bytes = base_input_stream.ReadBytes(kEndpointIdLength); + if (!endpoint_id_bytes.has_value()) { + return absl::InvalidArgumentError( + "Cannot deserialize BleAdvertisement: endpoint_id."); + } + + std::string endpoint_id = std::string{*endpoint_id_bytes}; // The next 1 byte is supposed to be the length of the endpoint_info. auto expected_endpoint_info_length = base_input_stream.ReadUint8(); + if (!expected_endpoint_info_length.has_value()) { + return absl::InvalidArgumentError( + "Cannot deserialize BleAdvertisement: endpoint_info_length."); + } // The next x bytes are the endpoint info. (Max length is 131 bytes or 17 // bytes as fast_advertisement being true). - auto endpoint_info = - base_input_stream.ReadBytes(expected_endpoint_info_length); + auto endpoint_info_bytes = + base_input_stream.ReadBytes(*expected_endpoint_info_length); + if (!endpoint_info_bytes.has_value()) { + return absl::InvalidArgumentError( + "Cannot deserialize BleAdvertisement: endpoint_info."); + } + + ByteArray endpoint_info = *endpoint_info_bytes; + const int max_endpoint_info_length = fast_advertisement ? kMaxFastEndpointInfoLength : kMaxEndpointInfoLength; if (endpoint_info.Empty() || @@ -161,7 +190,7 @@ absl::StatusOr BleAdvertisement::CreateBleAdvertisement( return absl::InvalidArgumentError(absl::StrCat( "Cannot deserialize BleAdvertisement(fast advertisement=", fast_advertisement, "): expected endpointInfo to be ", - expected_endpoint_info_length, " bytes, got ", endpoint_info.size())); + *expected_endpoint_info_length, " bytes, got ", endpoint_info.size())); } // The next 6 bytes are the bluetooth mac address if not fast advertisement. @@ -169,8 +198,12 @@ absl::StatusOr BleAdvertisement::CreateBleAdvertisement( if (!fast_advertisement) { auto bluetooth_mac_address_bytes = base_input_stream.ReadBytes(BluetoothUtils::kBluetoothMacAddressLength); + if (!bluetooth_mac_address_bytes.has_value()) { + return absl::InvalidArgumentError( + "Cannot deserialize BleAdvertisement: bluetooth_mac_address."); + } bluetooth_mac_address = - BluetoothUtils::ToString(bluetooth_mac_address_bytes); + BluetoothUtils::ToString(*bluetooth_mac_address_bytes); } // The next 1 byte is supposed to be the length of the uwb_address. If the @@ -180,24 +213,32 @@ absl::StatusOr BleAdvertisement::CreateBleAdvertisement( BleAdvertisement ble_advertisement; if (base_input_stream.IsAvailable(1)) { auto expected_uwb_address_length = base_input_stream.ReadUint8(); + if (!expected_uwb_address_length.has_value()) { + return absl::InvalidArgumentError( + "Cannot deserialize BleAdvertisement: uwb_address_length."); + } // If the length of uwb_address is not zero, then retrieve it. if (expected_uwb_address_length != 0) { - uwb_address = base_input_stream.ReadBytes(expected_uwb_address_length); - if (uwb_address.Empty() || - uwb_address.size() != expected_uwb_address_length) { - return absl::InvalidArgumentError(absl::StrCat( - "Cannot deserialize BleAdvertisement: expected uwbAddress size to " - "be ", - expected_uwb_address_length, " bytes, got ", uwb_address.size())); + auto uwb_address_bytes = + base_input_stream.ReadBytes(*expected_uwb_address_length); + + if (!uwb_address_bytes.has_value()) { + return absl::InvalidArgumentError( + "Cannot deserialize BleAdvertisement: uwb_address."); } + uwb_address = *uwb_address_bytes; } // The next 1 byte is extra field. if (!fast_advertisement) { if (base_input_stream.IsAvailable(kExtraFieldLength)) { - auto extra_field = static_cast(base_input_stream.ReadUint8()); + auto extra_field = base_input_stream.ReadUint8(); + if (!extra_field.has_value()) { + return absl::InvalidArgumentError( + "Cannot deserialize BleAdvertisement: extra_field."); + } ble_advertisement.web_rtc_state_ = - (extra_field & kWebRtcConnectableFlagBitmask) == 1 + (*extra_field & kWebRtcConnectableFlagBitmask) == 1 ? WebRtcState::kConnectable : WebRtcState::kUnconnectable; } diff --git a/connections/implementation/ble_advertisement_test.cc b/connections/implementation/ble_advertisement_test.cc index db6a021c4c..9b02a9f022 100644 --- a/connections/implementation/ble_advertisement_test.cc +++ b/connections/implementation/ble_advertisement_test.cc @@ -14,11 +14,16 @@ #include "connections/implementation/ble_advertisement.h" +#include +#include + #include "gmock/gmock.h" #include "protobuf-matchers/protocol-buffer-matchers.h" #include "gtest/gtest.h" #include "absl/status/status.h" +#include "absl/strings/string_view.h" #include "connections/implementation/base_pcp_handler.h" +#include "connections/implementation/pcp.h" #include "internal/platform/byte_array.h" namespace nearby { diff --git a/connections/implementation/bluetooth_device_name.cc b/connections/implementation/bluetooth_device_name.cc index 4823e72e7d..ea2284279a 100644 --- a/connections/implementation/bluetooth_device_name.cc +++ b/connections/implementation/bluetooth_device_name.cc @@ -14,15 +14,18 @@ #include "connections/implementation/bluetooth_device_name.h" -#include - -#include +#include +#include #include #include "absl/strings/escaping.h" #include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "connections/implementation/base_pcp_handler.h" +#include "connections/implementation/pcp.h" #include "internal/platform/base64_utils.h" #include "internal/platform/base_input_stream.h" +#include "internal/platform/byte_array.h" #include "internal/platform/logging.h" namespace nearby { @@ -66,48 +69,67 @@ BluetoothDeviceName::BluetoothDeviceName( } if (bluetooth_device_name_bytes.size() < kMinBluetoothDeviceNameLength) { - NEARBY_LOGS(INFO) - << "Cannot deserialize BluetoothDeviceName: expecting min " - << kMinBluetoothDeviceNameLength << " raw bytes, got " - << bluetooth_device_name_bytes.size(); + LOG(INFO) << "Cannot deserialize BluetoothDeviceName: expecting min " + << kMinBluetoothDeviceNameLength << " raw bytes, got " + << bluetooth_device_name_bytes.size(); return; } BaseInputStream base_input_stream{bluetooth_device_name_bytes}; // The first 1 byte is supposed to be the version and pcp. - auto version_and_pcp_byte = static_cast(base_input_stream.ReadUint8()); + auto version_and_pcp_byte = base_input_stream.ReadUint8(); + if (!version_and_pcp_byte.has_value()) { + LOG(INFO) << "Cannot deserialize BluetoothDeviceName: version_and_pcp."; + return; + } // The upper 3 bits are supposed to be the version. version_ = - static_cast((version_and_pcp_byte & kVersionBitmask) >> 5); + static_cast((*version_and_pcp_byte & kVersionBitmask) >> 5); if (version_ != Version::kV1) { - NEARBY_LOGS(INFO) - << "Cannot deserialize BluetoothDeviceName: unsupported version=" - << static_cast(version_); + LOG(INFO) << "Cannot deserialize BluetoothDeviceName: unsupported version=" + << static_cast(version_); return; } // The lower 5 bits are supposed to be the Pcp. - pcp_ = static_cast(version_and_pcp_byte & kPcpBitmask); + pcp_ = static_cast(*version_and_pcp_byte & kPcpBitmask); switch (pcp_) { case Pcp::kP2pCluster: // Fall through case Pcp::kP2pStar: // Fall through case Pcp::kP2pPointToPoint: break; default: - NEARBY_LOGS(INFO) - << "Cannot deserialize BluetoothDeviceName: unsupported V1 PCP " - << static_cast(pcp_); + LOG(INFO) << "Cannot deserialize BluetoothDeviceName: unsupported V1 PCP " + << static_cast(pcp_); return; } // The next 4 bytes are supposed to be the endpoint_id. - endpoint_id_ = std::string{base_input_stream.ReadBytes(kEndpointIdLength)}; + auto endpoint_id_bytes = base_input_stream.ReadBytes(kEndpointIdLength); + if (!endpoint_id_bytes.has_value()) { + LOG(INFO) << "Cannot deserialize BluetoothDeviceName: endpoint_id."; + return; + } + endpoint_id_ = std::string{*endpoint_id_bytes}; // The next 3 bytes are supposed to be the service_id_hash. - service_id_hash_ = base_input_stream.ReadBytes(kServiceIdHashLength); + auto service_id_hash_bytes = + base_input_stream.ReadBytes(kServiceIdHashLength); + if (!service_id_hash_bytes.has_value()) { + LOG(INFO) << "Cannot deserialize BluetoothDeviceName: service_id_hash."; + endpoint_id_.clear(); + return; + } + + service_id_hash_ = *service_id_hash_bytes; // The next 1 byte is field containing WebRtc state. - auto field_byte = static_cast(base_input_stream.ReadUint8()); - web_rtc_state_ = (field_byte & kWebRtcConnectableFlagBitmask) == 1 + auto field_byte = base_input_stream.ReadUint8(); + if (!field_byte.has_value()) { + LOG(INFO) << "Cannot deserialize BluetoothDeviceName: extra_field."; + endpoint_id_.clear(); + return; + } + web_rtc_state_ = (*field_byte & kWebRtcConnectableFlagBitmask) == 1 ? WebRtcState::kConnectable : WebRtcState::kUnconnectable; @@ -116,42 +138,48 @@ BluetoothDeviceName::BluetoothDeviceName( base_input_stream.ReadBytes(kReservedLength); // The next 1 byte is supposed to be the length of the endpoint_info. - std::uint32_t expected_endpoint_info_length = base_input_stream.ReadUint8(); + auto expected_endpoint_info_length = base_input_stream.ReadUint8(); + if (!expected_endpoint_info_length.has_value()) { + LOG(INFO) + << "Cannot deserialize BluetoothDeviceName: endpoint_info_length."; + endpoint_id_.clear(); + return; + } // The rest bytes are supposed to be the endpoint_info - endpoint_info_ = base_input_stream.ReadBytes(expected_endpoint_info_length); - if (endpoint_info_.Empty() || - endpoint_info_.size() != expected_endpoint_info_length) { - NEARBY_LOGS(INFO) << "Cannot deserialize BluetoothDeviceName: expected " - "endpoint info to be " - << expected_endpoint_info_length << " bytes, got " - << endpoint_info_.size(); - - // Clear endpoint_id for validity. + auto endpoint_info_bytes = + base_input_stream.ReadBytes(*expected_endpoint_info_length); + if (!endpoint_info_bytes.has_value()) { + LOG(INFO) << "Cannot deserialize BluetoothDeviceName: endpoint_info."; endpoint_id_.clear(); return; } + endpoint_info_ = *endpoint_info_bytes; // If the input stream has extra bytes, it's for UWB address. The first byte // is the address length. It can be 2-byte short address or 8-byte extended // address. if (base_input_stream.IsAvailable(1)) { // The next 1 byte is supposed to be the length of the uwb_address. - std::uint32_t expected_uwb_address_length = base_input_stream.ReadUint8(); + auto expected_uwb_address_length = base_input_stream.ReadUint8(); + if (!expected_uwb_address_length.has_value()) { + LOG(INFO) + << "Cannot deserialize BluetoothDeviceName: uwb_address_length."; + endpoint_id_.clear(); + return; + } + // If the length of usb_address is not zero, then retrieve it. if (expected_uwb_address_length != 0) { - uwb_address_ = base_input_stream.ReadBytes(expected_uwb_address_length); - if (uwb_address_.Empty() || - uwb_address_.size() != expected_uwb_address_length) { - NEARBY_LOGS(INFO) << "Cannot deserialize BluetoothDeviceName: expected " - "uwbAddress size to be " - << expected_uwb_address_length << " bytes, got " - << uwb_address_.size(); - - // Clear endpoint_id for validity. + auto uwb_address_bytes = + base_input_stream.ReadBytes(*expected_uwb_address_length); + if (!uwb_address_bytes.has_value()) { + LOG(INFO) << "Cannot deserialize BluetoothDeviceName: uwb_address."; endpoint_id_.clear(); return; } + + uwb_address_ = *uwb_address_bytes; } } } @@ -178,11 +206,10 @@ BluetoothDeviceName::operator std::string() const { ByteArray usable_endpoint_info(endpoint_info_); if (endpoint_info_.size() > kMaxEndpointInfoLength) { - NEARBY_LOGS(INFO) - << "While serializing Advertisement, truncating Endpoint Name " - << absl::BytesToHexString(endpoint_info_.data()) << " (" - << endpoint_info_.size() << " bytes) down to " << kMaxEndpointInfoLength - << " bytes"; + LOG(INFO) << "While serializing Advertisement, truncating Endpoint Name " + << absl::BytesToHexString(endpoint_info_.data()) << " (" + << endpoint_info_.size() << " bytes) down to " + << kMaxEndpointInfoLength << " bytes"; usable_endpoint_info.SetData(endpoint_info_.data(), kMaxEndpointInfoLength); } diff --git a/connections/implementation/bluetooth_device_name.h b/connections/implementation/bluetooth_device_name.h index c55a1031b7..8a1399cc5a 100644 --- a/connections/implementation/bluetooth_device_name.h +++ b/connections/implementation/bluetooth_device_name.h @@ -15,7 +15,7 @@ #ifndef CORE_INTERNAL_BLUETOOTH_DEVICE_NAME_H_ #define CORE_INTERNAL_BLUETOOTH_DEVICE_NAME_H_ -#include +#include #include "absl/strings/string_view.h" #include "connections/implementation/base_pcp_handler.h" diff --git a/connections/implementation/bluetooth_device_name_test.cc b/connections/implementation/bluetooth_device_name_test.cc index f339d170bb..5693507820 100644 --- a/connections/implementation/bluetooth_device_name_test.cc +++ b/connections/implementation/bluetooth_device_name_test.cc @@ -16,9 +16,11 @@ #include #include +#include #include "gtest/gtest.h" #include "internal/platform/base64_utils.h" +#include "internal/platform/byte_array.h" namespace nearby { namespace connections { diff --git a/connections/implementation/mediums/ble_v2/ble_advertisement.cc b/connections/implementation/mediums/ble_v2/ble_advertisement.cc index d4b81b76b5..609582580f 100644 --- a/connections/implementation/mediums/ble_v2/ble_advertisement.cc +++ b/connections/implementation/mediums/ble_v2/ble_advertisement.cc @@ -14,9 +14,9 @@ #include "connections/implementation/mediums/ble_v2/ble_advertisement.h" -#include - +#include #include +#include #include #include @@ -101,16 +101,21 @@ absl::StatusOr BleAdvertisement::CreateBleAdvertisement( BaseInputStream base_input_stream(advertisement_bytes); // The first 1 byte is supposed to be the version, socket version and the fast // advertisement flag. - auto version_byte = static_cast(base_input_stream.ReadUint8()); + auto version_byte = base_input_stream.ReadUint8(); + if (!version_byte.has_value()) { + return absl::InvalidArgumentError( + "Cannot deserialize BleAdvertisement: version."); + } - Version version = static_cast((version_byte & kVersionBitmask) >> 5); + Version version = + static_cast((*version_byte & kVersionBitmask) >> 5); if (!IsSupportedVersion(version)) { return absl::InvalidArgumentError(absl::StrCat( "Cannot deserialize BleAdvertisement: unsupported Version ", version)); } SocketVersion socket_version = - static_cast((version_byte & kSocketVersionBitmask) >> 2); + static_cast((*version_byte & kSocketVersionBitmask) >> 2); if (!IsSupportedSocketVersion(socket_version)) { return absl::InvalidArgumentError(absl::StrCat( "Cannot deserialize BleAdvertisement: unsupported SocketVersion ", @@ -118,34 +123,50 @@ absl::StatusOr BleAdvertisement::CreateBleAdvertisement( } bool fast_advertisement = - static_cast((version_byte & kFastAdvertisementFlagBitmask) >> 1); + static_cast((*version_byte & kFastAdvertisementFlagBitmask) >> 1); // The next 3 bytes are supposed to be the service_id_hash if not fast // advertisement. ByteArray service_id_hash; if (!fast_advertisement) { - service_id_hash = base_input_stream.ReadBytes(kServiceIdHashLength); + auto service_id_hash_bytes = + base_input_stream.ReadBytes(kServiceIdHashLength); + if (!service_id_hash_bytes.has_value()) { + return absl::InvalidArgumentError( + "Cannot deserialize BleAdvertisement: service_id_hash."); + } + service_id_hash = *service_id_hash_bytes; } // Data length. - int expected_data_size = - fast_advertisement - ? static_cast( - base_input_stream.ReadBytes(kFastDataSizeLength).data()[0]) - : static_cast(base_input_stream.ReadUint32()); - if (expected_data_size < 0) { - return absl::InvalidArgumentError( - absl::StrCat("Cannot deserialize BleAdvertisement: negative data size ", - expected_data_size)); + uint32_t expected_data_size; + if (fast_advertisement) { + auto fast_data_size_bytes = + base_input_stream.ReadBytes(kFastDataSizeLength); + if (!fast_data_size_bytes.has_value()) { + return absl::InvalidArgumentError( + "Cannot deserialize BleAdvertisement: fast_data_size."); + } + expected_data_size = static_cast(fast_data_size_bytes->data()[0]); + } else { + auto data_size_bytes = base_input_stream.ReadUint32(); + if (!data_size_bytes.has_value()) { + return absl::InvalidArgumentError( + "Cannot deserialize BleAdvertisement: data_size."); + } + expected_data_size = *data_size_bytes; } // Data. // Check that the stated data size is the same as what we received. - auto data = base_input_stream.ReadBytes(expected_data_size); - if (data.size() != expected_data_size) { - return absl::InvalidArgumentError(absl::StrCat( - "Cannot deserialize BleAdvertisement: expected data to be ", - expected_data_size, " bytes, got ", data.size())); + ByteArray data; + if (expected_data_size > 0) { + auto data_bytes = base_input_stream.ReadBytes(expected_data_size); + if (!data_bytes.has_value()) { + return absl::InvalidArgumentError( + "Cannot deserialize BleAdvertisement: data."); + } + data = *data_bytes; } BleAdvertisement ble_advertisement; @@ -158,8 +179,12 @@ absl::StatusOr BleAdvertisement::CreateBleAdvertisement( // Device token. If the number of remaining bytes are valid for device token, // then read it. if (base_input_stream.IsAvailable(kDeviceTokenLength)) { - ble_advertisement.device_token_ = - base_input_stream.ReadBytes(kDeviceTokenLength); + auto device_token_bytes = base_input_stream.ReadBytes(kDeviceTokenLength); + if (!device_token_bytes.has_value()) { + return absl::InvalidArgumentError( + "Cannot deserialize BleAdvertisement: device_token."); + } + ble_advertisement.device_token_ = *device_token_bytes; } else { // No device token no more optional field. return ble_advertisement; @@ -172,8 +197,13 @@ absl::StatusOr BleAdvertisement::CreateBleAdvertisement( int extra_fields_byte_number = kExtraFieldsMaskLength + BleAdvertisementHeader::kPsmValueByteLength; if (base_input_stream.IsAvailable(extra_fields_byte_number)) { - BleExtraFields extra_fields{ - base_input_stream.ReadBytes(extra_fields_byte_number)}; + auto extra_fields_bytes = + base_input_stream.ReadBytes(extra_fields_byte_number); + if (!extra_fields_bytes.has_value()) { + return absl::InvalidArgumentError( + "Cannot deserialize BleAdvertisement: extra_field."); + } + BleExtraFields extra_fields{*extra_fields_bytes}; ble_advertisement.psm_ = extra_fields.GetPsm(); } return ble_advertisement; @@ -276,7 +306,7 @@ BleAdvertisement::BleExtraFields::BleExtraFields( ByteArray mutated_extra_fields_bytes = {ble_extra_fields_bytes}; BaseInputStream base_input_stream{mutated_extra_fields_bytes}; // The first 1 byte is field mask. - auto mask_byte = static_cast(base_input_stream.ReadUint8()); + auto mask_byte = base_input_stream.ReadUint8().value_or(0); if (!mask_byte) { return; } @@ -285,7 +315,7 @@ BleAdvertisement::BleExtraFields::BleExtraFields( if (HasField(mask_byte, kPsmBitmask) && base_input_stream.IsAvailable( BleAdvertisementHeader::kPsmValueByteLength)) { - psm_ = static_cast(base_input_stream.ReadUint16()); + psm_ = base_input_stream.ReadInt16().value_or(0); } } diff --git a/connections/implementation/mediums/ble_v2/ble_advertisement_header.cc b/connections/implementation/mediums/ble_v2/ble_advertisement_header.cc index c04b62eee5..44d72df475 100644 --- a/connections/implementation/mediums/ble_v2/ble_advertisement_header.cc +++ b/connections/implementation/mediums/ble_v2/ble_advertisement_header.cc @@ -14,8 +14,6 @@ #include "connections/implementation/mediums/ble_v2/ble_advertisement_header.h" -#include - #include #include @@ -75,38 +73,40 @@ BleAdvertisementHeader::BleAdvertisementHeader( return; } } else { - NEARBY_LOGS(ERROR) << "Cannot deserialize BLEAdvertisementHeader: failed " - "Base64 decoding"; + LOG(INFO) << "Cannot deserialize BLEAdvertisementHeader: failed " + "Base64 decoding"; return; } } if (advertisement_header_bytes.size() < kMinAdvertisementHeaderLength) { - NEARBY_LOGS(ERROR) - << "Cannot deserialize BleAdvertisementHeader: expecting min " - << kMinAdvertisementHeaderLength << "raw bytes, got " - << advertisement_header_bytes.size(); + LOG(INFO) << "Cannot deserialize BleAdvertisementHeader: expecting min " + << kMinAdvertisementHeaderLength << "raw bytes, got " + << advertisement_header_bytes.size(); return; } BaseInputStream base_input_stream(advertisement_header_bytes); // The first 1 byte is supposed to be the version and number of slots. - auto version_and_num_slots_byte = - static_cast(base_input_stream.ReadUint8()); + auto version_and_num_slots_byte = base_input_stream.ReadUint8(); + if (!version_and_num_slots_byte.has_value()) { + LOG(INFO) << "Cannot deserialize BleAdvertisementHeader: version_and_num."; + return; + } // The upper 3 bits are supposed to be the version. - version_ = - static_cast((version_and_num_slots_byte & kVersionBitmask) >> 5); + version_ = static_cast( + (*version_and_num_slots_byte & kVersionBitmask) >> 5); if (version_ != Version::kV2) { - NEARBY_LOGS(ERROR) + LOG(INFO) << "Cannot deserialize BleAdvertisementHeader: unsupported Version " << static_cast(version_); return; } // The next 1 bit is supposed to be the extended advertisement flag. support_extended_advertisement_ = - ((version_and_num_slots_byte & kExtendedAdvertismentBitMask) >> 4) == 1; + ((*version_and_num_slots_byte & kExtendedAdvertismentBitMask) >> 4) == 1; // The lower 4 bits are supposed to be the number of slots. - num_slots_ = static_cast(version_and_num_slots_byte & kNumSlotsBitmask); + num_slots_ = static_cast(*version_and_num_slots_byte & kNumSlotsBitmask); if (num_slots_ < 0) { version_ = Version::kUndefined; return; @@ -114,15 +114,17 @@ BleAdvertisementHeader::BleAdvertisementHeader( // The next 10 bytes are supposed to be the service_id_bloom_filter. service_id_bloom_filter_ = - base_input_stream.ReadBytes(kServiceIdBloomFilterByteLength); + base_input_stream.ReadBytes(kServiceIdBloomFilterByteLength) + .value_or(ByteArray()); // The next 4 bytes are supposed to be the advertisement_hash. advertisement_hash_ = - base_input_stream.ReadBytes(kAdvertisementHashByteLength); + base_input_stream.ReadBytes(kAdvertisementHashByteLength) + .value_or(ByteArray()); // The next 2 bytes are PSM value. if (base_input_stream.IsAvailable(kPsmValueByteLength)) { - psm_ = static_cast(base_input_stream.ReadUint16()); + psm_ = base_input_stream.ReadInt16().value_or(0); } } diff --git a/connections/implementation/mediums/ble_v2/ble_packet.cc b/connections/implementation/mediums/ble_v2/ble_packet.cc index b9501c1f7f..5bb6e884cb 100644 --- a/connections/implementation/mediums/ble_v2/ble_packet.cc +++ b/connections/implementation/mediums/ble_v2/ble_packet.cc @@ -14,12 +14,15 @@ #include "connections/implementation/mediums/ble_v2/ble_packet.h" +#include #include #include #include +#include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "internal/platform/base_input_stream.h" +#include "internal/platform/byte_array.h" #include "internal/platform/logging.h" #include "proto/mediums/ble_frames.pb.h" @@ -122,21 +125,28 @@ absl::StatusOr BlePacket::CreateDataPacket( BlePacket::BlePacket(const ByteArray& ble_packet_bytes) { if (ble_packet_bytes.Empty()) { - NEARBY_LOGS(ERROR) << "Cannot deserialize BlePacket: null bytes passed in"; + LOG(INFO) << "Cannot deserialize BlePacket: null bytes passed in"; return; } if (ble_packet_bytes.size() < kServiceIdHashLength) { - NEARBY_LOGS(INFO) << "Cannot deserialize BlePacket: expecting min " - << kServiceIdHashLength << " raw bytes, got " - << ble_packet_bytes.size(); + LOG(INFO) << "Cannot deserialize BlePacket: expecting min " + << kServiceIdHashLength << " raw bytes, got " + << ble_packet_bytes.size(); return; } ByteArray packet_bytes(ble_packet_bytes); BaseInputStream base_input_stream{packet_bytes}; // The first 3 bytes are supposed to be the service_id_hash. - service_id_hash_ = base_input_stream.ReadBytes(kServiceIdHashLength); + auto service_id_hash_bytes = + base_input_stream.ReadBytes(kServiceIdHashLength); + if (!service_id_hash_bytes.has_value()) { + LOG(INFO) << "Cannot deserialize BlePacket: service_id_hash."; + return; + } + + service_id_hash_ = *service_id_hash_bytes; if (service_id_hash_ == ByteArray(kControlPacketServiceIdHash, kServiceIdHashLength)) { packet_type_ = BlePacketType::kControl; @@ -145,8 +155,14 @@ BlePacket::BlePacket(const ByteArray& ble_packet_bytes) { } // The rest bytes are supposed to be the data. - data_ = base_input_stream.ReadBytes(ble_packet_bytes.size() - - kServiceIdHashLength); + auto data_bytes = base_input_stream.ReadBytes(ble_packet_bytes.size() - + kServiceIdHashLength); + if (!data_bytes.has_value()) { + LOG(INFO) << "Cannot deserialize BlePacket: data."; + return; + } + + data_ = *data_bytes; } BlePacket::operator ByteArray() const { diff --git a/connections/implementation/wifi_lan_service_info.cc b/connections/implementation/wifi_lan_service_info.cc index 50986aba6f..7857cb0aca 100644 --- a/connections/implementation/wifi_lan_service_info.cc +++ b/connections/implementation/wifi_lan_service_info.cc @@ -14,15 +14,19 @@ #include "connections/implementation/wifi_lan_service_info.h" -#include - -#include +#include +#include #include #include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "connections/implementation/base_pcp_handler.h" +#include "connections/implementation/pcp.h" #include "internal/platform/base64_utils.h" #include "internal/platform/base_input_stream.h" +#include "internal/platform/byte_array.h" #include "internal/platform/logging.h" +#include "internal/platform/nsd_service_info.h" namespace nearby { namespace connections { @@ -68,7 +72,7 @@ WifiLanServiceInfo::WifiLanServiceInfo(const NsdServiceInfo& nsd_service_info) { if (!txt_endpoint_info_name.empty()) { endpoint_info_ = Base64Utils::Decode(txt_endpoint_info_name); if (endpoint_info_.size() > kMaxEndpointInfoLength) { - NEARBY_LOGS(INFO) + LOG(INFO) << "Cannot deserialize EndpointInfo: expecting endpoint info max " << kMaxEndpointInfoLength << " raw bytes, got " << endpoint_info_.size(); @@ -79,75 +83,86 @@ WifiLanServiceInfo::WifiLanServiceInfo(const NsdServiceInfo& nsd_service_info) { std::string service_info_name = nsd_service_info.GetServiceName(); ByteArray service_info_bytes = Base64Utils::Decode(service_info_name); if (service_info_bytes.Empty()) { - NEARBY_LOGS(INFO) + LOG(INFO) << "Cannot deserialize WifiLanServiceInfo: failed Base64 decoding of " << service_info_name; return; } if (service_info_bytes.size() < kMinLanServiceNameLength) { - NEARBY_LOGS(INFO) << "Cannot deserialize WifiLanServiceInfo: expecting min " - << kMinLanServiceNameLength - << " raw bytes, got " - << service_info_bytes.size(); + LOG(INFO) << "Cannot deserialize WifiLanServiceInfo: expecting min " + << kMinLanServiceNameLength << " raw bytes, got " + << service_info_bytes.size(); return; } BaseInputStream base_input_stream{service_info_bytes}; // The first 1 byte is supposed to be the version and pcp. - auto version_and_pcp_byte = static_cast(base_input_stream.ReadUint8()); + auto version_and_pcp_byte = base_input_stream.ReadUint8(); + if (!version_and_pcp_byte.has_value()) { + LOG(INFO) << "Cannot deserialize WifiLanServiceInfo: version_and_pcp."; + return; + } // The upper 3 bits are supposed to be the version. version_ = - static_cast((version_and_pcp_byte & kVersionBitmask) >> 5); + static_cast((*version_and_pcp_byte & kVersionBitmask) >> 5); if (version_ != Version::kV1) { - NEARBY_LOGS(INFO) - << "Cannot deserialize WifiLanServiceInfo: unsupported Version " - << static_cast(version_); + LOG(INFO) << "Cannot deserialize WifiLanServiceInfo: unsupported Version " + << static_cast(version_); return; } // The lower 5 bits are supposed to be the Pcp. - pcp_ = static_cast(version_and_pcp_byte & kPcpBitmask); + pcp_ = static_cast(*version_and_pcp_byte & kPcpBitmask); switch (pcp_) { case Pcp::kP2pCluster: // Fall through case Pcp::kP2pStar: // Fall through case Pcp::kP2pPointToPoint: break; default: - NEARBY_LOGS(INFO) - << "Cannot deserialize WifiLanServiceInfo: unsupported V1 PCP " - << static_cast(pcp_); + LOG(INFO) << "Cannot deserialize WifiLanServiceInfo: unsupported V1 PCP " + << static_cast(pcp_); } // The next 4 bytes are supposed to be the endpoint_id. - endpoint_id_ = std::string{base_input_stream.ReadBytes(kEndpointIdLength)}; + auto endpoint_id_bytes = base_input_stream.ReadBytes(kEndpointIdLength); + if (!endpoint_id_bytes.has_value()) { + LOG(INFO) << "Cannot deserialize WifiLanServiceInfo: endpoint_id."; + return; + } + endpoint_id_ = std::string{*endpoint_id_bytes}; // The next 3 bytes are supposed to be the service_id_hash. - service_id_hash_ = base_input_stream.ReadBytes(kServiceIdHashLength); + auto service_id_hash_bytes = + base_input_stream.ReadBytes(kServiceIdHashLength); + if (!service_id_hash_bytes.has_value()) { + LOG(INFO) << "Cannot deserialize WifiLanServiceInfo: service_id_hash."; + endpoint_id_.clear(); + return; + } + service_id_hash_ = *service_id_hash_bytes; // The next 1 byte is supposed to be the length of the uwb_address. If // available, continues to deserialize UWB address and extra field of WebRtc // state. if (base_input_stream.IsAvailable(1)) { - std::uint32_t expected_uwb_address_length = base_input_stream.ReadUint8(); + auto expected_uwb_address_length = + base_input_stream.ReadUint8().value_or(0); // If the length of uwb_address is not zero, then retrieve it. if (expected_uwb_address_length != 0) { - uwb_address_ = base_input_stream.ReadBytes(expected_uwb_address_length); - if (uwb_address_.Empty() || - uwb_address_.size() != expected_uwb_address_length) { - NEARBY_LOGS(INFO) << "Cannot deserialize WifiLanServiceInfo: expected " - "uwbAddress size to be " - << expected_uwb_address_length << " bytes, got " - << uwb_address_.size(); - // Clear enpoint_id for validity. + auto uwb_address_bytes = + base_input_stream.ReadBytes(expected_uwb_address_length); + if (!uwb_address_bytes.has_value()) { + LOG(INFO) << "Cannot deserialize WifiLanServiceInfo: uwb_address."; endpoint_id_.clear(); return; } + uwb_address_ = *uwb_address_bytes; } // The next 1 byte is extra field. web_rtc_state_ = WebRtcState::kUndefined; if (base_input_stream.IsAvailable(kExtraFieldLength)) { - auto extra_field = static_cast(base_input_stream.ReadUint8()); + auto extra_field = base_input_stream.ReadUint8().value_or(0); web_rtc_state_ = (extra_field & kWebRtcConnectableFlagBitmask) == 1 ? WebRtcState::kConnectable : WebRtcState::kUnconnectable; diff --git a/internal/platform/base_input_stream.cc b/internal/platform/base_input_stream.cc index d0cbc6b588..3740e89bfd 100644 --- a/internal/platform/base_input_stream.cc +++ b/internal/platform/base_input_stream.cc @@ -14,6 +14,13 @@ #include "internal/platform/base_input_stream.h" +#include +#include +#include + +#include "internal/platform/byte_array.h" +#include "internal/platform/exception.h" + namespace nearby { ExceptionOr BaseInputStream::Read(std::int64_t size) { @@ -31,51 +38,92 @@ ExceptionOr BaseInputStream::Read(std::int64_t size) { } } -std::uint8_t BaseInputStream::ReadUint8() { +std::optional BaseInputStream::ReadUint8() { constexpr int byte_size = sizeof(std::uint8_t); - ByteArray read_bytes = ReadBytes(byte_size); - if (read_bytes.Empty() || read_bytes.size() != byte_size) { - return -1; + std::optional read_bytes = ReadBytes(byte_size); + if (!read_bytes.has_value()) { + return std::nullopt; + } + + const char *data = read_bytes->data(); + return static_cast(data[0]); +} + +std::optional BaseInputStream::ReadInt8() { + constexpr int byte_size = sizeof(std::int8_t); + std::optional read_bytes = ReadBytes(byte_size); + if (!read_bytes.has_value()) { + return std::nullopt; } - return read_bytes.data()[0]; + const char *data = read_bytes->data(); + return static_cast(data[0]); } -std::uint16_t BaseInputStream::ReadUint16() { +std::optional BaseInputStream::ReadUint16() { constexpr int byte_size = sizeof(std::uint16_t); - ByteArray read_bytes = ReadBytes(byte_size); - if (read_bytes.Empty() || read_bytes.size() != byte_size) { - return -1; + std::optional read_bytes = ReadBytes(byte_size); + if (!read_bytes.has_value()) { + return std::nullopt; + } + + // Convert from network order. + const unsigned char *data = + reinterpret_cast(read_bytes->data()); + return static_cast(data[0] << 8 | data[1]); +} + +std::optional BaseInputStream::ReadInt16() { + constexpr int byte_size = sizeof(std::int16_t); + std::optional read_bytes = ReadBytes(byte_size); + if (!read_bytes.has_value()) { + return std::nullopt; } // Convert from network order. - const char *data = read_bytes.data(); - return static_cast(data[0]) << 8 | static_cast(data[1]); + const unsigned char *data = + reinterpret_cast(read_bytes->data()); + return static_cast(data[0] << 8 | data[1]); } -std::uint32_t BaseInputStream::ReadUint32() { +std::optional BaseInputStream::ReadUint32() { constexpr int byte_size = sizeof(std::uint32_t); - ByteArray read_bytes = ReadBytes(byte_size); - if (read_bytes.Empty() || read_bytes.size() != byte_size) { - return -1; + std::optional read_bytes = ReadBytes(byte_size); + if (!read_bytes.has_value()) { + return std::nullopt; } // Convert from network order. - const char *data = read_bytes.data(); - return static_cast(data[0]) << 24 | - static_cast(data[1]) << 16 | - static_cast(data[2]) << 8 | static_cast(data[3]); + const unsigned char *data = + reinterpret_cast(read_bytes->data()); + return static_cast(data[0] << 24 | data[1] << 16 | data[2] << 8 | + data[3]); } -std::uint64_t BaseInputStream::ReadUint64() { +std::optional BaseInputStream::ReadInt32() { + constexpr int byte_size = sizeof(std::uint32_t); + std::optional read_bytes = ReadBytes(byte_size); + if (!read_bytes.has_value()) { + return std::nullopt; + } + + // Convert from network order. + const unsigned char *data = + reinterpret_cast(read_bytes->data()); + return static_cast(data[0] << 24 | data[1] << 16 | data[2] << 8 | + data[3]); +} + +std::optional BaseInputStream::ReadUint64() { constexpr int byte_size = sizeof(std::uint64_t); - ByteArray read_bytes = ReadBytes(byte_size); - if (read_bytes.Empty() || read_bytes.size() != byte_size) { - return -1; + std::optional read_bytes = ReadBytes(byte_size); + if (!read_bytes.has_value()) { + return std::nullopt; } // Convert from network order. - const char *data = read_bytes.data(); + const unsigned char *data = + reinterpret_cast(read_bytes->data()); return static_cast(data[0]) << 56 | static_cast(data[1]) << 48 | static_cast(data[2]) << 40 | @@ -85,13 +133,32 @@ std::uint64_t BaseInputStream::ReadUint64() { static_cast(data[6]) << 8 | static_cast(data[7]); } -ByteArray BaseInputStream::ReadBytes(int size) { +std::optional BaseInputStream::ReadInt64() { + constexpr int byte_size = sizeof(std::int64_t); + std::optional read_bytes = ReadBytes(byte_size); + if (!read_bytes.has_value()) { + return std::nullopt; + } + + // Convert from network order. + const unsigned char *data = + reinterpret_cast(read_bytes->data()); + return static_cast(data[0]) << 56 | + static_cast(data[1]) << 48 | + static_cast(data[2]) << 40 | + static_cast(data[3]) << 32 | + static_cast(data[4]) << 24 | + static_cast(data[5]) << 16 | + static_cast(data[6]) << 8 | static_cast(data[7]); +} + +std::optional BaseInputStream::ReadBytes(int size) { ExceptionOr read_bytes_result = Read(size); if (!read_bytes_result.ok()) { - return ByteArray{}; + return std::nullopt; } - return read_bytes_result.GetResult(); + return read_bytes_result.result(); } } // namespace nearby diff --git a/internal/platform/base_input_stream.h b/internal/platform/base_input_stream.h index 02e05bc343..d547498699 100644 --- a/internal/platform/base_input_stream.h +++ b/internal/platform/base_input_stream.h @@ -18,6 +18,7 @@ #include #include #include +#include #include "internal/platform/byte_array.h" #include "internal/platform/exception.h" @@ -46,18 +47,23 @@ class BaseInputStream : public InputStream { return {Exception::kSuccess}; } - std::uint8_t ReadUint8(); - std::uint16_t ReadUint16(); - std::uint32_t ReadUint32(); - std::uint64_t ReadUint64(); - ByteArray ReadBytes(int size); + std::optional ReadUint8(); + std::optional ReadInt8(); + std::optional ReadUint16(); + std::optional ReadInt16(); + std::optional ReadUint32(); + std::optional ReadInt32(); + std::optional ReadUint64(); + std::optional ReadInt64(); + std::optional ReadBytes(int size); + bool IsAvailable(int size) const { return buffer_.size() - position_ >= size; } private: ByteArray &buffer_; - int position_{0}; + size_t position_{0}; }; } // namespace nearby diff --git a/internal/platform/byte_utils.cc b/internal/platform/byte_utils.cc index 5b6f603467..c1bb461355 100644 --- a/internal/platform/byte_utils.cc +++ b/internal/platform/byte_utils.cc @@ -30,7 +30,7 @@ std::string ByteUtils::ToFourDigitString(ByteArray& bytes) { BaseInputStream base_input_stream{bytes}; while (base_input_stream.IsAvailable(1)) { - auto byte = static_cast(base_input_stream.ReadUint8()); + auto byte = base_input_stream.ReadInt8().value_or(0); hashCode = (hashCode + byte * multiplier) % kHashBasePrime; multiplier = multiplier * kHashBaseMultiplier % kHashBasePrime; }