Skip to content

Commit

Permalink
fixed GAE support
Browse files Browse the repository at this point in the history
  • Loading branch information
mdipierro committed Dec 9, 2024
1 parent 57bec7d commit f725019
Show file tree
Hide file tree
Showing 12 changed files with 79 additions and 48 deletions.
2 changes: 1 addition & 1 deletion pydal/adapters/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ def _parse(
else:
#: fields[j] may be None if only 'colnames' was specified in db.executesql()
field = fields[j]
f_itype, ftype = (field and [field._itype, field.type] or [None, None])
f_itype, ftype = field and [field._itype, field.type] or [None, None]
value = self.parse_value(value, f_itype, ftype, blob_decode)
# for aliased fields use the aliased name
if isinstance(field, Expression) and field.op == self.dialect._as:
Expand Down
37 changes: 13 additions & 24 deletions pydal/adapters/google.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,14 @@
from .mysql import MySQL
from .postgres import PostgrePsyco

try:
from google.cloud import firestore
except ImportError:
pass
else:
from google.cloud.firestore_v1 import aggregation


class GoogleMigratorMixin(object):
migrator_cls = InDBMigrator

Expand Down Expand Up @@ -150,23 +158,9 @@ class Firestore(NoSQLAdapter):

def _initialize_(self):

import firebase_admin
from firebase_admin import credentials, firestore
super(Firestore, self)._initialize_()
args = parse_qs(self.uri.split("//")[1]) if "//" in self.uri else {}

# for thread safety recycle the app in cache
if "cred" in args:
# not on google app engine
cred = credentials.Certificate(args["cred"][0])
else:
# on google app engine
cred = credentials.ApplicationDefault()
try:
app = firebase_admin.initialize_app(cred, args.get("name", [None])[0])
except ValueError:
app = firebase_admin.get_app()
self._client = firestore.client(app, args.get("id", [None])[0])
self._client = firestore.Client(project=args.get("id", [None])[0])

def find_driver(self):
return
Expand Down Expand Up @@ -208,8 +202,6 @@ def represent(self, obj, field_type, tablename=None):

def apply_filter(self, source, table, query):

from google.cloud.firestore_v1.base_query import FieldFilter, BaseCompositeFilter

if isinstance(query, Query) and query.first is table._id:
if query.op.__name__ == "eq":
return source.document(str(query.second)).get()
Expand All @@ -229,8 +221,6 @@ def apply_filter(self, source, table, query):

def get_docs(self, table, query, orderby=None, limitby=None):

from firebase_admin import firestore

source = self._client.collection(table._tablename)
source = self.apply_filter(source, table, query)

Expand Down Expand Up @@ -299,8 +289,6 @@ def select(self, query, fields, attributes):

def count(self, query, distinct=None, limit=None):
# OK
from google.cloud.firestore_v1 import aggregation

if distinct:
raise RuntimeError("COUNT DISTINCT not supported")
table = self.get_table(query)
Expand Down Expand Up @@ -342,7 +330,6 @@ def update(self, table, query, update_fields):
return counter

def truncate(self, table, mode=""):
# OK
def delete_collection(coll_ref, batch_size):
if batch_size == 0:
return
Expand All @@ -352,7 +339,8 @@ def delete_collection(coll_ref, batch_size):
for doc in docs:
batch.delete(doc)
deleted = deleted + 1
batch.commit()
if deleted:
batch.commit()
if deleted >= batch_size:
return delete_collection(coll_ref, batch_size)

Expand Down Expand Up @@ -390,5 +378,6 @@ def bulk_insert(self, table, items):
rid = Reference(id)
rid._table, rid._recor = table, None
ids.append(rid)
batch.commit()
if items:
batch.commit()
return ids
1 change: 1 addition & 0 deletions pydal/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@
)
)


# TESTING
class MetaDAL(type):
def __call__(cls, *args, **kwargs):
Expand Down
62 changes: 48 additions & 14 deletions pydal/dialects/google.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,12 @@
from .base import NoSQLDialect

try:
from firebase_admin import firestore
from google.cloud.firestore_v1.base_query import FieldFilter, Or
from google.cloud import firestore
from google.cloud.firestore_v1.base_query import (
BaseCompositeFilter,
FieldFilter,
Or,
)
except ImportError:
pass

Expand All @@ -13,34 +17,64 @@
class FirestoreDialect(NoSQLDialect):

def _and(self, first, second, query_env={}):
filters = first if isinstance(first, list) else [first]
filters += second if isinstance(second, list) else [second]
return filters
a = self.expand(first, query_env=query_env)
b = self.expand(second, query_env=query_env)
filters = (
a.filters
if isinstance(a, BaseCompositeFilter) and a.operator == "AND"
else [a]
)
filters += (
b.filters
if isinstance(b, BaseCompositeFilter) and a.operator == "AND"
else [b]
)
return BaseCompositeFilter("AND", filters)

def _or(self, first, second, query_env={}):
a = self.expand(first, query_env=query_env)
b = self.expand(second, query_env=query_env)
filters = a.filters if isinstance(a, Or) else [a]
filters += b.filters if isinstance(b, Or) else [b]
return Or(filters)
filters = (
a.filters
if isinstance(a, BaseCompositeFilter) and a.operator == "OR"
else [a]
)
filters += (
b.filters
if isinstance(b, BaseCompositeFilter) and a.operator == "OR"
else [b]
)
return BaseCompositeFilter("OR", filters)

