From 3066d8ac5e2e3ed6897087528de901d49039134f Mon Sep 17 00:00:00 2001 From: Salil Chandra Date: Wed, 18 Dec 2024 14:51:27 -0500 Subject: [PATCH] Add support for binding parameters by index Signed-off-by: Salil Chandra --- comdb2/_ccdb2.pyx | 17 +++++++--- comdb2/_cdb2api.pxd | 1 + comdb2/cdb2.py | 20 ++++++++++-- comdb2/dbapi2.py | 41 ++++++++++++++++++++---- tests/test_cdb2.py | 16 +++++++++- tests/test_dbapi2.py | 76 ++++++++++++++++++++++++++++++++++++++++++++ 6 files changed, 157 insertions(+), 14 deletions(-) diff --git a/comdb2/_ccdb2.pyx b/comdb2/_ccdb2.pyx index 4b8e8fc..d10d249 100644 --- a/comdb2/_ccdb2.pyx +++ b/comdb2/_ccdb2.pyx @@ -401,15 +401,24 @@ cdef class Handle(object): param_guards = [] try: if parameters is not None: - for key, val in parameters.items(): - ckey = _string_as_bytes(key) + bind_by_index = isinstance(parameters, (list, tuple)) + items = enumerate(parameters, 1) if bind_by_index \ + else parameters.items() + for key, val in items: + ckey = _string_as_bytes(key) if not bind_by_index else key cval = _ParameterValue(val, key) param_guards.append(ckey) param_guards.append(cval) if cval.list_size == -1: - rc = lib.cdb2_bind_param(self.hndl, ckey, - cval.type, cval.data, cval.size) + if bind_by_index: + rc = lib.cdb2_bind_index(self.hndl, ckey, + cval.type, cval.data, cval.size) + else: + rc = lib.cdb2_bind_param(self.hndl, ckey, + cval.type, cval.data, cval.size) else: + if bind_by_index: + raise ValueError("Binding arrays by index is currently unsupported. Bind arrays by name.") # Bind Array if cval is an array rc = lib.cdb2_bind_array(self.hndl, ckey, cval.type, cval.data, cval.list_size, cval.size) _errchk(rc, self.hndl) diff --git a/comdb2/_cdb2api.pxd b/comdb2/_cdb2api.pxd index e061680..4b4e93f 100644 --- a/comdb2/_cdb2api.pxd +++ b/comdb2/_cdb2api.pxd @@ -76,6 +76,7 @@ cdef extern from "cdb2api.h" nogil: void* cdb2_column_value(cdb2_hndl_tp* hndl, int col) except + const char* cdb2_errstr(cdb2_hndl_tp* hndl) except + int cdb2_bind_param(cdb2_hndl_tp *hndl, const char *name, int type, const void *varaddr, int length) except + + int cdb2_bind_index(cdb2_hndl_tp *hndl, int index, int type, const void *varaddr, int length) except + int cdb2_bind_array(cdb2_hndl_tp *hndl, const char *name, cdb2_coltype, const void *varaddr, size_t count, size_t typelen) except + int cdb2_clearbindings(cdb2_hndl_tp *hndl) except + int cdb2_clear_ack(cdb2_hndl_tp *hndl) except + diff --git a/comdb2/cdb2.py b/comdb2/cdb2.py index 01f8c7f..7e07cb5 100644 --- a/comdb2/cdb2.py +++ b/comdb2/cdb2.py @@ -99,6 +99,14 @@ examples we make use of the `list` constructor to turn the iterable returned by `Handle.execute` into a list of result rows. +You can also bind by index by providing a sequence +instead of by name with placeholders specified using ``?`` in sequence. +Note that binding an array by index is currently not supported. For example: + + >>> query = "select 25 between ? and ?" + >>> print(list(hndl.execute(query, [20, 42]))) + [[1]] + Types ----- @@ -359,7 +367,7 @@ def row_factory( def execute( self, sql: str | bytes, - parameters: Mapping[str, ParameterValue] | None = None, + parameters: Mapping[str, ParameterValue] | Sequence[ParameterValue] | None = None, *, column_types: Sequence[ColumnType] | None = None, ) -> Handle: @@ -369,7 +377,7 @@ def execute( This should always be the preferred method of parameterizing the SQL query, as it prevents SQL injection vulnerabilities and is faster. Placeholders for named parameters must be in Comdb2's native format, - ``@param_name``. + ``@param_name``, or with ``?`` for positional parameters. If ``column_types`` is provided and non-empty, it must be a sequence of members of the `ColumnType` enumeration. The database will coerce the @@ -383,6 +391,8 @@ def execute( sql (str): The SQL string to execute. parameters (Mapping[str, Any]): An optional mapping from parameter names to the values to be bound for them. + (Sequence[Any]): Can also use sequence with ``?`` with parameters executed in sequence. + Note that binding arrays by index is currently not supported. These must be bound by name. column_types (Sequence[int]): An optional sequence of types (values of the `ColumnType` enumeration) which the columns of the result set will be coerced to. @@ -403,6 +413,12 @@ def execute( ... print(row) [1, 2] [2, 4] + + >>> for row in hndl.execute("select 1, 2 UNION ALL select ?, ?", + ... [2, 4]): + ... print(row) + [1, 2] + [2, 4] """ if parameters is None: parameters = {} diff --git a/comdb2/dbapi2.py b/comdb2/dbapi2.py index 581b685..3047830 100644 --- a/comdb2/dbapi2.py +++ b/comdb2/dbapi2.py @@ -120,10 +120,23 @@ When we run the same query with parameter ``b`` bound to ``23``, a ``0`` is returned instead, because ``20 <= 25 <= 23`` is false. +Alternatively, you can also bind by index instead of by name, by providing a list/tuple +with placeholders specified using ``?`` in the same order as the elements in the list/tuple. +Note that binding an array by index is currently not supported, these must be bound by name. For example: + + >>> query = "select 25 between ? and ?" + >>> print(conn.cursor().execute(query, [20, 42]).fetchall()) + [[1]] + +In this example, we execute the query with the first ``?`` bound to 20 and the second +``?`` bound to 42. Thus, a ``1`` is returned like in the previous example. + Note: - Because parameters are bound using ``%(name)s``, other ``%`` signs in + Because parameters by name are bound using ``%(name)s``, other ``%`` signs in a query must be escaped. For example, ``WHERE name like 'M%'`` becomes ``WHERE name LIKE 'M%%'``. + However, this does not apply when binding parameters by index. ``%`` does not + need to be escaped in this case, and only in this case. Types ----- @@ -314,6 +327,9 @@ Because SQL strings for this module use the ``pyformat`` placeholder style, any literal ``%`` characters in a query must be escaped by doubling them. ``WHERE name like 'M%'`` becomes ``WHERE name LIKE 'M%%'``. + +This module also has support for ``qmark`` if binding by index. +``%`` does not need to be escaped if binding by position. """ _FIRST_WORD_OF_STMT = re.compile( @@ -996,7 +1012,7 @@ def callproc(self, procname: str, parameters: Sequence[ParameterValue]) -> Seque def execute( self, sql: str, - parameters: Mapping[str, ParameterValue] | None = None, + parameters: Mapping[str, ParameterValue] | Sequence[ParameterValue] | None = None, *, column_types: Sequence[ColumnType] | None = None, ) -> Cursor: @@ -1005,6 +1021,8 @@ def execute( The ``sql`` string must be provided as a Python format string, with parameter placeholders represented as ``%(name)s`` and all other ``%`` signs escaped as ``%%``. + HOWEVER, if binding by index (parameter placeholders represented as ``?``), + ``%`` does not need to be escaped. This is the only time it does not need to be escaped. Note: Using placeholders should always be the preferred method of @@ -1029,6 +1047,8 @@ def execute( sql (str): The SQL string to execute, as a Python format string. parameters (Mapping[str, Any]): An optional mapping from parameter names to the values to be bound for them. + (Sequence[Any]): Can also use sequence with ``?`` with parameters executed in sequence. + Note that binding arrays by index is currently not supported. These must be bound by name. column_types (Sequence[int]): An optional sequence of types (values of the `ColumnType` enumeration) which the columns of the result set will be coerced to. @@ -1049,6 +1069,11 @@ def execute( ... {'x': 2, 'y': 4}) >>> cursor.fetchall() [[1, 2], [2, 4]] + + >>> cursor.execute("select 1, 2 UNION ALL select ?, ?", + ... [2, 4]]) + >>> cursor.fetchall() + [[1, 2], [2, 4]] """ self._check_closed() self._description = None @@ -1068,7 +1093,7 @@ def execute( return self def executemany( - self, sql: str, seq_of_parameters: Sequence[Mapping[str, ParameterValue]] + self, sql: str, seq_of_parameters: Sequence[Mapping[str, ParameterValue]] | Sequence[Sequence[ParameterValue]] ) -> None: """Execute the same SQL statement repeatedly with different parameters. @@ -1078,10 +1103,10 @@ def executemany( Args: sql (str): The SQL string to execute, as a Python format string of the format expected by `execute`. - seq_of_parameters (Sequence[Mapping[str, Any]]): A sequence of + seq_of_parameters (Sequence[Mapping[str, Any]] | Sequence[Sequence[Any]]): A sequence of mappings from parameter names to the values to be bound for - them. The ``sql`` statement will be run once per element in - this sequence. + them or a sequence of a sequence of parameter values if binding by index. + The ``sql`` statement will be run once per element in this sequence. """ self._check_closed() for parameters in seq_of_parameters: @@ -1105,7 +1130,9 @@ def _execute(self, operation, sql, parameters=None, *, column_types=None): try: # If variable interpolation fails, then translate the exception to # an InterfaceError to signal that it's a client-side problem. - sql = sql % {name: "@" + name for name in parameters} + # If binding by index then no need to modify sql + if not isinstance(parameters, (list, tuple)): + sql = sql % {name: "@" + name for name in parameters} except KeyError as keyerr: msg = "No value provided for parameter %s" % keyerr raise InterfaceError(msg) from keyerr diff --git a/tests/test_cdb2.py b/tests/test_cdb2.py index da8b736..e614aa8 100644 --- a/tests/test_cdb2.py +++ b/tests/test_cdb2.py @@ -62,8 +62,11 @@ def test_binding_parameters(): hndl.execute("insert into simple(key, val) values(@k, @v)", dict(k=3, v=4)) assert hndl.get_effects()[0] == 1 + hndl.execute("insert into simple(key, val) values(?, ?)", [5, 6]) + assert hndl.get_effects()[0] == 1 + rows = list(hndl.execute("select key, val from simple order by key")) - assert rows == [[1, 2], [3, 4]] + assert rows == [[1, 2], [3, 4], [5, 6]] def test_commit_failures(): @@ -304,6 +307,17 @@ def test_parameter_name_in_binding_errors_noexception(): ) +def test_binding_array_by_index(): + hndl = cdb2.Handle("mattdb", "dev") + + with pytest.raises(Exception) as exc: + hndl.execute("select * from carray(?)", [[1, 2, 3]]) + + assert exc.value.args[0] == ( + "Binding arrays by index is currently unsupported. Bind arrays by name." + ) + + def test_specifying_column_types(): # GIVEN hndl = cdb2.Handle("mattdb", "dev") diff --git a/tests/test_dbapi2.py b/tests/test_dbapi2.py index 3dbd175..361525a 100644 --- a/tests/test_dbapi2.py +++ b/tests/test_dbapi2.py @@ -28,6 +28,7 @@ DataError, Datetime, DatetimeUs, + Error, ForeignKeyConstraintError, IntegrityError, InterfaceError, @@ -263,6 +264,24 @@ def test_unescaped_percent(): with pytest.raises(InterfaceError): cursor.execute("select 1%2") + # no escaped percent with binding sequence + cursor.execute("select ? % ?", [5, 3]) # Should work + with pytest.raises(Error): + cursor.execute("select ? %% ?", [5, 3]) + + +def test_different_sequences(): + conn = connect("mattdb", "dev") + cursor = conn.cursor() + cursor.execute("select ?, ?", [1,2]) + assert cursor.fetchall() == [[1, 2]] + + cursor.execute("select ?, ?", (1,2)) + assert cursor.fetchall() == [[1, 2]] + + with pytest.raises(AttributeError): + cursor.execute("select ?, ?", "hi") + def test_reading_and_writing_datetimes(): conn = connect("mattdb", "dev") @@ -377,6 +396,22 @@ def test_all_datatypes_as_parameters(): assert row == list(v for k, v in params) assert cursor.fetchone() is None + cursor.execute("delete from all_datatypes") + conn.commit() + + cursor.execute( + "insert into all_datatypes(" + ", ".join(COLUMN_LIST) + ")" + " values(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + list(v for k, v in params), + ) + + conn.commit() + + cursor.execute("select * from all_datatypes") + row2 = cursor.fetchone() + assert row2 == row + assert cursor.fetchone() is None + def test_naive_datetime_as_parameter(): conn = connect("mattdb", "dev") @@ -414,6 +449,22 @@ def test_naive_datetime_as_parameter(): assert row == [Datetime(2009, 2, 13, 18, 31, 30, 234000, pytz.UTC)] assert cursor.fetchone() is None + cursor.execute("delete from all_datatypes") + conn.commit() + + cursor.execute( + "insert into all_datatypes(" + ", ".join(COLUMN_LIST) + ")" + " values(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + list(v for k, v in params), + ) + + conn.commit() + + cursor.execute("select datetime_col from all_datatypes") + row2 = cursor.fetchone() + assert row2 == row + assert cursor.fetchone() is None + def test_datetime_with_non_olson_tzname(): conn = connect("mattdb", "dev") @@ -431,6 +482,12 @@ def test_datetime_with_non_olson_tzname(): assert row[1].tzname() == "UTC" assert row[1] == edt_dt + row = cursor.execute("select ?, ?", [est_dt, edt_dt]).fetchone() + assert row[0].tzname() == "UTC" + assert row[0] == est_dt + assert row[1].tzname() == "UTC" + assert row[1] == edt_dt + def test_rounding_datetime_to_nearest_millisecond(): conn = connect("mattdb", "dev") @@ -443,10 +500,16 @@ def test_rounding_datetime_to_nearest_millisecond(): cursor.execute("select @date", {"date": curr_microsecond}) assert cursor.fetchall() == [[prev_millisecond]] + cursor.execute("select ?", [curr_microsecond]) + assert cursor.fetchall() == [[prev_millisecond]] + curr_microsecond += datetime.timedelta(microseconds=1) cursor.execute("select @date", {"date": curr_microsecond}) assert cursor.fetchall() == [[next_millisecond]] + cursor.execute("select ?", [curr_microsecond]) + assert cursor.fetchall() == [[next_millisecond]] + def test_cursor_description(): conn = connect("mattdb", "dev") @@ -501,6 +564,9 @@ def test_binding_number_that_overflows_long_long(): with pytest.raises(DataError): cursor.execute("select @i", dict(i=2**64 + 1)) + with pytest.raises(DataError): + cursor.execute("select ?", [2**64 + 1]) + def test_retrieving_null(): conn = connect("mattdb", "dev") @@ -969,6 +1035,16 @@ def test_parameter_binding_invalid_arrays(values, exc_msg): cursor.execute("select * from carray(%(values)s)", dict(values=values)) +def test_parameter_binding_arrays_by_index(): + # GIVEN + conn = connect("mattdb", "dev") + cursor = conn.cursor() + + # WHEN/THEN + with pytest.raises(ValueError) as exc: + cursor.execute("select * from carray(?)", [[1, 2, 3]]) + + def test_specifying_column_types(): # GIVEN conn = connect("mattdb", "dev")