Skip to content

Commit

Permalink
streaming body decompression
Browse files Browse the repository at this point in the history
Client request body was fully loaded in memory before decompressing, we
are now decompressing it as it goes
  • Loading branch information
Geal authored and abernix committed Mar 6, 2024
1 parent e06cf5f commit 9e9527c
Show file tree
Hide file tree
Showing 6 changed files with 61 additions and 108 deletions.
47 changes: 34 additions & 13 deletions apollo-router/src/axum_factory/axum_http_server_factory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use std::sync::atomic::Ordering;
use std::sync::Arc;
use std::time::Instant;

use axum::error_handling::HandleErrorLayer;
use axum::extract::Extension;
use axum::extract::State;
use axum::http::StatusCode;
Expand All @@ -32,14 +33,15 @@ use tokio::sync::mpsc;
use tokio_rustls::TlsAcceptor;
use tower::service_fn;
use tower::BoxError;
use tower::ServiceBuilder;
use tower::ServiceExt;
use tower_http::decompression::DecompressionBody;
use tower_http::trace::TraceLayer;

use super::listeners::ensure_endpoints_consistency;
use super::listeners::ensure_listenaddrs_consistency;
use super::listeners::extra_endpoints;
use super::listeners::ListenersAndRouters;
use super::utils::decompress_request_body;
use super::utils::PropagatingMakeSpan;
use super::ListenAddrAndRouter;
use super::ENDPOINT_CALLBACK;
Expand All @@ -57,6 +59,7 @@ use crate::plugins::traffic_shaping::RateLimited;
use crate::router::ApolloRouterError;
use crate::router_factory::Endpoint;
use crate::router_factory::RouterFactory;
use crate::services::http::service::BodyStream;
use crate::services::router;
use crate::uplink::license_enforcement::LicenseState;
use crate::uplink::license_enforcement::APOLLO_ROUTER_LICENSE_EXPIRED;
Expand Down Expand Up @@ -173,11 +176,9 @@ where
tracing::trace!(?health, request = ?req.router_request, "health check");
async move {
Ok(router::Response {
response: http::Response::builder()
.status(status_code)
.body::<hyper::Body>(
serde_json::to_vec(&health).map_err(BoxError::from)?.into(),
)?,
response: http::Response::builder().status(status_code).body::<Body>(
serde_json::to_vec(&health).map_err(BoxError::from)?.into(),
)?,
context: req.context,
})
}
Expand Down Expand Up @@ -422,6 +423,10 @@ pub(crate) fn span_mode(configuration: &Configuration) -> SpanMode {
.unwrap_or_default()
}

async fn decompression_error(_error: BoxError) -> axum::response::Response {
(StatusCode::BAD_REQUEST, "cannot decompress request body").into_response()
}

fn main_endpoint<RF>(
service_factory: RF,
configuration: &Configuration,
Expand All @@ -436,8 +441,16 @@ where
})?;
let span_mode = span_mode(configuration);

