Skip to content

Commit

Permalink
refactor: introduce complete_list_item_value
Browse files Browse the repository at this point in the history
  • Loading branch information
Cito committed Feb 15, 2024
1 parent 3d3393f commit c698ab5
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 100 deletions.
193 changes: 93 additions & 100 deletions src/graphql/execution/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -1177,11 +1177,10 @@ async def complete_async_iterator_value(
"""
errors = async_payload_record.errors if async_payload_record else self.errors
stream = self.get_stream_values(field_nodes, path)
is_awaitable = self.is_awaitable
complete_list_item_value = self.complete_list_item_value
awaitable_indices: List[int] = []
append_awaitable = awaitable_indices.append
completed_results: List[Any] = []
append_result = completed_results.append
index = 0
while True:
if (
Expand Down Expand Up @@ -1213,46 +1212,23 @@ async def complete_async_iterator_value(
value = await anext(iterator)
except StopAsyncIteration:
break
try:
completed_item = self.complete_value(
item_type,
field_nodes,
info,
item_path,
value,
async_payload_record,
)
if is_awaitable(completed_item):
# noinspection PyShadowingNames
async def catch_error(
completed_item: Awaitable[Any], item_path: Path
) -> Any:
try:
return await completed_item
except Exception as raw_error:
error = located_error(
raw_error, field_nodes, item_path.as_list()
)
self.filter_subsequent_payloads(
item_path, async_payload_record
)
handle_field_error(error, item_type, errors)
return None

append_result(catch_error(completed_item, item_path))
append_awaitable(index)
else:
append_result(completed_item)
except Exception as raw_error:
append_result(None)
error = located_error(raw_error, field_nodes, item_path.as_list())
self.filter_subsequent_payloads(item_path, async_payload_record)
handle_field_error(error, item_type, errors)
except Exception as raw_error:
append_result(None)
error = located_error(raw_error, field_nodes, item_path.as_list())
handle_field_error(error, item_type, errors)
completed_results.append(None)
break
if complete_list_item_value(
value,
completed_results,
errors,
item_type,
field_nodes,
info,
item_path,
async_payload_record,
):
append_awaitable(index)

index += 1

if not awaitable_indices:
Expand Down Expand Up @@ -1307,12 +1283,11 @@ def complete_list_value(
# This is specified as a simple map, however we're optimizing the path where
# the list contains no coroutine objects by avoiding creating another coroutine
# object.
is_awaitable = self.is_awaitable
complete_list_item_value = self.complete_list_item_value
awaitable_indices: List[int] = []
append_awaitable = awaitable_indices.append
previous_async_payload_record = async_payload_record
completed_results: List[Any] = []
append_result = completed_results.append
for index, item in enumerate(result):
# No need to modify the info object containing the path, since from here on
# it is not ever accessed by resolver functions.
Expand All @@ -1335,67 +1310,17 @@ def complete_list_value(
)
continue

completed_item: AwaitableOrValue[Any]

if is_awaitable(item):
# noinspection PyShadowingNames
async def await_completed(item: Any, item_path: Path) -> Any:
try:
completed = self.complete_value(
item_type,
field_nodes,
info,
item_path,
await item,
async_payload_record,
)
if is_awaitable(completed):
return await completed
except Exception as raw_error:
error = located_error(
raw_error, field_nodes, item_path.as_list()
)
handle_field_error(error, item_type, errors)
self.filter_subsequent_payloads(item_path, async_payload_record)
return None
return completed

completed_item = await_completed(item, item_path)
else:
try:
completed_item = self.complete_value(
item_type,
field_nodes,
info,
item_path,
item,
async_payload_record,
)
if is_awaitable(completed_item):
# noinspection PyShadowingNames
async def await_completed(item: Any, item_path: Path) -> Any:
try:
return await item
except Exception as raw_error:
error = located_error(
raw_error, field_nodes, item_path.as_list()
)
handle_field_error(error, item_type, errors)
self.filter_subsequent_payloads(
item_path, async_payload_record
)
return None

completed_item = await_completed(completed_item, item_path)
except Exception as raw_error:
error = located_error(raw_error, field_nodes, item_path.as_list())
handle_field_error(error, item_type, errors)
self.filter_subsequent_payloads(item_path, async_payload_record)
completed_item = None

if is_awaitable(completed_item):
if complete_list_item_value(
item,
completed_results,
errors,
item_type,
field_nodes,
info,
item_path,
async_payload_record,
):
append_awaitable(index)
append_result(completed_item)

if not awaitable_indices:
return completed_results
Expand All @@ -1418,6 +1343,74 @@ async def get_completed_results() -> List[Any]:

return get_completed_results()

def complete_list_item_value(
self,
item: Any,
complete_results: List[Any],
errors: List[GraphQLError],
item_type: GraphQLOutputType,
field_nodes: List[FieldNode],
info: GraphQLResolveInfo,
item_path: Path,
async_payload_record: Optional[AsyncPayloadRecord],
) -> bool:
"""Complete a list item value by adding it to the completed results.
Returns True if the value is awaitable.
"""
is_awaitable = self.is_awaitable
try:
if is_awaitable(item):
completed_item: Any

async def await_completed() -> Any:
completed = self.complete_value(
item_type,
field_nodes,
info,
item_path,
await item,
async_payload_record,
)
return await completed if is_awaitable(completed) else completed

completed_item = await_completed()
else:
completed_item = self.complete_value(
item_type,
field_nodes,
info,
item_path,
item,
async_payload_record,
)

if is_awaitable(completed_item):
# noinspection PyShadowingNames
async def catch_error() -> Any:
try:
return await completed_item
except Exception as raw_error:
error = located_error(
raw_error, field_nodes, item_path.as_list()
)
handle_field_error(error, item_type, errors)
self.filter_subsequent_payloads(item_path, async_payload_record)
return None

complete_results.append(catch_error())
return True

complete_results.append(completed_item)

except Exception as raw_error:
error = located_error(raw_error, field_nodes, item_path.as_list())
handle_field_error(error, item_type, errors)
self.filter_subsequent_payloads(item_path, async_payload_record)
complete_results.append(None)

return False

@staticmethod
def complete_leaf_value(return_type: GraphQLLeafType, result: Any) -> Any:
"""Complete a leaf value.
Expand Down
1 change: 1 addition & 0 deletions tests/execution/test_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -1230,6 +1230,7 @@ async def friend_list(_info):
}

@pytest.mark.asyncio()
@pytest.mark.filterwarnings("ignore:.* was never awaited:RuntimeWarning")
async def does_not_filter_payloads_when_null_error_is_in_a_different_path():
document = parse(
"""
Expand Down

0 comments on commit c698ab5

Please sign in to comment.