diff --git a/benches/jwt.rs b/benches/jwt.rs index d2fee79e..8d6d9e9c 100644 --- a/benches/jwt.rs +++ b/benches/jwt.rs @@ -1,5 +1,7 @@ use criterion::{black_box, criterion_group, criterion_main, Criterion}; -use jsonwebtoken::{decode, encode, Algorithm, DecodingKey, EncodingKey, Header, Validation}; +use jsonwebtoken::{ + decode, decode_header, encode, Algorithm, DecodingKey, EncodingKey, Header, Validation, +}; use serde::{Deserialize, Serialize}; #[derive(Debug, PartialEq, Clone, Serialize, Deserialize)] @@ -21,7 +23,10 @@ fn bench_decode(c: &mut Criterion) { let token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ"; let key = DecodingKey::from_secret("secret".as_ref()); - c.bench_function("bench_decode", |b| { + let mut group = c.benchmark_group("decode"); + group.throughput(criterion::Throughput::Bytes(token.len() as u64)); + + group.bench_function("str", |b| { b.iter(|| { decode::( black_box(token), @@ -30,6 +35,19 @@ fn bench_decode(c: &mut Criterion) { ) }) }); + + drop(group); + let mut group = c.benchmark_group("header"); + group.throughput(criterion::Throughput::Bytes(token.len() as u64)); + + group.bench_function("str", |b| { + b.iter(|| { + decode_header( + // Simulate the cost of validating &str before decoding + black_box(token), + ) + }) + }); } criterion_group!(benches, bench_encode, bench_decode); diff --git a/src/crypto/mod.rs b/src/crypto/mod.rs index c2957dc8..9c89170a 100644 --- a/src/crypto/mod.rs +++ b/src/crypto/mod.rs @@ -46,7 +46,7 @@ pub fn sign(message: &[u8], key: &EncodingKey, algorithm: Algorithm) -> Result, message: &[u8], key: &[u8], ) -> Result { @@ -66,16 +66,17 @@ fn verify_ring( /// /// `message` is base64(header) + "." + base64(claims) pub fn verify( - signature: &str, + signature: impl AsRef<[u8]>, message: &[u8], key: &DecodingKey, algorithm: Algorithm, ) -> Result { + let signature = signature.as_ref(); match algorithm { Algorithm::HS256 | Algorithm::HS384 | Algorithm::HS512 => { // we just re-sign the message with the key and compare if they are equal let signed = sign(message, &EncodingKey::from_secret(key.as_bytes()), algorithm)?; - Ok(verify_slices_are_equal(signature.as_ref(), signed.as_ref()).is_ok()) + Ok(verify_slices_are_equal(signature, signed.as_ref()).is_ok()) } Algorithm::ES256 | Algorithm::ES384 => verify_ring( ecdsa::alg_to_ec_verification(algorithm), diff --git a/src/crypto/rsa.rs b/src/crypto/rsa.rs index 4c97db3c..e5af3b29 100644 --- a/src/crypto/rsa.rs +++ b/src/crypto/rsa.rs @@ -51,7 +51,7 @@ pub(crate) fn sign( /// Checks that a signature is valid based on the (n, e) RSA pubkey components pub(crate) fn verify_from_components( alg: &'static signature::RsaParameters, - signature: &str, + signature: impl AsRef<[u8]>, message: &[u8], components: (&[u8], &[u8]), ) -> Result { diff --git a/src/decoding.rs b/src/decoding.rs index 8d87f03d..bcc78c63 100644 --- a/src/decoding.rs +++ b/src/decoding.rs @@ -204,11 +204,11 @@ impl DecodingKey { /// Verify signature of a JWT, and return header object and raw payload /// /// If the token or its signature is invalid, it will return an error. -fn verify_signature<'a>( - token: &'a str, +fn verify_signature_bytes<'a>( + token: &'a [u8], key: &DecodingKey, validation: &Validation, -) -> Result<(Header, &'a str)> { +) -> Result<(Header, &'a [u8])> { if validation.validate_signature && validation.algorithms.is_empty() { return Err(new_error(ErrorKind::MissingAlgorithm)); } @@ -221,15 +221,15 @@ fn verify_signature<'a>( } } - let (signature, message) = expect_two!(token.rsplitn(2, '.')); - let (payload, header) = expect_two!(message.rsplitn(2, '.')); + let (signature, message) = expect_two!(token.rsplitn(2, |b| *b == b'.')); + let (payload, header) = expect_two!(message.rsplitn(2, |b| *b == b'.')); let header = Header::from_encoded(header)?; if validation.validate_signature && !validation.algorithms.contains(&header.alg) { return Err(new_error(ErrorKind::InvalidAlgorithm)); } - if validation.validate_signature && !verify(signature, message.as_bytes(), key, header.alg)? { + if validation.validate_signature && !verify(signature, message, key, header.alg)? { return Err(new_error(ErrorKind::InvalidSignature)); } @@ -250,16 +250,17 @@ fn verify_signature<'a>( /// company: String /// } /// -/// let token = "a.jwt.token".to_string(); +/// let token = "a.jwt.token"; /// // Claims is a struct that implements Deserialize -/// let token_message = decode::(&token, &DecodingKey::from_secret("secret".as_ref()), &Validation::new(Algorithm::HS256)); +/// let token_message = decode::(token, &DecodingKey::from_secret("secret".as_ref()), &Validation::new(Algorithm::HS256)); /// ``` pub fn decode( - token: &str, + token: impl AsRef<[u8]>, key: &DecodingKey, validation: &Validation, ) -> Result> { - match verify_signature(token, key, validation) { + let token = token.as_ref(); + match verify_signature_bytes(token, key, validation) { Err(e) => Err(e), Ok((header, claims)) => { let decoded_claims = DecodedJwtPartClaims::from_jwt_part_claims(claims)?; @@ -278,11 +279,12 @@ pub fn decode( /// ```rust /// use jsonwebtoken::decode_header; /// -/// let token = "a.jwt.token".to_string(); -/// let header = decode_header(&token); +/// let token = "a.jwt.token"; +/// let header = decode_header(token); /// ``` -pub fn decode_header(token: &str) -> Result
{ - let (_, message) = expect_two!(token.rsplitn(2, '.')); - let (_, header) = expect_two!(message.rsplitn(2, '.')); +pub fn decode_header(token: impl AsRef<[u8]>) -> Result
{ + let token = token.as_ref(); + let (_, message) = expect_two!(token.rsplitn(2, |b| *b == b'.')); + let (_, header) = expect_two!(message.rsplitn(2, |b| *b == b'.')); Header::from_encoded(header) } diff --git a/tests/ecdsa/mod.rs b/tests/ecdsa/mod.rs index 8c06910f..73934764 100644 --- a/tests/ecdsa/mod.rs +++ b/tests/ecdsa/mod.rs @@ -26,7 +26,7 @@ fn round_trip_sign_verification_pk8() { let encrypted = sign(b"hello world", &EncodingKey::from_ec_der(privkey), Algorithm::ES256).unwrap(); let is_valid = - verify(&encrypted, b"hello world", &DecodingKey::from_ec_der(pubkey), Algorithm::ES256) + verify(encrypted, b"hello world", &DecodingKey::from_ec_der(pubkey), Algorithm::ES256) .unwrap(); assert!(is_valid); } @@ -41,7 +41,7 @@ fn round_trip_sign_verification_pem() { sign(b"hello world", &EncodingKey::from_ec_pem(privkey_pem).unwrap(), Algorithm::ES256) .unwrap(); let is_valid = verify( - &encrypted, + encrypted, b"hello world", &DecodingKey::from_ec_pem(pubkey_pem).unwrap(), Algorithm::ES256, diff --git a/tests/eddsa/mod.rs b/tests/eddsa/mod.rs index 85dd0245..d7d89d23 100644 --- a/tests/eddsa/mod.rs +++ b/tests/eddsa/mod.rs @@ -26,7 +26,7 @@ fn round_trip_sign_verification_pk8() { let encrypted = sign(b"hello world", &EncodingKey::from_ed_der(privkey), Algorithm::EdDSA).unwrap(); let is_valid = - verify(&encrypted, b"hello world", &DecodingKey::from_ed_der(pubkey), Algorithm::EdDSA) + verify(encrypted, b"hello world", &DecodingKey::from_ed_der(pubkey), Algorithm::EdDSA) .unwrap(); assert!(is_valid); } @@ -41,7 +41,7 @@ fn round_trip_sign_verification_pem() { sign(b"hello world", &EncodingKey::from_ed_pem(privkey_pem).unwrap(), Algorithm::EdDSA) .unwrap(); let is_valid = verify( - &encrypted, + encrypted, b"hello world", &DecodingKey::from_ed_pem(pubkey_pem).unwrap(), Algorithm::EdDSA,