Skip to content

Commit

Permalink
Thread-safe ServiceLoader usage
Browse files Browse the repository at this point in the history
Blend of pre-0.11.0 behavior that cached implementation instances and post-0.11.0 behavior using the JDK ServiceLoader to find/create instances of an SPI interface.  This change:

- Reinstates the <= 0.10.x behavior of caching application singleton service implementation instances in a thread-safe reference (previously an AtomicReference, but in this change, a ConcurrentMap).  If an app singleton instance is cached and found, it is returned to be (re)used immediately when requested.  This is ok for JJWT's purposes because all service implementations instances must be thread-safe application singletons by API contract/design, so caching them for repeated use is fine.

- Ensures that only if a service implementation instance is not in the app singleton cache, a new instance is located/created using a new JDK ServiceLoader instance, which doesn't require thread-safe considerations since it is used only in a single-threaded model for the short time it is used to discover a service implementation.  This PR/change removes the post-0.11.0 concurrent cache of ServiceLoader instances since they themselves are not designed to be thread-safe.

- Ensures that if a ServiceLoader discovers an implementation and returns a new instance, that instance is then cached as an application singleton in the aforementioned ConcurrentMap for continued reuse.

- Renames Services#loadFirst to Services#get to more accurately reflect calling expectations:  The fact that any 'loading' via the ServiceLoader may occur is not important for Services callers, and the previous method name was unnecessarily exposing internal implementation concepts.  This is safe to do in a point release (0.12.3 -> 0.12.4) because the Services class and its methods, while public, are in the `impl` module, only to be used internally for JJWT's purpose and never intended to be used by application developers.

- Updates all test methods to use the renamed method accordingly.

Fixes #873
  • Loading branch information
