Skip to content

Commit

Permalink
feat(tls): Add tls handshake timeout support
Browse files Browse the repository at this point in the history
Implement timeout controls in the TLS connection process to prevent the
client from getting stuck due to the server becoming unresponsive while
handling TLS.

Refs: #2072
  • Loading branch information
honsunrise committed Jan 13, 2025
1 parent 5ad89bf commit 1d1dea5
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 8 deletions.
17 changes: 13 additions & 4 deletions tonic/src/transport/channel/service/tls.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use std::fmt;
use std::sync::Arc;
use std::{sync::Arc, time::Duration};

use hyper_util::rt::TokioIo;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::time;
use tokio_rustls::{
rustls::{
crypto,
Expand All @@ -23,6 +24,7 @@ pub(crate) struct TlsConnector {
config: Arc<ClientConfig>,
domain: Arc<ServerName<'static>>,
assume_http2: bool,
timeout: Option<Duration>,
}

impl TlsConnector {
Expand All @@ -34,6 +36,7 @@ impl TlsConnector {
assume_http2: bool,
#[cfg(feature = "tls-native-roots")] with_native_roots: bool,
#[cfg(feature = "tls-webpki-roots")] with_webpki_roots: bool,
timeout: Option<Duration>,
) -> Result<Self, crate::BoxError> {
fn with_provider(
provider: Arc<crypto::CryptoProvider>,
Expand Down Expand Up @@ -92,16 +95,22 @@ impl TlsConnector {
config: Arc::new(config),
domain: Arc::new(ServerName::try_from(domain)?.to_owned()),
assume_http2,
timeout,
})
}

pub(crate) async fn connect<I>(&self, io: I) -> Result<BoxedIo, crate::BoxError>
where
I: AsyncRead + AsyncWrite + Send + Unpin + 'static,
{
let io = RustlsConnector::from(self.config.clone())
.connect(self.domain.as_ref().to_owned(), io)
.await?;
let conn_fut =
RustlsConnector::from(self.config.clone()).connect(self.domain.as_ref().to_owned(), io);
let io = match self.timeout {
Some(timeout) => time::timeout(timeout, conn_fut)
.await
.map_err(|_| TlsError::HandshakeTimeout)?,
None => conn_fut.await,
}?;

// Generally we require ALPN to be negotiated, but if the user has
// explicitly set `assume_http2` to true, we'll allow it to be missing.
Expand Down
11 changes: 11 additions & 0 deletions tonic/src/transport/channel/tls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use crate::transport::{
Error,
};
use http::Uri;
use std::time::Duration;
use tokio_rustls::rustls::pki_types::TrustAnchor;

/// Configures TLS settings for endpoints.
Expand All @@ -18,6 +19,7 @@ pub struct ClientTlsConfig {
with_native_roots: bool,
#[cfg(feature = "tls-webpki-roots")]
with_webpki_roots: bool,
timeout: Option<Duration>,
}

impl ClientTlsConfig {
Expand Down Expand Up @@ -112,6 +114,14 @@ impl ClientTlsConfig {
config
}

/// Sets the timeout for the TLS handshake.
pub fn timeout(self, timeout: Duration) -> Self {
ClientTlsConfig {
timeout: Some(timeout),
..self
}
}

pub(crate) fn into_tls_connector(self, uri: &Uri) -> Result<TlsConnector, crate::BoxError> {
let domain = match &self.domain {
Some(domain) => domain,
Expand All @@ -127,6 +137,7 @@ impl ClientTlsConfig {
self.with_native_roots,
#[cfg(feature = "tls-webpki-roots")]
self.with_webpki_roots,
self.timeout,
)
}
}
19 changes: 16 additions & 3 deletions tonic/src/transport/server/service/tls.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,24 @@
use std::{fmt, sync::Arc};
use std::{fmt, sync::Arc, time::Duration};

use tokio::io::{AsyncRead, AsyncWrite};
use tokio::time;
use tokio_rustls::{
rustls::{server::WebPkiClientVerifier, RootCertStore, ServerConfig},
server::TlsStream,
TlsAcceptor as RustlsAcceptor,
};

use crate::transport::{
service::tls::{convert_certificate_to_pki_types, convert_identity_to_pki_types, ALPN_H2},
service::tls::{
convert_certificate_to_pki_types, convert_identity_to_pki_types, TlsError, ALPN_H2,
},
Certificate, Identity,
};

#[derive(Clone)]
pub(crate) struct TlsAcceptor {
inner: Arc<ServerConfig>,
timeout: Option<Duration>,
}

impl TlsAcceptor {
Expand All @@ -23,6 +27,7 @@ impl TlsAcceptor {
client_ca_root: Option<&Certificate>,
client_auth_optional: bool,
ignore_client_order: bool,
timeout: Option<Duration>,
) -> Result<Self, crate::BoxError> {
let builder = ServerConfig::builder();

Expand All @@ -48,6 +53,7 @@ impl TlsAcceptor {
config.alpn_protocols.push(ALPN_H2.into());
Ok(Self {
inner: Arc::new(config),
timeout,
})
}

Expand All @@ -56,7 +62,14 @@ impl TlsAcceptor {
IO: AsyncRead + AsyncWrite + Unpin,
{
let acceptor = RustlsAcceptor::from(self.inner.clone());
acceptor.accept(io).await.map_err(Into::into)
let accept_fut = acceptor.accept(io);
match self.timeout {
Some(timeout) => time::timeout(timeout, accept_fut)
.await
.map_err(|_| TlsError::HandshakeTimeout)?,
None => accept_fut.await,
}
.map_err(Into::into)
}
}

Expand Down
15 changes: 14 additions & 1 deletion tonic/src/transport/server/tls.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::fmt;
use std::{fmt, time::Duration};

use super::service::TlsAcceptor;
use crate::transport::tls::{Certificate, Identity};
Expand All @@ -10,6 +10,7 @@ pub struct ServerTlsConfig {
client_ca_root: Option<Certificate>,
client_auth_optional: bool,
ignore_client_order: bool,
timeout: Option<Duration>,
}

impl fmt::Debug for ServerTlsConfig {
Expand Down Expand Up @@ -64,12 +65,24 @@ impl ServerTlsConfig {
}
}

/// Sets the timeout for the TLS handshake.
///
/// # Default
/// By default, this option is set to `None`.
pub fn timeout(self, timeout: Duration) -> Self {
ServerTlsConfig {
timeout: Some(timeout),
..self
}
}

pub(crate) fn tls_acceptor(&self) -> Result<TlsAcceptor, crate::BoxError> {
TlsAcceptor::new(
self.identity.as_ref().unwrap(),
self.client_ca_root.as_ref(),
self.client_auth_optional,
self.ignore_client_order,
self.timeout,
)
}
}
2 changes: 2 additions & 0 deletions tonic/src/transport/service/tls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ pub(crate) enum TlsError {
NativeCertsNotFound,
CertificateParseError,
PrivateKeyParseError,
HandshakeTimeout,
}

impl fmt::Display for TlsError {
Expand All @@ -29,6 +30,7 @@ impl fmt::Display for TlsError {
f,
"Error parsing TLS private key - no RSA or PKCS8-encoded keys found."
),
TlsError::HandshakeTimeout => write!(f, "TLS handshake timeout."),
}
}
}
Expand Down

0 comments on commit 1d1dea5

Please sign in to comment.