diff --git a/rest_typed_views/decorators.py b/rest_typed_views/decorators.py index 4d9cd7d..77b5e02 100644 --- a/rest_typed_views/decorators.py +++ b/rest_typed_views/decorators.py @@ -1,7 +1,7 @@ import inspect -from functools import wraps from typing import Any, Dict, List +from rest_framework.views import APIView from rest_framework.decorators import action, api_view from rest_framework.exceptions import ValidationError from rest_framework.fields import empty @@ -20,6 +20,32 @@ from .params import BodyParam, CurrentUserParam, PassThruParam, PathParam, QueryParam, HeaderParam + +def wraps_drf(view): + def _wraps_drf(func): + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + + wrapper.__name__ = view.__name__ + wrapper.__module__ = view.__module__ + wrapper.renderer_classes = getattr( + view, "renderer_classes", APIView.renderer_classes + ) + wrapper.parser_classes = getattr(view, "parser_classes", APIView.parser_classes) + wrapper.authentication_classes = getattr( + view, "authentication_classes", APIView.authentication_classes + ) + wrapper.throttle_classes = getattr( + view, "throttle_classes", APIView.throttle_classes + ) + wrapper.permission_classes = getattr( + view, "permission_classes", APIView.permission_classes + ) + return wrapper + + return _wraps_drf + + def build_explicit_param( param: inspect.Parameter, request: Request, settings: ParamSettings, path_args: dict ): @@ -111,6 +137,7 @@ def wrap_validate_and_render(view): prevalidate(view) @api_view(methods) + @wraps_drf(view) def wrapper(*original_args, **original_kwargs): original_args = list(original_args) request = find_request(original_args) @@ -129,7 +156,7 @@ def wrap_validate_and_render(view): prevalidate(view, for_method=True) @action(**action_kwargs) - @wraps(view) + @wraps_drf(view) def wrapper(*original_args, **original_kwargs): original_args = list(original_args) request = find_request(original_args) diff --git a/test_project/testapp/views.py b/test_project/testapp/views.py index 5f98b8a..6a83916 100644 --- a/test_project/testapp/views.py +++ b/test_project/testapp/views.py @@ -1,7 +1,7 @@ from datetime import date, datetime, time, timedelta from decimal import Decimal from enum import Enum -from typing import List +from typing import List, Optional import marshmallow import typesystem diff --git a/test_project/urls.py b/test_project/urls.py index 10dc9f1..b71887b 100644 --- a/test_project/urls.py +++ b/test_project/urls.py @@ -18,6 +18,7 @@ urlpatterns = [ url(r"^logs/(?P[0-9])/", get_logs, name="get-log-entry"), url(r"^users/", create_user, name="create-user"), + url(r"^test/", test_view, name="test-view"), url(r"^bookings/", create_booking, name="create-booking"), url(r"^test/", test_view, name="test-view"), url(r"^band-members/", create_band_member, name="create-band-member"),