Skip to content

Commit

Permalink
fix: Propagate error when building a HTTP request
Browse files Browse the repository at this point in the history
- Add error::Error variant
- Add test for an invalid request
  • Loading branch information
threema-donat committed Apr 24, 2024
1 parent e680ba6 commit f53b6c3
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 27 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

## Unreleased

- fix: Avoid panics when acquiring a RwLock<_>
- fix: Avoid panics when acquiring a RwLock<_> and when building HTTP requests

## v0.6.2

Expand Down
61 changes: 35 additions & 26 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ impl Client {
/// See [ErrorReason](enum.ErrorReason.html) for possible errors.
#[cfg_attr(feature = "tracing", ::tracing::instrument)]
pub async fn send<T: PayloadLike>(&self, payload: T) -> Result<Response, Error> {
let request = self.build_request(payload);
let request = self.build_request(payload)?;
let requesting = self.http_client.request(request);

let response = requesting.await?;
Expand Down Expand Up @@ -140,7 +140,7 @@ impl Client {
}
}

fn build_request<T: PayloadLike>(&self, payload: T) -> hyper::Request<Body> {
fn build_request<T: PayloadLike>(&self, payload: T) -> Result<hyper::Request<Body>, Error> {
let path = format!("https://{}/3/device/{}", self.endpoint, payload.get_device_token());

let mut builder = hyper::Request::builder()
Expand Down Expand Up @@ -169,17 +169,16 @@ impl Client {
}
if let Some(ref signer) = self.signer {
let auth = signer
.with_signature(|signature| format!("Bearer {}", signature))
.unwrap();
.with_signature(|signature| format!("Bearer {}", signature))?;

builder = builder.header(AUTHORIZATION, auth.as_bytes());
}

let payload_json = payload.to_json_string().unwrap();
let payload_json = payload.to_json_string()?;
builder = builder.header(CONTENT_LENGTH, format!("{}", payload_json.len()).as_bytes());

let request_body = Body::from(payload_json);
builder.body(request_body).unwrap()
builder.body(request_body).map_err(Error::BuildRequestError)
}
}

Expand All @@ -206,7 +205,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
let builder = DefaultNotificationBuilder::new();
let payload = builder.build("a_test_id", Default::default());
let client = Client::new(AlpnConnector::new(), None, Endpoint::Production);
let request = client.build_request(payload);
let request = client.build_request(payload).unwrap();
let uri = format!("{}", request.uri());

assert_eq!("https://api.push.apple.com/3/device/a_test_id", &uri);
Expand All @@ -217,7 +216,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
let builder = DefaultNotificationBuilder::new();
let payload = builder.build("a_test_id", Default::default());
let client = Client::new(AlpnConnector::new(), None, Endpoint::Sandbox);
let request = client.build_request(payload);
let request = client.build_request(payload).unwrap();
let uri = format!("{}", request.uri());

assert_eq!("https://api.development.push.apple.com/3/device/a_test_id", &uri);
Expand All @@ -228,17 +227,27 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
let builder = DefaultNotificationBuilder::new();
let payload = builder.build("a_test_id", Default::default());
let client = Client::new(AlpnConnector::new(), None, Endpoint::Production);
let request = client.build_request(payload);
let request = client.build_request(payload).unwrap();

assert_eq!(&Method::POST, request.method());
}

#[test]
fn test_request_invalid() {
let builder = DefaultNotificationBuilder::new();
let payload = builder.build("\r\n", Default::default());
let client = Client::new(AlpnConnector::new(), None, Endpoint::Production);
let request = client.build_request(payload);

assert!(matches!(request, Err(Error::BuildRequestError(_))));
}

#[test]
fn test_request_content_type() {
let builder = DefaultNotificationBuilder::new();
let payload = builder.build("a_test_id", Default::default());
let client = Client::new(AlpnConnector::new(), None, Endpoint::Production);
let request = client.build_request(payload);
let request = client.build_request(payload).unwrap();

assert_eq!("application/json", request.headers().get(CONTENT_TYPE).unwrap());
}
Expand All @@ -248,7 +257,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
let builder = DefaultNotificationBuilder::new();
let payload = builder.build("a_test_id", Default::default());
let client = Client::new(AlpnConnector::new(), None, Endpoint::Production);
let request = client.build_request(payload.clone());
let request = client.build_request(payload.clone()).unwrap();
let payload_json = payload.to_json_string().unwrap();
let content_length = request.headers().get(CONTENT_LENGTH).unwrap().to_str().unwrap();

Expand All @@ -260,7 +269,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
let builder = DefaultNotificationBuilder::new();
let payload = builder.build("a_test_id", Default::default());
let client = Client::new(AlpnConnector::new(), None, Endpoint::Production);
let request = client.build_request(payload);
let request = client.build_request(payload).unwrap();

assert_eq!(None, request.headers().get(AUTHORIZATION));
}
Expand All @@ -278,7 +287,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
let builder = DefaultNotificationBuilder::new();
let payload = builder.build("a_test_id", Default::default());
let client = Client::new(AlpnConnector::new(), Some(signer), Endpoint::Production);
let request = client.build_request(payload);
let request = client.build_request(payload).unwrap();

