From 26701397d84338a42c7acbce78368ae8f9d97271 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sun, 15 Sep 2024 18:15:50 +0200 Subject: [PATCH] incremental publisher should handle all response building Replicates graphql/graphql-js@1f30b54edc3f7b8443f4aedc48fc56c0d2be9705 --- docs/conf.py | 9 +- src/graphql/execution/__init__.py | 10 +- src/graphql/execution/execute.py | 278 +-------------- .../execution/incremental_publisher.py | 331 +++++++++++++++--- 4 files changed, 301 insertions(+), 327 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index 43766c1b..4655434b 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -150,6 +150,7 @@ EnterLeaveVisitor ExperimentalIncrementalExecutionResults FieldGroup +FormattedIncrementalResult FormattedSourceLocation GraphQLAbstractType GraphQLCompositeType @@ -161,19 +162,19 @@ GraphQLTypeResolver GroupedFieldSet IncrementalDataRecord +IncrementalResult InitialResultRecord Middleware SubsequentDataRecord asyncio.events.AbstractEventLoop graphql.execution.collect_fields.FieldsAndPatches -graphql.execution.map_async_iterable.map_async_iterable -graphql.execution.Middleware -graphql.execution.execute.ExperimentalIncrementalExecutionResults graphql.execution.execute.StreamArguments +graphql.execution.map_async_iterable.map_async_iterable +graphql.execution.incremental_publisher.DeferredFragmentRecord graphql.execution.incremental_publisher.IncrementalPublisher graphql.execution.incremental_publisher.InitialResultRecord graphql.execution.incremental_publisher.StreamItemsRecord -graphql.execution.incremental_publisher.DeferredFragmentRecord +graphql.execution.Middleware graphql.language.lexer.EscapeSequence graphql.language.visitor.EnterLeaveVisitor graphql.type.definition.GT_co diff --git a/src/graphql/execution/__init__.py b/src/graphql/execution/__init__.py index aec85be1..2d5225be 100644 --- a/src/graphql/execution/__init__.py +++ b/src/graphql/execution/__init__.py @@ -14,21 +14,21 @@ default_type_resolver, subscribe, ExecutionContext, - ExecutionResult, - ExperimentalIncrementalExecutionResults, - InitialIncrementalExecutionResult, - FormattedExecutionResult, - FormattedInitialIncrementalExecutionResult, Middleware, ) from .incremental_publisher import ( + ExecutionResult, + ExperimentalIncrementalExecutionResults, FormattedSubsequentIncrementalExecutionResult, FormattedIncrementalDeferResult, FormattedIncrementalResult, FormattedIncrementalStreamResult, + FormattedExecutionResult, + FormattedInitialIncrementalExecutionResult, IncrementalDeferResult, IncrementalResult, IncrementalStreamResult, + InitialIncrementalExecutionResult, SubsequentIncrementalExecutionResult, ) from .async_iterables import map_async_iterable diff --git a/src/graphql/execution/execute.py b/src/graphql/execution/execute.py index d61909a9..ca4df8ff 100644 --- a/src/graphql/execution/execute.py +++ b/src/graphql/execution/execute.py @@ -13,20 +13,14 @@ Awaitable, Callable, Iterable, - Iterator, List, NamedTuple, Optional, - Sequence, Tuple, Union, cast, ) -try: - from typing import TypedDict -except ImportError: # Python < 3.8 - from typing_extensions import TypedDict try: from typing import TypeAlias, TypeGuard except ImportError: # Python < 3.10 @@ -37,7 +31,7 @@ except ImportError: # Python < 3.7 from concurrent.futures import TimeoutError -from ..error import GraphQLError, GraphQLFormattedError, located_error +from ..error import GraphQLError, located_error from ..language import ( DocumentNode, FragmentDefinitionNode, @@ -82,14 +76,13 @@ ) from .incremental_publisher import ( ASYNC_DELAY, - FormattedIncrementalResult, + ExecutionResult, + ExperimentalIncrementalExecutionResults, IncrementalDataRecord, IncrementalPublisher, - IncrementalResult, InitialResultRecord, StreamItemsRecord, SubsequentDataRecord, - SubsequentIncrementalExecutionResult, ) from .middleware import MiddlewareManager from .values import get_argument_values, get_directive_values, get_variable_values @@ -112,12 +105,7 @@ async def anext(iterator: AsyncIterator) -> Any: "execute_sync", "experimental_execute_incrementally", "subscribe", - "ExecutionResult", "ExecutionContext", - "ExperimentalIncrementalExecutionResults", - "FormattedExecutionResult", - "FormattedInitialIncrementalExecutionResult", - "InitialIncrementalExecutionResult", "Middleware", ] @@ -144,181 +132,7 @@ async def anext(iterator: AsyncIterator) -> Any: # 3) inline fragment "spreads" e.g. "...on Type { a }" -class FormattedExecutionResult(TypedDict, total=False): - """Formatted execution result""" - - data: dict[str, Any] | None - errors: list[GraphQLFormattedError] - extensions: dict[str, Any] - - -class ExecutionResult: - """The result of GraphQL execution. - - - ``data`` is the result of a successful execution of the query. - - ``errors`` is included when any errors occurred as a non-empty list. - - ``extensions`` is reserved for adding non-standard properties. - """ - - __slots__ = "data", "errors", "extensions" - - data: dict[str, Any] | None - errors: list[GraphQLError] | None - extensions: dict[str, Any] | None - - def __init__( - self, - data: dict[str, Any] | None = None, - errors: list[GraphQLError] | None = None, - extensions: dict[str, Any] | None = None, - ) -> None: - self.data = data - self.errors = errors - self.extensions = extensions - - def __repr__(self) -> str: - name = self.__class__.__name__ - ext = "" if self.extensions is None else f", extensions={self.extensions}" - return f"{name}(data={self.data!r}, errors={self.errors!r}{ext})" - - def __iter__(self) -> Iterator[Any]: - return iter((self.data, self.errors)) - - @property - def formatted(self) -> FormattedExecutionResult: - """Get execution result formatted according to the specification.""" - formatted: FormattedExecutionResult = {"data": self.data} - if self.errors is not None: - formatted["errors"] = [error.formatted for error in self.errors] - if self.extensions is not None: - formatted["extensions"] = self.extensions - return formatted - - def __eq__(self, other: object) -> bool: - if isinstance(other, dict): - if "extensions" not in other: - return other == {"data": self.data, "errors": self.errors} - return other == { - "data": self.data, - "errors": self.errors, - "extensions": self.extensions, - } - if isinstance(other, tuple): - if len(other) == 2: - return other == (self.data, self.errors) - return other == (self.data, self.errors, self.extensions) - return ( - isinstance(other, self.__class__) - and other.data == self.data - and other.errors == self.errors - and other.extensions == self.extensions - ) - - def __ne__(self, other: object) -> bool: - return not self == other - - -class FormattedInitialIncrementalExecutionResult(TypedDict, total=False): - """Formatted initial incremental execution result""" - - data: dict[str, Any] | None - errors: list[GraphQLFormattedError] - hasNext: bool - incremental: list[FormattedIncrementalResult] - extensions: dict[str, Any] - - -class InitialIncrementalExecutionResult: - """Initial incremental execution result. - - - ``has_next`` is True if a future payload is expected. - - ``incremental`` is a list of the results from defer/stream directives. - """ - - data: dict[str, Any] | None - errors: list[GraphQLError] | None - incremental: Sequence[IncrementalResult] | None - has_next: bool - extensions: dict[str, Any] | None - - __slots__ = "data", "errors", "has_next", "incremental", "extensions" - - def __init__( - self, - data: dict[str, Any] | None = None, - errors: list[GraphQLError] | None = None, - incremental: Sequence[IncrementalResult] | None = None, - has_next: bool = False, - extensions: dict[str, Any] | None = None, - ) -> None: - self.data = data - self.errors = errors - self.incremental = incremental - self.has_next = has_next - self.extensions = extensions - - def __repr__(self) -> str: - name = self.__class__.__name__ - args: list[str] = [f"data={self.data!r}, errors={self.errors!r}"] - if self.incremental: - args.append(f"incremental[{len(self.incremental)}]") - if self.has_next: - args.append("has_next") - if self.extensions: - args.append(f"extensions={self.extensions}") - return f"{name}({', '.join(args)})" - - @property - def formatted(self) -> FormattedInitialIncrementalExecutionResult: - """Get execution result formatted according to the specification.""" - formatted: FormattedInitialIncrementalExecutionResult = {"data": self.data} - if self.errors is not None: - formatted["errors"] = [error.formatted for error in self.errors] - if self.incremental: - formatted["incremental"] = [result.formatted for result in self.incremental] - formatted["hasNext"] = self.has_next - if self.extensions is not None: - formatted["extensions"] = self.extensions - return formatted - - def __eq__(self, other: object) -> bool: - if isinstance(other, dict): - return ( - other.get("data") == self.data - and other.get("errors") == self.errors - and ( - "incremental" not in other - or other["incremental"] == self.incremental - ) - and ("hasNext" not in other or other["hasNext"] == self.has_next) - and ( - "extensions" not in other or other["extensions"] == self.extensions - ) - ) - if isinstance(other, tuple): - size = len(other) - return ( - 1 < size < 6 - and ( - self.data, - self.errors, - self.incremental, - self.has_next, - self.extensions, - )[:size] - == other - ) - return ( - isinstance(other, self.__class__) - and other.data == self.data - and other.errors == self.errors - and other.incremental == self.incremental - and other.has_next == self.has_next - and other.extensions == self.extensions - ) - - def __ne__(self, other: object) -> bool: - return not self == other +Middleware: TypeAlias = Optional[Union[Tuple, List, MiddlewareManager]] class StreamArguments(NamedTuple): @@ -328,16 +142,6 @@ class StreamArguments(NamedTuple): label: str | None -class ExperimentalIncrementalExecutionResults(NamedTuple): - """Execution results when retrieved incrementally.""" - - initial_result: InitialIncrementalExecutionResult - subsequent_results: AsyncGenerator[SubsequentIncrementalExecutionResult, None] - - -Middleware: TypeAlias = Optional[Union[Tuple, List, MiddlewareManager]] - - class ExecutionContext: """Data that must be available at all points during query execution. @@ -482,24 +286,6 @@ def build( is_awaitable, ) - @staticmethod - def build_response( - data: dict[str, Any] | None, errors: list[GraphQLError] - ) -> ExecutionResult: - """Build response. - - Given a completed execution context and data, build the (data, errors) response - defined by the "Response" section of the GraphQL spec. - """ - if not errors: - return ExecutionResult(data, None) - # Sort the error list in order to make it deterministic, since we might have - # been using parallel execution. - errors.sort( - key=lambda error: (error.locations or [], error.path or [], error.message) - ) - return ExecutionResult(data, errors) - def build_per_event_execution_context(self, payload: Any) -> ExecutionContext: """Create a copy of the execution context for usage with subscribe events.""" return self.__class__( @@ -1882,57 +1668,29 @@ def execute_impl( # in this case is the entire response. incremental_publisher = context.incremental_publisher initial_result_record = incremental_publisher.prepare_initial_result_record() - build_response = context.build_response try: - result = context.execute_operation(initial_result_record) + data = context.execute_operation(initial_result_record) + if context.is_awaitable(data): - if context.is_awaitable(result): - # noinspection PyShadowingNames - async def await_result() -> Any: + async def await_response() -> ( + ExecutionResult | ExperimentalIncrementalExecutionResults + ): try: - errors = incremental_publisher.get_initial_errors( - initial_result_record - ) - initial_result = build_response( - await result, # type: ignore - errors, + return incremental_publisher.build_data_response( + initial_result_record, + await data, # type: ignore ) - incremental_publisher.publish_initial(initial_result_record) - if incremental_publisher.has_next(): - return ExperimentalIncrementalExecutionResults( - initial_result=InitialIncrementalExecutionResult( - initial_result.data, - initial_result.errors, - has_next=True, - ), - subsequent_results=incremental_publisher.subscribe(), - ) except GraphQLError as error: - incremental_publisher.add_field_error(initial_result_record, error) - errors = incremental_publisher.get_initial_errors( - initial_result_record + return incremental_publisher.build_error_response( + initial_result_record, error ) - return build_response(None, errors) - return initial_result - return await_result() + return await_response() + + return incremental_publisher.build_data_response(initial_result_record, data) # type: ignore - initial_result = build_response(result, initial_result_record.errors) # type: ignore - incremental_publisher.publish_initial(initial_result_record) - if incremental_publisher.has_next(): - return ExperimentalIncrementalExecutionResults( - initial_result=InitialIncrementalExecutionResult( - initial_result.data, - initial_result.errors, - has_next=True, - ), - subsequent_results=incremental_publisher.subscribe(), - ) except GraphQLError as error: - incremental_publisher.add_field_error(initial_result_record, error) - errors = incremental_publisher.get_initial_errors(initial_result_record) - return build_response(None, errors) - return initial_result + return incremental_publisher.build_error_response(initial_result_record, error) def assume_not_awaitable(_value: Any) -> bool: diff --git a/src/graphql/execution/incremental_publisher.py b/src/graphql/execution/incremental_publisher.py index bf145da3..fdc35fff 100644 --- a/src/graphql/execution/incremental_publisher.py +++ b/src/graphql/execution/incremental_publisher.py @@ -11,6 +11,7 @@ AsyncIterator, Awaitable, Collection, + Iterator, NamedTuple, Sequence, Union, @@ -21,7 +22,6 @@ except ImportError: # Python < 3.8 from typing_extensions import TypedDict - if TYPE_CHECKING: from ..error import GraphQLError, GraphQLFormattedError from ..pyutils import Path @@ -29,10 +29,15 @@ __all__ = [ "ASYNC_DELAY", "DeferredFragmentRecord", + "ExecutionResult", + "ExperimentalIncrementalExecutionResults", + "FormattedExecutionResult", "FormattedIncrementalDeferResult", "FormattedIncrementalResult", "FormattedIncrementalStreamResult", + "FormattedInitialIncrementalExecutionResult", "FormattedSubsequentIncrementalExecutionResult", + "InitialIncrementalExecutionResult", "InitialResultRecord", "IncrementalDataRecord", "IncrementalDeferResult", @@ -49,6 +54,190 @@ suppress_key_error = suppress(KeyError) +class FormattedExecutionResult(TypedDict, total=False): + """Formatted execution result""" + + data: dict[str, Any] | None + errors: list[GraphQLFormattedError] + extensions: dict[str, Any] + + +class ExecutionResult: + """The result of GraphQL execution. + + - ``data`` is the result of a successful execution of the query. + - ``errors`` is included when any errors occurred as a non-empty list. + - ``extensions`` is reserved for adding non-standard properties. + """ + + __slots__ = "data", "errors", "extensions" + + data: dict[str, Any] | None + errors: list[GraphQLError] | None + extensions: dict[str, Any] | None + + def __init__( + self, + data: dict[str, Any] | None = None, + errors: list[GraphQLError] | None = None, + extensions: dict[str, Any] | None = None, + ) -> None: + self.data = data + self.errors = errors + self.extensions = extensions + + def __repr__(self) -> str: + name = self.__class__.__name__ + ext = "" if self.extensions is None else f", extensions={self.extensions}" + return f"{name}(data={self.data!r}, errors={self.errors!r}{ext})" + + def __iter__(self) -> Iterator[Any]: + return iter((self.data, self.errors)) + + @property + def formatted(self) -> FormattedExecutionResult: + """Get execution result formatted according to the specification.""" + formatted: FormattedExecutionResult = {"data": self.data} + if self.errors is not None: + formatted["errors"] = [error.formatted for error in self.errors] + if self.extensions is not None: + formatted["extensions"] = self.extensions + return formatted + + def __eq__(self, other: object) -> bool: + if isinstance(other, dict): + if "extensions" not in other: + return other == {"data": self.data, "errors": self.errors} + return other == { + "data": self.data, + "errors": self.errors, + "extensions": self.extensions, + } + if isinstance(other, tuple): + if len(other) == 2: + return other == (self.data, self.errors) + return other == (self.data, self.errors, self.extensions) + return ( + isinstance(other, self.__class__) + and other.data == self.data + and other.errors == self.errors + and other.extensions == self.extensions + ) + + def __ne__(self, other: object) -> bool: + return not self == other + + +class FormattedInitialIncrementalExecutionResult(TypedDict, total=False): + """Formatted initial incremental execution result""" + + data: dict[str, Any] | None + errors: list[GraphQLFormattedError] + hasNext: bool + incremental: list[FormattedIncrementalResult] + extensions: dict[str, Any] + + +class InitialIncrementalExecutionResult: + """Initial incremental execution result. + + - ``has_next`` is True if a future payload is expected. + - ``incremental`` is a list of the results from defer/stream directives. + """ + + data: dict[str, Any] | None + errors: list[GraphQLError] | None + incremental: Sequence[IncrementalResult] | None + has_next: bool + extensions: dict[str, Any] | None + + __slots__ = "data", "errors", "has_next", "incremental", "extensions" + + def __init__( + self, + data: dict[str, Any] | None = None, + errors: list[GraphQLError] | None = None, + incremental: Sequence[IncrementalResult] | None = None, + has_next: bool = False, + extensions: dict[str, Any] | None = None, + ) -> None: + self.data = data + self.errors = errors + self.incremental = incremental + self.has_next = has_next + self.extensions = extensions + + def __repr__(self) -> str: + name = self.__class__.__name__ + args: list[str] = [f"data={self.data!r}, errors={self.errors!r}"] + if self.incremental: + args.append(f"incremental[{len(self.incremental)}]") + if self.has_next: + args.append("has_next") + if self.extensions: + args.append(f"extensions={self.extensions}") + return f"{name}({', '.join(args)})" + + @property + def formatted(self) -> FormattedInitialIncrementalExecutionResult: + """Get execution result formatted according to the specification.""" + formatted: FormattedInitialIncrementalExecutionResult = {"data": self.data} + if self.errors is not None: + formatted["errors"] = [error.formatted for error in self.errors] + if self.incremental: + formatted["incremental"] = [result.formatted for result in self.incremental] + formatted["hasNext"] = self.has_next + if self.extensions is not None: + formatted["extensions"] = self.extensions + return formatted + + def __eq__(self, other: object) -> bool: + if isinstance(other, dict): + return ( + other.get("data") == self.data + and other.get("errors") == self.errors + and ( + "incremental" not in other + or other["incremental"] == self.incremental + ) + and ("hasNext" not in other or other["hasNext"] == self.has_next) + and ( + "extensions" not in other or other["extensions"] == self.extensions + ) + ) + if isinstance(other, tuple): + size = len(other) + return ( + 1 < size < 6 + and ( + self.data, + self.errors, + self.incremental, + self.has_next, + self.extensions, + )[:size] + == other + ) + return ( + isinstance(other, self.__class__) + and other.data == self.data + and other.errors == self.errors + and other.incremental == self.incremental + and other.has_next == self.has_next + and other.extensions == self.extensions + ) + + def __ne__(self, other: object) -> bool: + return not self == other + + +class ExperimentalIncrementalExecutionResults(NamedTuple): + """Execution results when retrieved incrementally.""" + + initial_result: InitialIncrementalExecutionResult + subsequent_results: AsyncGenerator[SubsequentIncrementalExecutionResult, None] + + class FormattedIncrementalDeferResult(TypedDict, total=False): """Formatted incremental deferred execution result""" @@ -363,53 +552,6 @@ def __init__(self) -> None: self._resolve = None # lazy initialization self._tasks: set[Awaitable] = set() - def has_next(self) -> bool: - """Check whether there is a next incremental result.""" - return bool(self._pending) - - async def subscribe( - self, - ) -> AsyncGenerator[SubsequentIncrementalExecutionResult, None]: - """Subscribe to the incremental results.""" - is_done = False - pending = self._pending - - try: - while not is_done: - released = self._released - for item in released: - with suppress_key_error: - del pending[item] - self._released = {} - - result = self._get_incremental_result(released) - - if not self.has_next(): - is_done = True - - if result is not None: - yield result - else: - resolve = self._resolve - if resolve is None: - self._resolve = resolve = Event() - await resolve.wait() - finally: - close_async_iterators = [] - for incremental_data_record in pending: - if isinstance( - incremental_data_record, StreamItemsRecord - ): # pragma: no cover - async_iterator = incremental_data_record.async_iterator - if async_iterator: - try: - close_async_iterator = async_iterator.aclose() # type: ignore - except AttributeError: - pass - else: - close_async_iterators.append(close_async_iterator) - await gather(*close_async_iterators) - def prepare_initial_result_record(self) -> InitialResultRecord: """Prepare a new initial result record.""" return InitialResultRecord(errors=[], children={}) @@ -471,18 +613,47 @@ def add_field_error( """Add a field error to the given incremental data record.""" incremental_data_record.errors.append(error) - def publish_initial(self, initial_result: InitialResultRecord) -> None: - """Publish the initial result.""" - for child in initial_result.children: + def build_data_response( + self, initial_result_record: InitialResultRecord, data: dict[str, Any] | None + ) -> ExecutionResult | ExperimentalIncrementalExecutionResults: + """Build response for the given data.""" + for child in initial_result_record.children: if child.filtered: continue self._publish(child) - def get_initial_errors( - self, initial_result: InitialResultRecord - ) -> list[GraphQLError]: - """Get the errors from the given initial result.""" - return initial_result.errors + errors = initial_result_record.errors or None + if errors: + errors.sort( + key=lambda error: ( + error.locations or [], + error.path or [], + error.message, + ) + ) + if self._pending: + return ExperimentalIncrementalExecutionResults( + initial_result=InitialIncrementalExecutionResult( + data, + errors, + has_next=True, + ), + subsequent_results=self._subscribe(), + ) + return ExecutionResult(data, errors) + + def build_error_response( + self, initial_result_record: InitialResultRecord, error: GraphQLError + ) -> ExecutionResult: + """Build response for the given error.""" + errors = initial_result_record.errors + errors.append(error) + # Sort the error list in order to make it deterministic, since we might have + # been using parallel execution. + errors.sort( + key=lambda error: (error.locations or [], error.path or [], error.message) + ) + return ExecutionResult(None, errors) def filter( self, @@ -510,6 +681,49 @@ def filter( else: self._add_task(close_async_iterator) + async def _subscribe( + self, + ) -> AsyncGenerator[SubsequentIncrementalExecutionResult, None]: + """Subscribe to the incremental results.""" + is_done = False + pending = self._pending + + try: + while not is_done: + released = self._released + for item in released: + with suppress_key_error: + del pending[item] + self._released = {} + + result = self._get_incremental_result(released) + + if not self._pending: + is_done = True + + if result is not None: + yield result + else: + resolve = self._resolve + if resolve is None: + self._resolve = resolve = Event() + await resolve.wait() + finally: + close_async_iterators = [] + for incremental_data_record in pending: + if isinstance( + incremental_data_record, StreamItemsRecord + ): # pragma: no cover + async_iterator = incremental_data_record.async_iterator + if async_iterator: + try: + close_async_iterator = async_iterator.aclose() # type: ignore + except AttributeError: + pass + else: + close_async_iterators.append(close_async_iterator) + await gather(*close_async_iterators) + def _trigger(self) -> None: """Trigger the resolve event.""" resolve = self._resolve @@ -572,11 +786,12 @@ def _get_incremental_result( ) append_result(incremental_result) + has_next = bool(self._pending) if incremental_results: return SubsequentIncrementalExecutionResult( - incremental=incremental_results, has_next=self.has_next() + incremental=incremental_results, has_next=has_next ) - if encountered_completed_async_iterator and not self.has_next(): + if encountered_completed_async_iterator and not has_next: return SubsequentIncrementalExecutionResult(has_next=False) return None