From f53b6c336dd4886db0363dd64129b3a90f953260 Mon Sep 17 00:00:00 2001 From: threema-donat <129288638+threema-donat@users.noreply.github.com> Date: Wed, 24 Apr 2024 08:39:11 +0200 Subject: [PATCH] fix: Propagate error when building a HTTP request - Add error::Error variant - Add test for an invalid request --- CHANGELOG.md | 2 +- src/client.rs | 61 +++++++++++++++++++++++++++++---------------------- src/error.rs | 4 ++++ 3 files changed, 40 insertions(+), 27 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index beb4467..97215b9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,7 +2,7 @@ ## Unreleased - - fix: Avoid panics when acquiring a RwLock<_> + - fix: Avoid panics when acquiring a RwLock<_> and when building HTTP requests ## v0.6.2 diff --git a/src/client.rs b/src/client.rs index 2e80818..b409228 100644 --- a/src/client.rs +++ b/src/client.rs @@ -111,7 +111,7 @@ impl Client { /// See [ErrorReason](enum.ErrorReason.html) for possible errors. #[cfg_attr(feature = "tracing", ::tracing::instrument)] pub async fn send(&self, payload: T) -> Result { - let request = self.build_request(payload); + let request = self.build_request(payload)?; let requesting = self.http_client.request(request); let response = requesting.await?; @@ -140,7 +140,7 @@ impl Client { } } - fn build_request(&self, payload: T) -> hyper::Request { + fn build_request(&self, payload: T) -> Result, Error> { let path = format!("https://{}/3/device/{}", self.endpoint, payload.get_device_token()); let mut builder = hyper::Request::builder() @@ -169,17 +169,16 @@ impl Client { } if let Some(ref signer) = self.signer { let auth = signer - .with_signature(|signature| format!("Bearer {}", signature)) - .unwrap(); + .with_signature(|signature| format!("Bearer {}", signature))?; builder = builder.header(AUTHORIZATION, auth.as_bytes()); } - let payload_json = payload.to_json_string().unwrap(); + let payload_json = payload.to_json_string()?; builder = builder.header(CONTENT_LENGTH, format!("{}", payload_json.len()).as_bytes()); let request_body = Body::from(payload_json); - builder.body(request_body).unwrap() + builder.body(request_body).map_err(Error::BuildRequestError) } } @@ -206,7 +205,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ let builder = DefaultNotificationBuilder::new(); let payload = builder.build("a_test_id", Default::default()); let client = Client::new(AlpnConnector::new(), None, Endpoint::Production); - let request = client.build_request(payload); + let request = client.build_request(payload).unwrap(); let uri = format!("{}", request.uri()); assert_eq!("https://api.push.apple.com/3/device/a_test_id", &uri); @@ -217,7 +216,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ let builder = DefaultNotificationBuilder::new(); let payload = builder.build("a_test_id", Default::default()); let client = Client::new(AlpnConnector::new(), None, Endpoint::Sandbox); - let request = client.build_request(payload); + let request = client.build_request(payload).unwrap(); let uri = format!("{}", request.uri()); assert_eq!("https://api.development.push.apple.com/3/device/a_test_id", &uri); @@ -228,17 +227,27 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ let builder = DefaultNotificationBuilder::new(); let payload = builder.build("a_test_id", Default::default()); let client = Client::new(AlpnConnector::new(), None, Endpoint::Production); - let request = client.build_request(payload); + let request = client.build_request(payload).unwrap(); assert_eq!(&Method::POST, request.method()); } + #[test] + fn test_request_invalid() { + let builder = DefaultNotificationBuilder::new(); + let payload = builder.build("\r\n", Default::default()); + let client = Client::new(AlpnConnector::new(), None, Endpoint::Production); + let request = client.build_request(payload); + + assert!(matches!(request, Err(Error::BuildRequestError(_)))); + } + #[test] fn test_request_content_type() { let builder = DefaultNotificationBuilder::new(); let payload = builder.build("a_test_id", Default::default()); let client = Client::new(AlpnConnector::new(), None, Endpoint::Production); - let request = client.build_request(payload); + let request = client.build_request(payload).unwrap(); assert_eq!("application/json", request.headers().get(CONTENT_TYPE).unwrap()); } @@ -248,7 +257,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ let builder = DefaultNotificationBuilder::new(); let payload = builder.build("a_test_id", Default::default()); let client = Client::new(AlpnConnector::new(), None, Endpoint::Production); - let request = client.build_request(payload.clone()); + let request = client.build_request(payload.clone()).unwrap(); let payload_json = payload.to_json_string().unwrap(); let content_length = request.headers().get(CONTENT_LENGTH).unwrap().to_str().unwrap(); @@ -260,7 +269,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ let builder = DefaultNotificationBuilder::new(); let payload = builder.build("a_test_id", Default::default()); let client = Client::new(AlpnConnector::new(), None, Endpoint::Production); - let request = client.build_request(payload); + let request = client.build_request(payload).unwrap(); assert_eq!(None, request.headers().get(AUTHORIZATION)); } @@ -278,7 +287,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ let builder = DefaultNotificationBuilder::new(); let payload = builder.build("a_test_id", Default::default()); let client = Client::new(AlpnConnector::new(), Some(signer), Endpoint::Production); - let request = client.build_request(payload); + let request = client.build_request(payload).unwrap(); assert_ne!(None, request.headers().get(AUTHORIZATION)); } @@ -292,7 +301,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ }; let payload = builder.build("a_test_id", options); let client = Client::new(AlpnConnector::new(), None, Endpoint::Production); - let request = client.build_request(payload); + let request = client.build_request(payload).unwrap(); let apns_push_type = request.headers().get("apns-push-type").unwrap(); assert_eq!("background", apns_push_type); @@ -303,7 +312,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ let builder = DefaultNotificationBuilder::new(); let payload = builder.build("a_test_id", Default::default()); let client = Client::new(AlpnConnector::new(), None, Endpoint::Production); - let request = client.build_request(payload); + let request = client.build_request(payload).unwrap(); let apns_priority = request.headers().get("apns-priority"); assert_eq!(None, apns_priority); @@ -322,7 +331,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ ); let client = Client::new(AlpnConnector::new(), None, Endpoint::Production); - let request = client.build_request(payload); + let request = client.build_request(payload).unwrap(); let apns_priority = request.headers().get("apns-priority").unwrap(); assert_eq!("5", apns_priority); @@ -341,7 +350,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ ); let client = Client::new(AlpnConnector::new(), None, Endpoint::Production); - let request = client.build_request(payload); + let request = client.build_request(payload).unwrap(); let apns_priority = request.headers().get("apns-priority").unwrap(); assert_eq!("10", apns_priority); @@ -354,7 +363,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ let payload = builder.build("a_test_id", Default::default()); let client = Client::new(AlpnConnector::new(), None, Endpoint::Production); - let request = client.build_request(payload); + let request = client.build_request(payload).unwrap(); let apns_id = request.headers().get("apns-id"); assert_eq!(None, apns_id); @@ -373,7 +382,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ ); let client = Client::new(AlpnConnector::new(), None, Endpoint::Production); - let request = client.build_request(payload); + let request = client.build_request(payload).unwrap(); let apns_id = request.headers().get("apns-id").unwrap(); assert_eq!("a-test-apns-id", apns_id); @@ -386,7 +395,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ let payload = builder.build("a_test_id", Default::default()); let client = Client::new(AlpnConnector::new(), None, Endpoint::Production); - let request = client.build_request(payload); + let request = client.build_request(payload).unwrap(); let apns_expiration = request.headers().get("apns-expiration"); assert_eq!(None, apns_expiration); @@ -405,7 +414,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ ); let client = Client::new(AlpnConnector::new(), None, Endpoint::Production); - let request = client.build_request(payload); + let request = client.build_request(payload).unwrap(); let apns_expiration = request.headers().get("apns-expiration").unwrap(); assert_eq!("420", apns_expiration); @@ -418,7 +427,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ let payload = builder.build("a_test_id", Default::default()); let client = Client::new(AlpnConnector::new(), None, Endpoint::Production); - let request = client.build_request(payload); + let request = client.build_request(payload).unwrap(); let apns_collapse_id = request.headers().get("apns-collapse-id"); assert_eq!(None, apns_collapse_id); @@ -437,7 +446,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ ); let client = Client::new(AlpnConnector::new(), None, Endpoint::Production); - let request = client.build_request(payload); + let request = client.build_request(payload).unwrap(); let apns_collapse_id = request.headers().get("apns-collapse-id").unwrap(); assert_eq!("a_collapse_id", apns_collapse_id); @@ -450,7 +459,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ let payload = builder.build("a_test_id", Default::default()); let client = Client::new(AlpnConnector::new(), None, Endpoint::Production); - let request = client.build_request(payload); + let request = client.build_request(payload).unwrap(); let apns_topic = request.headers().get("apns-topic"); assert_eq!(None, apns_topic); @@ -469,7 +478,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ ); let client = Client::new(AlpnConnector::new(), None, Endpoint::Production); - let request = client.build_request(payload); + let request = client.build_request(payload).unwrap(); let apns_topic = request.headers().get("apns-topic").unwrap(); assert_eq!("a_topic", apns_topic); @@ -480,7 +489,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ let builder = DefaultNotificationBuilder::new(); let payload = builder.build("a_test_id", Default::default()); let client = Client::new(AlpnConnector::new(), None, Endpoint::Production); - let request = client.build_request(payload.clone()); + let request = client.build_request(payload.clone()).unwrap(); let body = hyper::body::to_bytes(request).await.unwrap(); let body_str = String::from_utf8(body.to_vec()).unwrap(); diff --git a/src/error.rs b/src/error.rs index af7ffb1..a583d83 100644 --- a/src/error.rs +++ b/src/error.rs @@ -38,6 +38,10 @@ pub enum Error { #[error("Error in reading a certificate file: {0}")] ReadError(#[from] io::Error), + /// Error while creating the HTTP request + #[error("Failed to construct HTTP request: {0}")] + BuildRequestError(#[source] http::Error), + /// Unexpected private key (only EC keys are supported). #[cfg(all(not(feature = "openssl"), feature = "ring"))] #[error("Unexpected private key: {0}")]