Skip to content

Commit

Permalink
feat: SQLAlchemy v2 support
Browse files Browse the repository at this point in the history
  • Loading branch information
smotornyuk authored and duttonw committed Oct 15, 2024
1 parent 8535a18 commit f2e8f86
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 58 deletions.
31 changes: 18 additions & 13 deletions ckanext/xloader/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ def init(config, echo=False):
global ENGINE, _METADATA, JOBS_TABLE, METADATA_TABLE, LOGS_TABLE
db_uri = config.get('ckanext.xloader.jobs_db.uri',
'sqlite:////tmp/xloader_jobs.db')
ENGINE = sqlalchemy.create_engine(db_uri, echo=echo, convert_unicode=True)
_METADATA = sqlalchemy.MetaData(ENGINE)
ENGINE = sqlalchemy.create_engine(db_uri, echo=echo)
_METADATA = sqlalchemy.MetaData()
JOBS_TABLE = _init_jobs_table()
METADATA_TABLE = _init_metadata_table()
LOGS_TABLE = _init_logs_table()
Expand Down Expand Up @@ -111,8 +111,10 @@ def get_job(job_id):
if job_id:
job_id = six.text_type(job_id)

result = ENGINE.execute(
JOBS_TABLE.select().where(JOBS_TABLE.c.job_id == job_id)).first()
with ENGINE.connect() as conn:
result = conn.execute(
JOBS_TABLE.select().where(JOBS_TABLE.c.job_id == job_id)
).first()

if not result:
return None
Expand Down Expand Up @@ -298,10 +300,11 @@ def _update_job(job_id, job_dict):
if "data" in job_dict:
job_dict["data"] = six.text_type(job_dict["data"])

ENGINE.execute(
JOBS_TABLE.update()
.where(JOBS_TABLE.c.job_id == job_id)
.values(**job_dict))
with ENGINE.begin() as conn:
conn.execute(
JOBS_TABLE.update()
.where(JOBS_TABLE.c.job_id == job_id)
.values(**job_dict))


def mark_job_as_completed(job_id, data=None):
Expand Down Expand Up @@ -443,9 +446,10 @@ def _get_metadata(job_id):
# warnings.
job_id = six.text_type(job_id)

results = ENGINE.execute(
METADATA_TABLE.select().where(
METADATA_TABLE.c.job_id == job_id)).fetchall()
with ENGINE.connect() as conn:
results = conn.execute(
METADATA_TABLE.select().where(
METADATA_TABLE.c.job_id == job_id)).fetchall()
metadata = {}
for row in results:
value = row['value']
Expand All @@ -461,8 +465,9 @@ def _get_logs(job_id):
# warnings.
job_id = six.text_type(job_id)

results = ENGINE.execute(
LOGS_TABLE.select().where(LOGS_TABLE.c.job_id == job_id)).fetchall()
with ENGINE.connect() as conn:
results = conn.execute(
LOGS_TABLE.select().where(LOGS_TABLE.c.job_id == job_id)).fetchall()

results = [dict(result) for result in results]

Expand Down
73 changes: 39 additions & 34 deletions ckanext/xloader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from six.moves import zip
from tabulator import config as tabulator_config, EncodingError, Stream, TabulatorException
from unidecode import unidecode
import sqlalchemy as sa

import ckan.plugins as p

Expand Down Expand Up @@ -118,8 +119,8 @@ def _clear_datastore_resource(resource_id):
'''
engine = get_write_engine()
with engine.begin() as conn:
conn.execute("SET LOCAL lock_timeout = '5s'")
conn.execute('TRUNCATE TABLE "{}"'.format(resource_id))
conn.execute(sa.text("SET LOCAL lock_timeout = '5s'"))
conn.execute(sa.text('TRUNCATE TABLE "{}"'.format(resource_id)))


def load_csv(csv_filepath, resource_id, mimetype='text/csv', logger=None):
Expand Down Expand Up @@ -253,12 +254,17 @@ def load_csv(csv_filepath, resource_id, mimetype='text/csv', logger=None):
except Exception as e:
raise LoaderError('Could not create the database table: {}'
.format(e))
connection = context['connection'] = engine.connect()


# datstore_active is switched on by datastore_create - TODO temporarily
# disable it until the load is complete
_disable_fulltext_trigger(connection, resource_id)
_drop_indexes(context, data_dict, False)

with engine.begin() as conn:
_disable_fulltext_trigger(conn, resource_id)

with engine.begin() as conn:
context['connection'] = conn
_drop_indexes(context, data_dict, False)

logger.info('Copying to database...')

Expand All @@ -276,9 +282,8 @@ def load_csv(csv_filepath, resource_id, mimetype='text/csv', logger=None):
# 4. COPY FROM STDIN - not quite as fast as COPY from a file, but avoids
# the superuser issue. <-- picked

raw_connection = engine.raw_connection()
try:
cur = raw_connection.cursor()
with engine.begin() as conn:
cur = conn.connection.cursor()
try:
with open(csv_filepath, 'rb') as f:
# can't use :param for table name because params are only
Expand Down Expand Up @@ -308,15 +313,14 @@ def load_csv(csv_filepath, resource_id, mimetype='text/csv', logger=None):

finally:
cur.close()
finally:
raw_connection.commit()
finally:
os.remove(csv_filepath) # i.e. the tempfile

logger.info('...copying done')

logger.info('Creating search index...')
_populate_fulltext(connection, resource_id, fields=fields)
with engine.begin() as conn:
_populate_fulltext(conn, resource_id, fields=fields)
logger.info('...search index created')

return fields
Expand Down Expand Up @@ -550,9 +554,9 @@ def fulltext_function_exists(connection):
https://github.com/ckan/ckan/pull/3786
or otherwise it is checked on startup of this plugin.
'''
res = connection.execute('''
res = connection.execute(sa.text('''
select * from pg_proc where proname = 'populate_full_text_trigger';
''')
'''))
return bool(res.rowcount)


