Skip to content

Commit

Permalink
KNOX-2575 - Add kid and jku claims to JWT tokens issues by Knox (a…
Browse files Browse the repository at this point in the history
  • Loading branch information
moresandeep authored Apr 20, 2021
1 parent 4a44fb7 commit c1e8a3c
Show file tree
Hide file tree
Showing 15 changed files with 261 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,12 @@
import java.util.Arrays;
import java.util.HashSet;
import java.util.Map;
import java.util.Optional;
import java.util.Set;

import org.apache.knox.gateway.GatewayResources;
import org.apache.knox.gateway.config.GatewayConfig;
import org.apache.knox.gateway.i18n.messages.MessagesFactory;
import org.apache.knox.gateway.i18n.resources.ResourcesFactory;
import org.apache.knox.gateway.services.Service;
import org.apache.knox.gateway.services.ServiceLifecycleException;
Expand Down Expand Up @@ -73,6 +75,7 @@

public class DefaultTokenAuthorityService implements JWTokenAuthority, Service {
private static final GatewayResources RESOURCES = ResourcesFactory.get(GatewayResources.class);
private static final TokenAuthorityServiceMessages LOG = MessagesFactory.get(TokenAuthorityServiceMessages.class);

// Only standard RSA and HMAC signature algorithms are accepted
// https://tools.ietf.org/html/rfc7518
Expand All @@ -86,6 +89,8 @@ public class DefaultTokenAuthorityService implements JWTokenAuthority, Service {
private byte[] cachedSigningHmacSecret;
private RSAPrivateKey signingKey;

private Optional<String> cachedSigningKeyID = Optional.empty();

public void setKeystoreService(KeystoreService ks) {
this.keystoreService = ks;
}
Expand All @@ -96,7 +101,7 @@ public void setAliasService(AliasService as) {

@Override
public JWT issueToken(JWTokenAttributes jwtAttributes) throws TokenServiceException {
String[] claimArray = new String[4];
String[] claimArray = new String[6];
claimArray[0] = "KNOXSSO";
claimArray[1] = jwtAttributes.getPrincipal().getName();
claimArray[2] = null;
Expand All @@ -106,8 +111,14 @@ public JWT issueToken(JWTokenAttributes jwtAttributes) throws TokenServiceExcept
else {
claimArray[3] = String.valueOf(jwtAttributes.getExpires());
}

final String algorithm = jwtAttributes.getAlgorithm();
if(SUPPORTED_HMAC_SIG_ALGS.contains(algorithm)) {
claimArray[4] = null;
claimArray[5] = null;
} else {
claimArray[4] = cachedSigningKeyID.isPresent() ? cachedSigningKeyID.get() : null;
claimArray[5] = jwtAttributes.getJku();
}
final JWT token = SUPPORTED_PKI_SIG_ALGS.contains(algorithm) || SUPPORTED_HMAC_SIG_ALGS.contains(algorithm) ? new JWTToken(algorithm, claimArray, jwtAttributes.getAudiences(), jwtAttributes.isManaged()) : null;
if (token != null) {
if (SUPPORTED_HMAC_SIG_ALGS.contains(algorithm)) {
Expand Down Expand Up @@ -289,8 +300,13 @@ public void start() throws ServiceLifecycleException {
else if (! (publicKey instanceof RSAPublicKey)) {
throw new ServiceLifecycleException(RESOURCES.publicSigningKeyWrongType(signingKeyAlias));
}
cachedSigningKeyID = Optional.of(TokenUtils.getThumbprint((RSAPublicKey) publicKey, "SHA-256"));
} catch (KeyStoreException e) {
throw new ServiceLifecycleException(RESOURCES.publicSigningKeyNotFound(signingKeyAlias), e);
} catch (final JOSEException e) {
/* in case there is an error getting KID log and move one */
LOG.errorGettingKid(e.toString());
cachedSigningKeyID = Optional.empty();
}

// Ensure that the private signing keys is available
Expand All @@ -311,4 +327,8 @@ else if (! (key instanceof RSAPrivateKey)) {
@Override
public void stop() throws ServiceLifecycleException {
}

protected Optional<String> getCachedSigningKeyID() {
return cachedSigningKeyID;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.knox.gateway.services.token.impl;

import org.apache.knox.gateway.i18n.messages.Message;
import org.apache.knox.gateway.i18n.messages.MessageLevel;
import org.apache.knox.gateway.i18n.messages.Messages;

@Messages(logger = "org.apache.knox.gateway.services.token.state")
public interface TokenAuthorityServiceMessages {
@Message(level = MessageLevel.ERROR, text = "There was an error getting kid, cause: {0}")
void errorGettingKid(String message);
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import java.security.interfaces.RSAPublicKey;
import java.util.Collections;
import java.util.HashMap;
import java.util.Optional;

import org.apache.knox.gateway.config.GatewayConfig;
import org.apache.knox.gateway.services.ServiceLifecycleException;
Expand Down Expand Up @@ -337,6 +338,7 @@ public void testTokenCreationCustomSigningKey() throws Exception {
ta.setAliasService(as);
ta.setKeystoreService(ks);
ta.init(config, new HashMap<>());
ta.start();

final JWTokenAttributes jwtAttributes = new JWTokenAttributesBuilder().setPrincipal(principal).setAudiences(Collections.emptyList()).setAlgorithm("RS256").setExpires(-1)
.setSigningKeystoreName(customSigningKeyName).setSigningKeystoreAlias(customSigningKeyAlias).setSigningKeystorePassphrase(customSigningKeyPassphrase.toCharArray()).build();
Expand Down Expand Up @@ -584,4 +586,61 @@ public void testServiceInvalidKeyPassword() throws Exception {

EasyMock.verify(config, ms, as);
}

/**
* Test getSigningCertKid() function
* @throws Exception
*/
@Test
public void testGetSigningCertKid() throws Exception {
Principal principal = EasyMock.createNiceMock(Principal.class);
EasyMock.expect(principal.getName()).andReturn("[email protected]");

GatewayConfig config = EasyMock.createNiceMock(GatewayConfig.class);
String basedir = System.getProperty("basedir");
if (basedir == null) {
basedir = new File(".").getCanonicalPath();
}

EasyMock.expect(config.getGatewaySecurityDir()).andReturn(basedir + "/target/test-classes").anyTimes();
EasyMock.expect(config.getGatewayKeystoreDir()).andReturn(basedir + "/target/test-classes/keystores").anyTimes();
EasyMock.expect(config.getSigningKeystoreName()).andReturn("server-keystore.jks").anyTimes();
EasyMock.expect(config.getSigningKeystorePath()).andReturn(basedir + "/target/test-classes/keystores/server-keystore.jks").anyTimes();
EasyMock.expect(config.getSigningKeystorePasswordAlias()).andReturn(GatewayConfig.DEFAULT_SIGNING_KEYSTORE_PASSWORD_ALIAS).anyTimes();
EasyMock.expect(config.getSigningKeyPassphraseAlias()).andReturn(GatewayConfig.DEFAULT_SIGNING_KEY_PASSPHRASE_ALIAS).anyTimes();
EasyMock.expect(config.getSigningKeystoreType()).andReturn("jks").anyTimes();
EasyMock.expect(config.getSigningKeyAlias()).andReturn("server").anyTimes();
EasyMock.expect(config.getCredentialStoreType()).andReturn(GatewayConfig.DEFAULT_CREDENTIAL_STORE_TYPE).anyTimes();
EasyMock.expect(config.getCredentialStoreAlgorithm()).andReturn(GatewayConfig.DEFAULT_CREDENTIAL_STORE_ALG).anyTimes();

MasterService ms = EasyMock.createNiceMock(MasterService.class);
EasyMock.expect(ms.getMasterSecret()).andReturn("horton".toCharArray());

AliasService as = EasyMock.createNiceMock(AliasService.class);
EasyMock.expect(as.getSigningKeyPassphrase()).andReturn("horton".toCharArray()).anyTimes();

EasyMock.replay(principal, config, ms, as);

DefaultKeystoreService ks = new DefaultKeystoreService();
ks.setMasterService(ms);

ks.init(config, new HashMap<>());

DefaultTokenAuthorityService ta = new DefaultTokenAuthorityService();

/* negative test */
/* expectation that that the exception is eaten up in case where there was an exception getting kid */
Optional<String> opt = ta.getCachedSigningKeyID();
assertFalse(opt.isPresent());

/* now test for cases where we expect to get kid */
ta.setAliasService(as);
ta.setKeystoreService(ks);

ta.init(config, new HashMap<>());
ta.start();

opt = ta.getCachedSigningKeyID();
assertTrue("Missing expected KID value", opt.isPresent());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -360,14 +360,15 @@ protected TokenStateService createTokenStateService() throws Exception {

/* create a test JWT token */
protected JWT getJWTToken(final long expiry) {
String[] claims = new String[4];
String[] claims = new String[6];
claims[0] = "KNOXSSO";
claims[1] = "[email protected]";
claims[2] = "https://login.example.com";
if(expiry > 0) {
claims[3] = Long.toString(expiry);
}

claims[4] = "E0LDZulQ0XE_otJ5aoQtQu-RnXv8hU-M9U4dD7vDioA";
claims[5] = null;
JWT token = new JWTToken("RS256", claims);
// Sign the token
JWSSigner signer = new RSASSASigner(privateKey);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -768,7 +768,7 @@ public boolean verifyToken(JWT token) throws TokenServiceException {
@Override
public JWT issueToken(JWTokenAttributes jwtAttributes)
throws TokenServiceException {
String[] claimArray = new String[4];
String[] claimArray = new String[6];
claimArray[0] = "KNOXSSO";
claimArray[1] = jwtAttributes.getPrincipal().getName();
claimArray[2] = null;
Expand All @@ -777,6 +777,8 @@ public JWT issueToken(JWTokenAttributes jwtAttributes)
} else {
claimArray[3] = String.valueOf(jwtAttributes.getExpires());
}
claimArray[4] = "E0LDZulQ0XE_otJ5aoQtQu-RnXv8hU-M9U4dD7vDioA";
claimArray[5] = null;

JWT token = new JWTToken(jwtAttributes.getAlgorithm(), claimArray, jwtAttributes.getAudiences());
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.apache.knox.gateway.services.ServiceType;
import org.apache.knox.gateway.services.security.KeystoreService;
import org.apache.knox.gateway.services.security.KeystoreServiceException;
import org.apache.knox.gateway.services.security.token.TokenUtils;

import javax.annotation.PostConstruct;
import javax.inject.Singleton;
Expand All @@ -47,9 +48,9 @@
@Singleton
@Path(JWKSResource.RESOURCE_PATH)
public class JWKSResource {

public static final String JWKS_PATH = "/jwks.json";
static final String RESOURCE_PATH = "knoxtoken/api/v1";
static final String JWKS_PATH = "/jwks.json";

@Context
HttpServletRequest request;
@Context
Expand Down Expand Up @@ -80,10 +81,11 @@ private Response getJwks(final String keystore) {
.entity(new JWKSet().toJSONObject().toString()).build();
}

final String kid = TokenUtils.getThumbprint(rsa, "SHA-256");
final RSAKey.Builder builder = new RSAKey.Builder(rsa)
.keyUse(KeyUse.SIGNATURE)
.algorithm(new JWSAlgorithm(rsa.getAlgorithm()))
.keyIDFromThumbprint();
.keyID(kid);

jwks = new JWKSet(builder.build());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ public class TokenResource {
private static final String TOKEN_EXP_RENEWAL_MAX_LIFETIME = "knox.token.exp.max-lifetime";
private static final String TOKEN_RENEWER_WHITELIST = "knox.token.renewer.whitelist";
private static final long TOKEN_TTL_DEFAULT = 30000L;
static final String RESOURCE_PATH = "knoxtoken/api/v1/token";
static final String TOKEN_API_PATH = "knoxtoken/api/v1";
static final String RESOURCE_PATH = TOKEN_API_PATH + "/token";
static final String RENEW_PATH = "/renew";
static final String REVOKE_PATH = "/revoke";
private static final String TARGET_ENDPOINT_PULIC_CERT_PEM = "knox.token.target.endpoint.cert.pem";
Expand Down Expand Up @@ -393,7 +394,6 @@ private Response getAuthenticationToken() {
try {
Certificate cert = ks.getCertificateForGateway();
byte[] bytes = cert.getEncoded();
//Base64 encoder = new Base64(76, "\n".getBytes("ASCII"));
endpointPublicCert = Base64.encodeBase64String(bytes);
} catch (KeyStoreException | KeystoreServiceException | CertificateEncodingException e) {
// assuming that certs will be properly provisioned across all clients
Expand All @@ -402,19 +402,31 @@ private Response getAuthenticationToken() {
}
}

String jku = null;
/* remove .../token and replace it with ..../jwks.json */
final int idx = request.getRequestURL().lastIndexOf("/");
if(idx > 1) {
jku = request.getRequestURL().substring(0, idx) + JWKSResource.JWKS_PATH;
}

try {
final boolean managedToken = tokenStateService != null;
JWT token;
JWTokenAttributes jwtAttributes;
if (targetAudiences.isEmpty()) {
jwtAttributes = new JWTokenAttributesBuilder().setPrincipal(p).setAlgorithm(signatureAlgorithm).setExpires(expires).setManaged(managedToken).build();
token = ts.issueToken(jwtAttributes);
} else {
jwtAttributes = new JWTokenAttributesBuilder().setPrincipal(p).setAudiences(targetAudiences).setAlgorithm(signatureAlgorithm).setExpires(expires)
.setManaged(managedToken).build();
token = ts.issueToken(jwtAttributes);
final JWTokenAttributesBuilder jwtAttributesBuilder = new JWTokenAttributesBuilder();
jwtAttributesBuilder
.setPrincipal(p)
.setAlgorithm(signatureAlgorithm)
.setExpires(expires)
.setManaged(managedToken)
.setJku(jku);
if (!targetAudiences.isEmpty()) {
jwtAttributesBuilder.setAudiences(targetAudiences);
}

jwtAttributes = jwtAttributesBuilder.build();
token = ts.issueToken(jwtAttributes);

if (token != null) {
String accessToken = token.toString();
String tokenId = TokenUtils.getTokenId(token);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,13 @@ public void testE2E()
}

private JWT getTestToken(final String algorithm) {
String[] claimArray = new String[4];
String[] claimArray = new String[6];
claimArray[0] = "KNOXSSO";
claimArray[1] = "[email protected]";
claimArray[2] = null;
claimArray[3] = null;
claimArray[4] = "E0LDZulQ0XE_otJ5aoQtQu-RnXv8hU-M9U4dD7vDioA";
claimArray[5] = null;

final JWT token = new JWTToken(algorithm, claimArray,
Collections.singletonList("aud"), false);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,10 @@ public class TokenServiceResourceTest {
private static RSAPublicKey publicKey;
private static RSAPrivateKey privateKey;

private static String TOKEN_API_PATH = "https://gateway-host:8443/gateway/sandbox/knoxtoken/api/v1";
private static String TOKEN_PATH = "/token";
private static String JKWS_PATH = "/jwks.json";

private ServletContext context;
private HttpServletRequest request;
private JWTokenAuthority authority;
Expand Down Expand Up @@ -124,6 +128,7 @@ private void configureCommonExpectations(Map<String, String> contextExpectations
Principal principal = EasyMock.createNiceMock(Principal.class);
EasyMock.expect(principal.getName()).andReturn("alice").anyTimes();
EasyMock.expect(request.getUserPrincipal()).andReturn(principal).anyTimes();
EasyMock.expect(request.getRequestURL()).andReturn(new StringBuffer(TOKEN_API_PATH+TOKEN_PATH)).anyTimes();

GatewayServices services = EasyMock.createNiceMock(GatewayServices.class);
EasyMock.expect(context.getAttribute(GatewayServices.GATEWAY_SERVICES_ATTRIBUTE)).andReturn(services).anyTimes();
Expand Down Expand Up @@ -759,6 +764,35 @@ public void testTokenRevocation_Enabled_WithRenewersWithValidSubject() throws Ex
validateSuccessfulRevocationResponse(renewalResponse);
}

@Test
public void testKidJkuClaims() throws Exception {
final Map<String, String> contextExpectations = new HashMap<>();
contextExpectations.put("knox.token.ttl", "60000");
configureCommonExpectations(contextExpectations);

TokenResource tr = new TokenResource();
tr.request = request;
tr.context = context;
tr.init();

// Issue a token
Response retResponse = tr.doGet();

assertEquals(200, retResponse.getStatus());

// Parse the response
final String retString = retResponse.getEntity().toString();
final String accessToken = getTagValue(retString, "access_token");
assertNotNull(accessToken);

// Verify the token
final JWT parsedToken = new JWTToken(accessToken);
assertEquals("alice", parsedToken.getSubject());
assertTrue(authority.verifyToken(parsedToken));

assertNotNull(parsedToken.getClaim("kid"));
assertEquals(TOKEN_API_PATH+JKWS_PATH, parsedToken.getClaim("jku"));
}

/**
*
Expand Down Expand Up @@ -1170,7 +1204,7 @@ public boolean verifyToken(JWT token) {

@Override
public JWT issueToken(JWTokenAttributes jwtAttributes) {
String[] claimArray = new String[4];
String[] claimArray = new String[6];
claimArray[0] = "KNOXSSO";
claimArray[1] = jwtAttributes.getPrincipal().getName();
claimArray[2] = null;
Expand All @@ -1179,6 +1213,8 @@ public JWT issueToken(JWTokenAttributes jwtAttributes) {
} else {
claimArray[3] = String.valueOf(jwtAttributes.getExpires());
}
claimArray[4] = "E0LDZulQ0XE_otJ5aoQtQu-RnXv8hU-M9U4dD7vDioA";
claimArray[5] = jwtAttributes.getJku();

JWT token = new JWTToken(jwtAttributes.getAlgorithm(), claimArray, jwtAttributes.getAudiences());
JWSSigner signer = new RSASSASigner(privateKey);
Expand Down
Loading

0 comments on commit c1e8a3c

Please sign in to comment.