This is an automated email from the ASF dual-hosted git repository.
lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git
The following commit(s) were added to refs/heads/main by this push:
new 687da3753 feat(python/adbc_driver_manager): enable DB-API without
PyArrow (#2609)
687da3753 is described below
commit 687da37536b370821be6ac95bd5c5de53207b2e7
Author: David Li <[email protected]>
AuthorDate: Fri Mar 21 04:35:21 2025 -0400
feat(python/adbc_driver_manager): enable DB-API without PyArrow (#2609)
Enable limited use of the DB-API interface without PyArrow for people
using other Arrow-based libraries, like polars (which is used here to
test the new functionality).
This doesn't enable everything (e.g. get_objects), but it does enable
(parameterized) queries and ingestion, which I would guess are the most
important things.
Fixes #2413.
---
ci/conda_env_docs.txt | 1 +
ci/scripts/python_sdist_test.sh | 2 +-
ci/scripts/python_util.sh | 29 +-
ci/scripts/python_wheel_unix_test.sh | 7 +-
ci/scripts/python_wheel_windows_test.bat | 2 +-
docs/source/conf.py | 17 ++
.../adbc_driver_manager/_lib.pyi | 1 +
.../adbc_driver_manager/_lib.pyx | 14 +-
.../adbc_driver_manager/dbapi.py | 292 ++++++++++++++-------
python/adbc_driver_manager/pyproject.toml | 1 +
python/adbc_driver_manager/tests/test_dbapi.py | 42 +++
.../tests/test_dbapi_nopyarrow.py | 156 +++++++++++
12 files changed, 461 insertions(+), 103 deletions(-)
diff --git a/ci/conda_env_docs.txt b/ci/conda_env_docs.txt
index 26d20f8c6..9badc3f31 100644
--- a/ci/conda_env_docs.txt
+++ b/ci/conda_env_docs.txt
@@ -21,6 +21,7 @@ make
# Needed to install mermaid
nodejs
numpydoc
+polars
pytest
sphinx>=8.1
sphinx-autobuild
diff --git a/ci/scripts/python_sdist_test.sh b/ci/scripts/python_sdist_test.sh
index 46ce43835..38e71e692 100755
--- a/ci/scripts/python_sdist_test.sh
+++ b/ci/scripts/python_sdist_test.sh
@@ -47,7 +47,7 @@ echo "=== Installing sdists ==="
for component in ${COMPONENTS}; do
pip install --no-deps --force-reinstall
${source_dir}/python/${component}/dist/*.tar.gz
done
-pip install importlib-resources pytest pyarrow pandas protobuf
+pip install importlib-resources pytest pyarrow pandas polars protobuf
echo "=== (${PYTHON_VERSION}) Testing sdists ==="
test_packages
diff --git a/ci/scripts/python_util.sh b/ci/scripts/python_util.sh
index abeabd78b..3bf1dee2e 100644
--- a/ci/scripts/python_util.sh
+++ b/ci/scripts/python_util.sh
@@ -161,10 +161,33 @@ import $component.dbapi
# --import-mode required, else tries to import from the source dir
instead of installed package
if [[ "${component}" = "adbc_driver_manager" ]]; then
- export PYTEST_ADDOPTS="-k 'not duckdb and not sqlite'"
- elif [[ "${component}" = "adbc_driver_postgresql" ]]; then
- export PYTEST_ADDOPTS="-k 'not polars'"
+ export PYTEST_ADDOPTS="${PYTEST_ADDOPTS} -k 'not duckdb and not
sqlite'"
fi
python -m pytest -vvx --import-mode append
${source_dir}/python/$component/tests
done
}
+
+function test_packages_pyarrowless {
+ local -r driver_path=$(python -c "import os; import adbc_driver_sqlite;
print(os.path.dirname(adbc_driver_sqlite._driver_path()))")
+ export LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:${driver_path}"
+ export DYLD_LIBRARY_PATH="${DYLD_LIBRARY_PATH}:${driver_path}"
+ # For macOS (because we name the file ".so" on every platform regardless
of the actual type)
+ ln -s "${driver_path}/libadbc_driver_sqlite.so"
"${driver_path}/libadbc_driver_sqlite.dylib"
+ for component in ${COMPONENTS}; do
+ echo "=== Testing $component (no PyArrow) ==="
+
+ python -c "
+import $component
+import $component.dbapi
+"
+
+ local test_files=$(find ${source_dir}/python/$component/tests -type f |
+ grep -e 'nopyarrow\.py$')
+ if [[ -z "${test_files}" ]]; then
+ continue
+ fi
+
+ # --import-mode required, else tries to import from the source dir
instead of installed package
+ python -m pytest -vvx --import-mode append "${test_files[@]}"
+ done
+}
diff --git a/ci/scripts/python_wheel_unix_test.sh
b/ci/scripts/python_wheel_unix_test.sh
index 15eea984d..3cc72fd9b 100755
--- a/ci/scripts/python_wheel_unix_test.sh
+++ b/ci/scripts/python_wheel_unix_test.sh
@@ -49,8 +49,13 @@ for component in ${COMPONENTS}; do
echo "NOTE: assuming wheels are already installed"
fi
done
-pip install importlib-resources pytest pyarrow pandas protobuf
+pip install importlib-resources pytest pyarrow pandas polars protobuf
echo "=== (${PYTHON_VERSION}) Testing wheels ==="
test_packages
+
+echo "=== (${PYTHON_VERSION}) Testing wheels (no PyArrow) ==="
+pip uninstall -y pyarrow
+export PYTEST_ADDOPTS="${PYTEST_ADDOPTS} -k pyarrowless"
+test_packages_pyarrowless
diff --git a/ci/scripts/python_wheel_windows_test.bat
b/ci/scripts/python_wheel_windows_test.bat
index 963067b7b..852991bfc 100644
--- a/ci/scripts/python_wheel_windows_test.bat
+++ b/ci/scripts/python_wheel_windows_test.bat
@@ -27,7 +27,7 @@ FOR %%c IN (adbc_driver_bigquery adbc_driver_manager
adbc_driver_flightsql adbc_
)
)
-pip install importlib-resources pytest pyarrow pandas protobuf
+pip install importlib-resources pytest pyarrow pandas polars protobuf
echo "=== (%PYTHON_VERSION%) Testing wheels ==="
diff --git a/docs/source/conf.py b/docs/source/conf.py
index bcbefaf96..10771e168 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -54,6 +54,23 @@ extensions = [
]
templates_path = ["_templates"]
+
+def on_missing_reference(app, env, node, contnode):
+ if str(contnode) == "polars.DataFrame":
+ # Polars does something odd with Sphinx such that polars.DataFrame
+ # isn't xrefable; suppress the warning.
+ return contnode
+ elif str(contnode) == "CapsuleType":
+ # CapsuleType is only in 3.13+
+ return contnode
+ else:
+ return None
+
+
+def setup(app):
+ app.connect("missing-reference", on_missing_reference)
+
+
# -- Options for autodoc ----------------------------------------------------
try:
diff --git a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyi
b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyi
index 0a19f92ed..6c7d6c4ab 100644
--- a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyi
+++ b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyi
@@ -214,3 +214,4 @@ def _blocking_call(
kwargs: dict,
cancel: typing.Callable[[], None],
) -> _T: ...
+def is_pycapsule(obj: Any, name: bytes) -> bool: ...
diff --git a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx
b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx
index 21afe9d3c..d2ac2401b 100644
--- a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx
+++ b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx
@@ -31,7 +31,7 @@ from typing import List, Optional, Tuple
import cython
from cpython.bytes cimport PyBytes_FromStringAndSize
from cpython.pycapsule cimport (
- PyCapsule_GetPointer, PyCapsule_New, PyCapsule_CheckExact
+ PyCapsule_GetPointer, PyCapsule_IsValid, PyCapsule_New
)
from libc.stdint cimport int64_t, uint8_t, uint32_t, uintptr_t
from libc.stdlib cimport malloc, free
@@ -337,6 +337,12 @@ cdef class _AdbcHandle:
f"with open {self._child_type}")
+def is_pycapsule(obj, bytes name) -> bool:
+ """Check if an object is a PyCapsule of a specific type."""
+ # Taken from nanoarrow
+ return PyCapsule_IsValid(obj, name) == 1
+
+
cdef void pycapsule_schema_deleter(object capsule) noexcept:
cdef CArrowSchema* allocated = <CArrowSchema*>PyCapsule_GetPointer(
capsule, "arrow_schema"
@@ -1125,7 +1131,7 @@ cdef class AdbcStatement(_AdbcHandle):
)
schema, data = data.__arrow_c_array__()
- if PyCapsule_CheckExact(data):
+ if is_pycapsule(data, b"arrow_array"):
c_array = <CArrowArray*> PyCapsule_GetPointer(data, "arrow_array")
elif isinstance(data, ArrowArrayHandle):
c_array = &(<ArrowArrayHandle> data).array
@@ -1137,7 +1143,7 @@ cdef class AdbcStatement(_AdbcHandle):
f"Protocol), a PyCapsule, int or ArrowArrayHandle, not
{type(data)}"
)
- if PyCapsule_CheckExact(schema):
+ if is_pycapsule(schema, b"arrow_schema"):
c_schema = <CArrowSchema*> PyCapsule_GetPointer(schema,
"arrow_schema")
elif isinstance(schema, ArrowSchemaHandle):
c_schema = &(<ArrowSchemaHandle> schema).schema
@@ -1172,7 +1178,7 @@ cdef class AdbcStatement(_AdbcHandle):
):
stream = stream.__arrow_c_stream__()
- if PyCapsule_CheckExact(stream):
+ if is_pycapsule(stream, b"arrow_array_stream"):
c_stream = <CArrowArrayStream*> PyCapsule_GetPointer(
stream, "arrow_array_stream"
)
diff --git a/python/adbc_driver_manager/adbc_driver_manager/dbapi.py
b/python/adbc_driver_manager/adbc_driver_manager/dbapi.py
index 9dfc4e551..679cc1bff 100644
--- a/python/adbc_driver_manager/adbc_driver_manager/dbapi.py
+++ b/python/adbc_driver_manager/adbc_driver_manager/dbapi.py
@@ -17,6 +17,15 @@
"""PEP 249 (DB-API 2.0) API wrapper for the ADBC Driver Manager.
+PyArrow Requirement
+===================
+
+This module requires PyArrow for full functionality. If PyArrow is not
+installed, all functionality that actually reads/writes data will be missing.
+You can still execute queries and get the result as a PyCapsule, but many
+other methods will raise. Also, the DB-API type definitions (``BINARY``,
+``DATETIME``, etc) will be present, but invalid.
+
Resource Management
===================
@@ -40,26 +49,23 @@ from typing import Any, Dict, List, Literal, Optional,
Tuple, Union
try:
import pyarrow
-except ImportError as e:
- raise ImportError("PyArrow is required for the DBAPI-compatible
interface") from e
-
-try:
import pyarrow.dataset
except ImportError:
- _pya_dataset = ()
- _pya_scanner = ()
+ _has_pyarrow = False
else:
- _pya_dataset = (pyarrow.dataset.Dataset,)
- _pya_scanner = (pyarrow.dataset.Scanner,)
+ _has_pyarrow = True
+ from . import _reader
import adbc_driver_manager
-from . import _lib, _reader
+from . import _lib
from ._lib import _blocking_call
if typing.TYPE_CHECKING:
import pandas
- from typing_extensions import Self
+ import polars
+ import pyarrow
+ from typing_extensions import CapsuleType, Self
# ----------------------------------------------------------
# Globals
@@ -131,37 +137,44 @@ class _TypeSet(frozenset):
return False
-#: The type of binary columns.
-BINARY = _TypeSet({pyarrow.binary().id, pyarrow.large_binary().id})
-#: The type of datetime columns.
-DATETIME = _TypeSet(
- [
- pyarrow.date32().id,
- pyarrow.date64().id,
- pyarrow.time32("s").id,
- pyarrow.time64("ns").id,
- pyarrow.timestamp("s").id,
- ]
-)
-#: The type of numeric columns.
-NUMBER = _TypeSet(
- [
- pyarrow.int8().id,
- pyarrow.int16().id,
- pyarrow.int32().id,
- pyarrow.int64().id,
- pyarrow.uint8().id,
- pyarrow.uint16().id,
- pyarrow.uint32().id,
- pyarrow.uint64().id,
- pyarrow.float32().id,
- pyarrow.float64().id,
- ]
-)
-#: The type of "row ID" columns.
-ROWID = _TypeSet([pyarrow.int64().id])
-#: The type of string columns.
-STRING = _TypeSet([pyarrow.string().id, pyarrow.large_string().id])
+if _has_pyarrow:
+ #: The type of binary columns.
+ BINARY = _TypeSet({pyarrow.binary().id, pyarrow.large_binary().id})
+ #: The type of datetime columns.
+ DATETIME = _TypeSet(
+ [
+ pyarrow.date32().id,
+ pyarrow.date64().id,
+ pyarrow.time32("s").id,
+ pyarrow.time64("ns").id,
+ pyarrow.timestamp("s").id,
+ ]
+ )
+ #: The type of numeric columns.
+ NUMBER = _TypeSet(
+ [
+ pyarrow.int8().id,
+ pyarrow.int16().id,
+ pyarrow.int32().id,
+ pyarrow.int64().id,
+ pyarrow.uint8().id,
+ pyarrow.uint16().id,
+ pyarrow.uint32().id,
+ pyarrow.uint64().id,
+ pyarrow.float32().id,
+ pyarrow.float64().id,
+ ]
+ )
+ #: The type of "row ID" columns.
+ ROWID = _TypeSet([pyarrow.int64().id])
+ #: The type of string columns.
+ STRING = _TypeSet([pyarrow.string().id, pyarrow.large_string().id])
+else:
+ BINARY = _TypeSet()
+ DATETIME = _TypeSet()
+ NUMBER = _TypeSet()
+ ROWID = _TypeSet()
+ STRING = _TypeSet()
# ----------------------------------------------------------
# Functions
@@ -396,6 +409,8 @@ class Connection(_Closeable):
-----
This is an extension and not part of the DBAPI standard.
"""
+ _requires_pyarrow()
+
handle = _blocking_call(self._conn.get_info, (), {}, self._conn.cancel)
reader = pyarrow.RecordBatchReader._import_from_c(handle.address)
table = _blocking_call(reader.read_all, (), {}, self._conn.cancel)
@@ -418,7 +433,7 @@ class Connection(_Closeable):
table_name_filter: Optional[str] = None,
table_types_filter: Optional[List[str]] = None,
column_name_filter: Optional[str] = None,
- ) -> pyarrow.RecordBatchReader:
+ ) -> "pyarrow.RecordBatchReader":
"""
List catalogs, schemas, tables, etc. in the database.
@@ -441,6 +456,8 @@ class Connection(_Closeable):
-----
This is an extension and not part of the DBAPI standard.
"""
+ _requires_pyarrow()
+
if depth in ("all", "columns"):
c_depth = _lib.GetObjectsDepth.ALL
elif depth == "catalogs":
@@ -471,7 +488,7 @@ class Connection(_Closeable):
*,
catalog_filter: Optional[str] = None,
db_schema_filter: Optional[str] = None,
- ) -> pyarrow.Schema:
+ ) -> "pyarrow.Schema":
"""
Get the Arrow schema of a table by name.
@@ -488,6 +505,8 @@ class Connection(_Closeable):
-----
This is an extension and not part of the DBAPI standard.
"""
+ _requires_pyarrow()
+
handle = _blocking_call(
self._conn.get_table_schema,
(
@@ -508,6 +527,8 @@ class Connection(_Closeable):
-----
This is an extension and not part of the DBAPI standard.
"""
+ _requires_pyarrow()
+
handle = _blocking_call(
self._conn.get_table_types,
(),
@@ -660,17 +681,10 @@ class Cursor(_Closeable):
self._stmt.bind(parameters)
elif hasattr(parameters, "__arrow_c_stream__"):
self._stmt.bind_stream(parameters)
- elif isinstance(parameters, pyarrow.RecordBatch):
- arr_handle = _lib.ArrowArrayHandle()
- sch_handle = _lib.ArrowSchemaHandle()
- parameters._export_to_c(arr_handle.address, sch_handle.address)
- self._stmt.bind(arr_handle, sch_handle)
+ elif _lib.is_pycapsule(parameters, b"arrow_array_stream"):
+ self._stmt.bind_stream(parameters)
else:
- if isinstance(parameters, pyarrow.Table):
- parameters = parameters.to_reader()
- stream_handle = _lib.ArrowArrayStreamHandle()
- parameters._export_to_c(stream_handle.address)
- self._stmt.bind_stream(stream_handle)
+ raise TypeError(f"Cannot bind {type(parameters)}")
def _prepare_execute(self, operation, parameters=None) -> None:
self._results = None
@@ -690,6 +704,7 @@ class Cursor(_Closeable):
if _is_arrow_data(parameters):
self._bind(parameters)
elif parameters:
+ _requires_pyarrow()
rb = pyarrow.record_batch(
[[param_value] for param_value in parameters],
names=[str(i) for i in range(len(parameters))],
@@ -716,9 +731,7 @@ class Cursor(_Closeable):
handle, self._rowcount = _blocking_call(
self._stmt.execute_query, (), {}, self._stmt.cancel
)
- self._results = _RowIterator(
- self._stmt,
_reader.AdbcRecordBatchReader._import_from_c(handle.address)
- )
+ self._results = _RowIterator(self._stmt, handle)
def executemany(self, operation: Union[bytes, str], seq_of_parameters) ->
None:
"""
@@ -746,6 +759,7 @@ class Cursor(_Closeable):
if _is_arrow_data(seq_of_parameters):
arrow_parameters = seq_of_parameters
elif seq_of_parameters:
+ _requires_pyarrow()
arrow_parameters = pyarrow.RecordBatch.from_pydict(
{
str(col_idx): pyarrow.array(x)
@@ -753,6 +767,7 @@ class Cursor(_Closeable):
},
)
else:
+ _requires_pyarrow()
arrow_parameters = pyarrow.record_batch([])
self._bind(arrow_parameters)
@@ -836,7 +851,12 @@ class Cursor(_Closeable):
def adbc_ingest(
self,
table_name: str,
- data: Union[pyarrow.RecordBatch, pyarrow.Table,
pyarrow.RecordBatchReader],
+ data: Union[
+ "pyarrow.RecordBatch",
+ "pyarrow.Table",
+ "pyarrow.RecordBatchReader",
+ "CapsuleType",
+ ],
mode: Literal["append", "create", "replace", "create_append"] =
"create",
*,
catalog_name: Optional[str] = None,
@@ -932,24 +952,24 @@ class Cursor(_Closeable):
self._stmt.bind(data)
elif hasattr(data, "__arrow_c_stream__"):
self._stmt.bind_stream(data)
- elif isinstance(data, pyarrow.RecordBatch):
- array = _lib.ArrowArrayHandle()
- schema = _lib.ArrowSchemaHandle()
- data._export_to_c(array.address, schema.address)
- self._stmt.bind(array, schema)
+ elif _lib.is_pycapsule(data, b"arrow_array_stream"):
+ self._stmt.bind_stream(data)
else:
- if isinstance(data, pyarrow.Table):
- data = data.to_reader()
- elif isinstance(data, pyarrow.dataset.Dataset):
- data = data.scanner().to_reader()
+ _requires_pyarrow()
+ if isinstance(data, pyarrow.dataset.Dataset):
+ data = typing.cast(pyarrow.dataset.Dataset,
data).scanner().to_reader()
elif isinstance(data, pyarrow.dataset.Scanner):
- data = data.to_reader()
+ data = typing.cast(pyarrow.dataset.Scanner, data).to_reader()
elif not hasattr(data, "_export_to_c"):
- data = pyarrow.Table.from_batches(data)
- data = data.to_reader()
- handle = _lib.ArrowArrayStreamHandle()
- data._export_to_c(handle.address)
- self._stmt.bind_stream(handle)
+ data = pyarrow.Table.from_batches(data).to_reader()
+ if hasattr(data, "_export_to_c"):
+ handle = _lib.ArrowArrayStreamHandle()
+ # pyright doesn't seem to handle flow-sensitive typing here
+ data._export_to_c(handle.address) # type: ignore
+ self._stmt.bind_stream(handle)
+ else:
+ # Should be impossible from above but let's be explicit
+ raise TypeError(f"Cannot bind {type(data)}")
self._last_query = None
return _blocking_call(self._stmt.execute_update, (), {},
self._stmt.cancel)
@@ -958,7 +978,7 @@ class Cursor(_Closeable):
self,
operation,
parameters=None,
- ) -> Tuple[List[bytes], pyarrow.Schema]:
+ ) -> Tuple[List[bytes], "pyarrow.Schema"]:
"""
Execute a query and get the partitions of a distributed result set.
@@ -975,6 +995,7 @@ class Cursor(_Closeable):
-----
This is an extension and not part of the DBAPI standard.
"""
+ _requires_pyarrow()
self._prepare_execute(operation, parameters)
partitions, schema_handle, self._rowcount = _blocking_call(
self._stmt.execute_partitions, (), {}, self._stmt.cancel
@@ -985,7 +1006,7 @@ class Cursor(_Closeable):
schema = None
return partitions, schema
- def adbc_execute_schema(self, operation, parameters=None) ->
pyarrow.Schema:
+ def adbc_execute_schema(self, operation, parameters=None) ->
"pyarrow.Schema":
"""
Get the schema of the result set of a query without executing it.
@@ -998,11 +1019,12 @@ class Cursor(_Closeable):
-----
This is an extension and not part of the DBAPI standard.
"""
+ _requires_pyarrow()
self._prepare_execute(operation, parameters)
schema = _blocking_call(self._stmt.execute_schema, (), {},
self._stmt.cancel)
return pyarrow.Schema._import_from_c(schema.address)
- def adbc_prepare(self, operation: Union[bytes, str]) ->
Optional[pyarrow.Schema]:
+ def adbc_prepare(self, operation: Union[bytes, str]) ->
Optional["pyarrow.Schema"]:
"""
Prepare a query without executing it.
@@ -1020,6 +1042,7 @@ class Cursor(_Closeable):
-----
This is an extension and not part of the DBAPI standard.
"""
+ _requires_pyarrow()
self._prepare_execute(operation)
try:
@@ -1038,14 +1061,13 @@ class Cursor(_Closeable):
-----
This is an extension and not part of the DBAPI standard.
"""
+ _requires_pyarrow()
self._results = None
handle = _blocking_call(
self._conn._conn.read_partition, (partition,), {},
self._stmt.cancel
)
self._rowcount = -1
- self._results = _RowIterator(
- self._stmt,
pyarrow.RecordBatchReader._import_from_c(handle.address)
- )
+ self._results = _RowIterator(self._stmt, handle)
@property
def adbc_statement(self) -> _lib.AdbcStatement:
@@ -1076,7 +1098,7 @@ class Cursor(_Closeable):
self._stmt.set_sql_query(operation)
_blocking_call(self._stmt.execute_update, (), {}, self._stmt.cancel)
- def fetchallarrow(self) -> pyarrow.Table:
+ def fetchallarrow(self) -> "pyarrow.Table":
"""
Fetch all rows of the result as a PyArrow Table.
@@ -1088,7 +1110,7 @@ class Cursor(_Closeable):
"""
return self.fetch_arrow_table()
- def fetch_arrow_table(self) -> pyarrow.Table:
+ def fetch_arrow_table(self) -> "pyarrow.Table":
"""
Fetch all rows of the result as a PyArrow Table.
@@ -1122,7 +1144,22 @@ class Cursor(_Closeable):
)
return self._results.fetch_df()
- def fetch_record_batch(self) -> pyarrow.RecordBatchReader:
+ def fetch_polars(self) -> "polars.DataFrame":
+ """
+ Fetch all rows of the result as a Polars DataFrame.
+
+ Notes
+ -----
+ This is an extension and not part of the DBAPI standard.
+ """
+ if self._results is None:
+ raise ProgrammingError(
+ "Cannot fetch_polars() before execute()",
+ status_code=_lib.AdbcStatusCode.INVALID_STATE,
+ )
+ return self._results.fetch_polars()
+
+ def fetch_record_batch(self) -> "pyarrow.RecordBatchReader":
"""
Fetch the result as a PyArrow RecordBatchReader.
@@ -1133,6 +1170,7 @@ class Cursor(_Closeable):
-----
This is an extension and not part of the DBAPI standard.
"""
+ _requires_pyarrow()
if self._results is None:
raise ProgrammingError(
"Cannot fetch_record_batch() before execute()",
@@ -1141,7 +1179,27 @@ class Cursor(_Closeable):
# XXX(https://github.com/apache/arrow-adbc/issues/1523): return the
# "real" PyArrow reader since PyArrow may try to poke the internal C++
# reader pointer
- return self._results._reader._reader
+ return self._results.reader._reader
+
+ def fetch_arrow(self) -> _lib.ArrowArrayStreamHandle:
+ """
+ Fetch the result as an object implementing the Arrow PyCapsule
interface.
+
+ This can only be called once. It must be called before any other
+ method that inspect the data (e.g. description, fetchone,
+ fetch_arrow_table, etc.). Once this is called, other methods that
+ inspect the data may not be called.
+
+ Notes
+ -----
+ This is an extension and not part of the DBAPI standard.
+ """
+ if self._results is None:
+ raise ProgrammingError(
+ "Cannot fetch_arrow() before execute()",
+ status_code=_lib.AdbcStatusCode.INVALID_STATE,
+ )
+ return self._results.fetch_arrow()
# ----------------------------------------------------------
@@ -1151,24 +1209,41 @@ class Cursor(_Closeable):
class _RowIterator(_Closeable):
"""Track state needed to iterate over the result set."""
- def __init__(self, stmt, reader: pyarrow.RecordBatchReader) -> None:
+ def __init__(self, stmt, handle: _lib.ArrowArrayStreamHandle) -> None:
self._stmt = stmt
- self._reader = reader
+ self._handle: Optional[_lib.ArrowArrayStreamHandle] = handle
+ self._reader: Optional["_reader.AdbcRecordBatchReader"] = None
self._current_batch = None
self._next_row = 0
self._finished = False
self.rownumber = 0
def close(self) -> None:
- if hasattr(self._reader, "close"):
+ if self._reader is not None and hasattr(self._reader, "close"):
# Only in recent PyArrow
self._reader.close()
+ self._reader = None
+
+ @property
+ def reader(self) -> "_reader.AdbcRecordBatchReader":
+ if self._reader is None:
+ _requires_pyarrow()
+ if self._handle is None:
+ raise ProgrammingError(
+ "Result set has been closed or consumed",
+ status_code=_lib.AdbcStatusCode.INVALID_STATE,
+ )
+ else:
+ handle, self._handle = self._handle, None
+ klass = _reader.AdbcRecordBatchReader # type: ignore
+ self._reader = klass._import_from_c(handle.address)
+ return self._reader
@property
def description(self) -> List[tuple]:
return [
(field.name, field.type, None, None, None, None, None)
- for field in self._reader.schema
+ for field in self.reader.schema
]
def fetchone(self) -> Optional[tuple]:
@@ -1176,7 +1251,7 @@ class _RowIterator(_Closeable):
try:
while True:
self._current_batch = _blocking_call(
- self._reader.read_next_batch, (), {}, self._stmt.cancel
+ self.reader.read_next_batch, (), {}, self._stmt.cancel
)
if self._current_batch.num_rows > 0:
break
@@ -1211,11 +1286,33 @@ class _RowIterator(_Closeable):
rows.append(row)
return rows
- def fetch_arrow_table(self) -> pyarrow.Table:
- return _blocking_call(self._reader.read_all, (), {}, self._stmt.cancel)
+ def fetch_arrow_table(self) -> "pyarrow.Table":
+ return _blocking_call(self.reader.read_all, (), {}, self._stmt.cancel)
def fetch_df(self) -> "pandas.DataFrame":
- return _blocking_call(self._reader.read_pandas, (), {},
self._stmt.cancel)
+ return _blocking_call(self.reader.read_pandas, (), {},
self._stmt.cancel)
+
+ def fetch_polars(self) -> "polars.DataFrame":
+ import polars
+
+ return _blocking_call(
+ lambda: typing.cast(
+ polars.DataFrame,
+ polars.from_arrow(self.fetch_arrow()),
+ ),
+ (),
+ {},
+ self._stmt.cancel,
+ )
+
+ def fetch_arrow(self) -> _lib.ArrowArrayStreamHandle:
+ if self._handle is None:
+ raise ProgrammingError(
+ "Result set has been closed or consumed",
+ status_code=_lib.AdbcStatusCode.INVALID_STATE,
+ )
+ handle, self._handle = self._handle, None
+ return handle
_PYTEST_ENV_VAR = "PYTEST_CURRENT_TEST"
@@ -1234,10 +1331,19 @@ def _warn_unclosed(name):
def _is_arrow_data(data):
+ # No need to check for PyArrow types explicitly since they support the
+ # dunder methods
return (
hasattr(data, "__arrow_c_array__")
or hasattr(data, "__arrow_c_stream__")
- or isinstance(
- data, (pyarrow.RecordBatch, pyarrow.Table,
pyarrow.RecordBatchReader)
- )
+ or _lib.is_pycapsule(data, b"arrow_array")
+ or _lib.is_pycapsule(data, b"arrow_array_stream")
)
+
+
+def _requires_pyarrow():
+ if not _has_pyarrow:
+ raise ProgrammingError(
+ "This API requires PyArrow to be installed",
+ status_code=_lib.AdbcStatusCode.INVALID_STATE,
+ )
diff --git a/python/adbc_driver_manager/pyproject.toml
b/python/adbc_driver_manager/pyproject.toml
index 744024c97..a99be7c9b 100644
--- a/python/adbc_driver_manager/pyproject.toml
+++ b/python/adbc_driver_manager/pyproject.toml
@@ -41,6 +41,7 @@ build-backend = "setuptools.build_meta"
markers = [
"duckdb: tests that require DuckDB",
"panicdummy: tests that require the testing-only panicdummy driver",
+ "pyarrowless: tests of functionality when PyArrow is NOT installed",
"sqlite: tests that require the SQLite driver",
]
xfail_strict = true
diff --git a/python/adbc_driver_manager/tests/test_dbapi.py
b/python/adbc_driver_manager/tests/test_dbapi.py
index 2db92388f..8325a2d19 100644
--- a/python/adbc_driver_manager/tests/test_dbapi.py
+++ b/python/adbc_driver_manager/tests/test_dbapi.py
@@ -16,6 +16,8 @@
# under the License.
import pandas
+import polars
+import polars.testing
import pyarrow
import pyarrow.dataset
import pytest
@@ -165,6 +167,9 @@ class StreamWrapper:
lambda: StreamWrapper(
pyarrow.table([[1, 2], ["foo", ""]], names=["ints", "strs"])
),
+ lambda: pyarrow.table(
+ [[1, 2], ["foo", ""]], names=["ints", "strs"]
+ ).__arrow_c_stream__(),
],
)
@pytest.mark.sqlite
@@ -226,6 +231,27 @@ def test_query_fetch_py(sqlite):
@pytest.mark.sqlite
def test_query_fetch_arrow(sqlite):
+ with sqlite.cursor() as cur:
+ with pytest.raises(sqlite.ProgrammingError):
+ cur.fetch_arrow()
+
+ cur.execute("SELECT 1, 'foo' AS foo, 2.0")
+ capsule = cur.fetch_arrow().__arrow_c_stream__()
+ reader = pyarrow.RecordBatchReader._import_from_c_capsule(capsule)
+ assert reader.read_all() == pyarrow.table(
+ {
+ "1": [1],
+ "foo": ["foo"],
+ "2.0": [2.0],
+ }
+ )
+
+ with pytest.raises(sqlite.ProgrammingError):
+ cur.fetch_arrow()
+
+
[email protected]
+def test_query_fetch_arrow_table(sqlite):
with sqlite.cursor() as cur:
cur.execute("SELECT 1, 'foo' AS foo, 2.0")
assert cur.fetch_arrow_table() == pyarrow.table(
@@ -253,6 +279,22 @@ def test_query_fetch_df(sqlite):
)
[email protected]
+def test_query_fetch_polars(sqlite):
+ with sqlite.cursor() as cur:
+ cur.execute("SELECT 1, 'foo' AS foo, 2.0")
+ polars.testing.assert_frame_equal(
+ cur.fetch_polars(),
+ polars.DataFrame(
+ {
+ "1": [1],
+ "foo": ["foo"],
+ "2.0": [2.0],
+ }
+ ),
+ )
+
+
@pytest.mark.sqlite
@pytest.mark.parametrize(
"parameters",
diff --git a/python/adbc_driver_manager/tests/test_dbapi_nopyarrow.py
b/python/adbc_driver_manager/tests/test_dbapi_nopyarrow.py
new file mode 100644
index 000000000..f65763fd4
--- /dev/null
+++ b/python/adbc_driver_manager/tests/test_dbapi_nopyarrow.py
@@ -0,0 +1,156 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import typing
+
+import polars
+import polars.testing
+import pytest
+
+from adbc_driver_manager import dbapi
+
+pytestmark = pytest.mark.pyarrowless
+
+
[email protected](scope="module", autouse=True)
+def no_pyarrow() -> None:
+ try:
+ import pyarrow # noqa:F401
+ except ImportError:
+ return
+ else:
+ pytest.skip("Skipping because pyarrow is installed")
+
+
[email protected]
+def sqlite() -> typing.Generator[dbapi.Connection, None, None]:
+ with dbapi.connect(driver="adbc_driver_sqlite") as conn:
+ yield conn
+
+
[email protected](
+ "data",
+ [
+ pytest.param(polars.DataFrame({"theresult": [1]}),
id="polars.DataFrame"),
+ pytest.param(polars.Series([{"theresult": 1}]), id="polars.Series"),
+ pytest.param(
+ polars.DataFrame({"theresult": [1]}).__arrow_c_stream__(),
+ id="PyCapsule_Stream",
+ ),
+ ],
+)
+def test_ingest(sqlite: dbapi.Connection, data: typing.Any) -> None:
+ with sqlite.cursor() as cursor:
+ cursor.adbc_ingest("mytable", data)
+ cursor.execute("SELECT * FROM mytable")
+ df = cursor.fetch_polars()
+ polars.testing.assert_frame_equal(
+ df,
+ polars.DataFrame(
+ {
+ "theresult": [1],
+ }
+ ),
+ )
+
+
+def test_query(sqlite: dbapi.Connection) -> None:
+ with sqlite.cursor() as cursor:
+ cursor.execute("SELECT 1 AS theresult")
+ capsule = cursor.fetch_arrow()
+ df = typing.cast(polars.DataFrame, polars.from_arrow(capsule))
+ polars.testing.assert_frame_equal(
+ df,
+ polars.DataFrame(
+ {
+ "theresult": [1],
+ }
+ ),
+ )
+
+ cursor.execute("SELECT 1 AS theresult")
+ df = cursor.fetch_polars()
+ polars.testing.assert_frame_equal(
+ df,
+ polars.DataFrame(
+ {
+ "theresult": [1],
+ }
+ ),
+ )
+
+
[email protected](
+ "parameters",
+ [
+ pytest.param(polars.DataFrame({"$0": [1]}), id="polars.DataFrame"),
+ pytest.param(polars.Series([{"$0": 1}]), id="polars.Series"),
+ pytest.param(
+ polars.DataFrame({"$0": [1]}).__arrow_c_stream__(),
id="PyCapsule_Stream"
+ ),
+ ],
+)
+def test_query_bind(sqlite: dbapi.Connection, parameters: typing.Any) -> None:
+ with sqlite.cursor() as cursor:
+ cursor.execute("SELECT 1 + ? AS theresult", parameters=parameters)
+
+ df = cursor.fetch_polars()
+ polars.testing.assert_frame_equal(
+ df,
+ polars.DataFrame(
+ {
+ "theresult": [2],
+ }
+ ),
+ )
+
+
+def test_query_not_permitted(sqlite: dbapi.Connection) -> None:
+ with sqlite.cursor() as cursor:
+ cursor.execute("SELECT 1 AS theresult")
+
+ with pytest.raises(dbapi.ProgrammingError, match="requires PyArrow"):
+ cursor.fetchone()
+
+ with pytest.raises(dbapi.ProgrammingError, match="requires PyArrow"):
+ cursor.fetchall()
+
+ with pytest.raises(dbapi.ProgrammingError, match="requires PyArrow"):
+ cursor.fetchallarrow()
+
+ with pytest.raises(dbapi.ProgrammingError, match="requires PyArrow"):
+ cursor.fetch_arrow_table()
+
+ with pytest.raises(dbapi.ProgrammingError, match="requires PyArrow"):
+ cursor.fetch_df()
+
+ capsule = cursor.fetch_arrow()
+ # Import the result to free memory
+ polars.from_arrow(capsule)
+
+
+def test_query_double_capsule(sqlite: dbapi.Connection) -> None:
+ with sqlite.cursor() as cursor:
+ cursor.execute("SELECT 1 AS theresult")
+
+ capsule = cursor.fetch_arrow()
+
+ with pytest.raises(dbapi.ProgrammingError, match="has been closed"):
+ cursor.fetch_arrow()
+
+ # Import the result to free memory
+ polars.from_arrow(capsule)