From 5feaeebd133a0c2bad1313a31d2a838658c5f8a5 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Thu, 15 Feb 2024 22:17:23 +0100 Subject: [PATCH] polish: do not repeat is_awaitable check Replicates graphql/graphql-js@7fd1ddb9eeaba378a6445543be179b35d6c1ee55 --- src/graphql/execution/execute.py | 152 +++++++++++++++++-------------- tests/execution/test_executor.py | 2 + tests/execution/test_stream.py | 49 ++++++++++ 3 files changed, 133 insertions(+), 70 deletions(-) diff --git a/src/graphql/execution/execute.py b/src/graphql/execution/execute.py index 23907903..54bd0ec1 100644 --- a/src/graphql/execution/execute.py +++ b/src/graphql/execution/execute.py @@ -1359,11 +1359,11 @@ def complete_list_item_value( Returns True if the value is awaitable. """ is_awaitable = self.is_awaitable - try: - if is_awaitable(item): - completed_item: Any - async def await_completed() -> Any: + if is_awaitable(item): + # noinspection PyShadowingNames + async def await_completed() -> Any: + try: completed = self.complete_value( item_type, field_nodes, @@ -1373,21 +1373,28 @@ async def await_completed() -> Any: async_payload_record, ) return await completed if is_awaitable(completed) else completed + except Exception as raw_error: + error = located_error(raw_error, field_nodes, item_path.as_list()) + handle_field_error(error, item_type, errors) + self.filter_subsequent_payloads(item_path, async_payload_record) + return None - completed_item = await_completed() - else: - completed_item = self.complete_value( - item_type, - field_nodes, - info, - item_path, - item, - async_payload_record, - ) + complete_results.append(await_completed()) + return True + + try: + completed_item = self.complete_value( + item_type, + field_nodes, + info, + item_path, + item, + async_payload_record, + ) if is_awaitable(completed_item): # noinspection PyShadowingNames - async def catch_error() -> Any: + async def await_completed() -> Any: try: return await completed_item except Exception as raw_error: @@ -1398,7 +1405,7 @@ async def catch_error() -> Any: self.filter_subsequent_payloads(item_path, async_payload_record) return None - complete_results.append(catch_error()) + complete_results.append(await_completed()) return True complete_results.append(completed_item) @@ -1728,15 +1735,17 @@ def execute_stream_field( parent_context: Optional[AsyncPayloadRecord] = None, ) -> AsyncPayloadRecord: """Execute stream field.""" + is_awaitable = self.is_awaitable async_payload_record = StreamRecord( label, item_path, None, parent_context, self ) completed_item: Any - try: - try: - if self.is_awaitable(item): - async def await_completed_item() -> Any: + if is_awaitable(item): + # noinspection PyShadowingNames + async def await_completed_items() -> Optional[List[Any]]: + try: + try: completed = self.complete_value( item_type, field_nodes, @@ -1745,76 +1754,79 @@ async def await_completed_item() -> Any: await item, async_payload_record, ) - return ( + return [ await completed if self.is_awaitable(completed) else completed + ] + except Exception as raw_error: + error = located_error( + raw_error, field_nodes, item_path.as_list() + ) + handle_field_error( + error, item_type, async_payload_record.errors ) + self.filter_subsequent_payloads(item_path, async_payload_record) + return [None] + except GraphQLError as error: + async_payload_record.errors.append(error) + self.filter_subsequent_payloads(path, async_payload_record) + return None - completed_item = await_completed_item() + async_payload_record.add_items(await_completed_items()) + return async_payload_record - else: - completed_item = self.complete_value( - item_type, - field_nodes, - info, - item_path, - item, - async_payload_record, - ) + try: + try: + completed_item = self.complete_value( + item_type, + field_nodes, + info, + item_path, + item, + async_payload_record, + ) - if self.is_awaitable(completed_item): + completed_items: Any - async def await_completed_item() -> Any: + if is_awaitable(completed_item): + # noinspection PyShadowingNames + async def await_completed_items() -> Optional[List[Any]]: # noinspection PyShadowingNames try: - return await completed_item - except Exception as raw_error: - # noinspection PyShadowingNames - error = located_error( - raw_error, field_nodes, item_path.as_list() - ) - handle_field_error( - error, item_type, async_payload_record.errors - ) - self.filter_subsequent_payloads( - item_path, async_payload_record - ) + try: + return [await completed_item] + except Exception as raw_error: # pragma: no cover + # noinspection PyShadowingNames + error = located_error( + raw_error, field_nodes, item_path.as_list() + ) + handle_field_error( + error, item_type, async_payload_record.errors + ) + self.filter_subsequent_payloads( + item_path, async_payload_record + ) + return [None] + except GraphQLError as error: # pragma: no cover + async_payload_record.errors.append(error) + self.filter_subsequent_payloads(path, async_payload_record) return None - complete_item = await_completed_item() - + completed_items = await_completed_items() else: - complete_item = completed_item + completed_items = [completed_item] + except Exception as raw_error: error = located_error(raw_error, field_nodes, item_path.as_list()) handle_field_error(error, item_type, async_payload_record.errors) - self.filter_subsequent_payloads( # pragma: no cover - item_path, async_payload_record - ) - complete_item = None # pragma: no cover + self.filter_subsequent_payloads(item_path, async_payload_record) + completed_items = [None] except GraphQLError as error: async_payload_record.errors.append(error) self.filter_subsequent_payloads(item_path, async_payload_record) - async_payload_record.add_items(None) - return async_payload_record - - completed_items: AwaitableOrValue[Optional[List[Any]]] - if self.is_awaitable(complete_item): - - async def await_completed_items() -> Optional[List[Any]]: - # noinspection PyShadowingNames - try: - return [await complete_item] # type: ignore - except GraphQLError as error: - async_payload_record.errors.append(error) - self.filter_subsequent_payloads(path, async_payload_record) - return None - - completed_items = await_completed_items() - else: - completed_items = [complete_item] + completed_items = None async_payload_record.add_items(completed_items) return async_payload_record diff --git a/tests/execution/test_executor.py b/tests/execution/test_executor.py index be9e8965..b70ed483 100644 --- a/tests/execution/test_executor.py +++ b/tests/execution/test_executor.py @@ -514,6 +514,7 @@ async def asyncReturnErrorWithExtensions(self, _info): ], ) + @pytest.mark.filterwarnings("ignore:.* was never awaited:RuntimeWarning") def handles_sync_errors_combined_with_async_ones(): is_async_resolver_finished = False @@ -560,6 +561,7 @@ async def async_resolver(_obj, _info): ], ) + @pytest.mark.filterwarnings("ignore:.* was never awaited:RuntimeWarning") def full_response_path_is_included_for_non_nullable_fields(): def resolve_ok(*_args): return {} diff --git a/tests/execution/test_stream.py b/tests/execution/test_stream.py index f8bedc62..ccfd1f93 100644 --- a/tests/execution/test_stream.py +++ b/tests/execution/test_stream.py @@ -536,6 +536,54 @@ async def await_friend(f): }, ] + @pytest.mark.asyncio() + async def can_stream_a_field_that_returns_a_list_with_nested_async_fields(): + document = parse( + """ + query { + friendList @stream(initialCount: 2) { + name + id + } + } + """ + ) + + async def get_name(f): + return f.name + + async def get_id(f): + return f.id + + result = await complete( + document, + { + "friendList": lambda _info: [ + {"name": get_name(f), "id": get_id(f)} for f in friends + ] + }, + ) + assert result == [ + { + "data": { + "friendList": [ + {"name": "Luke", "id": "1"}, + {"name": "Han", "id": "2"}, + ] + }, + "hasNext": True, + }, + { + "incremental": [ + { + "items": [{"name": "Leia", "id": "3"}], + "path": ["friendList", 2], + } + ], + "hasNext": False, + }, + ] + @pytest.mark.asyncio() async def handles_error_in_list_of_awaitables_before_initial_count_reached(): document = parse( @@ -1292,6 +1340,7 @@ async def friend_list(_info): } @pytest.mark.asyncio() + @pytest.mark.filterwarnings("ignore:.* was never awaited:RuntimeWarning") async def does_not_filter_payloads_when_null_error_is_in_a_different_path(): document = parse( """