Skip to content

Commit

Permalink
add Amazon AWS Cognito JWK link support for token validation and veri…
Browse files Browse the repository at this point in the history
…fication through jwt.LoadAWSCognitoKeys package-level function
  • Loading branch information
kataras committed Feb 3, 2024
1 parent 0fc2a08 commit 25cef28
Show file tree
Hide file tree
Showing 6 changed files with 255 additions and 4 deletions.
2 changes: 1 addition & 1 deletion LICENSE
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
The MIT License (MIT)

Copyright (c) 2020-2023 Gerasimos Maropoulos <[email protected]>
Copyright (c) 2020-2024 Gerasimos Maropoulos <[email protected]>

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ Import as `import "github.com/kataras/jwt"` and use it as `jwt.XXX`.
* [Encryption](#encryption)
* [Benchmarks](_benchmarks)
* [Examples](_examples)
* [Amazon AWS Cognito Verification](_examples/aws-cognito-verify/main.go) **NEW**
* [Basic](_examples/basic/main.go)
* [Custom Header](_examples/custom-header/main.go)
* [Multiple Key IDs](_examples/multiple-kids/main.go)
Expand Down
41 changes: 41 additions & 0 deletions _examples/aws-cognito-verify/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package main

import (
"fmt"

"github.com/kataras/jwt"
)

// |=================================================================================|
// | Amazon's AWS Cognito integration example for token validation and verification. |
// |=================================================================================|

func main() {
/*
cognitoConfig := jwt.AWSKeysConfiguration{
Region: "us-west-2",
UserPoolID: "us-west-2_xxx",
}
keys, err := cognitoConfig.Load()
if err != nil {
panic(err)
}
OR:
*/
keys, err := jwt.LoadAWSCognitoKeys("us-west-2" /* region */, "us-west-2_xxx" /* user pool id */)
if err != nil {
panic(err) // handle error, e.g. pool does not exist in the region.
}

var tokenToValidate = `xxx.xxx.xxx` // put a token here issued by your own aws cognito user pool to test it.

var claims jwt.Map // Your own custom claims here.
if err := keys.VerifyToken([]byte(tokenToValidate), &claims); err != nil {
panic(err) // handle error, e.g. token expired, or kid is empty.
}

for k, v := range claims {
fmt.Printf("%s: %v\n", k, v)
}
}
11 changes: 11 additions & 0 deletions alg.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,3 +169,14 @@ var (
EdDSA,
}
)

// parseAlg returns the algorithm by its name or nil.
func parseAlg(name string) Alg {
for _, alg := range allAlgs {
if alg.Name() == name {
return alg
}
}

return nil
}
6 changes: 3 additions & 3 deletions jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package jwt
import (
"bytes"
"encoding/json"
"io/ioutil"
"os"
"reflect"
"time"
)
Expand All @@ -30,8 +30,8 @@ var CompareHeader HeaderValidator = compareHeader
// ReadFile can be used to customize the way the
// Must/Load Key function helpers are loading the filenames from.
// Example of usage: embedded key pairs.
// Defaults to the `ioutil.ReadFile` which reads the file from the physical disk.
var ReadFile = ioutil.ReadFile
// Defaults to the `os.ReadFile` which reads the file from the physical disk.
var ReadFile = os.ReadFile

// Marshal same as json.Marshal.
// This variable can be modified to enable custom encoder behavior
Expand Down
198 changes: 198 additions & 0 deletions kid_keys_aws_cognito.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
package jwt

import (
"crypto/rsa"
"encoding/base64"
"encoding/json"
"fmt"
"math/big"
"net/http"
)

// |=========================================================================|
// | Amazon's AWS Cognito integration for token validation and verification. |
// |=========================================================================|

// AWSCognitoKeysConfiguration is a configuration for fetching the JSON Web Key Set from AWS Cognito.
// See `LoadAWSCognitoKeys` and its `Load` and `WithClient` methods.
type AWSCognitoKeysConfiguration struct {
Region string `json:"region" yaml:"Region" toml:"Region" env:"AWS_COGNITO_REGION"` // e.g. "us-west-2"
UserPoolID string `json:"user_pool_id" yaml:"UserPoolID" toml:"Region" env:"AWS_COGNITO_USER_POOL_ID"` // e.g. "us-west-2_XXX"

httpClient HTTPClient
}

// LoadAWSCognitoKeys loads the AWS Cognito JSON Web Key Set from the given region and user pool ID.
// It returns the Keys object or an error if the request fails.
// It uses the default http.Client to fetch the JSON Web Key Set.
// It is a shortcut for the following:
//
// config := jwt.AWSKeysConfiguration{
// Region: region,
// UserPoolID: userPoolID,
// }
// return config.Load()
func LoadAWSCognitoKeys(region, userPoolID string) (Keys, error) {
config := AWSCognitoKeysConfiguration{
Region: region,
UserPoolID: userPoolID,
}
return config.Load()
}

