diff --git a/flask_graphql/graphqlview.py b/flask_graphql/graphqlview.py index 40f074b..1cfc317 100644 --- a/flask_graphql/graphqlview.py +++ b/flask_graphql/graphqlview.py @@ -33,6 +33,8 @@ class GraphQLView(View): graphiql_template = None middleware = None batch = False + json_encoder = None + json_decoder = None methods = ['GET', 'POST', 'PUT', 'DELETE'] @@ -142,10 +144,10 @@ def get_response(self, request, data, show_graphiql=False): def json_encode(self, request, d, show_graphiql=False): pretty = self.pretty or show_graphiql or request.args.get('pretty') if not pretty: - return json.dumps(d, separators=(',', ':')) + return json.dumps(d, separators=(',', ':'), cls=self.json_encoder) - return json.dumps(d, sort_keys=True, - indent=2, separators=(',', ': ')) + return json.dumps(d, sort_keys=True, indent=2, + separators=(',', ': '), cls=self.json_encoder) # noinspection PyBroadException def parse_body(self, request): @@ -155,7 +157,7 @@ def parse_body(self, request): elif content_type == 'application/json': try: - request_json = json.loads(request.data.decode('utf8')) + request_json = json.loads(request.data.decode('utf8'), cls=self.json_decoder) if self.batch: assert isinstance(request_json, list) else: diff --git a/tests/encoder.py b/tests/encoder.py new file mode 100644 index 0000000..57b0854 --- /dev/null +++ b/tests/encoder.py @@ -0,0 +1,12 @@ +from json import JSONEncoder, JSONDecoder +from json.decoder import WHITESPACE + + +class TestJSONEncoder(JSONEncoder): + def encode(self, o): + return 'TESTSTRING' + + +class TestJSONDecoder(JSONDecoder): + def decode(self, s, _w=WHITESPACE.match): + return {'query': '{test}'} diff --git a/tests/test_graphqlview.py b/tests/test_graphqlview.py index 1efcfbc..013c97a 100644 --- a/tests/test_graphqlview.py +++ b/tests/test_graphqlview.py @@ -12,6 +12,7 @@ from urllib.parse import urlencode from .app import create_app +from .encoder import TestJSONEncoder, TestJSONDecoder from flask import url_for @@ -500,8 +501,8 @@ def test_batch_supports_post_json_query_with_json_variables(client): 'payload': { 'data': {'test': "Hello Dolly"} }, 'status': 200, }] - - + + @pytest.mark.parametrize('app', [create_app(batch=True)]) def test_batch_allows_post_with_operation_name(client): response = client.post( @@ -532,3 +533,23 @@ def test_batch_allows_post_with_operation_name(client): }, 'status': 200, }] + + +@pytest.mark.parametrize('app', [create_app(json_encoder=TestJSONEncoder)]) +def test_custom_encoder(client): + response = client.get(url_string(query='{test}')) + + # TestJSONEncoder just encodes everything to 'TESTSTRING' + assert response.data.decode() == 'TESTSTRING' + + +@pytest.mark.parametrize('app', [create_app(json_decoder=TestJSONDecoder)]) +def test_custom_decoder(client): + # The submitted data here of 'TEST' is clearly not valid JSON. The TestJSONDecoder will + # decode this into valid JSON with a valid gql query. + response = client.post(url_string(), data='TEST', content_type='application/json') + + assert response.status_code == 200 + assert response_json(response) == { + 'data': {'test': "Hello World"} + }