diff --git a/api/v1/routes/auth.py b/api/v1/routes/auth.py index 3de88aa63..9122595ee 100644 --- a/api/v1/routes/auth.py +++ b/api/v1/routes/auth.py @@ -1,4 +1,7 @@ from datetime import timedelta +from slowapi import Limiter +from slowapi.util import get_remote_address + from fastapi import (BackgroundTasks, Depends, status, APIRouter, Response, Request) @@ -27,9 +30,12 @@ auth = APIRouter(prefix="/auth", tags=["Authentication"]) +# Initialize rate limiter +limiter = Limiter(key_func=get_remote_address) @auth.post("/register", status_code=status.HTTP_201_CREATED, response_model=auth_response) -def register(background_tasks: BackgroundTasks, response: Response, user_schema: UserCreate, db: Session = Depends(get_db)): +@limiter.limit("10/minute") # Limit to 10 requests per minute per IP +def register(request: Request, background_tasks: BackgroundTasks, response: Response, user_schema: UserCreate, db: Session = Depends(get_db)): '''Endpoint for a user to register their account''' # Create user account @@ -88,7 +94,8 @@ def register(background_tasks: BackgroundTasks, response: Response, user_schema: @auth.post(path="/register-super-admin", status_code=status.HTTP_201_CREATED, response_model=auth_response) -def register_as_super_admin(user: UserCreate, db: Session = Depends(get_db)): +@limiter.limit("5/minute") # Limit to 5 requests per minute per IP +def register_as_super_admin(request: Request, user: UserCreate, db: Session = Depends(get_db)): """Endpoint for super admin creation""" user = user_service.create_admin(db=db, schema=user) @@ -131,7 +138,8 @@ def register_as_super_admin(user: UserCreate, db: Session = Depends(get_db)): @auth.post("/login", status_code=status.HTTP_200_OK, response_model=auth_response) -def login(login_request: LoginRequest, db: Session = Depends(get_db)): +@limiter.limit("10/minute") # Limit to 10 requests per minute per IP +def login(request: Request, login_request: LoginRequest, db: Session = Depends(get_db)): """Endpoint to log in a user""" # Authenticate the user @@ -171,7 +179,9 @@ def login(login_request: LoginRequest, db: Session = Depends(get_db)): @auth.post("/logout", status_code=status.HTTP_200_OK) +@limiter.limit("10/minute") # Limit to 10 requests per minute per IP def logout( + request: Request, response: Response, db: Session = Depends(get_db), current_user: User = Depends(user_service.get_current_user), @@ -187,6 +197,7 @@ def logout( @auth.post("/refresh-access-token", status_code=status.HTTP_200_OK) +@limiter.limit("10/minute") # Limit to 10 requests per minute per IP def refresh_access_token( request: Request, response: Response, db: Session = Depends(get_db) ): @@ -220,7 +231,8 @@ def refresh_access_token( @auth.post("/request-token", status_code=status.HTTP_200_OK) -async def request_signin_token(background_tasks: BackgroundTasks, +@limiter.limit("10/minute") # Limit to 10 requests per minute per IP +async def request_signin_token(request: Request, background_tasks: BackgroundTasks, email_schema: EmailRequest, db: Session = Depends(get_db) ): """Generate and send a 6-digit sign-in token to the user's email""" @@ -253,7 +265,9 @@ async def request_signin_token(background_tasks: BackgroundTasks, @auth.post("/verify-token", status_code=status.HTTP_200_OK, response_model=auth_response) +@limiter.limit("10/minute") # Limit to 10 requests per minute per IP async def verify_signin_token( + request: Request, token_schema: TokenRequest, db: Session = Depends(get_db) ): """Verify the 6-digit sign-in token and log in the user""" @@ -294,6 +308,7 @@ async def verify_signin_token( # TODO: Fix magic link authentication @auth.post("/magic-link", status_code=status.HTTP_200_OK) +@limiter.limit("10/minute") # Limit to 10 requests per minute per IP def request_magic_link( request: MagicLinkRequest, background_tasks: BackgroundTasks, response: Response, db: Session = Depends(get_db) @@ -319,7 +334,8 @@ def request_magic_link( @auth.post("/magic-link/verify") -async def verify_magic_link(token_schema: Token, db: Session = Depends(get_db)): +@limiter.limit("10/minute") # Limit to 10 requests per minute per IP +async def verify_magic_link(request: Request, token_schema: Token, db: Session = Depends(get_db)): user, access_token = AuthService.verify_magic_token(token_schema.token, db) user_organizations = organisation_service.retrieve_user_organizations(user, db) @@ -352,7 +368,9 @@ async def verify_magic_link(token_schema: Token, db: Session = Depends(get_db)): @auth.put("/password", status_code=200) +@limiter.limit("10/minute") # Limit to 10 requests per minute per IP async def change_password( + request: Request, schema: ChangePasswordSchema, db: Session = Depends(get_db), user: User = Depends(user_service.get_current_user), @@ -369,7 +387,9 @@ async def change_password( @auth.get("/@me", status_code=status.HTTP_200_OK, response_model=AuthMeResponse) +@limiter.limit("10/minute") # Limit to 10 requests per minute per IP def get_current_user_details( + request: Request, db: Annotated[Session, Depends(get_db)], current_user: Annotated[User, Depends(user_service.get_current_user)], ): diff --git a/main.py b/main.py index d80d994b2..26134c6a1 100644 --- a/main.py +++ b/main.py @@ -3,6 +3,8 @@ import uvicorn, os from sqlalchemy.exc import IntegrityError from fastapi import HTTPException, Request +from slowapi import Limiter +from slowapi.util import get_remote_address from fastapi.templating import Jinja2Templates from fastapi.exceptions import RequestValidationError from fastapi.responses import JSONResponse @@ -34,6 +36,11 @@ async def lifespan(app: FastAPI): version="1.0.0", ) + +# Initialize the rate limiter +limiter = Limiter(key_func=get_remote_address) +app.state.limiter = limiter + # Set up email templates and css static files email_templates = Jinja2Templates(directory='api/core/dependencies/email/templates')