From 9f67f3f6c4375231f86cfb936f42d4ce9a6055f7 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Thu, 22 Sep 2022 23:17:38 +0200 Subject: [PATCH] parser: limit maximum number of tokens Replicates graphql/graphql-js@f0a0a4dadffe41dae541ab297f95997435b27c57 --- src/graphql/language/parser.py | 55 +++++++++++++++++++++++++++------- tests/language/test_parser.py | 17 +++++++++++ 2 files changed, 61 insertions(+), 11 deletions(-) diff --git a/src/graphql/language/parser.py b/src/graphql/language/parser.py index 389913a5..97af0746 100644 --- a/src/graphql/language/parser.py +++ b/src/graphql/language/parser.py @@ -73,6 +73,7 @@ def parse( source: SourceType, no_location: bool = False, + max_tokens: Optional[int] = None, allow_legacy_fragment_variables: bool = False, ) -> DocumentNode: """Given a GraphQL source, parse it into a Document. @@ -80,8 +81,14 @@ def parse( Throws GraphQLError if a syntax error is encountered. By default, the parser creates AST nodes that know the location in the source that - they correspond to. The ``no_location`` option disables that behavior for - performance or testing. + they correspond to. Setting the ``no_location`` parameter to False disables that + behavior for performance or testing. + + Parser CPU and memory usage is linear to the number of tokens in a document, + however in extreme cases it becomes quadratic due to memory exhaustion. + Parsing happens before validation, so even invalid queries can burn lots of + CPU time and memory. To prevent this, you can set a maximum number of tokens + allowed within a document using the ``max_tokens`` parameter. Legacy feature (will be removed in v3.3): @@ -100,6 +107,7 @@ def parse( parser = Parser( source, no_location=no_location, + max_tokens=max_tokens, allow_legacy_fragment_variables=allow_legacy_fragment_variables, ) return parser.parse_document() @@ -108,6 +116,7 @@ def parse( def parse_value( source: SourceType, no_location: bool = False, + max_tokens: Optional[int] = None, allow_legacy_fragment_variables: bool = False, ) -> ValueNode: """Parse the AST for a given string containing a GraphQL value. @@ -123,6 +132,7 @@ def parse_value( parser = Parser( source, no_location=no_location, + max_tokens=max_tokens, allow_legacy_fragment_variables=allow_legacy_fragment_variables, ) parser.expect_token(TokenKind.SOF) @@ -134,6 +144,7 @@ def parse_value( def parse_const_value( source: SourceType, no_location: bool = False, + max_tokens: Optional[int] = None, allow_legacy_fragment_variables: bool = False, ) -> ConstValueNode: """Parse the AST for a given string containing a GraphQL constant value. @@ -144,6 +155,7 @@ def parse_const_value( parser = Parser( source, no_location=no_location, + max_tokens=max_tokens, allow_legacy_fragment_variables=allow_legacy_fragment_variables, ) parser.expect_token(TokenKind.SOF) @@ -155,6 +167,7 @@ def parse_const_value( def parse_type( source: SourceType, no_location: bool = False, + max_tokens: Optional[int] = None, allow_legacy_fragment_variables: bool = False, ) -> TypeNode: """Parse the AST for a given string containing a GraphQL Type. @@ -170,6 +183,7 @@ def parse_type( parser = Parser( source, no_location=no_location, + max_tokens=max_tokens, allow_legacy_fragment_variables=allow_legacy_fragment_variables, ) parser.expect_token(TokenKind.SOF) @@ -191,13 +205,16 @@ class Parser: """ _lexer: Lexer - _no_Location: bool + _no_location: bool + _max_tokens: Optional[int] _allow_legacy_fragment_variables: bool + _token_counter: int def __init__( self, source: SourceType, no_location: bool = False, + max_tokens: Optional[int] = None, allow_legacy_fragment_variables: bool = False, ): source = ( @@ -206,7 +223,9 @@ def __init__( self._lexer = Lexer(source) self._no_location = no_location + self._max_tokens = max_tokens self._allow_legacy_fragment_variables = allow_legacy_fragment_variables + self._token_counter = 0 def parse_name(self) -> NameNode: """Convert a name lex token into a name parse node.""" @@ -477,7 +496,7 @@ def parse_value_literal(self, is_const: bool) -> ValueNode: def parse_string_literal(self, _is_const: bool = False) -> StringValueNode: token = self._lexer.token - self._lexer.advance() + self.advance_lexer() return StringValueNode( value=token.value, block=token.kind == TokenKind.BLOCK_STRING, @@ -514,18 +533,18 @@ def parse_object(self, is_const: bool) -> ObjectValueNode: def parse_int(self, _is_const: bool = False) -> IntValueNode: token = self._lexer.token - self._lexer.advance() + self.advance_lexer() return IntValueNode(value=token.value, loc=self.loc(token)) def parse_float(self, _is_const: bool = False) -> FloatValueNode: token = self._lexer.token - self._lexer.advance() + self.advance_lexer() return FloatValueNode(value=token.value, loc=self.loc(token)) def parse_named_values(self, _is_const: bool = False) -> ValueNode: token = self._lexer.token value = token.value - self._lexer.advance() + self.advance_lexer() if value == "true": return BooleanValueNode(value=True, loc=self.loc(token)) if value == "false": @@ -1020,7 +1039,7 @@ def expect_token(self, kind: TokenKind) -> Token: """ token = self._lexer.token if token.kind == kind: - self._lexer.advance() + self.advance_lexer() return token raise GraphQLSyntaxError( @@ -1037,7 +1056,7 @@ def expect_optional_token(self, kind: TokenKind) -> bool: """ token = self._lexer.token if token.kind == kind: - self._lexer.advance() + self.advance_lexer() return True return False @@ -1050,7 +1069,7 @@ def expect_keyword(self, value: str) -> None: """ token = self._lexer.token if token.kind == TokenKind.NAME and token.value == value: - self._lexer.advance() + self.advance_lexer() else: raise GraphQLSyntaxError( self._lexer.source, @@ -1066,7 +1085,7 @@ def expect_optional_keyword(self, value: str) -> bool: """ token = self._lexer.token if token.kind == TokenKind.NAME and token.value == value: - self._lexer.advance() + self.advance_lexer() return True return False @@ -1154,6 +1173,20 @@ def delimited_many( break return nodes + def advance_lexer(self) -> None: + max_tokens = self._max_tokens + token = self._lexer.advance() + + if max_tokens is not None and token.kind != TokenKind.EOF: + self._token_counter += 1 + if self._token_counter > max_tokens: + raise GraphQLSyntaxError( + self._lexer.source, + token.start, + f"Document contains more than {max_tokens} tokens." + " Parsing aborted.", + ) + def get_token_desc(token: Token) -> str: """Describe a token as a string for debugging.""" diff --git a/tests/language/test_parser.py b/tests/language/test_parser.py index 027a605b..650e4c65 100644 --- a/tests/language/test_parser.py +++ b/tests/language/test_parser.py @@ -91,6 +91,23 @@ def parse_provides_useful_error_when_using_source(): """ ) + def limits_maximum_number_of_tokens(): + assert parse("{ foo }", max_tokens=3) + with raises( + GraphQLSyntaxError, + match="Syntax Error: Document contains more than 2 tokens." + " Parsing aborted.", + ): + assert parse("{ foo }", max_tokens=2) + + assert parse('{ foo(bar: "baz") }', max_tokens=8) + with raises( + GraphQLSyntaxError, + match="Syntax Error: Document contains more than 7 tokens." + " Parsing aborted.", + ): + assert parse('{ foo(bar: "baz") }', max_tokens=7) + def parses_variable_inline_values(): parse("{ field(complex: { a: { b: [ $var ] } }) }")