diff --git a/.env.sample b/.env.sample new file mode 100644 index 000000000..300e463bd --- /dev/null +++ b/.env.sample @@ -0,0 +1,14 @@ +PYTHON_ENV=dev +DB_TYPE=mysql +DB_NAME=dbname +DB_USER=user +DB_PASSWORD=password +DB_HOST=127.0.0.1 +DB_PORT=3306 +MYSQL_DRIVER=pymysql +DB_URL="mysql+pymysql://user:password@127.0.0.1:3306/dbname" +SECRET_KEY = "" +ALGORITHM = HS256 +ACCESS_TOKEN_EXPIRE_MINUTES = 10 +JWT_REFRESH_EXPIRY=5 +APP_URL= \ No newline at end of file diff --git a/README.md b/README.md index 3f43ea45b..27935829a 100644 --- a/README.md +++ b/README.md @@ -1 +1,22 @@ -# hng_boilerplate_python_web \ No newline at end of file +# FASTAPI +FastAPI boilerplate + +## Setup + +1. Create a virtual environment. + ```sh + python3 -m venv .venv + ``` +2. Activate virtual environment. +```sh + source /path/to/venv/bin/activate` +``` +3. Install project dependencies `pip install -r requirements.txt` +4. Create a .env file by copying the .env.sample file +`cp .env.sample .env` + +5. Start server. + ```sh + python main.py +``` + diff --git a/__init__.py b/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/alembic/README b/alembic/README new file mode 100644 index 000000000..98e4f9c44 --- /dev/null +++ b/alembic/README @@ -0,0 +1 @@ +Generic single-database configuration. \ No newline at end of file diff --git a/alembic/env.py b/alembic/env.py new file mode 100644 index 000000000..b78cb58bd --- /dev/null +++ b/alembic/env.py @@ -0,0 +1,89 @@ + +import os +from alembic import context +from decouple import config +from logging.config import fileConfig +from sqlalchemy import engine_from_config +from sqlalchemy import pool +from api.db.database import Base +from api.v1.models.auth import User, BlackListToken + +#from db.database import DATABASE_URL +DATABASE_URL=config('DB_URL') + + +# this is the Alembic Config object, which provides +# access to the values within the .ini file in use. +config = context.config + +# Interpret the config file for Python logging. +# This line sets up loggers basically. +if config.config_file_name is not None: + fileConfig(config.config_file_name) + +# add your model's MetaData object here +# for 'autogenerate' support +# from myapp import mymodel +# target_metadata = mymodel.Base.metadata +target_metadata = Base.metadata + +# other values from the config, defined by the needs of env.py, +# can be acquired: +# my_important_option = config.get_main_option("my_important_option") +# ... etc. + + +def run_migrations_offline() -> None: + """Run migrations in 'offline' mode. + + This configures the context with just a URL + and not an Engine, though an Engine is acceptable + here as well. By skipping the Engine creation + we don't even need a DBAPI to be available. + + Calls to context.execute() here emit the given string to the + script output. + + """ + url = config.get_main_option("sqlalchemy.url", DATABASE_URL) + context.configure( + url=url, + target_metadata=target_metadata, + literal_binds=True, + dialect_opts={"paramstyle": "named"}, + ) + + with context.begin_transaction(): + context.run_migrations() + + +def run_migrations_online() -> None: + """Run migrations in 'online' mode. + + In this scenario we need to create an Engine + and associate a connection with the context. + + """ + + alembic_config = config.get_section(config.config_ini_section) + alembic_config['sqlalchemy.url'] = DATABASE_URL + + connectable = engine_from_config( + alembic_config, + prefix="sqlalchemy.", + poolclass=pool.NullPool, + ) + + with connectable.connect() as connection: + context.configure( + connection=connection, target_metadata=target_metadata + ) + + with context.begin_transaction(): + context.run_migrations() + + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() diff --git a/alembic/script.py.mako b/alembic/script.py.mako new file mode 100644 index 000000000..fbc4b07dc --- /dev/null +++ b/alembic/script.py.mako @@ -0,0 +1,26 @@ +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +${imports if imports else ""} + +# revision identifiers, used by Alembic. +revision: str = ${repr(up_revision)} +down_revision: Union[str, None] = ${repr(down_revision)} +branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)} +depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)} + + +def upgrade() -> None: + ${upgrades if upgrades else "pass"} + + +def downgrade() -> None: + ${downgrades if downgrades else "pass"} diff --git a/api/core/__init__.py b/api/core/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/api/core/base/__init__.py b/api/core/base/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/api/core/base/services.py b/api/core/base/services.py new file mode 100644 index 000000000..9636c474b --- /dev/null +++ b/api/core/base/services.py @@ -0,0 +1,22 @@ +from abc import ABC, abstractmethod + +class Service(ABC): + @abstractmethod + def create(self): + pass + + @abstractmethod + def fetch(self): + pass + + @abstractmethod + def fetch_all(self): + pass + + @abstractmethod + def update(self): + pass + + @abstractmethod + def delete(self): + pass \ No newline at end of file diff --git a/api/core/dependencies/__init__.py b/api/core/dependencies/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/api/core/dependencies/user.py b/api/core/dependencies/user.py new file mode 100644 index 000000000..9f757f3c2 --- /dev/null +++ b/api/core/dependencies/user.py @@ -0,0 +1,36 @@ +from typing import Union +from fastapi.security import OAuth2PasswordBearer +from fastapi import Depends, Cookie, HTTPException, status +from sqlalchemy.orm import Session +from sqlalchemy.sql import and_ +from jose import JWTError, jwt + +from api.v1.schemas import auth as user_schema +from api.v1.services.auth import User +from api.v1.models.auth import User as UserModel + +from api.db.database import get_db +from api.core import responses + +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="login") + +def is_authenticated( + access_token: Union[str, user_schema.Token] = Depends(oauth2_scheme), + db: Session = Depends(get_db), +) -> Union[user_schema.User, JWTError]: + + if not access_token: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=responses.INVALID_CREDENTIALS) + + userService = User() + + access_token_info = userService.verify_access_token(access_token, db) + + if type(access_token_info) is JWTError: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=responses.INVALID_CREDENTIALS) + + user = userService.fetch(id=access_token_info.id,db=db) + + return user + + diff --git a/api/core/responses.py b/api/core/responses.py new file mode 100644 index 000000000..851eaf241 --- /dev/null +++ b/api/core/responses.py @@ -0,0 +1,7 @@ +EMAIL_IN_USE = "This email is already in use." +NOT_FOUND = "Not found!" +ID_OR_UNIQUE_ID_REQUIRED = "ID or Unique ID required!" +INVALID_CREDENTIALS = "Invalid Credentials!" +COULD_NOT_VALIDATE_CRED = "Could not validate credentials." +SUCCESS = "SUCCESS" +EXPIRED="Token expired." diff --git a/api/db/__init__.py b/api/db/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/api/db/database.py b/api/db/database.py new file mode 100644 index 000000000..472829eca --- /dev/null +++ b/api/db/database.py @@ -0,0 +1,50 @@ +# database.py +from sqlalchemy import create_engine +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import sessionmaker +from decouple import config + + +def get_db_engine(): + + DB_TYPE = config("DB_TYPE") + DB_NAME = config("DB_NAME") + DB_USER = config("DB_USER") + DB_PASSWORD = config("DB_PASSWORD") + DB_HOST = config("DB_HOST") + DB_PORT = config("DB_PORT") + MYSQL_DRIVER = config("MYSQL_DRIVER") + DATABASE_URL = "" + + if DB_TYPE == "mysql": + DATABASE_URL = f'mysql+{MYSQL_DRIVER}://{DB_USER}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB_NAME}' + elif DB_TYPE == "postgresql": + DATABASE_URL = f"postgresql://{DB_USER}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB_NAME}" + else: + DATABASE_URL = "sqlite:///./database.db" + + if DB_TYPE == "sqlite": + db_engine = create_engine(DATABASE_URL, connect_args={"check_same_thread": False}) + else: + db_engine = create_engine(DATABASE_URL, pool_size=32, max_overflow=64) + + return db_engine + +db_engine = get_db_engine() + + +SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=db_engine) + +Base = declarative_base() + + +def create_database(): + return Base.metadata.create_all(bind=db_engine) + + +def get_db(): + db = SessionLocal() + try: + yield db + finally: + db.close() diff --git a/api/db/mongo.py b/api/db/mongo.py new file mode 100644 index 000000000..4003c47a3 --- /dev/null +++ b/api/db/mongo.py @@ -0,0 +1,19 @@ +from pymongo import MongoClient +from api.utils import settings +from pymongo.mongo_client import MongoClient +from motor.motor_asyncio import AsyncIOMotorClient + + +def create_nosql_db(): + + client = MongoClient(settings.MONGO_URI) + + try: + client.admin.command("ping") + print("MongoDB Connection Established...") + except Exception as e: + print(e) + + +client = MongoClient(settings.MONGO_URI) +db = client.get_database(settings.MONGO_DB_NAME) diff --git a/api/utils/__init__.py b/api/utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/api/utils/dict.py b/api/utils/dict.py new file mode 100644 index 000000000..45d1d291b --- /dev/null +++ b/api/utils/dict.py @@ -0,0 +1,9 @@ + +def clone_object(obj: dict, unwanted_fields=[]): + new_obj = {} + for k in list(obj.keys()): + if k in unwanted_fields: + continue + + new_obj[k] = obj[k] + return new_obj diff --git a/api/utils/exceptions.py b/api/utils/exceptions.py new file mode 100644 index 000000000..e4744ddb1 --- /dev/null +++ b/api/utils/exceptions.py @@ -0,0 +1,7 @@ +from fastapi import HTTPException + +class CustomException(HTTPException): + + @staticmethod + def PermissionError(): + raise HTTPException(status_code=403, detail="User has no permission to perform this action") \ No newline at end of file diff --git a/api/utils/mailer.py b/api/utils/mailer.py new file mode 100644 index 000000000..27675839b --- /dev/null +++ b/api/utils/mailer.py @@ -0,0 +1,7 @@ +import resend +from decouple import config + +resend.api_key = config('RESEND_API_KEY') + +class Mailer(resend.Emails): + pass \ No newline at end of file diff --git a/api/utils/paginator.py b/api/utils/paginator.py new file mode 100644 index 000000000..f27af595f --- /dev/null +++ b/api/utils/paginator.py @@ -0,0 +1,47 @@ +from sqlalchemy.orm import Session + +def total_row_count(model, organization_id, db: Session): + return db.query(model).filter(model.organization_id == organization_id).filter( + model.is_deleted == False).count() + +def off_set(page: int, size: int): + return (page-1)*size + + +def size_validator(size:int): + if size < 0 or size > 100: + return "page size must be between 0 and 100" + return size + + +def page_urls(page: int, size: int, count: int, endpoint: str): + paging = {} + if (size + off_set(page, size)) >= count: + paging['next'] = None + if page > 1: + paging['previous'] = f"{endpoint}?page={page-1}&size={size}" + else: + paging['previous'] = None + else: + paging['next'] = f"{endpoint}?page={page+1}&size={size}" + if page > 1: + paging['previous'] = f"{endpoint}?page={page-1}&size={size}" + else: + paging['previous'] = None + + return paging + + +def build_paginated_response( + page: int, size: int, total: int, pointers: dict, items +) -> dict: + response = { + "page": page, + "size": size, + "total": total, + "previous_page": pointers["previous"], + "next_page": pointers["next"], + "items": items, + } + + return response \ No newline at end of file diff --git a/api/utils/settings.py b/api/utils/settings.py new file mode 100644 index 000000000..0fb3c8532 --- /dev/null +++ b/api/utils/settings.py @@ -0,0 +1,4 @@ +from decouple import config + +MONGO_URI = config("MONGO_URI") +MONGO_DB_NAME = config("MONGO_DB_NAME") \ No newline at end of file diff --git a/api/utils/sql.py b/api/utils/sql.py new file mode 100644 index 000000000..e69de29bb diff --git a/api/utils/string.py b/api/utils/string.py new file mode 100644 index 000000000..d5daa19b7 --- /dev/null +++ b/api/utils/string.py @@ -0,0 +1,16 @@ +from fastapi import status +from fastapi.exceptions import HTTPException + +def is_empty_string(string: str) -> bool: + if len(string.strip()) == 0: + return True + + return False + +class EmptyStringException(HTTPException): + def __init__(self, detail: str) -> None: + super().__init__( + status_code=status.HTTP_400_BAD_REQUEST, + detail=detail, + headers=None, + ) \ No newline at end of file diff --git a/api/utils/utils.py b/api/utils/utils.py new file mode 100644 index 000000000..0b1a27285 --- /dev/null +++ b/api/utils/utils.py @@ -0,0 +1,13 @@ +def build_paginated_response( + page: int, size: int, total: int, pointers: dict, items +) -> dict: + response = { + "page": page, + "size": size, + "total": total, + "previous_page": pointers["previous"], + "next_page": pointers["next"], + "items": items, + } + + return response \ No newline at end of file diff --git a/api/v1/__init__.py b/api/v1/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/api/v1/models/__init__.py b/api/v1/models/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/api/v1/models/auth.py b/api/v1/models/auth.py new file mode 100644 index 000000000..b85749d7b --- /dev/null +++ b/api/v1/models/auth.py @@ -0,0 +1,28 @@ +from sqlalchemy import Boolean, Column, ForeignKey, Integer, String, DateTime, BIGINT, Text +from sqlalchemy.orm import relationship +from datetime import datetime, date +from api.db.database import Base +import passlib.hash as _hash + + +class User(Base): + __tablename__ = "users" + id = Column(BIGINT, primary_key=True, autoincrement=True, index=True) + unique_id = Column(String(255), nullable=True) + first_name = Column(String(255), nullable=False) + last_name = Column(String(255), nullable=False) + email = Column(String(500), unique=True, index=True, nullable=False) + password = Column(String(500), nullable=False) + is_active = Column(Boolean, default=True) + date_created = Column(DateTime,default=datetime.utcnow) + last_updated = Column(DateTime,default=datetime.utcnow) + is_deleted = Column(Boolean, default=False) + +class BlackListToken(Base): + __tablename__ = "blacklist_tokens" + id = Column(BIGINT, primary_key=True, autoincrement=True, index=True) + created_by = Column(BIGINT, ForeignKey('users.id'), index=True) + token = Column(String(255), index=True) + date_created = Column(DateTime, default= datetime.utcnow) + + diff --git a/api/v1/routes/__init__.py b/api/v1/routes/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/api/v1/routes/auth.py b/api/v1/routes/auth.py new file mode 100644 index 000000000..640880051 --- /dev/null +++ b/api/v1/routes/auth.py @@ -0,0 +1,253 @@ +from fastapi import Depends, Cookie, HTTPException, APIRouter, Depends, status, Response, Request, BackgroundTasks +from fastapi.security import OAuth2PasswordBearer +from sqlalchemy.orm import Session +from datetime import timedelta, datetime +from typing import Union +from decouple import config +from api.v1.schemas import auth as user_schema +from api.db.database import get_db +from api.v1.services.auth import User +from api.v1.models.auth import User as UserModel +from api.core import responses +from api.core.dependencies.user import is_authenticated + +ACCESS_TOKEN_EXPIRE_MINUTES = int(config('ACCESS_TOKEN_EXPIRE_MINUTES')) +JWT_REFRESH_EXPIRY = int(config('JWT_REFRESH_EXPIRY')) +IS_REFRESH_TOKEN_SECURE = True if config('PYTHON_ENV') == "production" else False + + +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="login") + +app = APIRouter(tags=["Auth"]) + + +@app.post("/signup", status_code=status.HTTP_201_CREATED) +async def signup( + response: Response, + user:user_schema.CreateUser, + db:Session = Depends(get_db) +): + + """ + Endpoint to create a user + + Returns: Created User. + """ + userService = User() + created_user = userService.create(user=user, db=db) + + + access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) + access_token = userService.create_access_token( + data={"id": created_user.id}, db=db, expires_delta=access_token_expires + ) + + refresh_token = userService.create_refresh_token(data={"id": created_user.id}, db=db) + + response.set_cookie( + key="refresh_token", + value=refresh_token, + max_age=JWT_REFRESH_EXPIRY, + secure=True, + httponly=True, + samesite="strict", + ) + + return {"message": responses.SUCCESS, + "data": user_schema.ShowUser.model_validate(created_user), + "access_token": access_token, + "refresh_token": refresh_token, + "token_type": "bearer" + } + + +@app.post("/login", status_code=status.HTTP_200_OK) +async def login_for_access_token( + response: Response, + data: user_schema.Login, + background_task: BackgroundTasks, + db: Session = Depends(get_db) +): + """ + LOGIN + + Returns: Logged in User and access token. + """ + + userService = User() + user = userService.authenticate_user(email=data.email, password=data.password, db=db) + if not user: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=responses.INVALID_CREDENTIALS, + headers={"WWW-Authenticate": "Bearer"}, + ) + access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) + access_token = userService.create_access_token( + data={"id": user.id}, db=db, expires_delta=access_token_expires + ) + + refresh_token = userService.create_refresh_token(data={"id": user.id}, db=db) + + response.set_cookie( + key="refresh_token", + value=refresh_token, + max_age=JWT_REFRESH_EXPIRY, + secure=True, + httponly=True, + samesite="strict", + path="/" + ) + + return { + "data": user_schema.ShowUser.model_validate(user), + "access_token": access_token, + "refresh_token": refresh_token, + "token_type": "bearer" + } + + + +@app.get("/user", status_code=status.HTTP_200_OK) +async def get_user( + user: user_schema.User = Depends(is_authenticated), + db: Session = Depends(get_db), +): + """ + Returns an authenticated user information + """ + print(user) + return user + + + +@app.get("/refresh-access-token", status_code=status.HTTP_200_OK) +async def refresh_access_token( + response: Response, + refresh_token: Union[str, None] = Cookie(default=None), + db: Session = Depends(get_db), +): + """Refreshes an access_token with the issued refresh_token + Parameters + ---------- + refresh_token : str, None + The refresh token sent in the cookie by the client (default is None) + + Raises + ------ + UnauthorizedError + If the refresh token is None. + """ + print(refresh_token) + credentials_exception =HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Could not validate credentials", + headers={"WWW-Authenticate": "Bearer"}, + ) + + if refresh_token is None: + raise HTTPException( + detail="Log in to authenticate user", + status_code=status.HTTP_401_UNAUTHORIZED, + ) + + valid_refresh_token = User.verify_refresh_token( + refresh_token, db + ) + + if valid_refresh_token.email is None: + response.set_cookie( + key="refresh_token", + value=refresh_token, + max_age=JWT_REFRESH_EXPIRY, + secure=IS_REFRESH_TOKEN_SECURE, + httponly=True, + samesite="strict", + ) + + print("refresh failed") + else: + user = ( + db.query(UserModel) + .filter(UserModel.id == valid_refresh_token.id) + .first() + ) + + access_token = User.create_access_token( + {"user_id": valid_refresh_token.id}, db + ) + + response.set_cookie( + key="refresh_token", + value=refresh_token, + max_age=JWT_REFRESH_EXPIRY, + secure=IS_REFRESH_TOKEN_SECURE, + httponly=True, + samesite="strict", + ) + + # Access token expires in 15 mins, + return {"user": user_schema.ShowUser.model_validate(user), "access_token": access_token, "expires_in": 900} + + + +@app.post("/logout", status_code=status.HTTP_200_OK) +async def logout_user( + request: Request, + response: Response, + user: user_schema.User = Depends(is_authenticated), + db: Session = Depends(get_db), +): + """ + This endpoint logs out an authenticated user. + + Returns message: User logged out successfully. + """ + + userService = User() + access_token = request.headers.get('Authorization') + + logout = userService.logout(token=access_token, user=user, db=db) + + response.set_cookie( + key="refresh_token", + max_age="0", + secure=True, + httponly=True, + samesite="strict", + ) + + return {"message": "User logged out successfully."} + + +@app.delete("/users/{user_id}") +async def delete_user( + user_id: int, + user: user_schema.User = Depends(is_authenticated), + db: Session = Depends(get_db), +): + + """ + This endpoint deletes a user from the db. (Soft delete) + + Returns message: User deleted successfully. + """ + userService = User() + deleted_user = userService.delete(db=db, id=user_id) + + return {"message": "User deleted successfully."} + + +@app.post("/users/roles", status_code=status.HTTP_200_OK) +async def create_user_roles( + user: user_schema.ShowUser = Depends(is_authenticated), + db:Session = Depends(get_db) +): + + """ + Endpoint to create custom roles for users mixing permissions. + + Returns created role + + """ + pass \ No newline at end of file diff --git a/api/v1/schemas/__init__.py b/api/v1/schemas/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/api/v1/schemas/auth.py b/api/v1/schemas/auth.py new file mode 100644 index 000000000..c2cfc8e28 --- /dev/null +++ b/api/v1/schemas/auth.py @@ -0,0 +1,67 @@ +from uuid import uuid4 +from pydantic import BaseModel, model_validator, EmailStr +from typing import Optional +from fastapi import HTTPException, status +from datetime import datetime +from api.db.database import SessionLocal +from api.v1.models.auth import User +from api.core import responses + +""" +TODO: PASSWORD COMPLEXITY VALIDATION ON CREATEUSER SCHEMA + UNIQUE_ID SHOULD NOT ALREADY EXIST + + +""" +class Login(BaseModel): + email: EmailStr + password: str + +class Token(BaseModel): + access_token: str + token_type: str + +class TokenData(BaseModel): + id:int + email:str + +class UserBase(BaseModel): + first_name: str + last_name: str + email: str + unique_id: Optional[str] = None + is_active: bool = True + date_created: Optional[datetime] = datetime.utcnow() + last_updated: Optional[datetime] = datetime.utcnow() + + class Config: + from_attributes = True + + +class CreateUser(UserBase): + password: str + + class Config: + from_attributes = True + + + #validate email not in use + @model_validator(mode='before') + @classmethod + def validate_email(cls, values): + email = values.get("email") + + with SessionLocal() as db: + user_email = db.query(User).filter(User.email == email).first() + if user_email: + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail= responses.EMAIL_IN_USE) + + + return values + +class ShowUser(UserBase): + id: int + is_deleted: Optional[bool] + + class Config: + from_attributes=True \ No newline at end of file diff --git a/api/v1/services/__init__.py b/api/v1/services/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/api/v1/services/auth.py b/api/v1/services/auth.py new file mode 100644 index 000000000..16255be29 --- /dev/null +++ b/api/v1/services/auth.py @@ -0,0 +1,236 @@ +import passlib.hash as _hash +from passlib.context import CryptContext +from fastapi.security import OAuth2PasswordBearer +from fastapi import Depends +from jose import JWTError, jwt +from sqlalchemy.orm import Session +from sqlalchemy.sql import or_ +from fastapi import HTTPException, status +from datetime import datetime,timedelta +from typing import Annotated, Union +from uuid import uuid4 +from decouple import config + +from api.v1.models.auth import User as UserModel, BlackListToken +from api.v1.schemas import auth as user_schema +from api.core import responses +from api.core.base.services import Service + + +SECRET_KEY = config('SECRET_KEY') +ALGORITHM = config("ALGORITHM") +pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") +JWT_REFRESH_EXPIRY = int(config("JWT_REFRESH_EXPIRY")) + + +class User(Service): + + def __init__(self) -> None: + pass + + def create(self, user: user_schema.CreateUser, db: Session): + created_user = UserModel(unique_id=user.unique_id, + first_name=user.first_name, + last_name=user.last_name, + email=user.email, + date_created=user.date_created, + last_updated=user.last_updated, + is_active=user.is_active, + password=self.hash_password(user.password)) + db.add(created_user) + db.commit() + + return created_user + + @staticmethod + def fetch(db: Session, id: int = None, unique_id: str = None) -> user_schema.User: + if id is None and unique_id is None: + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=responses.ID_OR_UNIQUE_ID_REQUIRED) + + user = db.query(UserModel).filter(UserModel.id==id).filter(UserModel.is_deleted==False).first() + if user is None: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=responses.NOT_FOUND) + + return user + + @staticmethod + def fetch_all(): + pass + + @staticmethod + def fetch_by_email(email: str, db: Session) -> user_schema.User: + user = db.query(UserModel).filter(UserModel.email == email, UserModel.is_deleted==False).first() + + return user + + def update(self): + pass + + @classmethod + def delete(cls,db: Session, id: int=None, unique_id: str=None) -> user_schema.User: + user = cls.fetch(id=id, unique_id=unique_id, db=db) + user.is_deleted = True + db.commit() + return user + + @classmethod + async def get_current_user(cls, token: Annotated[str, Depends(oauth2_scheme)], db:Session) -> user_schema.User: + credentials_exception = HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=responses.COULD_NOT_VALIDATE_CRED, + headers={"WWW-Authenticate": "Bearer"}, + ) + try: + payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) + id: str = payload.get("id") + if id is None: + raise credentials_exception + token_data = id + except JWTError: + raise credentials_exception + user = cls.fetch(id=token_data,db=db) + if user is None: + raise credentials_exception + return user + + @classmethod + def authenticate_user(cls, db: Session, password: str,email: str) -> user_schema.User: + user = cls.fetch_by_email(email=email, db=db) + if not user: + return False + if not cls.verify_password(password, user.password): + return False + return user + + @staticmethod + def verify_password(password, hashed_password): + return pwd_context.verify(password, hashed_password) + + @staticmethod + def hash_password(password) -> str: + return pwd_context.hash(password) + + @staticmethod + def create_access_token(data: dict, db: Session, expires_delta: timedelta = None) -> str: + to_encode = data.copy() + if expires_delta: + expire = datetime.utcnow() + expires_delta + else: + expire = datetime.utcnow() + timedelta(minutes=30) + to_encode.update({"exp": expire}) + encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) + + db.commit() + + return encoded_jwt + + @staticmethod + def create_refresh_token(data: dict, db: Session) -> str: + to_encode = data.copy() + + expire = datetime.utcnow() + timedelta(seconds=int(JWT_REFRESH_EXPIRY)) + to_encode.update({"exp": expire}) + encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) + + return encoded_jwt + + @classmethod + def verify_access_token(cls, token: str, db: Session) -> user_schema.TokenData: + try: + invalid_token = cls.check_token_blacklist(db=db, token=token) + if invalid_token == True: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED,detail=responses.INVALID_CREDENTIALS, + headers={"WWW-Authenticate": "Bearer"}) + + payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) + id: int = payload.get("id") + + if id is None: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED,detail=responses.INVALID_CREDENTIALS, + headers={"WWW-Authenticate": "Bearer"}) + + user = cls.fetch(db=db,id=id) + + if user is None: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=responses.INVALID_CREDENTIALS, + headers={"WWW-Authenticate": "Bearer"}) + + + token_data = user_schema.TokenData(email=user.email, id=id) + + return token_data + + except JWTError as error: + print(error, 'error') + return JWTError(HTTPException(status_code=status.HTTP_401_UNAUTHORIZED,detail=responses.INVALID_CREDENTIALS, + headers={"WWW-Authenticate": "Bearer"})) + + @classmethod + def verify_refresh_token(cls, refresh_token: str, db: Session) -> user_schema.TokenData: + try: + if not refresh_token: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=responses.EXPIRED) + + payload = jwt.decode(refresh_token, SECRET_KEY, algorithms=[ALGORITHM]) + id: str = payload.get("id") + + if id is None: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED,detail=responses.INVALID_CREDENTIALS, + headers={"WWW-Authenticate": "Bearer"}) + + user = cls.fetch(id=id, db=db) + + if user is None: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=responses.INVALID_CREDENTIALS) + + + token_data = user_schema.TokenData(email=user.email, id=id) + + except JWTError: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED,detail=responses.INVALID_CREDENTIALS, + headers={"WWW-Authenticate": "Bearer"}) + + return token_data + + @staticmethod + def check_token_blacklist(token: str, db:Session)-> bool: + fetched_token = db.query(BlackListToken).filter(BlackListToken.token == token).first() + + if fetched_token: + return True + else: + return False + + @staticmethod + def logout(token: str, user: user_schema.ShowUser, db:Session) -> str: + blacklist_token = BlackListToken( + token=token.split(' ')[1], + created_by=user.id + ) + + db.add(blacklist_token) + db.commit() + + return token + + + + + + + + + + + + + + + + + + + + + diff --git a/main.py b/main.py new file mode 100644 index 000000000..7d2e42597 --- /dev/null +++ b/main.py @@ -0,0 +1,54 @@ +import uvicorn +from contextlib import asynccontextmanager +from typing import Union +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware +from starlette.requests import Request +from api.db.database import create_database +from api.db.mongo import create_nosql_db +from api.v1.routes.auth import app as auth + + +@asynccontextmanager +async def lifespan(app: FastAPI): + create_database() + create_nosql_db() + yield + ## write shutdown logic below yield + + +app = FastAPI(lifespan=lifespan) + + +create_nosql_db() + + +origins = [ + "http://localhost:3000", + "http://localhost:3001", +] + +app.add_middleware( + CORSMiddleware, + allow_origins=origins, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + + +app.include_router(auth, tags=["Auth"]) +# app.include_router(users, tags=["Users"]) + + + +@app.get("/", tags=["Home"]) +async def get_root(request: Request) -> dict: + return { + "message": "Welcome to API", + "URL": "", + } + + +if __name__ == "__main__": + uvicorn.run("main:app", port=7001, reload=True) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 000000000..853b6833e --- /dev/null +++ b/requirements.txt @@ -0,0 +1,12 @@ +fastapi +SQLAlchemy +alembic +pydantic +python-decouple +pytest +python-jose +python-multipart +python-dotenv +bcrypt +passlib +pymysql diff --git a/setup.py b/setup.py new file mode 100644 index 000000000..fb0a78572 --- /dev/null +++ b/setup.py @@ -0,0 +1,2 @@ +from setuptools import setup, find_packages +setup(name='api', packages=find_packages()) \ No newline at end of file diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/database.py b/tests/database.py new file mode 100644 index 000000000..458a640b0 --- /dev/null +++ b/tests/database.py @@ -0,0 +1,51 @@ +from fastapi.testclient import TestClient +import pytest +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker +from sqlalchemy.ext.declarative import declarative_base +from ..main import app +from decouple import config + + +from api.db.database import get_db, Base + + +DB_TYPE = config("DB_TYPE") +DB_NAME = config("DB_NAME") +DB_USER = config("DB_USER") +DB_PASSWORD = config("DB_PASSWORD") +DB_HOST = config("DB_HOST") +DB_PORT = config("DB_PORT") +MYSQL_DRIVER = config("MYSQL_DRIVER") +DATABASE_URL = "" + +SQLALCHEMY_DATABASE_URL = f'{DB_TYPE}+{MYSQL_DRIVER}://{DB_USER}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB_NAME}_test' + + +engine = create_engine(SQLALCHEMY_DATABASE_URL) + +TestingSessionLocal = sessionmaker( + autocommit=False, autoflush=False, bind=engine) + + +@pytest.fixture() +def session(): + Base.metadata.drop_all(bind=engine) + Base.metadata.create_all(bind=engine) + db = TestingSessionLocal() + try: + yield db + finally: + db.close() + + +@pytest.fixture() +def client(session): + def override_get_db(): + + try: + yield session + finally: + session.close() + app.dependency_overrides[get_db] = override_get_db + yield TestClient(app) \ No newline at end of file diff --git a/tests/v1/__init__.py b/tests/v1/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/v1/conftest.py b/tests/v1/conftest.py new file mode 100644 index 000000000..5268a6a48 --- /dev/null +++ b/tests/v1/conftest.py @@ -0,0 +1,76 @@ +from fastapi.testclient import TestClient +import pytest +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker +from sqlalchemy.ext.declarative import declarative_base +from main import app +from decouple import config +import json + + +from api.db.database import get_db +from api.db.database import Base +from api.v1.services.auth import User +from api.v1.models.auth import User as UserModel + + +DB_TYPE = config("DB_TYPE") +DB_NAME = config("DB_NAME") +DB_USER = config("DB_USER") +DB_PASSWORD = config("DB_PASSWORD") +DB_HOST = config("DB_HOST") +DB_PORT = config("DB_PORT") +MYSQL_DRIVER = config("MYSQL_DRIVER") +DATABASE_URL = "" + +SQLALCHEMY_DATABASE_URL = f'{DB_TYPE}+{MYSQL_DRIVER}://{DB_USER}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB_NAME}_test' + +engine = create_engine(SQLALCHEMY_DATABASE_URL) + +TestingSessionLocal = sessionmaker( + autocommit=False, autoflush=False, bind=engine) + + +@pytest.fixture() +def session(): + Base.metadata.drop_all(bind=engine) + Base.metadata.create_all(bind=engine) + db = TestingSessionLocal() + try: + yield db + finally: + db.close() + + +@pytest.fixture() +def client(session): + def override_get_db(): + + try: + yield session + finally: + session.close() + app.dependency_overrides[get_db] = override_get_db + yield TestClient(app) + + +@pytest.fixture +def test_user(client): + payload = { + "first_name": "dean", + "last_name": "smith", + "email": "dean@gmail.com", + "password": "password123", + "unique_id": "1005" +} + res = client.post("/signup/", data=json.dumps(payload)) + + assert res.status_code == 201 + + new_user = res.json()['data'] + new_user['access_token'] = res.json()['access_token'] + new_user['email'] = payload['email'] + return new_user + + + diff --git a/tests/v1/test_users.py b/tests/v1/test_users.py new file mode 100644 index 000000000..68fcb4415 --- /dev/null +++ b/tests/v1/test_users.py @@ -0,0 +1,56 @@ +import pytest +from jose import jwt +from decouple import config +from api.v1.schemas.auth import User, ShowUser, Token, TokenData +import json + + +SECRET_KEY = config('SECRET_KEY') +ALGORITHM = config("ALGORITHM") + + +payload = { + "first_name": "test", + "last_name": "user", + "email": "test@gmail.com", + "password": "password123", + "unique_id": "1005" +} + +def test_create_user(client): + + res = client.post( + "/signup/", data=(json.dumps(payload))) + + new_user = ShowUser(**res.json()['data']) + assert new_user.email==payload['email'] + assert res.status_code == 201 + + +def test_login_user(test_user, client): + res = client.post( + "/login", data=json.dumps({"email": test_user['email'], "password": "password123"})) + + logged_in_user = ShowUser(**res.json()['data']) + payload = jwt.decode(res.json()['access_token'], + SECRET_KEY, algorithms=[ALGORITHM]) + id = payload.get("id") + token_type= res.json()['token_type'] + assert id == test_user['id'] + assert logged_in_user.email == test_user['email'] + assert token_type == "bearer" + assert res.status_code == 200 + + +@pytest.mark.parametrize("email, password, status_code", [ + ('wrongemail@gmail.com', 'password123', 401), + ('dean@gmail.com', 'wrongpassword', 401), + ('wrongemail@gmail.com', 'wrongpassword', 401), + (None, 'password123', 422), + ('test@gmail.com', None, 422) +]) +def test_incorrect_login(test_user, client, email, password, status_code): + res = client.post( + "/login", data=json.dumps({"email": email, "password": password})) + + assert res.status_code == status_code \ No newline at end of file