Skip to content

Commit

Permalink
KNOX-2603 - Caching passcode token verifications (apache#448)
Browse files Browse the repository at this point in the history
  • Loading branch information
smolnar82 authored May 17, 2021
1 parent b55e9ce commit 2bd305d
Show file tree
Hide file tree
Showing 7 changed files with 62 additions and 57 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -551,7 +551,7 @@ private HadoopAuthFilter testIfJwtSupported(String supportJwt) throws Exception
expect(filterConfig.getInitParameter(JWTFederationFilter.JWT_UNAUTHENTICATED_PATHS_PARAM)).andReturn(null).anyTimes();
expect(filterConfig.getInitParameter(AbstractJWTFilter.JWT_EXPECTED_ISSUER)).andReturn(null).anyTimes();
expect(filterConfig.getInitParameter(AbstractJWTFilter.JWT_EXPECTED_SIGALG)).andReturn(null).anyTimes();
expect(filterConfig.getInitParameter(SignatureVerificationCache.JWT_VERIFIED_CACHE_MAX)).andReturn(null).anyTimes();
expect(filterConfig.getInitParameter(SignatureVerificationCache.TOKENS_VERIFIED_CACHE_MAX)).andReturn(null).anyTimes();
}

final ServletContext servletContext = createMock(ServletContext.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -382,14 +382,18 @@ protected boolean validateToken(final HttpServletRequest request,
try {
if (tokenId != null) {
if (tokenIsStillValid(tokenId)) {
if (validatePasscode(tokenId, passcode)) {
if (hasSignatureBeenVerified(passcode) || validatePasscode(tokenId, passcode)) {
return true;
} else {
log.wrongPasscodeToken(tokenId);
handleValidationError(request, response, HttpServletResponse.SC_BAD_REQUEST, "Bad request: wrong passcode");
}
} else {
log.tokenHasExpired(Tokens.getTokenIDDisplayText(tokenId));
// Explicitly evict the record of this token's signature verification (if present).
// There is no value in keeping this record for expired tokens, and explicitly removing them may prevent
// records for other valid tokens from being prematurely evicted from the cache.
removeSignatureVerificationRecord(passcode);
handleValidationError(request, response, HttpServletResponse.SC_BAD_REQUEST,
"Bad request: token has expired");
}
Expand All @@ -412,7 +416,11 @@ private boolean validatePasscode(String tokenId, String passcode) throws Unknown
final TokenMetadata tokenMetadata = tokenStateService.getTokenMetadata(tokenId);
final String userName = tokenMetadata == null ? "" : tokenMetadata.getUserName();
final byte[] storedPasscode = tokenMetadata == null ? null : tokenMetadata.getPasscode().getBytes(UTF_8);
return Arrays.equals(tokenMAC.hash(tokenId, issueTime, userName, passcode).getBytes(UTF_8), storedPasscode);
final boolean validPasscode = Arrays.equals(tokenMAC.hash(tokenId, issueTime, userName, passcode).getBytes(UTF_8), storedPasscode);
if (validPasscode) {
recordSignatureVerification(passcode);
}
return validPasscode;
}

protected boolean verifyTokenSignature(final JWT token) {
Expand Down Expand Up @@ -459,32 +467,32 @@ protected boolean verifyTokenSignature(final JWT token) {
}

/**
* Determine if the specified JWT signature has previously been successfully verified.
* Determine if the specified JWT or Passcode token signature has previously been successfully verified.
*
* @param jwt A serialized JWT String.
* @param token A serialized JWT String or Passcode token.
*
* @return true, if the specified token has been previously verified; Otherwise, false.
*/
protected boolean hasSignatureBeenVerified(final String jwt) {
return signatureVerificationCache.hasSignatureBeenVerified(jwt);
protected boolean hasSignatureBeenVerified(final String token) {
return signatureVerificationCache.hasSignatureBeenVerified(token);
}

/**
* Record a successful JWT signature verification.
* Record a successful JWT or Passcode token signature verification.
*
* @param jwt The serialized String for a JWT which has been successfully verified.
* @param token The serialized String for a JWT or Passcode token which has been successfully verified.
*/
protected void recordSignatureVerification(final String jwt) {
signatureVerificationCache.recordSignatureVerification(jwt);
protected void recordSignatureVerification(final String token) {
signatureVerificationCache.recordSignatureVerification(token);
}

/**
* Explicitly evict the signature verification record for the specified JWT from the cache if it exists.
*
* @param jwt The serialized String for a JWT whose signature verification record should be evicted.
* @param token The serialized String for a JWT or Passcode token whose signature verification record should be evicted.
*/
protected void removeSignatureVerificationRecord(final String jwt) {
signatureVerificationCache.removeSignatureVerificationRecord(jwt);
protected void removeSignatureVerificationRecord(final String token) {
signatureVerificationCache.removeSignatureVerificationRecord(token);
}

protected abstract void handleValidationError(HttpServletRequest request, HttpServletResponse response, int status,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,8 @@ public void doFilter(ServletRequest request, ServletResponse response, FilterCha
} else if (TokenType.Passcode.equals(tokenType)) {
// Validate the token based on the server-managed metadata
// The received token value must be a Base64 encoded value of Base64(tokenId)::Base64(rawPasscode)
String tokenId = null, passcode = null;
String tokenId = null;
String passcode = null;
try {
final String[] base64DecodedTokenIdAndPasscode = decodeBase64(tokenValue).split("::");
tokenId = decodeBase64(base64DecodedTokenIdAndPasscode[0]);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@
*/
public class SignatureVerificationCache {

public static final String JWT_VERIFIED_CACHE_MAX = "jwt.verified.cache.max";
public static final int JWT_VERIFIED_CACHE_MAX_DEFAULT = 250;
public static final String TOKENS_VERIFIED_CACHE_MAX = "tokens.verified.cache.max";
private static final int TOKENS_VERIFIED_CACHE_MAX_DEFAULT = 250;

static final String DEFAULT_CACHE_ID = "default-cache";

Expand Down Expand Up @@ -71,9 +71,9 @@ private SignatureVerificationCache(final FilterConfig config) {
* @param config The configuration of the provider employing this cache.
*/
private void initializeVerifiedTokensCache(final FilterConfig config) {
int maxCacheSize = JWT_VERIFIED_CACHE_MAX_DEFAULT;
int maxCacheSize = TOKENS_VERIFIED_CACHE_MAX_DEFAULT;

String configValue = config.getInitParameter(JWT_VERIFIED_CACHE_MAX);
String configValue = config.getInitParameter(TOKENS_VERIFIED_CACHE_MAX);
if (configValue != null && !configValue.isEmpty()) {
try {
maxCacheSize = Integer.parseInt(configValue);
Expand All @@ -86,32 +86,32 @@ private void initializeVerifiedTokensCache(final FilterConfig config) {
}

/**
* Determine if the specified JWT's signature has previously been successfully verified.
* Determine if the specified token's signature has previously been successfully verified.
*
* @param jwt A serialized JWT.
* @param token A serialized JWT or Passcode token.
*
* @return true, if the specified token has been previously verified; Otherwise, false.
*/
public boolean hasSignatureBeenVerified(final String jwt) {
return (verifiedTokens.getIfPresent(jwt) != null);
public boolean hasSignatureBeenVerified(final String token) {
return (verifiedTokens.getIfPresent(token) != null);
}

/**
* Record a successful token signature verification.
*
* @param jwt A serialized JWT for which the signature has been successfully verified.
* @param token A serialized JWT or Passcode token for which the signature has been successfully verified.
*/
public void recordSignatureVerification(final String jwt) {
verifiedTokens.put(jwt, true);
public void recordSignatureVerification(final String token) {
verifiedTokens.put(token, true);
}

/**
* Explicitly evict the signature verification record from the cache if it exists.
*
* @param jwt The serialized JWT for which the associated signature verification record should be evicted.
* @param token The serialized JWT or Passcode token for which the associated signature verification record should be evicted.
*/
public void removeSignatureVerificationRecord(final String jwt) {
verifiedTokens.asMap().remove(jwt);
public void removeSignatureVerificationRecord(final String token) {
verifiedTokens.asMap().remove(token);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
public abstract class AbstractJWTFilterTest {
protected static final String SERVICE_URL = "https://localhost:8888/resource";
private static final String dnTemplate = "CN={0},OU=Test,O=Hadoop,L=Test,ST=Test,C=US";
protected static final String PASSCODE_CLAIM = "passcode";

protected AbstractJWTFilter handler;
protected static RSAPublicKey publicKey;
Expand Down Expand Up @@ -802,7 +803,7 @@ public void doTestVerificationOptimization(boolean includeTokenId) throws Except

Properties props = getProperties();
props.put(AbstractJWTFilter.JWT_EXPECTED_SIGALG, "RS512");
props.put(SignatureVerificationCache.JWT_VERIFIED_CACHE_MAX, "1");
props.put(SignatureVerificationCache.TOKENS_VERIFIED_CACHE_MAX, "1");
props.put(TestFilterConfig.TOPOLOGY_NAME_PROP, "jwt-verification-optimization-test");
handler.init(new TestFilterConfig(props));
Assert.assertEquals("Expected no token verification calls yet.",
Expand Down Expand Up @@ -856,7 +857,7 @@ public void testExpiredTokensEvictedFromSignatureVerificationCache() throws Exce

Properties props = getProperties();
props.put(AbstractJWTFilter.JWT_EXPECTED_SIGALG, "RS512");
props.put(SignatureVerificationCache.JWT_VERIFIED_CACHE_MAX, "1");
props.put(SignatureVerificationCache.TOKENS_VERIFIED_CACHE_MAX, "1");
props.put(TestFilterConfig.TOPOLOGY_NAME_PROP, "jwt-eviction-test");
handler.init(new TestFilterConfig(props));
Assert.assertEquals("Expected no token verification calls yet.",
Expand Down Expand Up @@ -1009,7 +1010,8 @@ protected SignedJWT getJWT(final String issuer,
.audience(aud)
.expirationTime(expires)
.notBeforeTime(nbf)
.claim("scope", "openid");
.claim("scope", "openid")
.claim(PASSCODE_CLAIM, UUID.randomUUID().toString());
if (knoxId != null) {
builder.claim(JWTToken.KNOX_ID_CLAIM, knoxId);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,21 @@
*/
package org.apache.knox.gateway.provider.federation;

import com.nimbusds.jwt.SignedJWT;
import static org.junit.Assert.fail;

import java.nio.charset.StandardCharsets;
import java.text.ParseException;
import java.time.Instant;
import java.util.Date;
import java.util.Locale;
import java.util.Map;
import java.util.Properties;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;

import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

import org.apache.commons.codec.binary.Base64;
import org.apache.commons.codec.digest.HmacAlgorithms;
Expand All @@ -37,22 +51,7 @@
import org.junit.Assert;
import org.junit.Test;

import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

import java.nio.charset.StandardCharsets;
import java.text.ParseException;
import java.time.Instant;
import java.util.Date;
import java.util.Locale;
import java.util.Map;
import java.util.Properties;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;

import static org.junit.Assert.fail;
import com.nimbusds.jwt.SignedJWT;

@SuppressWarnings({"PMD.JUnit4TestShouldUseBeforeAnnotation", "PMD.JUnit4TestShouldUseTestAnnotation"})
public class TokenIDAsHTTPBasicCredsFederationFilterTest extends JWTAsHTTPBasicCredsFederationFilterTest {
Expand All @@ -74,7 +73,7 @@ protected void setTokenOnRequest(final HttpServletRequest request, final SignedJ
try {
final long issueTime = System.currentTimeMillis() - TimeUnit.MINUTES.toMillis(5);
final String subject = (String) jwt.getJWTClaimsSet().getClaim(JWTToken.SUBJECT);
final String passcode = UUID.randomUUID().toString();
final String passcode = (String) jwt.getJWTClaimsSet().getClaims().get(PASSCODE_CLAIM);
addTokenState(jwt, issueTime, subject, passcode);
setTokenOnRequest(request, TestJWTFederationFilter.PASSCODE, generatePasscodeField(getTokenId(jwt), passcode));
} catch(ParseException e) {
Expand All @@ -101,7 +100,7 @@ protected void setTokenOnRequest(final HttpServletRequest request,
try {
final long issueTime = System.currentTimeMillis() - TimeUnit.MINUTES.toMillis(5);
final String subject = (String) jwt.getJWTClaimsSet().getClaim(JWTToken.SUBJECT);
final String passcode = UUID.randomUUID().toString();
final String passcode = (String) jwt.getJWTClaimsSet().getClaims().get(PASSCODE_CLAIM);
addTokenState(jwt, issueTime, subject, passcode);
setTokenOnRequest(request, authUsername, generatePasscodeField(getTokenId(jwt), passcode));
} catch(ParseException e) {
Expand Down Expand Up @@ -338,11 +337,6 @@ public void testNotBeforeJWT() throws Exception {
// Override to disable N/A test
}

@Override
public void testVerificationOptimization() throws Exception {
// Override to disable N/A test
}

@Override
public void testExpiredTokensEvictedFromSignatureVerificationCache() throws Exception {
// Override to disable N/A test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ public long getTokenIssueTime(String tokenId) throws UnknownTokenException {
}
issueTime = convertCharArrayToLong(issueTimeStr);
// Update the in-memory cache to avoid subsequent keystore look-ups for the same state
super.setIssueTime(tokenId, issueTime);
setIssueTimeInMemory(tokenId, issueTime);
} catch (UnknownTokenException e) {
throw e;
} catch (Exception e) {
Expand Down

0 comments on commit 2bd305d

Please sign in to comment.