diff --git a/tonic/Cargo.toml b/tonic/Cargo.toml index 33fe76af1..9e93008e9 100644 --- a/tonic/Cargo.toml +++ b/tonic/Cargo.toml @@ -36,20 +36,22 @@ tls-aws-lc = ["_tls-any", "tokio-rustls/aws-lc-rs"] tls-native-roots = ["_tls-any", "channel", "dep:rustls-native-certs"] tls-webpki-roots = ["_tls-any","channel", "dep:webpki-roots"] router = ["dep:axum", "dep:tower", "tower?/util"] +timeout = ["dep:tokio", "tokio?/time"] server = [ + "timeout", "dep:h2", "dep:hyper", "hyper?/server", "dep:hyper-util", "hyper-util?/service", "hyper-util?/server-auto", "dep:socket2", - "dep:tokio", "tokio?/macros", "tokio?/net", "tokio?/time", + "dep:tokio", "tokio?/macros", "tokio?/net", "tokio-stream/net", "dep:tower", "tower?/util", "tower?/limit", ] channel = [ + "timeout", "dep:hyper", "hyper?/client", "dep:hyper-util", "hyper-util?/client-legacy", "dep:tower", "tower?/balance", "tower?/buffer", "tower?/discover", "tower?/limit", "tower?/util", - "dep:tokio", "tokio?/time", "dep:hyper-timeout", ] transport = ["server", "channel"] diff --git a/tonic/src/lib.rs b/tonic/src/lib.rs index 3dcbd2108..e1ca38140 100644 --- a/tonic/src/lib.rs +++ b/tonic/src/lib.rs @@ -22,6 +22,8 @@ //! - `server`: Enables just the full featured server portion of the `transport` feature. //! - `channel`: Enables just the full featured channel portion of the `transport` feature. //! - `router`: Enables the [`axum`] based service router. Enabled by default. +//! - `timeout`: Enables timeout related feature including `GrpcTimeout` middleware. Enabled +//! by default. //! - `codegen`: Enables all the required exports and optional dependencies required //! for [`tonic-build`]. Enabled by default. //! - `tls-ring`: Enables the [`rustls`] based TLS options for the `transport` feature using diff --git a/tonic/src/transport/service/grpc_timeout.rs b/tonic/src/service/grpc_timeout.rs similarity index 88% rename from tonic/src/transport/service/grpc_timeout.rs rename to tonic/src/service/grpc_timeout.rs index 019a37a2f..320ed4a80 100644 --- a/tonic/src/transport/service/grpc_timeout.rs +++ b/tonic/src/service/grpc_timeout.rs @@ -1,23 +1,50 @@ +//! Middleware which implements gRPC timeout. + use crate::{metadata::GRPC_TIMEOUT_HEADER, TimeoutExpired}; use http::{HeaderMap, HeaderValue, Request}; use pin_project::pin_project; use std::{ + fmt, future::Future, pin::Pin, task::{ready, Context, Poll}, time::Duration, }; use tokio::time::Sleep; +use tower_layer::Layer; use tower_service::Service; +/// Layer which applies the [`GrpcTimeout`] middleware. #[derive(Debug, Clone)] -pub(crate) struct GrpcTimeout { +pub struct GrpcTimeoutLayer { + server_timeout: Option, +} + +impl Layer for GrpcTimeoutLayer { + type Service = GrpcTimeout; + + fn layer(&self, inner: S) -> Self::Service { + GrpcTimeout::new(inner, self.server_timeout) + } +} + +impl GrpcTimeoutLayer { + /// Create a new `GrpcTimeoutLayer`. + pub fn new(server_timeout: Option) -> Self { + Self { server_timeout } + } +} + +/// Middleware which implements gRPC timeout. +#[derive(Debug, Clone)] +pub struct GrpcTimeout { inner: S, server_timeout: Option, } impl GrpcTimeout { - pub(crate) fn new(inner: S, server_timeout: Option) -> Self { + /// Create a new [`GrpcTimeout`] middleware. + pub fn new(inner: S, server_timeout: Option) -> Self { Self { inner, server_timeout, @@ -62,14 +89,21 @@ where } } +/// Response future for [`GrpcTimeout`]. #[pin_project] -pub(crate) struct ResponseFuture { +pub struct ResponseFuture { #[pin] inner: F, #[pin] sleep: Option, } +impl fmt::Debug for ResponseFuture { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ResponseFuture").finish() + } +} + impl Future for ResponseFuture where F: Future>, diff --git a/tonic/src/service/mod.rs b/tonic/src/service/mod.rs index f1e860637..f34256a8c 100644 --- a/tonic/src/service/mod.rs +++ b/tonic/src/service/mod.rs @@ -16,3 +16,8 @@ pub use axum::{body::Body as AxumBody, Router as AxumRouter}; pub mod recover_error; pub use self::recover_error::{RecoverError, RecoverErrorLayer}; + +#[cfg(feature = "timeout")] +pub mod grpc_timeout; +#[cfg(feature = "timeout")] +pub use self::grpc_timeout::{GrpcTimeout, GrpcTimeoutLayer}; diff --git a/tonic/src/transport/channel/service/connection.rs b/tonic/src/transport/channel/service/connection.rs index 4e84ac92e..27d31a894 100644 --- a/tonic/src/transport/channel/service/connection.rs +++ b/tonic/src/transport/channel/service/connection.rs @@ -1,7 +1,8 @@ use super::{AddOrigin, Reconnect, SharedExec, UserAgent}; use crate::{ body::Body, - transport::{channel::BoxFuture, service::GrpcTimeout, Endpoint}, + service::GrpcTimeoutLayer, + transport::{channel::BoxFuture, Endpoint}, }; use http::{Request, Response, Uri}; use hyper::rt; @@ -62,7 +63,7 @@ impl Connection { AddOrigin::new(s, origin) }) .layer_fn(|s| UserAgent::new(s, endpoint.user_agent.clone())) - .layer_fn(|s| GrpcTimeout::new(s, endpoint.timeout)) + .layer(GrpcTimeoutLayer::new(endpoint.timeout)) .option_layer(endpoint.concurrency_limit.map(ConcurrencyLimitLayer::new)) .option_layer(endpoint.rate_limit.map(|(l, d)| RateLimitLayer::new(l, d))) .into_inner(); diff --git a/tonic/src/transport/server/mod.rs b/tonic/src/transport/server/mod.rs index 2427065d6..9923119ce 100644 --- a/tonic/src/transport/server/mod.rs +++ b/tonic/src/transport/server/mod.rs @@ -42,9 +42,8 @@ pub use incoming::TcpIncoming; use crate::transport::Error; use self::service::{ConnectInfoLayer, ServerIo}; -use super::service::GrpcTimeout; use crate::body::Body; -use crate::service::RecoverErrorLayer; +use crate::service::{GrpcTimeoutLayer, RecoverErrorLayer}; use bytes::Bytes; use http::{Request, Response}; use http_body_util::BodyExt; @@ -1090,7 +1089,7 @@ where let svc = ServiceBuilder::new() .layer(RecoverErrorLayer::new()) .option_layer(concurrency_limit.map(ConcurrencyLimitLayer::new)) - .layer_fn(|s| GrpcTimeout::new(s, timeout)) + .layer(GrpcTimeoutLayer::new(timeout)) .service(svc); let svc = ServiceBuilder::new() diff --git a/tonic/src/transport/service/mod.rs b/tonic/src/transport/service/mod.rs index b41869c7c..a0ee9c3bc 100644 --- a/tonic/src/transport/service/mod.rs +++ b/tonic/src/transport/service/mod.rs @@ -1,5 +1,2 @@ -pub(crate) mod grpc_timeout; #[cfg(feature = "_tls-any")] pub(crate) mod tls; - -pub(crate) use self::grpc_timeout::GrpcTimeout;