assert_ne!(None, request.headers().get(AUTHORIZATION));
}
Expand All @@ -292,7 +301,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
};
let payload = builder.build("a_test_id", options);
let client = Client::new(AlpnConnector::new(), None, Endpoint::Production);
let request = client.build_request(payload);
let request = client.build_request(payload).unwrap();
let apns_push_type = request.headers().get("apns-push-type").unwrap();

assert_eq!("background", apns_push_type);
Expand All @@ -303,7 +312,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
let builder = DefaultNotificationBuilder::new();
let payload = builder.build("a_test_id", Default::default());
let client = Client::new(AlpnConnector::new(), None, Endpoint::Production);
let request = client.build_request(payload);
let request = client.build_request(payload).unwrap();
let apns_priority = request.headers().get("apns-priority");

assert_eq!(None, apns_priority);
Expand All @@ -322,7 +331,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
);

let client = Client::new(AlpnConnector::new(), None, Endpoint::Production);
let request = client.build_request(payload);
let request = client.build_request(payload).unwrap();
let apns_priority = request.headers().get("apns-priority").unwrap();

assert_eq!("5", apns_priority);
Expand All @@ -341,7 +350,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
);

let client = Client::new(AlpnConnector::new(), None, Endpoint::Production);
let request = client.build_request(payload);
let request = client.build_request(payload).unwrap();
let apns_priority = request.headers().get("apns-priority").unwrap();

assert_eq!("10", apns_priority);
Expand All @@ -354,7 +363,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
let payload = builder.build("a_test_id", Default::default());

let client = Client::new(AlpnConnector::new(), None, Endpoint::Production);
let request = client.build_request(payload);
let request = client.build_request(payload).unwrap();
let apns_id = request.headers().get("apns-id");

assert_eq!(None, apns_id);
Expand All @@ -373,7 +382,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
);

let client = Client::new(AlpnConnector::new(), None, Endpoint::Production);
let request = client.build_request(payload);
let request = client.build_request(payload).unwrap();
let apns_id = request.headers().get("apns-id").unwrap();

assert_eq!("a-test-apns-id", apns_id);
Expand All @@ -386,7 +395,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
let payload = builder.build("a_test_id", Default::default());

let client = Client::new(AlpnConnector::new(), None, Endpoint::Production);
let request = client.build_request(payload);
let request = client.build_request(payload).unwrap();
let apns_expiration = request.headers().get("apns-expiration");

assert_eq!(None, apns_expiration);
Expand All @@ -405,7 +414,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
);

let client = Client::new(AlpnConnector::new(), None, Endpoint::Production);
let request = client.build_request(payload);
let request = client.build_request(payload).unwrap();
let apns_expiration = request.headers().get("apns-expiration").unwrap();

assert_eq!("420", apns_expiration);
Expand All @@ -418,7 +427,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
let payload = builder.build("a_test_id", Default::default());

let client = Client::new(AlpnConnector::new(), None, Endpoint::Production);
let request = client.build_request(payload);
let request = client.build_request(payload).unwrap();
let apns_collapse_id = request.headers().get("apns-collapse-id");

assert_eq!(None, apns_collapse_id);
Expand All @@ -437,7 +446,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
);

let client = Client::new(AlpnConnector::new(), None, Endpoint::Production);
let request = client.build_request(payload);
let request = client.build_request(payload).unwrap();
let apns_collapse_id = request.headers().get("apns-collapse-id").unwrap();

assert_eq!("a_collapse_id", apns_collapse_id);
Expand All @@ -450,7 +459,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
let payload = builder.build("a_test_id", Default::default());

let client = Client::new(AlpnConnector::new(), None, Endpoint::Production);
let request = client.build_request(payload);
let request = client.build_request(payload).unwrap();
let apns_topic = request.headers().get("apns-topic");

assert_eq!(None, apns_topic);
Expand All @@ -469,7 +478,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
);

let client = Client::new(AlpnConnector::new(), None, Endpoint::Production);
let request = client.build_request(payload);
let request = client.build_request(payload).unwrap();
let apns_topic = request.headers().get("apns-topic").unwrap();

assert_eq!("a_topic", apns_topic);
Expand All @@ -480,7 +489,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
let builder = DefaultNotificationBuilder::new();
let payload = builder.build("a_test_id", Default::default());
let client = Client::new(AlpnConnector::new(), None, Endpoint::Production);
let request = client.build_request(payload.clone());
let request = client.build_request(payload.clone()).unwrap();

let body = hyper::body::to_bytes(request).await.unwrap();
let body_str = String::from_utf8(body.to_vec()).unwrap();
Expand Down
4 changes: 4 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ pub enum Error {
#[error("Error in reading a certificate file: {0}")]
ReadError(#[from] io::Error),

/// Error while creating the HTTP request
#[error("Failed to construct HTTP request: {0}")]
BuildRequestError(#[source] http::Error),

/// Unexpected private key (only EC keys are supported).
#[cfg(all(not(feature = "openssl"), feature = "ring"))]
#[error("Unexpected private key: {0}")]
Expand Down

0 comments on commit f53b6c3

Please sign in to comment.