Skip to content

Commit

Permalink
parser: limit maximum number of tokens
Browse files Browse the repository at this point in the history
  • Loading branch information
Cito committed Sep 22, 2022
1 parent 52daf76 commit 9f67f3f
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 11 deletions.
55 changes: 44 additions & 11 deletions src/graphql/language/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,15 +73,22 @@
def parse(
source: SourceType,
no_location: bool = False,
max_tokens: Optional[int] = None,
allow_legacy_fragment_variables: bool = False,
) -> DocumentNode:
"""Given a GraphQL source, parse it into a Document.
Throws GraphQLError if a syntax error is encountered.
By default, the parser creates AST nodes that know the location in the source that
they correspond to. The ``no_location`` option disables that behavior for
performance or testing.
they correspond to. Setting the ``no_location`` parameter to False disables that
behavior for performance or testing.
Parser CPU and memory usage is linear to the number of tokens in a document,
however in extreme cases it becomes quadratic due to memory exhaustion.
Parsing happens before validation, so even invalid queries can burn lots of
CPU time and memory. To prevent this, you can set a maximum number of tokens
allowed within a document using the ``max_tokens`` parameter.
Legacy feature (will be removed in v3.3):
Expand All @@ -100,6 +107,7 @@ def parse(
parser = Parser(
source,
no_location=no_location,
max_tokens=max_tokens,
allow_legacy_fragment_variables=allow_legacy_fragment_variables,
)
return parser.parse_document()
Expand All @@ -108,6 +116,7 @@ def parse(
def parse_value(
source: SourceType,
no_location: bool = False,
max_tokens: Optional[int] = None,
allow_legacy_fragment_variables: bool = False,
) -> ValueNode:
"""Parse the AST for a given string containing a GraphQL value.
Expand All @@ -123,6 +132,7 @@ def parse_value(
parser = Parser(
source,
no_location=no_location,
max_tokens=max_tokens,
allow_legacy_fragment_variables=allow_legacy_fragment_variables,
)
parser.expect_token(TokenKind.SOF)
Expand All @@ -134,6 +144,7 @@ def parse_value(
def parse_const_value(
source: SourceType,
no_location: bool = False,
max_tokens: Optional[int] = None,
allow_legacy_fragment_variables: bool = False,
) -> ConstValueNode:
"""Parse the AST for a given string containing a GraphQL constant value.
Expand All @@ -144,6 +155,7 @@ def parse_const_value(
parser = Parser(
source,
no_location=no_location,
max_tokens=max_tokens,
allow_legacy_fragment_variables=allow_legacy_fragment_variables,
)
parser.expect_token(TokenKind.SOF)
Expand All @@ -155,6 +167,7 @@ def parse_const_value(
def parse_type(
source: SourceType,
no_location: bool = False,
max_tokens: Optional[int] = None,
allow_legacy_fragment_variables: bool = False,
) -> TypeNode:
"""Parse the AST for a given string containing a GraphQL Type.
Expand All @@ -170,6 +183,7 @@ def parse_type(
parser = Parser(
source,
no_location=no_location,
max_tokens=max_tokens,
allow_legacy_fragment_variables=allow_legacy_fragment_variables,
)
parser.expect_token(TokenKind.SOF)
Expand All @@ -191,13 +205,16 @@ class Parser:
"""

_lexer: Lexer
_no_Location: bool
_no_location: bool
_max_tokens: Optional[int]
_allow_legacy_fragment_variables: bool
_token_counter: int

def __init__(
self,
source: SourceType,
no_location: bool = False,
max_tokens: Optional[int] = None,
allow_legacy_fragment_variables: bool = False,
):
source = (
Expand All @@ -206,7 +223,9 @@ def __init__(

self._lexer = Lexer(source)
self._no_location = no_location
self._max_tokens = max_tokens
self._allow_legacy_fragment_variables = allow_legacy_fragment_variables
self._token_counter = 0

def parse_name(self) -> NameNode:
"""Convert a name lex token into a name parse node."""
Expand Down Expand Up @@ -477,7 +496,7 @@ def parse_value_literal(self, is_const: bool) -> ValueNode:

def parse_string_literal(self, _is_const: bool = False) -> StringValueNode:
token = self._lexer.token
self._lexer.advance()
self.advance_lexer()
return StringValueNode(
value=token.value,
block=token.kind == TokenKind.BLOCK_STRING,
Expand Down Expand Up @@ -514,18 +533,18 @@ def parse_object(self, is_const: bool) -> ObjectValueNode:

def parse_int(self, _is_const: bool = False) -> IntValueNode:
token = self._lexer.token
self._lexer.advance()
self.advance_lexer()
return IntValueNode(value=token.value, loc=self.loc(token))

def parse_float(self, _is_const: bool = False) -> FloatValueNode:
token = self._lexer.token
self._lexer.advance()
self.advance_lexer()
return FloatValueNode(value=token.value, loc=self.loc(token))

def parse_named_values(self, _is_const: bool = False) -> ValueNode:
token = self._lexer.token
value = token.value
self._lexer.advance()
self.advance_lexer()
if value == "true":
return BooleanValueNode(value=True, loc=self.loc(token))
if value == "false":
Expand Down Expand Up @@ -1020,7 +1039,7 @@ def expect_token(self, kind: TokenKind) -> Token:
"""
token = self._lexer.token
if token.kind == kind:
self._lexer.advance()
self.advance_lexer()
return token

raise GraphQLSyntaxError(
Expand All @@ -1037,7 +1056,7 @@ def expect_optional_token(self, kind: TokenKind) -> bool:
"""
token = self._lexer.token
if token.kind == kind:
self._lexer.advance()
self.advance_lexer()
return True

return False
Expand All @@ -1050,7 +1069,7 @@ def expect_keyword(self, value: str) -> None:
"""
token = self._lexer.token
if token.kind == TokenKind.NAME and token.value == value:
self._lexer.advance()
self.advance_lexer()
else:
raise GraphQLSyntaxError(
self._lexer.source,
Expand All @@ -1066,7 +1085,7 @@ def expect_optional_keyword(self, value: str) -> bool:
"""
token = self._lexer.token
if token.kind == TokenKind.NAME and token.value == value:
self._lexer.advance()
self.advance_lexer()
return True

return False
Expand Down Expand Up @@ -1154,6 +1173,20 @@ def delimited_many(
break
return nodes

def advance_lexer(self) -> None:
max_tokens = self._max_tokens
token = self._lexer.advance()

if max_tokens is not None and token.kind != TokenKind.EOF:
self._token_counter += 1
if self._token_counter > max_tokens:
raise GraphQLSyntaxError(
self._lexer.source,
token.start,
f"Document contains more than {max_tokens} tokens."
" Parsing aborted.",
)


def get_token_desc(token: Token) -> str:
"""Describe a token as a string for debugging."""
Expand Down
17 changes: 17 additions & 0 deletions tests/language/test_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,23 @@ def parse_provides_useful_error_when_using_source():
"""
)

def limits_maximum_number_of_tokens():
assert parse("{ foo }", max_tokens=3)
with raises(
GraphQLSyntaxError,
match="Syntax Error: Document contains more than 2 tokens."
" Parsing aborted.",
):
assert parse("{ foo }", max_tokens=2)

assert parse('{ foo(bar: "baz") }', max_tokens=8)
with raises(
GraphQLSyntaxError,
match="Syntax Error: Document contains more than 7 tokens."
" Parsing aborted.",
):
assert parse('{ foo(bar: "baz") }', max_tokens=7)

def parses_variable_inline_values():
parse("{ field(complex: { a: { b: [ $var ] } }) }")

Expand Down

0 comments on commit 9f67f3f

Please sign in to comment.