diff --git a/.env.sample b/.env.sample index bf53fdac0..72cfca777 100644 --- a/.env.sample +++ b/.env.sample @@ -34,5 +34,8 @@ TWILIO_PHONE_NUMBER="TWILIO_PHONE_NUMBER" FLUTTERWAVE_SECRET="" PAYSTACK_SECRET="" +STRIPE_SECRET_KEY="" +STRIPE_WEBHOOK_SECRET="" + MAILJET_API_KEY='MAIL JET API KEY' MAILJET_API_SECRET='SECRET KEY' diff --git a/alembic/versions/2e1bdb317917_added_models_to_support_billing_.py b/alembic/versions/2e1bdb317917_added_models_to_support_billing_.py new file mode 100644 index 000000000..c388fcc91 --- /dev/null +++ b/alembic/versions/2e1bdb317917_added_models_to_support_billing_.py @@ -0,0 +1,30 @@ +"""added models to support billing subscription + +Revision ID: 2e1bdb317917 +Revises: 95a09b7f5c2c +Create Date: 2024-08-10 18:53:02.873361 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '2e1bdb317917' +down_revision: Union[str, None] = '95a09b7f5c2c' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + pass + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + pass + # ### end Alembic commands ### diff --git a/alembic/versions/5048f25d8274_added_models_to_support_billing_.py b/alembic/versions/5048f25d8274_added_models_to_support_billing_.py new file mode 100644 index 000000000..68605ef5f --- /dev/null +++ b/alembic/versions/5048f25d8274_added_models_to_support_billing_.py @@ -0,0 +1,44 @@ +"""added models to support billing subscription + +Revision ID: 5048f25d8274 +Revises: ff92a0037698 +Create Date: 2024-08-10 18:15:12.101521 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '5048f25d8274' +down_revision: Union[str, None] = 'ff92a0037698' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('user_subscriptions', + sa.Column('user_id', sa.String(), nullable=False), + sa.Column('plan_id', sa.String(), nullable=False), + sa.Column('active', sa.Boolean(), nullable=True), + sa.Column('start_date', sa.String(), nullable=False), + sa.Column('end_date', sa.String(), nullable=True), + sa.Column('id', sa.String(), nullable=False), + sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True), + sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True), + sa.ForeignKeyConstraint(['plan_id'], ['billing_plans.id'], ondelete='CASCADE'), + sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'), + sa.PrimaryKeyConstraint('id') + ) + op.create_index(op.f('ix_user_subscriptions_id'), 'user_subscriptions', ['id'], unique=False) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index(op.f('ix_user_subscriptions_id'), table_name='user_subscriptions') + op.drop_table('user_subscriptions') + # ### end Alembic commands ### diff --git a/alembic/versions/95a09b7f5c2c_added_models_to_support_billing_.py b/alembic/versions/95a09b7f5c2c_added_models_to_support_billing_.py new file mode 100644 index 000000000..cb1527084 --- /dev/null +++ b/alembic/versions/95a09b7f5c2c_added_models_to_support_billing_.py @@ -0,0 +1,30 @@ +"""added models to support billing subscription + +Revision ID: 95a09b7f5c2c +Revises: 5048f25d8274 +Create Date: 2024-08-10 18:47:09.891441 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '95a09b7f5c2c' +down_revision: Union[str, None] = '5048f25d8274' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + pass + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + pass + # ### end Alembic commands ### diff --git a/alembic/versions/b7e8db61eb86_updated_billing_model_ensuing_that_plan_.py b/alembic/versions/b7e8db61eb86_updated_billing_model_ensuing_that_plan_.py new file mode 100644 index 000000000..264bdc553 --- /dev/null +++ b/alembic/versions/b7e8db61eb86_updated_billing_model_ensuing_that_plan_.py @@ -0,0 +1,30 @@ +"""updated billing model ensuing that plan name is unique + +Revision ID: b7e8db61eb86 +Revises: 2e1bdb317917 +Create Date: 2024-08-11 09:55:41.558757 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = 'b7e8db61eb86' +down_revision: Union[str, None] = '2e1bdb317917' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_unique_constraint(None, 'billing_plans', ['name']) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_constraint(None, 'billing_plans', type_='unique') + # ### end Alembic commands ### diff --git a/alembic/versions/c9d13f346b57_updated_billing_model_ensuing_that_plan_.py b/alembic/versions/c9d13f346b57_updated_billing_model_ensuing_that_plan_.py new file mode 100644 index 000000000..53fbcae37 --- /dev/null +++ b/alembic/versions/c9d13f346b57_updated_billing_model_ensuing_that_plan_.py @@ -0,0 +1,30 @@ +"""updated billing model ensuing that plan name is unique + +Revision ID: c9d13f346b57 +Revises: b7e8db61eb86 +Create Date: 2024-08-11 10:04:26.675695 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = 'c9d13f346b57' +down_revision: Union[str, None] = 'b7e8db61eb86' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + pass + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + pass + # ### end Alembic commands ### diff --git a/api/v1/models/billing_plan.py b/api/v1/models/billing_plan.py index 69eba80de..0fdb589f5 100644 --- a/api/v1/models/billing_plan.py +++ b/api/v1/models/billing_plan.py @@ -1,5 +1,5 @@ # app/models/billing_plan.py -from sqlalchemy import Column, String, ARRAY, ForeignKey, Numeric +from sqlalchemy import Column, String, ARRAY, ForeignKey, Numeric, Boolean from sqlalchemy.orm import relationship from api.v1.models.base_model import BaseTableModel @@ -10,7 +10,7 @@ class BillingPlan(BaseTableModel): organization_id = Column( String, ForeignKey("organizations.id", ondelete="CASCADE"), nullable=False ) - name = Column(String, nullable=False) + name = Column(String, nullable=False, unique=True) price = Column(Numeric, nullable=False) currency = Column(String, nullable=False) duration = Column(String, nullable=False) @@ -18,3 +18,16 @@ class BillingPlan(BaseTableModel): features = Column(ARRAY(String), nullable=False) organization = relationship("Organization", back_populates="billing_plans") + + +class UserSubscription(BaseTableModel): + __tablename__ = "user_subscriptions" + + user_id = Column(String, ForeignKey("users.id", ondelete="CASCADE"), nullable=False) + plan_id = Column(String, ForeignKey("billing_plans.id", ondelete="CASCADE"), nullable=False) + active = Column(Boolean, default=True) + start_date = Column(String, nullable=False) + end_date = Column(String, nullable=True) + + user = relationship("User", back_populates="subscriptions") + billing_plan = relationship("BillingPlan") \ No newline at end of file diff --git a/api/v1/models/user.py b/api/v1/models/user.py index b33da7f5f..057fc8198 100644 --- a/api/v1/models/user.py +++ b/api/v1/models/user.py @@ -85,6 +85,10 @@ class User(BaseTableModel): ) product_comments = relationship("ProductComment", back_populates="user", cascade="all, delete-orphan") + subscriptions = relationship( + "UserSubscription", back_populates="user", cascade="all, delete-orphan" + ) + def to_dict(self): obj_dict = super().to_dict() obj_dict.pop("password") diff --git a/api/v1/routes/__init__.py b/api/v1/routes/__init__.py index 8e9ce4ea3..1710771ea 100644 --- a/api/v1/routes/__init__.py +++ b/api/v1/routes/__init__.py @@ -43,6 +43,7 @@ from api.v1.routes.privacy import privacies from api.v1.routes.settings import settings from api.v1.routes.terms_and_conditions import terms_and_conditions +from api.v1.routes.stripe import subscription_ api_version_one = APIRouter(prefix="/api/v1") @@ -89,3 +90,4 @@ api_version_one.include_router(team) api_version_one.include_router(terms_and_conditions) api_version_one.include_router(product_comment) +api_version_one.include_router(subscription_) \ No newline at end of file diff --git a/api/v1/routes/stripe.py b/api/v1/routes/stripe.py new file mode 100644 index 000000000..9ae03ddd4 --- /dev/null +++ b/api/v1/routes/stripe.py @@ -0,0 +1,68 @@ +from fastapi import APIRouter, Depends, HTTPException, Request +from sqlalchemy.orm import Session +import stripe +from api.v1.services.stripe_payment import stripe_payment_request, update_user_plan +import json +from api.v1.schemas.stripe import PlanUpgradeRequest +from api.db.database import get_db +import os +from api.v1.models.user import User +from api.v1.services.user import user_service +from dotenv import load_dotenv, find_dotenv + +load_dotenv(find_dotenv()) + +stripe.api_key = os.getenv('STRIPE_SECRET_KEY') +endpoint_secret = os.getenv('STRIPE_WEBHOOK_SECRET') + +subscription_ = APIRouter(prefix="/payment", tags=["subscribe-plan"]) + +@subscription_.post("/stripe/upgrade-plan") +def stripe_payment( + plan_upgrade_request: PlanUpgradeRequest, + request: Request, + db: Session = Depends(get_db), + current_user: User = Depends(user_service.get_current_user) + ): + return stripe_payment_request(db, plan_upgrade_request.user_id, request, plan_upgrade_request.plan_name) + +@subscription_.get("/stripe/success") +def success_upgrade(): + return {"message" : "Payment successful"} + +@subscription_.get("/stripe/cancel") +def cancel_upgrade(): + return {"message" : "Payment canceled"} + +@subscription_.post("/webhook") +async def webhook_received( + request: Request, + db: Session = Depends(get_db) + ): + + payload = await request.body() + event = None + + try: + event = stripe.Event.construct_from(json.loads(payload), stripe.api_key) + except ValueError as e: + print("Invalid payload") + raise HTTPException(status_code=400, detail="Invalid payload") + except stripe.error.SignatureVerificationError as e: + print("Invalid signature") + raise HTTPException(status_code=400, detail="Invalid signature") + + if event["type"] == "checkout.session.completed": + payment = event["data"]["object"] + response_details = { + "amount": payment["amount_total"], + "currency": payment["currency"], + "user_id": payment["metadata"]["user_id"], + "user_email": payment["customer_details"]["email"], + "user_name": payment["customer_details"]["name"], + "order_id": payment["id"] + } + # Save to DB + # Send email in background task + await update_user_plan(db, payment["metadata"]["user_id"], payment["metadata"]["plan_name"]) + return {"message": response_details} diff --git a/api/v1/schemas/stripe.py b/api/v1/schemas/stripe.py new file mode 100644 index 000000000..1c445fa69 --- /dev/null +++ b/api/v1/schemas/stripe.py @@ -0,0 +1,30 @@ +from typing import List, Optional + +from pydantic import BaseModel, Field, validator + + +class PaymentInfo(BaseModel): + card_number: str = Field(..., min_length=16, max_length=16) + exp_month: int + exp_year: int + cvc: str = Field(..., min_length=3, max_length=4) + + @validator('card_number') + def card_number_validator(cls, v): + if not v.isdigit() or len(v) != 16: + raise ValueError('Card number must be 16 digits') + return v + + @validator('cvc') + def cvc_validator(cls, v): + if not v.isdigit() or not (3 <= len(v) <= 4): + raise ValueError('CVC must be 3 or 4 digits') + return v + + +class PlanUpgradeRequest(BaseModel): + user_id: str + plan_name: str + payment_info: Optional[PaymentInfo] = None + + diff --git a/api/v1/services/stripe_payment.py b/api/v1/services/stripe_payment.py new file mode 100644 index 000000000..968ec8316 --- /dev/null +++ b/api/v1/services/stripe_payment.py @@ -0,0 +1,126 @@ +from sqlalchemy.orm import Session +from api.v1.models.user import User +from api.v1.models.billing_plan import BillingPlan, UserSubscription +import stripe +from fastapi.encoders import jsonable_encoder +from api.utils.success_response import success_response +import os +from fastapi import HTTPException, status, Request +from datetime import datetime, timedelta + +stripe.api_key = os.getenv('STRIPE_SECRET_KEY') + +def get_plan_by_name(db: Session, plan_name: str): + return db.query(BillingPlan).filter(BillingPlan.name == plan_name).first() + +def stripe_payment_request(db: Session, user_id: str, request: Request, plan_name: str): + + base_url = request.base_url + + success_url = f"{base_url}api/v1/payment/stripe/success" + cancel_url = f"{base_url}api/v1/payment/stripe/cancel" + + user = db.query(User).filter(User.id == user_id).first() + + if not user: + raise HTTPException(status_code=404, detail="User not found") + + plan = get_plan_by_name(db, plan_name) + + if not plan: + raise HTTPException(status_code=404, detail="Plan not found") + + if plan.name != "Free": + try: + # Create a checkout session + checkout_session = stripe.checkout.Session.create( + payment_method_types=['card'], + line_items=[{ + 'price_data': { + 'currency': plan.currency, + 'product_data': { + 'name': plan.name, + }, + 'unit_amount': int(plan.price * 100), # Convert to the smallest unit + }, + 'quantity': 1, + }], + mode='payment', + customer_email=user.email, # Automatically fill in the user's email in the checkout + success_url=success_url, + cancel_url=cancel_url, + metadata={ + 'user_id': user_id, + 'plan_name': plan_name, + }, + ) + + if checkout_session: + data = { + "cancel_url": checkout_session["cancel_url"], + "success_url": checkout_session["success_url"], + "customer_details": checkout_session["customer_details"], + "customer_email": checkout_session["customer_email"], + "created_at": checkout_session["created"], + "expires_at": checkout_session["expires_at"], + "metadata": checkout_session["metadata"], + "payment_method_types": checkout_session["payment_method_types"], + "checkout_url": checkout_session["url"], + "amount_total": checkout_session["amount_total"] + } + + return success_response( + status_code=status.HTTP_201_CREATED, + message=f'payment in progress', + data=data, + ) + + except stripe.error.StripeError as e: + # Handle Stripe error + raise HTTPException(status_code=500, detail=f"Payment failed: {str(e)}") + + else: + raise HTTPException(status_code=400, detail="No payment is required for the Free plan") + + +def convert_duration_to_timedelta(duration: str) -> timedelta: + if duration == "monthly": + return timedelta(days=30) # Approximate month length + elif duration == "yearly": + return timedelta(days=365) # Approximate year length + else: + raise ValueError("Invalid duration") + +async def update_user_plan(db: Session, user_id: str, plan_name: str): + user = db.query(User).filter(User.id == user_id).first() + plan = get_plan_by_name(db, plan_name) + + if not user: + raise HTTPException(status_code=404, detail="User not found") + + if not plan: + raise HTTPException(status_code=404, detail="Plan not found") + + # Convert duration from string to timedelta + try: + duration = convert_duration_to_timedelta(plan.duration) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + + # Update the user's subscription in the database + user_subscription = db.query(UserSubscription).filter(UserSubscription.user_id == user_id).first() + + if user_subscription: + user_subscription.plan_id = plan.id + user_subscription.start_date = datetime.utcnow() + user_subscription.end_date = datetime.utcnow() + duration + else: + new_subscription = UserSubscription( + user_id=user_id, + plan_id=plan.id, + start_date=datetime.utcnow(), + end_date=datetime.utcnow() + duration + ) + db.add(new_subscription) + + db.commit() \ No newline at end of file