diff --git a/connections/implementation/BUILD b/connections/implementation/BUILD index 58125a8fa7..a7edecdc7c 100644 --- a/connections/implementation/BUILD +++ b/connections/implementation/BUILD @@ -179,6 +179,7 @@ cc_library( deps = [ ":internal", "//connections:core_types", + "//connections/implementation/analytics", "//connections/implementation/flags:connections_flags", "//connections/v3:v3_types", "//internal/flags:nearby_flags", @@ -188,6 +189,7 @@ cc_library( "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/functional:bind_front", "@com_google_absl//absl/strings", + "@com_google_absl//absl/time", "@com_google_googletest//:gtest_for_library_testonly", ], ) diff --git a/connections/implementation/base_endpoint_channel.cc b/connections/implementation/base_endpoint_channel.cc index d810d40ebe..4a58e90698 100644 --- a/connections/implementation/base_endpoint_channel.cc +++ b/connections/implementation/base_endpoint_channel.cc @@ -319,6 +319,11 @@ void BaseEndpointChannel::CloseIo() { } } +uint32_t BaseEndpointChannel::GetNextKeepAliveSeqNo() const { + MutexLock lock(&keep_alive_mutex_); + return next_keep_alive_seq_no_++; +} + void BaseEndpointChannel::SetAnalyticsRecorder( analytics::AnalyticsRecorder* analytics_recorder, const std::string& endpoint_id) { diff --git a/connections/implementation/base_endpoint_channel.h b/connections/implementation/base_endpoint_channel.h index e5e3b71a93..9131baa989 100644 --- a/connections/implementation/base_endpoint_channel.h +++ b/connections/implementation/base_endpoint_channel.h @@ -15,6 +15,7 @@ #ifndef CORE_INTERNAL_BASE_ENDPOINT_CHANNEL_H_ #define CORE_INTERNAL_BASE_ENDPOINT_CHANNEL_H_ +#include #include #include @@ -84,6 +85,7 @@ class BaseEndpointChannel : public EndpointChannel { ABSL_LOCKS_EXCLUDED(last_read_mutex_) override; absl::Time GetLastWriteTimestamp() const ABSL_LOCKS_EXCLUDED(last_write_mutex_) override; + uint32_t GetNextKeepAliveSeqNo() const override; void SetAnalyticsRecorder(analytics::AnalyticsRecorder* analytics_recorder, const std::string& endpoint_id) override; @@ -117,6 +119,10 @@ class BaseEndpointChannel : public EndpointChannel { absl::Time last_write_timestamp_ ABSL_GUARDED_BY(last_write_mutex_) = absl::InfinitePast(); + mutable Mutex keep_alive_mutex_; + mutable uint32_t next_keep_alive_seq_no_ ABSL_GUARDED_BY(keep_alive_mutex_) = + 0; + const std::string service_id_; const std::string channel_name_; diff --git a/connections/implementation/connections_authentication_transport_test.cc b/connections/implementation/connections_authentication_transport_test.cc index c471a80158..a2293018f4 100644 --- a/connections/implementation/connections_authentication_transport_test.cc +++ b/connections/implementation/connections_authentication_transport_test.cc @@ -14,6 +14,7 @@ #include "connections/implementation/connections_authentication_transport.h" +#include #include #include #include @@ -75,6 +76,7 @@ class MockEndpointChannel : public EndpointChannel { MOCK_METHOD(void, Resume, (), (override)); MOCK_METHOD(absl::Time, GetLastReadTimestamp, (), (const, override)); MOCK_METHOD(absl::Time, GetLastWriteTimestamp, (), (const, override)); + MOCK_METHOD(uint32_t, GetNextKeepAliveSeqNo, (), (const, override)); MOCK_METHOD(void, SetAnalyticsRecorder, (analytics::AnalyticsRecorder*, const std::string&), (override)); diff --git a/connections/implementation/encryption_runner_test.cc b/connections/implementation/encryption_runner_test.cc index a691efb9e6..84b2249302 100644 --- a/connections/implementation/encryption_runner_test.cc +++ b/connections/implementation/encryption_runner_test.cc @@ -15,6 +15,7 @@ #include "connections/implementation/encryption_runner.h" #include +#include #include #include "gtest/gtest.h" @@ -103,6 +104,9 @@ class FakeEndpointChannel : public EndpointChannel { void Resume() override {} absl::Time GetLastReadTimestamp() const override { return read_timestamp_; } absl::Time GetLastWriteTimestamp() const override { return write_timestamp_; } + uint32_t GetNextKeepAliveSeqNo() const override { + return next_keep_alive_seq_no_++; + } void SetAnalyticsRecorder(analytics::AnalyticsRecorder* analytics_recorder, const std::string& endpoint_id) override {} @@ -111,6 +115,7 @@ class FakeEndpointChannel : public EndpointChannel { OutputStream* out_ = nullptr; absl::Time read_timestamp_ = absl::InfinitePast(); absl::Time write_timestamp_ = absl::InfinitePast(); + mutable uint32_t next_keep_alive_seq_no_ = 0; }; struct User { diff --git a/connections/implementation/endpoint_channel.h b/connections/implementation/endpoint_channel.h index 2de11abb68..d1b1633b77 100644 --- a/connections/implementation/endpoint_channel.h +++ b/connections/implementation/endpoint_channel.h @@ -15,9 +15,12 @@ #ifndef CORE_INTERNAL_ENDPOINT_CHANNEL_H_ #define CORE_INTERNAL_ENDPOINT_CHANNEL_H_ +#include +#include #include #include "securegcm/d2d_connection_context_v1.h" +#include "absl/time/time.h" #include "connections/implementation/analytics/analytics_recorder.h" #include "connections/implementation/analytics/packet_meta_data.h" #include "internal/platform/byte_array.h" @@ -125,13 +128,16 @@ class EndpointChannel { // writes have occurred. virtual absl::Time GetLastWriteTimestamp() const = 0; + // Returns the next sequence number to be used for a KeepAlive frame. + virtual uint32_t GetNextKeepAliveSeqNo() const = 0; + // Sets the AnalyticsRecorder instance for analytics. virtual void SetAnalyticsRecorder( analytics::AnalyticsRecorder* analytics_recorder, const std::string& endpoint_id) = 0; // Enables the multiplex socket on the EndpointChannel. - virtual bool EnableMultiplexSocket() {return false;} + virtual bool EnableMultiplexSocket() { return false; } }; inline bool operator==(const EndpointChannel& lhs, const EndpointChannel& rhs) { diff --git a/connections/implementation/endpoint_manager.cc b/connections/implementation/endpoint_manager.cc index 26236aaf36..8759968c66 100644 --- a/connections/implementation/endpoint_manager.cc +++ b/connections/implementation/endpoint_manager.cc @@ -31,6 +31,7 @@ #include "connections/implementation/endpoint_channel_manager.h" #include "connections/implementation/offline_frames.h" #include "connections/implementation/proto/offline_wire_formats.pb.h" +#include "connections/implementation/proto/offline_wire_formats.proto.h" #include "connections/implementation/service_id_constants.h" #include "connections/listeners.h" #include "connections/medium_selector.h" @@ -53,6 +54,7 @@ namespace connections { namespace { using ::location::nearby::analytics::proto::ConnectionsLog; +using ::location::nearby::connections::KeepAliveFrame; using ::location::nearby::connections::OfflineFrame; using ::location::nearby::connections::PayloadTransferFrame; using ::location::nearby::connections::V1Frame; @@ -124,8 +126,8 @@ void EndpointManager::EndpointChannelLoopRunnable( // will retry and attempt to pick another channel. // If channel is deleted (no mapping), or it is still the same channel // (same Medium) on which we got the Exception::kIo, we terminate the loop. - NEARBY_LOGS(INFO) << "Started worker loop name=" << runnable_name - << ", endpoint=" << endpoint_id; + LOG(INFO) << "Started worker loop name=" << runnable_name + << ", endpoint=" << endpoint_id; Medium last_failed_medium = Medium::UNKNOWN_MEDIUM; while (true) { // It's important to keep re-fetching the EndpointChannel for an endpoint @@ -134,7 +136,7 @@ void EndpointManager::EndpointChannelLoopRunnable( std::shared_ptr channel = channel_manager_->GetChannelForEndpoint(endpoint_id); if (channel == nullptr) { - NEARBY_LOGS(INFO) << "Endpoint channel is nullptr, bail out."; + LOG(INFO) << "Endpoint channel is nullptr, bail out."; break; } @@ -142,7 +144,7 @@ void EndpointManager::EndpointChannelLoopRunnable( // EndpointChannel for this endpoint, there's nothing more to do here. if ((last_failed_medium != Medium::UNKNOWN_MEDIUM) && (channel->GetMedium() == last_failed_medium)) { - NEARBY_LOGS(INFO) + LOG(INFO) << "No new endpoint channel is found after a failure, exit loop."; break; } @@ -157,19 +159,17 @@ void EndpointManager::EndpointChannelLoopRunnable( // detail. if (exception.Raised(Exception::kInvalidProtocolBuffer)) { last_failed_medium = channel->GetMedium(); - NEARBY_LOGS(INFO) - << "Received invalid protobuf message, re-fetching endpoint " - "channel; last_failed_medium=" - << location::nearby::proto::connections::Medium_Name( - last_failed_medium); + LOG(INFO) << "Received invalid protobuf message, re-fetching endpoint " + "channel; last_failed_medium=" + << location::nearby::proto::connections::Medium_Name( + last_failed_medium); continue; } if (exception.Raised(Exception::kIo)) { last_failed_medium = channel->GetMedium(); - NEARBY_LOGS(INFO) - << "Endpoint channel IO exception; last_failed_medium=" - << location::nearby::proto::connections::Medium_Name( - last_failed_medium); + LOG(INFO) << "Endpoint channel IO exception; last_failed_medium=" + << location::nearby::proto::connections::Medium_Name( + last_failed_medium); continue; } if (exception.Raised(Exception::kInterrupted)) { @@ -178,9 +178,9 @@ void EndpointManager::EndpointChannelLoopRunnable( } if (!keep_using_channel.result()) { - NEARBY_LOGS(INFO) << "Dropping current channel: last medium=" - << location::nearby::proto::connections::Medium_Name( - last_failed_medium); + LOG(INFO) << "Dropping current channel: last medium=" + << location::nearby::proto::connections::Medium_Name( + last_failed_medium); if (client->IsSafeToDisconnectEnabled(endpoint_id)) { channel_manager_->MarkEndpointStopWaitToDisconnect( endpoint_id, /* is_safe_to_disconnect */ false, @@ -191,13 +191,13 @@ void EndpointManager::EndpointChannelLoopRunnable( } // Indicate we're out of the loop and it is ok to schedule another instance // if needed. - NEARBY_LOGS(INFO) << "Worker going down; worker name=" << runnable_name - << "; endpoint_id=" << endpoint_id; + LOG(INFO) << "Worker going down; worker name=" << runnable_name + << "; endpoint_id=" << endpoint_id; // Always clear out all state related to this endpoint before terminating // this thread. DiscardEndpoint(client, endpoint_id, DisconnectionReason::IO_ERROR); - NEARBY_LOGS(INFO) << "Worker done; worker name=" << runnable_name - << "; endpoint_id=" << endpoint_id; + LOG(INFO) << "Worker done; worker name=" << runnable_name + << "; endpoint_id=" << endpoint_id; } ExceptionOr EndpointManager::TryDecryptFrame( @@ -215,8 +215,7 @@ ExceptionOr EndpointManager::TryDecryptFrame( } auto elapsed = SystemClock::ElapsedRealtime() - start_time; if (elapsed > kDecryptRetryTimeout) { - NEARBY_LOGS(WARNING) << "Can't decrypt the message. Timeout after " - << elapsed; + LOG(WARNING) << "Can't decrypt the message. Timeout after " << elapsed; return Exception::kTimeout; } SystemClock::Sleep(absl::Milliseconds(1)); @@ -236,8 +235,7 @@ ExceptionOr EndpointManager::HandleData( PacketMetaData packet_meta_data; ExceptionOr bytes = endpoint_channel->Read(packet_meta_data); if (!bytes.ok()) { - NEARBY_LOGS(INFO) << "Stop reading on read-time exception: " - << bytes.exception(); + LOG(INFO) << "Stop reading on read-time exception: " << bytes.exception(); return ExceptionOr(bytes.exception()); } ExceptionOr wrapped_frame = parser::FromBytes(bytes.result()); @@ -261,13 +259,12 @@ ExceptionOr EndpointManager::HandleData( if (!wrapped_frame.ok()) { if (wrapped_frame.GetException().Raised( Exception::kInvalidProtocolBuffer)) { - NEARBY_LOGS(INFO) << "Failed to decode; endpoint=" << endpoint_id - << "; channel=" << endpoint_channel->GetType() - << "; skip"; + LOG(INFO) << "Failed to decode; endpoint=" << endpoint_id + << "; channel=" << endpoint_channel->GetType() << "; skip"; continue; } else { - NEARBY_LOGS(INFO) << "Stop reading on parse-time exception: " - << wrapped_frame.exception(); + LOG(INFO) << "Stop reading on parse-time exception: " + << wrapped_frame.exception(); return ExceptionOr(wrapped_frame.exception()); } } @@ -280,14 +277,33 @@ ExceptionOr EndpointManager::HandleData( // report messages without handlers, except KEEP_ALIVE, which has // no explicit handler. if (frame_type == V1Frame::KEEP_ALIVE) { - NEARBY_LOGS(INFO) << "KeepAlive message for endpoint " << endpoint_id; + KeepAliveFrame keep_alive_frame = frame.v1().keep_alive(); + bool ack = keep_alive_frame.has_ack() ? keep_alive_frame.ack() : false; + uint32_t seq_num = + keep_alive_frame.has_seq_num() ? keep_alive_frame.seq_num() : 0; + + LOG(INFO) << "Received a KEEP_ALIVE frame (ack:" << ack + << ",seq:" << seq_num << ") from endpoint " << endpoint_id + << " on channel " << endpoint_channel->GetType() + << (ack ? "" : " and reply a KEEP_ALIVE ACK frame."); + if (!ack && !endpoint_channel->IsPaused()) { + Exception write_exception = endpoint_channel->Write( + parser::ForKeepAlive(/*ack=*/true, /*seq_num=*/seq_num)); + if (!write_exception.Ok()) { + LOG(ERROR) + << "Failed to reply KEEP_ALIVE ack frame (ack:true, seq_num:" + << seq_num << ") to endpoint " << endpoint_id << " on channel " + << endpoint_channel->GetType(); + return ExceptionOr(write_exception); + } + } } else if (frame_type == V1Frame::DISCONNECTION) { - NEARBY_LOGS(INFO) << "Disconnect message for endpoint " << endpoint_id; + LOG(INFO) << "Disconnect message from endpoint " << endpoint_id + << " on channel " << endpoint_channel->GetType(); ProcessDisconnectionFrame(client, endpoint_id, endpoint_channel, frame); } else { - NEARBY_LOGS(ERROR) << "Unhandled message: endpoint_id=" << endpoint_id - << ", frame type=" - << V1Frame::FrameType_Name(frame_type); + LOG(ERROR) << "Unhandled message: endpoint_id=" << endpoint_id + << ", frame type=" << V1Frame::FrameType_Name(frame_type); } continue; } @@ -302,10 +318,9 @@ void EndpointManager::ProcessDisconnectionFrame( ClientProxy* client, const std::string& endpoint_id, EndpointChannel* endpoint_channel, OfflineFrame& frame) { if (!client->IsSafeToDisconnectEnabled(endpoint_id)) { - NEARBY_LOGS(INFO) - << "EndpointManager received a DISCONNECTION frame from endpoint " - << endpoint_id << " on channel " << endpoint_channel->GetType() - << ", disconnecting..."; + LOG(INFO) << "EndpointManager received a DISCONNECTION frame from endpoint " + << endpoint_id << " on channel " << endpoint_channel->GetType() + << ", disconnecting..."; endpoint_channel->Close(DisconnectionReason::REMOTE_DISCONNECTION); return; } @@ -313,14 +328,14 @@ void EndpointManager::ProcessDisconnectionFrame( if (!frame.v1().has_disconnection() || !frame.v1().disconnection().has_request_safe_to_disconnect() || !frame.v1().disconnection().request_safe_to_disconnect()) { - NEARBY_LOGS(INFO) << "[safe-to-disconnect] no need to apply " - "safe-to-disconnect protocol for endpoint " - << endpoint_id << " on channel " - << endpoint_channel->GetType() << ", disconnecting..."; + LOG(INFO) << "[safe-to-disconnect] no need to apply " + "safe-to-disconnect protocol for endpoint " + << endpoint_id << " on channel " << endpoint_channel->GetType() + << ", disconnecting..."; endpoint_channel->Close(DisconnectionReason::REMOTE_DISCONNECTION); return; } - NEARBY_LOGS(INFO) + LOG(INFO) << "[safe-to-disconnect] received a " "DISCONNECTION frame with request safe to disconnect = true and ack = " << frame.v1().disconnection().ack_safe_to_disconnect() @@ -341,15 +356,15 @@ void EndpointManager::ProcessDisconnectionFrame( DisconnectionReason::REMOTE_DISCONNECTION); }); endpoint_channel->Resume(); - NEARBY_LOGS(INFO) << "[safe-to-disconnect] Sending " - "DISCONNECTION frame with request 1, ack 1"; + LOG(INFO) << "[safe-to-disconnect] Sending " + "DISCONNECTION frame with request 1, ack 1"; Exception write_exception = endpoint_channel->Write( parser::ForDisconnection(/* request_safe_to_disconnect= */ true, /* ack_safe_to_disconnect= */ true)); if (!write_exception.Ok()) { - NEARBY_LOGS(INFO) << "[safe-to-disconnect] Failed to send " - "DISCONNECTION frame with ack to endpoint" - << endpoint_id; + LOG(INFO) << "[safe-to-disconnect] Failed to send " + "DISCONNECTION frame with ack to endpoint" + << endpoint_id; } } } @@ -380,11 +395,17 @@ ExceptionOr EndpointManager::HandleKeepAlive( : last_write_time + keep_alive_interval - SystemClock::ElapsedRealtime(); if (duration_until_write_keep_alive <= absl::ZeroDuration()) { - Exception write_exception = endpoint_channel->Write(parser::ForKeepAlive()); + uint32_t seq_num = endpoint_channel->GetNextKeepAliveSeqNo(); + Exception write_exception = endpoint_channel->Write( + parser::ForKeepAlive(/*ack=*/false, /*seq_num=*/seq_num)); if (!write_exception.Ok()) { + LOG(ERROR) << "Failed to send KEEP_ALIVE frame (ack:false, seq_num:" + << seq_num << ") on channel " << endpoint_channel->GetType(); return ExceptionOr(write_exception); } duration_until_write_keep_alive = keep_alive_interval; + LOG(INFO) << "Sent a KEEP_ALIVE frame (ack:false, seq_num:" << seq_num + << ") on channel " << endpoint_channel->GetType(); } absl::Duration wait_for = @@ -423,7 +444,7 @@ EndpointManager::EndpointManager( : channel_manager_(manager), serial_executor_(std::move(serial_executor)) {} EndpointManager::~EndpointManager() { - NEARBY_LOGS(INFO) << "Initiating shutdown of EndpointManager."; + LOG(INFO) << "Initiating shutdown of EndpointManager."; { MutexLock lock(&mutex_); is_shutdown_ = true; @@ -431,33 +452,31 @@ EndpointManager::~EndpointManager() { analytics::ThroughputRecorderContainer::GetInstance().Shutdown(); CountDownLatch latch(1); RunOnEndpointManagerThread("bring-down-endpoints", [this, &latch]() { - NEARBY_LOGS(INFO) << "Bringing down endpoints"; + LOG(INFO) << "Bringing down endpoints"; endpoints_.clear(); latch.CountDown(); }); latch.Await(); - NEARBY_LOGS(INFO) << "Bringing down control thread"; + LOG(INFO) << "Bringing down control thread"; serial_executor_->Shutdown(); - NEARBY_LOGS(INFO) << "EndpointManager is down"; + LOG(INFO) << "EndpointManager is down"; } void EndpointManager::RegisterFrameProcessor( V1Frame::FrameType frame_type, EndpointManager::FrameProcessor* processor) { if (auto frame_processor = GetFrameProcessor(frame_type)) { - NEARBY_LOGS(INFO) << "EndpointManager received request to update " - "registration of frame processor " - << processor << " for frame type " - << V1Frame::FrameType_Name(frame_type) << ", self" - << this; + LOG(INFO) << "EndpointManager received request to update " + "registration of frame processor " + << processor << " for frame type " + << V1Frame::FrameType_Name(frame_type) << ", self" << this; frame_processor.set(processor); } else { MutexLock lock(&frame_processors_lock_); - NEARBY_LOGS(INFO) << "EndpointManager received request to add registration " - "of frame processor " - << processor << " for frame type " - << V1Frame::FrameType_Name(frame_type) - << ", self=" << this; + LOG(INFO) << "EndpointManager received request to add registration " + "of frame processor " + << processor << " for frame type " + << V1Frame::FrameType_Name(frame_type) << ", self=" << this; frame_processors_.emplace(frame_type, processor); } } @@ -465,26 +484,23 @@ void EndpointManager::RegisterFrameProcessor( void EndpointManager::UnregisterFrameProcessor( V1Frame::FrameType frame_type, const EndpointManager::FrameProcessor* processor) { - NEARBY_LOGS(INFO) << "UnregisterFrameProcessor [enter]: processor =" - << processor; + LOG(INFO) << "UnregisterFrameProcessor [enter]: processor =" << processor; if (processor == nullptr) return; if (auto frame_processor = GetFrameProcessor(frame_type)) { if (frame_processor.get() == processor) { frame_processor.reset(); - NEARBY_LOGS(INFO) << "EndpointManager unregister frame processor " - << processor << " for frame type " - << V1Frame::FrameType_Name(frame_type) - << ", self=" << this; + LOG(INFO) << "EndpointManager unregister frame processor " << processor + << " for frame type " << V1Frame::FrameType_Name(frame_type) + << ", self=" << this; } else { - NEARBY_LOGS(INFO) << "EndpointManager cannot unregister frame processor " - << processor - << " because it is not registered for frame type " - << V1Frame::FrameType_Name(frame_type) - << ", expected=" << frame_processor.get(); + LOG(INFO) << "EndpointManager cannot unregister frame processor " + << processor << " because it is not registered for frame type " + << V1Frame::FrameType_Name(frame_type) + << ", expected=" << frame_processor.get(); } } else { - NEARBY_LOGS(INFO) << "UnregisterFrameProcessor [not found]: processor=" - << processor; + LOG(INFO) << "UnregisterFrameProcessor [not found]: processor=" + << processor; } } @@ -502,13 +518,13 @@ void EndpointManager::RemoveEndpointState(const std::string& endpoint_id) { NEARBY_VLOG(1) << "EnsureWorkersTerminated for endpoint " << endpoint_id; auto item = endpoints_.find(endpoint_id); if (item != endpoints_.end()) { - NEARBY_LOGS(INFO) << "EndpointState found for endpoint " << endpoint_id; + LOG(INFO) << "EndpointState found for endpoint " << endpoint_id; // If another instance of data and keep-alive handlers is running, it will // terminate soon. Removing EndpointState waits for workers to complete. endpoints_.erase(item); NEARBY_VLOG(1) << "Workers terminated for endpoint " << endpoint_id; } else { - NEARBY_LOGS(INFO) << "EndpointState not found for endpoint " << endpoint_id; + LOG(INFO) << "EndpointState not found for endpoint " << endpoint_id; } } @@ -527,100 +543,98 @@ void EndpointManager::RegisterEndpoint( // which is copyalbe. We ignore the risk of job not scheduled (and an // associated risk of memory leak), because this may only happen during // service shutdown. - RunOnEndpointManagerThread("register-endpoint", [this, client, - channel = channel.release(), - &endpoint_id, &info, - &connection_options, - &listener, &connection_token, - &latch]() { - if (endpoints_.contains(endpoint_id)) { - NEARBY_LOGS(WARNING) << "Registering duplicate endpoint " << endpoint_id; - // We must remove old endpoint state before registering a new one - // for the same endpoint_id. - RemoveEndpointState(endpoint_id); - } - - absl::Duration keep_alive_interval = - absl::Milliseconds(connection_options.keep_alive_interval_millis); - absl::Duration keep_alive_timeout = - absl::Milliseconds(connection_options.keep_alive_timeout_millis); - NEARBY_LOGS(INFO) << "Registering endpoint " << endpoint_id - << " for client " << client->GetClientId() - << " with keep-alive frame as interval=" - << absl::FormatDuration(keep_alive_interval) - << ", timeout=" - << absl::FormatDuration(keep_alive_timeout); - - // Pass ownership of channel to EndpointChannelManager - NEARBY_LOGS(INFO) << "Registering endpoint with channel manager: endpoint " - << endpoint_id; - channel_manager_->RegisterChannelForEndpoint( - client, endpoint_id, std::unique_ptr(channel)); - - EndpointState& endpoint_state = - endpoints_ - .emplace(endpoint_id, EndpointState(endpoint_id, channel_manager_)) - .first->second; - - NEARBY_LOGS(INFO) << "Starting workers: endpoint " << endpoint_id; - // For every endpoint, there's normally only one Read handler instance - // running on a dedicated thread. This instance reads data from the - // endpoint and delegates incoming frames to various FrameProcessors. - // Once the frame has been properly handled, it starts reading again - // for the next frame. If the handler fails its read and no other - // EndpointChannels are available for this endpoint, a disconnection - // will be initiated. - endpoint_state.StartEndpointReader([this, client, endpoint_id]() { - EndpointChannelLoopRunnable( - "Read", client, endpoint_id, - [this, client, endpoint_id](EndpointChannel* channel) { - return HandleData(endpoint_id, client, channel); - }); - }); - - // For every endpoint, there's only one KeepAliveManager instance - // running on a dedicated thread. This instance will periodically send - // out a ping* to the endpoint while listening for an incoming pong**. - // If it fails to send the ping, or if no pong is heard within - // keep_alive_timeout, it initiates a disconnection. - // - // (*) Bluetooth requires a constant outgoing stream of messages. If - // there's silence, Android will break the socket. This is why we - // ping. - // (**) Wifi Hotspots can fail to notice a connection has been lost, - // and they will happily keep writing to /dev/null. This is why we - // listen for the pong. - NEARBY_VLOG(1) << "EndpointManager enabling KeepAlive for endpoint " - << endpoint_id; - endpoint_state.StartEndpointKeepAliveManager( - [this, client, endpoint_id, keep_alive_interval, keep_alive_timeout]( - Mutex* keep_alive_waiter_mutex, - ConditionVariable* keep_alive_waiter) { + RunOnEndpointManagerThread( + "register-endpoint", + [this, client, channel = channel.release(), &endpoint_id, &info, + &connection_options, &listener, &connection_token, &latch]() { + if (endpoints_.contains(endpoint_id)) { + LOG(WARNING) << "Registering duplicate endpoint " << endpoint_id; + // We must remove old endpoint state before registering a new one + // for the same endpoint_id. + RemoveEndpointState(endpoint_id); + } + + absl::Duration keep_alive_interval = + absl::Milliseconds(connection_options.keep_alive_interval_millis); + absl::Duration keep_alive_timeout = + absl::Milliseconds(connection_options.keep_alive_timeout_millis); + LOG(INFO) << "Registering endpoint " << endpoint_id << " for client " + << client->GetClientId() + << " with keep-alive frame as interval=" + << absl::FormatDuration(keep_alive_interval) + << ", timeout=" << absl::FormatDuration(keep_alive_timeout); + + // Pass ownership of channel to EndpointChannelManager + LOG(INFO) << "Registering endpoint with channel manager: endpoint " + << endpoint_id; + channel_manager_->RegisterChannelForEndpoint( + client, endpoint_id, std::unique_ptr(channel)); + + EndpointState& endpoint_state = + endpoints_ + .emplace(endpoint_id, + EndpointState(endpoint_id, channel_manager_)) + .first->second; + + LOG(INFO) << "Starting workers: endpoint " << endpoint_id; + // For every endpoint, there's normally only one Read handler instance + // running on a dedicated thread. This instance reads data from the + // endpoint and delegates incoming frames to various FrameProcessors. + // Once the frame has been properly handled, it starts reading again + // for the next frame. If the handler fails its read and no other + // EndpointChannels are available for this endpoint, a disconnection + // will be initiated. + endpoint_state.StartEndpointReader([this, client, endpoint_id]() { EndpointChannelLoopRunnable( - "KeepAliveManager", client, endpoint_id, - [this, keep_alive_interval, keep_alive_timeout, - keep_alive_waiter_mutex, - keep_alive_waiter](EndpointChannel* channel) { - return HandleKeepAlive( - channel, keep_alive_interval, keep_alive_timeout, - keep_alive_waiter_mutex, keep_alive_waiter); + "Read", client, endpoint_id, + [this, client, endpoint_id](EndpointChannel* channel) { + return HandleData(endpoint_id, client, channel); }); }); - NEARBY_LOGS(INFO) << "Registering endpoint " << endpoint_id - << ", workers started and notifying client."; - // It's now time to let the client know of this new connection so that - // they can accept or reject it. - client->OnConnectionInitiated(endpoint_id, info, connection_options, - listener, connection_token); - latch.CountDown(); - }); + // For every endpoint, there's only one KeepAliveManager instance + // running on a dedicated thread. This instance will periodically send + // out a ping* to the endpoint while listening for an incoming pong**. + // If it fails to send the ping, or if no pong is heard within + // keep_alive_timeout, it initiates a disconnection. + // + // (*) Bluetooth requires a constant outgoing stream of messages. If + // there's silence, Android will break the socket. This is why we + // ping. + // (**) Wifi Hotspots can fail to notice a connection has been lost, + // and they will happily keep writing to /dev/null. This is why we + // listen for the pong. + NEARBY_VLOG(1) << "EndpointManager enabling KeepAlive for endpoint " + << endpoint_id; + endpoint_state.StartEndpointKeepAliveManager( + [this, client, endpoint_id, keep_alive_interval, + keep_alive_timeout](Mutex* keep_alive_waiter_mutex, + ConditionVariable* keep_alive_waiter) { + EndpointChannelLoopRunnable( + "KeepAliveManager", client, endpoint_id, + [this, keep_alive_interval, keep_alive_timeout, + keep_alive_waiter_mutex, + keep_alive_waiter](EndpointChannel* channel) { + return HandleKeepAlive( + channel, keep_alive_interval, keep_alive_timeout, + keep_alive_waiter_mutex, keep_alive_waiter); + }); + }); + LOG(INFO) << "Registering endpoint " << endpoint_id + << ", workers started and notifying client."; + + // It's now time to let the client know of this new connection so that + // they can accept or reject it. + client->OnConnectionInitiated(endpoint_id, info, connection_options, + listener, connection_token); + latch.CountDown(); + }); latch.Await(); } void EndpointManager::UnregisterEndpoint(ClientProxy* client, const std::string& endpoint_id) { - NEARBY_LOGS(INFO) << "UnregisterEndpoint for endpoint " << endpoint_id; + LOG(INFO) << "UnregisterEndpoint for endpoint " << endpoint_id; CountDownLatch latch(1); RunOnEndpointManagerThread( "unregister-endpoint", [this, client, endpoint_id, &latch]() { @@ -664,7 +678,7 @@ std::vector EndpointManager::SendPayloadChunk( void EndpointManager::DiscardEndpoint(ClientProxy* client, const std::string& endpoint_id, DisconnectionReason reason) { - NEARBY_LOGS(INFO) << "DiscardEndpoint for endpoint " << endpoint_id; + LOG(INFO) << "DiscardEndpoint for endpoint " << endpoint_id; if (reason == DisconnectionReason::IO_ERROR) { channel_manager_->MarkEndpointStopWaitToDisconnect( endpoint_id, /* is_safe_to_disconnect */ false, @@ -737,8 +751,8 @@ std::vector EndpointManager::SendControlMessage( void EndpointManager::RemoveEndpoint(ClientProxy* client, const std::string& endpoint_id, bool notify, DisconnectionReason reason) { - NEARBY_LOGS(INFO) << "RemoveEndpoint for endpoint: " << endpoint_id - << ", reason: " << reason; + LOG(INFO) << "RemoveEndpoint for endpoint: " << endpoint_id + << ", reason: " << reason; SafeDisconnectionResult safe_disconnect_result = ConnectionsLog::EstablishedConnection::SAFE_DISCONNECTION; @@ -757,8 +771,8 @@ void EndpointManager::RemoveEndpoint(ClientProxy* client, is_safe_disconnection ? ConnectionsLog::EstablishedConnection::SAFE_DISCONNECTION : ConnectionsLog::EstablishedConnection::UNSAFE_DISCONNECTION; - NEARBY_LOGS(INFO) << "[safe-to-disconnect] safe_disconnect_result:" - << (safe_disconnect_result ? "true" : "false"); + LOG(INFO) << "[safe-to-disconnect] safe_disconnect_result:" + << (safe_disconnect_result ? "true" : "false"); } } if (safe_disconnect_result == @@ -780,7 +794,7 @@ void EndpointManager::RemoveEndpoint(ClientProxy* client, reason); client->OnDisconnected(endpoint_id, notify); - NEARBY_LOGS(INFO) << "Removed endpoint for endpoint " << endpoint_id; + LOG(INFO) << "Removed endpoint for endpoint " << endpoint_id; } RemoveEndpointState(endpoint_id); } @@ -788,8 +802,7 @@ void EndpointManager::RemoveEndpoint(ClientProxy* client, bool EndpointManager::ApplySafeToDisconnect(const std::string& endpoint_id, EndpointChannel* endpoint_channel, DisconnectionReason reason) { - NEARBY_LOGS(INFO) << "[safe-to-disconnect] ApplySafeToDisconnect reason: " - << reason; + LOG(INFO) << "[safe-to-disconnect] ApplySafeToDisconnect reason: " << reason; // TODO(b/303544913): clean up the safe-to-disconnect logic bool is_safe_disconnection = false; bool send_disconnection_frame = true; @@ -826,25 +839,24 @@ bool EndpointManager::ApplySafeToDisconnect(const std::string& endpoint_id, // If the channel was paused (i.e. during a bandwidth upgrade negotiation) // we resume to ensure the thread won't hang when trying to write to it. endpoint_channel->Resume(); - NEARBY_LOGS(INFO) << "[safe-to-disconnect] Sending " - "DISCONNECTION frame with request 1, ack 0"; + LOG(INFO) << "[safe-to-disconnect] Sending " + "DISCONNECTION frame with request 1, ack 0"; Exception write_exception = endpoint_channel->Write( parser::ForDisconnection(/* request_safe_to_disconnect= */ true, /* ack_safe_to_disconnect= */ false)); if (!write_exception.Ok()) { - NEARBY_LOGS(WARNING) << "[safe-to-disconnect] Failed to send " - "DISCONNECTION frame to endpoint" - << endpoint_id << " for reason: " << reason; + LOG(WARNING) << "[safe-to-disconnect] Failed to send " + "DISCONNECTION frame to endpoint" + << endpoint_id << " for reason: " << reason; return is_safe_disconnection; } } - NEARBY_LOGS(WARNING) << "[safe-to-disconnect] Wait for " - << (is_wait_for_ack ? "ack" : "disconnection") - << " from endpoint: " << endpoint_id - << " for reason: " << reason << ", timeout in " - << timeout_millis; + LOG(WARNING) << "[safe-to-disconnect] Wait for " + << (is_wait_for_ack ? "ack" : "disconnection") + << " from endpoint: " << endpoint_id << " for reason: " << reason + << ", timeout in " << timeout_millis; bool state = channel_manager_->CreateNewTimeoutDisconnectedState( endpoint_id, timeout_millis); if (!state) return is_safe_disconnection; @@ -857,41 +869,38 @@ bool EndpointManager::ApplySafeToDisconnect(const std::string& endpoint_id, void EndpointManager::WaitForEndpointDisconnectionProcessing( ClientProxy* client, const std::string& service_id, const std::string& endpoint_id, DisconnectionReason reason) { - NEARBY_LOGS(INFO) << "Wait: client=" << client - << "; service_id=" << service_id - << "; endpoint_id=" << endpoint_id; + LOG(INFO) << "Wait: client=" << client << "; service_id=" << service_id + << "; endpoint_id=" << endpoint_id; CountDownLatch barrier = NotifyFrameProcessorsOnEndpointDisconnect( client, service_id, endpoint_id, reason); - NEARBY_LOGS(INFO) - << "Waiting for frame processors to disconnect from endpoint " - << endpoint_id; + LOG(INFO) << "Waiting for frame processors to disconnect from endpoint " + << endpoint_id; if (!barrier.Await(kProcessEndpointDisconnectionTimeout).result()) { - NEARBY_LOGS(INFO) << "Failed to disconnect frame processors from endpoint " - << endpoint_id; + LOG(INFO) << "Failed to disconnect frame processors from endpoint " + << endpoint_id; } else { - NEARBY_LOGS(INFO) << "Finished waiting for frame processors to " - "disconnect from endpoint " - << endpoint_id; + LOG(INFO) << "Finished waiting for frame processors to " + "disconnect from endpoint " + << endpoint_id; } } CountDownLatch EndpointManager::NotifyFrameProcessorsOnEndpointDisconnect( ClientProxy* client, const std::string& service_id, const std::string& endpoint_id, DisconnectionReason reason) { - NEARBY_LOGS(INFO) << "NotifyFrameProcessorsOnEndpointDisconnect: client=" - << client << "; service_id=" << service_id - << "; endpoint_id=" << endpoint_id; + LOG(INFO) << "NotifyFrameProcessorsOnEndpointDisconnect: client=" << client + << "; service_id=" << service_id << "; endpoint_id=" << endpoint_id; MutexLock lock(&frame_processors_lock_); auto total_size = frame_processors_.size(); - NEARBY_LOGS(INFO) << "Total frame processors: " << total_size; + LOG(INFO) << "Total frame processors: " << total_size; CountDownLatch barrier(total_size); int valid = 0; for (auto& item : frame_processors_) { LockedFrameProcessor processor(&item.second); - NEARBY_LOGS(INFO) << "processor=" << processor.get() - << "; frame type=" << V1Frame::FrameType_Name(item.first); + LOG(INFO) << "processor=" << processor.get() + << "; frame type=" << V1Frame::FrameType_Name(item.first); if (processor) { valid++; processor->OnEndpointDisconnect(client, service_id, endpoint_id, barrier, @@ -902,9 +911,9 @@ CountDownLatch EndpointManager::NotifyFrameProcessorsOnEndpointDisconnect( } if (!valid) { - NEARBY_LOGS(INFO) << "No valid frame processors."; + LOG(INFO) << "No valid frame processors."; } else { - NEARBY_LOGS(INFO) << "Valid frame processors: " << valid; + LOG(INFO) << "Valid frame processors: " << valid; } return barrier; } @@ -935,11 +944,10 @@ std::vector EndpointManager::SendTransferFrameBytes( // We no longer know about this endpoint (it was either explicitly // unregistered, or a read/write error made us unregister it // internally). - NEARBY_LOGS(ERROR) << "EndpointManager failed to find EndpointChannel " - "over which to write " - << packet_type << " at offset " << offset - << " of Payload " << payload_id << " to endpoint " - << endpoint_id; + LOG(ERROR) << "EndpointManager failed to find EndpointChannel " + "over which to write " + << packet_type << " at offset " << offset << " of Payload " + << payload_id << " to endpoint " << endpoint_id; failed_endpoint_ids.push_back(endpoint_id); continue; @@ -948,7 +956,7 @@ std::vector EndpointManager::SendTransferFrameBytes( Exception write_exception = channel->Write(bytes, packet_meta_data); if (!write_exception.Ok()) { failed_endpoint_ids.push_back(endpoint_id); - NEARBY_LOGS(INFO) << "Failed to send packet; endpoint_id=" << endpoint_id; + LOG(INFO) << "Failed to send packet; endpoint_id=" << endpoint_id; continue; } analytics::ThroughputRecorderContainer::GetInstance() diff --git a/connections/implementation/endpoint_manager_test.cc b/connections/implementation/endpoint_manager_test.cc index 8c10c07f32..a3c8093c01 100644 --- a/connections/implementation/endpoint_manager_test.cc +++ b/connections/implementation/endpoint_manager_test.cc @@ -98,6 +98,7 @@ class MockEndpointChannel : public EndpointChannel { MOCK_METHOD(void, Resume, (), (override)); MOCK_METHOD(absl::Time, GetLastReadTimestamp, (), (const, override)); MOCK_METHOD(absl::Time, GetLastWriteTimestamp, (), (const, override)); + MOCK_METHOD(uint32_t, GetNextKeepAliveSeqNo, (), (const, override)); MOCK_METHOD(void, SetAnalyticsRecorder, (analytics::AnalyticsRecorder*, const std::string&), (override)); @@ -133,8 +134,8 @@ class MockFrameProcessor : public EndpointManager::FrameProcessor { class SetSafeToDisconnect { public: SetSafeToDisconnect(bool safe_to_disconnect, bool auto_reconnect, - bool payload_received_ack, - std::int32_t safe_to_disconnect_version) { + bool payload_received_ack, + std::int32_t safe_to_disconnect_version) { NearbyFlags::GetInstance().OverrideBoolFlagValue( config_package_nearby::nearby_connections_feature:: kEnableSafeToDisconnect, @@ -236,9 +237,9 @@ TEST_F(EndpointManagerTest, RegisterEndpointCallsOnConnectionInitiated) { } TEST_F(EndpointManagerTest, UnregisterEndpointCallsOnDisconnected) { -// auto endpoint_channel = std::make_unique(); -// EXPECT_CALL(*endpoint_channel, Read()) -// .WillRepeatedly(Return(ExceptionOr(Exception::kIo))); + // auto endpoint_channel = std::make_unique(); + // EXPECT_CALL(*endpoint_channel, Read()) + // .WillRepeatedly(Return(ExceptionOr(Exception::kIo))); RegisterEndpoint(std::make_unique()); // NOTE: disconnect_cb is not called, because we did not reach fully connected // state. On top of that, UnregisterEndpoint is suppressing this notification. @@ -348,8 +349,8 @@ TEST_F(EndpointManagerTest, SendControlMessageAndPayloadAckWorks) { auto failed_ids_1 = em_.SendControlMessage(header, control, std::vector{endpoint_id_}); EXPECT_EQ(failed_ids_1, std::vector{}); - auto failed_ids_2 = em_.SendPayloadAck(header.id(), - std::vector{endpoint_id_}); + auto failed_ids_2 = + em_.SendPayloadAck(header.id(), std::vector{endpoint_id_}); EXPECT_EQ(failed_ids_2, std::vector{}); NEARBY_LOGS(INFO) << "Will unregister endpoint now"; em_.UnregisterEndpoint(client_.get(), endpoint_id_); diff --git a/connections/implementation/fake_endpoint_channel.h b/connections/implementation/fake_endpoint_channel.h index ffdd0453dc..795903f19f 100644 --- a/connections/implementation/fake_endpoint_channel.h +++ b/connections/implementation/fake_endpoint_channel.h @@ -15,8 +15,12 @@ #ifndef NEARBY_CONNECTIONS_IMPLEMENTATION_FAKE_ENDPOINT_CHANNEL_H_ #define NEARBY_CONNECTIONS_IMPLEMENTATION_FAKE_ENDPOINT_CHANNEL_H_ +#include +#include #include +#include "absl/time/time.h" +#include "connections/implementation/analytics/analytics_recorder.h" #include "connections/implementation/endpoint_channel.h" #include "internal/platform/byte_array.h" #include "internal/platform/exception.h" @@ -94,6 +98,9 @@ class FakeEndpointChannel : public EndpointChannel { void Resume() override { is_paused_ = false; } absl::Time GetLastReadTimestamp() const override { return read_timestamp_; } absl::Time GetLastWriteTimestamp() const override { return write_timestamp_; } + uint32_t GetNextKeepAliveSeqNo() const override { + return next_keep_alive_seq_no_++; + } void SetAnalyticsRecorder(analytics::AnalyticsRecorder* analytics_recorder, const std::string& endpoint_id) override {} @@ -116,6 +123,7 @@ class FakeEndpointChannel : public EndpointChannel { bool is_paused_ = false; location::nearby::proto::connections::DisconnectionReason disconnection_reason_; + mutable uint32_t next_keep_alive_seq_no_ = 0; }; } // namespace connections diff --git a/connections/implementation/offline_frames.cc b/connections/implementation/offline_frames.cc index 680bae51a5..f6eb7e1066 100644 --- a/connections/implementation/offline_frames.cc +++ b/connections/implementation/offline_frames.cc @@ -19,6 +19,7 @@ #include #include +#include "connections/connection_options.h" #include "connections/implementation/flags/nearby_connections_feature_flags.h" #include "connections/implementation/internal_payload.h" #include "connections/implementation/offline_frames_validator.h" @@ -27,6 +28,7 @@ #include "connections/status.h" #include "internal/flags/nearby_flags.h" #include "internal/platform/byte_array.h" +#include "internal/platform/exception.h" namespace nearby { namespace connections { @@ -35,15 +37,16 @@ namespace { using ExceptionOrOfflineFrame = ExceptionOr<::location::nearby::connections::OfflineFrame>; +using ::location::nearby::connections::AutoReconnectFrame; using ::location::nearby::connections::BandwidthUpgradeNegotiationFrame; using ::location::nearby::connections::ConnectionRequestFrame; using ::location::nearby::connections::ConnectionResponseFrame; +using ::location::nearby::connections::KeepAliveFrame; using ::location::nearby::connections::LocationHint; using ::location::nearby::connections::OfflineFrame; using ::location::nearby::connections::OsInfo; using ::location::nearby::connections::PayloadTransferFrame; using ::location::nearby::connections::V1Frame; -using ::location::nearby::connections::AutoReconnectFrame; ByteArray ToBytes(OfflineFrame&& frame) { ByteArray bytes(frame.ByteSizeLong()); @@ -178,7 +181,7 @@ ByteArray ForConnectionResponse(std::int32_t status, const OsInfo& os_info, v1_frame->set_type(V1Frame::CONNECTION_RESPONSE); auto* sub_frame = v1_frame->mutable_connection_response(); - // For backward compatiblility, here still sets both status and response + // For backward compatibility, here still sets both status and response // parameters until the response feature is roll out in all supported // devices. sub_frame->set_status(status); @@ -474,6 +477,18 @@ ByteArray ForKeepAlive() { return ToBytes(std::move(frame)); } +ByteArray ForKeepAlive(bool ack, uint32_t seq_num) { + OfflineFrame frame; + + frame.set_version(OfflineFrame::V1); + auto* v1_frame = frame.mutable_v1(); + v1_frame->set_type(V1Frame::KEEP_ALIVE); + KeepAliveFrame* keep_alive = v1_frame->mutable_keep_alive(); + keep_alive->set_ack(ack); + keep_alive->set_seq_num(seq_num); + return ToBytes(std::move(frame)); +} + ByteArray ForDisconnection(bool request_safe_to_disconnect, bool ack_safe_to_disconnect) { OfflineFrame frame; diff --git a/connections/implementation/offline_frames.h b/connections/implementation/offline_frames.h index ccc821147b..9a80eaabe9 100644 --- a/connections/implementation/offline_frames.h +++ b/connections/implementation/offline_frames.h @@ -21,6 +21,7 @@ #include "connections/implementation/proto/offline_wire_formats.pb.h" #include "connections/connection_options.h" +#include "connections/medium_selector.h" #include "internal/platform/byte_array.h" #include "internal/platform/exception.h" @@ -101,6 +102,7 @@ ByteArray ForBwuLastWrite(); ByteArray ForBwuSafeToClose(); ByteArray ForKeepAlive(); +ByteArray ForKeepAlive(bool ack, uint32_t seq_num); ByteArray ForDisconnection(bool request_safe_to_disconnect, bool ack_safe_to_disconnect); ByteArray ForAutoReconnectIntroduction(const std::string& endpoint_id);