Skip to content

Commit

Permalink
Validation: Allow to limit maximum number of validation errors
Browse files Browse the repository at this point in the history
  • Loading branch information
Cito committed Sep 14, 2019
1 parent 712ef84 commit 44ac1d9
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 10 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ The current version 3.0.0a2 of GraphQL-core is up-to-date
with GraphQL.js version 14.4.2.

All parts of the API are covered by an extensive test suite
of currently 1979 unit tests.
of currently 1981 unit tests.


## Documentation
Expand Down
32 changes: 29 additions & 3 deletions src/graphql/validation/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,16 @@
__all__ = ["assert_valid_sdl", "assert_valid_sdl_extension", "validate", "validate_sdl"]


class ValidationAbortedError(RuntimeError):
"""Error when a validation has been aborted (error limit reached)."""


def validate(
schema: GraphQLSchema,
document_ast: DocumentNode,
rules: Sequence[RuleType] = None,
type_info: TypeInfo = None,
max_errors: int = None,
) -> List[GraphQLError]:
"""Implements the "Validation" section of the spec.
Expand Down Expand Up @@ -45,13 +50,34 @@ def validate(
rules = specified_rules
elif not isinstance(rules, (list, tuple)):
raise TypeError("Rules must be passed as a list/tuple.")
context = ValidationContext(schema, document_ast, type_info)
if max_errors is not None and not isinstance(max_errors, int):
raise TypeError("The maximum number of errors must be passed as an int.")

errors: List[GraphQLError] = []

def on_error(error: GraphQLError) -> None:
if max_errors is not None and len(errors) >= max_errors:
errors.append(
GraphQLError(
"Too many validation errors, error limit reached."
" Validation aborted."
)
)
raise ValidationAbortedError
errors.append(error)

context = ValidationContext(schema, document_ast, type_info, on_error)

# This uses a specialized visitor which runs multiple visitors in parallel,
# while maintaining the visitor skip and break API.
visitors = [rule(context) for rule in rules]

# Visit the whole document with each instance of all provided rules.
visit(document_ast, TypeInfoVisitor(type_info, ParallelVisitor(visitors)))
return context.errors
try:
visit(document_ast, TypeInfoVisitor(type_info, ParallelVisitor(visitors)))
except ValidationAbortedError:
pass
return errors


def validate_sdl(
Expand Down
27 changes: 21 additions & 6 deletions src/graphql/validation/validation_context.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, List, NamedTuple, Optional, Set, Union, cast
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Set, Union, cast

from ..error import GraphQLError
from ..language import (
Expand Down Expand Up @@ -62,10 +62,14 @@ class ASTValidationContext:
"""

document: DocumentNode
on_error: Optional[Callable[[GraphQLError], None]]
errors: List[GraphQLError]

def __init__(self, ast: DocumentNode) -> None:
def __init__(
self, ast: DocumentNode, on_error: Callable[[GraphQLError], None] = None
) -> None:
self.document = ast
self.on_error = on_error
self.errors = []
self._fragments: Optional[Dict[str, FragmentDefinitionNode]] = None
self._fragment_spreads: Dict[SelectionSetNode, List[FragmentSpreadNode]] = {}
Expand All @@ -75,6 +79,8 @@ def __init__(self, ast: DocumentNode) -> None:

def report_error(self, error: GraphQLError):
self.errors.append(error)
if self.on_error:
self.on_error(error)

def get_fragment(self, name: str) -> Optional[FragmentDefinitionNode]:
fragments = self._fragments
Expand Down Expand Up @@ -146,8 +152,13 @@ class SDLValidationContext(ASTValidationContext):

schema: Optional[GraphQLSchema]

def __init__(self, ast: DocumentNode, schema: GraphQLSchema = None) -> None:
super().__init__(ast)
def __init__(
self,
ast: DocumentNode,
schema: GraphQLSchema = None,
on_error: Callable[[GraphQLError], None] = None,
) -> None:
super().__init__(ast, on_error)
self.schema = schema


Expand All @@ -162,9 +173,13 @@ class ValidationContext(ASTValidationContext):
schema: GraphQLSchema

def __init__(
self, schema: GraphQLSchema, ast: DocumentNode, type_info: TypeInfo
self,
schema: GraphQLSchema,
ast: DocumentNode,
type_info: TypeInfo,
on_error: Callable[[GraphQLError], None] = None,
) -> None:
super().__init__(ast)
super().__init__(ast, on_error)
self.schema = schema
self._type_info = type_info
self._variable_usages: Dict[NodeWithSelectionSet, List[VariableUsage]] = {}
Expand Down
39 changes: 39 additions & 0 deletions tests/validation/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,42 @@ def validates_using_a_custom_type_info():
"Cannot query field 'isHousetrained' on type 'Dog'."
" Did you mean 'isHousetrained'?",
]


def describe_validate_limit_maximum_number_of_validation_errors():
query = """
{
firstUnknownField
secondUnknownField
thirdUnknownField
}
"""
doc = parse(query, no_location=True)

def _validate_document(max_errors=None):
return validate(test_schema, doc, max_errors=max_errors)

def _invalid_field_error(field_name: str):
return {
"message": f"Cannot query field '{field_name}' on type 'QueryRoot'.",
"locations": [],
}

def when_max_errors_is_equal_to_number_of_errors():
errors = _validate_document(max_errors=3)
assert errors == [
_invalid_field_error("firstUnknownField"),
_invalid_field_error("secondUnknownField"),
_invalid_field_error("thirdUnknownField"),
]

def when_max_errors_is_less_than_number_of_errors():
errors = _validate_document(max_errors=2)
assert errors == [
_invalid_field_error("firstUnknownField"),
_invalid_field_error("secondUnknownField"),
{
"message": "Too many validation errors, error limit reached."
" Validation aborted."
},
]

0 comments on commit 44ac1d9

Please sign in to comment.