This is an automated email from the ASF dual-hosted git repository.

lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git


The following commit(s) were added to refs/heads/main by this push:
     new 74b98ef  feat(python/adbc_driver_manager): expose 
StatementGetParameterSchema (#555)
74b98ef is described below

commit 74b98eff927494707881e20c23ac5913b18a8afe
Author: David Li <[email protected]>
AuthorDate: Mon Mar 27 21:14:24 2023 -0400

    feat(python/adbc_driver_manager): expose StatementGetParameterSchema (#555)
    
    Fixes #537.
---
 .../adbc_driver_manager/_lib.pyx                   | 36 ++++++++++++++++++++++
 .../adbc_driver_manager/dbapi.py                   | 26 ++++++++++++++++
 python/adbc_driver_manager/tests/test_dbapi.py     | 13 ++++++++
 python/adbc_driver_manager/tests/test_lowlevel.py  | 20 ++++++++++++
 4 files changed, 95 insertions(+)

diff --git a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx 
b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx
index 63c965d..2e8658f 100644
--- a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx
+++ b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx
@@ -199,6 +199,10 @@ cdef extern from "adbc.h" nogil:
         CAdbcStatement* statement,
         CArrowArrayStream* out, int64_t* rows_affected,
         CAdbcError* error)
+    CAdbcStatusCode AdbcStatementGetParameterSchema(
+        CAdbcStatement* statement,
+        CArrowSchema* schema,
+        CAdbcError* error);
     CAdbcStatusCode AdbcStatementNew(
         CAdbcConnection* connection,
         CAdbcStatement* statement,
@@ -1046,6 +1050,38 @@ cdef class AdbcStatement(_AdbcHandle):
         check_error(status, &c_error)
         return rows_affected
 
+    def get_parameter_schema(self) -> ArrowSchemaHandle:
+        """Get the Arrow schema for bound parameters.
+
+        This retrieves an Arrow schema describing the number, names,
+        and types of the parameters in a parameterized statement.  The
+        fields of the schema should be in order of the ordinal
+        position of the parameters; named parameters should appear
+        only once.
+
+        If the parameter does not have a name, or the name cannot be
+        determined, the name of the corresponding field in the schema
+        will be an empty string.  If the type cannot be determined,
+        the type of the corresponding field will be NA (NullType).
+
+        This should be called after :meth:`prepare`.
+
+        Raises
+        ------
+        NotSupportedError
+            If the schema could not be determined.
+
+        """
+        cdef CAdbcError c_error = empty_error()
+        cdef CAdbcStatusCode status
+        cdef ArrowSchemaHandle handle = ArrowSchemaHandle()
+
+        with nogil:
+            status = AdbcStatementGetParameterSchema(
+                &self.statement, &handle.schema, &c_error)
+        check_error(status, &c_error)
+        return handle
+
     def prepare(self) -> None:
         """Turn this statement into a prepared statement."""
         cdef CAdbcError c_error = empty_error()
diff --git a/python/adbc_driver_manager/adbc_driver_manager/dbapi.py 
b/python/adbc_driver_manager/adbc_driver_manager/dbapi.py
index eddb32b..4e7e5ff 100644
--- a/python/adbc_driver_manager/adbc_driver_manager/dbapi.py
+++ b/python/adbc_driver_manager/adbc_driver_manager/dbapi.py
@@ -789,6 +789,32 @@ class Cursor(_Closeable):
         partitions, schema, self._rowcount = self._stmt.execute_partitions()
         return partitions, pyarrow.Schema._import_from_c(schema.address)
 
+    def adbc_prepare(self, operation: Union[bytes, str]) -> 
Optional[pyarrow.Schema]:
+        """
+        Prepare a query without executing it.
+
+        To execute the query afterwards, call :meth:`execute` or
+        :meth:`executemany` with the same query.  This will not
+        prepare the query a second time.
+
+        Returns
+        -------
+        pyarrow.Schema or None
+            The schema of the bind parameters, or None if the schema
+            could not be determined.
+
+        Notes
+        -----
+        This is an extension and not part of the DBAPI standard.
+        """
+        self._prepare_execute(operation)
+
+        try:
+            handle = self._stmt.get_parameter_schema()
+        except NotSupportedError:
+            return None
+        return pyarrow.Schema._import_from_c(handle.address)
+
     def adbc_read_partition(self, partition: bytes) -> None:
         """
         Read a partition of a distributed result set.
diff --git a/python/adbc_driver_manager/tests/test_dbapi.py 
b/python/adbc_driver_manager/tests/test_dbapi.py
index a75bfed..478067f 100644
--- a/python/adbc_driver_manager/tests/test_dbapi.py
+++ b/python/adbc_driver_manager/tests/test_dbapi.py
@@ -294,6 +294,19 @@ def test_executemany(sqlite):
         assert next(cur) == (5, 6)
 
 
[email protected]
+def test_prepare(sqlite):
+    with sqlite.cursor() as cur:
+        schema = cur.adbc_prepare("SELECT 1")
+        assert schema == pyarrow.schema([])
+
+        schema = cur.adbc_prepare("SELECT 1 + ?")
+        assert schema == pyarrow.schema([("0", "null")])
+
+        cur.execute("SELECT 1 + ?", (1,))
+        assert cur.fetchone() == (2,)
+
+
 @pytest.mark.sqlite
 def test_close_warning(sqlite):
     with pytest.warns(
diff --git a/python/adbc_driver_manager/tests/test_lowlevel.py 
b/python/adbc_driver_manager/tests/test_lowlevel.py
index 08209fb..819873e 100644
--- a/python/adbc_driver_manager/tests/test_lowlevel.py
+++ b/python/adbc_driver_manager/tests/test_lowlevel.py
@@ -259,6 +259,26 @@ def test_statement_ingest(sqlite):
         assert table == pyarrow.Table.from_batches([data])
 
 
[email protected]
+def test_statement_adbc_prepare(sqlite):
+    _, conn = sqlite
+    with adbc_driver_manager.AdbcStatement(conn) as stmt:
+        stmt.set_sql_query("SELECT 1")
+        stmt.prepare()
+        handle = stmt.get_parameter_schema()
+        assert _import(handle) == pyarrow.schema([])
+
+        stmt.set_sql_query("SELECT 1 + ?")
+        stmt.prepare()
+        handle = stmt.get_parameter_schema()
+        assert _import(handle) == pyarrow.schema([("0", "null")])
+
+        _bind(stmt, pyarrow.record_batch([[41]], names=["0"]))
+        handle, _ = stmt.execute_query()
+        table = _import(handle).read_all()
+        assert table == pyarrow.table([[42]], names=["1 + ?"])
+
+
 @pytest.mark.sqlite
 def test_statement_autocommit(sqlite):
     _, conn = sqlite

Reply via email to