Skip to content

Commit

Permalink
test: add typehints
Browse files Browse the repository at this point in the history
  • Loading branch information
DEENUU1 committed Mar 16, 2024
1 parent 8b2a1d3 commit 1878686
Show file tree
Hide file tree
Showing 12 changed files with 55 additions and 46 deletions.
8 changes: 7 additions & 1 deletion app/notification/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,14 @@
from .email_notification import EmailNotificationStrategy


def create_notifications(db: Session = Depends(get_db)) -> None:
def create_notifications(db: Session) -> None:

notification_filters = NotificationFilterService(db).get_all_active()
print("Get all active notification filters")

# Iterate through all active notification_filters objects
for filter in notification_filters:
print("Iterate through all active notification_filters objects")

# Get filtered offers
offers = OfferService(db).get_all(
Expand All @@ -30,17 +32,21 @@ def create_notifications(db: Session = Depends(get_db)) -> None:
floor=filter.floor,
query=filter.query,
)
print("Get filtered offers")

notification_input = NotificationInput(
user_id=filter.user_id,
title=f"New offers for {filter.category}",
message=f"There are {len(offers.offers)} offers for {filter.category}",
)
print("Create notification object")

notification_service = NotificationService(db)
notification_object = notification_service.create(notification_input)
notification_id = notification_object.id

# Update notification object with offers
print("Update notification object with offers")
offers_ids = [offer.id for offer in offers.offers]
notification_service.update_offers(notification_id, offers_ids)

Expand Down
17 changes: 6 additions & 11 deletions app/repositories/notification_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,32 +17,27 @@ def create(self, notification: NotificationInput) -> NotificationOutput:
self.session.add(new_notification)
self.session.commit()
self.session.refresh(new_notification)
return NotificationOutput(
**new_notification.__dict__,
id=new_notification.id,
created_at=new_notification.created_at,
read=new_notification.read,
)
return NotificationOutput(**new_notification.__dict__)

def get_all_by_user_id(self, user_id: UUID4) -> List[NotificationOutput]:
notifications = self.session.query(Notification).filter(Notification.user_id == user_id).all()
return [
NotificationOutput(**notification.model_dump(exclude_none=True)) for notification in notifications
NotificationOutput(**notification.__dict__) for notification in notifications
]

def get_by_id(self, _id: UUID4) -> NotificationOutput:
notification = self.session.query(Notification).filter(Notification.id == id).first()
notification = self.session.query(Notification).filter(Notification.id == _id).first()
return NotificationOutput(**notification.model_dump(exclude_none=True))

def notification_exists_by_id(self, _id: UUID4) -> bool:
notification = self.session.query(Notification).filter(Notification.id == id).first()
notification = self.session.query(Notification).filter(Notification.id == _id).first()
if notification:
return True
else:
return False

def get_notification(self, _id: UUID4) -> Type[Notification]:
notification = self.session.query(Notification).filter(Notification.id == id).first()
notification = self.session.query(Notification).filter(Notification.id == _id).first()
return notification

def mark_as_read(self, notification: Type[Notification]) -> bool:
Expand All @@ -52,7 +47,7 @@ def mark_as_read(self, notification: Type[Notification]) -> bool:
return True

def get_notification_by_id(self, _id: UUID4) -> NotificationOutput:
notification = self.session.query(Notification).filter(Notification.id == id).first()
notification = self.session.query(Notification).filter(Notification.id == _id).first()
return NotificationOutput(**notification.model_dump(exclude_none=True))

def update_offers(self, notification: Type[Notification], offers: List[Type[Offer]]) -> bool:
Expand Down
8 changes: 4 additions & 4 deletions app/repositories/notificationfilter_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def create(self, data: NotificationFilterInput) -> NotificationFilterOutput:
self.session.add(notification)
self.session.commit()
self.session.refresh(notification)
return NotificationFilterOutput(**notification.dict(), id=notification.id)
return NotificationFilterOutput(**notification.__dict__)

def notification_exists_by_id(self, _id: UUID4) -> bool:
notification = self.session.query(NotificationFilter).filter(NotificationFilter.id == _id).first()
Expand All @@ -33,15 +33,15 @@ def get_notification_by_id(self, _id: UUID4) -> Type[NotificationFilter]:

def get_all(self) -> List[NotificationFilterOutput]:
notifications = self.session.query().all()
return [NotificationFilterOutput(**notification) for notification in notifications]
return [NotificationFilterOutput(**notification.__dict__) for notification in notifications]

def get_all_active(self) -> List[NotificationFilterOutput]:
notifications = self.session.query(NotificationFilter).filter(NotificationFilter.active == True).all()
return [NotificationFilterOutput(**notification) for notification in notifications]
return [NotificationFilterOutput(**notification.__dict__) for notification in notifications]