let decompression = ServiceBuilder::new()
.layer(HandleErrorLayer::<_, ()>::new(decompression_error))
.layer(
tower_http::decompression::RequestDecompressionLayer::new()
.br(true)
.gzip(true)
.deflate(true),
);
let mut main_route = main_router::<RF>(configuration)
.layer(middleware::from_fn(decompress_request_body))
.layer(decompression)
.layer(middleware::from_fn_with_state(
(license, Instant::now(), Arc::new(AtomicU64::new(0))),
license_handler,
Expand Down Expand Up @@ -530,19 +543,21 @@ async fn license_handler<B>(
}
}

pub(super) fn main_router<RF>(configuration: &Configuration) -> axum::Router
pub(super) fn main_router<RF>(
configuration: &Configuration,
) -> axum::Router<(), DecompressionBody<Body>>
where
RF: RouterFactory,
{
let mut router = Router::new().route(
&configuration.supergraph.sanitized_path(),
get({
move |Extension(service): Extension<RF>, request: Request<Body>| {
move |Extension(service): Extension<RF>, request: Request<DecompressionBody<Body>>| {
handle_graphql(service.create().boxed(), request)
}
})
.post({
move |Extension(service): Extension<RF>, request: Request<Body>| {
move |Extension(service): Extension<RF>, request: Request<DecompressionBody<Body>>| {
handle_graphql(service.create().boxed(), request)
}
}),
Expand All @@ -552,12 +567,14 @@ where
router = router.route(
"/",
get({
move |Extension(service): Extension<RF>, request: Request<Body>| {
move |Extension(service): Extension<RF>,
request: Request<DecompressionBody<Body>>| {
handle_graphql(service.create().boxed(), request)
}
})
.post({
move |Extension(service): Extension<RF>, request: Request<Body>| {
move |Extension(service): Extension<RF>,
request: Request<DecompressionBody<Body>>| {
handle_graphql(service.create().boxed(), request)
}
}),
Expand All @@ -569,10 +586,14 @@ where

async fn handle_graphql(
service: router::BoxService,
http_request: Request<Body>,
http_request: Request<DecompressionBody<Body>>,
) -> impl IntoResponse {
let _guard = SessionCountGuard::start();

let (parts, body) = http_request.into_parts();

let http_request = http::Request::from_parts(parts, Body::wrap_stream(BodyStream::new(body)));

let request: router::Request = http_request.into();
let context = request.context.clone();
let accept_encoding = request
Expand Down
88 changes: 0 additions & 88 deletions apollo-router/src/axum_factory/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,8 @@
use std::net::SocketAddr;

use async_compression::tokio::write::BrotliDecoder;
use async_compression::tokio::write::GzipDecoder;
use async_compression::tokio::write::ZlibDecoder;
use axum::http::StatusCode;
use axum::middleware::Next;
use axum::response::*;
use futures::prelude::*;
use http::header::CONTENT_ENCODING;
use http::Request;
use hyper::Body;
use opentelemetry::global;
use opentelemetry::trace::TraceContextExt;
use tokio::io::AsyncWriteExt;
use tower_http::trace::MakeSpan;
use tower_service::Service;
use tracing::Span;
Expand All @@ -26,83 +15,6 @@ use crate::uplink::license_enforcement::LICENSE_EXPIRED_SHORT_MESSAGE;

pub(crate) const REQUEST_SPAN_NAME: &str = "request";

pub(super) async fn decompress_request_body(
req: Request<Body>,
next: Next<Body>,
) -> Result<Response, Response> {
let (parts, body) = req.into_parts();
let content_encoding = parts.headers.get(&CONTENT_ENCODING);
macro_rules! decode_body {
($decoder: ident, $error_message: expr) => {{
let body_bytes = hyper::body::to_bytes(body)
.map_err(|err| {
(
StatusCode::BAD_REQUEST,
format!("cannot read request body: {err}"),
)
.into_response()
})
.await?;
let mut decoder = $decoder::new(Vec::new());
decoder.write_all(&body_bytes).await.map_err(|err| {
(
StatusCode::BAD_REQUEST,
format!("{}: {err}", $error_message),
)
.into_response()
})?;
decoder.shutdown().await.map_err(|err| {
(
StatusCode::BAD_REQUEST,
format!("{}: {err}", $error_message),
)
.into_response()
})?;

Ok(next
.run(Request::from_parts(parts, Body::from(decoder.into_inner())))
.await)
}};
}

match content_encoding {
Some(content_encoding) => match content_encoding.to_str() {
Ok(content_encoding_str) => match content_encoding_str {
"br" => decode_body!(BrotliDecoder, "cannot decompress (brotli) request body"),
"gzip" => decode_body!(GzipDecoder, "cannot decompress (gzip) request body"),
"deflate" => decode_body!(ZlibDecoder, "cannot decompress (deflate) request body"),
"identity" => Ok(next.run(Request::from_parts(parts, body)).await),
unknown => {
let message = format!("unknown content-encoding header value {unknown:?}");
tracing::error!(message);
u64_counter!(
"apollo_router_http_requests_total",
"Total number of HTTP requests made.",
1,
status = StatusCode::BAD_REQUEST.as_u16() as i64,
error = message.clone()
);

Err((StatusCode::BAD_REQUEST, message).into_response())
}
},

Err(err) => {
let message = format!("cannot read content-encoding header: {err}");
u64_counter!(
"apollo_router_http_requests_total",
"Total number of HTTP requests made.",
1,
status = 400,
error = message.clone()
);
Err((StatusCode::BAD_REQUEST, message).into_response())
}
},
None => Ok(next.run(Request::from_parts(parts, body)).await),
}
}

#[derive(Clone, Default)]
pub(crate) struct PropagatingMakeSpan {
pub(crate) license: LicenseState,
Expand Down
7 changes: 7 additions & 0 deletions apollo-router/src/services/http/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,13 @@ pin_project! {
}
}

impl<B: hyper::body::HttpBody> BodyStream<B> {
/// Create a new `BodyStream`.
pub(crate) fn new(body: DecompressionBody<B>) -> Self {
Self { inner: body }
}
}

impl<B> Stream for BodyStream<B>
where
B: hyper::body::HttpBody,
Expand Down
11 changes: 11 additions & 0 deletions apollo-router/src/services/layers/content_negotiation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,17 @@ where
.to_string(),
))
.expect("cannot fail");
u64_counter!(
"apollo_router_http_requests_total",
"Total number of HTTP requests made.",
1,
status = StatusCode::UNSUPPORTED_MEDIA_TYPE.as_u16() as i64,
error = format!(
r#"'content-type' header must be one of: {:?} or {:?}"#,
APPLICATION_JSON.essence_str(),
GRAPHQL_JSON_RESPONSE_HEADER_VALUE,
)
);

return Ok(ControlFlow::Break(response.into()));
}
Expand Down
11 changes: 6 additions & 5 deletions apollo-router/tests/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ use std::time::SystemTime;

use buildstructor::buildstructor;
use http::header::ACCEPT;
use http::header::CONTENT_ENCODING;
use http::header::CONTENT_TYPE;
use http::HeaderValue;
use jsonpath_lib::Selector;
Expand Down Expand Up @@ -378,7 +377,7 @@ impl IntegrationTest {
}

#[allow(dead_code)]
pub fn execute_bad_content_encoding(
pub fn execute_bad_content_type(
&self,
) -> impl std::future::Future<Output = (String, reqwest::Response)> {
self.execute_query_internal(&json!({"garbage":{}}), Some("garbage"))
Expand All @@ -387,7 +386,7 @@ impl IntegrationTest {
fn execute_query_internal(
&self,
query: &Value,
content_encoding: Option<&'static str>,
content_type: Option<&'static str>,
) -> impl std::future::Future<Output = (String, reqwest::Response)> {
assert!(
self.router.is_some(),
Expand All @@ -404,8 +403,10 @@ impl IntegrationTest {

let mut request = client
.post("http://localhost:4000")
.header(CONTENT_TYPE, APPLICATION_JSON.essence_str())
.header(CONTENT_ENCODING, content_encoding.unwrap_or("identity"))
.header(
CONTENT_TYPE,
content_type.unwrap_or(APPLICATION_JSON.essence_str()),
)
.header("apollographql-client-name", "custom_name")
.header("apollographql-client-version", "1.0")
.header("x-my-header", "test")
Expand Down
5 changes: 3 additions & 2 deletions apollo-router/tests/telemetry/metrics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,10 +135,11 @@ async fn test_bad_queries() {
None,
)
.await;
router.execute_bad_content_encoding().await;
router.execute_bad_content_type().await;

router
.assert_metrics_contains(
r#"apollo_router_http_requests_total{error="unknown content-encoding header value \"garbage\"",status="400",otel_scope_name="apollo/router"}"#,
r#"apollo_router_http_requests_total{error="'content-type' header must be one of: \"application/json\" or \"application/graphql-response+json\"",status="415",otel_scope_name="apollo/router"}"#,
None,
)
.await;
Expand Down

0 comments on commit 9e9527c

Please sign in to comment.