diff --git a/.palantir/revapi.yml b/.palantir/revapi.yml index fe9ee372f..0a70f8cb5 100644 --- a/.palantir/revapi.yml +++ b/.palantir/revapi.yml @@ -311,3 +311,17 @@ acceptedBreaks: new: "method com.palantir.dialogue.clients.DialogueClients.StickyChannelSession\ \ com.palantir.dialogue.clients.DialogueClients.StickyChannelFactory2::session()" justification: "interface for consumption, not extension" + "5.0.0": + com.palantir.dialogue:dialogue-target: + - code: "java.method.addedToInterface" + new: "method com.palantir.dialogue.Deserializer com.palantir.dialogue.BodySerDe::deserializer(com.palantir.dialogue.DeserializerArgs)" + justification: "Adding new methods to create deserializers in support of endpoint\ + \ associated error deserialization" + - code: "java.method.addedToInterface" + new: "method com.palantir.dialogue.Deserializer com.palantir.dialogue.BodySerDe::inputStreamDeserializer(com.palantir.dialogue.DeserializerArgs)" + justification: "Adding new methods to create deserializers in support of endpoint\ + \ associated error deserialization" + - code: "java.method.addedToInterface" + new: "method com.palantir.dialogue.Deserializer com.palantir.dialogue.BodySerDe::optionalInputStreamDeserializer(com.palantir.dialogue.DeserializerArgs)" + justification: "Adding new methods to create deserializers in support of endpoint\ + \ associated error deserialization" diff --git a/dialogue-serde/src/main/java/com/palantir/conjure/java/dialogue/serde/ConjureBodySerDe.java b/dialogue-serde/src/main/java/com/palantir/conjure/java/dialogue/serde/ConjureBodySerDe.java index b738b8529..b2c45e017 100644 --- a/dialogue-serde/src/main/java/com/palantir/conjure/java/dialogue/serde/ConjureBodySerDe.java +++ b/dialogue-serde/src/main/java/com/palantir/conjure/java/dialogue/serde/ConjureBodySerDe.java @@ -26,6 +26,7 @@ import com.palantir.dialogue.BinaryRequestBody; import com.palantir.dialogue.BodySerDe; import com.palantir.dialogue.Deserializer; +import com.palantir.dialogue.DeserializerArgs; import com.palantir.dialogue.RequestBody; import com.palantir.dialogue.Response; import com.palantir.dialogue.Serializer; @@ -40,13 +41,18 @@ import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; +import java.lang.reflect.Constructor; +import java.lang.reflect.ParameterizedType; import java.lang.reflect.Type; import java.util.ArrayList; +import java.util.Collections; import java.util.Comparator; import java.util.List; import java.util.Optional; +import java.util.function.Function; import java.util.function.Supplier; import java.util.stream.Collectors; +import javax.annotation.Nullable; /** Package private internal API. */ final class ConjureBodySerDe implements BodySerDe { @@ -58,7 +64,8 @@ final class ConjureBodySerDe implements BodySerDe { private final Deserializer> optionalBinaryInputStreamDeserializer; private final Deserializer emptyBodyDeserializer; private final LoadingCache> serializers; - private final LoadingCache> deserializers; + private final LoadingCache> deserializers; + private final EmptyContainerDeserializer emptyContainerDeserializer; /** * Selects the first (based on input order) of the provided encodings that @@ -67,31 +74,46 @@ final class ConjureBodySerDe implements BodySerDe { */ ConjureBodySerDe( List rawEncodings, - ErrorDecoder errorDecoder, EmptyContainerDeserializer emptyContainerDeserializer, CaffeineSpec cacheSpec) { List encodings = decorateEncodings(rawEncodings); this.encodingsSortedByWeight = sortByWeight(encodings); Preconditions.checkArgument(encodings.size() > 0, "At least one Encoding is required"); this.defaultEncoding = encodings.get(0).encoding(); - this.binaryInputStreamDeserializer = new EncodingDeserializerRegistry<>( + this.emptyContainerDeserializer = emptyContainerDeserializer; + this.binaryInputStreamDeserializer = EncodingDeserializerForEndpointRegistry.create( ImmutableList.of(BinaryEncoding.INSTANCE), - errorDecoder, emptyContainerDeserializer, - BinaryEncoding.MARKER); - this.optionalBinaryInputStreamDeserializer = new EncodingDeserializerRegistry<>( + BinaryEncoding.MARKER, + DeserializerArgs.builder() + .baseType(BinaryEncoding.MARKER) + .success(BinaryEncoding.MARKER) + .build()); + this.optionalBinaryInputStreamDeserializer = EncodingDeserializerForEndpointRegistry.create( ImmutableList.of(BinaryEncoding.INSTANCE), - errorDecoder, emptyContainerDeserializer, - BinaryEncoding.OPTIONAL_MARKER); - this.emptyBodyDeserializer = new EmptyBodyDeserializer(errorDecoder); + BinaryEncoding.OPTIONAL_MARKER, + DeserializerArgs.>builder() + .baseType(BinaryEncoding.OPTIONAL_MARKER) + .success(BinaryEncoding.OPTIONAL_MARKER) + .build()); + this.emptyBodyDeserializer = new EmptyBodyDeserializer(new EndpointErrorDecoder<>(Collections.emptyMap())); // Class unloading: Not supported, Jackson keeps strong references to the types // it sees: https://github.com/FasterXML/jackson-databind/issues/489 this.serializers = Caffeine.from(cacheSpec) .build(type -> new EncodingSerializerRegistry<>(defaultEncoding, TypeMarker.of(type))); - this.deserializers = Caffeine.from(cacheSpec) - .build(type -> new EncodingDeserializerRegistry<>( - encodingsSortedByWeight, errorDecoder, emptyContainerDeserializer, TypeMarker.of(type))); + this.deserializers = Caffeine.from(cacheSpec).build(type -> buildCacheEntry(TypeMarker.of(type))); + } + + private EncodingDeserializerForEndpointRegistry buildCacheEntry(TypeMarker typeMarker) { + return EncodingDeserializerForEndpointRegistry.create( + encodingsSortedByWeight, + emptyContainerDeserializer, + typeMarker, + DeserializerArgs.builder() + .baseType(typeMarker) + .success(typeMarker) + .build()); } private static List decorateEncodings(List input) { @@ -122,6 +144,13 @@ public Deserializer deserializer(TypeMarker token) { return (Deserializer) deserializers.get(token.getType()); } + @Override + @SuppressWarnings("unchecked") + public Deserializer deserializer(DeserializerArgs deserializerArgs) { + return EncodingDeserializerForEndpointRegistry.create( + encodingsSortedByWeight, emptyContainerDeserializer, deserializerArgs.baseType(), deserializerArgs); + } + @Override public Deserializer emptyBodyDeserializer() { return emptyBodyDeserializer; @@ -132,11 +161,36 @@ public Deserializer inputStreamDeserializer() { return binaryInputStreamDeserializer; } + @Override + @SuppressWarnings("unchecked") + public Deserializer inputStreamDeserializer(DeserializerArgs deserializerArgs) { + return new EncodingDeserializerForEndpointRegistry<>( + ImmutableList.of(BinaryEncoding.INSTANCE), + emptyContainerDeserializer, + deserializerArgs.baseType(), + deserializerArgs, + BinaryEncoding.MARKER, + (Function) createSuccessTypeFunctionForInputStream(deserializerArgs.successType())); + } + @Override public Deserializer> optionalInputStreamDeserializer() { return optionalBinaryInputStreamDeserializer; } + @Override + @SuppressWarnings("unchecked") + public Deserializer optionalInputStreamDeserializer(DeserializerArgs deserializerArgs) { + return new EncodingDeserializerForEndpointRegistry<>( + ImmutableList.of(BinaryEncoding.INSTANCE), + emptyContainerDeserializer, + deserializerArgs.baseType(), + deserializerArgs, + BinaryEncoding.OPTIONAL_MARKER, + (Function, T>) + createSuccessTypeFunctionForOptionalInputStream(deserializerArgs.successType())); + } + @Override public RequestBody serialize(BinaryRequestBody value) { Preconditions.checkNotNull(value, "A BinaryRequestBody value is required"); @@ -216,44 +270,70 @@ private static final class EncodingSerializerContainer { } } - private static final class EncodingDeserializerRegistry implements Deserializer { - - private static final SafeLogger log = SafeLoggerFactory.get(EncodingDeserializerRegistry.class); - private final ImmutableList> encodings; - private final ErrorDecoder errorDecoder; + private static final class EncodingDeserializerForEndpointRegistry implements Deserializer { + private static final SafeLogger log = SafeLoggerFactory.get(EncodingDeserializerForEndpointRegistry.class); + private final ImmutableList> encodings; + private final EndpointErrorDecoder endpointErrorDecoder; private final Optional acceptValue; - private final Supplier> emptyInstance; + private final Supplier> emptyInstance; private final TypeMarker token; + private final @Nullable Function transform; + + @SuppressWarnings("unchecked") + static EncodingDeserializerForEndpointRegistry create( + List encodingsSortedByWeight, + EmptyContainerDeserializer empty, + TypeMarker token, + DeserializerArgs deserializersForEndpoint) { + return new EncodingDeserializerForEndpointRegistry<>( + encodingsSortedByWeight, + empty, + token, + deserializersForEndpoint, + (TypeMarker) deserializersForEndpoint.successType(), + null); + } - EncodingDeserializerRegistry( - List encodings, - ErrorDecoder errorDecoder, + EncodingDeserializerForEndpointRegistry( + List encodingsSortedByWeight, EmptyContainerDeserializer empty, - TypeMarker token) { - this.encodings = encodings.stream() - .map(encoding -> new EncodingDeserializerContainer<>(encoding, token)) + TypeMarker token, + DeserializerArgs deserializersForEndpoint, + TypeMarker intermediateResult, + @Nullable Function transform) { + this.encodings = encodingsSortedByWeight.stream() + .map(encoding -> new EncodingDeserializerContainer<>(encoding, intermediateResult)) .collect(ImmutableList.toImmutableList()); - this.errorDecoder = errorDecoder; + this.endpointErrorDecoder = new EndpointErrorDecoder<>( + deserializersForEndpoint.errorNameToTypeMarker(), + encodingsSortedByWeight.stream() + .filter(encoding -> encoding.supportsContentType("application/json")) + .findFirst()); this.token = token; - this.emptyInstance = Suppliers.memoize(() -> empty.tryGetEmptyInstance(token)); - // Encodings are applied to the accept header in the order of preference based on the provided list. - this.acceptValue = - Optional.of(encodings.stream().map(Encoding::getContentType).collect(Collectors.joining(", "))); + this.emptyInstance = Suppliers.memoize(() -> empty.tryGetEmptyInstance(intermediateResult)); + this.acceptValue = Optional.of(encodingsSortedByWeight.stream() + .map(Encoding::getContentType) + .collect(Collectors.joining(", "))); + this.transform = transform; } @Override + @SuppressWarnings("unchecked") public T deserialize(Response response) { boolean closeResponse = true; try { - if (errorDecoder.isError(response)) { - throw errorDecoder.decode(response); + if (endpointErrorDecoder.isError(response)) { + return endpointErrorDecoder.decode(response); } else if (response.code() == 204) { // TODO(dfox): what if we get a 204 for a non-optional type??? // TODO(dfox): support http200 & body=null // TODO(dfox): what if we were expecting an empty list but got {}? - Optional maybeEmptyInstance = emptyInstance.get(); + Optional maybeEmptyInstance = emptyInstance.get(); if (maybeEmptyInstance.isPresent()) { - return maybeEmptyInstance.get(); + if (transform == null) { + return (T) maybeEmptyInstance.get(); + } + return transform.apply(maybeEmptyInstance.get()); } throw new SafeRuntimeException( "Unable to deserialize non-optional response type from 204", SafeArg.of("type", token)); @@ -265,11 +345,14 @@ public T deserialize(Response response) { "Response is missing Content-Type header", SafeArg.of("received", response.headers().keySet())); } - Encoding.Deserializer deserializer = getResponseDeserializer(contentType.get()); - T deserialized = deserializer.deserialize(response.body()); + Encoding.Deserializer deserializer = getResponseDeserializer(contentType.get()); + S deserialized = deserializer.deserialize(response.body()); // deserializer has taken on responsibility for closing the response body closeResponse = false; - return deserialized; + if (transform == null) { + return (T) deserialized; + } + return transform.apply(deserialized); } catch (IOException e) { throw new SafeRuntimeException( "Failed to deserialize response stream", @@ -291,9 +374,9 @@ public Optional accepts() { /** Returns the {@link EncodingDeserializerContainer} to use to deserialize the request body. */ @SuppressWarnings("ForLoopReplaceableByForEach") // performance sensitive code avoids iterator allocation - Encoding.Deserializer getResponseDeserializer(String contentType) { + Encoding.Deserializer getResponseDeserializer(String contentType) { for (int i = 0; i < encodings.size(); i++) { - EncodingDeserializerContainer container = encodings.get(i); + EncodingDeserializerContainer container = encodings.get(i); if (container.encoding.supportsContentType(contentType)) { return container.deserializer; } @@ -301,20 +384,17 @@ Encoding.Deserializer getResponseDeserializer(String contentType) { return throwingDeserializer(contentType); } - private Encoding.Deserializer throwingDeserializer(String contentType) { - return new Encoding.Deserializer() { - @Override - public T deserialize(InputStream input) { - try { - input.close(); - } catch (RuntimeException | IOException e) { - log.warn("Failed to close InputStream", e); - } - throw new SafeRuntimeException( - "Unsupported Content-Type", - SafeArg.of("received", contentType), - SafeArg.of("supportedEncodings", encodings)); + private Encoding.Deserializer throwingDeserializer(String contentType) { + return input -> { + try { + input.close(); + } catch (RuntimeException | IOException e) { + log.warn("Failed to close InputStream", e); } + throw new SafeRuntimeException( + "Unsupported Content-Type", + SafeArg.of("received", contentType), + SafeArg.of("supportedEncodings", encodings)); }; } } @@ -337,10 +417,10 @@ public String toString() { } private static final class EmptyBodyDeserializer implements Deserializer { - private final ErrorDecoder errorDecoder; + private final EndpointErrorDecoder endpointErrorDecoder; - EmptyBodyDeserializer(ErrorDecoder errorDecoder) { - this.errorDecoder = errorDecoder; + EmptyBodyDeserializer(EndpointErrorDecoder endpointErrorDecoder) { + this.endpointErrorDecoder = endpointErrorDecoder; } @Override @@ -348,8 +428,8 @@ private static final class EmptyBodyDeserializer implements Deserializer { public Void deserialize(Response response) { // We should not fail if a server that previously returned nothing starts returning a response try (Response unused = response) { - if (errorDecoder.isError(response)) { - throw errorDecoder.decode(response); + if (endpointErrorDecoder.isError(response)) { + endpointErrorDecoder.decode(response); } return null; } @@ -365,4 +445,59 @@ public String toString() { return "EmptyBodyDeserializer{}"; } } + + @SuppressWarnings("unchecked") + private static Function createSuccessTypeFunctionForInputStream(TypeMarker successT) { + return successTypeCreatorFactory(successT, successType -> { + try { + return ((Class) successType.getType()).getConstructor(InputStream.class); + } catch (ReflectiveOperationException ex) { + throw new SafeRuntimeException("Failed to create success type", ex); + } + }); + } + + @SuppressWarnings("unchecked") + private static Function, T> createSuccessTypeFunctionForOptionalInputStream( + TypeMarker successT) { + return successTypeCreatorFactory(successT, successType -> { + try { + Class clazz = (Class) successType.getType(); + for (Constructor ctor : clazz.getConstructors()) { + if (ctor.getParameterCount() != 1) { + continue; + } + Type paramType = ctor.getGenericParameterTypes()[0]; + if (paramType instanceof ParameterizedType parameterizedType) { + if (parameterizedType.getRawType().equals(Optional.class) + && parameterizedType.getActualTypeArguments()[0].equals(InputStream.class)) { + return (Constructor) ctor; + } + } + } + } catch (SecurityException ex) { + throw new SafeRuntimeException("Failed to create success type", ex); + } + throw new SafeRuntimeException( + "Failed to create success type. Could not find constructor with Optional parameter"); + }); + } + + private static Function successTypeCreatorFactory( + TypeMarker successT, Function, Constructor> ctorExtractor) { + if (!(successT.getType() instanceof Class)) { + throw new SafeRuntimeException("Failed to create success type", SafeArg.of("type", successT)); + } + Constructor ctor = ctorExtractor.apply(successT); + if (ctor == null) { + throw new SafeRuntimeException("Failed to create success type", SafeArg.of("type", successT)); + } + return ctorParam -> { + try { + return ctor.newInstance(ctorParam); + } catch (ReflectiveOperationException e) { + throw new SafeRuntimeException("Failed to create success type", e); + } + }; + } } diff --git a/dialogue-serde/src/main/java/com/palantir/conjure/java/dialogue/serde/DefaultConjureRuntime.java b/dialogue-serde/src/main/java/com/palantir/conjure/java/dialogue/serde/DefaultConjureRuntime.java index 3e4766fda..befd4c350 100644 --- a/dialogue-serde/src/main/java/com/palantir/conjure/java/dialogue/serde/DefaultConjureRuntime.java +++ b/dialogue-serde/src/main/java/com/palantir/conjure/java/dialogue/serde/DefaultConjureRuntime.java @@ -45,7 +45,6 @@ public final class DefaultConjureRuntime implements ConjureRuntime { private DefaultConjureRuntime(Builder builder) { this.bodySerDe = new ConjureBodySerDe( builder.encodings.isEmpty() ? DEFAULT_ENCODINGS : builder.encodings, - ErrorDecoder.INSTANCE, Encodings.emptyContainerDeserializer(), DEFAULT_SERDE_CACHE_SPEC); } diff --git a/dialogue-serde/src/main/java/com/palantir/conjure/java/dialogue/serde/EndpointErrorDecoder.java b/dialogue-serde/src/main/java/com/palantir/conjure/java/dialogue/serde/EndpointErrorDecoder.java new file mode 100644 index 000000000..2c1547240 --- /dev/null +++ b/dialogue-serde/src/main/java/com/palantir/conjure/java/dialogue/serde/EndpointErrorDecoder.java @@ -0,0 +1,266 @@ +/* + * (c) Copyright 2024 Palantir Technologies Inc. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.palantir.conjure.java.dialogue.serde; + +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Maps; +import com.google.common.net.HttpHeaders; +import com.google.common.primitives.Longs; +import com.palantir.conjure.java.api.errors.QosException; +import com.palantir.conjure.java.api.errors.QosReason; +import com.palantir.conjure.java.api.errors.QosReasons; +import com.palantir.conjure.java.api.errors.QosReasons.QosResponseDecodingAdapter; +import com.palantir.conjure.java.api.errors.RemoteException; +import com.palantir.conjure.java.api.errors.SerializableError; +import com.palantir.conjure.java.api.errors.UnknownRemoteException; +import com.palantir.conjure.java.dialogue.serde.Encoding.Deserializer; +import com.palantir.dialogue.Response; +import com.palantir.dialogue.TypeMarker; +import com.palantir.logsafe.Arg; +import com.palantir.logsafe.SafeArg; +import com.palantir.logsafe.SafeLoggable; +import com.palantir.logsafe.UnsafeArg; +import com.palantir.logsafe.exceptions.SafeExceptions; +import com.palantir.logsafe.logger.SafeLogger; +import com.palantir.logsafe.logger.SafeLoggerFactory; +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.net.MalformedURLException; +import java.net.URL; +import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import javax.annotation.Nullable; + +/** + * Extracts the error from a {@link Response}. + *

If the error's name is in the {@link #errorNameToJsonDeserializerMap}, this class attempts to deserialize the + * {@link Response} body as JSON, to the error type. Otherwise, a {@link RemoteException} is thrown. If the + * {@link Response} does not adhere to the expected format, an {@link UnknownRemoteException} is thrown. + * + * @param the base type of the endpoint response. It's a union of the result type and all the error types. + */ +final class EndpointErrorDecoder { + private static final SafeLogger log = SafeLoggerFactory.get(EndpointErrorDecoder.class); + // Errors are currently expected to be JSON objects. See + // https://palantir.github.io/conjure/#/docs/spec/wire?id=_55-conjure-errors. As there is greater adoption of + // endpoint associated errors and larger parameters, we may be interested in supporting SMILE/CBOR for more + // performant handling of larger paramater payloads. + private static final Encoding JSON_ENCODING = Encodings.json(); + private static final Deserializer NAMED_ERROR_DESERIALIZER = + JSON_ENCODING.deserializer(new TypeMarker<>() {}); + private static final Deserializer SERIALIZABLE_ERROR_DESERIALIZER = + JSON_ENCODING.deserializer(new TypeMarker<>() {}); + private final Map> errorNameToJsonDeserializerMap; + + EndpointErrorDecoder(Map> errorNameToTypeMap) { + this(errorNameToTypeMap, Optional.empty()); + } + + EndpointErrorDecoder( + Map> errorNameToTypeMap, Optional maybeJsonEncoding) { + this.errorNameToJsonDeserializerMap = ImmutableMap.copyOf( + Maps.transformValues(errorNameToTypeMap, maybeJsonEncoding.orElse(JSON_ENCODING)::deserializer)); + } + + boolean isError(Response response) { + return 300 <= response.code() && response.code() <= 599; + } + + T decode(Response response) { + if (log.isDebugEnabled()) { + log.debug("Received an error response", diagnosticArgs(response)); + } + try { + return decodeInternal(response); + } catch (Exception e) { + e.addSuppressed(diagnostic(response)); + throw e; + } + } + + Optional checkCode(Response response) { + int code = response.code(); + switch (code) { + case 308: + Optional location = response.getFirstHeader(HttpHeaders.LOCATION); + if (location.isPresent()) { + String locationHeader = location.get(); + try { + UnknownRemoteException remoteException = new UnknownRemoteException(code, ""); + remoteException.initCause( + QosException.retryOther(qosReason(response), new URL(locationHeader))); + return Optional.of(remoteException); + } catch (MalformedURLException e) { + log.error( + "Failed to parse location header for QosException.RetryOther", + UnsafeArg.of("locationHeader", locationHeader), + e); + } + } else { + log.error("Retrieved HTTP status code 308 without Location header, cannot perform " + + "redirect. This appears to be a server-side protocol violation."); + } + break; + case 429: + return Optional.of(response.getFirstHeader(HttpHeaders.RETRY_AFTER) + .map(Longs::tryParse) + .map(Duration::ofSeconds) + .map(duration -> QosException.throttle(qosReason(response), duration)) + .orElseGet(() -> QosException.throttle(qosReason(response)))); + case 503: + return Optional.of(QosException.unavailable(qosReason(response))); + } + return Optional.empty(); + } + + private @Nullable String extractErrorName(byte[] body) { + try { + NamedError namedError = NAMED_ERROR_DESERIALIZER.deserialize(new ByteArrayInputStream(body)); + if (namedError == null) { + return null; + } + return namedError.errorName(); + } catch (IOException | RuntimeException e) { + return null; + } + } + + private T decodeInternal(Response response) { + Optional maybeQosException = checkCode(response); + if (maybeQosException.isPresent()) { + throw maybeQosException.get(); + } + int code = response.code(); + + byte[] body; + try { + body = toByteArray(response.body()); + } catch (NullPointerException | IOException e) { + UnknownRemoteException exception = new UnknownRemoteException(code, ""); + exception.initCause(e); + throw exception; + } + + Optional contentType = response.getFirstHeader(HttpHeaders.CONTENT_TYPE); + if (contentType.isPresent() + && Encodings.matchesContentType(JSON_ENCODING.getContentType(), contentType.get())) { + try { + String errorName = extractErrorName(body); + if (errorName == null) { + throw createRemoteException(body, code); + } + Deserializer deserializer = errorNameToJsonDeserializerMap.get(errorName); + if (deserializer == null) { + throw createRemoteException(body, code); + } + return deserializer.deserialize(new ByteArrayInputStream(body)); + } catch (RemoteException remoteException) { + // rethrow the created remote exception + throw remoteException; + } catch (Exception e) { + throw new UnknownRemoteException(code, new String(body, StandardCharsets.UTF_8)); + } + } + + throw new UnknownRemoteException(code, new String(body, StandardCharsets.UTF_8)); + } + + private static RemoteException createRemoteException(byte[] body, int code) throws IOException { + SerializableError serializableError = + SERIALIZABLE_ERROR_DESERIALIZER.deserialize(new ByteArrayInputStream(body)); + return new RemoteException(serializableError, code); + } + + private static byte[] toByteArray(InputStream body) throws IOException { + try (body) { + return body.readAllBytes(); + } + } + + static ResponseDiagnostic diagnostic(Response response) { + return new ResponseDiagnostic(diagnosticArgs(response)); + } + + static ImmutableList> diagnosticArgs(Response response) { + ImmutableList.Builder> args = ImmutableList.>builder().add(SafeArg.of("status", response.code())); + recordHeader(HttpHeaders.SERVER, response, args); + recordHeader(HttpHeaders.CONTENT_TYPE, response, args); + recordHeader(HttpHeaders.CONTENT_LENGTH, response, args); + recordHeader(HttpHeaders.CONNECTION, response, args); + recordHeader(HttpHeaders.DATE, response, args); + recordHeader("x-envoy-response-flags", response, args); + recordHeader("x-envoy-response-code-details", response, args); + recordHeader("Response-Flags", response, args); + recordHeader("Response-Code-Details", response, args); + return args.build(); + } + + private static void recordHeader(String header, Response response, ImmutableList.Builder> args) { + response.getFirstHeader(header).ifPresent(server -> args.add(SafeArg.of(header, server))); + } + + private static final class ResponseDiagnostic extends RuntimeException implements SafeLoggable { + + private static final String SAFE_MESSAGE = "Response Diagnostic Information"; + + private final ImmutableList> args; + + ResponseDiagnostic(ImmutableList> args) { + super(SafeExceptions.renderMessage(SAFE_MESSAGE, args.toArray(new Arg[0]))); + this.args = args; + } + + @Override + public String getLogMessage() { + return SAFE_MESSAGE; + } + + @Override + public List> getArgs() { + return args; + } + + @Override + @SuppressWarnings("UnsynchronizedOverridesSynchronized") // nop + public Throwable fillInStackTrace() { + // no-op: stack trace generation is expensive, this type exists + // to simply associate diagnostic information with a failure. + return this; + } + } + + private static QosReason qosReason(Response response) { + return QosReasons.parseFromResponse(response, DialogueQosResponseDecodingAdapter.INSTANCE); + } + + private enum DialogueQosResponseDecodingAdapter implements QosResponseDecodingAdapter { + INSTANCE; + + @Override + public Optional getFirstHeader(Response response, String headerName) { + return response.getFirstHeader(headerName); + } + } + + record NamedError(@JsonProperty("errorName") String errorName) {} +} diff --git a/dialogue-serde/src/main/java/com/palantir/conjure/java/dialogue/serde/ErrorDecoder.java b/dialogue-serde/src/main/java/com/palantir/conjure/java/dialogue/serde/ErrorDecoder.java index 50642aea0..91aa488bd 100644 --- a/dialogue-serde/src/main/java/com/palantir/conjure/java/dialogue/serde/ErrorDecoder.java +++ b/dialogue-serde/src/main/java/com/palantir/conjure/java/dialogue/serde/ErrorDecoder.java @@ -17,35 +17,21 @@ package com.palantir.conjure.java.dialogue.serde; import com.fasterxml.jackson.databind.ObjectMapper; -import com.google.common.collect.ImmutableList; import com.google.common.io.CharStreams; import com.google.common.net.HttpHeaders; -import com.google.common.primitives.Longs; -import com.palantir.conjure.java.api.errors.QosException; -import com.palantir.conjure.java.api.errors.QosReason; -import com.palantir.conjure.java.api.errors.QosReasons; -import com.palantir.conjure.java.api.errors.QosReasons.QosResponseDecodingAdapter; import com.palantir.conjure.java.api.errors.RemoteException; import com.palantir.conjure.java.api.errors.SerializableError; import com.palantir.conjure.java.api.errors.UnknownRemoteException; import com.palantir.conjure.java.serialization.ObjectMappers; import com.palantir.dialogue.Response; -import com.palantir.logsafe.Arg; -import com.palantir.logsafe.SafeArg; -import com.palantir.logsafe.SafeLoggable; -import com.palantir.logsafe.UnsafeArg; -import com.palantir.logsafe.exceptions.SafeExceptions; import com.palantir.logsafe.logger.SafeLogger; import com.palantir.logsafe.logger.SafeLoggerFactory; import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; import java.io.Reader; -import java.net.MalformedURLException; -import java.net.URL; import java.nio.charset.StandardCharsets; -import java.time.Duration; -import java.util.List; +import java.util.Collections; import java.util.Optional; /** @@ -59,17 +45,19 @@ public enum ErrorDecoder { private static final SafeLogger log = SafeLoggerFactory.get(ErrorDecoder.class); private static final ObjectMapper MAPPER = ObjectMappers.newClientObjectMapper(); + private static final EndpointErrorDecoder ENDPOINT_ERROR_DECODER = + new EndpointErrorDecoder<>(Collections.emptyMap()); public boolean isError(Response response) { - return 300 <= response.code() && response.code() <= 599; + return ENDPOINT_ERROR_DECODER.isError(response); } public RuntimeException decode(Response response) { if (log.isDebugEnabled()) { - log.debug("Received an error response", diagnosticArgs(response)); + log.debug("Received an error response", EndpointErrorDecoder.diagnosticArgs(response)); } RuntimeException result = decodeInternal(response); - result.addSuppressed(diagnostic(response)); + result.addSuppressed(EndpointErrorDecoder.diagnostic(response)); return result; } @@ -77,38 +65,12 @@ private RuntimeException decodeInternal(Response response) { // TODO(rfink): What about HTTP/101 switching protocols? // TODO(rfink): What about HEAD requests? - int code = response.code(); - switch (code) { - case 308: - Optional location = response.getFirstHeader(HttpHeaders.LOCATION); - if (location.isPresent()) { - String locationHeader = location.get(); - try { - UnknownRemoteException remoteException = new UnknownRemoteException(code, ""); - remoteException.initCause( - QosException.retryOther(qosReason(response), new URL(locationHeader))); - return remoteException; - } catch (MalformedURLException e) { - log.error( - "Failed to parse location header for QosException.RetryOther", - UnsafeArg.of("locationHeader", locationHeader), - e); - } - } else { - log.error("Retrieved HTTP status code 308 without Location header, cannot perform " - + "redirect. This appears to be a server-side protocol violation."); - } - break; - case 429: - return response.getFirstHeader(HttpHeaders.RETRY_AFTER) - .map(Longs::tryParse) - .map(Duration::ofSeconds) - .map(duration -> QosException.throttle(qosReason(response), duration)) - .orElseGet(() -> QosException.throttle(qosReason(response))); - case 503: - return QosException.unavailable(qosReason(response)); + Optional maybeQosException = ENDPOINT_ERROR_DECODER.checkCode(response); + if (maybeQosException.isPresent()) { + return maybeQosException.get(); } + int code = response.code(); String body; try { body = toString(response.body()); @@ -136,69 +98,4 @@ private static String toString(InputStream body) throws IOException { return CharStreams.toString(reader); } } - - private static ResponseDiagnostic diagnostic(Response response) { - return new ResponseDiagnostic(diagnosticArgs(response)); - } - - private static ImmutableList> diagnosticArgs(Response response) { - ImmutableList.Builder> args = ImmutableList.>builder().add(SafeArg.of("status", response.code())); - recordHeader(HttpHeaders.SERVER, response, args); - recordHeader(HttpHeaders.CONTENT_TYPE, response, args); - recordHeader(HttpHeaders.CONTENT_LENGTH, response, args); - recordHeader(HttpHeaders.CONNECTION, response, args); - recordHeader(HttpHeaders.DATE, response, args); - recordHeader("x-envoy-response-flags", response, args); - recordHeader("x-envoy-response-code-details", response, args); - recordHeader("Response-Flags", response, args); - recordHeader("Response-Code-Details", response, args); - return args.build(); - } - - private static void recordHeader(String header, Response response, ImmutableList.Builder> args) { - response.getFirstHeader(header).ifPresent(server -> args.add(SafeArg.of(header, server))); - } - - private static final class ResponseDiagnostic extends RuntimeException implements SafeLoggable { - - private static final String SAFE_MESSAGE = "Response Diagnostic Information"; - - private final ImmutableList> args; - - ResponseDiagnostic(ImmutableList> args) { - super(SafeExceptions.renderMessage(SAFE_MESSAGE, args.toArray(new Arg[0]))); - this.args = args; - } - - @Override - public String getLogMessage() { - return SAFE_MESSAGE; - } - - @Override - public List> getArgs() { - return args; - } - - @Override - @SuppressWarnings("UnsynchronizedOverridesSynchronized") // nop - public Throwable fillInStackTrace() { - // no-op: stack trace generation is expensive, this type exists - // to simply associate diagnostic information with a failure. - return this; - } - } - - private static QosReason qosReason(Response response) { - return QosReasons.parseFromResponse(response, DialogueQosResponseDecodingAdapter.INSTANCE); - } - - private enum DialogueQosResponseDecodingAdapter implements QosResponseDecodingAdapter { - INSTANCE; - - @Override - public Optional getFirstHeader(Response response, String headerName) { - return response.getFirstHeader(headerName); - } - } } diff --git a/dialogue-serde/src/main/java/com/palantir/conjure/java/dialogue/serde/JacksonEmptyContainerLoader.java b/dialogue-serde/src/main/java/com/palantir/conjure/java/dialogue/serde/JacksonEmptyContainerLoader.java index 5f03e3399..83bdac0e0 100644 --- a/dialogue-serde/src/main/java/com/palantir/conjure/java/dialogue/serde/JacksonEmptyContainerLoader.java +++ b/dialogue-serde/src/main/java/com/palantir/conjure/java/dialogue/serde/JacksonEmptyContainerLoader.java @@ -24,6 +24,7 @@ import com.palantir.logsafe.logger.SafeLogger; import com.palantir.logsafe.logger.SafeLoggerFactory; import java.io.IOException; +import java.lang.reflect.Constructor; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; import java.lang.reflect.Modifier; @@ -84,9 +85,22 @@ private Optional constructEmptyInstance(Type type, TypeMarker origina } } + Constructor emptyRecordConstructor = getEmptyRecordCanonicalConstructor(type); + if (emptyRecordConstructor != null) { + try { + return Optional.of(emptyRecordConstructor.newInstance()); + } catch (InstantiationException | IllegalAccessException | InvocationTargetException e) { + if (log.isDebugEnabled()) { + log.debug("Empty record construction failed", SafeArg.of("type", type), e); + } + return Optional.empty(); + } + } + if (log.isDebugEnabled()) { log.debug( - "Jackson couldn't instantiate an empty instance and also couldn't find a usable @JsonCreator", + "Jackson couldn't instantiate an empty instance and also couldn't find a usable @JsonCreator" + + " or an empty record constructor", SafeArg.of("type", type)); } return Optional.empty(); @@ -152,6 +166,24 @@ private static Method getJsonCreatorStaticMethod(Type type) { return null; } + /** + * If {@code type} is a record with 0 components, return the canonical constructor. + */ + @Nullable + private static Constructor getEmptyRecordCanonicalConstructor(Type type) { + if (type instanceof Class) { + Class clazz = (Class) type; + if (clazz.isRecord() && clazz.getRecordComponents().length == 0) { + for (Constructor ctor : clazz.getDeclaredConstructors()) { + if (0 == ctor.getParameterCount()) { + return ctor; + } + } + } + } + return null; + } + private static Optional invokeStaticFactoryMethod(Method method, Object parameter) { try { return Optional.ofNullable(method.invoke(null, parameter)); diff --git a/dialogue-serde/src/test/java/com/palantir/conjure/java/dialogue/serde/BinaryEncodingTest.java b/dialogue-serde/src/test/java/com/palantir/conjure/java/dialogue/serde/BinaryEncodingTest.java index 0d13b0b4a..fe8a9bcd2 100644 --- a/dialogue-serde/src/test/java/com/palantir/conjure/java/dialogue/serde/BinaryEncodingTest.java +++ b/dialogue-serde/src/test/java/com/palantir/conjure/java/dialogue/serde/BinaryEncodingTest.java @@ -34,7 +34,6 @@ public void testBinary() throws IOException { TestResponse response = new TestResponse().code(200).contentType("application/octet-stream"); BodySerDe serializers = new ConjureBodySerDe( ImmutableList.of(WeightedEncoding.of(new ConjureBodySerDeTest.StubEncoding("application/json"))), - ErrorDecoder.INSTANCE, Encodings.emptyContainerDeserializer(), DefaultConjureRuntime.DEFAULT_SERDE_CACHE_SPEC); InputStream deserialized = serializers.inputStreamDeserializer().deserialize(response); @@ -58,7 +57,6 @@ public void testBinary_optional_present() throws IOException { TestResponse response = new TestResponse().code(200).contentType("application/octet-stream"); BodySerDe serializers = new ConjureBodySerDe( ImmutableList.of(WeightedEncoding.of(new ConjureBodySerDeTest.StubEncoding("application/json"))), - ErrorDecoder.INSTANCE, Encodings.emptyContainerDeserializer(), DefaultConjureRuntime.DEFAULT_SERDE_CACHE_SPEC); Optional maybe = diff --git a/dialogue-serde/src/test/java/com/palantir/conjure/java/dialogue/serde/ConjureBodySerDeTest.java b/dialogue-serde/src/test/java/com/palantir/conjure/java/dialogue/serde/ConjureBodySerDeTest.java index 684c5acdd..e7e2356b7 100644 --- a/dialogue-serde/src/test/java/com/palantir/conjure/java/dialogue/serde/ConjureBodySerDeTest.java +++ b/dialogue-serde/src/test/java/com/palantir/conjure/java/dialogue/serde/ConjureBodySerDeTest.java @@ -30,6 +30,7 @@ import com.palantir.conjure.java.serialization.ObjectMappers; import com.palantir.dialogue.BinaryRequestBody; import com.palantir.dialogue.BodySerDe; +import com.palantir.dialogue.DeserializerArgs; import com.palantir.dialogue.RequestBody; import com.palantir.dialogue.TestResponse; import com.palantir.dialogue.TypeMarker; @@ -52,8 +53,6 @@ public class ConjureBodySerDeTest { private static final TypeMarker TYPE = new TypeMarker() {}; private static final TypeMarker> OPTIONAL_TYPE = new TypeMarker>() {}; - private ErrorDecoder errorDecoder = ErrorDecoder.INSTANCE; - @Test public void testRequestContentType() throws IOException { @@ -71,12 +70,25 @@ public void testRequestOptionalEmpty() { assertThat(value).isEmpty(); } + @Test + public void testRequestCustomEmpty() { + record EmptyRecord() {} + TestResponse response = new TestResponse().code(204); + BodySerDe serializers = conjureBodySerDe("application/json"); + EmptyRecord value = serializers + .deserializer(DeserializerArgs.builder() + .baseType(new TypeMarker<>() {}) + .success(new TypeMarker<>() {}) + .build()) + .deserialize(response); + assertThat(value).isNotNull(); + } + private ConjureBodySerDe conjureBodySerDe(String... contentTypes) { return new ConjureBodySerDe( Arrays.stream(contentTypes) .map(c -> WeightedEncoding.of(new StubEncoding(c))) .collect(ImmutableList.toImmutableList()), - errorDecoder, Encodings.emptyContainerDeserializer(), DefaultConjureRuntime.DEFAULT_SERDE_CACHE_SPEC); } @@ -115,7 +127,6 @@ public void testAcceptBasedOnWeight() throws IOException { BodySerDe serializers = new ConjureBodySerDe( ImmutableList.of(WeightedEncoding.of(plain, .5), WeightedEncoding.of(json, 1)), - ErrorDecoder.INSTANCE, Encodings.emptyContainerDeserializer(), DefaultConjureRuntime.DEFAULT_SERDE_CACHE_SPEC); // first encoding is default @@ -174,7 +185,6 @@ public void if_deserialize_throws_response_is_still_closed() { TestResponse response = new TestResponse().code(200).contentType("application/json"); BodySerDe serializers = new ConjureBodySerDe( ImmutableList.of(WeightedEncoding.of(BrokenEncoding.INSTANCE)), - ErrorDecoder.INSTANCE, Encodings.emptyContainerDeserializer(), DefaultConjureRuntime.DEFAULT_SERDE_CACHE_SPEC); assertThatThrownBy(() -> serializers.deserializer(TYPE).deserialize(response)) diff --git a/dialogue-serde/src/test/java/com/palantir/conjure/java/dialogue/serde/DefaultClientsTest.java b/dialogue-serde/src/test/java/com/palantir/conjure/java/dialogue/serde/DefaultClientsTest.java index a8ac9ab0d..58f1e4631 100644 --- a/dialogue-serde/src/test/java/com/palantir/conjure/java/dialogue/serde/DefaultClientsTest.java +++ b/dialogue-serde/src/test/java/com/palantir/conjure/java/dialogue/serde/DefaultClientsTest.java @@ -80,7 +80,6 @@ public final class DefaultClientsTest { private Response response = new TestResponse(); private BodySerDe bodySerde = new ConjureBodySerDe( DefaultConjureRuntime.DEFAULT_ENCODINGS, - ErrorDecoder.INSTANCE, Encodings.emptyContainerDeserializer(), DefaultConjureRuntime.DEFAULT_SERDE_CACHE_SPEC); private final SettableFuture responseFuture = SettableFuture.create(); diff --git a/dialogue-serde/src/test/java/com/palantir/conjure/java/dialogue/serde/EndpointErrorTestUtils.java b/dialogue-serde/src/test/java/com/palantir/conjure/java/dialogue/serde/EndpointErrorTestUtils.java new file mode 100644 index 000000000..b53de896e --- /dev/null +++ b/dialogue-serde/src/test/java/com/palantir/conjure/java/dialogue/serde/EndpointErrorTestUtils.java @@ -0,0 +1,164 @@ +/* + * (c) Copyright 2024 Palantir Technologies Inc. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.palantir.conjure.java.dialogue.serde; + +import com.fasterxml.jackson.annotation.JsonProperty; +import com.palantir.conjure.java.api.errors.CheckedServiceException; +import com.palantir.dialogue.TypeMarker; +import com.palantir.logsafe.Arg; +import com.palantir.logsafe.Safe; +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.OptionalDouble; +import java.util.OptionalInt; +import java.util.OptionalLong; +import java.util.function.Function; + +final class EndpointErrorTestUtils { + private EndpointErrorTestUtils() {} + + abstract static class EndpointError { + @Safe + String errorCode; + + @Safe + String errorName; + + @Safe + String errorInstanceId; + + T args; + + EndpointError(String errorCode, String errorName, String errorInstanceId, T args) { + this.errorCode = errorCode; + this.errorName = errorName; + this.errorInstanceId = errorInstanceId; + this.args = args; + } + } + + record ConjureError( + @JsonProperty("errorCode") String errorCode, + @JsonProperty("errorName") String errorName, + @JsonProperty("errorInstanceId") String errorInstanceId, + @JsonProperty("parameters") Map parameters) { + static ConjureError fromCheckedServiceException(CheckedServiceException exception) { + Map parameters = new HashMap<>(); + for (Arg arg : exception.getArgs()) { + if (shouldIncludeArgInParameters(arg)) { + parameters.put(arg.getName(), arg.getValue()); + } + } + return new ConjureError( + exception.getErrorType().code().name(), + exception.getErrorType().name(), + exception.getErrorInstanceId(), + parameters); + } + + private static boolean shouldIncludeArgInParameters(Arg arg) { + Object obj = arg.getValue(); + return obj != null + && (!(obj instanceof Optional) || ((Optional) obj).isPresent()) + && (!(obj instanceof OptionalInt) || ((OptionalInt) obj).isPresent()) + && (!(obj instanceof OptionalLong) || ((OptionalLong) obj).isPresent()) + && (!(obj instanceof OptionalDouble) || ((OptionalDouble) obj).isPresent()); + } + } + + /** Deserializes requests as the type. */ + public static final class TypeReturningStubEncoding implements Encoding { + + private final String contentType; + private final Function, Encoding.Deserializer> deserializerFactory; + private final Map, Encoding.Deserializer> deserializers = new HashMap<>(); + + TypeReturningStubEncoding(String contentType) { + this(contentType, typeMarker -> Encodings.json().deserializer(typeMarker)); + } + + TypeReturningStubEncoding( + String contentType, Function, Encoding.Deserializer> deserializerFactory) { + this.contentType = contentType; + this.deserializerFactory = deserializerFactory; + } + + @Override + public Encoding.Serializer serializer(TypeMarker _type) { + return (_value, _output) -> { + // nop + }; + } + + @Override + @SuppressWarnings("unchecked") + public Encoding.Deserializer deserializer(TypeMarker type) { + return input -> { + Deserializer deserializer = + (Deserializer) deserializers.computeIfAbsent(type, deserializerFactory); + return deserializer.deserialize(input); + }; + } + + @Override + public String getContentType() { + return contentType; + } + + @Override + public boolean supportsContentType(String input) { + return contentType.equals(input); + } + + @Override + public String toString() { + return "TypeReturningStubEncoding{" + contentType + '}'; + } + + @SuppressWarnings("unchecked") + public Encoding.Deserializer getDeserializer(TypeMarker type) { + return (Deserializer) deserializers.get(type); + } + } + + public static final class ContentRecordingJsonDeserializer implements Encoding.Deserializer { + private final List deserializedContent = new ArrayList<>(); + private final Encoding.Deserializer delegate; + + ContentRecordingJsonDeserializer(TypeMarker type) { + this.delegate = Encodings.json().deserializer(type); + } + + public List getDeserializedContent() { + return deserializedContent; + } + + @Override + public T deserialize(InputStream input) throws IOException { + String inputString = new String(input.readAllBytes(), StandardCharsets.UTF_8); + deserializedContent.add(inputString); + return delegate.deserialize(new ByteArrayInputStream(inputString.getBytes(StandardCharsets.UTF_8))); + } + } +} diff --git a/dialogue-serde/src/test/java/com/palantir/conjure/java/dialogue/serde/EndpointErrorsConjureBodySerDeTest.java b/dialogue-serde/src/test/java/com/palantir/conjure/java/dialogue/serde/EndpointErrorsConjureBodySerDeTest.java new file mode 100644 index 000000000..8d8c47203 --- /dev/null +++ b/dialogue-serde/src/test/java/com/palantir/conjure/java/dialogue/serde/EndpointErrorsConjureBodySerDeTest.java @@ -0,0 +1,431 @@ +/* + * (c) Copyright 2024 Palantir Technologies Inc. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.palantir.conjure.java.dialogue.serde; + +import static org.assertj.core.api.AssertionsForClassTypes.assertThat; +import static org.assertj.core.api.AssertionsForClassTypes.assertThatExceptionOfType; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.annotation.JsonValue; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.collect.ImmutableList; +import com.google.errorprone.annotations.MustBeClosed; +import com.palantir.conjure.java.api.errors.CheckedServiceException; +import com.palantir.conjure.java.api.errors.ErrorType; +import com.palantir.conjure.java.api.errors.RemoteException; +import com.palantir.conjure.java.api.errors.SerializableError; +import com.palantir.conjure.java.dialogue.serde.EndpointErrorTestUtils.ConjureError; +import com.palantir.conjure.java.dialogue.serde.EndpointErrorTestUtils.ContentRecordingJsonDeserializer; +import com.palantir.conjure.java.dialogue.serde.EndpointErrorTestUtils.EndpointError; +import com.palantir.conjure.java.dialogue.serde.EndpointErrorTestUtils.TypeReturningStubEncoding; +import com.palantir.conjure.java.serialization.ObjectMappers; +import com.palantir.dialogue.BodySerDe; +import com.palantir.dialogue.DeserializerArgs; +import com.palantir.dialogue.TestResponse; +import com.palantir.dialogue.TypeMarker; +import com.palantir.logsafe.Preconditions; +import com.palantir.logsafe.Safe; +import com.palantir.logsafe.SafeArg; +import com.palantir.logsafe.Unsafe; +import com.palantir.logsafe.UnsafeArg; +import java.io.IOException; +import java.io.InputStream; +import java.util.Arrays; +import java.util.List; +import java.util.Optional; +import java.util.function.Supplier; +import java.util.stream.Stream; +import javax.annotation.Nullable; +import javax.annotation.processing.Generated; +import org.assertj.core.api.InstanceOfAssertFactories; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.api.extension.ExtensionContext; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.ArgumentsProvider; +import org.junit.jupiter.params.provider.ArgumentsSource; +import org.junit.jupiter.params.provider.ValueSource; +import org.mockito.junit.jupiter.MockitoExtension; + +@ExtendWith(MockitoExtension.class) +public class EndpointErrorsConjureBodySerDeTest { + private static final ObjectMapper MAPPER = ObjectMappers.newServerObjectMapper(); + + @Generated("by conjure-java") + private sealed interface EmptyBodyEndpointReturnBaseType permits EmptyReturnValue, ErrorReturnValue {} + + @Generated("by conjure-java") + record EmptyReturnValue() implements EmptyBodyEndpointReturnBaseType {} + + @Generated("by conjure-java") + private sealed interface EndpointReturnBaseType permits ExpectedReturnValue, ErrorReturnValue {} + + @Generated("by conjure-java") + private sealed interface EndpointBinaryReturnBaseType permits BinaryReturnValue, ErrorReturnValue {} + + @Generated("by conjure-java") + private sealed interface EndpointOptionalBinaryReturnBaseType permits OptionalBinaryReturnValue, ErrorReturnValue {} + + @Generated("by conjure-java") + record ExpectedReturnValue(@JsonValue String value) implements EndpointReturnBaseType { + public ExpectedReturnValue { + Preconditions.checkArgumentNotNull(value, "value cannot be null"); + } + } + + @Generated("by conjure-java") + record BinaryReturnValue(@MustBeClosed @JsonValue InputStream value) + implements EndpointErrorsConjureBodySerDeTest.EndpointBinaryReturnBaseType { + public BinaryReturnValue { + Preconditions.checkArgumentNotNull(value, "value cannot be null"); + } + } + + @Generated("by conjure-java") + record OptionalBinaryReturnValue(@JsonValue Optional value) + implements EndpointOptionalBinaryReturnBaseType { + public OptionalBinaryReturnValue { + Preconditions.checkArgumentNotNull(value, "value cannot be null"); + } + } + + @Generated("by conjure-java") + record ComplexArg(int foo, String bar) {} + + @Generated("by conjure-java") + record ErrorForEndpointArgs( + @JsonProperty("arg") @Safe String arg, + @JsonProperty("unsafeArg") @Unsafe String unsafeArg, + @JsonProperty("complexArg") @Safe ComplexArg complexArg, + @JsonProperty("optionalArg") @Safe Optional optionalArg) {} + + static final class ErrorReturnValue extends EndpointError + implements EndpointErrorsConjureBodySerDeTest.EndpointReturnBaseType, + EndpointErrorsConjureBodySerDeTest.EmptyBodyEndpointReturnBaseType, + EndpointErrorsConjureBodySerDeTest.EndpointBinaryReturnBaseType, + EndpointErrorsConjureBodySerDeTest.EndpointOptionalBinaryReturnBaseType { + @JsonCreator + ErrorReturnValue( + @JsonProperty("errorCode") String errorCode, + @JsonProperty("errorName") String errorName, + @JsonProperty("errorInstanceId") String errorInstanceId, + @JsonProperty("parameters") ErrorForEndpointArgs args) { + super(errorCode, errorName, errorInstanceId, args); + } + } + + @Generated("by conjure-java") + public static final class TestEndpointError extends CheckedServiceException { + private TestEndpointError( + @Safe String arg, + @Unsafe String unsafeArg, + @Safe ComplexArg complexArg, + @Safe Optional optionalArg, + @Nullable Throwable cause) { + super( + ErrorType.FAILED_PRECONDITION, + cause, + SafeArg.of("arg", arg), + UnsafeArg.of("unsafeArg", unsafeArg), + SafeArg.of("complexArg", complexArg), + SafeArg.of("optionalArg", optionalArg)); + } + } + + @Test + public void testDeserializeExpectedValue() { + // Given + String expectedString = "expectedString"; + TestResponse response = TestResponse.withBody(String.format("\"%s\"", expectedString)) + .contentType("application/json") + .code(200); + BodySerDe serializers = conjureBodySerDe("application/json", "text/plain"); + DeserializerArgs deserializerArgs = DeserializerArgs.builder() + .baseType(new TypeMarker<>() {}) + .success(new TypeMarker() {}) + .error("Default:FailedPrecondition", new TypeMarker() {}) + .build(); + // When + EndpointReturnBaseType value = + serializers.deserializer(deserializerArgs).deserialize(response); + // Then + assertThat(value).isEqualTo(new ExpectedReturnValue(expectedString)); + } + + // The error should be deserialized using Encodings.json(), when a JSON encoding is not provided. + @ParameterizedTest + @ValueSource(strings = {"application/json", "text/plain"}) + public void testDeserializeCustomError(String supportedContentType) throws IOException { + // Given + TestEndpointError errorThrownByEndpoint = + new TestEndpointError("value", "unsafeValue", new ComplexArg(1, "bar"), Optional.of(2), null); + String responseBody = + MAPPER.writeValueAsString(ConjureError.fromCheckedServiceException(errorThrownByEndpoint)); + + TestResponse response = TestResponse.withBody(responseBody) + .contentType("application/json") + .code(500); + BodySerDe serializers = conjureBodySerDe(supportedContentType); + DeserializerArgs deserializerArgs = DeserializerArgs.builder() + .baseType(new TypeMarker<>() {}) + .success(new TypeMarker() {}) + .error("Default:FailedPrecondition", new TypeMarker() {}) + .build(); + + // When + EndpointErrorsConjureBodySerDeTest.EndpointReturnBaseType value = + serializers.deserializer(deserializerArgs).deserialize(response); + + // Then + ErrorReturnValue expectedErrorForEndpoint = new ErrorReturnValue( + ErrorType.FAILED_PRECONDITION.code().name(), + ErrorType.FAILED_PRECONDITION.name(), + errorThrownByEndpoint.getErrorInstanceId(), + new ErrorForEndpointArgs("value", "unsafeValue", new ComplexArg(1, "bar"), Optional.of(2))); + assertThat(value).isInstanceOf(ErrorReturnValue.class); + assertThat(value) + .extracting("errorCode", "errorName", "errorInstanceId", "args") + .containsExactly( + expectedErrorForEndpoint.errorCode, + expectedErrorForEndpoint.errorName, + expectedErrorForEndpoint.errorInstanceId, + expectedErrorForEndpoint.args); + } + + // When an error is deserialized, but the error type is not registered, the error should be deserialized as a + // SerializableError and a RemoteException should be thrown. + @Test + public void testDeserializingUndefinedErrorFallsbackToSerializableError() throws IOException { + TestEndpointError errorThrownByEndpoint = + new TestEndpointError("value", "unsafeValue", new ComplexArg(1, "bar"), Optional.of(2), null); + String responseBody = + MAPPER.writeValueAsString(ConjureError.fromCheckedServiceException(errorThrownByEndpoint)); + + TestResponse response = TestResponse.withBody(responseBody) + .contentType("application/json") + .code(500); + BodySerDe serializers = conjureBodySerDe("application/json", "text/plain"); + DeserializerArgs deserializerArgs = DeserializerArgs.builder() + .baseType(new TypeMarker<>() {}) + .success(new TypeMarker() {}) + // Note: no error types are registered. + .build(); + + // Then + assertThatExceptionOfType(RemoteException.class) + .isThrownBy(() -> { + serializers.deserializer(deserializerArgs).deserialize(response); + }) + .satisfies(exception -> { + SerializableError error = exception.getError(); + assertThat(error.errorCode()) + .isEqualTo(ErrorType.FAILED_PRECONDITION.code().name()); + assertThat(error.errorInstanceId()).isEqualTo(errorThrownByEndpoint.getErrorInstanceId()); + assertThat(error.errorName()).isEqualTo(ErrorType.FAILED_PRECONDITION.name()); + assertThat(error.parameters()) + .extracting("arg", "unsafeArg", "complexArg", "optionalArg") + .containsExactly( + "value", + "unsafeValue", + MAPPER.writeValueAsString(new ComplexArg(1, "bar")), + MAPPER.writeValueAsString(Optional.of(2))); + }); + } + + @ParameterizedTest + @ArgumentsSource(BinaryBodyArgumentsProvider.class) + public void testDeserializeBinaryValue(byte[] binaryBody) { + // Given + TestResponse response = new TestResponse(binaryBody) + .contentType("application/octet-stream") + .code(200); + + BodySerDe serializers = new ConjureBodySerDe( + ImmutableList.of(WeightedEncoding.of(BinaryEncoding.INSTANCE)), + Encodings.emptyContainerDeserializer(), + DefaultConjureRuntime.DEFAULT_SERDE_CACHE_SPEC); + + DeserializerArgs deserializerArgs = + DeserializerArgs.builder() + .baseType(new TypeMarker<>() {}) + .success(new TypeMarker() {}) + .error("Default:FailedPrecondition", new TypeMarker() {}) + .build(); + // When + EndpointBinaryReturnBaseType value = + serializers.inputStreamDeserializer(deserializerArgs).deserialize(response); + // Then + assertThat(value).isInstanceOfSatisfying(BinaryReturnValue.class, binaryReturnValue -> { + assertThat(EndpointErrorsConjureBodySerDeTest.readAllBytesUnchecked(binaryReturnValue::value)) + .isEqualTo(binaryBody); + }); + } + + @ParameterizedTest + @ArgumentsSource(BinaryBodyArgumentsProvider.class) + public void testDeserializeOptionalBinaryValuePresent(byte[] binaryBody) { + // Given + TestResponse response = new TestResponse(binaryBody) + .contentType("application/octet-stream") + .code(200); + + BodySerDe serializers = new ConjureBodySerDe( + ImmutableList.of(WeightedEncoding.of(BinaryEncoding.INSTANCE)), + Encodings.emptyContainerDeserializer(), + DefaultConjureRuntime.DEFAULT_SERDE_CACHE_SPEC); + + DeserializerArgs deserializerArgs = + DeserializerArgs.builder() + .baseType(new TypeMarker<>() {}) + .success(new TypeMarker() {}) + .error("Default:FailedPrecondition", new TypeMarker() {}) + .build(); + // When + EndpointOptionalBinaryReturnBaseType value = + serializers.optionalInputStreamDeserializer(deserializerArgs).deserialize(response); + // Then + assertThat(value).isInstanceOfSatisfying(OptionalBinaryReturnValue.class, optionalBinaryReturnValue -> { + assertThat(optionalBinaryReturnValue.value()).isPresent(); + assertThat(EndpointErrorsConjureBodySerDeTest.readAllBytesUnchecked(optionalBinaryReturnValue.value()::get)) + .isEqualTo(binaryBody); + }); + } + + @Test + public void testDeserializeOptionalBinaryValueError() throws JsonProcessingException { + // Given + TestEndpointError errorThrownByEndpoint = + new TestEndpointError("value", "unsafeValue", new ComplexArg(1, "bar"), Optional.of(2), null); + String responseBody = + MAPPER.writeValueAsString(ConjureError.fromCheckedServiceException(errorThrownByEndpoint)); + + TestResponse response = TestResponse.withBody(responseBody) + .contentType("application/json") + .code(500); + + BodySerDe serializers = new ConjureBodySerDe( + ImmutableList.of(WeightedEncoding.of(BinaryEncoding.INSTANCE)), + Encodings.emptyContainerDeserializer(), + DefaultConjureRuntime.DEFAULT_SERDE_CACHE_SPEC); + + DeserializerArgs deserializerArgs = + DeserializerArgs.builder() + .baseType(new TypeMarker<>() {}) + .success(new TypeMarker() {}) + .error("Default:FailedPrecondition", new TypeMarker() {}) + .build(); + // When + EndpointOptionalBinaryReturnBaseType value = + serializers.optionalInputStreamDeserializer(deserializerArgs).deserialize(response); + // Then + ErrorReturnValue expectedErrorForEndpoint = new ErrorReturnValue( + ErrorType.FAILED_PRECONDITION.code().name(), + ErrorType.FAILED_PRECONDITION.name(), + errorThrownByEndpoint.getErrorInstanceId(), + new ErrorForEndpointArgs("value", "unsafeValue", new ComplexArg(1, "bar"), Optional.of(2))); + assertThat(value).isInstanceOf(ErrorReturnValue.class); + assertThat(value) + .extracting("errorCode", "errorName", "errorInstanceId", "args") + .containsExactly( + expectedErrorForEndpoint.errorCode, + expectedErrorForEndpoint.errorName, + expectedErrorForEndpoint.errorInstanceId, + expectedErrorForEndpoint.args); + } + + @Test + public void testDeserializeEmptyBody() { + // Given + TestResponse response = new TestResponse().code(204); + BodySerDe serializers = conjureBodySerDe("application/json", "text/plain"); + DeserializerArgs deserializerArgs = + DeserializerArgs.builder() + .baseType(new TypeMarker<>() {}) + .success(new TypeMarker() {}) + .error("Default:FailedPrecondition", new TypeMarker() {}) + .build(); + // When + EmptyBodyEndpointReturnBaseType value = + serializers.deserializer(deserializerArgs).deserialize(response); + // Then + assertThat(value).isEqualTo(new EmptyReturnValue()); + } + + // Ensure that the supplied JSON encoding is used when available. + @Test + public void testDeserializeWithCustomEncoding() throws JsonProcessingException { + // Given + TestEndpointError errorThrownByEndpoint = + new TestEndpointError("value", "unsafeValue", new ComplexArg(1, "bar"), Optional.of(2), null); + String responseBody = + MAPPER.writeValueAsString(ConjureError.fromCheckedServiceException(errorThrownByEndpoint)); + + TypeReturningStubEncoding stubbingEncoding = + new TypeReturningStubEncoding("application/json", ContentRecordingJsonDeserializer::new); + BodySerDe serializers = new ConjureBodySerDe( + List.of(WeightedEncoding.of(stubbingEncoding)), + Encodings.emptyContainerDeserializer(), + DefaultConjureRuntime.DEFAULT_SERDE_CACHE_SPEC); + TestResponse response = TestResponse.withBody(responseBody) + .contentType("application/json") + .code(500); + + TypeMarker errorTypeMarker = new TypeMarker<>() {}; + DeserializerArgs deserializerArgs = DeserializerArgs.builder() + .baseType(new TypeMarker<>() {}) + .success(new TypeMarker() {}) + .error("Default:FailedPrecondition", errorTypeMarker) + .build(); + + // When + serializers.deserializer(deserializerArgs).deserialize(response); + + // Then + assertThat(stubbingEncoding.getDeserializer(errorTypeMarker)) + .isInstanceOfSatisfying(ContentRecordingJsonDeserializer.class, deserializer -> assertThat( + deserializer.getDeserializedContent()) + .asInstanceOf(InstanceOfAssertFactories.LIST) + .containsExactly(responseBody)); + } + + private static byte[] readAllBytesUnchecked(Supplier stream) { + try (InputStream is = stream.get()) { + return is.readAllBytes(); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + private ConjureBodySerDe conjureBodySerDe(String... contentTypes) { + return new ConjureBodySerDe( + Arrays.stream(contentTypes) + .map(c -> WeightedEncoding.of(new TypeReturningStubEncoding(c))) + .collect(ImmutableList.toImmutableList()), + Encodings.emptyContainerDeserializer(), + DefaultConjureRuntime.DEFAULT_SERDE_CACHE_SPEC); + } + + private static final class BinaryBodyArgumentsProvider implements ArgumentsProvider { + @Override + public Stream provideArguments(ExtensionContext _context) { + return Stream.of(Arguments.of((Object) new byte[] {1, 2, 3}), Arguments.of((Object) new byte[] {})); + } + } +} diff --git a/dialogue-serde/src/test/java/com/palantir/conjure/java/dialogue/serde/ErrorDecoderTest.java b/dialogue-serde/src/test/java/com/palantir/conjure/java/dialogue/serde/ErrorDecoderTest.java index 9f26a7aa6..b58fdfc33 100644 --- a/dialogue-serde/src/test/java/com/palantir/conjure/java/dialogue/serde/ErrorDecoderTest.java +++ b/dialogue-serde/src/test/java/com/palantir/conjure/java/dialogue/serde/ErrorDecoderTest.java @@ -17,6 +17,7 @@ package com.palantir.conjure.java.dialogue.serde; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.assertj.core.api.Assertions.fail; import com.fasterxml.jackson.core.JsonProcessingException; @@ -38,8 +39,12 @@ import com.palantir.logsafe.Preconditions; import com.palantir.logsafe.SafeArg; import java.time.Duration; +import java.util.Collections; +import java.util.function.Consumer; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.EnumSource; import org.mockito.junit.jupiter.MockitoExtension; @ExtendWith(MockitoExtension.class) @@ -54,24 +59,31 @@ public final class ErrorDecoderTest { private static String createServiceException(ServiceException exception) { try { - return SERVER_MAPPER.writeValueAsString(SerializableError.forException(exception)); + String ret = SERVER_MAPPER.writeValueAsString(SerializableError.forException(exception)); + return ret; } catch (JsonProcessingException e) { fail("failed to serialize"); return ""; } } + public enum DecoderType { + LEGACY, + ENDPOINT + } + private static final ErrorDecoder decoder = ErrorDecoder.INSTANCE; + private static final EndpointErrorDecoder endpointErrorDecoder = + new EndpointErrorDecoder<>(Collections.emptyMap()); - @Test - public void extractsRemoteExceptionForAllErrorCodes() { + @ParameterizedTest + @EnumSource(DecoderType.class) + public void extractsRemoteExceptionForAllErrorCodes(DecoderType decoderType) { for (int code : ImmutableList.of(300, 400, 404, 500)) { Response response = TestResponse.withBody(SERIALIZED_EXCEPTION).code(code).contentType("application/json"); - assertThat(decoder.isError(response)).isTrue(); - RuntimeException result = decoder.decode(response); - assertThat(result).isInstanceOfSatisfying(RemoteException.class, exception -> { + Consumer validationFunction = exception -> { assertThat(exception.getCause()).isNull(); assertThat(exception.getStatus()).isEqualTo(code); assertThat(exception.getError().errorCode()) @@ -90,117 +102,224 @@ public void extractsRemoteExceptionForAllErrorCodes() { + " (" + ErrorType.FAILED_PRECONDITION.name() + ")"); - }); + }; + + if (decoderType == DecoderType.LEGACY) { + assertThat(decoder.isError(response)).isTrue(); + RuntimeException result = decoder.decode(response); + assertThat(result).isInstanceOfSatisfying(RemoteException.class, validationFunction); + } else { + assertThat(endpointErrorDecoder.isError(response)).isTrue(); + assertThatExceptionOfType(RemoteException.class) + .isThrownBy(() -> endpointErrorDecoder.decode(response)) + .satisfies(validationFunction); + } } } - @Test - public void testQos503() { + @ParameterizedTest + @EnumSource(DecoderType.class) + public void testQos503(DecoderType decoderType) { Response response = TestResponse.withBody(SERIALIZED_EXCEPTION).code(503); - assertThat(decoder.isError(response)).isTrue(); - RuntimeException result = decoder.decode(response); - assertThat(result).isInstanceOfSatisfying(QosException.Unavailable.class, exception -> { - assertThat(exception.getReason()).isEqualTo(QOS_REASON); - }); + Consumer validationFunction = exception -> { + assertThat(exception).isInstanceOfSatisfying(QosException.Unavailable.class, qosException -> { + assertThat(qosException.getReason()).isEqualTo(QOS_REASON); + }); + }; + + if (decoderType == DecoderType.LEGACY) { + assertThat(decoder.isError(response)).isTrue(); + RuntimeException result = decoder.decode(response); + assertThat(result).isInstanceOfSatisfying(RuntimeException.class, validationFunction); + } else { + assertThat(endpointErrorDecoder.isError(response)).isTrue(); + assertThatExceptionOfType(RuntimeException.class) + .isThrownBy(() -> endpointErrorDecoder.decode(response)) + .satisfies(validationFunction); + } } - @Test - public void testQos503WithMetadata() { + @ParameterizedTest + @EnumSource(DecoderType.class) + public void testQos503WithMetadata(DecoderType decoderType) { Response response = TestResponse.withBody(SERIALIZED_EXCEPTION) .code(503) .withHeader("Qos-Retry-Hint", "do-not-retry") .withHeader("Qos-Due-To", "custom"); - assertThat(decoder.isError(response)).isTrue(); - - RuntimeException result = decoder.decode(response); - assertThat(result).isInstanceOfSatisfying(QosException.Unavailable.class, exception -> { - assertThat(exception.getReason()) - .isEqualTo(QosReason.builder() - .from(QOS_REASON) - .dueTo(DueTo.CUSTOM) - .retryHint(RetryHint.DO_NOT_RETRY) - .build()); - }); + + Consumer validationFunction = exception -> { + assertThat(exception).isInstanceOfSatisfying(QosException.Unavailable.class, qosException -> { + assertThat(qosException.getReason()) + .isEqualTo(QosReason.builder() + .from(QOS_REASON) + .dueTo(DueTo.CUSTOM) + .retryHint(RetryHint.DO_NOT_RETRY) + .build()); + }); + }; + + if (decoderType == DecoderType.LEGACY) { + assertThat(decoder.isError(response)).isTrue(); + RuntimeException result = decoder.decode(response); + assertThat(result).isInstanceOfSatisfying(RuntimeException.class, validationFunction); + } else { + assertThat(endpointErrorDecoder.isError(response)).isTrue(); + assertThatExceptionOfType(RuntimeException.class) + .isThrownBy(() -> endpointErrorDecoder.decode(response)) + .satisfies(validationFunction); + } } - @Test - public void testQos429() { + @ParameterizedTest + @EnumSource(DecoderType.class) + public void testQos429(DecoderType decoderType) { Response response = TestResponse.withBody(SERIALIZED_EXCEPTION).code(429); - assertThat(decoder.isError(response)).isTrue(); - RuntimeException result = decoder.decode(response); - assertThat(result).isInstanceOfSatisfying(QosException.Throttle.class, exception -> { - assertThat(exception.getReason()).isEqualTo(QOS_REASON); - assertThat(exception.getRetryAfter()).isEmpty(); - }); + Consumer validationFunction = exception -> { + assertThat(exception).isInstanceOfSatisfying(QosException.Throttle.class, qosException -> { + assertThat(qosException.getReason()).isEqualTo(QOS_REASON); + assertThat(qosException.getRetryAfter()).isEmpty(); + }); + }; + + if (decoderType == DecoderType.LEGACY) { + assertThat(decoder.isError(response)).isTrue(); + RuntimeException result = decoder.decode(response); + assertThat(result).isInstanceOfSatisfying(RuntimeException.class, validationFunction); + } else { + assertThat(endpointErrorDecoder.isError(response)).isTrue(); + assertThatExceptionOfType(RuntimeException.class) + .isThrownBy(() -> endpointErrorDecoder.decode(response)) + .satisfies(validationFunction); + } } - @Test - public void testQos429_retryAfter() { + @ParameterizedTest + @EnumSource(DecoderType.class) + public void testQos429_retryAfter(DecoderType decoderType) { Response response = TestResponse.withBody(SERIALIZED_EXCEPTION).code(429).withHeader(HttpHeaders.RETRY_AFTER, "3"); - assertThat(decoder.isError(response)).isTrue(); - RuntimeException result = decoder.decode(response); - assertThat(result).isInstanceOfSatisfying(QosException.Throttle.class, exception -> { - assertThat(exception.getReason()).isEqualTo(QOS_REASON); - assertThat(exception.getRetryAfter()).hasValue(Duration.ofSeconds(3)); - }); + Consumer validationFunction = exception -> { + assertThat(exception).isInstanceOfSatisfying(QosException.Throttle.class, qosException -> { + assertThat(qosException.getReason()).isEqualTo(QOS_REASON); + assertThat(qosException.getRetryAfter()).hasValue(Duration.ofSeconds(3)); + }); + }; + + if (decoderType == DecoderType.LEGACY) { + assertThat(decoder.isError(response)).isTrue(); + RuntimeException result = decoder.decode(response); + assertThat(result).isInstanceOfSatisfying(RuntimeException.class, validationFunction); + } else { + assertThat(endpointErrorDecoder.isError(response)).isTrue(); + assertThatExceptionOfType(RuntimeException.class) + .isThrownBy(() -> endpointErrorDecoder.decode(response)) + .satisfies(validationFunction); + } } - @Test - public void testQos429_retryAfter_invalid() { + @ParameterizedTest + @EnumSource(DecoderType.class) + public void testQos429_retryAfter_invalid(DecoderType decoderType) { Response response = TestResponse.withBody(SERIALIZED_EXCEPTION).code(429).withHeader(HttpHeaders.RETRY_AFTER, "bad"); - assertThat(decoder.isError(response)).isTrue(); - RuntimeException result = decoder.decode(response); - assertThat(result).isInstanceOfSatisfying(QosException.Throttle.class, exception -> { - assertThat(exception.getReason()).isEqualTo(QOS_REASON); - assertThat(exception.getRetryAfter()).isEmpty(); - }); + Consumer validationFunction = exception -> { + assertThat(exception).isInstanceOfSatisfying(QosException.Throttle.class, qosException -> { + assertThat(qosException.getReason()).isEqualTo(QOS_REASON); + assertThat(qosException.getRetryAfter()).isEmpty(); + }); + }; + + if (decoderType == DecoderType.LEGACY) { + assertThat(decoder.isError(response)).isTrue(); + RuntimeException result = decoder.decode(response); + assertThat(result).isInstanceOfSatisfying(RuntimeException.class, validationFunction); + } else { + assertThat(endpointErrorDecoder.isError(response)).isTrue(); + assertThatExceptionOfType(RuntimeException.class) + .isThrownBy(() -> endpointErrorDecoder.decode(response)) + .satisfies(validationFunction); + } } - @Test - public void testQos308_noLocation() { + @ParameterizedTest + @EnumSource(DecoderType.class) + public void testQos308_noLocation(DecoderType decoderType) { Response response = TestResponse.withBody(SERIALIZED_EXCEPTION).code(308); - assertThat(decoder.isError(response)).isTrue(); - RuntimeException result = decoder.decode(response); - assertThat(result) - .isInstanceOfSatisfying(UnknownRemoteException.class, exception -> assertThat(exception.getStatus()) - .isEqualTo(308)); + Consumer validationFunction = exception -> { + assertThat(exception).isInstanceOfSatisfying(UnknownRemoteException.class, unknownException -> { + assertThat(unknownException.getStatus()).isEqualTo(308); + }); + }; + + if (decoderType == DecoderType.LEGACY) { + assertThat(decoder.isError(response)).isTrue(); + RuntimeException result = decoder.decode(response); + assertThat(result).isInstanceOfSatisfying(RuntimeException.class, validationFunction); + } else { + assertThat(endpointErrorDecoder.isError(response)).isTrue(); + assertThatExceptionOfType(RuntimeException.class) + .isThrownBy(() -> endpointErrorDecoder.decode(response)) + .satisfies(validationFunction); + } } - @Test - public void testQos308_invalidLocation() { + @ParameterizedTest + @EnumSource(DecoderType.class) + public void testQos308_invalidLocation(DecoderType decoderType) { Response response = TestResponse.withBody(SERIALIZED_EXCEPTION).code(308).withHeader(HttpHeaders.LOCATION, "invalid"); - assertThat(decoder.isError(response)).isTrue(); - RuntimeException result = decoder.decode(response); - assertThat(result) - .isInstanceOfSatisfying(UnknownRemoteException.class, exception -> assertThat(exception.getStatus()) - .isEqualTo(308)); + Consumer validationFunction = exception -> { + assertThat(exception).isInstanceOfSatisfying(UnknownRemoteException.class, unknownException -> { + assertThat(unknownException.getStatus()).isEqualTo(308); + }); + }; + + if (decoderType == DecoderType.LEGACY) { + assertThat(decoder.isError(response)).isTrue(); + RuntimeException result = decoder.decode(response); + assertThat(result).isInstanceOfSatisfying(RuntimeException.class, validationFunction); + } else { + assertThat(endpointErrorDecoder.isError(response)).isTrue(); + assertThatExceptionOfType(RuntimeException.class) + .isThrownBy(() -> endpointErrorDecoder.decode(response)) + .satisfies(validationFunction); + } } - @Test - public void testQos308() { + @ParameterizedTest + @EnumSource(ErrorDecoderTest.DecoderType.class) + public void testQos308(DecoderType decoderType) { String expectedLocation = "https://localhost"; Response response = TestResponse.withBody(SERIALIZED_EXCEPTION) .code(308) .withHeader(HttpHeaders.LOCATION, expectedLocation); - assertThat(decoder.isError(response)).isTrue(); - RuntimeException result = decoder.decode(response); - assertThat(result) - .isInstanceOf(UnknownRemoteException.class) - .getRootCause() - .isInstanceOfSatisfying(QosException.RetryOther.class, exception -> { - assertThat(exception.getReason()).isEqualTo(QOS_REASON); - assertThat(exception.getRedirectTo()).asString().isEqualTo(expectedLocation); - }); + Consumer validationFunction = exception -> { + assertThat(exception) + .isInstanceOf(UnknownRemoteException.class) + .getRootCause() + .isInstanceOfSatisfying(QosException.RetryOther.class, qosException -> { + assertThat(qosException.getReason()).isEqualTo(QOS_REASON); + assertThat(qosException.getRedirectTo()).asString().isEqualTo(expectedLocation); + }); + }; + + if (decoderType == DecoderType.LEGACY) { + assertThat(decoder.isError(response)).isTrue(); + RuntimeException result = decoder.decode(response); + assertThat(result).isInstanceOfSatisfying(RuntimeException.class, validationFunction); + } else { + assertThat(endpointErrorDecoder.isError(response)).isTrue(); + assertThatExceptionOfType(RuntimeException.class) + .isThrownBy(() -> endpointErrorDecoder.decode(response)) + .satisfies(validationFunction); + } } @Test @@ -216,16 +335,31 @@ public void cannotDecodeNonJsonMediaTypes() { TestResponse.withBody(SERIALIZED_EXCEPTION).code(500).contentType("text/plain"))) .isInstanceOf(UnknownRemoteException.class) .hasMessage("Response status: 500"); + + assertThatExceptionOfType(UnknownRemoteException.class) + .isThrownBy(() -> endpointErrorDecoder.decode( + TestResponse.withBody(SERIALIZED_EXCEPTION).code(500).contentType("text/plain"))) + .satisfies(exception -> assertThat(exception.getMessage()).isEqualTo("Response status: 500")); } - @Test - public void doesNotHandleUnparseableBody() { - assertThat(decoder.decode(TestResponse.withBody("not json").code(500).contentType("application/json/"))) - .isInstanceOfSatisfying(UnknownRemoteException.class, expected -> { - assertThat(expected.getStatus()).isEqualTo(500); - assertThat(expected.getBody()).isEqualTo("not json"); - assertThat(expected.getMessage()).isEqualTo("Response status: 500"); - }); + @ParameterizedTest + @EnumSource(DecoderType.class) + public void doesNotHandleUnparseableBody(DecoderType decoderType) { + Response response = TestResponse.withBody("not json").code(500).contentType("application/json/"); + + Consumer validationFunction = exception -> { + assertThat(exception.getStatus()).isEqualTo(500); + assertThat(exception.getBody()).isEqualTo("not json"); + }; + + if (decoderType == DecoderType.LEGACY) { + RuntimeException result = decoder.decode(response); + assertThat(result).isInstanceOfSatisfying(UnknownRemoteException.class, validationFunction); + } else { + assertThatExceptionOfType(UnknownRemoteException.class) + .isThrownBy(() -> endpointErrorDecoder.decode(response)) + .satisfies(validationFunction); + } } @Test @@ -234,32 +368,57 @@ public void doesNotHandleNullBody() { assertThat(decoder.decode(TestResponse.withBody(null).code(500).contentType("application/json"))) .isInstanceOf(UnknownRemoteException.class) .hasMessage("Response status: 500"); + + assertThatExceptionOfType(UnknownRemoteException.class) + .isThrownBy(() -> endpointErrorDecoder.decode( + TestResponse.withBody(null).code(500).contentType("application/json"))) + .satisfies(exception -> assertThat(exception.getMessage()).isEqualTo("Response status: 500")); } - @Test - public void handlesUnexpectedJson() { - assertThat(decoder.decode(TestResponse.withBody("{\"error\":\"some-unknown-json\"}") - .code(502) - .contentType("application/json"))) - .isInstanceOfSatisfying(UnknownRemoteException.class, expected -> { - assertThat(expected.getStatus()).isEqualTo(502); - assertThat(expected.getBody()).isEqualTo("{\"error\":\"some-unknown-json\"}"); - assertThat(expected.getMessage()).isEqualTo("Response status: 502"); - }); + @ParameterizedTest + @EnumSource(DecoderType.class) + public void handlesUnexpectedJson(DecoderType decoderType) { + Response response = TestResponse.withBody("{\"error\":\"some-unknown-json\"}") + .code(502) + .contentType("application/json"); + + Consumer validationFunction = expected -> { + assertThat(expected.getStatus()).isEqualTo(502); + assertThat(expected.getBody()).isEqualTo("{\"error\":\"some-unknown-json\"}"); + assertThat(expected.getMessage()).isEqualTo("Response status: 502"); + }; + if (decoderType == DecoderType.LEGACY) { + assertThat(decoder.decode(response)) + .isInstanceOfSatisfying(UnknownRemoteException.class, validationFunction); + } else { + assertThatExceptionOfType(UnknownRemoteException.class) + .isThrownBy(() -> endpointErrorDecoder.decode(response)) + .satisfies(validationFunction); + } } - @Test - public void handlesJsonWithEncoding() { + @ParameterizedTest + @EnumSource(DecoderType.class) + public void handlesJsonWithEncoding(DecoderType decoderType) { int code = 500; - RuntimeException result = decoder.decode( - TestResponse.withBody(SERIALIZED_EXCEPTION).code(code).contentType("application/json; charset=utf-8")); - assertThat(result).isInstanceOfSatisfying(RemoteException.class, exception -> { + Response response = + TestResponse.withBody(SERIALIZED_EXCEPTION).code(code).contentType("application/json; charset=utf-8"); + + Consumer validationFunction = exception -> { assertThat(exception.getCause()).isNull(); assertThat(exception.getStatus()).isEqualTo(code); assertThat(exception.getError().errorCode()) .isEqualTo(ErrorType.FAILED_PRECONDITION.code().name()); assertThat(exception.getError().errorName()).isEqualTo(ErrorType.FAILED_PRECONDITION.name()); - }); + }; + + if (decoderType == DecoderType.LEGACY) { + assertThat(decoder.decode(response)).isInstanceOfSatisfying(RemoteException.class, validationFunction); + } else { + assertThatExceptionOfType(RemoteException.class) + .isThrownBy(() -> endpointErrorDecoder.decode(response)) + .satisfies(validationFunction); + } } private static RemoteException encodeAndDecode(Exception exception) { diff --git a/dialogue-target/src/main/java/com/palantir/dialogue/BodySerDe.java b/dialogue-target/src/main/java/com/palantir/dialogue/BodySerDe.java index 8801f0c44..6529e9662 100644 --- a/dialogue-target/src/main/java/com/palantir/dialogue/BodySerDe.java +++ b/dialogue-target/src/main/java/com/palantir/dialogue/BodySerDe.java @@ -28,6 +28,8 @@ public interface BodySerDe { /** Creates a {@link Deserializer} for the requested type. Deserializer instances should be reused. */ Deserializer deserializer(TypeMarker type); + Deserializer deserializer(DeserializerArgs deserializerArgs); + /** * Returns a {@link Deserializer} that fails if a non-empty reponse body is presented and returns null otherwise. */ @@ -41,9 +43,13 @@ public interface BodySerDe { */ Deserializer inputStreamDeserializer(); + Deserializer inputStreamDeserializer(DeserializerArgs deserializerArgs); + /** Same as {@link #inputStreamDeserializer()} with support for 204 responses. */ Deserializer> optionalInputStreamDeserializer(); + Deserializer optionalInputStreamDeserializer(DeserializerArgs deserializerArgs); + /** Serializes a {@link BinaryRequestBody} to
application/octet-stream
. */ RequestBody serialize(BinaryRequestBody value); } diff --git a/dialogue-target/src/main/java/com/palantir/dialogue/DeserializerArgs.java b/dialogue-target/src/main/java/com/palantir/dialogue/DeserializerArgs.java new file mode 100644 index 000000000..66891ce1f --- /dev/null +++ b/dialogue-target/src/main/java/com/palantir/dialogue/DeserializerArgs.java @@ -0,0 +1,98 @@ +/* + * (c) Copyright 2024 Palantir Technologies Inc. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.palantir.dialogue; + +import com.google.common.collect.ImmutableMap; +import com.palantir.logsafe.Preconditions; +import java.util.HashMap; +import java.util.Map; +import javax.annotation.Nonnull; +import javax.annotation.Nullable; + +public final class DeserializerArgs { + private final TypeMarker baseType; + private final TypeMarker successType; + private final ImmutableMap> errorNameToTypeMarker; + + private DeserializerArgs( + TypeMarker baseType, + TypeMarker successType, + ImmutableMap> map) { + this.baseType = baseType; + this.successType = successType; + this.errorNameToTypeMarker = map; + } + + public static Builder builder() { + return new Builder<>(); + } + + public static final class Builder { + private boolean buildInvoked = false; + private @Nullable TypeMarker baseType; + private @Nullable TypeMarker successType; + private final Map> errorNameToTypeMarker; + + private Builder() { + this.errorNameToTypeMarker = new HashMap<>(); + } + + public Builder baseType(@Nonnull TypeMarker baseT) { + checkNotBuilt(); + this.baseType = Preconditions.checkNotNull(baseT, "base type must be non-null"); + return this; + } + + public Builder success(@Nonnull TypeMarker successT) { + checkNotBuilt(); + this.successType = Preconditions.checkNotNull(successT, "success type must be non-null"); + return this; + } + + public Builder error(@Nonnull String errorName, @Nonnull TypeMarker errorT) { + checkNotBuilt(); + this.errorNameToTypeMarker.put( + Preconditions.checkNotNull(errorName, "error name must be non-null"), + Preconditions.checkNotNull(errorT, "error type must be non-null")); + return this; + } + + public DeserializerArgs build() { + checkNotBuilt(); + Preconditions.checkNotNull(baseType, "base type must be set"); + Preconditions.checkNotNull(successType, "success type must be set"); + buildInvoked = true; + return new DeserializerArgs<>(baseType, successType, ImmutableMap.copyOf(errorNameToTypeMarker)); + } + + private void checkNotBuilt() { + Preconditions.checkState(!buildInvoked, "build has already been called"); + } + } + + public TypeMarker baseType() { + return baseType; + } + + public TypeMarker successType() { + return successType; + } + + public Map> errorNameToTypeMarker() { + return errorNameToTypeMarker; + } +} diff --git a/dialogue-target/src/main/java/com/palantir/dialogue/EndpointError.java b/dialogue-target/src/main/java/com/palantir/dialogue/EndpointError.java new file mode 100644 index 000000000..d2d595a27 --- /dev/null +++ b/dialogue-target/src/main/java/com/palantir/dialogue/EndpointError.java @@ -0,0 +1,55 @@ +/* + * (c) Copyright 2025 Palantir Technologies Inc. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.palantir.dialogue; + +import com.palantir.logsafe.Safe; + +public abstract class EndpointError { + @Safe + private final String errorCode; + + @Safe + private final String errorName; + + @Safe + private final String errorInstanceId; + + private final T params; + + protected EndpointError(String errorCode, String errorName, String errorInstanceId, T params) { + this.errorCode = errorCode; + this.errorName = errorName; + this.errorInstanceId = errorInstanceId; + this.params = params; + } + + public final String getErrorCode() { + return errorCode; + } + + public final String getErrorName() { + return errorName; + } + + public final String getErrorInstanceId() { + return errorInstanceId; + } + + public final T getParams() { + return params; + } +}