From a46dd8bff00a6a4ad9d2bee04e29ac73bc6e058f Mon Sep 17 00:00:00 2001 From: Pritham Marupaka Date: Mon, 6 Jan 2025 11:55:42 -0500 Subject: [PATCH] wip --- .../annotations/ConjureErrorDecoder.java | 1 + .../java/dialogue/serde/ConjureBodySerDe.java | 153 ++++-------------- .../dialogue/serde/EndpointErrorDecoder.java | 42 +++-- .../dialogue/serde/ConjureBodySerDeTest.java | 16 +- .../EndpointErrorsConjureBodySerDeTest.java | 1 + 5 files changed, 72 insertions(+), 141 deletions(-) diff --git a/dialogue-annotations/src/main/java/com/palantir/dialogue/annotations/ConjureErrorDecoder.java b/dialogue-annotations/src/main/java/com/palantir/dialogue/annotations/ConjureErrorDecoder.java index 382c97560..1bf3a01eb 100644 --- a/dialogue-annotations/src/main/java/com/palantir/dialogue/annotations/ConjureErrorDecoder.java +++ b/dialogue-annotations/src/main/java/com/palantir/dialogue/annotations/ConjureErrorDecoder.java @@ -18,6 +18,7 @@ import com.palantir.dialogue.Response; +// TODO(pm): use the new EndpointErrorDecoder public final class ConjureErrorDecoder implements ErrorDecoder { @Override diff --git a/dialogue-serde/src/main/java/com/palantir/conjure/java/dialogue/serde/ConjureBodySerDe.java b/dialogue-serde/src/main/java/com/palantir/conjure/java/dialogue/serde/ConjureBodySerDe.java index 8e67cddd2..0ed2af86e 100644 --- a/dialogue-serde/src/main/java/com/palantir/conjure/java/dialogue/serde/ConjureBodySerDe.java +++ b/dialogue-serde/src/main/java/com/palantir/conjure/java/dialogue/serde/ConjureBodySerDe.java @@ -45,6 +45,7 @@ import java.util.ArrayList; import java.util.Comparator; import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.function.Supplier; import java.util.stream.Collectors; @@ -52,7 +53,6 @@ /** * items: * - we don't want to use `String` for the error identifier. Let's create an `ErrorName` class. - * - re-consider using a map for the deserializersForEndpointBaseType field. is there a more direct way to get this info */ /** Package private internal API. */ @@ -65,7 +65,7 @@ final class ConjureBodySerDe implements BodySerDe { private final Deserializer> optionalBinaryInputStreamDeserializer; private final Deserializer emptyBodyDeserializer; private final LoadingCache> serializers; - private final LoadingCache> deserializers; + private final LoadingCache> deserializers; private final EmptyContainerDeserializer emptyContainerDeserializer; /** @@ -75,32 +75,49 @@ final class ConjureBodySerDe implements BodySerDe { */ ConjureBodySerDe( List rawEncodings, - ErrorDecoder errorDecoder, + ErrorDecoder _errorDecoder, EmptyContainerDeserializer emptyContainerDeserializer, CaffeineSpec cacheSpec) { List encodings = decorateEncodings(rawEncodings); this.encodingsSortedByWeight = sortByWeight(encodings); Preconditions.checkArgument(encodings.size() > 0, "At least one Encoding is required"); + // note(pm): why do the weighted encoding thing? can we just pass in the default encoding? this.defaultEncoding = encodings.get(0).encoding(); this.emptyContainerDeserializer = emptyContainerDeserializer; - this.binaryInputStreamDeserializer = new EncodingDeserializerRegistry<>( + this.binaryInputStreamDeserializer = new EncodingDeserializerForEndpointRegistry<>( ImmutableList.of(BinaryEncoding.INSTANCE), - errorDecoder, emptyContainerDeserializer, - BinaryEncoding.MARKER); - this.optionalBinaryInputStreamDeserializer = new EncodingDeserializerRegistry<>( + BinaryEncoding.MARKER, + DeserializerArgs.builder() + .withBaseType(BinaryEncoding.MARKER) + .withExpectedResult(BinaryEncoding.MARKER) + .build()); + this.optionalBinaryInputStreamDeserializer = new EncodingDeserializerForEndpointRegistry<>( ImmutableList.of(BinaryEncoding.INSTANCE), - errorDecoder, emptyContainerDeserializer, - BinaryEncoding.OPTIONAL_MARKER); - this.emptyBodyDeserializer = new EmptyBodyDeserializer(errorDecoder); + BinaryEncoding.OPTIONAL_MARKER, + DeserializerArgs.>builder() + .withBaseType(BinaryEncoding.OPTIONAL_MARKER) + .withExpectedResult(BinaryEncoding.OPTIONAL_MARKER) + .build()); + this.emptyBodyDeserializer = + new EmptyBodyDeserializer(new EndpointErrorDecoder<>(Map.of(), encodingsSortedByWeight)); // Class unloading: Not supported, Jackson keeps strong references to the types // it sees: https://github.com/FasterXML/jackson-databind/issues/489 this.serializers = Caffeine.from(cacheSpec) .build(type -> new EncodingSerializerRegistry<>(defaultEncoding, TypeMarker.of(type))); - this.deserializers = Caffeine.from(cacheSpec) - .build(type -> new EncodingDeserializerRegistry<>( - encodingsSortedByWeight, errorDecoder, emptyContainerDeserializer, TypeMarker.of(type))); + this.deserializers = Caffeine.from(cacheSpec).build(type -> buildCacheEntry(TypeMarker.of(type))); + } + + private EncodingDeserializerForEndpointRegistry buildCacheEntry(TypeMarker typeMarker) { + return new EncodingDeserializerForEndpointRegistry<>( + encodingsSortedByWeight, + emptyContainerDeserializer, + typeMarker, + DeserializerArgs.builder() + .withBaseType(typeMarker) + .withExpectedResult(typeMarker) + .build()); } private static List decorateEncodings(List input) { @@ -235,108 +252,7 @@ private static final class EncodingSerializerContainer { } } - private static final class EncodingDeserializerRegistry implements Deserializer { - - private static final SafeLogger log = SafeLoggerFactory.get(EncodingDeserializerRegistry.class); - private final ImmutableList> encodings; - private final ErrorDecoder errorDecoder; - private final Optional acceptValue; - private final Supplier> emptyInstance; - private final TypeMarker token; - - EncodingDeserializerRegistry( - List encodings, - ErrorDecoder errorDecoder, - EmptyContainerDeserializer empty, - TypeMarker token) { - this.encodings = encodings.stream() - .map(encoding -> new EncodingDeserializerContainer<>(encoding, token)) - .collect(ImmutableList.toImmutableList()); - this.errorDecoder = errorDecoder; - this.token = token; - this.emptyInstance = Suppliers.memoize(() -> empty.tryGetEmptyInstance(token)); - // Encodings are applied to the accept header in the order of preference based on the provided list. - this.acceptValue = - Optional.of(encodings.stream().map(Encoding::getContentType).collect(Collectors.joining(", "))); - } - - @Override - public T deserialize(Response response) { - boolean closeResponse = true; - try { - if (errorDecoder.isError(response)) { - throw errorDecoder.decode(response); - } else if (response.code() == 204) { - // TODO(dfox): what if we get a 204 for a non-optional type??? - // TODO(dfox): support http200 & body=null - // TODO(dfox): what if we were expecting an empty list but got {}? - Optional maybeEmptyInstance = emptyInstance.get(); - if (maybeEmptyInstance.isPresent()) { - return maybeEmptyInstance.get(); - } - throw new SafeRuntimeException( - "Unable to deserialize non-optional response type from 204", SafeArg.of("type", token)); - } - - Optional contentType = response.getFirstHeader(HttpHeaders.CONTENT_TYPE); - if (!contentType.isPresent()) { - throw new SafeIllegalArgumentException( - "Response is missing Content-Type header", - SafeArg.of("received", response.headers().keySet())); - } - Encoding.Deserializer deserializer = getResponseDeserializer(contentType.get()); - T deserialized = deserializer.deserialize(response.body()); - // deserializer has taken on responsibility for closing the response body - closeResponse = false; - return deserialized; - } catch (IOException e) { - throw new SafeRuntimeException( - "Failed to deserialize response stream", - e, - SafeArg.of("contentType", response.getFirstHeader(HttpHeaders.CONTENT_TYPE)), - SafeArg.of("type", token)); - } finally { - if (closeResponse) { - response.close(); - } - } - } - - @Override - public Optional accepts() { - return acceptValue; - } - - /** Returns the {@link EncodingDeserializerContainer} to use to deserialize the request body. */ - @SuppressWarnings("ForLoopReplaceableByForEach") - // performance sensitive code avoids iterator allocation - Encoding.Deserializer getResponseDeserializer(String contentType) { - for (int i = 0; i < encodings.size(); i++) { - EncodingDeserializerContainer container = encodings.get(i); - if (container.encoding.supportsContentType(contentType)) { - return container.deserializer; - } - } - return throwingDeserializer(contentType); - } - - private Encoding.Deserializer throwingDeserializer(String contentType) { - return input -> { - try { - input.close(); - } catch (RuntimeException | IOException e) { - log.warn("Failed to close InputStream", e); - } - throw new SafeRuntimeException( - "Unsupported Content-Type", - SafeArg.of("received", contentType), - SafeArg.of("supportedEncodings", encodings)); - }; - } - } - private static final class EncodingDeserializerForEndpointRegistry implements Deserializer { - private static final SafeLogger log = SafeLoggerFactory.get(EncodingDeserializerForEndpointRegistry.class); private final ImmutableList> encodings; private final EndpointErrorDecoder endpointErrorDecoder; @@ -367,7 +283,8 @@ public T deserialize(Response response) { boolean closeResponse = true; try { if (endpointErrorDecoder.isError(response)) { - // TODO(pm): This needs to return T for the new deserializer API, but throw an exception for the old + // TODO(pm): This needs to return T for the new deserializer API, but throw an exception for the + // old. We need to introduce this branching logic for rollout. return endpointErrorDecoder.decode(response); } else if (response.code() == 204) { Optional maybeEmptyInstance = emptyInstance.get(); @@ -457,9 +374,9 @@ public String toString() { } private static final class EmptyBodyDeserializer implements Deserializer { - private final ErrorDecoder errorDecoder; + private final EndpointErrorDecoder errorDecoder; - EmptyBodyDeserializer(ErrorDecoder errorDecoder) { + EmptyBodyDeserializer(EndpointErrorDecoder errorDecoder) { this.errorDecoder = errorDecoder; } @@ -469,7 +386,7 @@ public Void deserialize(Response response) { // We should not fail if a server that previously returned nothing starts returning a response try (Response unused = response) { if (errorDecoder.isError(response)) { - throw errorDecoder.decode(response); + errorDecoder.decode(response); } return null; } diff --git a/dialogue-serde/src/main/java/com/palantir/conjure/java/dialogue/serde/EndpointErrorDecoder.java b/dialogue-serde/src/main/java/com/palantir/conjure/java/dialogue/serde/EndpointErrorDecoder.java index f2ecfdfcc..dff70a07b 100644 --- a/dialogue-serde/src/main/java/com/palantir/conjure/java/dialogue/serde/EndpointErrorDecoder.java +++ b/dialogue-serde/src/main/java/com/palantir/conjure/java/dialogue/serde/EndpointErrorDecoder.java @@ -127,27 +127,32 @@ private T decodeInternal(Response response) { } Optional contentType = response.getFirstHeader(HttpHeaders.CONTENT_TYPE); - // Use a factory: given contentType, create the deserailizer. + // Use a factory: given contentType, create the deserializer. // We need Encoding.Deserializer here. That depends on the encoding. - if (contentType.isPresent() && Encodings.matchesContentType("application/json", contentType.get())) { + String jsonContentType = "application/json"; + if (contentType.isPresent() && Encodings.matchesContentType(jsonContentType, contentType.get())) { try { JsonNode node = MAPPER.readTree(body); - if (node.get("errorName") != null) { - // TODO(pm): Update this to use some struct instead of errorName. - TypeMarker container = Optional.ofNullable( - errorNameToTypeMap.get(node.get("errorName").asText())) - .orElseThrow(); - for (int i = 0; i < encodings.size(); i++) { - Encoding encoding = encodings.get(i); - if (encoding.supportsContentType(contentType.get())) { - return encoding.deserializer(container) - .deserialize(new ByteArrayInputStream(body.getBytes(StandardCharsets.UTF_8))); - } + if (node.get("errorName") == null) { + throwSerializableError(body, code); + } + // TODO(pm): Update this to use some struct instead of errorName. + Optional> maybeContainer = Optional.ofNullable( + errorNameToTypeMap.get(node.get("errorName").asText())); + if (maybeContainer.isEmpty()) { + // This thrown exception will be caught below. Refactor. + throwSerializableError(body, code); + } + for (int i = 0; i < encodings.size(); i++) { + Encoding encoding = encodings.get(i); + if (encoding.supportsContentType(jsonContentType)) { + return encoding.deserializer(maybeContainer.get()) + .deserialize(new ByteArrayInputStream(body.getBytes(StandardCharsets.UTF_8))); } - } else { - SerializableError serializableError = MAPPER.readValue(body, SerializableError.class); - throw new RemoteException(serializableError, code); } + } catch (RemoteException remoteException) { + // rethrow the created remote exception + throw remoteException; } catch (Exception e) { throw new UnknownRemoteException(code, body); } @@ -156,6 +161,11 @@ private T decodeInternal(Response response) { throw new UnknownRemoteException(code, body); } + private static void throwSerializableError(String body, int code) throws IOException { + SerializableError serializableError = MAPPER.readValue(body, SerializableError.class); + throw new RemoteException(serializableError, code); + } + private static String toString(InputStream body) throws IOException { try (Reader reader = new InputStreamReader(body, StandardCharsets.UTF_8)) { return CharStreams.toString(reader); diff --git a/dialogue-serde/src/test/java/com/palantir/conjure/java/dialogue/serde/ConjureBodySerDeTest.java b/dialogue-serde/src/test/java/com/palantir/conjure/java/dialogue/serde/ConjureBodySerDeTest.java index da7ea260c..38ef424de 100644 --- a/dialogue-serde/src/test/java/com/palantir/conjure/java/dialogue/serde/ConjureBodySerDeTest.java +++ b/dialogue-serde/src/test/java/com/palantir/conjure/java/dialogue/serde/ConjureBodySerDeTest.java @@ -22,11 +22,14 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.collect.ImmutableList; import com.palantir.conjure.java.api.errors.ErrorType; import com.palantir.conjure.java.api.errors.RemoteException; import com.palantir.conjure.java.api.errors.SerializableError; import com.palantir.conjure.java.api.errors.ServiceException; +import com.palantir.conjure.java.serialization.ObjectMappers; import com.palantir.dialogue.BinaryRequestBody; import com.palantir.dialogue.BodySerDe; import com.palantir.dialogue.RequestBody; @@ -47,6 +50,7 @@ @ExtendWith(MockitoExtension.class) public class ConjureBodySerDeTest { + private static final ObjectMapper SERVER_MAPPER = ObjectMappers.newServerObjectMapper(); private static final TypeMarker TYPE = new TypeMarker() {}; private static final TypeMarker> OPTIONAL_TYPE = new TypeMarker>() {}; @@ -137,14 +141,12 @@ public void testRequestUnknownContentType() throws IOException { } @Test - public void testErrorsDecoded() { - TestResponse response = new TestResponse().code(400); - + public void testErrorsDecoded() throws JsonProcessingException { ServiceException serviceException = new ServiceException(ErrorType.INVALID_ARGUMENT); - SerializableError serialized = SerializableError.forException(serviceException); - errorDecoder = mock(ErrorDecoder.class); - when(errorDecoder.isError(response)).thenReturn(true); - when(errorDecoder.decode(response)).thenReturn(new RemoteException(serialized, 400)); + TestResponse response = TestResponse.withBody( + SERVER_MAPPER.writeValueAsString(SerializableError.forException(serviceException))) + .code(400) + .contentType("application/json"); BodySerDe serializers = conjureBodySerDe("text/plain"); diff --git a/dialogue-serde/src/test/java/com/palantir/conjure/java/dialogue/serde/EndpointErrorsConjureBodySerDeTest.java b/dialogue-serde/src/test/java/com/palantir/conjure/java/dialogue/serde/EndpointErrorsConjureBodySerDeTest.java index 296bb193d..eb06ec457 100644 --- a/dialogue-serde/src/test/java/com/palantir/conjure/java/dialogue/serde/EndpointErrorsConjureBodySerDeTest.java +++ b/dialogue-serde/src/test/java/com/palantir/conjure/java/dialogue/serde/EndpointErrorsConjureBodySerDeTest.java @@ -144,6 +144,7 @@ public void testDeserializeCustomErrors() throws IOException { EndpointErrorsConjureBodySerDeTest.EndpointReturnBaseType value = serializers.deserializer(deserializerArgs).deserialize(response); + assertThat(value).isInstanceOf(ErrorForEndpoint.class); assertThat(value) .extracting("errorCode", "errorName", "errorInstanceId", "args") .containsExactly(