Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Expose GrpcTimeout middleware #2162

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions tonic/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,20 +36,22 @@ tls-aws-lc = ["_tls-any", "tokio-rustls/aws-lc-rs"]
tls-native-roots = ["_tls-any", "channel", "dep:rustls-native-certs"]
tls-webpki-roots = ["_tls-any","channel", "dep:webpki-roots"]
router = ["dep:axum", "dep:tower", "tower?/util"]
timeout = ["dep:tokio", "tokio?/time"]
server = [
"timeout",
"dep:h2",
"dep:hyper", "hyper?/server",
"dep:hyper-util", "hyper-util?/service", "hyper-util?/server-auto",
"dep:socket2",
"dep:tokio", "tokio?/macros", "tokio?/net", "tokio?/time",
"dep:tokio", "tokio?/macros", "tokio?/net",
"tokio-stream/net",
"dep:tower", "tower?/util", "tower?/limit",
]
channel = [
"timeout",
"dep:hyper", "hyper?/client",
"dep:hyper-util", "hyper-util?/client-legacy",
"dep:tower", "tower?/balance", "tower?/buffer", "tower?/discover", "tower?/limit", "tower?/util",
"dep:tokio", "tokio?/time",
"dep:hyper-timeout",
]
transport = ["server", "channel"]
Expand Down
2 changes: 2 additions & 0 deletions tonic/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
//! - `server`: Enables just the full featured server portion of the `transport` feature.
//! - `channel`: Enables just the full featured channel portion of the `transport` feature.
//! - `router`: Enables the [`axum`] based service router. Enabled by default.
//! - `timeout`: Enables timeout related feature including `GrpcTimeout` middleware. Enabled
//! by default.
//! - `codegen`: Enables all the required exports and optional dependencies required
//! for [`tonic-build`]. Enabled by default.
//! - `tls-ring`: Enables the [`rustls`] based TLS options for the `transport` feature using
Expand Down
Original file line number Diff line number Diff line change
@@ -1,23 +1,50 @@
//! Middleware which implements gRPC timeout.

use crate::{metadata::GRPC_TIMEOUT_HEADER, TimeoutExpired};
use http::{HeaderMap, HeaderValue, Request};
use pin_project::pin_project;
use std::{
fmt,
future::Future,
pin::Pin,
task::{ready, Context, Poll},
time::Duration,
};
use tokio::time::Sleep;
use tower_layer::Layer;
use tower_service::Service;

/// Layer which applies the [`GrpcTimeout`] middleware.
#[derive(Debug, Clone)]
pub(crate) struct GrpcTimeout<S> {
pub struct GrpcTimeoutLayer {
server_timeout: Option<Duration>,
}

impl<S> Layer<S> for GrpcTimeoutLayer {
type Service = GrpcTimeout<S>;

fn layer(&self, inner: S) -> Self::Service {
GrpcTimeout::new(inner, self.server_timeout)
}
}

impl GrpcTimeoutLayer {
/// Create a new `GrpcTimeoutLayer`.
pub fn new(server_timeout: Option<Duration>) -> Self {
Self { server_timeout }
}
}

/// Middleware which implements gRPC timeout.
#[derive(Debug, Clone)]
pub struct GrpcTimeout<S> {
inner: S,
server_timeout: Option<Duration>,
}

impl<S> GrpcTimeout<S> {
pub(crate) fn new(inner: S, server_timeout: Option<Duration>) -> Self {
/// Create a new [`GrpcTimeout`] middleware.
pub fn new(inner: S, server_timeout: Option<Duration>) -> Self {
Self {
inner,
server_timeout,
Expand Down Expand Up @@ -62,14 +89,21 @@ where
}
}

/// Response future for [`GrpcTimeout`].
#[pin_project]
pub(crate) struct ResponseFuture<F> {
pub struct ResponseFuture<F> {
#[pin]
inner: F,
#[pin]
sleep: Option<Sleep>,
}

impl<F> fmt::Debug for ResponseFuture<F> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ResponseFuture").finish()
}
}

impl<F, Res, E> Future for ResponseFuture<F>
where
F: Future<Output = Result<Res, E>>,
Expand Down
5 changes: 5 additions & 0 deletions tonic/src/service/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,8 @@ pub use axum::{body::Body as AxumBody, Router as AxumRouter};

pub mod recover_error;
pub use self::recover_error::{RecoverError, RecoverErrorLayer};

#[cfg(feature = "timeout")]
pub mod grpc_timeout;
#[cfg(feature = "timeout")]
pub use self::grpc_timeout::{GrpcTimeout, GrpcTimeoutLayer};
5 changes: 3 additions & 2 deletions tonic/src/transport/channel/service/connection.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use super::{AddOrigin, Reconnect, SharedExec, UserAgent};
use crate::{
body::Body,
transport::{channel::BoxFuture, service::GrpcTimeout, Endpoint},
service::GrpcTimeoutLayer,
transport::{channel::BoxFuture, Endpoint},
};
use http::{Request, Response, Uri};
use hyper::rt;
Expand Down Expand Up @@ -62,7 +63,7 @@ impl Connection {
AddOrigin::new(s, origin)
})
.layer_fn(|s| UserAgent::new(s, endpoint.user_agent.clone()))
.layer_fn(|s| GrpcTimeout::new(s, endpoint.timeout))
.layer(GrpcTimeoutLayer::new(endpoint.timeout))
.option_layer(endpoint.concurrency_limit.map(ConcurrencyLimitLayer::new))
.option_layer(endpoint.rate_limit.map(|(l, d)| RateLimitLayer::new(l, d)))
.into_inner();
Expand Down
5 changes: 2 additions & 3 deletions tonic/src/transport/server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,8 @@ pub use incoming::TcpIncoming;
use crate::transport::Error;

use self::service::{ConnectInfoLayer, ServerIo};
use super::service::GrpcTimeout;
use crate::body::Body;
use crate::service::RecoverErrorLayer;
use crate::service::{GrpcTimeoutLayer, RecoverErrorLayer};
use bytes::Bytes;
use http::{Request, Response};
use http_body_util::BodyExt;
Expand Down Expand Up @@ -1090,7 +1089,7 @@ where
let svc = ServiceBuilder::new()
.layer(RecoverErrorLayer::new())
.option_layer(concurrency_limit.map(ConcurrencyLimitLayer::new))
.layer_fn(|s| GrpcTimeout::new(s, timeout))
.layer(GrpcTimeoutLayer::new(timeout))
.service(svc);

let svc = ServiceBuilder::new()
Expand Down
3 changes: 0 additions & 3 deletions tonic/src/transport/service/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,2 @@
pub(crate) mod grpc_timeout;
#[cfg(feature = "_tls-any")]
pub(crate) mod tls;

pub(crate) use self::grpc_timeout::GrpcTimeout;