From 6845ad57e8a964d4c5f657e67dd6a21322a6721d Mon Sep 17 00:00:00 2001 From: Mish Savelyev <1564970+sausage-todd@users.noreply.github.com> Date: Tue, 17 Oct 2023 12:38:26 +0200 Subject: [PATCH] Wrap all sqlalchemy usages of session into a block (#1706) --- .../crowd/backend/repository/repository.py | 140 +++++++++--------- 1 file changed, 72 insertions(+), 68 deletions(-) diff --git a/backend/src/serverless/microservices/python/crowd-backend/crowd/backend/repository/repository.py b/backend/src/serverless/microservices/python/crowd-backend/crowd/backend/repository/repository.py index cc7bfc76b2..e49b65f1dc 100644 --- a/backend/src/serverless/microservices/python/crowd-backend/crowd/backend/repository/repository.py +++ b/backend/src/serverless/microservices/python/crowd-backend/crowd/backend/repository/repository.py @@ -70,8 +70,7 @@ def __init__(self, tenant_id="", db_url=False, test=False, send=True): ) Base.metadata.create_all(self.engine, checkfirst=True) - Session = sessionmaker(bind=self.engine) - self.session = Session() + self.Session = sessionmaker(bind=self.engine) self.tenant_id = tenant_id self.send = send @@ -106,25 +105,25 @@ def find_in_table(self, table, query, many=False): dict: document """ - search_query = self.session.query(table) - for attr, value in query.items(): - - # Check if query is nested - nested_count = attr.count(".") - # If nested - if nested_count > 0: - attributes = attr.split(".") - nested_attributes = tuple(attributes[1:]) - # Define nested expression - expr = getattr(table, attributes[0])[nested_attributes] - # Execute search_query - search_query = search_query.filter(expr == json.dumps(value)) - else: - search_query = search_query.filter(getattr(table, attr) == value) + with self.Session() as session: + search_query = session.query(table) + for attr, value in query.items(): + # Check if query is nested + nested_count = attr.count(".") + # If nested + if nested_count > 0: + attributes = attr.split(".") + nested_attributes = tuple(attributes[1:]) + # Define nested expression + expr = getattr(table, attributes[0])[nested_attributes] + # Execute search_query + search_query = search_query.filter(expr == json.dumps(value)) + else: + search_query = search_query.filter(getattr(table, attr) == value) - if many: - return search_query.all() - return search_query.first() + if many: + return search_query.all() + return search_query.first() def find_by_id(self, table, id): """ @@ -138,7 +137,8 @@ def find_by_id(self, table, id): dict: the document """ - return self.session.query(table).get(id) + with self.Session() as session: + return session.query(table).get(id) def find_all_usernames(self): with self.engine.connect() as con: @@ -146,7 +146,8 @@ def find_all_usernames(self): f"""select m."id", mw."username", m."displayName", m."emails" from "members" m inner join "memberActivityAggregatesMVs" mw on m.id = mw.id - where m."tenantId" = '{self.tenant_id}'""").fetchall() + where m."tenantId" = '{self.tenant_id}'""" + ).fetchall() def find_all( self, table, ignore_tenant: "bool" = False, query: "dict" = None, order: "dict" = None @@ -173,29 +174,30 @@ def find_all( **{dbk.TENANT: uuid.UUID(self.tenant_id)}, } - search_query = self.session.query(table) - for attr, value in query.items(): - # Check if query is nested - nested_count = attr.count(".") - # If nested - if nested_count > 0: - attributes = attr.split(".") - nested_attributes = tuple(attributes[1:]) - # Define nested expression - expr = getattr(table, attributes[0])[nested_attributes] - # Execute search_query - search_query = search_query.filter(expr == json.dumps(value)) - else: - search_query = search_query.filter(getattr(table, attr) == value) - - if order: - for key, value in order.items(): - if value: - search_query = search_query.order_by(asc(key)) + with self.Session() as session: + search_query = session.query(table) + for attr, value in query.items(): + # Check if query is nested + nested_count = attr.count(".") + # If nested + if nested_count > 0: + attributes = attr.split(".") + nested_attributes = tuple(attributes[1:]) + # Define nested expression + expr = getattr(table, attributes[0])[nested_attributes] + # Execute search_query + search_query = search_query.filter(expr == json.dumps(value)) else: - search_query = search_query.order_by(desc(key)) + search_query = search_query.filter(getattr(table, attr) == value) - return search_query.all() + if order: + for key, value in order.items(): + if value: + search_query = search_query.order_by(asc(key)) + else: + search_query = search_query.order_by(desc(key)) + + return search_query.all() def find_activities(self, search_filters=None): if not search_filters: @@ -208,22 +210,23 @@ def count(self, table, search_filters=None): search_filters[dbk.TENANT] = uuid.UUID(self.tenant_id) - search_query = self.session.query(table) - for attr, value in search_filters.items(): - # Check if query is nested - nested_count = attr.count(".") - # If nested - if nested_count > 0: - attributes = attr.split(".") - nested_attributes = tuple(attributes[1:]) - # Define nested expression - expr = getattr(table, attributes[0])[nested_attributes] - # Execute query - search_query = search_query.filter(expr == json.dumps(value)) - else: - search_query = search_query.filter(getattr(table, attr) == value) + with self.Session() as session: + search_query = session.query(table) + for attr, value in search_filters.items(): + # Check if query is nested + nested_count = attr.count(".") + # If nested + if nested_count > 0: + attributes = attr.split(".") + nested_attributes = tuple(attributes[1:]) + # Define nested expression + expr = getattr(table, attributes[0])[nested_attributes] + # Execute query + search_query = search_query.filter(expr == json.dumps(value)) + else: + search_query = search_query.filter(getattr(table, attr) == value) - return search_query.count() + return search_query.count() def find_available_microservices(self, service): """ @@ -253,16 +256,17 @@ def find_new_members(self, microservice, query: "dict" = None) -> "list[dict]": **{dbk.TENANT: uuid.UUID(self.tenant_id)}, } - search_query = self.session.query(Member) + with self.Session() as session: + search_query = session.query(Member) - # Filter with query - for attr, value in query.items(): - search_query = search_query.filter(getattr(Member, attr) == value) + # Filter with query + for attr, value in query.items(): + search_query = search_query.filter(getattr(Member, attr) == value) - # Find members that are new - # We use a security padding of 5 minutes - search_query = search_query.filter( - Member.createdAt >= (microservice.updatedAt - timedelta(minutes=5)) - ).order_by(Member.createdAt.desc()) + # Find members that are new + # We use a security padding of 5 minutes + search_query = search_query.filter( + Member.createdAt >= (microservice.updatedAt - timedelta(minutes=5)) + ).order_by(Member.createdAt.desc()) - return search_query.all() + return search_query.all()