// WithClient sets the HTTP client to be used for fetching the JSON Web Key Set from AWS Cognito.
// If not set, the default http.Client is used.
func (c *AWSCognitoKeysConfiguration) WithClient(httpClient HTTPClient) *AWSCognitoKeysConfiguration {
c.httpClient = httpClient
return c
}

// Load fetches the JSON Web Key Set from AWS Cognito and parses it into a jwt.Keys object.
// It returns the Keys object or an error if the request fails.
// If the HTTP client is not set, the default http.Client is used.
//
// Calls the `ParseAWSCognitoKeys` function with the given configuration.
func (c *AWSCognitoKeysConfiguration) Load() (Keys, error) {
httpClient := c.httpClient
if httpClient == nil {
httpClient = http.DefaultClient
}

return ParseAWSCognitoKeys(httpClient, c.Region, c.UserPoolID)
}

// JWKSet represents a JSON Web Key Set.
type JWKSet struct {
Keys []*JWK `json:"keys"`
}

// JWK represents a JSON Web Key.
type JWK struct {
Kty string `json:"kty"`
N string `json:"n"`
E string `json:"e"`
Kid string `json:"kid"`
Alg string `json:"alg"`
Use string `json:"use"`
}

// HTTPClient is an interface that can be used to mock the http.Client.
// It is used to fetch the JSON Web Key Set from AWS Cognito.
type HTTPClient interface {
Get(string) (*http.Response, error)
}

// ParseAWSCognitoKeys fetches the JSON Web Key Set from AWS Cognito and parses it into a jwt.Keys object.
func ParseAWSCognitoKeys(client HTTPClient, region, userPoolID string) (Keys, error) {
set, err := fetchAWSCognitoJWKSet(client, region, userPoolID)
if err != nil {
return nil, err
}

return parseAWSCognitoJWKSet(set)
}

// AWSCognitoError represents an error response from AWS Cognito.
// It implements the error interface.
type AWSCognitoError struct {
StatusCode int
Message string `json:"message"`
}

// Error returns the error message.
func (e AWSCognitoError) Error() string {
return e.Message
}

// fetchAWSCognitoJWKSet fetches the JSON Web Key Set from AWS Cognito.
// It returns the JWKSet object or an error if the request fails.
func fetchAWSCognitoJWKSet(
client HTTPClient,
region string,
userPoolID string,
) (*JWKSet, error) {
url := fmt.Sprintf("https://cognito-idp.%s.amazonaws.com/%s/.well-known/jwks.json", region, userPoolID)

resp, err := client.Get(url)
if err != nil {
return nil, err
}
defer resp.Body.Close()

if resp.StatusCode >= 400 {
fetchErr := AWSCognitoError{
StatusCode: resp.StatusCode,
}

err = json.NewDecoder(resp.Body).Decode(&fetchErr)
if err != nil {
return nil, fmt.Errorf("jwt: cannot decode error message: %w", err)
}

return nil, fetchErr
}

var jwkSet JWKSet
err = json.NewDecoder(resp.Body).Decode(&jwkSet)
if err != nil {
return nil, err
}

return &jwkSet, nil
}

// parseAWSCognitoJWKSet parses the JWKSet object into a jwt.Keys object.
// It returns the Keys object or an error if the parsing fails.
// It filters out unsupported algorithms.
func parseAWSCognitoJWKSet(set *JWKSet) (Keys, error) {
keys := make(Keys, len(set.Keys))
for _, key := range set.Keys {
alg := parseAlg(key.Alg)
if alg == nil {
continue
}

publicKey, err := convertJWKToPublicKey(key)
if err != nil {
return nil, err
}

keys[key.Kid] = &Key{
ID: key.Kid,
Alg: alg,
Public: publicKey,
}
}

return keys, nil
}

// convertJWKToPublicKey converts a JWK object to a *rsa.PublicKey object.
func convertJWKToPublicKey(jwk *JWK) (*rsa.PublicKey, error) {
// decode the n and e values from base64.
nBytes, err := base64.RawURLEncoding.DecodeString(jwk.N)
if err != nil {
return nil, err
}
eBytes, err := base64.RawURLEncoding.DecodeString(jwk.E)
if err != nil {
return nil, err
}

// construct a big.Int from the n bytes.
n := new(big.Int).SetBytes(nBytes)

// construct an int from the e bytes.
var e int
for _, b := range eBytes {
e = e<<8 + int(b)
}

// construct a *rsa.PublicKey from the n and e values.
pubKey := &rsa.PublicKey{
N: n,
E: e,
}

return pubKey, nil
}

0 comments on commit 25cef28

Please sign in to comment.