From 63c487722c83dcfac8e761e7057c81f3dd0cb4bc Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Tue, 17 Dec 2024 20:36:14 +0100 Subject: [PATCH] Unify action titles --- .github/workflows/benchmarks.yml | 14 +- .github/workflows/lint.yml | 1 + .github/workflows/publish.yml | 1 + .github/workflows/test.yml | 1 + src/graphql/execution/collect_fields.py | 352 +++++++--- src/graphql/execution/execute.py | 578 +++++++++++----- .../execution/incremental_publisher.py | 624 ++++++++++++------ tests/benchmarks/test_visit.py | 2 +- 8 files changed, 1086 insertions(+), 487 deletions(-) diff --git a/.github/workflows/benchmarks.yml b/.github/workflows/benchmarks.yml index 35867e43..21f4750a 100644 --- a/.github/workflows/benchmarks.yml +++ b/.github/workflows/benchmarks.yml @@ -11,18 +11,22 @@ jobs: benchmarks: name: ๐Ÿ“ˆ Benchmarks runs-on: ubuntu-latest + steps: - uses: actions/checkout@v4 - - uses: actions/setup-python@v5 + + - name: Set up Python 3.12 + uses: actions/setup-python@v5 id: setup-python with: python-version: "3.12" architecture: x64 - - run: pipx install poetry - - - run: poetry env use 3.12 - - run: poetry install --with test + - name: Install with poetry + run: | + pipx install poetry + poetry env use 3.12 + poetry install --with test - name: Run benchmarks uses: CodSpeedHQ/action@v3 diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 74f14604..703a56aa 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -4,6 +4,7 @@ on: [push, pull_request] jobs: lint: + name: ๐Ÿงน Lint runs-on: ubuntu-latest steps: diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 561b3028..8bd8c296 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -7,6 +7,7 @@ on: jobs: build: + name: ๐Ÿ—๏ธ Build runs-on: ubuntu-latest steps: diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 8959d0de..298d3dd0 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -4,6 +4,7 @@ on: [push, pull_request] jobs: tests: + name: ๐Ÿงช Tests runs-on: ubuntu-latest strategy: diff --git a/src/graphql/execution/collect_fields.py b/src/graphql/execution/collect_fields.py index 5cb5a723..d94f8079 100644 --- a/src/graphql/execution/collect_fields.py +++ b/src/graphql/execution/collect_fields.py @@ -3,8 +3,7 @@ from __future__ import annotations import sys -from collections import defaultdict -from typing import Any, Dict, List, NamedTuple +from typing import Any, Dict, NamedTuple, cast from ..language import ( FieldNode, @@ -35,31 +34,90 @@ __all__ = [ "collect_fields", "collect_subfields", + "CollectFieldsResult", + "CollectFieldsContext", + "DeferUsage", + "DeferUsageSet", + "FieldDetails", "FieldGroup", - "FieldsAndPatches", - "GroupedFieldSet", + "GroupedFieldSetDetails", + "Target", + "TargetSet", + "NON_DEFERRED_TARGET_SET", ] + +class DeferUsage(NamedTuple): + """An optionally labelled list of ancestor targets.""" + + label: str | None + ancestors: list[Target] + + +Target: TypeAlias = DeferUsage | None + if sys.version_info < (3, 9): - FieldGroup: TypeAlias = List[FieldNode] - GroupedFieldSet = Dict[str, FieldGroup] + TargetSet: TypeAlias = Dict[Target, None] + DeferUsageSet: TypeAlias = Dict[DeferUsage, None] else: # Python >= 3.9 - FieldGroup: TypeAlias = list[FieldNode] - GroupedFieldSet = dict[str, FieldGroup] + TargetSet: TypeAlias = dict[Target, None] + DeferUsageSet: TypeAlias = dict[DeferUsage, None] -class PatchFields(NamedTuple): - """Optionally labelled set of fields to be used as a patch.""" +NON_DEFERRED_TARGET_SET: TargetSet = {} + + +class FieldDetails(NamedTuple): + """A field node and its target.""" + + node: FieldNode + target: Target + + +class FieldGroup(NamedTuple): + """A group of fields that share the same target set.""" + + fields: list[FieldDetails] + targets: TargetSet + + def to_nodes(self) -> list[FieldNode]: + """Return the field nodes in this group.""" + return [field_details.node for field_details in self.fields] + + +if sys.version_info < (3, 9): + GroupedFieldSet: TypeAlias = Dict[str, FieldGroup] +else: # Python >= 3.9 + GroupedFieldSet: TypeAlias = dict[str, FieldGroup] + + +class GroupedFieldSetDetails(NamedTuple): + """A grouped field set with defer info.""" - label: str | None grouped_field_set: GroupedFieldSet + should_initiate_defer: bool -class FieldsAndPatches(NamedTuple): - """Tuple of collected fields and patches to be applied.""" +class CollectFieldsResult(NamedTuple): + """Collected fields and deferred usages.""" grouped_field_set: GroupedFieldSet - patches: list[PatchFields] + new_grouped_field_set_details: dict[DeferUsageSet, GroupedFieldSetDetails] + new_defer_usages: list[DeferUsage] + + +class CollectFieldsContext(NamedTuple): + """Context for collecting fields.""" + + schema: GraphQLSchema + fragments: dict[str, FragmentDefinitionNode] + variable_values: dict[str, Any] + operation: OperationDefinitionNode + runtime_type: GraphQLObjectType + targets_by_key: dict[str, TargetSet] + fields_by_target: dict[Target, dict[str, list[FieldNode]]] + new_defer_usages: list[DeferUsage] + visited_fragment_names: set[str] def collect_fields( @@ -68,7 +126,7 @@ def collect_fields( variable_values: dict[str, Any], runtime_type: GraphQLObjectType, operation: OperationDefinitionNode, -) -> FieldsAndPatches: +) -> CollectFieldsResult: """Collect fields. Given a selection_set, collects all the fields and returns them. @@ -79,20 +137,15 @@ def collect_fields( For internal use only. """ - grouped_field_set: dict[str, list[FieldNode]] = defaultdict(list) - patches: list[PatchFields] = [] - collect_fields_impl( - schema, - fragments, - variable_values, - operation, - runtime_type, - operation.selection_set, - grouped_field_set, - patches, - set(), + context = CollectFieldsContext( + schema, fragments, variable_values, operation, runtime_type, {}, {}, [], set() + ) + collect_fields_impl(context, operation.selection_set) + + return CollectFieldsResult( + *build_grouped_field_sets(context.targets_by_key, context.fields_by_target), + context.new_defer_usages, ) - return FieldsAndPatches(grouped_field_set, patches) def collect_subfields( @@ -102,7 +155,7 @@ def collect_subfields( operation: OperationDefinitionNode, return_type: GraphQLObjectType, field_group: FieldGroup, -) -> FieldsAndPatches: +) -> CollectFieldsResult: """Collect subfields. Given a list of field nodes, collects all the subfields of the passed in fields, @@ -114,47 +167,58 @@ def collect_subfields( For internal use only. """ - sub_grouped_field_set: dict[str, list[FieldNode]] = defaultdict(list) - visited_fragment_names: set[str] = set() - - sub_patches: list[PatchFields] = [] - sub_fields_and_patches = FieldsAndPatches(sub_grouped_field_set, sub_patches) + context = CollectFieldsContext( + schema, fragments, variable_values, operation, return_type, {}, {}, [], set() + ) - for node in field_group: + for field_details in field_group.fields: + node = field_details.node if node.selection_set: - collect_fields_impl( - schema, - fragments, - variable_values, - operation, - return_type, - node.selection_set, - sub_grouped_field_set, - sub_patches, - visited_fragment_names, - ) - return sub_fields_and_patches + collect_fields_impl(context, node.selection_set, field_details.target) + + return CollectFieldsResult( + *build_grouped_field_sets(context.targets_by_key, context.fields_by_target), + context.new_defer_usages, + ) def collect_fields_impl( - schema: GraphQLSchema, - fragments: dict[str, FragmentDefinitionNode], - variable_values: dict[str, Any], - operation: OperationDefinitionNode, - runtime_type: GraphQLObjectType, + context: CollectFieldsContext, selection_set: SelectionSetNode, - grouped_field_set: dict[str, list[FieldNode]], - patches: list[PatchFields], - visited_fragment_names: set[str], + parent_target: Target | None = None, + new_target: Target | None = None, ) -> None: """Collect fields (internal implementation).""" - patch_fields: dict[str, list[FieldNode]] + ( + schema, + fragments, + variable_values, + operation, + runtime_type, + targets_by_key, + fields_by_target, + new_defer_usages, + visited_fragment_names, + ) = context + + ancestors: list[Target] for selection in selection_set.selections: if isinstance(selection, FieldNode): if not should_include_node(variable_values, selection): continue - grouped_field_set[get_field_entry_key(selection)].append(selection) + key = get_field_entry_key(selection) + target = new_target or parent_target + key_targets = targets_by_key.get(key) + if key_targets is None: + targets_by_key[key] = {target: None} + else: + key_targets[target] = None + target_fields = fields_by_target.get(target) + if target_fields is None: + fields_by_target[target] = {key: [selection]} + else: + target_fields[key].append(selection) elif isinstance(selection, InlineFragmentNode): if not should_include_node( variable_values, selection @@ -162,32 +226,19 @@ def collect_fields_impl( continue defer = get_defer_values(operation, variable_values, selection) + if defer: - patch_fields = defaultdict(list) - collect_fields_impl( - schema, - fragments, - variable_values, - operation, - runtime_type, - selection.selection_set, - patch_fields, - patches, - visited_fragment_names, + ancestors = ( + [None] + if parent_target is None + else [parent_target, *parent_target.ancestors] ) - patches.append(PatchFields(defer.label, patch_fields)) + target = DeferUsage(defer.label, ancestors) + new_defer_usages.append(target) else: - collect_fields_impl( - schema, - fragments, - variable_values, - operation, - runtime_type, - selection.selection_set, - grouped_field_set, - patches, - visited_fragment_names, - ) + target = new_target + + collect_fields_impl(context, selection.selection_set, parent_target, target) elif isinstance(selection, FragmentSpreadNode): # pragma: no cover else frag_name = selection.name.value @@ -204,35 +255,19 @@ def collect_fields_impl( ): continue - if not defer: - visited_fragment_names.add(frag_name) - if defer: - patch_fields = defaultdict(list) - collect_fields_impl( - schema, - fragments, - variable_values, - operation, - runtime_type, - fragment.selection_set, - patch_fields, - patches, - visited_fragment_names, + ancestors = ( + [None] + if parent_target is None + else [parent_target, *parent_target.ancestors] ) - patches.append(PatchFields(defer.label, patch_fields)) + target = DeferUsage(defer.label, ancestors) + new_defer_usages.append(target) else: - collect_fields_impl( - schema, - fragments, - variable_values, - operation, - runtime_type, - fragment.selection_set, - grouped_field_set, - patches, - visited_fragment_names, - ) + visited_fragment_names.add(frag_name) + target = new_target + + collect_fields_impl(context, fragment.selection_set, parent_target, target) class DeferValues(NamedTuple): @@ -305,3 +340,108 @@ def does_fragment_condition_match( def get_field_entry_key(node: FieldNode) -> str: """Implement the logic to compute the key of a given field's entry""" return node.alias.value if node.alias else node.name.value + + +def build_grouped_field_sets( + targets_by_key: dict[str, TargetSet], + fields_by_target: dict[Target, dict[str, list[FieldNode]]], + parent_targets: TargetSet = NON_DEFERRED_TARGET_SET, +) -> tuple[GroupedFieldSet, dict[DeferUsageSet, GroupedFieldSetDetails]]: + """... TODO ...""" + parent_target_keys, target_set_details_map = get_target_set_details( + targets_by_key, parent_targets + ) + + grouped_field_set = ( + get_ordered_grouped_field_set( + parent_target_keys, parent_targets, targets_by_key, fields_by_target + ) + if parent_target_keys + else {} + ) + + new_grouped_field_set_details: dict[DeferUsageSet, GroupedFieldSetDetails] = {} + + for masking_targets, target_set_details in target_set_details_map.items(): + keys, should_initiate_defer = target_set_details + + new_grouped_field_set = get_ordered_grouped_field_set( + keys, masking_targets, targets_by_key, fields_by_target + ) + + # All TargetSets that causes new grouped field sets consist only of DeferUsages + # and have should_initiate_defer defined + + new_grouped_field_set_details[cast(DeferUsageSet, masking_targets)] = ( + GroupedFieldSetDetails(new_grouped_field_set, should_initiate_defer) + ) + + return grouped_field_set, new_grouped_field_set_details + + +class TargetSetDetails(NamedTuple): + """A set of target keys with defer info.""" + + keys: set[str] + should_initiate_defer: bool + + +def get_target_set_details( + targets_by_key: dict[str, TargetSet], parent_targets: TargetSet +) -> tuple[set[str], dict[TargetSet, TargetSetDetails]]: + """... TODO ...""" + parent_target_keys: set[str] = set() + target_set_details_map: dict[TargetSet, TargetSetDetails] = {} + + for response_key, targets in targets_by_key.items(): + masking_target_list: list[Target] = [] + for target in targets: + if not target or all( + ancestor not in targets for ancestor in target.ancestors + ): + masking_target_list.append(target) + + masking_targets: TargetSet = dict.fromkeys(masking_target_list) + if masking_targets == parent_targets: + parent_target_keys.add(response_key) + continue + + target_set_details = target_set_details_map.get(masking_targets) + if target_set_details is None: + target_set_details = TargetSetDetails( + {response_key}, + any( + defer_usage not in parent_targets for defer_usage in masking_targets + ), + ) + target_set_details_map[masking_targets] = target_set_details + else: + target_set_details.keys.add(response_key) + + return parent_target_keys, target_set_details_map + + +def get_ordered_grouped_field_set( + keys: set[str], + masking_targets: TargetSet, + targets_by_key: dict[str, TargetSet], + fields_by_target: dict[Target, dict[str, list[FieldNode]]], +) -> GroupedFieldSet: + """... TODO ...""" + grouped_field_set: GroupedFieldSet = {} + + first_target = next(iter(masking_targets)) + first_fields = fields_by_target[first_target] + for key in first_fields: + if key in keys: + field_group = grouped_field_set.get(key) + if field_group is None: + field_group = FieldGroup([], masking_targets) + grouped_field_set[key] = field_group + for target in targets_by_key[key]: + fields_for_target = fields_by_target[target] + nodes = fields_for_target[key] + del fields_for_target[key] + field_group.fields.extend(FieldDetails(node, target) for node in nodes) + + return grouped_field_set diff --git a/src/graphql/execution/execute.py b/src/graphql/execution/execute.py index ca4df8ff..d5e7ed62 100644 --- a/src/graphql/execution/execute.py +++ b/src/graphql/execution/execute.py @@ -3,7 +3,6 @@ from __future__ import annotations from asyncio import ensure_future, gather, shield, wait_for -from collections.abc import Mapping from contextlib import suppress from typing import ( Any, @@ -14,8 +13,10 @@ Callable, Iterable, List, + Mapping, NamedTuple, Optional, + Sequence, Tuple, Union, cast, @@ -68,21 +69,27 @@ ) from .async_iterables import map_async_iterable from .collect_fields import ( + NON_DEFERRED_TARGET_SET, + DeferUsage, + DeferUsageSet, + FieldDetails, FieldGroup, - FieldsAndPatches, GroupedFieldSet, + GroupedFieldSetDetails, collect_fields, collect_subfields, ) from .incremental_publisher import ( ASYNC_DELAY, + DeferredFragmentRecord, + DeferredGroupedFieldSetRecord, ExecutionResult, ExperimentalIncrementalExecutionResults, IncrementalDataRecord, IncrementalPublisher, InitialResultRecord, StreamItemsRecord, - SubsequentDataRecord, + StreamRecord, ) from .middleware import MiddlewareManager from .values import get_argument_values, get_directive_values, get_variable_values @@ -135,11 +142,12 @@ async def anext(iterator: AsyncIterator) -> Any: Middleware: TypeAlias = Optional[Union[Tuple, List, MiddlewareManager]] -class StreamArguments(NamedTuple): - """Arguments of the stream directive""" +class StreamUsage(NamedTuple): + """Stream directive usage information""" - initial_count: int label: str | None + initial_count: int + field_group: FieldGroup class ExecutionContext: @@ -196,6 +204,7 @@ def __init__( self._canceled_iterators: set[AsyncIterator] = set() self._subfields_cache: dict[tuple, FieldsAndPatches] = {} self._tasks: set[Awaitable] = set() + self._stream_usages: dict[FieldGroup, StreamUsage] = {} @classmethod def build( @@ -310,8 +319,8 @@ def execute_operation( Implements the "Executing operations" section of the spec. """ - schema = self.schema operation = self.operation + schema = self.schema root_type = schema.get_root_type(operation.operation) if root_type is None: msg = ( @@ -320,12 +329,26 @@ def execute_operation( ) raise GraphQLError(msg, operation) - grouped_field_set, patches = collect_fields( - schema, - self.fragments, - self.variable_values, - root_type, - operation, + grouped_field_set, new_grouped_field_set_details, new_defer_usages = ( + collect_fields( + schema, self.fragments, self.variable_values, root_type, operation + ) + ) + + incremental_publisher = self.incremental_publisher + new_defer_map = add_new_deferred_fragments( # TODO + incremental_publisher, new_defer_usages, initial_result_record + ) + + path: Path | None = None + + new_deferred_grouped_field_set_records = ( + add_new_deferred_grouped_field_sets( # TODO + incremental_publisher, + new_grouped_field_set_details, + new_defer_map, + path, + ) ) root_value = self.root_value @@ -334,18 +357,22 @@ def execute_operation( self.execute_fields_serially if operation.operation == OperationType.MUTATION else self.execute_fields - )(root_type, root_value, None, grouped_field_set, initial_result_record) - - for patch in patches: - label, patch_grouped_filed_set = patch - self.execute_deferred_fragment( - root_type, - root_value, - patch_grouped_filed_set, - initial_result_record, - label, - None, - ) + )( + root_type, + root_value, + path, + grouped_field_set, + initial_result_record, + new_defer_map, + ) + + self.execute_deferred_grouped_field_sets( + root_type, + root_value, + path, + new_deferred_grouped_field_set_records, + new_defer_map, + ) return result @@ -356,6 +383,7 @@ def execute_fields_serially( path: Path | None, grouped_field_set: GroupedFieldSet, incremental_data_record: IncrementalDataRecord, + defer_map: Mapping[DeferUsage, DeferredFragmentRecord], ) -> AwaitableOrValue[dict[str, Any]]: """Execute the given fields serially. @@ -375,6 +403,7 @@ def reducer( field_group, field_path, incremental_data_record, + defer_map, ) if result is Undefined: return results @@ -401,6 +430,7 @@ def execute_fields( path: Path | None, grouped_field_set: GroupedFieldSet, incremental_data_record: IncrementalDataRecord, + defer_map: Mapping[DeferUsage, DeferredFragmentRecord], ) -> AwaitableOrValue[dict[str, Any]]: """Execute the given fields concurrently. @@ -419,6 +449,7 @@ def execute_fields( field_group, field_path, incremental_data_record, + defer_map, ) if result is not Undefined: results[response_name] = result @@ -456,6 +487,7 @@ def execute_field( field_group: FieldGroup, path: Path, incremental_data_record: IncrementalDataRecord, + defer_map: Mapping[DeferUsage, DeferredFragmentRecord], ) -> AwaitableOrValue[Any]: """Resolve the field on the given source object. @@ -465,7 +497,7 @@ def execute_field( calling its resolve function, then calls complete_value to await coroutine objects, serialize scalars, or execute the sub-selection-set for objects. """ - field_name = field_group[0].name.value + field_name = field_group.fields[0].node.name.value field_def = self.schema.get_field(parent_type, field_name) if not field_def: return Undefined @@ -483,7 +515,9 @@ def execute_field( try: # Build a dictionary of arguments from the field.arguments AST, using the # variables scope to fulfill any variable references. - args = get_argument_values(field_def, field_group[0], self.variable_values) + args = get_argument_values( + field_def, field_group.fields[0].node, self.variable_values + ) # Note that contrary to the JavaScript implementation, we pass the context # value as part of the resolve info. @@ -497,10 +531,17 @@ def execute_field( path, result, incremental_data_record, + defer_map, ) completed = self.complete_value( - return_type, field_group, info, path, result, incremental_data_record + return_type, + field_group, + info, + path, + result, + incremental_data_record, + defer_map, ) if self.is_awaitable(completed): # noinspection PyShadowingNames @@ -547,8 +588,8 @@ def build_resolve_info( # The resolve function's first argument is a collection of information about # the current execution state. return GraphQLResolveInfo( - field_group[0].name.value, - field_group, + field_group.fields[0].node.name.value, + field_group.to_nodes(), field_def.type, parent_type, path, @@ -570,7 +611,7 @@ def handle_field_error( incremental_data_record: IncrementalDataRecord, ) -> None: """Handle error properly according to the field type.""" - error = located_error(raw_error, field_group, path.as_list()) + error = located_error(raw_error, field_group.to_nodes(), path.as_list()) # If the field type is non-nullable, then it is resolved without any protection # from errors, however it still properly locates the error. @@ -589,6 +630,7 @@ def complete_value( path: Path, result: Any, incremental_data_record: IncrementalDataRecord, + defer_map: Mapping[DeferUsage, DeferredFragmentRecord], ) -> AwaitableOrValue[Any]: """Complete a value. @@ -626,6 +668,7 @@ def complete_value( path, result, incremental_data_record, + defer_map, ) if completed is None: msg = ( @@ -642,7 +685,13 @@ def complete_value( # If field type is List, complete each item in the list with inner type if is_list_type(return_type): return self.complete_list_value( - return_type, field_group, info, path, result, incremental_data_record + return_type, + field_group, + info, + path, + result, + incremental_data_record, + defer_map, ) # If field type is a leaf type, Scalar or Enum, serialize to a valid value, @@ -654,13 +703,25 @@ def complete_value( # Object type and complete for that type. if is_abstract_type(return_type): return self.complete_abstract_value( - return_type, field_group, info, path, result, incremental_data_record + return_type, + field_group, + info, + path, + result, + incremental_data_record, + defer_map, ) # If field type is Object, execute and complete all sub-selections. if is_object_type(return_type): return self.complete_object_value( - return_type, field_group, info, path, result, incremental_data_record + return_type, + field_group, + info, + path, + result, + incremental_data_record, + defer_map, ) # Not reachable. All possible output types have been considered. @@ -678,6 +739,7 @@ async def complete_awaitable_value( path: Path, result: Any, incremental_data_record: IncrementalDataRecord, + defer_map: Mapping[DeferUsage, DeferredFragmentRecord], ) -> Any: """Complete an awaitable value.""" try: @@ -689,6 +751,7 @@ async def complete_awaitable_value( path, resolved, incremental_data_record, + defer_map, ) if self.is_awaitable(completed): completed = await completed @@ -702,10 +765,10 @@ async def complete_awaitable_value( def get_stream_values( self, field_group: FieldGroup, path: Path - ) -> StreamArguments | None: + ) -> StreamUsage | None: """Get stream values. - Returns an object containing the `@stream` arguments if a field should be + Returns an object containing info for streaming if a field should be streamed based on the experimental flag, stream directive present and not disabled by the "if" argument. """ @@ -713,10 +776,14 @@ def get_stream_values( if isinstance(path.key, int): return None + stream_usage = self._stream_usages.get(field_group) + if stream_usage is not None: + return stream_usage + # validation only allows equivalent streams on multiple fields, so it is # safe to only check the first field_node for the stream directive stream = get_directive_values( - GraphQLStreamDirective, field_group[0], self.variable_values + GraphQLStreamDirective, field_group.fields[0].node, self.variable_values ) if not stream or stream.get("if") is False: @@ -734,8 +801,21 @@ def get_stream_values( ) raise TypeError(msg) - label = stream.get("label") - return StreamArguments(initial_count=initial_count, label=label) + streamed_field_group = FieldGroup( + [ + FieldDetails(field_details.node, None) + for field_details in field_group.fields + ], + NON_DEFERRED_TARGET_SET, + ) + + stream_usage = StreamUsage( + stream.get("label"), stream["initial_count"], streamed_field_group + ) + + self._stream_usages[field_group] = stream_usage + + return stream_usage async def complete_async_iterator_value( self, @@ -745,36 +825,36 @@ async def complete_async_iterator_value( path: Path, async_iterator: AsyncIterator[Any], incremental_data_record: IncrementalDataRecord, + defer_map: Mapping[DeferUsage, DeferredFragmentRecord], ) -> list[Any]: """Complete an async iterator. Complete an async iterator value by completing the result and calling recursively until all the results are completed. """ - stream = self.get_stream_values(field_group, path) + stream_usage = self.get_stream_usage(field_group, path) complete_list_item_value = self.complete_list_item_value awaitable_indices: list[int] = [] append_awaitable = awaitable_indices.append completed_results: list[Any] = [] index = 0 while True: - if ( - stream - and isinstance(stream.initial_count, int) - and index >= stream.initial_count - ): + if stream_usage and index >= stream_usage.initial_count: + early_return = async_iterator.returnattr # TODO!!! + stream_record = StreamRecord(path, stream_usage.label, early_return) + with suppress_timeout_error: await wait_for( shield( self.execute_stream_async_iterator( index, async_iterator, - field_group, + stream_usage.field_group, info, item_type, path, incremental_data_record, - stream.label, + stream_record, ) ), timeout=ASYNC_DELAY, @@ -789,7 +869,7 @@ async def complete_async_iterator_value( break except Exception as raw_error: raise located_error( - raw_error, field_group, path.as_list() + raw_error, field_group.to_nodes(), path.as_list() ) from raw_error if complete_list_item_value( value, @@ -799,6 +879,7 @@ async def complete_async_iterator_value( info, item_path, incremental_data_record, + defer_map, ): append_awaitable(index) @@ -829,6 +910,7 @@ def complete_list_value( path: Path, result: AsyncIterable[Any] | Iterable[Any], incremental_data_record: IncrementalDataRecord, + defer_map: Mapping[DeferUsage, DeferredFragmentRecord], ) -> AwaitableOrValue[list[Any]]: """Complete a list value. @@ -846,6 +928,7 @@ def complete_list_value( path, async_iterator, incremental_data_record, + defer_map, ) if not is_iterable(result): @@ -855,35 +938,34 @@ def complete_list_value( ) raise GraphQLError(msg) - stream = self.get_stream_values(field_group, path) + stream_usage = self.get_stream_usage(field_group, path) # This is specified as a simple map, however we're optimizing the path where # the list contains no coroutine objects by avoiding creating another coroutine # object. complete_list_item_value = self.complete_list_item_value + current_parents = incremental_data_record awaitable_indices: list[int] = [] append_awaitable = awaitable_indices.append - previous_incremental_data_record = incremental_data_record completed_results: list[Any] = [] + stream_record: StreamRecord | None = None for index, item in enumerate(result): # No need to modify the info object containing the path, since from here on # it is not ever accessed by resolver functions. item_path = path.add_key(index, None) - if ( - stream - and isinstance(stream.initial_count, int) - and index >= stream.initial_count - ): - previous_incremental_data_record = self.execute_stream_field( + if stream_usage and index >= stream_usage.initial_count: + if stream_record is None: + stream_record = StreamRecord(path, stream_usage.label) + current_parents = self.execute_stream_field( path, item_path, item, - field_group, + stream_usage.field_group, info, item_type, - previous_incremental_data_record, - stream.label, + current_parents, + stream_record, ) continue @@ -895,9 +977,15 @@ def complete_list_value( info, item_path, incremental_data_record, + defer_map, ): append_awaitable(index) + if stream_record is not None: + self.incremental_publisher.set_is_final_record( + cast(StreamItemsRecord, current_parents) + ) + if not awaitable_indices: return completed_results @@ -928,6 +1016,7 @@ def complete_list_item_value( info: GraphQLResolveInfo, item_path: Path, incremental_data_record: IncrementalDataRecord, + defer_map: Mapping[DeferUsage, DeferredFragmentRecord], ) -> bool: """Complete a list item value by adding it to the completed results. @@ -944,6 +1033,7 @@ def complete_list_item_value( item_path, item, incremental_data_record, + defer_map, ) ) return True @@ -956,6 +1046,7 @@ def complete_list_item_value( item_path, item, incremental_data_record, + defer_map, ) if is_awaitable(completed_item): @@ -1019,6 +1110,7 @@ def complete_abstract_value( path: Path, result: Any, incremental_data_record: IncrementalDataRecord, + defer_map: Mapping[DeferUsage, DeferredFragmentRecord], ) -> AwaitableOrValue[Any]: """Complete an abstract value. @@ -1045,6 +1137,7 @@ async def await_complete_object_value() -> Any: path, result, incremental_data_record, + defer_map, ) if self.is_awaitable(value): return await value # type: ignore @@ -1062,6 +1155,7 @@ async def await_complete_object_value() -> Any: path, result, incremental_data_record, + defer_map, ) def ensure_valid_runtime_type( @@ -1082,7 +1176,7 @@ def ensure_valid_runtime_type( " a 'resolve_type' function or each possible type should provide" " an 'is_type_of' function." ) - raise GraphQLError(msg, field_group) + raise GraphQLError(msg, field_group.to_nodes()) if is_object_type(runtime_type_name): # pragma: no cover msg = ( @@ -1098,7 +1192,7 @@ def ensure_valid_runtime_type( f" for field '{info.parent_type.name}.{info.field_name}' with value" f" {inspect(result)}, received '{inspect(runtime_type_name)}'." ) - raise GraphQLError(msg, field_group) + raise GraphQLError(msg, field_group.to_nodes()) runtime_type = self.schema.get_type(runtime_type_name) @@ -1107,21 +1201,21 @@ def ensure_valid_runtime_type( f"Abstract type '{return_type.name}' was resolved to a type" f" '{runtime_type_name}' that does not exist inside the schema." ) - raise GraphQLError(msg, field_group) + raise GraphQLError(msg, field_group.to_nodes()) if not is_object_type(runtime_type): msg = ( f"Abstract type '{return_type.name}' was resolved" f" to a non-object type '{runtime_type_name}'." ) - raise GraphQLError(msg, field_group) + raise GraphQLError(msg, field_group.to_nodes()) if not self.schema.is_sub_type(return_type, runtime_type): msg = ( f"Runtime Object type '{runtime_type.name}' is not a possible" f" type for '{return_type.name}'." ) - raise GraphQLError(msg, field_group) + raise GraphQLError(msg, field_group.to_nodes()) # noinspection PyTypeChecker return runtime_type @@ -1134,6 +1228,7 @@ def complete_object_value( path: Path, result: Any, incremental_data_record: IncrementalDataRecord, + defer_map: Mapping[DeferUsage, DeferredFragmentRecord], ) -> AwaitableOrValue[dict[str, Any]]: """Complete an Object value by executing all sub-selections.""" # If there is an `is_type_of()` predicate function, call it with the current @@ -1150,7 +1245,12 @@ async def execute_subfields_async() -> dict[str, Any]: return_type, result, field_group ) return self.collect_and_execute_subfields( - return_type, field_group, path, result, incremental_data_record + return_type, + field_group, + path, + result, + incremental_data_record, + defer_map, ) # type: ignore return execute_subfields_async() @@ -1159,7 +1259,7 @@ async def execute_subfields_async() -> dict[str, Any]: raise invalid_return_type_error(return_type, result, field_group) return self.collect_and_execute_subfields( - return_type, field_group, path, result, incremental_data_record + return_type, field_group, path, result, incremental_data_record, defer_map ) def collect_and_execute_subfields( @@ -1169,26 +1269,41 @@ def collect_and_execute_subfields( path: Path, result: Any, incremental_data_record: IncrementalDataRecord, + defer_map: Mapping[DeferUsage, DeferredFragmentRecord], ) -> AwaitableOrValue[dict[str, Any]]: """Collect sub-fields to execute to complete this value.""" - sub_grouped_field_set, sub_patches = self.collect_subfields( - return_type, field_group + grouped_field_set, new_grouped_field_set_details, new_defer_usages = ( + self.collect_subfields(return_type, field_group) + ) + + incremental_publisher = self.incremental_publisher + new_defer_map = add_new_deferred_fragments( + incremental_publisher, + new_defer_usages, + incremental_data_record, + defer_map, + path, + ) + new_deferred_grouped_field_set_records = add_new_deferred_grouped_field_sets( + incremental_publisher, new_grouped_field_set_details, new_defer_map, path ) sub_fields = self.execute_fields( - return_type, result, path, sub_grouped_field_set, incremental_data_record + return_type, + result, + path, + grouped_field_set, + incremental_data_record, + new_defer_map, ) - for sub_patch in sub_patches: - label, sub_patch_grouped_field_set = sub_patch - self.execute_deferred_fragment( - return_type, - result, - sub_patch_grouped_field_set, - incremental_data_record, - label, - path, - ) + self.execute_deferred_grouped_field_sets( + return_type, + result, + path, + new_deferred_grouped_field_set_records, + new_defer_map, + ) return sub_fields @@ -1258,57 +1373,90 @@ async def callback(payload: Any) -> ExecutionResult: return map_async_iterable(result_or_stream, callback) - def execute_deferred_fragment( + def execute_deferred_grouped_field_sets( self, parent_type: GraphQLObjectType, source_value: Any, - fields: GroupedFieldSet, - parent_context: IncrementalDataRecord, - label: str | None = None, - path: Path | None = None, + path: Path | None, + new_deferred_grouped_field_set_records: Sequence[DeferredGroupedFieldSetRecord], + defer_map: Mapping[DeferUsage, DeferredFragmentRecord], ) -> None: - """Execute deferred fragment.""" + """Execute deferred grouped field sets.""" + for deferred_grouped_field_set_record in new_deferred_grouped_field_set_records: + if deferred_grouped_field_set_record.should_initiate_defer: + + async def execute_deferred_grouped_field_set( + deferred_grouped_field_set_record: DeferredGroupedFieldSetRecord, + ) -> None: + self.execute_deferred_grouped_field_set( + parent_type, + source_value, + path, + deferred_grouped_field_set_record, + defer_map, + ) + + self.add_task( + execute_deferred_grouped_field_set( + deferred_grouped_field_set_record + ) + ) + + else: + self.execute_deferred_grouped_field_set( + parent_type, + source_value, + path, + deferred_grouped_field_set_record, + defer_map, + ) + + def execute_deferred_grouped_field_set( + self, + parent_type: GraphQLObjectType, + source_value: Any, + path: Path | None, + deferred_grouped_field_set_record: DeferredGroupedFieldSetRecord, + defer_map: Mapping[DeferUsage, DeferredFragmentRecord], + ) -> None: + """Execute deferred grouped field set.""" incremental_publisher = self.incremental_publisher - incremental_data_record = ( - incremental_publisher.prepare_new_deferred_fragment_record( - label, path, parent_context - ) - ) try: - awaitable_or_data = self.execute_fields( - parent_type, source_value, path, fields, incremental_data_record + incremental_result = self.execute_fields( + parent_type, + source_value, + path, + deferred_grouped_field_set_record.grouped_field_set, + deferred_grouped_field_set_record, + defer_map, ) - if self.is_awaitable(awaitable_or_data): + if self.is_awaitable(incremental_result): - async def await_data() -> None: + async def await_incremental_result() -> None: try: - data = await awaitable_or_data # type: ignore + result = await incremental_result except GraphQLError as error: - incremental_publisher.add_field_error( - incremental_data_record, error - ) - incremental_publisher.complete_deferred_fragment_record( - incremental_data_record, None + incremental_publisher.mark_errored_deferred_grouped_field_set( + deferred_grouped_field_set_record, error ) else: - incremental_publisher.complete_deferred_fragment_record( - incremental_data_record, data + incremental_publisher.complete_deferred_grouped_field_set( + deferred_grouped_field_set_record, result ) - self.add_task(await_data()) + self.add_task(await_incremental_result()) else: - incremental_publisher.complete_deferred_fragment_record( - incremental_data_record, - awaitable_or_data, # type: ignore + incremental_publisher.complete_deferred_grouped_field_set( + deferred_grouped_field_set_record, + incremental_result, # type: ignore ) + except GraphQLError as error: - incremental_publisher.add_field_error(incremental_data_record, error) - incremental_publisher.complete_deferred_fragment_record( - incremental_data_record, None + incremental_publisher.mark_errored_deferred_grouped_field_set( + deferred_grouped_field_set_record, error ) - awaitable_or_data = None def execute_stream_field( self, @@ -1318,14 +1466,15 @@ def execute_stream_field( field_group: FieldGroup, info: GraphQLResolveInfo, item_type: GraphQLOutputType, - parent_context: IncrementalDataRecord, - label: str | None = None, - ) -> SubsequentDataRecord: + incremental_data_record: IncrementalDataRecord, + stream_record: StreamRecord, + ) -> StreamRecord: """Execute stream field.""" is_awaitable = self.is_awaitable incremental_publisher = self.incremental_publisher - incremental_data_record = incremental_publisher.prepare_new_stream_items_record( - label, item_path, parent_context + stream_items_record = StreamItemsRecord(stream_record, item_path) + incremental_publisher.report_new_stream_items_record( + stream_items_record, incremental_data_record ) completed_item: Any @@ -1339,23 +1488,24 @@ async def await_completed_awaitable_item() -> None: info, item_path, item, - incremental_data_record, + stream_items_record, + {}, ) except GraphQLError as error: incremental_publisher.add_field_error( incremental_data_record, error ) - incremental_publisher.filter(path, incremental_data_record) + incremental_publisher.filter(path, stream_items_record) incremental_publisher.complete_stream_items_record( - incremental_data_record, None + stream_items_record, None ) else: incremental_publisher.complete_stream_items_record( - incremental_data_record, [value] + stream_items_record, [value] ) self.add_task(await_completed_awaitable_item()) - return incremental_data_record + return stream_items_record try: try: @@ -1365,7 +1515,8 @@ async def await_completed_awaitable_item() -> None: info, item_path, item, - incremental_data_record, + stream_items_record, + {}, ) except Exception as raw_error: self.handle_field_error( @@ -1373,17 +1524,16 @@ async def await_completed_awaitable_item() -> None: item_type, field_group, item_path, - incremental_data_record, + stream_items_record, ) completed_item = None - incremental_publisher.filter(item_path, incremental_data_record) + incremental_publisher.filter(item_path, stream_items_record) except GraphQLError as error: - incremental_publisher.add_field_error(incremental_data_record, error) incremental_publisher.filter(path, incremental_data_record) - incremental_publisher.complete_stream_items_record( - incremental_data_record, None + incremental_publisher.mark_errored_stream_items_record( + stream_items_record, error ) - return incremental_data_record + return stream_items_record if is_awaitable(completed_item): @@ -1397,30 +1547,27 @@ async def await_completed_item() -> None: item_type, field_group, item_path, - incremental_data_record, + stream_items_record, ) - incremental_publisher.filter(item_path, incremental_data_record) + incremental_publisher.filter(item_path, stream_items_record) value = None except GraphQLError as error: # pragma: no cover - incremental_publisher.add_field_error( - incremental_data_record, error - ) - incremental_publisher.filter(path, incremental_data_record) - incremental_publisher.complete_stream_items_record( - incremental_data_record, None + incremental_publisher.filter(path, stream_items_record) + incremental_publisher.mark_errored_stream_items_record( + stream_items_record, error ) else: incremental_publisher.complete_stream_items_record( - incremental_data_record, [value] + stream_items_record, [value] ) self.add_task(await_completed_item()) - return incremental_data_record + return stream_items_record incremental_publisher.complete_stream_items_record( - incremental_data_record, [completed_item] + stream_items_record, [completed_item] ) - return incremental_data_record + return stream_items_record async def execute_stream_async_iterator_item( self, @@ -1428,8 +1575,7 @@ async def execute_stream_async_iterator_item( field_group: FieldGroup, info: GraphQLResolveInfo, item_type: GraphQLOutputType, - incremental_data_record: StreamItemsRecord, - path: Path, + stream_items_record: StreamItemsRecord, item_path: Path, ) -> Any: """Execute stream iterator item.""" @@ -1439,14 +1585,21 @@ async def execute_stream_async_iterator_item( item = await anext(async_iterator) except StopAsyncIteration as raw_error: self.incremental_publisher.set_is_completed_async_iterator( - incremental_data_record + stream_items_record ) raise StopAsyncIteration from raw_error except Exception as raw_error: - raise located_error(raw_error, field_group, path.as_list()) from raw_error + raise located_error( + raw_error, + field_group.to_nodes(), + stream_items_record.stream_record.path, + ) from raw_error + else: + if stream_items_record.stream_record.errors: + raise StopAsyncIteration try: completed_item = self.complete_value( - item_type, field_group, info, item_path, item, incremental_data_record + item_type, field_group, info, item_path, item, stream_items_record, {} ) return ( await completed_item @@ -1455,9 +1608,9 @@ async def execute_stream_async_iterator_item( ) except Exception as raw_error: self.handle_field_error( - raw_error, item_type, field_group, item_path, incremental_data_record + raw_error, item_type, field_group, item_path, stream_items_record ) - self.incremental_publisher.filter(item_path, incremental_data_record) + self.incremental_publisher.filter(item_path, stream_items_record) async def execute_stream_async_iterator( self, @@ -1467,21 +1620,20 @@ async def execute_stream_async_iterator( info: GraphQLResolveInfo, item_type: GraphQLOutputType, path: Path, - parent_context: IncrementalDataRecord, - label: str | None = None, + incremental_data_record: IncrementalDataRecord, + stream_record: StreamRecord, ) -> None: """Execute stream iterator.""" incremental_publisher = self.incremental_publisher index = initial_index - previous_incremental_data_record = parent_context + current_incremental_data_record = incremental_data_record done = False while True: item_path = Path(path, index, None) - incremental_data_record = ( - incremental_publisher.prepare_new_stream_items_record( - label, item_path, previous_incremental_data_record, async_iterator - ) + stream_items_record = StreamItemsRecord(stream_record, item_path) + incremental_publisher.report_new_stream_items_record( + stream_items_record, current_incremental_data_record ) try: @@ -1490,15 +1642,13 @@ async def execute_stream_async_iterator( field_group, info, item_type, - incremental_data_record, - path, + stream_items_record, item_path, ) except GraphQLError as error: - incremental_publisher.add_field_error(incremental_data_record, error) incremental_publisher.filter(path, incremental_data_record) - incremental_publisher.complete_stream_items_record( - incremental_data_record, None + incremental_publisher.mark_errored_stream_items_record( + stream_items_record, error ) if async_iterator: # pragma: no cover else with suppress_exceptions: @@ -1509,15 +1659,15 @@ async def execute_stream_async_iterator( break except StopAsyncIteration: done = True - - incremental_publisher.complete_stream_items_record( - incremental_data_record, - [completed_item], - ) + else: + incremental_publisher.complete_stream_items_record( + stream_items_record, + [completed_item], + ) if done: break - previous_incremental_data_record = incremental_data_record + current_incremental_data_record = incremental_data_record index += 1 def add_task(self, awaitable: Awaitable[Any]) -> None: @@ -1667,7 +1817,7 @@ def execute_impl( # at which point we still log the error and null the parent field, which # in this case is the entire response. incremental_publisher = context.incremental_publisher - initial_result_record = incremental_publisher.prepare_initial_result_record() + initial_result_record = InitialResultRecord() try: data = context.execute_operation(initial_result_record) if context.is_awaitable(data): @@ -1759,10 +1909,92 @@ def invalid_return_type_error( """Create a GraphQLError for an invalid return type.""" return GraphQLError( f"Expected value of type '{return_type.name}' but got: {inspect(result)}.", - field_group, + field_group.to_nodes(), ) +def add_new_deferred_fragments( + incremental_publisher: IncrementalPublisher, + new_defer_usages: Sequence[DeferUsage], + incremental_data_record: IncrementalDataRecord, + defer_map: Mapping[DeferUsage, DeferredFragmentRecord] | None = None, + path: Path | None = None, +) -> Mapping[DeferUsage, DeferredFragmentRecord]: + """Add new deferred fragments to the defer map.""" + new_defer_map: Mapping[DeferUsage, DeferredFragmentRecord] + if not new_defer_usages: + return {} if defer_map is None else defer_map + new_defer_map = {} if defer_map is None else dict(defer_map) + for defer_usage in new_defer_usages: + ancestors = defer_usage.ancestors + parent_defer_usage = ancestors[0] if ancestors else None + + parent = ( + cast(InitialResultRecord | StreamItemsRecord, incremental_data_record) + if parent_defer_usage is None + else dferred_fragment_record_from_defer_usage( + parent_defer_usage, new_defer_map + ) + ) + + deferred_fragment_record = DeferredFragmentRecord(path, defer_usage.label) + + incremental_publisher.report_new_defer_fragment_record( + deferred_fragment_record, parent + ) + + new_defer_map[defer_usage] = deferred_fragment_record + + return new_defer_map + + +def deferred_fragment_record_from_defer_usage( + defer_usage: DeferUsage, defer_map: Mapping[DeferUsage, DeferredFragmentRecord] +) -> DeferredFragmentRecord: + """Get the deferred fragment record mapped to the given defer usage.""" + return defer_map[defer_usage] + + +def add_new_deferred_grouped_field_sets( + incremental_publisher: IncrementalPublisher, + new_grouped_field_set_details: Mapping[DeferUsage, GroupedFieldSetDetails], + defer_map: Mapping[DeferUsage, DeferredFragmentRecord], + path: Path | None = None, +) -> list[DeferredGroupedFieldSetRecord]: + """Add new deferred grouped field sets to the defer map.""" + new_deferred_grouped_field_set_records: list[DeferredGroupedFieldSetRecord] = [] + + for ( + new_grouped_field_set_defer_usages, + grouped_field_set_details, + ) in new_grouped_field_set_details.items(): + deferred_fragment_records = get_deferred_fragment_records( + new_grouped_field_set_defer_usages, defer_map + ) + deferred_grouped_field_set_record = DeferredGroupedFieldSetRecord( + deferred_fragment_records, + grouped_field_set_details.grouped_field_set, + grouped_field_set_details.should_initiate_defer, + path, + ) + incremental_publisher.report_new_deferred_grouped_filed_set_record( + deferred_grouped_field_set_record + ) + new_deferred_grouped_field_set_records.append(deferred_grouped_field_set_record) + + return new_deferred_grouped_field_set_records + + +def get_deferred_fragment_records( + defer_usages: DeferUsageSet, defer_map: Mapping[DeferUsage, DeferredFragmentRecord] +) -> list[DeferredFragmentRecord]: + """Get the deferred fragment records for the given defer usages.""" + return [ + deferred_fragment_record_from_defer_usage(defer_usage, defer_map) + for defer_usage in defer_usages + ] + + def get_typename(value: Any) -> str | None: """Get the ``__typename`` property of the given value.""" if isinstance(value, Mapping): @@ -2025,12 +2257,12 @@ def execute_subscription( ).grouped_field_set first_root_field = next(iter(grouped_field_set.items())) response_name, field_group = first_root_field - field_name = field_group[0].name.value + field_name = field_group.fields[0].node.name.value field_def = schema.get_field(root_type, field_name) if not field_def: msg = f"The subscription field '{field_name}' is not defined." - raise GraphQLError(msg, field_group) + raise GraphQLError(msg, field_group.to_nodes()) path = Path(None, response_name, root_type.name) info = context.build_resolve_info(field_def, field_group, root_type, path) @@ -2041,7 +2273,9 @@ def execute_subscription( try: # Build a dictionary of arguments from the field.arguments AST, using the # variables scope to fulfill any variable references. - args = get_argument_values(field_def, field_group[0], context.variable_values) + args = get_argument_values( + field_def, field_group.fields[0].node, context.variable_values + ) # Call the `subscribe()` resolver or the default resolver to produce an # AsyncIterable yielding raw payloads. @@ -2054,14 +2288,16 @@ async def await_result() -> AsyncIterable[Any]: try: return assert_event_stream(await result) except Exception as error: - raise located_error(error, field_group, path.as_list()) from error + raise located_error( + error, field_group.to_nodes(), path.as_list() + ) from error return await_result() return assert_event_stream(result) except Exception as error: - raise located_error(error, field_group, path.as_list()) from error + raise located_error(error, field_group.to_nodes(), path.as_list()) from error def assert_event_stream(result: Any) -> AsyncIterable: diff --git a/src/graphql/execution/incremental_publisher.py b/src/graphql/execution/incremental_publisher.py index fdc35fff..a07d4212 100644 --- a/src/graphql/execution/incremental_publisher.py +++ b/src/graphql/execution/incremental_publisher.py @@ -8,8 +8,8 @@ TYPE_CHECKING, Any, AsyncGenerator, - AsyncIterator, Awaitable, + Callable, Collection, Iterator, NamedTuple, @@ -25,6 +25,7 @@ if TYPE_CHECKING: from ..error import GraphQLError, GraphQLFormattedError from ..pyutils import Path + from .collect_fields import GroupedFieldSet __all__ = [ "ASYNC_DELAY", @@ -54,6 +55,80 @@ suppress_key_error = suppress(KeyError) +class FormattedCompletedResult(TypedDict, total=False): + """Formatted completed execution result""" + + path: list[str | int] + label: str + errors: list[GraphQLFormattedError] + + +class CompletedResult: + """Completed execution result""" + + path: list[str | int] + label: str | None + errors: list[GraphQLError] | None + + __slots__ = "path", "label", "errors" + + def __init__( + self, + path: list[str | int], + label: str | None = None, + errors: list[GraphQLError] | None = None, + ) -> None: + self.path = path + self.label = label + self.errors = errors + + def __repr__(self) -> str: + name = self.__class__.__name__ + args: list[str] = [f"path={self.path!r}"] + if self.label: + args.append(f"label={self.label!r}") + if self.errors: + args.append(f"errors={self.errors!r}") + return f"{name}({', '.join(args)})" + + @property + def formatted(self) -> FormattedCompletedResult: + """Get execution result formatted according to the specification.""" + formatted: FormattedCompletedResult = {"path": self.path} + if self.label is not None: + formatted["label"] = self.label + if self.errors is not None: + formatted["errors"] = [error.formatted for error in self.errors] + return formatted + + def __eq__(self, other: object) -> bool: + if isinstance(other, dict): + return ( + other.get("path") == self.path + and ("label" not in other or other["label"] == self.label) + and ("errors" not in other or other["errors"] == self.errors) + ) + if isinstance(other, tuple): + size = len(other) + return 1 < size < 4 and (self.path, self.label, self.errors)[:size] == other + return ( + isinstance(other, self.__class__) + and other.path == self.path + and other.label == self.label + and other.errors == self.errors + ) + + def __ne__(self, other: object) -> bool: + return not self == other + + +class IncrementalUpdate(NamedTuple): + """Incremental update""" + + incremental: list[IncrementalResult] + completed: list[CompletedResult] + + class FormattedExecutionResult(TypedDict, total=False): """Formatted execution result""" @@ -147,31 +222,26 @@ class InitialIncrementalExecutionResult: data: dict[str, Any] | None errors: list[GraphQLError] | None - incremental: Sequence[IncrementalResult] | None has_next: bool extensions: dict[str, Any] | None - __slots__ = "data", "errors", "has_next", "incremental", "extensions" + __slots__ = "data", "errors", "has_next", "extensions" def __init__( self, data: dict[str, Any] | None = None, errors: list[GraphQLError] | None = None, - incremental: Sequence[IncrementalResult] | None = None, has_next: bool = False, extensions: dict[str, Any] | None = None, ) -> None: self.data = data self.errors = errors - self.incremental = incremental self.has_next = has_next self.extensions = extensions def __repr__(self) -> str: name = self.__class__.__name__ args: list[str] = [f"data={self.data!r}, errors={self.errors!r}"] - if self.incremental: - args.append(f"incremental[{len(self.incremental)}]") if self.has_next: args.append("has_next") if self.extensions: @@ -184,8 +254,6 @@ def formatted(self) -> FormattedInitialIncrementalExecutionResult: formatted: FormattedInitialIncrementalExecutionResult = {"data": self.data} if self.errors is not None: formatted["errors"] = [error.formatted for error in self.errors] - if self.incremental: - formatted["incremental"] = [result.formatted for result in self.incremental] formatted["hasNext"] = self.has_next if self.extensions is not None: formatted["extensions"] = self.extensions @@ -196,10 +264,6 @@ def __eq__(self, other: object) -> bool: return ( other.get("data") == self.data and other.get("errors") == self.errors - and ( - "incremental" not in other - or other["incremental"] == self.incremental - ) and ("hasNext" not in other or other["hasNext"] == self.has_next) and ( "extensions" not in other or other["extensions"] == self.extensions @@ -208,11 +272,10 @@ def __eq__(self, other: object) -> bool: if isinstance(other, tuple): size = len(other) return ( - 1 < size < 6 + 1 < size < 5 and ( self.data, self.errors, - self.incremental, self.has_next, self.extensions, )[:size] @@ -222,7 +285,6 @@ def __eq__(self, other: object) -> bool: isinstance(other, self.__class__) and other.data == self.data and other.errors == self.errors - and other.incremental == self.incremental and other.has_next == self.has_next and other.extensions == self.extensions ) @@ -244,7 +306,6 @@ class FormattedIncrementalDeferResult(TypedDict, total=False): data: dict[str, Any] | None errors: list[GraphQLFormattedError] path: list[str | int] - label: str extensions: dict[str, Any] @@ -254,23 +315,20 @@ class IncrementalDeferResult: data: dict[str, Any] | None errors: list[GraphQLError] | None path: list[str | int] | None - label: str | None extensions: dict[str, Any] | None - __slots__ = "data", "errors", "path", "label", "extensions" + __slots__ = "data", "errors", "path", "extensions" def __init__( self, data: dict[str, Any] | None = None, errors: list[GraphQLError] | None = None, path: list[str | int] | None = None, - label: str | None = None, extensions: dict[str, Any] | None = None, ) -> None: self.data = data self.errors = errors self.path = path - self.label = label self.extensions = extensions def __repr__(self) -> str: @@ -278,8 +336,6 @@ def __repr__(self) -> str: args: list[str] = [f"data={self.data!r}, errors={self.errors!r}"] if self.path: args.append(f"path={self.path!r}") - if self.label: - args.append(f"label={self.label!r}") if self.extensions: args.append(f"extensions={self.extensions}") return f"{name}({', '.join(args)})" @@ -292,8 +348,6 @@ def formatted(self) -> FormattedIncrementalDeferResult: formatted["errors"] = [error.formatted for error in self.errors] if self.path is not None: formatted["path"] = self.path - if self.label is not None: - formatted["label"] = self.label if self.extensions is not None: formatted["extensions"] = self.extensions return formatted @@ -304,7 +358,6 @@ def __eq__(self, other: object) -> bool: other.get("data") == self.data and other.get("errors") == self.errors and ("path" not in other or other["path"] == self.path) - and ("label" not in other or other["label"] == self.label) and ( "extensions" not in other or other["extensions"] == self.extensions ) @@ -312,18 +365,14 @@ def __eq__(self, other: object) -> bool: if isinstance(other, tuple): size = len(other) return ( - 1 < size < 6 - and (self.data, self.errors, self.path, self.label, self.extensions)[ - :size - ] - == other + 1 < size < 5 + and (self.data, self.errors, self.path, self.extensions)[:size] == other ) return ( isinstance(other, self.__class__) and other.data == self.data and other.errors == self.errors and other.path == self.path - and other.label == self.label and other.extensions == self.extensions ) @@ -337,7 +386,6 @@ class FormattedIncrementalStreamResult(TypedDict, total=False): items: list[Any] | None errors: list[GraphQLFormattedError] path: list[str | int] - label: str extensions: dict[str, Any] @@ -347,7 +395,6 @@ class IncrementalStreamResult: items: list[Any] | None errors: list[GraphQLError] | None path: list[str | int] | None - label: str | None extensions: dict[str, Any] | None __slots__ = "items", "errors", "path", "label", "extensions" @@ -357,13 +404,11 @@ def __init__( items: list[Any] | None = None, errors: list[GraphQLError] | None = None, path: list[str | int] | None = None, - label: str | None = None, extensions: dict[str, Any] | None = None, ) -> None: self.items = items self.errors = errors self.path = path - self.label = label self.extensions = extensions def __repr__(self) -> str: @@ -371,8 +416,6 @@ def __repr__(self) -> str: args: list[str] = [f"items={self.items!r}, errors={self.errors!r}"] if self.path: args.append(f"path={self.path!r}") - if self.label: - args.append(f"label={self.label!r}") if self.extensions: args.append(f"extensions={self.extensions}") return f"{name}({', '.join(args)})" @@ -385,8 +428,6 @@ def formatted(self) -> FormattedIncrementalStreamResult: formatted["errors"] = [error.formatted for error in self.errors] if self.path is not None: formatted["path"] = self.path - if self.label is not None: - formatted["label"] = self.label if self.extensions is not None: formatted["extensions"] = self.extensions return formatted @@ -397,7 +438,6 @@ def __eq__(self, other: object) -> bool: other.get("items") == self.items and other.get("errors") == self.errors and ("path" not in other or other["path"] == self.path) - and ("label" not in other or other["label"] == self.label) and ( "extensions" not in other or other["extensions"] == self.extensions ) @@ -405,10 +445,8 @@ def __eq__(self, other: object) -> bool: if isinstance(other, tuple): size = len(other) return ( - 1 < size < 6 - and (self.items, self.errors, self.path, self.label, self.extensions)[ - :size - ] + 1 < size < 5 + and (self.items, self.errors, self.path, self.extensions)[:size] == other ) return ( @@ -416,7 +454,6 @@ def __eq__(self, other: object) -> bool: and other.items == self.items and other.errors == self.errors and other.path == self.path - and other.label == self.label and other.extensions == self.extensions ) @@ -434,8 +471,9 @@ def __ne__(self, other: object) -> bool: class FormattedSubsequentIncrementalExecutionResult(TypedDict, total=False): """Formatted subsequent incremental execution result""" - incremental: list[FormattedIncrementalResult] hasNext: bool + incremental: list[FormattedIncrementalResult] + completed: list[FormattedCompletedResult] extensions: dict[str, Any] @@ -446,29 +484,34 @@ class SubsequentIncrementalExecutionResult: - ``incremental`` is a list of the results from defer/stream directives. """ - __slots__ = "has_next", "incremental", "extensions" + __slots__ = "has_next", "incremental", "completed", "extensions" - incremental: Sequence[IncrementalResult] | None has_next: bool + incremental: Sequence[IncrementalResult] | None + completed: Sequence[CompletedResult] | None extensions: dict[str, Any] | None def __init__( self, - incremental: Sequence[IncrementalResult] | None = None, has_next: bool = False, + incremental: Sequence[IncrementalResult] | None = None, + completed: Sequence[CompletedResult] | None = None, extensions: dict[str, Any] | None = None, ) -> None: - self.incremental = incremental self.has_next = has_next + self.incremental = incremental + self.completed = completed self.extensions = extensions def __repr__(self) -> str: name = self.__class__.__name__ args: list[str] = [] - if self.incremental: - args.append(f"incremental[{len(self.incremental)}]") if self.has_next: args.append("has_next") + if self.incremental: + args.append(f"incremental[{len(self.incremental)}]") + if self.completed: + args.append(f"completed[{len(self.completed)}]") if self.extensions: args.append(f"extensions={self.extensions}") return f"{name}({', '.join(args)})" @@ -477,9 +520,11 @@ def __repr__(self) -> str: def formatted(self) -> FormattedSubsequentIncrementalExecutionResult: """Get execution result formatted according to the specification.""" formatted: FormattedSubsequentIncrementalExecutionResult = {} + formatted["hasNext"] = self.has_next if self.incremental: formatted["incremental"] = [result.formatted for result in self.incremental] - formatted["hasNext"] = self.has_next + if self.completed: + formatted["completed"] = [result.formatted for result in self.completed] if self.extensions is not None: formatted["extensions"] = self.extensions return formatted @@ -487,8 +532,12 @@ def formatted(self) -> FormattedSubsequentIncrementalExecutionResult: def __eq__(self, other: object) -> bool: if isinstance(other, dict): return ( - ("incremental" not in other or other["incremental"] == self.incremental) - and ("hasNext" in other and other["hasNext"] == self.has_next) + ("hasNext" in other and other["hasNext"] == self.has_next) + and ( + "incremental" not in other + or other["incremental"] == self.incremental + ) + and ("completed" not in other or other["completed"] == self.completed) and ( "extensions" not in other or other["extensions"] == self.extensions ) @@ -496,18 +545,20 @@ def __eq__(self, other: object) -> bool: if isinstance(other, tuple): size = len(other) return ( - 1 < size < 4 + 1 < size < 5 and ( - self.incremental, self.has_next, + self.incremental, + self.completed, self.extensions, )[:size] == other ) return ( isinstance(other, self.__class__) - and other.incremental == self.incremental and other.has_next == self.has_next + and other.incremental == self.incremental + and other.completed == self.completed and other.extensions == self.extensions ) @@ -530,20 +581,20 @@ class IncrementalPublisher: The internal publishing state is managed as follows: - ``_released``: the set of Subsequent Data records that are ready to be sent to the + ``_released``: the set of Subsequent Result records that are ready to be sent to the client, i.e. their parents have completed and they have also completed. - ``_pending``: the set of Subsequent Data records that are definitely pending, i.e. + ``_pending``: the set of Subsequent Result records that are definitely pending, i.e. their parents have completed so that they can no longer be filtered. This includes - all Subsequent Data records in `released`, as well as Subsequent Data records that - have not yet completed. + all Subsequent Result records in `released`, as well as the records that have not + yet completed. Note: Instead of sets we use dicts (with values set to None) which preserve order and thereby achieve more deterministic results. """ - _released: dict[SubsequentDataRecord, None] - _pending: dict[SubsequentDataRecord, None] + _released: dict[SubsequentResultRecord, None] + _pending: dict[SubsequentResultRecord, None] _resolve: Event | None def __init__(self) -> None: @@ -552,60 +603,105 @@ def __init__(self) -> None: self._resolve = None # lazy initialization self._tasks: set[Awaitable] = set() - def prepare_initial_result_record(self) -> InitialResultRecord: - """Prepare a new initial result record.""" - return InitialResultRecord(errors=[], children={}) - - def prepare_new_deferred_fragment_record( - self, - label: str | None, - path: Path | None, - parent_context: IncrementalDataRecord, - ) -> DeferredFragmentRecord: - """Prepare a new deferred fragment record.""" - deferred_fragment_record = DeferredFragmentRecord(label, path) + @staticmethod + def report_new_defer_fragment_record( + deferred_fragment_record: DeferredFragmentRecord, + parent_incremental_result_record: InitialResultRecord + | DeferredFragmentRecord + | StreamItemsRecord, + ) -> None: + """Report a new deferred fragment record.""" + parent_incremental_result_record.children[deferred_fragment_record] = None - parent_context.children[deferred_fragment_record] = None - return deferred_fragment_record + @staticmethod + def report_new_deferred_grouped_filed_set_record( + deferred_grouped_field_set_record: DeferredGroupedFieldSetRecord, + ) -> None: + """Report a new deferred grouped field set record.""" + for ( + deferred_fragment_record + ) in deferred_grouped_field_set_record.deferred_fragment_records: + deferred_fragment_record._pending[deferred_grouped_field_set_record] = None # noqa: SLF001 + deferred_fragment_record.deferred_grouped_field_set_records[ + deferred_grouped_field_set_record + ] = None + + @staticmethod + def report_new_stream_items_record( + stream_items_record: StreamItemsRecord, + parent_incremental_data_record: IncrementalDataRecord, + ) -> None: + """Report a new stream items record.""" + if isinstance(parent_incremental_data_record, DeferredGroupedFieldSetRecord): + for parent in parent_incremental_data_record.deferred_fragment_records: + parent.children[stream_items_record] = None + else: + parent_incremental_data_record.children[stream_items_record] = None - def prepare_new_stream_items_record( + def complete_deferred_grouped_field_set( self, - label: str | None, - path: Path | None, - parent_context: IncrementalDataRecord, - async_iterator: AsyncIterator[Any] | None = None, - ) -> StreamItemsRecord: - """Prepare a new stream items record.""" - stream_items_record = StreamItemsRecord(label, path, async_iterator) - - parent_context.children[stream_items_record] = None - return stream_items_record + deferred_grouped_field_set_record: DeferredGroupedFieldSetRecord, + data: dict[str, Any], + ) -> None: + """Complete the given deferred grouped field set record with the given data.""" + deferred_grouped_field_set_record.data = data + for ( + deferred_fragment_record + ) in deferred_grouped_field_set_record.deferred_fragment_records: + pending = deferred_fragment_record._pending # noqa: SLF001 + del pending[deferred_grouped_field_set_record] + if not pending: + self.complete_deferred_fragment_record(deferred_fragment_record) + + def mark_errored_deferred_grouped_field_set( + self, + deferred_grouped_field_set_record: DeferredGroupedFieldSetRecord, + error: GraphQLError, + ) -> None: + """Mark the given deferred grouped field set record as errored.""" + for ( + deferred_fragment_record + ) in deferred_grouped_field_set_record.deferred_fragment_records: + deferred_fragment_record.errors.append(error) + self.complete_deferred_fragment_record(deferred_fragment_record) def complete_deferred_fragment_record( - self, - deferred_fragment_record: DeferredFragmentRecord, - data: dict[str, Any] | None, + self, deferred_fragment_record: DeferredFragmentRecord ) -> None: """Complete the given deferred fragment record.""" - deferred_fragment_record.data = data - deferred_fragment_record.is_completed = True self._release(deferred_fragment_record) def complete_stream_items_record( self, stream_items_record: StreamItemsRecord, - items: list[str] | None, + items: list[Any], ) -> None: """Complete the given stream items record.""" stream_items_record.items = items stream_items_record.is_completed = True self._release(stream_items_record) + def mark_errored_stream_items_record( + self, stream_items_record: StreamItemsRecord, error: GraphQLError + ) -> None: + """Mark the given stream items record as errored.""" + stream_items_record.errors.append(error) + self.set_is_final_record(stream_items_record) + stream_items_record.is_completed = True + stream_items_record.stream_record.early_return() # !!! TODO + self._release(stream_items_record) + + @staticmethod + def set_is_final_record(stream_items_record: StreamItemsRecord) -> None: + """Mark stream items record as final.""" + stream_items_record.is_final_record = True + def set_is_completed_async_iterator( self, stream_items_record: StreamItemsRecord ) -> None: """Mark async iterator for stream items as completed.""" stream_items_record.is_completed_async_iterator = True + self.set_is_final_record(stream_items_record) def add_field_error( self, incremental_data_record: IncrementalDataRecord, error: GraphQLError @@ -657,29 +753,28 @@ def build_error_response( def filter( self, - null_path: Path, + null_path: Path | None, erroring_incremental_data_record: IncrementalDataRecord, ) -> None: """Filter out the given erroring incremental data record.""" - null_path_list = null_path.as_list() + null_path_list = null_path.as_list() if null_path else [] + + streams: list[StreamRecord] = [] - descendants = self._get_descendants(erroring_incremental_data_record.children) + children = self._get_children(erroring_incremental_data_record) + descendants = self._get_descendants(children) for child in descendants: - if not self._matches_path(child.path, null_path_list): + if not self._nulls_child_subsequent_result_record(child, null_path_list): continue child.filtered = True if isinstance(child, StreamItemsRecord): - async_iterator = child.async_iterator - if async_iterator: - try: - close_async_iterator = async_iterator.aclose() # type:ignore - except AttributeError: # pragma: no cover - pass - else: - self._add_task(close_async_iterator) + streams.append(child.stream_record) + + for stream in streams: + stream.early_return() # !!! TODO async def _subscribe( self, @@ -709,20 +804,16 @@ async def _subscribe( self._resolve = resolve = Event() await resolve.wait() finally: - close_async_iterators = [] - for incremental_data_record in pending: - if isinstance( - incremental_data_record, StreamItemsRecord - ): # pragma: no cover - async_iterator = incremental_data_record.async_iterator - if async_iterator: - try: - close_async_iterator = async_iterator.aclose() # type: ignore - except AttributeError: - pass - else: - close_async_iterators.append(close_async_iterator) - await gather(*close_async_iterators) + streams: list[StreamRecord] = [] + descendants = self._get_descendants(pending) + for subsequent_result_record in descendants: + if isinstance(subsequent_result_record, StreamItemsRecord): + streams.append(subsequent_result_record.stream_record) + promises = [] # TODO + for stream in streams: + if stream.early_return: + promises.append(stream.early_return()) + await gather(*promises) def _trigger(self) -> None: """Trigger the resolve event.""" @@ -731,82 +822,129 @@ def _trigger(self) -> None: resolve.set() self._resolve = Event() - def _introduce(self, item: SubsequentDataRecord) -> None: + def _introduce(self, item: SubsequentResultRecord) -> None: """Introduce a new IncrementalDataRecord.""" self._pending[item] = None - def _release(self, item: SubsequentDataRecord) -> None: + def _release(self, item: SubsequentResultRecord) -> None: """Release the given IncrementalDataRecord.""" if item in self._pending: self._released[item] = None self._trigger() - def _push(self, item: SubsequentDataRecord) -> None: + def _push(self, item: SubsequentResultRecord) -> None: """Push the given IncrementalDataRecord.""" self._released[item] = None self._pending[item] = None self._trigger() def _get_incremental_result( - self, completed_records: Collection[SubsequentDataRecord] + self, completed_records: Collection[SubsequentResultRecord] ) -> SubsequentIncrementalExecutionResult | None: """Get the incremental result with the completed records.""" + update = self._process_pending(completed_records) + incremental, completed = update.incremental, update.completed + + has_next = bool(self._pending) + if not incremental and not completed and has_next: + return None + + return SubsequentIncrementalExecutionResult( + has_next, incremental or None, completed or None + ) + + def _process_pending( + self, + completed_records: Collection[SubsequentResultRecord], + ) -> IncrementalUpdate: + """Process the pending records.""" incremental_results: list[IncrementalResult] = [] - encountered_completed_async_iterator = False - append_result = incremental_results.append - for incremental_data_record in completed_records: - incremental_result: IncrementalResult - for child in incremental_data_record.children: + completed_results: list[CompletedResult] = [] + to_result = self._completed_record_to_result + for subsequent_result_record in completed_records: + for child in subsequent_result_record.children: if child.filtered: continue self._publish(child) - if isinstance(incremental_data_record, StreamItemsRecord): - items = incremental_data_record.items - if incremental_data_record.is_completed_async_iterator: + incremental_result: IncrementalResult + if isinstance(subsequent_result_record, StreamItemsRecord): + if subsequent_result_record.is_final_record: + completed_results.append( + to_result(subsequent_result_record.stream_record) + ) + if subsequent_result_record.is_completed_async_iterator: # async iterable resolver finished but there may be pending payload - encountered_completed_async_iterator = True - continue # pragma: no cover + continue + if subsequent_result_record.stream_record.errors: + continue incremental_result = IncrementalStreamResult( - items, - incremental_data_record.errors - if incremental_data_record.errors - else None, - incremental_data_record.path, - incremental_data_record.label, + subsequent_result_record.items, + subsequent_result_record.errors or None, + subsequent_result_record.stream_record.path, ) + incremental_results.append(incremental_result) else: - data = incremental_data_record.data - incremental_result = IncrementalDeferResult( - data, - incremental_data_record.errors - if incremental_data_record.errors - else None, - incremental_data_record.path, - incremental_data_record.label, - ) - append_result(incremental_result) - - has_next = bool(self._pending) - if incremental_results: - return SubsequentIncrementalExecutionResult( - incremental=incremental_results, has_next=has_next - ) - if encountered_completed_async_iterator and not has_next: - return SubsequentIncrementalExecutionResult(has_next=False) - return None + completed_results.append(to_result(subsequent_result_record)) + if subsequent_result_record.errors: + continue + for ( + deferred_grouped_field_set_record + ) in subsequent_result_record.deferred_grouped_field_set_records: + if not deferred_grouped_field_set_record.sent: + deferred_grouped_field_set_record.sent = True + incremental_result = IncrementalDeferResult( + deferred_grouped_field_set_record.data, + deferred_grouped_field_set_record.errors or None, + deferred_grouped_field_set_record.path, + ) + incremental_results.append(incremental_result) + return IncrementalUpdate(incremental_results, completed_results) + + @staticmethod + def _completed_record_to_result( + completed_record: DeferredFragmentRecord | StreamRecord, + ) -> CompletedResult: + """Convert the completed record to a result.""" + return CompletedResult( + completed_record.path, + completed_record.label or None, + completed_record.errors or None, + ) - def _publish(self, subsequent_result_record: SubsequentDataRecord) -> None: + def _publish(self, subsequent_result_record: SubsequentResultRecord) -> None: """Publish the given incremental data record.""" - if subsequent_result_record.is_completed: + if isinstance(subsequent_result_record, StreamItemsRecord): + if subsequent_result_record.is_completed: # type: ignore + self._push(subsequent_result_record) + else: + self._introduce(subsequent_result_record) + elif subsequent_result_record._pending: # noqa: SLF001 + self._introduce(subsequent_result_record) + else: self._push(subsequent_result_record) + + @staticmethod + def _get_children( + erroring_incremental_data_record: IncrementalDataRecord, + ) -> dict[SubsequentResultRecord, None]: + """Get the children of the given erroring incremental data record.""" + children: dict[SubsequentResultRecord, None] = {} + if isinstance(erroring_incremental_data_record, DeferredGroupedFieldSetRecord): + for ( + erroring_incremental_result_record + ) in erroring_incremental_data_record.deferred_fragment_records: + for child in erroring_incremental_result_record.children: + children[child] = None else: - self._introduce(subsequent_result_record) + for child in erroring_incremental_data_record.children: + children[child] = None + return children def _get_descendants( self, - children: dict[SubsequentDataRecord, None], - descendants: dict[SubsequentDataRecord, None] | None = None, - ) -> dict[SubsequentDataRecord, None]: + children: dict[SubsequentResultRecord, None], + descendants: dict[SubsequentResultRecord, None] | None = None, + ) -> dict[SubsequentResultRecord, None]: """Get the descendants of the given children.""" if descendants is None: descendants = {} @@ -815,6 +953,24 @@ def _get_descendants( self._get_descendants(child.children, descendants) return descendants + def _nulls_child_subsequent_result_record( + self, + subsequent_result_record: SubsequentResultRecord, + null_path: list[str | int], + ) -> bool: + """Check whether the given subsequent result record is nulled.""" + incremental_data_records: ( + list[SubsequentResultRecord] | dict[DeferredGroupedFieldSetRecord, None] + ) = ( + [subsequent_result_record] + if isinstance(subsequent_result_record, StreamItemsRecord) + else subsequent_result_record.deferred_grouped_field_set_records + ) + return any( + self._matches_path(incremental_data_record.path, null_path) + for incremental_data_record in incremental_data_records + ) + def _matches_path( self, test_path: list[str | int], base_path: list[str | int] ) -> bool: @@ -829,79 +985,139 @@ def _add_task(self, awaitable: Awaitable[Any]) -> None: task.add_done_callback(tasks.discard) -class InitialResultRecord(NamedTuple): - """Formatted subsequent incremental execution result""" +class InitialResultRecord: + """Initial result record""" errors: list[GraphQLError] - children: dict[SubsequentDataRecord, None] + children: dict[SubsequentResultRecord, None] + def __init__(self) -> None: + self.errors = [] + self.children = {} -class DeferredFragmentRecord: - """A record collecting data marked with the defer directive""" - errors: list[GraphQLError] - label: str | None +class DeferredGroupedFieldSetRecord: + """Deferred grouped field set record""" + path: list[str | int] + deferred_fragment_records: list[DeferredFragmentRecord] + grouped_field_set: GroupedFieldSet + should_initiate_defer: bool + errors: list[GraphQLError] data: dict[str, Any] | None - children: dict[SubsequentDataRecord, None] - is_completed: bool - filtered: bool + sent: bool - def __init__(self, label: str | None, path: Path | None) -> None: - self.label = label + def __init__( + self, + deferred_fragment_records: list[DeferredFragmentRecord], + grouped_field_set: GroupedFieldSet, + should_initiate_defer: bool, + path: Path | None = None, + ) -> None: self.path = path.as_list() if path else [] + self.deferred_fragment_records = deferred_fragment_records + self.grouped_field_set = grouped_field_set + self.should_initiate_defer = should_initiate_defer self.errors = [] + self.sent = False + + def __repr__(self) -> str: + name = self.__class__.__name__ + args: list[str] = [ + f"deferred_fragment_records={self.deferred_fragment_records!r}", + f"grouped_field_set={self.grouped_field_set!r}", + ] + if self.path: + args.append(f"path={self.path!r}") + return f"{name}({', '.join(args)})" + + +class DeferredFragmentRecord: + """Deferred fragment record""" + + path: list[str | int] + label: str | None + children: dict[SubsequentResultRecord, None] + deferred_grouped_field_set_records: dict[DeferredGroupedFieldSetRecord, None] + errors: list[GraphQLError] + filtered: bool + _pending: dict[DeferredGroupedFieldSetRecord, None] + + def __init__(self, path: Path | None = None, label: str | None = None) -> None: + self.path = path.as_list() if path else [] + self.label = label self.children = {} - self.is_completed = self.filtered = False - self.data = None + self.filtered = False + self.deferred_grouped_field_set_records = {} + self.errors = [] + self._pending = {} def __repr__(self) -> str: name = self.__class__.__name__ - args: list[str] = [f"path={self.path!r}"] + args: list[str] = [] + if self.path: + args.append(f"path={self.path!r}") if self.label: args.append(f"label={self.label!r}") - if self.data is not None: - args.append("data") return f"{name}({', '.join(args)})" +class StreamRecord: + """Stream record""" + + label: str | None + path: list[str | int] + errors: list[GraphQLError] + early_return: Callable[[], Awaitable[Any]] | None + + def __init__( + self, + path: Path, + label: str | None = None, + early_return: Callable[[], Awaitable[Any]] | None = None, + ) -> None: + self.path = path.as_list() + self.label = label + self.errors = [] + self.early_return = early_return + + class StreamItemsRecord: - """A record collecting items marked with the stream directive""" + """Stream items record""" errors: list[GraphQLError] - label: str | None + stream_record: StreamRecord path: list[str | int] - items: list[str] | None - children: dict[SubsequentDataRecord, None] - async_iterator: AsyncIterator[Any] | None + items: list[str] + children: dict[SubsequentResultRecord, None] + is_final_record: bool is_completed_async_iterator: bool is_completed: bool filtered: bool def __init__( self, - label: str | None, - path: Path | None, - async_iterator: AsyncIterator[Any] | None = None, + stream_record: StreamRecord, + path: Path | None = None, ) -> None: - self.label = label + self.stream_record = stream_record self.path = path.as_list() if path else [] - self.async_iterator = async_iterator - self.errors = [] self.children = {} - self.is_completed_async_iterator = self.is_completed = self.filtered = False - self.items = None + self.errors = [] + self.is_completed_async_iterator = self.is_completed = False + self.is_final_record = self.filtered = False + self.items = [] def __repr__(self) -> str: name = self.__class__.__name__ - args: list[str] = [f"path={self.path!r}"] - if self.label: - args.append(f"label={self.label!r}") - if self.items is not None: - args.append("items") + args: list[str] = [f"stream_record={self.stream_record!r}"] + if self.path: + args.append(f"path={self.path!r}") return f"{name}({', '.join(args)})" -SubsequentDataRecord = Union[DeferredFragmentRecord, StreamItemsRecord] +IncrementalDataRecord = Union[ + InitialResultRecord, DeferredGroupedFieldSetRecord, StreamItemsRecord +] -IncrementalDataRecord = Union[InitialResultRecord, SubsequentDataRecord] +SubsequentResultRecord = Union[DeferredFragmentRecord, StreamItemsRecord] diff --git a/tests/benchmarks/test_visit.py b/tests/benchmarks/test_visit.py index 583075bf..4e7a85a2 100644 --- a/tests/benchmarks/test_visit.py +++ b/tests/benchmarks/test_visit.py @@ -23,5 +23,5 @@ def test_visit_all_ast_nodes(benchmark, big_schema_sdl): # noqa: F811 def test_visit_all_ast_nodes_in_parallel(benchmark, big_schema_sdl): # noqa: F811 document_ast = parse(big_schema_sdl) visitor = DummyVisitor() - parallel_visitor = ParallelVisitor([visitor] * 20) + parallel_visitor = ParallelVisitor([visitor] * 25) benchmark(lambda: visit(document_ast, parallel_visitor))