diff --git a/.palantir/revapi.yml b/.palantir/revapi.yml index fe9ee372f..ff2dd1cc5 100644 --- a/.palantir/revapi.yml +++ b/.palantir/revapi.yml @@ -311,3 +311,9 @@ acceptedBreaks: new: "method com.palantir.dialogue.clients.DialogueClients.StickyChannelSession\ \ com.palantir.dialogue.clients.DialogueClients.StickyChannelFactory2::session()" justification: "interface for consumption, not extension" + "4.6.0": + com.palantir.dialogue:dialogue-target: + - code: "java.method.addedToInterface" + new: "method com.palantir.dialogue.Deserializer com.palantir.dialogue.BodySerDe::deserializer(com.palantir.dialogue.DeserializerArgs)" + justification: "Adding a new method to create deserializers in support of endpoint\ + \ associated error deserialization" diff --git a/dialogue-serde/src/main/java/com/palantir/conjure/java/dialogue/serde/ConjureBodySerDe.java b/dialogue-serde/src/main/java/com/palantir/conjure/java/dialogue/serde/ConjureBodySerDe.java index b738b8529..13bcb68c3 100644 --- a/dialogue-serde/src/main/java/com/palantir/conjure/java/dialogue/serde/ConjureBodySerDe.java +++ b/dialogue-serde/src/main/java/com/palantir/conjure/java/dialogue/serde/ConjureBodySerDe.java @@ -26,6 +26,7 @@ import com.palantir.dialogue.BinaryRequestBody; import com.palantir.dialogue.BodySerDe; import com.palantir.dialogue.Deserializer; +import com.palantir.dialogue.DeserializerArgs; import com.palantir.dialogue.RequestBody; import com.palantir.dialogue.Response; import com.palantir.dialogue.Serializer; @@ -48,6 +49,12 @@ import java.util.function.Supplier; import java.util.stream.Collectors; +/** + * items: + * - we don't want to use `String` for the error identifier. Let's create an `ErrorName` class. + * - re-consider using a map for the deserializersForEndpointBaseType field. is there a more direct way to get this info + */ + /** Package private internal API. */ final class ConjureBodySerDe implements BodySerDe { @@ -58,7 +65,8 @@ final class ConjureBodySerDe implements BodySerDe { private final Deserializer> optionalBinaryInputStreamDeserializer; private final Deserializer emptyBodyDeserializer; private final LoadingCache> serializers; - private final LoadingCache> deserializers; + private final LoadingCache> deserializers; + private final EmptyContainerDeserializer emptyContainerDeserializer; /** * Selects the first (based on input order) of the provided encodings that @@ -74,6 +82,7 @@ final class ConjureBodySerDe implements BodySerDe { this.encodingsSortedByWeight = sortByWeight(encodings); Preconditions.checkArgument(encodings.size() > 0, "At least one Encoding is required"); this.defaultEncoding = encodings.get(0).encoding(); + this.emptyContainerDeserializer = emptyContainerDeserializer; this.binaryInputStreamDeserializer = new EncodingDeserializerRegistry<>( ImmutableList.of(BinaryEncoding.INSTANCE), errorDecoder, @@ -122,6 +131,16 @@ public Deserializer deserializer(TypeMarker token) { return (Deserializer) deserializers.get(token.getType()); } + @Override + @SuppressWarnings("unchecked") + public Deserializer deserializer(DeserializerArgs deserializerArgs) { + return new EncodingDeserializerForEndpointRegistry<>( + encodingsSortedByWeight, + emptyContainerDeserializer, + (TypeMarker) deserializerArgs.baseType(), + deserializerArgs); + } + @Override public Deserializer emptyBodyDeserializer() { return emptyBodyDeserializer; @@ -301,6 +320,105 @@ Encoding.Deserializer getResponseDeserializer(String contentType) { return throwingDeserializer(contentType); } + private Encoding.Deserializer throwingDeserializer(String contentType) { + return input -> { + try { + input.close(); + } catch (RuntimeException | IOException e) { + log.warn("Failed to close InputStream", e); + } + throw new SafeRuntimeException( + "Unsupported Content-Type", + SafeArg.of("received", contentType), + SafeArg.of("supportedEncodings", encodings)); + }; + } + } + + private static final class EncodingDeserializerForEndpointRegistry implements Deserializer { + + private static final SafeLogger log = SafeLoggerFactory.get(EncodingDeserializerForEndpointRegistry.class); + private final ImmutableList> encodings; + private final EndpointErrorDecoder endpointErrorDecoder; + private final Optional acceptValue; + private final Supplier> emptyInstance; + private final TypeMarker token; + + EncodingDeserializerForEndpointRegistry( + List encodings, + EmptyContainerDeserializer empty, + TypeMarker token, + DeserializerArgs deserializersForEndpoint) { + this.encodings = encodings.stream() + .map(encoding -> new EncodingDeserializerContainer<>( + encoding, deserializersForEndpoint.expectedResultType())) + .collect(ImmutableList.toImmutableList()); + this.endpointErrorDecoder = + new EndpointErrorDecoder<>(deserializersForEndpoint.errorNameToTypeMarker(), encodings); + this.token = token; + this.emptyInstance = Suppliers.memoize(() -> empty.tryGetEmptyInstance(token)); + // Encodings are applied to the accept header in the order of preference based on the provided list. + this.acceptValue = + Optional.of(encodings.stream().map(Encoding::getContentType).collect(Collectors.joining(", "))); + } + + @Override + public T deserialize(Response response) { + boolean closeResponse = true; + try { + if (endpointErrorDecoder.isError(response)) { + return endpointErrorDecoder.decode(response); + } else if (response.code() == 204) { + Optional maybeEmptyInstance = emptyInstance.get(); + if (maybeEmptyInstance.isPresent()) { + return maybeEmptyInstance.get(); + } + throw new SafeRuntimeException( + "Unable to deserialize non-optional response type from 204", SafeArg.of("type", token)); + } + + Optional contentType = response.getFirstHeader(HttpHeaders.CONTENT_TYPE); + if (!contentType.isPresent()) { + throw new SafeIllegalArgumentException( + "Response is missing Content-Type header", + SafeArg.of("received", response.headers().keySet())); + } + Encoding.Deserializer deserializer = getResponseDeserializer(contentType.get()); + T deserialized = deserializer.deserialize(response.body()); + // deserializer has taken on responsibility for closing the response body + closeResponse = false; + return deserialized; + } catch (IOException e) { + throw new SafeRuntimeException( + "Failed to deserialize response stream", + e, + SafeArg.of("contentType", response.getFirstHeader(HttpHeaders.CONTENT_TYPE)), + SafeArg.of("type", token)); + } finally { + if (closeResponse) { + response.close(); + } + } + } + + @Override + public Optional accepts() { + return acceptValue; + } + + /** Returns the {@link EncodingDeserializerContainer} to use to deserialize the request body. */ + @SuppressWarnings("ForLoopReplaceableByForEach") + // performance sensitive code avoids iterator allocation + Encoding.Deserializer getResponseDeserializer(String contentType) { + for (int i = 0; i < encodings.size(); i++) { + EncodingDeserializerContainer container = encodings.get(i); + if (container.encoding.supportsContentType(contentType)) { + return container.deserializer; + } + } + return throwingDeserializer(contentType); + } + private Encoding.Deserializer throwingDeserializer(String contentType) { return new Encoding.Deserializer() { @Override @@ -320,7 +438,8 @@ public T deserialize(InputStream input) { } /** Effectively just a pair. */ - private static final class EncodingDeserializerContainer { + // TODO(pm): what does saving the deserializer do for us? + static final class EncodingDeserializerContainer { private final Encoding encoding; private final Encoding.Deserializer deserializer; diff --git a/dialogue-serde/src/main/java/com/palantir/conjure/java/dialogue/serde/EndpointErrorDecoder.java b/dialogue-serde/src/main/java/com/palantir/conjure/java/dialogue/serde/EndpointErrorDecoder.java new file mode 100644 index 000000000..34b1261cf --- /dev/null +++ b/dialogue-serde/src/main/java/com/palantir/conjure/java/dialogue/serde/EndpointErrorDecoder.java @@ -0,0 +1,229 @@ +/* + * (c) Copyright 2024 Palantir Technologies Inc. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.palantir.conjure.java.dialogue.serde; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.collect.ImmutableList; +import com.google.common.io.CharStreams; +import com.google.common.net.HttpHeaders; +import com.google.common.primitives.Longs; +import com.palantir.conjure.java.api.errors.QosException; +import com.palantir.conjure.java.api.errors.QosReason; +import com.palantir.conjure.java.api.errors.QosReasons; +import com.palantir.conjure.java.api.errors.QosReasons.QosResponseDecodingAdapter; +import com.palantir.conjure.java.api.errors.RemoteException; +import com.palantir.conjure.java.api.errors.SerializableError; +import com.palantir.conjure.java.api.errors.UnknownRemoteException; +import com.palantir.conjure.java.serialization.ObjectMappers; +import com.palantir.dialogue.Response; +import com.palantir.dialogue.TypeMarker; +import com.palantir.logsafe.Arg; +import com.palantir.logsafe.SafeArg; +import com.palantir.logsafe.SafeLoggable; +import com.palantir.logsafe.UnsafeArg; +import com.palantir.logsafe.exceptions.SafeExceptions; +import com.palantir.logsafe.logger.SafeLogger; +import com.palantir.logsafe.logger.SafeLoggerFactory; +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.io.Reader; +import java.net.MalformedURLException; +import java.net.URL; +import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +// TODO(pm): public because maybe we need to expose this in the dialogue annotations. What does that do? +// T is the base type of the endpoint response. It's a union of the result type and all of the error types. +public final class EndpointErrorDecoder { + private static final SafeLogger log = SafeLoggerFactory.get(EndpointErrorDecoder.class); + private static final ObjectMapper MAPPER = ObjectMappers.newClientObjectMapper(); + private final Map> errorNameToTypeMap; + private final List encodings; + + public EndpointErrorDecoder(Map> errorNameToTypeMap, List encodings) { + this.errorNameToTypeMap = errorNameToTypeMap; + this.encodings = encodings; + } + + public boolean isError(Response response) { + return 300 <= response.code() && response.code() <= 599; + } + + public T decode(Response response) { + if (log.isDebugEnabled()) { + log.debug("Received an error response", diagnosticArgs(response)); + } + try { + return decodeInternal(response); + } catch (Exception e) { + // TODO(pm): do we want to add the diagnostic information to the result type as well? + e.addSuppressed(diagnostic(response)); + throw e; + } + } + + @SuppressWarnings("checkstyle:CyclomaticComplexity") + private T decodeInternal(Response response) { + int code = response.code(); + switch (code) { + case 308: + Optional location = response.getFirstHeader(HttpHeaders.LOCATION); + if (location.isPresent()) { + String locationHeader = location.get(); + try { + UnknownRemoteException remoteException = new UnknownRemoteException(code, ""); + remoteException.initCause( + QosException.retryOther(qosReason(response), new URL(locationHeader))); + throw remoteException; + } catch (MalformedURLException e) { + log.error( + "Failed to parse location header for QosException.RetryOther", + UnsafeArg.of("locationHeader", locationHeader), + e); + } + } else { + log.error("Retrieved HTTP status code 308 without Location header, cannot perform " + + "redirect. This appears to be a server-side protocol violation."); + } + break; + case 429: + throw response.getFirstHeader(HttpHeaders.RETRY_AFTER) + .map(Longs::tryParse) + .map(Duration::ofSeconds) + .map(duration -> QosException.throttle(qosReason(response), duration)) + .orElseGet(() -> QosException.throttle(qosReason(response))); + case 503: + throw QosException.unavailable(qosReason(response)); + } + + String body; + try { + body = toString(response.body()); + } catch (NullPointerException | IOException e) { + UnknownRemoteException exception = new UnknownRemoteException(code, ""); + exception.initCause(e); + throw exception; + } + + Optional contentType = response.getFirstHeader(HttpHeaders.CONTENT_TYPE); + // Use a factory: given contentType, create the deserailizer. + // We need Encoding.Deserializer here. That depends on the encoding. + if (contentType.isPresent() && Encodings.matchesContentType("application/json", contentType.get())) { + try { + // TODO(pm): figure out if we can avoid double parsing. + JsonNode node = MAPPER.readTree(body); + if (node.get("errorName") != null) { + // TODO(pm): Update this to use some struct instead of errorName. + TypeMarker container = Optional.ofNullable( + errorNameToTypeMap.get(node.get("errorName").asText())) + .orElseThrow(); + // make this a normal for + for (Encoding encoding : encodings) { + if (encoding.supportsContentType(contentType.get())) { + return encoding.deserializer(container) + .deserialize(new ByteArrayInputStream(body.getBytes(StandardCharsets.UTF_8))); + } + } + } else { + SerializableError serializableError = MAPPER.readValue(body, SerializableError.class); + throw new RemoteException(serializableError, code); + } + } catch (Exception e) { + throw new UnknownRemoteException(code, body); + } + } + + throw new UnknownRemoteException(code, body); + } + + private static String toString(InputStream body) throws IOException { + try (Reader reader = new InputStreamReader(body, StandardCharsets.UTF_8)) { + return CharStreams.toString(reader); + } + } + + private static ResponseDiagnostic diagnostic(Response response) { + return new ResponseDiagnostic(diagnosticArgs(response)); + } + + private static ImmutableList> diagnosticArgs(Response response) { + ImmutableList.Builder> args = ImmutableList.>builder().add(SafeArg.of("status", response.code())); + recordHeader(HttpHeaders.SERVER, response, args); + recordHeader(HttpHeaders.CONTENT_TYPE, response, args); + recordHeader(HttpHeaders.CONTENT_LENGTH, response, args); + recordHeader(HttpHeaders.CONNECTION, response, args); + recordHeader(HttpHeaders.DATE, response, args); + recordHeader("x-envoy-response-flags", response, args); + recordHeader("x-envoy-response-code-details", response, args); + recordHeader("Response-Flags", response, args); + recordHeader("Response-Code-Details", response, args); + return args.build(); + } + + private static void recordHeader(String header, Response response, ImmutableList.Builder> args) { + response.getFirstHeader(header).ifPresent(server -> args.add(SafeArg.of(header, server))); + } + + private static final class ResponseDiagnostic extends RuntimeException implements SafeLoggable { + + private static final String SAFE_MESSAGE = "Response Diagnostic Information"; + + private final ImmutableList> args; + + ResponseDiagnostic(ImmutableList> args) { + super(SafeExceptions.renderMessage(SAFE_MESSAGE, args.toArray(new Arg[0]))); + this.args = args; + } + + @Override + public String getLogMessage() { + return SAFE_MESSAGE; + } + + @Override + public List> getArgs() { + return args; + } + + @Override + @SuppressWarnings("UnsynchronizedOverridesSynchronized") // nop + public Throwable fillInStackTrace() { + // no-op: stack trace generation is expensive, this type exists + // to simply associate diagnostic information with a failure. + return this; + } + } + + private static QosReason qosReason(Response response) { + return QosReasons.parseFromResponse(response, DialogueQosResponseDecodingAdapter.INSTANCE); + } + + private enum DialogueQosResponseDecodingAdapter implements QosResponseDecodingAdapter { + INSTANCE; + + @Override + public Optional getFirstHeader(Response response, String headerName) { + return response.getFirstHeader(headerName); + } + } +} diff --git a/dialogue-serde/src/test/java/com/palantir/conjure/java/dialogue/serde/EndpointErrorTestUtils.java b/dialogue-serde/src/test/java/com/palantir/conjure/java/dialogue/serde/EndpointErrorTestUtils.java new file mode 100644 index 000000000..69bcb6949 --- /dev/null +++ b/dialogue-serde/src/test/java/com/palantir/conjure/java/dialogue/serde/EndpointErrorTestUtils.java @@ -0,0 +1,100 @@ +/* + * (c) Copyright 2024 Palantir Technologies Inc. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.palantir.conjure.java.dialogue.serde; + +import com.fasterxml.jackson.annotation.JsonProperty; +import com.palantir.conjure.java.api.errors.CheckedServiceException; +import com.palantir.dialogue.TypeMarker; +import com.palantir.logsafe.Arg; +import java.util.HashMap; +import java.util.Map; +import java.util.Optional; +import java.util.OptionalDouble; +import java.util.OptionalInt; +import java.util.OptionalLong; + +final class EndpointErrorTestUtils { + private EndpointErrorTestUtils() {} + + record ConjureError( + @JsonProperty("errorCode") String errorCode, + @JsonProperty("errorName") String errorName, + @JsonProperty("errorInstanceId") String errorInstanceId, + @JsonProperty("parameters") Map parameters) { + static ConjureError fromCheckedServiceException(CheckedServiceException exception) { + Map parameters = new HashMap<>(); + for (Arg arg : exception.getArgs()) { + if (shouldIncludeArgInParameters(arg)) { + parameters.put(arg.getName(), arg.getValue()); + } + } + return new ConjureError( + exception.getErrorType().code().name(), + exception.getErrorType().name(), + exception.getErrorInstanceId(), + parameters); + } + + private static boolean shouldIncludeArgInParameters(Arg arg) { + Object obj = arg.getValue(); + return obj != null + && (!(obj instanceof Optional) || ((Optional) obj).isPresent()) + && (!(obj instanceof OptionalInt) || ((OptionalInt) obj).isPresent()) + && (!(obj instanceof OptionalLong) || ((OptionalLong) obj).isPresent()) + && (!(obj instanceof OptionalDouble) || ((OptionalDouble) obj).isPresent()); + } + } + + /** Deserializes requests as the type. */ + public static final class TypeReturningStubEncoding implements Encoding { + + private final String contentType; + + TypeReturningStubEncoding(String contentType) { + this.contentType = contentType; + } + + @Override + public Encoding.Serializer serializer(TypeMarker _type) { + return (_value, _output) -> { + // nop + }; + } + + @Override + public Encoding.Deserializer deserializer(TypeMarker type) { + return input -> { + return (T) Encodings.json().deserializer(type).deserialize(input); + }; + } + + @Override + public String getContentType() { + return contentType; + } + + @Override + public boolean supportsContentType(String input) { + return contentType.equals(input); + } + + @Override + public String toString() { + return "TypeReturningStubEncoding{" + contentType + '}'; + } + } +} diff --git a/dialogue-serde/src/test/java/com/palantir/conjure/java/dialogue/serde/EndpointErrorsConjureBodySerDeTest.java b/dialogue-serde/src/test/java/com/palantir/conjure/java/dialogue/serde/EndpointErrorsConjureBodySerDeTest.java new file mode 100644 index 000000000..296bb193d --- /dev/null +++ b/dialogue-serde/src/test/java/com/palantir/conjure/java/dialogue/serde/EndpointErrorsConjureBodySerDeTest.java @@ -0,0 +1,182 @@ +/* + * (c) Copyright 2024 Palantir Technologies Inc. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.palantir.conjure.java.dialogue.serde; + +import static org.assertj.core.api.AssertionsForClassTypes.assertThat; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.collect.ImmutableList; +import com.palantir.conjure.java.api.errors.CheckedServiceException; +import com.palantir.conjure.java.api.errors.ErrorType; +import com.palantir.conjure.java.dialogue.serde.EndpointErrorTestUtils.ConjureError; +import com.palantir.conjure.java.dialogue.serde.EndpointErrorTestUtils.TypeReturningStubEncoding; +import com.palantir.conjure.java.serialization.ObjectMappers; +import com.palantir.dialogue.BodySerDe; +import com.palantir.dialogue.DeserializerArgs; +import com.palantir.dialogue.TestResponse; +import com.palantir.dialogue.TypeMarker; +import com.palantir.logsafe.Preconditions; +import com.palantir.logsafe.Safe; +import com.palantir.logsafe.SafeArg; +import com.palantir.logsafe.Unsafe; +import com.palantir.logsafe.UnsafeArg; +import java.io.IOException; +import java.util.Arrays; +import java.util.Optional; +import javax.annotation.Nullable; +import javax.annotation.processing.Generated; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.junit.jupiter.MockitoExtension; + +@ExtendWith(MockitoExtension.class) +public class EndpointErrorsConjureBodySerDeTest { + private static final ObjectMapper MAPPER = ObjectMappers.newServerObjectMapper(); + private ErrorDecoder errorDecoder = ErrorDecoder.INSTANCE; + + @Generated("by conjure-java") + private sealed interface EndpointReturnBaseType permits StringReturn, ErrorForEndpoint {} + + @Generated("by conjure-java") + record StringReturn(String value) implements EndpointReturnBaseType { + @JsonCreator + public static StringReturn create(String value) { + return new StringReturn(Preconditions.checkArgumentNotNull(value, "value cannot be null")); + } + } + + abstract static class EndpointError { + @Safe + String errorCode; + + @Safe + String errorName; + + @Safe + String errorInstanceId; + + T args; + + EndpointError(String errorCode, String errorName, String errorInstanceId, T args) { + this.errorCode = errorCode; + this.errorName = errorName; + this.errorInstanceId = errorInstanceId; + this.args = args; + } + } + + record ErrorForEndpointArgs( + @JsonProperty("arg") @Safe String arg, + @JsonProperty("unsafeArg") @Unsafe String unsafeArg, + @JsonProperty("complexArg") @Safe ComplexArg complexArg, + @JsonProperty("optionalArg") @Safe Optional optionalArg) {} + + static final class ErrorForEndpoint extends EndpointError implements EndpointReturnBaseType { + @JsonCreator + ErrorForEndpoint( + @JsonProperty("errorCode") String errorCode, + @JsonProperty("errorName") String errorName, + @JsonProperty("errorInstanceId") String errorInstanceId, + @JsonProperty("parameters") ErrorForEndpointArgs args) { + super(errorCode, errorName, errorInstanceId, args); + } + } + + @Generated("by conjure-java") + record ComplexArg(int foo, String bar) {} + + @Generated("by conjure-java") + public static final class TestEndpointError extends CheckedServiceException { + private TestEndpointError( + @Safe String arg, + @Unsafe String unsafeArg, + @Safe ComplexArg complexArg, + @Safe Optional optionalArg, + @Nullable Throwable cause) { + super( + ErrorType.FAILED_PRECONDITION, + cause, + SafeArg.of("arg", arg), + UnsafeArg.of("unsafeArg", unsafeArg), + SafeArg.of("complexArg", complexArg), + SafeArg.of("optionalArg", optionalArg)); + } + } + + @Test + public void testDeserializeCustomErrors() throws IOException { + TestEndpointError errorThrownByEndpoint = + new TestEndpointError("value", "unsafeValue", new ComplexArg(1, "bar"), Optional.of(2), null); + + ErrorForEndpoint expectedErrorForEndpoint = new ErrorForEndpoint( + "FAILED_PRECONDITION", + "Default:FailedPrecondition", + errorThrownByEndpoint.getErrorInstanceId(), + new ErrorForEndpointArgs("value", "unsafeValue", new ComplexArg(1, "bar"), Optional.of(2))); + + String responseBody = + MAPPER.writeValueAsString(ConjureError.fromCheckedServiceException(errorThrownByEndpoint)); + TestResponse response = TestResponse.withBody(responseBody) + .contentType("application/json") + .code(500); + BodySerDe serializers = conjureBodySerDe("application/json", "text/plain"); + DeserializerArgs deserializerArgs = DeserializerArgs.builder() + .withBaseType(new TypeMarker<>() {}) + .withExpectedResult(new TypeMarker() {}) + .withErrorType("Default:FailedPrecondition", new TypeMarker() {}) + .build(); + EndpointErrorsConjureBodySerDeTest.EndpointReturnBaseType value = + serializers.deserializer(deserializerArgs).deserialize(response); + + assertThat(value) + .extracting("errorCode", "errorName", "errorInstanceId", "args") + .containsExactly( + expectedErrorForEndpoint.errorCode, + expectedErrorForEndpoint.errorName, + expectedErrorForEndpoint.errorInstanceId, + expectedErrorForEndpoint.args); + } + + @Test + public void testDeserializeExpectedValue() { + String expectedString = "expectedString"; + TestResponse response = TestResponse.withBody(String.format("\"%s\"", expectedString)) + .contentType("application/json") + .code(200); + BodySerDe serializers = conjureBodySerDe("application/json", "text/plain"); + DeserializerArgs deserializerArgs = DeserializerArgs.builder() + .withBaseType(new TypeMarker<>() {}) + .withExpectedResult(new TypeMarker() {}) + .withErrorType("Default:FailedPrecondition", new TypeMarker() {}) + .build(); + EndpointReturnBaseType value = + serializers.deserializer(deserializerArgs).deserialize(response); + assertThat(value).isEqualTo(new StringReturn(expectedString)); + } + + private ConjureBodySerDe conjureBodySerDe(String... contentTypes) { + return new ConjureBodySerDe( + Arrays.stream(contentTypes) + .map(c -> WeightedEncoding.of(new TypeReturningStubEncoding(c))) + .collect(ImmutableList.toImmutableList()), + errorDecoder, + Encodings.emptyContainerDeserializer(), + DefaultConjureRuntime.DEFAULT_SERDE_CACHE_SPEC); + } +} diff --git a/dialogue-serde/src/test/java/com/palantir/conjure/java/dialogue/serde/ErrorDecoderTest.java b/dialogue-serde/src/test/java/com/palantir/conjure/java/dialogue/serde/ErrorDecoderTest.java index 9f26a7aa6..510914481 100644 --- a/dialogue-serde/src/test/java/com/palantir/conjure/java/dialogue/serde/ErrorDecoderTest.java +++ b/dialogue-serde/src/test/java/com/palantir/conjure/java/dialogue/serde/ErrorDecoderTest.java @@ -54,7 +54,8 @@ public final class ErrorDecoderTest { private static String createServiceException(ServiceException exception) { try { - return SERVER_MAPPER.writeValueAsString(SerializableError.forException(exception)); + String ret = SERVER_MAPPER.writeValueAsString(SerializableError.forException(exception)); + return ret; } catch (JsonProcessingException e) { fail("failed to serialize"); return ""; diff --git a/dialogue-target/src/main/java/com/palantir/dialogue/BodySerDe.java b/dialogue-target/src/main/java/com/palantir/dialogue/BodySerDe.java index 8801f0c44..f7c69d264 100644 --- a/dialogue-target/src/main/java/com/palantir/dialogue/BodySerDe.java +++ b/dialogue-target/src/main/java/com/palantir/dialogue/BodySerDe.java @@ -28,6 +28,8 @@ public interface BodySerDe { /** Creates a {@link Deserializer} for the requested type. Deserializer instances should be reused. */ Deserializer deserializer(TypeMarker type); + Deserializer deserializer(DeserializerArgs deserializerArgs); + /** * Returns a {@link Deserializer} that fails if a non-empty reponse body is presented and returns null otherwise. */ diff --git a/dialogue-target/src/main/java/com/palantir/dialogue/DeserializerArgs.java b/dialogue-target/src/main/java/com/palantir/dialogue/DeserializerArgs.java new file mode 100644 index 000000000..c2690ef79 --- /dev/null +++ b/dialogue-target/src/main/java/com/palantir/dialogue/DeserializerArgs.java @@ -0,0 +1,108 @@ +/* + * (c) Copyright 2024 Palantir Technologies Inc. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.palantir.dialogue; + +import com.google.common.collect.ImmutableMap; +import com.palantir.logsafe.Preconditions; +import java.util.HashMap; +import java.util.Map; +import javax.annotation.Nonnull; + +public final class DeserializerArgs { + private final TypeMarker baseType; + private final TypeMarker expectedResultType; + private final ImmutableMap> errorNameToTypeMarker; + + private DeserializerArgs( + TypeMarker baseType, + TypeMarker expectedResultType, + ImmutableMap> map) { + this.baseType = baseType; + this.expectedResultType = expectedResultType; + this.errorNameToTypeMarker = map; + } + + public static Builder builder() { + return new Builder<>(); + } + + public static final class Builder { + private boolean buildInvoked = false; + private TypeMarker baseType; + private boolean baseTypeSet = false; + private TypeMarker expectedResultType; + private boolean expectedResultSet = false; + private final Map> errorNameToTypeMarker; + + @SuppressWarnings("NullAway") + // We ensure that the baseType and expectedResultType are set before building. + private Builder() { + this.errorNameToTypeMarker = new HashMap<>(); + } + + public Builder withBaseType(@Nonnull TypeMarker base) { + checkNotBuilt(); + this.baseType = Preconditions.checkNotNull(base, "base type must be non-null"); + this.baseTypeSet = true; + return this; + } + + public Builder withExpectedResult(TypeMarker expectedResultT) { + checkNotBuilt(); + this.expectedResultType = + Preconditions.checkNotNull(expectedResultT, "expected result type must be non-null"); + this.expectedResultSet = true; + return this; + } + + public Builder withErrorType(@Nonnull String errorName, @Nonnull TypeMarker errorType) { + checkNotBuilt(); + this.errorNameToTypeMarker.put( + Preconditions.checkNotNull(errorName, "error name must be non-null"), + Preconditions.checkNotNull(errorType, "error type must be non-null")); + return this; + } + + public DeserializerArgs build() { + checkNotBuilt(); + checkRequiredArgsSet(); + buildInvoked = true; + return new DeserializerArgs<>(baseType, expectedResultType, ImmutableMap.copyOf(errorNameToTypeMarker)); + } + + private void checkNotBuilt() { + Preconditions.checkState(!buildInvoked, "Build has already been called"); + } + + private void checkRequiredArgsSet() { + Preconditions.checkState(baseTypeSet, "base type must be set"); + Preconditions.checkState(expectedResultSet, "expected result type must be set"); + } + } + + public TypeMarker baseType() { + return baseType; + } + + public TypeMarker expectedResultType() { + return expectedResultType; + } + + public Map> errorNameToTypeMarker() { + return errorNameToTypeMarker; + } +}