lhazlewood authored Jan 17, 2024
1 parent 406f2f3 commit d878404
Show file tree
Hide file tree
Showing 13 changed files with 57 additions and 100 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,7 @@ public String compact() {

if (this.serializer == null) { // try to find one based on the services available
//noinspection unchecked
json(Services.loadFirst(Serializer.class));
json(Services.get(Serializer.class));
}

if (!Collections.isEmpty(claims)) { // normalize so we have one object to deal with:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ public JwtParser build() {

if (this.deserializer == null) {
//noinspection unchecked
json(Services.loadFirst(Deserializer.class));
json(Services.get(Deserializer.class));
}
if (this.signingKeyResolver != null && this.signatureVerificationKey != null) {
String msg = "Both a 'signingKeyResolver and a 'verifyWith' key cannot be configured. " +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ public B json(Deserializer<Map<String, ?>> reader) {
public final Parser<T> build() {
if (this.deserializer == null) {
//noinspection unchecked
this.deserializer = Services.loadFirst(Deserializer.class);
this.deserializer = Services.get(Deserializer.class);
}
return doBuild();
}
Expand Down
100 changes: 35 additions & 65 deletions impl/src/main/java/io/jsonwebtoken/impl/lang/Services.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,25 +15,24 @@
*/
package io.jsonwebtoken.impl.lang;

import io.jsonwebtoken.lang.Arrays;
import io.jsonwebtoken.lang.Assert;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.ServiceLoader;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;

import static io.jsonwebtoken.lang.Collections.arrayToList;

/**
* Helper class for loading services from the classpath, using a {@link ServiceLoader}. Decouples loading logic for
* better separation of concerns and testability.
*/
public final class Services {

private static ConcurrentMap<Class<?>, ServiceLoader<?>> SERVICE_CACHE = new ConcurrentHashMap<>();
private static final ConcurrentMap<Class<?>, Object> SERVICES = new ConcurrentHashMap<>();

private static final List<ClassLoaderAccessor> CLASS_LOADER_ACCESSORS = arrayToList(new ClassLoaderAccessor[] {
private static final List<ClassLoaderAccessor> CLASS_LOADER_ACCESSORS = Arrays.asList(new ClassLoaderAccessor[]{
new ClassLoaderAccessor() {
@Override
public ClassLoader getClassLoader() {
Expand All @@ -54,86 +53,57 @@ public ClassLoader getClassLoader() {
}
});

private Services() {}

/**
* Loads and instantiates all service implementation of the given SPI class and returns them as a List.
*
* @param spi The class of the Service Provider Interface
* @param <T> The type of the SPI
* @return An unmodifiable list with an instance of all available implementations of the SPI. No guarantee is given
* on the order of implementations, if more than one.
*/
public static <T> List<T> loadAll(Class<T> spi) {
Assert.notNull(spi, "Parameter 'spi' must not be null.");

ServiceLoader<T> serviceLoader = serviceLoader(spi);
if (serviceLoader != null) {

List<T> implementations = new ArrayList<>();
for (T implementation : serviceLoader) {
implementations.add(implementation);
}
return implementations;
}

throw new UnavailableImplementationException(spi);
private Services() {
}

/**
* Loads the first available implementation the given SPI class from the classpath. Uses the {@link ServiceLoader}
* to find implementations. When multiple implementations are available it will return the first one that it
* encounters. There is no guarantee with regard to ordering.
* Returns the first available implementation for the given SPI class, checking an internal thread-safe cache first,
* and, if not found, using a {@link ServiceLoader} to find implementations. When multiple implementations are
* available it will return the first one that it encounters. There is no guarantee with regard to ordering.
*
* @param spi The class of the Service Provider Interface
* @param <T> The type of the SPI
* @return A new instance of the service.
* @throws UnavailableImplementationException When no implementation the SPI is available on the classpath.
* @return The first available instance of the service.
* @throws UnavailableImplementationException When no implementation of the SPI class can be found.
*/
public static <T> T loadFirst(Class<T> spi) {
Assert.notNull(spi, "Parameter 'spi' must not be null.");

ServiceLoader<T> serviceLoader = serviceLoader(spi);
if (serviceLoader != null) {
return serviceLoader.iterator().next();
public static <T> T get(Class<T> spi) {
// TODO: JDK8, replace this find/putIfAbsent logic with ConcurrentMap.computeIfAbsent
T instance = findCached(spi);
if (instance == null) {
instance = loadFirst(spi); // throws UnavailableImplementationException if not found, which is what we want
SERVICES.putIfAbsent(spi, instance); // cache if not already cached
}

throw new UnavailableImplementationException(spi);
return instance;
}

/**
* Returns a ServiceLoader for <code>spi</code> class, checking multiple classloaders. The ServiceLoader
* will be cached if it contains at least one implementation of the <code>spi</code> class.<BR>
*
* <b>NOTE:</b> Only the first Serviceloader will be cached.
* @param spi The interface or abstract class representing the service loader.
* @return A service loader, or null if no implementations are found
* @param <T> The type of the SPI.
*/
private static <T> ServiceLoader<T> serviceLoader(Class<T> spi) {
// TODO: JDK8, replace this get/putIfAbsent logic with ConcurrentMap.computeIfAbsent
ServiceLoader<T> serviceLoader = (ServiceLoader<T>) SERVICE_CACHE.get(spi);
if (serviceLoader != null) {
return serviceLoader;
private static <T> T findCached(Class<T> spi) {
Assert.notNull(spi, "Service interface cannot be null.");
Object obj = SERVICES.get(spi);
if (obj != null) {
return Assert.isInstanceOf(spi, obj, "Unexpected cached service implementation type.");
}
return null;
}

for (ClassLoaderAccessor classLoaderAccessor : CLASS_LOADER_ACCESSORS) {
serviceLoader = ServiceLoader.load(spi, classLoaderAccessor.getClassLoader());
if (serviceLoader.iterator().hasNext()) {
SERVICE_CACHE.putIfAbsent(spi, serviceLoader);
return serviceLoader;
private static <T> T loadFirst(Class<T> spi) {
for (ClassLoaderAccessor accessor : CLASS_LOADER_ACCESSORS) {
ServiceLoader<T> loader = ServiceLoader.load(spi, accessor.getClassLoader());
Assert.stateNotNull(loader, "JDK ServiceLoader#load should never return null.");
Iterator<T> i = loader.iterator();
Assert.stateNotNull(i, "JDK ServiceLoader#iterator() should never return null.");
if (i.hasNext()) {
return i.next();
}
}

return null;
throw new UnavailableImplementationException(spi);
}

/**
* Clears internal cache of ServiceLoaders. This is useful when testing, or for applications that dynamically
* Clears internal cache of service singletons. This is useful when testing, or for applications that dynamically
* change classloaders.
*/
public static void reload() {
SERVICE_CACHE.clear();
SERVICES.clear();
}

private interface ClassLoaderAccessor {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ private JwksBridge() {

@SuppressWarnings({"unchecked", "unused"}) // used via reflection by io.jsonwebtoken.security.Jwks
public static String UNSAFE_JSON(Jwk<?> jwk) {
Serializer<Map<String, ?>> serializer = Services.loadFirst(Serializer.class);
Serializer<Map<String, ?>> serializer = Services.get(Serializer.class);
Assert.stateNotNull(serializer, "Serializer lookup failed. Ensure JSON impl .jar is in the runtime classpath.");
NamedSerializer ser = new NamedSerializer("JWK", serializer);
ByteArrayOutputStream out = new ByteArrayOutputStream(512);
Expand Down
4 changes: 2 additions & 2 deletions impl/src/test/groovy/io/jsonwebtoken/JwtsTest.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ class JwtsTest {
}

static def toJson(def o) {
def serializer = Services.loadFirst(Serializer)
def serializer = Services.get(Serializer)
def out = new ByteArrayOutputStream()
serializer.serialize(o, out)
return Strings.utf8(out.toByteArray())
Expand Down Expand Up @@ -1192,7 +1192,7 @@ class JwtsTest {
int j = jws.lastIndexOf('.')
def b64 = jws.substring(i, j)
def json = Strings.utf8(Decoders.BASE64URL.decode(b64))
def deser = Services.loadFirst(Deserializer)
def deser = Services.get(Deserializer)
def m = deser.deserialize(new StringReader(json)) as Map<String,?>

assertEquals aud, m.get('aud') // single string value
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ import static org.junit.Assert.fail

class RFC7515AppendixETest {

static final Serializer<Map<String, ?>> serializer = Services.loadFirst(Serializer)
static final Deserializer<Map<String, ?>> deserializer = Services.loadFirst(Deserializer)
static final Serializer<Map<String, ?>> serializer = Services.get(Serializer)
static final Deserializer<Map<String, ?>> deserializer = Services.get(Deserializer)

static byte[] ser(def value) {
ByteArrayOutputStream baos = new ByteArrayOutputStream(512)
Expand Down
4 changes: 1 addition & 3 deletions impl/src/test/groovy/io/jsonwebtoken/RFC7797Test.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,9 @@ class RFC7797Test {
def claims = Jwts.claims().subject('me').build()

ByteArrayOutputStream out = new ByteArrayOutputStream()
Services.loadFirst(Serializer).serialize(claims, out)
Services.get(Serializer).serialize(claims, out)
byte[] content = out.toByteArray()

//byte[] content = Services.loadFirst(Serializer).serialize(claims)

String s = Jwts.builder().signWith(key).content(content).encodePayload(false).compact()

// But verify with 3 types of sources: string, byte array, and two different kinds of InputStreams:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,15 @@ class DefaultJwtBuilderTest {
private DefaultJwtBuilder builder

private static byte[] serialize(Map<String, ?> map) {
def serializer = Services.loadFirst(Serializer)
def serializer = Services.get(Serializer)
ByteArrayOutputStream out = new ByteArrayOutputStream(512)
serializer.serialize(map, out)
return out.toByteArray()
}

private static Map<String, ?> deser(byte[] data) {
def reader = Streams.reader(data)
Map<String, ?> m = Services.loadFirst(Deserializer).deserialize(reader) as Map<String, ?>
Map<String, ?> m = Services.get(Deserializer).deserialize(reader) as Map<String, ?>
return m
}

Expand Down Expand Up @@ -749,7 +749,7 @@ class DefaultJwtBuilderTest {
// so we need to check the raw payload:
def encoded = new JwtTokenizer().tokenize(Streams.reader(jwt)).getPayload()
byte[] bytes = Decoders.BASE64URL.decode(encoded)
def claims = Services.loadFirst(Deserializer).deserialize(Streams.reader(bytes))
def claims = Services.get(Deserializer).deserialize(Streams.reader(bytes))

assertEquals two, claims.aud
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class DefaultJwtParserTest {
}

private static byte[] serialize(Map<String, ?> map) {
def serializer = Services.loadFirst(Serializer)
def serializer = Services.get(Serializer)
ByteArrayOutputStream out = new ByteArrayOutputStream(512)
serializer.serialize(map, out)
return out.toByteArray()
Expand Down
2 changes: 1 addition & 1 deletion impl/src/test/groovy/io/jsonwebtoken/impl/RfcTests.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class RfcTests {

static final Map<String, ?> jsonToMap(String json) {
Reader r = new CharSequenceReader(json)
Map<String, ?> m = Services.loadFirst(Deserializer).deserialize(r) as Map<String, ?>
Map<String, ?> m = Services.get(Deserializer).deserialize(r) as Map<String, ?>
return m
}

Expand Down
25 changes: 7 additions & 18 deletions impl/src/test/groovy/io/jsonwebtoken/impl/lang/ServicesTest.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -20,32 +20,21 @@ import io.jsonwebtoken.impl.DefaultStubService
import org.junit.After
import org.junit.Test

import static org.junit.Assert.*
import static org.junit.Assert.assertEquals
import static org.junit.Assert.assertNotNull

class ServicesTest {

@Test
void testSuccessfulLoading() {
def factory = Services.loadFirst(StubService)
assertNotNull factory
assertEquals(DefaultStubService, factory.class)
def service = Services.get(StubService)
assertNotNull service
assertEquals(DefaultStubService, service.class)
}

@Test(expected = UnavailableImplementationException)
void testLoadFirstUnavailable() {
Services.loadFirst(NoService.class)
}

@Test
void testLoadAllAvailable() {
def list = Services.loadAll(StubService.class)
assertEquals 1, list.size()
assertTrue list[0] instanceof StubService
}

@Test(expected = UnavailableImplementationException)
void testLoadAllUnavailable() {
Services.loadAll(NoService.class)
void testLoadUnavailable() {
Services.get(NoService.class)
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class RFC7518AppendixCTest {
}

private static final Map<String, ?> fromJson(String s) {
return Services.loadFirst(Deserializer).deserialize(new StringReader(s)) as Map<String, ?>
return Services.get(Deserializer).deserialize(new StringReader(s)) as Map<String, ?>
}

private static EcPrivateJwk readJwk(String json) {
Expand Down

0 comments on commit d878404

Please sign in to comment.