def eq(self, first, second=None, query_env={}):
return FieldFilter(first.name, "==", second)
return FieldFilter(
first.name, "==", self.expand(second, first.type, query_env=query_env)
)

def ne(self, first, second=None, query_env={}):
return FieldFilter(first.name, "!=", second)
return FieldFilter(
first.name, "!=", self.expand(second, first.type, query_env=query_env)
)

def lt(self, first, second=None, query_env={}):
return FieldFilter(first.name, "<", second)
return FieldFilter(
first.name, "<", self.expand(second, first.type, query_env=query_env)
)

def lte(self, first, second=None, query_env={}):
return FieldFilter(first.name, "<=", second)
return FieldFilter(
first.name, "<=", self.expand(second, first.type, query_env=query_env)
)

def gt(self, first, second=None, query_env={}):
return FieldFilter(first.name, ">", second)
return FieldFilter(
first.name, ">", self.expand(second, first.type, query_env=query_env)
)

def gte(self, first, second=None, query_env={}):
return FieldFilter(first.name, ">=", second)
return FieldFilter(
first.name, ">=", self.expand(second, first.type, query_env=query_env)
)

def invert(self, first, query_env={}):
raise NotImplementedError
Expand Down
2 changes: 1 addition & 1 deletion pydal/drivers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
DRIVERS = {}

try:
from firebase_admin import credentials, firestore
from google.cloud import firestore
from google.cloud.firestore_v1.base_query import FieldFilter

DRIVERS["firestore"] = firestore
Expand Down
2 changes: 1 addition & 1 deletion pydal/parsers/google.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def _json(self, value):

@for_type("blob")
def _json(self, value):
return base64.b64encode(value)
return value

@before_parse("reference")
def reference_extras(self, field_type):
Expand Down
2 changes: 1 addition & 1 deletion pydal/representers/google.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def _json(self, value):

@for_type("blob")
def _blob(self, value):
return base64.b64decode(value)
return value

@for_type("reference")
def _reference(self, value):
Expand Down
1 change: 1 addition & 0 deletions pydal/restapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import functools
import re
import traceback

from .utils import utcnow

__version__ = "0.1"
Expand Down
1 change: 1 addition & 0 deletions pydal/tools/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from pydal import DAL, Field
from pydal.validators import IS_IN_SET

from ..utils import utcnow


Expand Down
5 changes: 4 additions & 1 deletion pydal/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,23 @@
:license: BSD, see LICENSE for more details.
"""

import datetime
import re
import warnings
import datetime


class RemovedInNextVersionWarning(DeprecationWarning):
pass


warnings.simplefilter("always", RemovedInNextVersionWarning)


def utcnow():
"""returns the current time in utc"""
return datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)


def warn_of_deprecation(old_name, new_name, prefix=None, stack=2):
msg = "%(old)s is deprecated, use %(new)s instead."
if prefix:
Expand Down
8 changes: 5 additions & 3 deletions pydal/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,13 +803,15 @@ def validate(self, value, record_id=None):
db = self.dbset.db
table = db[tablename]
field = table[fieldname]
dbset = self.dbset(field == value, ignore_common_filters=self.ignore_common_filters)
dbset = self.dbset(
field == value, ignore_common_filters=self.ignore_common_filters
)

# make sure exclude the record_id
id = record_id or self.record_id
if isinstance(id, dict):
id = table(**id)
record = dbset.select(table._id, limitby=(0,1)).first()
record = dbset.select(table._id, limitby=(0, 1)).first()
if record and record[table._id.name] != id:
raise ValidationError(self.translator(self.error_message))
return value
Expand Down Expand Up @@ -1352,7 +1354,7 @@ def __init__(self, error_message="Invalid emails: %s"):

def validate(self, value, record_id=None):
bad_emails = []
f = IS_EMAIL()
f = IS_EMAIL()
if isinstance(value, str):
emails = re.findall(self.REGEX_NOT_EMAIL_SPLITTER, value)
else:
Expand Down
4 changes: 2 additions & 2 deletions tests/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,7 @@ def test_IS_NOT_IN_DB(self):
db = DAL("sqlite:memory")
db.define_table("person", Field("name"), Field("nickname"))
db.person.insert(name="george")
db.person.insert(name="costanza", nickname="T Bone")
costanza_id = db.person.insert(name="costanza", nickname="T Bone")
rtn = IS_NOT_IN_DB(db, "person.name", error_message="oops")("george")
self.assertEqual(rtn, ("george", "oops"))
rtn = IS_NOT_IN_DB(
Expand All @@ -440,7 +440,7 @@ def test_IS_NOT_IN_DB(self):
rtn = IS_NOT_IN_DB(db, db.person, error_message="oops")(1)
self.assertEqual(rtn, (1, "oops"))
vldtr = IS_NOT_IN_DB(db, "person.name", error_message="oops")
vldtr.set_self_id({"name": "costanza", "nickname": "T Bone"})
vldtr.set_self_id(costanza_id)
rtn = vldtr("george")
self.assertEqual(rtn, ("george", "oops"))
rtn = vldtr("costanza")
Expand Down

0 comments on commit f725019

Please sign in to comment.