Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: apply api middleware correctly #129

Merged
merged 2 commits into from
Apr 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions nitric/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,9 +439,9 @@ async def chained_middleware(ctx: C, nxt: Optional[Middleware[C]] = None) -> C:

return chained_middleware

middleware_chain = functools.reduce(reduce_chain, reversed(middlewares)) # type: ignore
middleware_chain = functools.reduce(reduce_chain, reversed(middlewares), last_middleware) # type: ignore
# type ignored because mypy appears to misidentify the correct return type
return await middleware_chain(ctx, last_middleware) # type: ignore
return await middleware_chain(ctx) # type: ignore

return composed

Expand Down
10 changes: 10 additions & 0 deletions nitric/resources/apis.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,9 @@ def _route(self, match: str, opts: Optional[RouteOptions] = None) -> Route:
if opts is None:
opts = RouteOptions()

if self.middleware is not None:
opts.middleware = self.middleware + opts.middleware

r = Route(self, match, opts)
self.routes.append(r)
return r
Expand Down Expand Up @@ -339,6 +342,13 @@ def method(
self, methods: List[HttpMethod], *middleware: HttpMiddleware | HttpHandler, opts: Optional[MethodOptions] = None
) -> None:
"""Register middleware for multiple HTTP Methods."""

# ensure route/api middlewares are added
middleware = (
*self.middleware,
*middleware
)

Method(self, methods, *middleware, opts=opts if opts else MethodOptions())

def get(self, *middleware: HttpMiddleware | HttpHandler, opts: Optional[MethodOptions] = None) -> None:
Expand Down
20 changes: 19 additions & 1 deletion tests/resources/test_apis.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

# from nitric.faas import HttpMethod, MethodOptions, ApiWorkerOptions
from nitric.resources import api, ApiOptions, JwtSecurityDefinition
from nitric.resources.apis import MethodOptions, ScopedOidcOptions, oidc_rule
from nitric.resources.apis import MethodOptions, ScopedOidcOptions, oidc_rule, HttpMiddleware
from nitric.proto.resources.v1 import (
ApiOpenIdConnectionDefinition,
ApiSecurityDefinitionResource,
Expand All @@ -40,6 +40,7 @@
from nitric.proto.apis.v1 import ApiDetailsResponse, ApiDetailsRequest, ApiWorkerScopes

from nitric.context import (
HttpContext,
HttpMethod,
)

Expand Down Expand Up @@ -221,6 +222,23 @@ def test_api_route(self):
assert test_route.middleware == []
assert test_route.api.name == test_api.name

def test_api_route_middleware(self):
mock_declare = AsyncMock()
mock_response = Object()
mock_declare.return_value = mock_response

async def middleware_test(ctx: HttpContext, nxt: HttpMiddleware):
return nxt(ctx)

with patch("nitric.proto.resources.v1.ResourcesStub.declare", mock_declare):
test_api = api("test-api-route-middleware", ApiOptions(path="/api/v2/", middleware=[middleware_test]))

test_route = test_api._route("/test")

assert len(test_api.middleware) == 1
assert len(test_route.middleware) == 1


def test_define_route(self):
mock_declare = AsyncMock()
mock_response = Object()
Expand Down
Loading