From f389ed796d10bf02152dd9094d27220530a8bbc1 Mon Sep 17 00:00:00 2001 From: Guogang Li Date: Thu, 9 Jan 2025 11:32:50 -0800 Subject: [PATCH] Support blocking socket in Wi-Fi LAN medium PiperOrigin-RevId: 713740381 --- .../implementation/windows/wifi_lan.h | 27 +- .../implementation/windows/wifi_lan_medium.cc | 269 +++++++++----- .../windows/wifi_lan_server_socket.cc | 230 +++++++----- .../implementation/windows/wifi_lan_socket.cc | 345 +++++++++++------- 4 files changed, 559 insertions(+), 312 deletions(-) diff --git a/internal/platform/implementation/windows/wifi_lan.h b/internal/platform/implementation/windows/wifi_lan.h index faa16fbe44..1cec735b89 100644 --- a/internal/platform/implementation/windows/wifi_lan.h +++ b/internal/platform/implementation/windows/wifi_lan.h @@ -44,7 +44,10 @@ #include "internal/platform/exception.h" #include "internal/platform/implementation/cancelable.h" #include "internal/platform/implementation/wifi_lan.h" +#include "internal/platform/implementation/windows/nearby_client_socket.h" +#include "internal/platform/implementation/windows/nearby_server_socket.h" #include "internal/platform/implementation/windows/scheduled_executor.h" +#include "internal/platform/implementation/windows/wifi_lan_mdns.h" #include "internal/platform/input_stream.h" #include "internal/platform/mutex.h" #include "internal/platform/nsd_service_info.h" @@ -98,7 +101,9 @@ using winrt::Windows::Storage::Streams::IOutputStream; // remote WiFi LAN service, also will return a WifiLanSocket to caller. class WifiLanSocket : public api::WifiLanSocket { public: + WifiLanSocket(); explicit WifiLanSocket(StreamSocket socket); + explicit WifiLanSocket(std::unique_ptr socket); WifiLanSocket(WifiLanSocket&) = default; WifiLanSocket(WifiLanSocket&&) = default; ~WifiLanSocket() override; @@ -122,11 +127,14 @@ class WifiLanSocket : public api::WifiLanSocket { // Returns Exception::kIo on error, Exception::kSuccess otherwise. Exception Close() override; + bool Connect(const std::string& ip_address, int port); + private: // A simple wrapper to handle input stream of socket class SocketInputStream : public InputStream { public: - SocketInputStream(IInputStream input_stream); + explicit SocketInputStream(IInputStream input_stream); + explicit SocketInputStream(NearbyClientSocket* client_socket); ~SocketInputStream() = default; ExceptionOr Read(std::int64_t size) override; @@ -134,14 +142,17 @@ class WifiLanSocket : public api::WifiLanSocket { Exception Close() override; private: + bool enable_blocking_socket_ = false; IInputStream input_stream_{nullptr}; Buffer read_buffer_{nullptr}; + NearbyClientSocket* client_socket_{nullptr}; }; // A simple wrapper to handle output stream of socket class SocketOutputStream : public OutputStream { public: - SocketOutputStream(IOutputStream output_stream); + explicit SocketOutputStream(IOutputStream output_stream); + explicit SocketOutputStream(NearbyClientSocket* client_socket); ~SocketOutputStream() = default; Exception Write(const ByteArray& data) override; @@ -149,13 +160,18 @@ class WifiLanSocket : public api::WifiLanSocket { Exception Close() override; private: + bool enable_blocking_socket_ = false; IOutputStream output_stream_{nullptr}; + NearbyClientSocket* client_socket_{nullptr}; }; // Internal properties StreamSocket stream_soket_{nullptr}; SocketInputStream input_stream_{nullptr}; SocketOutputStream output_stream_{nullptr}; + + bool enable_blocking_socket_ = false; + std::unique_ptr client_socket_; }; // WifiLanServerSocket provides the support to server socket, this server socket @@ -222,6 +238,10 @@ class WifiLanServerSocket : public api::WifiLanServerSocket { // Cache socket not be picked by upper layer int port_ = 0; bool closed_ = false; + + // Flag to enable blocking socket. + bool enable_blocking_socket_ = false; + NearbyServerSocket server_socket_; }; // Container of operations that can be performed over the WifiLan medium. @@ -344,6 +364,9 @@ class WifiLanMedium : public api::WifiLanMedium { // Used to keep the service name is advertising. std::string service_name_; + // mDNS service + WifiLanMdns wifi_lan_mdns_; + // Keep the server sockets listener pointer absl::flat_hash_map diff --git a/internal/platform/implementation/windows/wifi_lan_medium.cc b/internal/platform/implementation/windows/wifi_lan_medium.cc index 9f672eceec..b6af02dd4f 100644 --- a/internal/platform/implementation/windows/wifi_lan_medium.cc +++ b/internal/platform/implementation/windows/wifi_lan_medium.cc @@ -37,10 +37,13 @@ #include "absl/time/time.h" // Nearby connections headers +#include "absl/container/flat_hash_map.h" +#include "internal/flags/nearby_flags.h" #include "internal/platform/cancellation_flag.h" #include "internal/platform/cancellation_flag_listener.h" #include "internal/platform/exception.h" #include "internal/platform/feature_flags.h" +#include "internal/platform/flags/nearby_platform_feature_flags.h" #include "internal/platform/implementation/windows/string_utils.h" #include "internal/platform/implementation/windows/utils.h" #include "internal/platform/logging.h" @@ -56,7 +59,7 @@ constexpr absl::string_view kDeviceIpv4 = "IPv4"; // mDNS information for advertising and discovery constexpr std::wstring_view kMdnsHostName = L"Windows.local"; -constexpr absl::string_view kMdnsInstanceNameFormat = "%s.%slocal"; +const char kMdnsInstanceNameFormat[] = "%s.%slocal"; constexpr absl::string_view kMdnsDeviceSelectorFormat = "System.Devices.AepService.ProtocolId:=\"{4526e8c1-8aac-4153-9b16-" "55e86ada0e54}\" " @@ -114,28 +117,8 @@ bool WifiLanMedium::StartAdvertising(const NsdServiceInfo& nsd_service_info) { service_name_ = nsd_service_info.GetServiceName(); - std::string instance_name = - absl::StrFormat(kMdnsInstanceNameFormat.data(), service_name_, - nsd_service_info.GetServiceType()); - - LOG(INFO) << "mDNS instance name is " << instance_name; - - dnssd_service_instance_ = DnssdServiceInstance{ - string_utils::StringToWideString(instance_name), - nullptr, // let windows use default computer's local name - (uint16)nsd_service_info.GetPort()}; - - // Add TextRecords from NsdServiceInfo - auto text_attributes = dnssd_service_instance_.TextAttributes(); - absl::flat_hash_map text_records = nsd_service_info.GetTxtRecords(); - auto it = text_records.begin(); - while (it != text_records.end()) { - text_attributes.Insert(string_utils::StringToWideString(it->first), - string_utils::StringToWideString(it->second)); - it++; - } // Add IPv4 address in text attributes. std::vector ipv4_addresses = GetIpv4Addresses(); @@ -143,34 +126,78 @@ bool WifiLanMedium::StartAdvertising(const NsdServiceInfo& nsd_service_info) { if (ipv4_addresses.size() > 1) { LOG(WARNING) << "The device has multiple IPv4 addresses."; } - text_attributes.Insert(winrt::to_hstring(std::string(kDeviceIpv4)), - winrt::to_hstring(ipv4_addresses[0])); + text_records.insert_or_assign(std::string(kDeviceIpv4), ipv4_addresses[0]); } - dnssd_regirstraion_result_ = dnssd_service_instance_ - .RegisterStreamSocketListenerAsync( - server_socket_ptr->GetSocketListener()) - .get(); + if (NearbyFlags::GetInstance().GetBoolFlag( + nearby::platform::config_package_nearby::nearby_platform_feature:: + kEnableBlockingSocket)) { + bool result = wifi_lan_mdns_.StartMdnsService( + service_name_, nsd_service_info.GetServiceType(), + nsd_service_info.GetPort(), text_records); + + if (result) { + LOG(INFO) << "started to mDNS advertising."; + medium_status_ |= kMediumStatusAdvertising; + return true; + } - if (dnssd_regirstraion_result_.HasInstanceNameChanged()) { - LOG(WARNING) << "advertising instance name was changed due to have " - "same name instance was running."; - // stop the service and return false - StopAdvertising(nsd_service_info); + LOG(ERROR) << "failed to start mDNS advertising."; return false; - } - if (dnssd_regirstraion_result_.Status() == DnssdRegistrationStatus::Success) { - LOG(INFO) << "started to advertising."; - medium_status_ |= kMediumStatusAdvertising; - return true; - } + } else { + std::string instance_name = + absl::StrFormat(kMdnsInstanceNameFormat, service_name_, + nsd_service_info.GetServiceType()); + + LOG(INFO) << "mDNS instance name is " << instance_name; + + dnssd_service_instance_ = DnssdServiceInstance{ + string_utils::StringToWideString(instance_name), + nullptr, // let windows use default computer's local name + (uint16_t)nsd_service_info.GetPort()}; + + // Add TextRecords from NsdServiceInfo + auto text_attributes = dnssd_service_instance_.TextAttributes(); + + auto it = text_records.begin(); + while (it != text_records.end()) { + text_attributes.Insert(string_utils::StringToWideString(it->first), + string_utils::StringToWideString(it->second)); + it++; + } - // Clean up - LOG(ERROR) << "failed to start advertising due to registration failure."; - dnssd_service_instance_ = nullptr; - dnssd_regirstraion_result_ = nullptr; - return false; + if (server_socket_ptr == nullptr) { + LOG(ERROR) << "server socket is null."; + return false; + } + + dnssd_regirstraion_result_ = dnssd_service_instance_ + .RegisterStreamSocketListenerAsync( + server_socket_ptr->GetSocketListener()) + .get(); + + if (dnssd_regirstraion_result_.HasInstanceNameChanged()) { + LOG(WARNING) << "advertising instance name was changed due to have " + "same name instance was running."; + // stop the service and return false + StopAdvertising(nsd_service_info); + return false; + } + + if (dnssd_regirstraion_result_.Status() == + DnssdRegistrationStatus::Success) { + LOG(INFO) << "started to advertising."; + medium_status_ |= kMediumStatusAdvertising; + return true; + } + + // Clean up + LOG(ERROR) << "failed to start advertising due to registration failure."; + dnssd_service_instance_ = nullptr; + dnssd_regirstraion_result_ = nullptr; + return false; + } } bool WifiLanMedium::StopAdvertising(const NsdServiceInfo& nsd_service_info) { @@ -181,12 +208,27 @@ bool WifiLanMedium::StopAdvertising(const NsdServiceInfo& nsd_service_info) { return false; } - dnssd_service_instance_ = nullptr; + if (NearbyFlags::GetInstance().GetBoolFlag( + nearby::platform::config_package_nearby::nearby_platform_feature:: + kEnableBlockingSocket)) { + bool result = wifi_lan_mdns_.StopMdnsService(); - LOG(INFO) << "succeeded to stop mDNS advertising for service type =" - << nsd_service_info.GetServiceType(); - medium_status_ &= (~kMediumStatusAdvertising); - return true; + if (result) { + LOG(INFO) << "succeeded to stop mDNS advertising."; + medium_status_ &= (~kMediumStatusAdvertising); + return true; + } + + LOG(ERROR) << "failed to stop mDNS advertising."; + return false; + } else { + dnssd_service_instance_ = nullptr; + + LOG(INFO) << "succeeded to stop mDNS advertising for service type =" + << nsd_service_info.GetServiceType(); + medium_status_ &= (~kMediumStatusAdvertising); + return true; + } } // Returns true once the WifiLan discovery has been initiated. @@ -289,70 +331,105 @@ std::unique_ptr WifiLanMedium::ConnectToService( return nullptr; } - std::unique_ptr connection_cancellation_listener = - nullptr; + if (NearbyFlags::GetInstance().GetBoolFlag( + nearby::platform::config_package_nearby::nearby_platform_feature:: + kEnableBlockingSocket)) { + auto wifi_lan_socket = std::make_unique(); - HostName host_name{ - string_utils::StringToWideString(std::string(ipv4_address))}; - winrt::hstring service_name{winrt::to_hstring(port)}; + // setup cancel listener + std::unique_ptr connection_cancellation_listener = + nullptr; + if (cancellation_flag != nullptr) { + if (cancellation_flag->Cancelled()) { + LOG(WARNING) << "connect has been cancelled to service " << ipv4_address + << ":" << port; + return nullptr; + } - StreamSocket socket{}; + connection_cancellation_listener = + std::make_unique( + cancellation_flag, [socket = wifi_lan_socket.get()]() { + LOG(WARNING) << "connect is closed due to it is cancelled."; + socket->Close(); + }); + } - // setup cancel listener - if (cancellation_flag != nullptr) { - if (cancellation_flag->Cancelled()) { - LOG(INFO) << "connect has been cancelled to service " << ipv4_address - << ":" << port; + bool result = wifi_lan_socket->Connect(ipv4_address, port); + if (!result) { + LOG(ERROR) << "failed to connect to service " << ipv4_address << ":" + << port; return nullptr; } - connection_cancellation_listener = - std::make_unique( - cancellation_flag, [socket]() { - LOG(WARNING) << "connect is closed due to it is cancelled."; - socket.Close(); - }); - } + LOG(INFO) << "connected to remote service " << ipv4_address << ":" << port; - // connection to the service - try { - if (FeatureFlags::GetInstance().GetFlags().enable_connection_timeout) { - connection_timeout_ = scheduled_executor_.Schedule( - [socket]() { - LOG(WARNING) << "connect is closed due to timeout."; - socket.Close(); - }, - kConnectServiceTimeout); + return wifi_lan_socket; + + } else { + std::unique_ptr connection_cancellation_listener = + nullptr; + + HostName host_name{ + string_utils::StringToWideString(std::string(ipv4_address))}; + winrt::hstring service_name{winrt::to_hstring(port)}; + + StreamSocket socket{}; + + // setup cancel listener + if (cancellation_flag != nullptr) { + if (cancellation_flag->Cancelled()) { + LOG(INFO) << "connect has been cancelled to service " << ipv4_address + << ":" << port; + return nullptr; + } + + connection_cancellation_listener = + std::make_unique( + cancellation_flag, [socket]() { + LOG(WARNING) << "connect is closed due to it is cancelled."; + socket.Close(); + }); } - socket.ConnectAsync(host_name, service_name).get(); + // connection to the service + try { + if (FeatureFlags::GetInstance().GetFlags().enable_connection_timeout) { + connection_timeout_ = scheduled_executor_.Schedule( + [socket]() { + LOG(WARNING) << "connect is closed due to timeout."; + socket.Close(); + }, + kConnectServiceTimeout); + } + + socket.ConnectAsync(host_name, service_name).get(); + + if (connection_timeout_ != nullptr) { + connection_timeout_->Cancel(); + connection_timeout_ = nullptr; + } + + auto wifi_lan_socket = std::make_unique(socket); + + std::string local_address = + winrt::to_string(socket.Information().LocalAddress().DisplayName()); + std::string local_port = + winrt::to_string(socket.Information().LocalPort()); + LOG(INFO) << "connected to remote service " << ipv4_address << ":" << port + << " with local address " << local_address << ":" << local_port; + return wifi_lan_socket; + } catch (...) { + LOG(ERROR) << "failed to connect remote service " << ipv4_address << ":" + << port; + } if (connection_timeout_ != nullptr) { connection_timeout_->Cancel(); connection_timeout_ = nullptr; } - std::unique_ptr wifi_lan_socket = - std::make_unique(socket); - - std::string local_address = - winrt::to_string(socket.Information().LocalAddress().DisplayName()); - std::string local_port = winrt::to_string(socket.Information().LocalPort()); - - LOG(INFO) << "connected to remote service " << ipv4_address << ":" << port - << " with local address " << local_address << ":" << local_port; - return wifi_lan_socket; - } catch (...) { - LOG(ERROR) << "failed to connect remote service " << ipv4_address << ":" - << port; - } - - if (connection_timeout_ != nullptr) { - connection_timeout_->Cancel(); - connection_timeout_ = nullptr; + return nullptr; } - - return nullptr; } std::unique_ptr WifiLanMedium::ListenForService( diff --git a/internal/platform/implementation/windows/wifi_lan_server_socket.cc b/internal/platform/implementation/windows/wifi_lan_server_socket.cc index 8f4b7990a7..c95be5c4c9 100644 --- a/internal/platform/implementation/windows/wifi_lan_server_socket.cc +++ b/internal/platform/implementation/windows/wifi_lan_server_socket.cc @@ -17,13 +17,17 @@ #include #include #include +#include #include #include "absl/functional/any_invocable.h" #include "absl/synchronization/mutex.h" +#include "internal/flags/nearby_flags.h" #include "internal/platform/exception.h" +#include "internal/platform/flags/nearby_platform_feature_flags.h" #include "internal/platform/implementation/wifi_lan.h" #include "internal/platform/implementation/windows/generated/winrt/Windows.Networking.Sockets.h" +#include "internal/platform/implementation/windows/nearby_server_socket.h" #include "internal/platform/implementation/windows/utils.h" #include "internal/platform/implementation/windows/wifi_lan.h" #include "internal/platform/logging.h" @@ -36,32 +40,44 @@ using ::winrt::Windows::Networking::Sockets::SocketQualityOfService; } -WifiLanServerSocket::WifiLanServerSocket(int port) : port_(port) {} +WifiLanServerSocket::WifiLanServerSocket(int port) : port_(port) { + enable_blocking_socket_ = NearbyFlags::GetInstance().GetBoolFlag( + nearby::platform::config_package_nearby::nearby_platform_feature:: + kEnableBlockingSocket); +} WifiLanServerSocket::~WifiLanServerSocket() { Close(); } // Returns the first IP address. std::string WifiLanServerSocket::GetIPAddress() const { - if (stream_socket_listener_ == nullptr) { - LOG(ERROR) << "Failed to get IP address due to no server socket."; - return ""; - } + if (enable_blocking_socket_) { + return ipaddr_dotdecimal_to_4bytes_string(server_socket_.GetIPAddress()); + } else { + if (stream_socket_listener_ == nullptr) { + LOG(ERROR) << "Failed to get IP address due to no server socket."; + return ""; + } - if (ip_addresses_.empty()) { - LOG(ERROR) << "Failed to get IP address due to no avaible IP addresses."; - return ""; - } + if (ip_addresses_.empty()) { + LOG(ERROR) << "Failed to get IP address due to no avaible IP addresses."; + return ""; + } - return ip_addresses_.front(); + return ip_addresses_.front(); + } } // Returns socket port. int WifiLanServerSocket::GetPort() const { - if (stream_socket_listener_ == nullptr) { - return 0; - } + if (enable_blocking_socket_) { + return server_socket_.GetPort(); + } else { + if (stream_socket_listener_ == nullptr) { + return 0; + } - return std::stoi(stream_socket_listener_.Information().LocalPort().c_str()); + return std::stoi(stream_socket_listener_.Information().LocalPort().c_str()); + } } // Blocks until either: @@ -71,19 +87,29 @@ int WifiLanServerSocket::GetPort() const { // Returns nullptr on error. // Once error is reported, it is permanent, and ServerSocket has to be closed. std::unique_ptr WifiLanServerSocket::Accept() { - absl::MutexLock lock(&mutex_); - LOG(INFO) << __func__ << ": Accept is called."; + if (enable_blocking_socket_) { + auto client_socket = server_socket_.Accept(); + if (client_socket == nullptr) { + return nullptr; + } - while (!closed_ && pending_sockets_.empty()) { - cond_.Wait(&mutex_); - } - if (closed_) return {}; + LOG(INFO) << __func__ << ": Accepted a remote connection."; - StreamSocket wifi_lan_socket = pending_sockets_.front(); - pending_sockets_.pop_front(); + return std::make_unique(std::move(client_socket)); + } else { + absl::MutexLock lock(&mutex_); + LOG(INFO) << __func__ << ": Accept is called."; + while (!closed_ && pending_sockets_.empty()) { + cond_.Wait(&mutex_); + } + if (closed_) return {}; - LOG(INFO) << __func__ << ": Accepted a remote connection."; - return std::make_unique(wifi_lan_socket); + StreamSocket wifi_lan_socket = pending_sockets_.front(); + pending_sockets_.pop_front(); + + LOG(INFO) << __func__ << ": Accepted a remote connection."; + return std::make_unique(wifi_lan_socket); + } } void WifiLanServerSocket::SetCloseNotifier( @@ -96,27 +122,41 @@ Exception WifiLanServerSocket::Close() { try { absl::MutexLock lock(&mutex_); LOG(INFO) << __func__ << ": Close is called."; + if (enable_blocking_socket_) { + if (closed_) { + return {Exception::kSuccess}; + } - if (closed_) { - return {Exception::kSuccess}; - } - if (stream_socket_listener_ != nullptr) { - stream_socket_listener_.ConnectionReceived(listener_event_token_); - stream_socket_listener_.Close(); - stream_socket_listener_ = nullptr; + LOG(INFO) << __func__ << ": closing blocking socket."; - for (const auto& pending_socket : pending_sockets_) { - pending_socket.Close(); + server_socket_.Close(); + closed_ = true; + + if (close_notifier_ != nullptr) { + close_notifier_(); + } + } else { + if (closed_) { + return {Exception::kSuccess}; } + if (stream_socket_listener_ != nullptr) { + stream_socket_listener_.ConnectionReceived(listener_event_token_); + stream_socket_listener_.Close(); + stream_socket_listener_ = nullptr; - pending_sockets_ = {}; - } + for (const auto& pending_socket : pending_sockets_) { + pending_socket.Close(); + } - closed_ = true; - cond_.SignalAll(); + pending_sockets_ = {}; + } - if (close_notifier_ != nullptr) { - close_notifier_(); + closed_ = true; + cond_.SignalAll(); + + if (close_notifier_ != nullptr) { + close_notifier_(); + } } LOG(INFO) << __func__ << ": Close completed succesfully."; @@ -141,68 +181,78 @@ Exception WifiLanServerSocket::Close() { } bool WifiLanServerSocket::listen() { - // Get current IP addresses of the device. - ip_addresses_ = Get4BytesIpv4Addresses(); + if (enable_blocking_socket_) { + ip_addresses_ = GetIpv4Addresses(); - if (ip_addresses_.empty()) { - LOG(WARNING) << "failed to start accepting connection without IP " - "addresses configured on computer."; - return false; - } + if (!server_socket_.Listen(ip_addresses_.front(), port_)) { + LOG(ERROR) << "Failed to listen socket at " << ip_addresses_.front() + << ":" << port_; + return false; + } - // Setup stream socket listener. - stream_socket_listener_ = StreamSocketListener(); + return true; + } else { + // Get current IP addresses of the device. + ip_addresses_ = Get4BytesIpv4Addresses(); + + if (ip_addresses_.empty()) { + LOG(WARNING) << "failed to start accepting connection without IP " + "addresses configured on computer."; + return false; + } + // Setup stream socket listener. + stream_socket_listener_ = StreamSocketListener(); - stream_socket_listener_.Control().QualityOfService( - SocketQualityOfService::LowLatency); + stream_socket_listener_.Control().QualityOfService( + SocketQualityOfService::LowLatency); - stream_socket_listener_.Control().KeepAlive(true); + stream_socket_listener_.Control().KeepAlive(true); - // Setup socket event of ConnectionReceived. - listener_event_token_ = stream_socket_listener_.ConnectionReceived( - {this, &WifiLanServerSocket::Listener_ConnectionReceived}); + // Setup socket event of ConnectionReceived. + listener_event_token_ = stream_socket_listener_.ConnectionReceived( + {this, &WifiLanServerSocket::Listener_ConnectionReceived}); - try { - stream_socket_listener_.BindServiceNameAsync(winrt::to_hstring(port_)) - .get(); - if (port_ == 0) { - port_ = - std::stoi(stream_socket_listener_.Information().LocalPort().c_str()); + try { + stream_socket_listener_.BindServiceNameAsync(winrt::to_hstring(port_)) + .get(); + if (port_ == 0) { + port_ = std::stoi( + stream_socket_listener_.Information().LocalPort().c_str()); + } + + return true; + } catch (std::exception exception) { + LOG(ERROR) << __func__ + << ": Cannot accept connection on preferred port. Exception: " + << exception.what(); + } catch (const winrt::hresult_error& error) { + LOG(ERROR) + << __func__ + << ": Cannot accept connection on preferred port. WinRT exception: " + << error.code() << ": " << winrt::to_string(error.message()); + } catch (...) { + LOG(ERROR) << __func__ << ": Unknown exception."; } - return true; - } catch (std::exception exception) { - LOG(ERROR) << __func__ - << ": Cannot accept connection on preferred port. Exception: " - << exception.what(); - } catch (const winrt::hresult_error& error) { - LOG(ERROR) - << __func__ - << ": Cannot accept connection on preferred port. WinRT exception: " - << error.code() << ": " << winrt::to_string(error.message()); - } catch (...) { - LOG(ERROR) << __func__ << ": Unknown exeption."; - } + try { + stream_socket_listener_.BindServiceNameAsync({}).get(); - try { - stream_socket_listener_.BindServiceNameAsync({}).get(); + // Need to save the port information. + port_ = + std::stoi(stream_socket_listener_.Information().LocalPort().c_str()); + return true; + } catch (std::exception exception) { + LOG(ERROR) << __func__ << ": Cannot bind to any port. Exception: " + << exception.what(); + } catch (const winrt::hresult_error& error) { + LOG(ERROR) << __func__ << ": Cannot bind to any port. WinRT exception: " + << error.code() << ": " << winrt::to_string(error.message()); + } catch (...) { + LOG(ERROR) << __func__ << ": Unknown exception."; + } - // Need to save the port information. - port_ = - std::stoi(stream_socket_listener_.Information().LocalPort().c_str()); - return true; - } catch (std::exception exception) { - LOG(ERROR) << __func__ - << ": Cannot bind to any port. Exception: " << exception.what(); - } catch (const winrt::hresult_error& error) { - LOG(ERROR) << __func__ - << ": Cannot bind to any port. WinRT exception: " << error.code() - << ": " << winrt::to_string(error.message()); - } catch (...) { - LOG(ERROR) << __func__ << ": Unknown exeption."; + return false; } - - return false; } fire_and_forget WifiLanServerSocket::Listener_ConnectionReceived( diff --git a/internal/platform/implementation/windows/wifi_lan_socket.cc b/internal/platform/implementation/windows/wifi_lan_socket.cc index 286ca44ba6..5843831ed0 100644 --- a/internal/platform/implementation/windows/wifi_lan_socket.cc +++ b/internal/platform/implementation/windows/wifi_lan_socket.cc @@ -15,10 +15,15 @@ #include #include #include +#include +#include #include +#include "internal/flags/nearby_flags.h" #include "internal/platform/byte_array.h" #include "internal/platform/exception.h" +#include "internal/platform/flags/nearby_platform_feature_flags.h" +#include "internal/platform/implementation/windows/nearby_client_socket.h" #include "internal/platform/implementation/windows/wifi_lan.h" #include "internal/platform/input_stream.h" #include "internal/platform/logging.h" @@ -27,7 +32,19 @@ namespace nearby { namespace windows { +WifiLanSocket::WifiLanSocket() { + enable_blocking_socket_ = NearbyFlags::GetInstance().GetBoolFlag( + nearby::platform::config_package_nearby::nearby_platform_feature:: + kEnableBlockingSocket); + client_socket_ = std::make_unique(); + input_stream_ = SocketInputStream(client_socket_.get()); + output_stream_ = SocketOutputStream(client_socket_.get()); +} + WifiLanSocket::WifiLanSocket(StreamSocket socket) { + enable_blocking_socket_ = NearbyFlags::GetInstance().GetBoolFlag( + nearby::platform::config_package_nearby::nearby_platform_feature:: + kEnableBlockingSocket); stream_soket_ = socket; LOG(INFO) << "Socket send buffer size: " << socket.Control().OutboundBufferSizeInBytes(); @@ -38,182 +55,262 @@ WifiLanSocket::WifiLanSocket(StreamSocket socket) { output_stream_ = SocketOutputStream(socket.OutputStream()); } -WifiLanSocket::~WifiLanSocket() { - try { - if (stream_soket_ != nullptr) { - Close(); - } - } catch (std::exception exception) { - LOG(ERROR) << __func__ << ": Exception: " << exception.what(); - } catch (const winrt::hresult_error& error) { - LOG(ERROR) << __func__ << ": WinRT exception: " << error.code() << ": " - << winrt::to_string(error.message()); - } catch (...) { - LOG(ERROR) << __func__ << ": Unknown exeption."; - } +WifiLanSocket::WifiLanSocket(std::unique_ptr socket) { + enable_blocking_socket_ = NearbyFlags::GetInstance().GetBoolFlag( + nearby::platform::config_package_nearby::nearby_platform_feature:: + kEnableBlockingSocket); + client_socket_ = std::move(socket); + input_stream_ = SocketInputStream(client_socket_.get()); + output_stream_ = SocketOutputStream(client_socket_.get()); } +WifiLanSocket::~WifiLanSocket() { Close(); } + InputStream& WifiLanSocket::GetInputStream() { return input_stream_; } OutputStream& WifiLanSocket::GetOutputStream() { return output_stream_; } Exception WifiLanSocket::Close() { - try { - if (stream_soket_ != nullptr) { - stream_soket_.Close(); + if (enable_blocking_socket_) { + if (client_socket_ != nullptr) { + return client_socket_->Close(); } + return {Exception::kSuccess}; - } catch (std::exception exception) { - LOG(ERROR) << __func__ << ": Exception: " << exception.what(); - return {Exception::kIo}; - } catch (const winrt::hresult_error& error) { - LOG(ERROR) << __func__ << ": WinRT exception: " << error.code() << ": " - << winrt::to_string(error.message()); - return {Exception::kIo}; - } catch (...) { - LOG(ERROR) << __func__ << ": Unknown exeption."; - return {Exception::kIo}; + } else { + try { + if (stream_soket_ != nullptr) { + stream_soket_.Close(); + } + return {Exception::kSuccess}; + } catch (std::exception exception) { + LOG(ERROR) << __func__ << ": Exception: " << exception.what(); + return {Exception::kIo}; + } catch (const winrt::hresult_error& error) { + LOG(ERROR) << __func__ << ": WinRT exception: " << error.code() << ": " + << winrt::to_string(error.message()); + return {Exception::kIo}; + } catch (...) { + LOG(ERROR) << __func__ << ": Unknown exception."; + return {Exception::kIo}; + } } } +bool WifiLanSocket::Connect(const std::string& ip_address, int port) { + return client_socket_->Connect(ip_address, port); +} + // SocketInputStream WifiLanSocket::SocketInputStream::SocketInputStream(IInputStream input_stream) { + enable_blocking_socket_ = NearbyFlags::GetInstance().GetBoolFlag( + nearby::platform::config_package_nearby::nearby_platform_feature:: + kEnableBlockingSocket); input_stream_ = input_stream; } +WifiLanSocket::SocketInputStream::SocketInputStream( + NearbyClientSocket* client_socket) { + enable_blocking_socket_ = NearbyFlags::GetInstance().GetBoolFlag( + nearby::platform::config_package_nearby::nearby_platform_feature:: + kEnableBlockingSocket); + client_socket_ = client_socket; +} + ExceptionOr WifiLanSocket::SocketInputStream::Read( std::int64_t size) { - try { - if (read_buffer_ == nullptr || read_buffer_.Capacity() < size) { - read_buffer_ = Buffer(size); + if (enable_blocking_socket_) { + if (client_socket_ == nullptr) { + LOG(ERROR) << "Failed to read data due to no client socket."; + return {Exception::kIo}; } - // Reset the buffer length to 0. - read_buffer_.Length(0); + return client_socket_->Read(size); + } else { + try { + if (read_buffer_ == nullptr || read_buffer_.Capacity() < size) { + read_buffer_ = Buffer(size); + } - auto ibuffer = - input_stream_.ReadAsync(read_buffer_, size, InputStreamOptions::None) - .get(); + // Reset the buffer length to 0. + read_buffer_.Length(0); - if (ibuffer.Length() != size) { - LOG(WARNING) << "Only read partial of data: [" << ibuffer.Length() << "/" - << size << "]."; - } + auto ibuffer = + input_stream_.ReadAsync(read_buffer_, size, InputStreamOptions::None) + .get(); - ByteArray data((char*)ibuffer.data(), ibuffer.Length()); - return ExceptionOr(std::move(data)); - } catch (std::exception exception) { - LOG(ERROR) << __func__ << ": Exception: " << exception.what(); - return {Exception::kIo}; - } catch (const winrt::hresult_error& error) { - LOG(ERROR) << __func__ << ": WinRT exception: " << error.code() << ": " - << winrt::to_string(error.message()); - return {Exception::kIo}; - } catch (...) { - LOG(ERROR) << __func__ << ": Unknown exeption."; - return {Exception::kIo}; + if (ibuffer.Length() != size) { + LOG(WARNING) << "Only read partial of data: [" << ibuffer.Length() + << "/" << size << "]."; + } + + ByteArray data((char*)ibuffer.data(), ibuffer.Length()); + return ExceptionOr(std::move(data)); + } catch (std::exception exception) { + LOG(ERROR) << __func__ << ": Exception: " << exception.what(); + return {Exception::kIo}; + } catch (const winrt::hresult_error& error) { + LOG(ERROR) << __func__ << ": WinRT exception: " << error.code() << ": " + << winrt::to_string(error.message()); + return {Exception::kIo}; + } catch (...) { + LOG(ERROR) << __func__ << ": Unknown exception."; + return {Exception::kIo}; + } } } ExceptionOr WifiLanSocket::SocketInputStream::Skip(size_t offset) { - try { - Buffer buffer = Buffer(offset); - - auto ibuffer = - input_stream_.ReadAsync(buffer, offset, InputStreamOptions::None).get(); - return ExceptionOr((size_t)ibuffer.Length()); - } catch (std::exception exception) { - LOG(ERROR) << __func__ << ": Exception: " << exception.what(); - return {Exception::kIo}; - } catch (const winrt::hresult_error& error) { - LOG(ERROR) << __func__ << ": WinRT exception: " << error.code() << ": " - << winrt::to_string(error.message()); - return {Exception::kIo}; - } catch (...) { - LOG(ERROR) << __func__ << ": Unknown exeption."; - return {Exception::kIo}; + if (enable_blocking_socket_) { + if (client_socket_ == nullptr) { + return {Exception::kIo}; + } + + return client_socket_->Skip(offset); + } else { + try { + Buffer buffer = Buffer(offset); + + auto ibuffer = + input_stream_.ReadAsync(buffer, offset, InputStreamOptions::None) + .get(); + return ExceptionOr((size_t)ibuffer.Length()); + } catch (std::exception exception) { + LOG(ERROR) << __func__ << ": Exception: " << exception.what(); + return {Exception::kIo}; + } catch (const winrt::hresult_error& error) { + LOG(ERROR) << __func__ << ": WinRT exception: " << error.code() << ": " + << winrt::to_string(error.message()); + return {Exception::kIo}; + } catch (...) { + LOG(ERROR) << __func__ << ": Unknown exception."; + return {Exception::kIo}; + } } } Exception WifiLanSocket::SocketInputStream::Close() { - try { - input_stream_.Close(); - return {Exception::kSuccess}; - } catch (std::exception exception) { - LOG(ERROR) << __func__ << ": Exception: " << exception.what(); - return {Exception::kIo}; - } catch (const winrt::hresult_error& error) { - LOG(ERROR) << __func__ << ": WinRT exception: " << error.code() << ": " - << winrt::to_string(error.message()); - return {Exception::kIo}; - } catch (...) { - LOG(ERROR) << __func__ << ": Unknown exeption."; - return {Exception::kIo}; + if (enable_blocking_socket_) { + if (client_socket_ == nullptr) { + return {Exception::kIo}; + } + + return client_socket_->Close(); + } else { + try { + input_stream_.Close(); + return {Exception::kSuccess}; + } catch (std::exception exception) { + LOG(ERROR) << __func__ << ": Exception: " << exception.what(); + return {Exception::kIo}; + } catch (const winrt::hresult_error& error) { + LOG(ERROR) << __func__ << ": WinRT exception: " << error.code() << ": " + << winrt::to_string(error.message()); + return {Exception::kIo}; + } catch (...) { + LOG(ERROR) << __func__ << ": Unknown exception."; + return {Exception::kIo}; + } } } // SocketOutputStream WifiLanSocket::SocketOutputStream::SocketOutputStream( IOutputStream output_stream) { + enable_blocking_socket_ = NearbyFlags::GetInstance().GetBoolFlag( + nearby::platform::config_package_nearby::nearby_platform_feature:: + kEnableBlockingSocket); output_stream_ = output_stream; } +WifiLanSocket::SocketOutputStream::SocketOutputStream( + NearbyClientSocket* client_socket) { + enable_blocking_socket_ = NearbyFlags::GetInstance().GetBoolFlag( + nearby::platform::config_package_nearby::nearby_platform_feature:: + kEnableBlockingSocket); + client_socket_ = client_socket; +} + Exception WifiLanSocket::SocketOutputStream::Write(const ByteArray& data) { - try { - Buffer buffer = Buffer(data.size()); - std::memcpy(buffer.data(), data.data(), data.size()); - buffer.Length(data.size()); - uint32_t wrote_bytes = output_stream_.WriteAsync(buffer).get(); - if (wrote_bytes != data.size()) { - LOG(WARNING) << "Only wrote partial of data:[" << wrote_bytes << "/" - << data.size() << "]."; + if (enable_blocking_socket_) { + if (client_socket_ == nullptr) { + return {Exception::kIo}; } - return {Exception::kSuccess}; - } catch (std::exception exception) { - LOG(ERROR) << __func__ << ": Exception: " << exception.what(); - return {Exception::kIo}; - } catch (const winrt::hresult_error& error) { - LOG(ERROR) << __func__ << ": WinRT exception: " << error.code() << ": " - << winrt::to_string(error.message()); - return {Exception::kIo}; - } catch (...) { - LOG(ERROR) << __func__ << ": Unknown exeption."; - return {Exception::kIo}; + return client_socket_->Write(data); + } else { + try { + Buffer buffer = Buffer(data.size()); + std::memcpy(buffer.data(), data.data(), data.size()); + buffer.Length(data.size()); + uint32_t wrote_bytes = output_stream_.WriteAsync(buffer).get(); + if (wrote_bytes != data.size()) { + LOG(WARNING) << "Only wrote partial of data:[" << wrote_bytes << "/" + << data.size() << "]."; + } + + return {Exception::kSuccess}; + } catch (std::exception exception) { + LOG(ERROR) << __func__ << ": Exception: " << exception.what(); + return {Exception::kIo}; + } catch (const winrt::hresult_error& error) { + LOG(ERROR) << __func__ << ": WinRT exception: " << error.code() << ": " + << winrt::to_string(error.message()); + return {Exception::kIo}; + } catch (...) { + LOG(ERROR) << __func__ << ": Unknown exception."; + return {Exception::kIo}; + } } } Exception WifiLanSocket::SocketOutputStream::Flush() { - try { - output_stream_.FlushAsync().get(); - return {Exception::kSuccess}; - } catch (std::exception exception) { - LOG(ERROR) << __func__ << ": Exception: " << exception.what(); - return {Exception::kIo}; - } catch (const winrt::hresult_error& error) { - LOG(ERROR) << __func__ << ": WinRT exception: " << error.code() << ": " - << winrt::to_string(error.message()); - return {Exception::kIo}; - } catch (...) { - LOG(ERROR) << __func__ << ": Unknown exeption."; - return {Exception::kIo}; + if (enable_blocking_socket_) { + if (client_socket_ == nullptr) { + return {Exception::kIo}; + } + + return client_socket_->Flush(); + } else { + try { + output_stream_.FlushAsync().get(); + return {Exception::kSuccess}; + } catch (std::exception exception) { + LOG(ERROR) << __func__ << ": Exception: " << exception.what(); + return {Exception::kIo}; + } catch (const winrt::hresult_error& error) { + LOG(ERROR) << __func__ << ": WinRT exception: " << error.code() << ": " + << winrt::to_string(error.message()); + return {Exception::kIo}; + } catch (...) { + LOG(ERROR) << __func__ << ": Unknown exception."; + return {Exception::kIo}; + } } } Exception WifiLanSocket::SocketOutputStream::Close() { - try { - output_stream_.Close(); - return {Exception::kSuccess}; - } catch (std::exception exception) { - LOG(ERROR) << __func__ << ": Exception: " << exception.what(); - return {Exception::kIo}; - } catch (const winrt::hresult_error& error) { - LOG(ERROR) << __func__ << ": WinRT exception: " << error.code() << ": " - << winrt::to_string(error.message()); - return {Exception::kIo}; - } catch (...) { - LOG(ERROR) << __func__ << ": Unknown exeption."; - return {Exception::kIo}; + if (enable_blocking_socket_) { + if (client_socket_ == nullptr) { + return {Exception::kIo}; + } + + return client_socket_->Close(); + } else { + try { + output_stream_.Close(); + return {Exception::kSuccess}; + } catch (std::exception exception) { + LOG(ERROR) << __func__ << ": Exception: " << exception.what(); + return {Exception::kIo}; + } catch (const winrt::hresult_error& error) { + LOG(ERROR) << __func__ << ": WinRT exception: " << error.code() << ": " + << winrt::to_string(error.message()); + return {Exception::kIo}; + } catch (...) { + LOG(ERROR) << __func__ << ": Unknown exception."; + return {Exception::kIo}; + } } }