Expand All @@ -561,24 +565,25 @@ def fulltext_trigger_exists(connection, resource_id):
This will only be the case if your CKAN is new enough to have:
https://github.com/ckan/ckan/pull/3786
'''
res = connection.execute('''
res = connection.execute(sa.text('''
SELECT pg_trigger.tgname FROM pg_class
JOIN pg_trigger ON pg_class.oid=pg_trigger.tgrelid
WHERE pg_class.relname={table}
AND pg_trigger.tgname='zfulltext';
'''.format(
table=literal_string(resource_id)))
table=literal_string(resource_id))))
return bool(res.rowcount)


def _disable_fulltext_trigger(connection, resource_id):
connection.execute('ALTER TABLE {table} DISABLE TRIGGER zfulltext;'
.format(table=identifier(resource_id)))
connection.execute(sa.text('ALTER TABLE {table} DISABLE TRIGGER zfulltext;'
.format(table=identifier(resource_id, True))))


def _enable_fulltext_trigger(connection, resource_id):
connection.execute('ALTER TABLE {table} ENABLE TRIGGER zfulltext;'
.format(table=identifier(resource_id)))
connection.execute(sa.text(
'ALTER TABLE {table} ENABLE TRIGGER zfulltext;'
.format(table=identifier(resource_id, True))))


def _populate_fulltext(connection, resource_id, fields):
Expand All @@ -591,23 +596,20 @@ def _populate_fulltext(connection, resource_id, fields):
fields: list of dicts giving the each column's 'id' (name) and 'type'
(text/numeric/timestamp)
'''
sql = \
u'''
UPDATE {table}
SET _full_text = to_tsvector({cols});
'''.format(
# coalesce copes with blank cells
table=identifier(resource_id),
cols=" || ' ' || ".join(
stmt = sa.update(sa.table(resource_id, sa.column("_full_text"))).values(
_full_text=sa.text("to_tsvector({})".format(
" || ' ' || ".join(
'coalesce({}, \'\')'.format(
identifier(field['id'])
+ ('::text' if field['type'] != 'text' else '')
)
for field in fields
if not field['id'].startswith('_')
)
)
connection.execute(sql)
))
)

connection.execute(stmt)


def calculate_record_count(resource_id, logger):
Expand All @@ -619,15 +621,18 @@ def calculate_record_count(resource_id, logger):
logger.info('Calculating record count (running ANALYZE on the table)')
engine = get_write_engine()
conn = engine.connect()
conn.execute("ANALYZE \"{resource_id}\";"
.format(resource_id=resource_id))
conn.execute(sa.text("ANALYZE \"{resource_id}\";"
.format(resource_id=resource_id)))


def identifier(s):
def identifier(s, escape_binds=False):
# "%" needs to be escaped, otherwise connection.execute thinks it is for
# substituting a bind parameter
return u'"' + s.replace(u'"', u'""').replace(u'\0', '').replace('%', '%%')\
+ u'"'
escaped = s.replace(u'"', u'""').replace(u'\0', '')
if escape_binds:
escaped = escaped.replace('%', '%%')

return u'"' + escaped + u'"'


def literal_string(s):
Expand Down
22 changes: 11 additions & 11 deletions ckanext/xloader/tests/test_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
import pytest
import six
import sqlalchemy as sa
import sqlalchemy.orm as orm
import datetime
import logging
Expand Down Expand Up @@ -47,17 +48,16 @@ def _get_records(
c = Session.connection()
if exclude_full_text_column:
cols = self._get_column_names(Session, table_name)
cols = ", ".join(
loader.identifier(col) for col in cols if col != "_full_text"
)
cols = [
sa.column(col) for col in cols if col != "_full_text"
]
else:
cols = "*"
sql = 'SELECT {cols} FROM "{table_name}"'.format(
cols=cols, table_name=table_name
)
cols = [sa.text("*")]
stmt = sa.select(*cols).select_from(sa.table(table_name))

if limit is not None:
sql += " LIMIT {}".format(limit)
results = c.execute(sql)
stmt = stmt.limit(limit)
results = c.execute(stmt)
return results.fetchall()

def _get_column_names(self, Session, table_name):
Expand All @@ -71,7 +71,7 @@ def _get_column_names(self, Session, table_name):
ORDER BY ordinal_position;
""".format(table_name)
)
results = c.execute(sql)
results = c.execute(sa.text(sql))
records = results.fetchall()
return [r[0] for r in records]

Expand All @@ -85,7 +85,7 @@ def _get_column_types(self, Session, table_name):
ORDER BY ordinal_position;
""".format(table_name)
)
results = c.execute(sql)
results = c.execute(sa.text(sql))
records = results.fetchall()
return [r[0] for r in records]

Expand Down

0 comments on commit f2e8f86

Please sign in to comment.