Skip to content

Commit

Permalink
Add support for binding parameters by index
Browse files Browse the repository at this point in the history
Signed-off-by: Salil Chandra <[email protected]>
  • Loading branch information
chands10 committed Jan 7, 2025
1 parent 518dd75 commit ec9ddc5
Show file tree
Hide file tree
Showing 6 changed files with 144 additions and 14 deletions.
17 changes: 13 additions & 4 deletions comdb2/_ccdb2.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -401,15 +401,24 @@ cdef class Handle(object):
param_guards = []
try:
if parameters is not None:
for key, val in parameters.items():
ckey = _string_as_bytes(key)
bind_by_index = isinstance(parameters, (list, tuple))
items = enumerate(parameters, 1) if bind_by_index \
else parameters.items()
for key, val in items:
ckey = _string_as_bytes(key) if not bind_by_index else key
cval = _ParameterValue(val, key)
param_guards.append(ckey)
param_guards.append(cval)
if cval.list_size == -1:
rc = lib.cdb2_bind_param(self.hndl, <char*>ckey,
cval.type, cval.data, cval.size)
if bind_by_index:
rc = lib.cdb2_bind_index(self.hndl, ckey,
cval.type, cval.data, cval.size)
else:
rc = lib.cdb2_bind_param(self.hndl, <char*>ckey,
cval.type, cval.data, cval.size)
else:
if bind_by_index:
raise ValueError("Binding arrays by index is currently unsupported. Bind arrays by name.")
# Bind Array if cval is an array
rc = lib.cdb2_bind_array(self.hndl, <char*>ckey, cval.type, cval.data, cval.list_size, cval.size)
_errchk(rc, self.hndl)
Expand Down
1 change: 1 addition & 0 deletions comdb2/_cdb2api.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ cdef extern from "cdb2api.h" nogil:
void* cdb2_column_value(cdb2_hndl_tp* hndl, int col) except +
const char* cdb2_errstr(cdb2_hndl_tp* hndl) except +
int cdb2_bind_param(cdb2_hndl_tp *hndl, const char *name, int type, const void *varaddr, int length) except +
int cdb2_bind_index(cdb2_hndl_tp *hndl, int index, int type, const void *varaddr, int length) except +
int cdb2_bind_array(cdb2_hndl_tp *hndl, const char *name, cdb2_coltype, const void *varaddr, size_t count, size_t typelen) except +
int cdb2_clearbindings(cdb2_hndl_tp *hndl) except +
int cdb2_clear_ack(cdb2_hndl_tp *hndl) except +
20 changes: 18 additions & 2 deletions comdb2/cdb2.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,14 @@
examples we make use of the `list` constructor to turn the iterable returned
by `Handle.execute` into a list of result rows.
You can also bind by index by providing a sequence
instead of by name with placeholders specified using ``?`` in sequence.
Note that binding an array by index is currently not supported. For example:
>>> query = "select 25 between ? and ?"
>>> print(list(hndl.execute(query, [20, 42])))
[[1]]
Types
-----
Expand Down Expand Up @@ -359,7 +367,7 @@ def row_factory(
def execute(
self,
sql: str | bytes,
parameters: Mapping[str, ParameterValue] | None = None,
parameters: Mapping[str, ParameterValue] | Sequence[ParameterValue] | None = None,
*,
column_types: Sequence[ColumnType] | None = None,
) -> Handle:
Expand All @@ -369,7 +377,7 @@ def execute(
This should always be the preferred method of parameterizing the SQL
query, as it prevents SQL injection vulnerabilities and is faster.
Placeholders for named parameters must be in Comdb2's native format,
``@param_name``.
``@param_name``, or with ``?`` for positional parameters.
If ``column_types`` is provided and non-empty, it must be a sequence of
members of the `ColumnType` enumeration. The database will coerce the
Expand All @@ -383,6 +391,8 @@ def execute(
sql (str): The SQL string to execute.
parameters (Mapping[str, Any]): An optional mapping from parameter
names to the values to be bound for them.
(Sequence[Any]): Can also use sequence with ``?`` with parameters executed in sequence.
Note that binding arrays by index is currently not supported. These must be bound by name.
column_types (Sequence[int]): An optional sequence of types (values
of the `ColumnType` enumeration) which the columns of the
result set will be coerced to.
Expand All @@ -403,6 +413,12 @@ def execute(
... print(row)
[1, 2]
[2, 4]
>>> for row in hndl.execute("select 1, 2 UNION ALL select ?, ?",
... [2, 4]):
... print(row)
[1, 2]
[2, 4]
"""
if parameters is None:
parameters = {}
Expand Down
41 changes: 34 additions & 7 deletions comdb2/dbapi2.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,23 @@
When we run the same query with parameter ``b`` bound to ``23``, a ``0`` is
returned instead, because ``20 <= 25 <= 23`` is false.
Alternatively, you can also bind by index instead of by name, by providing a list/tuple
with placeholders specified using ``?`` in the same order as the elements in the list/tuple.
Note that binding an array by index is currently not supported, these must be bound by name. For example:
>>> query = "select 25 between ? and ?"
>>> print(conn.cursor().execute(query, [20, 42]).fetchall())
[[1]]
In this example, we execute the query with the first ``?`` bound to 20 and the second
``?`` bound to 42. Thus, a ``1`` is returned like in the previous example.
Note:
Because parameters are bound using ``%(name)s``, other ``%`` signs in
Because parameters by name are bound using ``%(name)s``, other ``%`` signs in
a query must be escaped. For example, ``WHERE name like 'M%'`` becomes
``WHERE name LIKE 'M%%'``.
However, this does not apply when binding parameters by index. ``%`` does not
need to be escaped in this case, and only in this case.
Types
-----
Expand Down Expand Up @@ -314,6 +327,9 @@
Because SQL strings for this module use the ``pyformat`` placeholder style,
any literal ``%`` characters in a query must be escaped by doubling them.
``WHERE name like 'M%'`` becomes ``WHERE name LIKE 'M%%'``.
This module also has support for ``qmark`` if binding by index.
``%`` does not need to be escaped if binding by position.
"""

_FIRST_WORD_OF_STMT = re.compile(
Expand Down Expand Up @@ -996,7 +1012,7 @@ def callproc(self, procname: str, parameters: Sequence[ParameterValue]) -> Seque
def execute(
self,
sql: str,
parameters: Mapping[str, ParameterValue] | None = None,
parameters: Mapping[str, ParameterValue] | Sequence[ParameterValue] | None = None,
*,
column_types: Sequence[ColumnType] | None = None,
) -> Cursor:
Expand All @@ -1005,6 +1021,8 @@ def execute(
The ``sql`` string must be provided as a Python format string, with
parameter placeholders represented as ``%(name)s`` and all other ``%``
signs escaped as ``%%``.
HOWEVER, if binding by index (parameter placeholders represented as ``?``),
``%`` does not need to be escaped. This is the only time it does not need to be escaped.
Note:
Using placeholders should always be the preferred method of
Expand All @@ -1029,6 +1047,8 @@ def execute(
sql (str): The SQL string to execute, as a Python format string.
parameters (Mapping[str, Any]): An optional mapping from parameter
names to the values to be bound for them.
(Sequence[Any]): Can also use sequence with ``?`` with parameters executed in sequence.
Note that binding arrays by index is currently not supported. These must be bound by name.
column_types (Sequence[int]): An optional sequence of types (values
of the `ColumnType` enumeration) which the columns of the
result set will be coerced to.
Expand All @@ -1049,6 +1069,11 @@ def execute(
... {'x': 2, 'y': 4})
>>> cursor.fetchall()
[[1, 2], [2, 4]]
>>> cursor.execute("select 1, 2 UNION ALL select ?, ?",
... [2, 4]])
>>> cursor.fetchall()
[[1, 2], [2, 4]]
"""
self._check_closed()
self._description = None
Expand All @@ -1068,7 +1093,7 @@ def execute(
return self

def executemany(
self, sql: str, seq_of_parameters: Sequence[Mapping[str, ParameterValue]]
self, sql: str, seq_of_parameters: Sequence[Mapping[str, ParameterValue]] | Sequence[Sequence[ParameterValue]]
) -> None:
"""Execute the same SQL statement repeatedly with different parameters.
Expand All @@ -1078,10 +1103,10 @@ def executemany(
Args:
sql (str): The SQL string to execute, as a Python format string of
the format expected by `execute`.
seq_of_parameters (Sequence[Mapping[str, Any]]): A sequence of
seq_of_parameters (Sequence[Mapping[str, Any]] | Sequence[Sequence[Any]]): A sequence of
mappings from parameter names to the values to be bound for
them. The ``sql`` statement will be run once per element in
this sequence.
them or a sequence of a sequence of parameter values if binding by index.
The ``sql`` statement will be run once per element in this sequence.
"""
self._check_closed()
for parameters in seq_of_parameters:
Expand All @@ -1105,7 +1130,9 @@ def _execute(self, operation, sql, parameters=None, *, column_types=None):
try:
# If variable interpolation fails, then translate the exception to
# an InterfaceError to signal that it's a client-side problem.
sql = sql % {name: "@" + name for name in parameters}
# If binding by index then no need to modify sql
if not isinstance(parameters, (list, tuple)):
sql = sql % {name: "@" + name for name in parameters}
except KeyError as keyerr:
msg = "No value provided for parameter %s" % keyerr
raise InterfaceError(msg) from keyerr
Expand Down
16 changes: 15 additions & 1 deletion tests/test_cdb2.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,11 @@ def test_binding_parameters():
hndl.execute("insert into simple(key, val) values(@k, @v)", dict(k=3, v=4))
assert hndl.get_effects()[0] == 1

hndl.execute("insert into simple(key, val) values(?, ?)", [5, 6])
assert hndl.get_effects()[0] == 1

rows = list(hndl.execute("select key, val from simple order by key"))
assert rows == [[1, 2], [3, 4]]
assert rows == [[1, 2], [3, 4], [5, 6]]


def test_commit_failures():
Expand Down Expand Up @@ -304,6 +307,17 @@ def test_parameter_name_in_binding_errors_noexception():
)


def test_binding_array_by_index():
hndl = cdb2.Handle("mattdb", "dev")

with pytest.raises(Exception) as exc:
hndl.execute("select * from carray(?)", [[1, 2, 3]])

assert exc.value.args[0] == (
"Binding arrays by index is currently unsupported. Bind arrays by name."
)


def test_specifying_column_types():
# GIVEN
hndl = cdb2.Handle("mattdb", "dev")
Expand Down
63 changes: 63 additions & 0 deletions tests/test_dbapi2.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
DataError,
Datetime,
DatetimeUs,
Error,
ForeignKeyConstraintError,
IntegrityError,
InterfaceError,
Expand Down Expand Up @@ -263,6 +264,11 @@ def test_unescaped_percent():
with pytest.raises(InterfaceError):
cursor.execute("select 1%2")

# no escaped percent with binding sequence
cursor.execute("select ? % ?", [5, 3]) # Should work
with pytest.raises(Error):
cursor.execute("select ? %% ?", [5, 3])


def test_reading_and_writing_datetimes():
conn = connect("mattdb", "dev")
Expand Down Expand Up @@ -377,6 +383,22 @@ def test_all_datatypes_as_parameters():
assert row == list(v for k, v in params)
assert cursor.fetchone() is None

cursor.execute("delete from all_datatypes")
conn.commit()

cursor.execute(
"insert into all_datatypes(" + ", ".join(COLUMN_LIST) + ")"
" values(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
list(v for k, v in params),
)

conn.commit()

cursor.execute("select * from all_datatypes")
row2 = cursor.fetchone()
assert row2 == row
assert cursor.fetchone() is None


def test_naive_datetime_as_parameter():
conn = connect("mattdb", "dev")
Expand Down Expand Up @@ -414,6 +436,22 @@ def test_naive_datetime_as_parameter():
assert row == [Datetime(2009, 2, 13, 18, 31, 30, 234000, pytz.UTC)]
assert cursor.fetchone() is None

cursor.execute("delete from all_datatypes")
conn.commit()

cursor.execute(
"insert into all_datatypes(" + ", ".join(COLUMN_LIST) + ")"
" values(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
list(v for k, v in params),
)

conn.commit()

cursor.execute("select datetime_col from all_datatypes")
row2 = cursor.fetchone()
assert row2 == row
assert cursor.fetchone() is None


def test_datetime_with_non_olson_tzname():
conn = connect("mattdb", "dev")
Expand All @@ -431,6 +469,12 @@ def test_datetime_with_non_olson_tzname():
assert row[1].tzname() == "UTC"
assert row[1] == edt_dt

row = cursor.execute("select ?, ?", [est_dt, edt_dt]).fetchone()
assert row[0].tzname() == "UTC"
assert row[0] == est_dt
assert row[1].tzname() == "UTC"
assert row[1] == edt_dt


def test_rounding_datetime_to_nearest_millisecond():
conn = connect("mattdb", "dev")
Expand All @@ -443,10 +487,16 @@ def test_rounding_datetime_to_nearest_millisecond():
cursor.execute("select @date", {"date": curr_microsecond})
assert cursor.fetchall() == [[prev_millisecond]]

cursor.execute("select ?", [curr_microsecond])
assert cursor.fetchall() == [[prev_millisecond]]

curr_microsecond += datetime.timedelta(microseconds=1)
cursor.execute("select @date", {"date": curr_microsecond})
assert cursor.fetchall() == [[next_millisecond]]

cursor.execute("select ?", [curr_microsecond])
assert cursor.fetchall() == [[next_millisecond]]


def test_cursor_description():
conn = connect("mattdb", "dev")
Expand Down Expand Up @@ -501,6 +551,9 @@ def test_binding_number_that_overflows_long_long():
with pytest.raises(DataError):
cursor.execute("select @i", dict(i=2**64 + 1))

with pytest.raises(DataError):
cursor.execute("select ?", [2**64 + 1])


def test_retrieving_null():
conn = connect("mattdb", "dev")
Expand Down Expand Up @@ -969,6 +1022,16 @@ def test_parameter_binding_invalid_arrays(values, exc_msg):
cursor.execute("select * from carray(%(values)s)", dict(values=values))


def test_parameter_binding_arrays_by_index():
# GIVEN
conn = connect("mattdb", "dev")
cursor = conn.cursor()

# WHEN/THEN
with pytest.raises(ValueError) as exc:
cursor.execute("select * from carray(?)", [[1, 2, 3]])


def test_specifying_column_types():
# GIVEN
conn = connect("mattdb", "dev")
Expand Down

0 comments on commit ec9ddc5

Please sign in to comment.