Skip to content

Commit

Permalink
polish: do not repeat is_awaitable check
Browse files Browse the repository at this point in the history
  • Loading branch information
Cito committed Feb 15, 2024
1 parent 7987576 commit 5feaeeb
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 70 deletions.
152 changes: 82 additions & 70 deletions src/graphql/execution/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -1359,11 +1359,11 @@ def complete_list_item_value(
Returns True if the value is awaitable.
"""
is_awaitable = self.is_awaitable
try:
if is_awaitable(item):
completed_item: Any

async def await_completed() -> Any:
if is_awaitable(item):
# noinspection PyShadowingNames
async def await_completed() -> Any:
try:
completed = self.complete_value(
item_type,
field_nodes,
Expand All @@ -1373,21 +1373,28 @@ async def await_completed() -> Any:
async_payload_record,
)
return await completed if is_awaitable(completed) else completed
except Exception as raw_error:
error = located_error(raw_error, field_nodes, item_path.as_list())
handle_field_error(error, item_type, errors)
self.filter_subsequent_payloads(item_path, async_payload_record)
return None

completed_item = await_completed()
else:
completed_item = self.complete_value(
item_type,
field_nodes,
info,
item_path,
item,
async_payload_record,
)
complete_results.append(await_completed())
return True

try:
completed_item = self.complete_value(
item_type,
field_nodes,
info,
item_path,
item,
async_payload_record,
)

if is_awaitable(completed_item):
# noinspection PyShadowingNames
async def catch_error() -> Any:
async def await_completed() -> Any:
try:
return await completed_item
except Exception as raw_error:
Expand All @@ -1398,7 +1405,7 @@ async def catch_error() -> Any:
self.filter_subsequent_payloads(item_path, async_payload_record)
return None

complete_results.append(catch_error())
complete_results.append(await_completed())
return True

complete_results.append(completed_item)
Expand Down Expand Up @@ -1728,15 +1735,17 @@ def execute_stream_field(
parent_context: Optional[AsyncPayloadRecord] = None,
) -> AsyncPayloadRecord:
"""Execute stream field."""
is_awaitable = self.is_awaitable
async_payload_record = StreamRecord(
label, item_path, None, parent_context, self
)
completed_item: Any
try:
try:
if self.is_awaitable(item):

async def await_completed_item() -> Any:
if is_awaitable(item):
# noinspection PyShadowingNames
async def await_completed_items() -> Optional[List[Any]]:
try:
try:
completed = self.complete_value(
item_type,
field_nodes,
Expand All @@ -1745,76 +1754,79 @@ async def await_completed_item() -> Any:
await item,
async_payload_record,
)
return (
return [
await completed
if self.is_awaitable(completed)
else completed
]
except Exception as raw_error:
error = located_error(
raw_error, field_nodes, item_path.as_list()
)
handle_field_error(
error, item_type, async_payload_record.errors
)
self.filter_subsequent_payloads(item_path, async_payload_record)
return [None]
except GraphQLError as error:
async_payload_record.errors.append(error)
self.filter_subsequent_payloads(path, async_payload_record)
return None

completed_item = await_completed_item()
async_payload_record.add_items(await_completed_items())
return async_payload_record

else:
completed_item = self.complete_value(
item_type,
field_nodes,
info,
item_path,
item,
async_payload_record,
)
try:
try:
completed_item = self.complete_value(
item_type,
field_nodes,
info,
item_path,
item,
async_payload_record,
)

if self.is_awaitable(completed_item):
completed_items: Any

async def await_completed_item() -> Any:
if is_awaitable(completed_item):
# noinspection PyShadowingNames
async def await_completed_items() -> Optional[List[Any]]:
# noinspection PyShadowingNames
try:
return await completed_item
except Exception as raw_error:
# noinspection PyShadowingNames
error = located_error(
raw_error, field_nodes, item_path.as_list()
)
handle_field_error(
error, item_type, async_payload_record.errors
)
self.filter_subsequent_payloads(
item_path, async_payload_record
)
try:
return [await completed_item]
except Exception as raw_error: # pragma: no cover
# noinspection PyShadowingNames
error = located_error(
raw_error, field_nodes, item_path.as_list()
)
handle_field_error(
error, item_type, async_payload_record.errors
)
self.filter_subsequent_payloads(
item_path, async_payload_record
)
return [None]
except GraphQLError as error: # pragma: no cover
async_payload_record.errors.append(error)
self.filter_subsequent_payloads(path, async_payload_record)
return None

complete_item = await_completed_item()

completed_items = await_completed_items()
else:
complete_item = completed_item
completed_items = [completed_item]

except Exception as raw_error:
error = located_error(raw_error, field_nodes, item_path.as_list())
handle_field_error(error, item_type, async_payload_record.errors)
self.filter_subsequent_payloads( # pragma: no cover
item_path, async_payload_record
)
complete_item = None # pragma: no cover
self.filter_subsequent_payloads(item_path, async_payload_record)
completed_items = [None]

except GraphQLError as error:
async_payload_record.errors.append(error)
self.filter_subsequent_payloads(item_path, async_payload_record)
async_payload_record.add_items(None)
return async_payload_record

completed_items: AwaitableOrValue[Optional[List[Any]]]
if self.is_awaitable(complete_item):

async def await_completed_items() -> Optional[List[Any]]:
# noinspection PyShadowingNames
try:
return [await complete_item] # type: ignore
except GraphQLError as error:
async_payload_record.errors.append(error)
self.filter_subsequent_payloads(path, async_payload_record)
return None

completed_items = await_completed_items()
else:
completed_items = [complete_item]
completed_items = None

async_payload_record.add_items(completed_items)
return async_payload_record
Expand Down
2 changes: 2 additions & 0 deletions tests/execution/test_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,7 @@ async def asyncReturnErrorWithExtensions(self, _info):
],
)

@pytest.mark.filterwarnings("ignore:.* was never awaited:RuntimeWarning")
def handles_sync_errors_combined_with_async_ones():
is_async_resolver_finished = False

Expand Down Expand Up @@ -560,6 +561,7 @@ async def async_resolver(_obj, _info):
],
)

@pytest.mark.filterwarnings("ignore:.* was never awaited:RuntimeWarning")
def full_response_path_is_included_for_non_nullable_fields():
def resolve_ok(*_args):
return {}
Expand Down
49 changes: 49 additions & 0 deletions tests/execution/test_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,54 @@ async def await_friend(f):
},
]

@pytest.mark.asyncio()
async def can_stream_a_field_that_returns_a_list_with_nested_async_fields():
document = parse(
"""
query {
friendList @stream(initialCount: 2) {
name
id
}
}
"""
)

async def get_name(f):
return f.name

async def get_id(f):
return f.id

result = await complete(
document,
{
"friendList": lambda _info: [
{"name": get_name(f), "id": get_id(f)} for f in friends
]
},
)
assert result == [
{
"data": {
"friendList": [
{"name": "Luke", "id": "1"},
{"name": "Han", "id": "2"},
]
},
"hasNext": True,
},
{
"incremental": [
{
"items": [{"name": "Leia", "id": "3"}],
"path": ["friendList", 2],
}
],
"hasNext": False,
},
]

@pytest.mark.asyncio()
async def handles_error_in_list_of_awaitables_before_initial_count_reached():
document = parse(
Expand Down Expand Up @@ -1292,6 +1340,7 @@ async def friend_list(_info):
}

@pytest.mark.asyncio()
@pytest.mark.filterwarnings("ignore:.* was never awaited:RuntimeWarning")
async def does_not_filter_payloads_when_null_error_is_in_a_different_path():
document = parse(
"""
Expand Down

0 comments on commit 5feaeeb

Please sign in to comment.