Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prevent subscription execution without websocket #1165

Merged
merged 4 commits into from
Feb 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
- Added `execute_get_queries` setting to the `GraphQL` apps that controls execution of the GraphQL "query" operations made with GET requests. Defaults to `False`.
- Added support for the Apollo Federation versions up to 2.6.
- Fixed deprecation warnings in Apollo Tracing extension.
- Added a check to prevent `subscription` operation execution when query is made with POST request.


## 0.22 (2024-01-31)
Expand Down
41 changes: 41 additions & 0 deletions ariadne/graphql.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,8 @@ async def graphql(

if require_query:
validate_operation_is_query(document, operation_name)
else:
validate_operation_is_not_subscription(document, operation_name)

if callable(root_value):
try:
Expand Down Expand Up @@ -358,6 +360,8 @@ def graphql_sync(

if require_query:
validate_operation_is_query(document, operation_name)
else:
validate_operation_is_not_subscription(document, operation_name)

if callable(root_value):
try:
Expand Down Expand Up @@ -680,3 +684,40 @@ def validate_operation_is_query(
raise GraphQLError(
"'operationName' is required if 'query' defines multiple operations."
)


def validate_operation_is_not_subscription(
document_ast: DocumentNode, operation_name: Optional[str]
):
if operation_name:
validate_named_operation_is_not_subscription(document_ast, operation_name)
else:
validate_anonymous_operation_is_not_subscription(document_ast)


def validate_named_operation_is_not_subscription(
document_ast: DocumentNode, operation_name: str
):
for definition in document_ast.definitions:
if (
isinstance(definition, OperationDefinitionNode)
and definition.name
and definition.name.value == operation_name
and definition.operation.name == "SUBSCRIPTION"
):
raise GraphQLError(
f"Operation '{operation_name}' is a subscription and can only be "
"executed over a WebSocket connection."
)


def validate_anonymous_operation_is_not_subscription(document_ast: DocumentNode):
operations: List[OperationDefinitionNode] = []
for definition in document_ast.definitions:
if isinstance(definition, OperationDefinitionNode):
operations.append(definition)

if len(operations) == 1 and operations[0].operation.name == "SUBSCRIPTION":
raise GraphQLError(
"Subscription operations can only be executed over a WebSocket connection."
)
18 changes: 18 additions & 0 deletions tests/asgi/__snapshots__/test_query_execution.ambr
Original file line number Diff line number Diff line change
@@ -1,4 +1,13 @@
# serializer version: 1
# name: test_attempt_execute_anonymous_subscription_over_post_returns_error_json
dict({
'errors': list([
dict({
'message': 'Subscription operations can only be executed over a WebSocket connection.',
}),
]),
})
# ---
# name: test_attempt_execute_complex_query_without_variables_returns_error_json
dict({
'data': None,
Expand Down Expand Up @@ -61,6 +70,15 @@
]),
})
# ---
# name: test_attempt_execute_subscription_over_post_returns_error_json
dict({
'errors': list([
dict({
'message': "Operation 'Test' is a subscription and can only be executed over a WebSocket connection.",
}),
]),
})
# ---
# name: test_attempt_execute_subscription_with_invalid_query_returns_error_json
dict({
'locations': list([
Expand Down
20 changes: 20 additions & 0 deletions tests/asgi/test_query_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,26 @@ def test_attempt_execute_query_with_invalid_operation_name_type_returns_error_js
assert snapshot == response.json()


def test_attempt_execute_anonymous_subscription_over_post_returns_error_json(
client, snapshot
):
response = client.post("/", json={"query": "subscription { ping }"})
assert response.status_code == 400
assert snapshot == response.json()


def test_attempt_execute_subscription_over_post_returns_error_json(client, snapshot):
response = client.post(
"/",
json={
"query": "subscription Test { ping }",
"operationName": "Test",
},
)
assert response.status_code == 400
assert snapshot == response.json()


def test_attempt_execute_subscription_with_invalid_query_returns_error_json(
client, snapshot
):
Expand Down
18 changes: 18 additions & 0 deletions tests/wsgi/__snapshots__/test_query_execution.ambr
Original file line number Diff line number Diff line change
@@ -1,4 +1,13 @@
# serializer version: 1
# name: test_attempt_execute_anonymous_subscription_over_post_returns_error_json
dict({
'errors': list([
dict({
'message': 'Subscription operations can only be executed over a WebSocket connection.',
}),
]),
})
# ---
# name: test_attempt_execute_complex_query_without_variables_returns_error_json
dict({
'data': None,
Expand Down Expand Up @@ -61,6 +70,15 @@
]),
})
# ---
# name: test_attempt_execute_subscription_over_post_returns_error_json
dict({
'errors': list([
dict({
'message': "Operation 'Test' is a subscription and can only be executed over a WebSocket connection.",
}),
]),
})
# ---
# name: test_complex_query_is_executed_for_post_json_request
dict({
'data': dict({
Expand Down
32 changes: 32 additions & 0 deletions tests/wsgi/test_query_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,38 @@ def test_attempt_execute_query_with_invalid_operation_name_type_returns_error_js
assert_json_response_equals_snapshot(result)


def test_attempt_execute_anonymous_subscription_over_post_returns_error_json(
middleware,
start_response,
graphql_query_request_factory,
graphql_response_headers,
assert_json_response_equals_snapshot,
):
request = graphql_query_request_factory(query="subscription { ping }")
result = middleware(request, start_response)
start_response.assert_called_once_with(
HTTP_STATUS_400_BAD_REQUEST, graphql_response_headers
)
assert_json_response_equals_snapshot(result)


def test_attempt_execute_subscription_over_post_returns_error_json(
middleware,
start_response,
graphql_query_request_factory,
graphql_response_headers,
assert_json_response_equals_snapshot,
):
request = graphql_query_request_factory(
query="subscription Test { ping }", operationName="Test"
)
result = middleware(request, start_response)
start_response.assert_called_once_with(
HTTP_STATUS_400_BAD_REQUEST, graphql_response_headers
)
assert_json_response_equals_snapshot(result)


def test_query_is_executed_for_multipart_form_request_with_file(
middleware, snapshot, start_response, graphql_response_headers
):
Expand Down
Loading