Skip to content

Commit

Permalink
feat: updated rate limiting for enhance security
Browse files Browse the repository at this point in the history
  • Loading branch information
MikeSoft007 committed Aug 23, 2024
1 parent 91a5413 commit f027eca
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 5 deletions.
30 changes: 25 additions & 5 deletions api/v1/routes/auth.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from datetime import timedelta
from slowapi import Limiter
from slowapi.util import get_remote_address

from fastapi import (BackgroundTasks, Depends,
status, APIRouter,
Response, Request)
Expand Down Expand Up @@ -27,9 +30,12 @@

auth = APIRouter(prefix="/auth", tags=["Authentication"])

# Initialize rate limiter
limiter = Limiter(key_func=get_remote_address)

@auth.post("/register", status_code=status.HTTP_201_CREATED, response_model=auth_response)
def register(background_tasks: BackgroundTasks, response: Response, user_schema: UserCreate, db: Session = Depends(get_db)):
@limiter.limit("10/minute") # Limit to 10 requests per minute per IP
def register(request: Request, background_tasks: BackgroundTasks, response: Response, user_schema: UserCreate, db: Session = Depends(get_db)):
'''Endpoint for a user to register their account'''

# Create user account
Expand Down Expand Up @@ -88,7 +94,8 @@ def register(background_tasks: BackgroundTasks, response: Response, user_schema:


@auth.post(path="/register-super-admin", status_code=status.HTTP_201_CREATED, response_model=auth_response)
def register_as_super_admin(user: UserCreate, db: Session = Depends(get_db)):
@limiter.limit("5/minute") # Limit to 5 requests per minute per IP
def register_as_super_admin(request: Request, user: UserCreate, db: Session = Depends(get_db)):
"""Endpoint for super admin creation"""

user = user_service.create_admin(db=db, schema=user)
Expand Down Expand Up @@ -131,7 +138,8 @@ def register_as_super_admin(user: UserCreate, db: Session = Depends(get_db)):


@auth.post("/login", status_code=status.HTTP_200_OK, response_model=auth_response)
def login(login_request: LoginRequest, db: Session = Depends(get_db)):
@limiter.limit("10/minute") # Limit to 10 requests per minute per IP
def login(request: Request, login_request: LoginRequest, db: Session = Depends(get_db)):
"""Endpoint to log in a user"""

# Authenticate the user
Expand Down Expand Up @@ -171,7 +179,9 @@ def login(login_request: LoginRequest, db: Session = Depends(get_db)):


@auth.post("/logout", status_code=status.HTTP_200_OK)
@limiter.limit("10/minute") # Limit to 10 requests per minute per IP
def logout(
request: Request,
response: Response,
db: Session = Depends(get_db),
current_user: User = Depends(user_service.get_current_user),
Expand All @@ -187,6 +197,7 @@ def logout(


@auth.post("/refresh-access-token", status_code=status.HTTP_200_OK)
@limiter.limit("10/minute") # Limit to 10 requests per minute per IP
def refresh_access_token(
request: Request, response: Response, db: Session = Depends(get_db)
):
Expand Down Expand Up @@ -220,7 +231,8 @@ def refresh_access_token(


@auth.post("/request-token", status_code=status.HTTP_200_OK)
async def request_signin_token(background_tasks: BackgroundTasks,
@limiter.limit("10/minute") # Limit to 10 requests per minute per IP
async def request_signin_token(request: Request, background_tasks: BackgroundTasks,
email_schema: EmailRequest, db: Session = Depends(get_db)
):
"""Generate and send a 6-digit sign-in token to the user's email"""
Expand Down Expand Up @@ -253,7 +265,9 @@ async def request_signin_token(background_tasks: BackgroundTasks,


@auth.post("/verify-token", status_code=status.HTTP_200_OK, response_model=auth_response)
@limiter.limit("10/minute") # Limit to 10 requests per minute per IP
async def verify_signin_token(
request: Request,
token_schema: TokenRequest, db: Session = Depends(get_db)
):
"""Verify the 6-digit sign-in token and log in the user"""
Expand Down Expand Up @@ -294,6 +308,7 @@ async def verify_signin_token(

# TODO: Fix magic link authentication
@auth.post("/magic-link", status_code=status.HTTP_200_OK)
@limiter.limit("10/minute") # Limit to 10 requests per minute per IP
def request_magic_link(
request: MagicLinkRequest, background_tasks: BackgroundTasks,
response: Response, db: Session = Depends(get_db)
Expand All @@ -319,7 +334,8 @@ def request_magic_link(


@auth.post("/magic-link/verify")
async def verify_magic_link(token_schema: Token, db: Session = Depends(get_db)):
@limiter.limit("10/minute") # Limit to 10 requests per minute per IP
async def verify_magic_link(request: Request, token_schema: Token, db: Session = Depends(get_db)):
user, access_token = AuthService.verify_magic_token(token_schema.token, db)
user_organizations = organisation_service.retrieve_user_organizations(user, db)

Expand Down Expand Up @@ -352,7 +368,9 @@ async def verify_magic_link(token_schema: Token, db: Session = Depends(get_db)):


@auth.put("/password", status_code=200)
@limiter.limit("10/minute") # Limit to 10 requests per minute per IP
async def change_password(
request: Request,
schema: ChangePasswordSchema,
db: Session = Depends(get_db),
user: User = Depends(user_service.get_current_user),
Expand All @@ -369,7 +387,9 @@ async def change_password(
@auth.get("/@me",
status_code=status.HTTP_200_OK,
response_model=AuthMeResponse)
@limiter.limit("10/minute") # Limit to 10 requests per minute per IP
def get_current_user_details(
request: Request,
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(user_service.get_current_user)],
):
Expand Down
7 changes: 7 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import uvicorn, os
from sqlalchemy.exc import IntegrityError
from fastapi import HTTPException, Request
from slowapi import Limiter
from slowapi.util import get_remote_address
from fastapi.templating import Jinja2Templates
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
Expand Down Expand Up @@ -34,6 +36,11 @@ async def lifespan(app: FastAPI):
version="1.0.0",
)


# Initialize the rate limiter
limiter = Limiter(key_func=get_remote_address)
app.state.limiter = limiter

# Set up email templates and css static files
email_templates = Jinja2Templates(directory='api/core/dependencies/email/templates')

Expand Down

0 comments on commit f027eca

Please sign in to comment.