def get_all_by_user(self, user_id: UUID4) -> List[NotificationFilterOutput]:
notifications = self.session.query(NotificationFilter).filter(NotificationFilter.user_id == user_id).all()
return [NotificationFilterOutput(**notification) for notification in notifications]
return [NotificationFilterOutput(**notification.__dict__) for notification in notifications]

def delete(self, notification: Type[NotificationFilter]) -> bool:
self.session.delete(notification)
Expand Down
15 changes: 9 additions & 6 deletions app/routers/v1/notification.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,21 @@

from auth.auth import get_current_user
from config.database import get_db
from schemas.notification_filter import NotificationFilterInput, NotificationFilterOutput, \
from schemas.notification_filter import (
NotificationFilterInput,
NotificationFilterOutput,
NotificationFilterUpdateStatus
)
from schemas.user import UserInDB
from services.notificationfilter_service import NotificationFilterService

router = APIRouter(
prefix="/notification",
tags=["notification"]
tags=["notif"]
)


@router.post("", status_code=201, response_model=NotificationFilterOutput)
@router.post("/filter", status_code=201, response_model=NotificationFilterOutput)
def create(
notification: NotificationFilterInput,
db: Session = Depends(get_db),
Expand All @@ -28,7 +31,7 @@ def create(
return _service.create(notification)


@router.put("/{_id}", status_code=200)
@router.put("/filter/{_id}", status_code=200)
def update_status(
status: NotificationFilterUpdateStatus,
_id: UUID4,
Expand All @@ -39,7 +42,7 @@ def update_status(
return _service.update_status(_id, status.status)


@router.delete("/{_id}", status_code=204)
@router.delete("/filter/{_id}", status_code=204)
def delete(
_id: UUID4,
db: Session = Depends(get_db),
Expand All @@ -49,7 +52,7 @@ def delete(
return _service.delete(_id)


@router.get("", status_code=200, response_model=List[NotificationFilterOutput])
@router.get("/filter", status_code=200, response_model=List[NotificationFilterOutput])
def get_all_by_user(
db: Session = Depends(get_db),
current_user: UserInDB = Depends(get_current_user)
Expand Down
11 changes: 9 additions & 2 deletions app/routers/v1/root.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
from fastapi import APIRouter
from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session

from notification.tasks import create_notifications
from config.database import get_db

router = APIRouter(
prefix="/health",
Expand All @@ -7,6 +11,9 @@


@router.get("", status_code=200)
def health():
def health(session: Session = Depends(get_db)):
""" Check if the service is running correctly """

create_notifications(session)

return {"status": "ok"}
2 changes: 0 additions & 2 deletions app/services/notification_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,11 @@ def update_offers(self, notification_id: UUID4, offers_id: List[Optional[UUID4]]
HTTPException(status_code=404, detail="Notification not found")

notification = self.repository.get_notification(notification_id)

offers = []
for _id in offers_id:
if _id is not None:
offer = self.offer_repository.get_offer_by_id(_id)
offers.append(offer)

return self.repository.update_offers(notification, offers)

def get_unread_user_count(self, user_id: UUID4) -> int:
Expand Down
2 changes: 1 addition & 1 deletion app/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)


def test_init_db():
def test_init_db() -> None:
offer.Offer.metadata.create_all(bind=engine)
location.Region.metadata.create_all(bind=engine)
location.City.metadata.create_all(bind=engine)
Expand Down
8 changes: 4 additions & 4 deletions app/tests/test_city.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@ def region(test_get_db):
return region


def test_success_create_city_object(test_get_db, region):
def test_success_create_city_object(test_get_db, region) -> None:
repository = CityRepository(test_get_db)
city = repository.create(CityInput(name="Uć", region_id=region.id))

assert city.name == "Uć"
assert city.region_id == region.id


def test_success_get_all_city_objects(test_get_db, region):
def test_success_get_all_city_objects(test_get_db, region) -> None:
repository = CityRepository(test_get_db)
repository.create(CityInput(name="Uć", region_id=region.id))
repository.create(CityInput(name="Łódź", region_id=region.id))
Expand All @@ -37,7 +37,7 @@ def test_success_get_all_city_objects(test_get_db, region):
assert get_all[1].name == "Łódź"


def test_success_get_all_city_objects_by_region(test_get_db, region):
def test_success_get_all_city_objects_by_region(test_get_db, region) -> None:
repository = CityRepository(test_get_db)
repository.create(CityInput(name="Uć", region_id=region.id))
repository.create(CityInput(name="Łódź", region_id=region.id))
Expand All @@ -48,7 +48,7 @@ def test_success_get_all_city_objects_by_region(test_get_db, region):
assert get_all[1].name == "Łódź"


def test_success_city_exists_by_name(test_get_db, region):
def test_success_city_exists_by_name(test_get_db, region) -> None:
repository = CityRepository(test_get_db)
repository.create(CityInput(name="Uć", region_id=region.id))

Expand Down
4 changes: 2 additions & 2 deletions app/tests/test_favourite.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,14 +75,14 @@ def offer(test_get_db, city):
return offer


def test_success_create_favourite_object(test_get_db, offer, user):
def test_success_create_favourite_object(test_get_db, offer, user) -> None:
repository = FavouriteRepository(test_get_db)
favourite = repository.create(FavouriteInput(user_id=user.id, offer_id=offer.id))
assert favourite.user_id == user.id
assert favourite.offer_id == offer.id


def test_get_all_by_user(test_get_db, offer, user):
def test_get_all_by_user(test_get_db, offer, user) -> None:
repository = FavouriteRepository(test_get_db)
repository.create(FavouriteInput(user_id=user.id, offer_id=offer.id))

Expand Down
6 changes: 3 additions & 3 deletions app/tests/test_offer.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def offer(test_get_db, city):
return data


def test_success_create_offer_scraper_object(test_get_db, offer, city):
def test_success_create_offer_scraper_object(test_get_db, offer, city) -> None:
repository = OfferRepository(test_get_db)
offer = repository.create_scraper(offer, city.id)
assert offer.title == "test offer"
Expand All @@ -72,13 +72,13 @@ def test_success_create_offer_scraper_object(test_get_db, offer, city):
assert offer.photos[0].url == "https://google.com/img123"


def test_success_offer_exists_by_url(test_get_db, offer, city):
def test_success_offer_exists_by_url(test_get_db, offer, city) -> None:
repository = OfferRepository(test_get_db)
offer = repository.create_scraper(offer, city.id)
assert repository.offer_exists_by_url(offer.details_url)


def test_success_offer_exists_by_id(test_get_db, offer, city):
def test_success_offer_exists_by_id(test_get_db, offer, city) -> None:
repository = OfferRepository(test_get_db)
offer = repository.create_scraper(offer, city.id)
assert repository.offer_exists_by_id(offer.id)
10 changes: 5 additions & 5 deletions app/tests/test_region.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,27 +8,27 @@
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))


def test_success_create_region_object(test_get_db):
def test_success_create_region_object(test_get_db) -> None:
repository = RegionRepository(test_get_db)
created_region = repository.create(RegionInput(name="Test Region"))
assert created_region.name == "Test Region"


def test_success_get_all_regions_objects(test_get_db):
def test_success_get_all_regions_objects(test_get_db) -> None:
repository = RegionRepository(test_get_db)
repository.create(RegionInput(name="Test Region"))
regions = repository.get_all()
assert len(regions) == 1


def test_success_get_region_by_id(test_get_db):
def test_success_get_region_by_id(test_get_db) -> None:
repository = RegionRepository(test_get_db)
created_region = repository.create(RegionInput(name="Test Region"))
fetched_region = repository.get_by_id(created_region.id)
assert fetched_region.name == "Test Region"


def test_success_update_region(test_get_db):
def test_success_update_region(test_get_db) -> None:
repository = RegionRepository(test_get_db)
created_region = repository.create(RegionInput(name="Test Region"))
region = repository.get_by_id(created_region.id)
Expand All @@ -37,7 +37,7 @@ def test_success_update_region(test_get_db):
assert updated_region.name == "Updated Region"


def test_success_delete_region(test_get_db):
def test_success_delete_region(test_get_db) -> None:
repository = RegionRepository(test_get_db)
created_region = repository.create(RegionInput(name="Test Region"))
region = repository.get_by_id(created_region.id)
Expand Down
10 changes: 5 additions & 5 deletions app/tests/test_user.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,24 +25,24 @@ def user(test_get_db):
return user


def test_success_create_user_object(test_get_db, user):
def test_success_create_user_object(test_get_db, user) -> None:
assert user.email == "[email protected]"
assert user.username == "test_user"
assert user.is_active
assert user.is_superuser == False


def test_success_user_exists_by_email(test_get_db, user):
def test_success_user_exists_by_email(test_get_db, user) -> None:
repository = UserRepository(test_get_db)
assert repository.user_exists_by_email(user.email)


def test_success_user_exists_by_username(test_get_db, user):
def test_success_user_exists_by_username(test_get_db, user) -> None:
repository = UserRepository(test_get_db)
assert repository.user_exists_by_username(user.username)


def test_success_get_user_by_email(test_get_db, user):
def test_success_get_user_by_email(test_get_db, user) -> None:
repository = UserRepository(test_get_db)
user = repository.get_user_by_email(user.email)
assert user.email == "[email protected]"
Expand All @@ -51,7 +51,7 @@ def test_success_get_user_by_email(test_get_db, user):
assert user.is_superuser == False


def test_success_get_user_by_username(test_get_db, user):
def test_success_get_user_by_username(test_get_db, user) -> None:
repository = UserRepository(test_get_db)
user = repository.get_user_by_username(user.username)
assert user.email == "[email protected]"
Expand Down

0 comments on commit 1878686

Please sign in to comment.