This is an automated email from the ASF dual-hosted git repository.
timsaucer pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-python.git
The following commit(s) were added to refs/heads/main by this push:
new 9b6acec0 Support types other than String and Int for partition columns
(#1154)
9b6acec0 is described below
commit 9b6acec075f49d551a2b90608b0c7114de84d718
Author: Michele Gregori <[email protected]>
AuthorDate: Thu Jun 19 19:58:22 2025 +0200
Support types other than String and Int for partition columns (#1154)
* impl
impl
* fix test
* format rust
* support for old logic
dasdas
* also on io
* fix formatting
---------
Co-authored-by: michele gregori <[email protected]>
---
python/datafusion/context.py | 66 +++++++++++++++++++++++++++-----
python/datafusion/io.py | 8 ++--
python/tests/test_sql.py | 26 +++++++------
src/context.rs | 89 ++++++++++++++++++++++++++++----------------
4 files changed, 132 insertions(+), 57 deletions(-)
diff --git a/python/datafusion/context.py b/python/datafusion/context.py
index 4ed465c9..5b99b0d2 100644
--- a/python/datafusion/context.py
+++ b/python/datafusion/context.py
@@ -19,8 +19,11 @@
from __future__ import annotations
+import warnings
from typing import TYPE_CHECKING, Any, Protocol
+import pyarrow as pa
+
try:
from warnings import deprecated # Python 3.13+
except ImportError:
@@ -42,7 +45,6 @@ if TYPE_CHECKING:
import pandas as pd
import polars as pl
- import pyarrow as pa
from datafusion.plan import ExecutionPlan, LogicalPlan
@@ -539,7 +541,7 @@ class SessionContext:
self,
name: str,
path: str | pathlib.Path,
- table_partition_cols: list[tuple[str, str]] | None = None,
+ table_partition_cols: list[tuple[str, str | pa.DataType]] | None =
None,
file_extension: str = ".parquet",
schema: pa.Schema | None = None,
file_sort_order: list[list[Expr | SortExpr]] | None = None,
@@ -560,6 +562,7 @@ class SessionContext:
"""
if table_partition_cols is None:
table_partition_cols = []
+ table_partition_cols =
self._convert_table_partition_cols(table_partition_cols)
file_sort_order_raw = (
[sort_list_to_raw_sort_list(f) for f in file_sort_order]
if file_sort_order is not None
@@ -778,7 +781,7 @@ class SessionContext:
self,
name: str,
path: str | pathlib.Path,
- table_partition_cols: list[tuple[str, str]] | None = None,
+ table_partition_cols: list[tuple[str, str | pa.DataType]] | None =
None,
parquet_pruning: bool = True,
file_extension: str = ".parquet",
skip_metadata: bool = True,
@@ -806,6 +809,7 @@ class SessionContext:
"""
if table_partition_cols is None:
table_partition_cols = []
+ table_partition_cols =
self._convert_table_partition_cols(table_partition_cols)
self.ctx.register_parquet(
name,
str(path),
@@ -869,7 +873,7 @@ class SessionContext:
schema: pa.Schema | None = None,
schema_infer_max_records: int = 1000,
file_extension: str = ".json",
- table_partition_cols: list[tuple[str, str]] | None = None,
+ table_partition_cols: list[tuple[str, str | pa.DataType]] | None =
None,
file_compression_type: str | None = None,
) -> None:
"""Register a JSON file as a table.
@@ -890,6 +894,7 @@ class SessionContext:
"""
if table_partition_cols is None:
table_partition_cols = []
+ table_partition_cols =
self._convert_table_partition_cols(table_partition_cols)
self.ctx.register_json(
name,
str(path),
@@ -906,7 +911,7 @@ class SessionContext:
path: str | pathlib.Path,
schema: pa.Schema | None = None,
file_extension: str = ".avro",
- table_partition_cols: list[tuple[str, str]] | None = None,
+ table_partition_cols: list[tuple[str, str | pa.DataType]] | None =
None,
) -> None:
"""Register an Avro file as a table.
@@ -922,6 +927,7 @@ class SessionContext:
"""
if table_partition_cols is None:
table_partition_cols = []
+ table_partition_cols =
self._convert_table_partition_cols(table_partition_cols)
self.ctx.register_avro(
name, str(path), schema, file_extension, table_partition_cols
)
@@ -981,7 +987,7 @@ class SessionContext:
schema: pa.Schema | None = None,
schema_infer_max_records: int = 1000,
file_extension: str = ".json",
- table_partition_cols: list[tuple[str, str]] | None = None,
+ table_partition_cols: list[tuple[str, str | pa.DataType]] | None =
None,
file_compression_type: str | None = None,
) -> DataFrame:
"""Read a line-delimited JSON data source.
@@ -1001,6 +1007,7 @@ class SessionContext:
"""
if table_partition_cols is None:
table_partition_cols = []
+ table_partition_cols =
self._convert_table_partition_cols(table_partition_cols)
return DataFrame(
self.ctx.read_json(
str(path),
@@ -1020,7 +1027,7 @@ class SessionContext:
delimiter: str = ",",
schema_infer_max_records: int = 1000,
file_extension: str = ".csv",
- table_partition_cols: list[tuple[str, str]] | None = None,
+ table_partition_cols: list[tuple[str, str | pa.DataType]] | None =
None,
file_compression_type: str | None = None,
) -> DataFrame:
"""Read a CSV data source.
@@ -1045,6 +1052,7 @@ class SessionContext:
"""
if table_partition_cols is None:
table_partition_cols = []
+ table_partition_cols =
self._convert_table_partition_cols(table_partition_cols)
path = [str(p) for p in path] if isinstance(path, list) else str(path)
@@ -1064,7 +1072,7 @@ class SessionContext:
def read_parquet(
self,
path: str | pathlib.Path,
- table_partition_cols: list[tuple[str, str]] | None = None,
+ table_partition_cols: list[tuple[str, str | pa.DataType]] | None =
None,
parquet_pruning: bool = True,
file_extension: str = ".parquet",
skip_metadata: bool = True,
@@ -1093,6 +1101,7 @@ class SessionContext:
"""
if table_partition_cols is None:
table_partition_cols = []
+ table_partition_cols =
self._convert_table_partition_cols(table_partition_cols)
file_sort_order = (
[sort_list_to_raw_sort_list(f) for f in file_sort_order]
if file_sort_order is not None
@@ -1114,7 +1123,7 @@ class SessionContext:
self,
path: str | pathlib.Path,
schema: pa.Schema | None = None,
- file_partition_cols: list[tuple[str, str]] | None = None,
+ file_partition_cols: list[tuple[str, str | pa.DataType]] | None = None,
file_extension: str = ".avro",
) -> DataFrame:
"""Create a :py:class:`DataFrame` for reading Avro data source.
@@ -1130,6 +1139,7 @@ class SessionContext:
"""
if file_partition_cols is None:
file_partition_cols = []
+ file_partition_cols =
self._convert_table_partition_cols(file_partition_cols)
return DataFrame(
self.ctx.read_avro(str(path), schema, file_partition_cols,
file_extension)
)
@@ -1146,3 +1156,41 @@ class SessionContext:
def execute(self, plan: ExecutionPlan, partitions: int) ->
RecordBatchStream:
"""Execute the ``plan`` and return the results."""
return RecordBatchStream(self.ctx.execute(plan._raw_plan, partitions))
+
+ @staticmethod
+ def _convert_table_partition_cols(
+ table_partition_cols: list[tuple[str, str | pa.DataType]],
+ ) -> list[tuple[str, pa.DataType]]:
+ warn = False
+ converted_table_partition_cols = []
+
+ for col, data_type in table_partition_cols:
+ if isinstance(data_type, str):
+ warn = True
+ if data_type == "string":
+ converted_data_type = pa.string()
+ elif data_type == "int":
+ converted_data_type = pa.int32()
+ else:
+ message = (
+ f"Unsupported literal data type '{data_type}' for
partition "
+ "column. Supported types are 'string' and 'int'"
+ )
+ raise ValueError(message)
+ else:
+ converted_data_type = data_type
+
+ converted_table_partition_cols.append((col, converted_data_type))
+
+ if warn:
+ message = (
+ "using literals for table_partition_cols data types is
deprecated,"
+ "use pyarrow types instead"
+ )
+ warnings.warn(
+ message,
+ category=DeprecationWarning,
+ stacklevel=2,
+ )
+
+ return converted_table_partition_cols
diff --git a/python/datafusion/io.py b/python/datafusion/io.py
index ef5ebf96..551e20a6 100644
--- a/python/datafusion/io.py
+++ b/python/datafusion/io.py
@@ -34,7 +34,7 @@ if TYPE_CHECKING:
def read_parquet(
path: str | pathlib.Path,
- table_partition_cols: list[tuple[str, str]] | None = None,
+ table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None,
parquet_pruning: bool = True,
file_extension: str = ".parquet",
skip_metadata: bool = True,
@@ -83,7 +83,7 @@ def read_json(
schema: pa.Schema | None = None,
schema_infer_max_records: int = 1000,
file_extension: str = ".json",
- table_partition_cols: list[tuple[str, str]] | None = None,
+ table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None,
file_compression_type: str | None = None,
) -> DataFrame:
"""Read a line-delimited JSON data source.
@@ -124,7 +124,7 @@ def read_csv(
delimiter: str = ",",
schema_infer_max_records: int = 1000,
file_extension: str = ".csv",
- table_partition_cols: list[tuple[str, str]] | None = None,
+ table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None,
file_compression_type: str | None = None,
) -> DataFrame:
"""Read a CSV data source.
@@ -171,7 +171,7 @@ def read_csv(
def read_avro(
path: str | pathlib.Path,
schema: pa.Schema | None = None,
- file_partition_cols: list[tuple[str, str]] | None = None,
+ file_partition_cols: list[tuple[str, str | pa.DataType]] | None = None,
file_extension: str = ".avro",
) -> DataFrame:
"""Create a :py:class:`DataFrame` for reading Avro data source.
diff --git a/python/tests/test_sql.py b/python/tests/test_sql.py
index b6348e3a..41cee4ef 100644
--- a/python/tests/test_sql.py
+++ b/python/tests/test_sql.py
@@ -157,8 +157,10 @@ def test_register_parquet(ctx, tmp_path):
assert result.to_pydict() == {"cnt": [100]}
[email protected]("path_to_str", [True, False])
-def test_register_parquet_partitioned(ctx, tmp_path, path_to_str):
[email protected](
+ ("path_to_str", "legacy_data_type"), [(True, False), (False, False),
(False, True)]
+)
+def test_register_parquet_partitioned(ctx, tmp_path, path_to_str,
legacy_data_type):
dir_root = tmp_path / "dataset_parquet_partitioned"
dir_root.mkdir(exist_ok=False)
(dir_root / "grp=a").mkdir(exist_ok=False)
@@ -177,10 +179,12 @@ def test_register_parquet_partitioned(ctx, tmp_path,
path_to_str):
dir_root = str(dir_root) if path_to_str else dir_root
+ partition_data_type = "string" if legacy_data_type else pa.string()
+
ctx.register_parquet(
"datapp",
dir_root,
- table_partition_cols=[("grp", "string")],
+ table_partition_cols=[("grp", partition_data_type)],
parquet_pruning=True,
file_extension=".parquet",
)
@@ -488,9 +492,9 @@ def test_register_listing_table(
):
dir_root = tmp_path / "dataset_parquet_partitioned"
dir_root.mkdir(exist_ok=False)
- (dir_root / "grp=a/date_id=20201005").mkdir(exist_ok=False, parents=True)
- (dir_root / "grp=a/date_id=20211005").mkdir(exist_ok=False, parents=True)
- (dir_root / "grp=b/date_id=20201005").mkdir(exist_ok=False, parents=True)
+ (dir_root / "grp=a/date=2020-10-05").mkdir(exist_ok=False, parents=True)
+ (dir_root / "grp=a/date=2021-10-05").mkdir(exist_ok=False, parents=True)
+ (dir_root / "grp=b/date=2020-10-05").mkdir(exist_ok=False, parents=True)
table = pa.Table.from_arrays(
[
@@ -501,13 +505,13 @@ def test_register_listing_table(
names=["int", "str", "float"],
)
pa.parquet.write_table(
- table.slice(0, 3), dir_root / "grp=a/date_id=20201005/file.parquet"
+ table.slice(0, 3), dir_root / "grp=a/date=2020-10-05/file.parquet"
)
pa.parquet.write_table(
- table.slice(3, 2), dir_root / "grp=a/date_id=20211005/file.parquet"
+ table.slice(3, 2), dir_root / "grp=a/date=2021-10-05/file.parquet"
)
pa.parquet.write_table(
- table.slice(5, 10), dir_root / "grp=b/date_id=20201005/file.parquet"
+ table.slice(5, 10), dir_root / "grp=b/date=2020-10-05/file.parquet"
)
dir_root = f"file://{dir_root}/" if path_to_str else dir_root
@@ -515,7 +519,7 @@ def test_register_listing_table(
ctx.register_listing_table(
"my_table",
dir_root,
- table_partition_cols=[("grp", "string"), ("date_id", "int")],
+ table_partition_cols=[("grp", pa.string()), ("date", pa.date64())],
file_extension=".parquet",
schema=table.schema if pass_schema else None,
file_sort_order=file_sort_order,
@@ -531,7 +535,7 @@ def test_register_listing_table(
assert dict(zip(rd["grp"], rd["count"])) == {"a": 5, "b": 2}
result = ctx.sql(
- "SELECT grp, COUNT(*) AS count FROM my_table WHERE date_id=20201005
GROUP BY grp" # noqa: E501
+ "SELECT grp, COUNT(*) AS count FROM my_table WHERE date='2020-10-05'
GROUP BY grp" # noqa: E501
).collect()
result = pa.Table.from_batches(result)
diff --git a/src/context.rs b/src/context.rs
index 55c92a8f..6ce1f12b 100644
--- a/src/context.rs
+++ b/src/context.rs
@@ -380,7 +380,7 @@ impl PySessionContext {
&mut self,
name: &str,
path: &str,
- table_partition_cols: Vec<(String, String)>,
+ table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
file_extension: &str,
schema: Option<PyArrowType<Schema>>,
file_sort_order: Option<Vec<Vec<PySortExpr>>>,
@@ -388,7 +388,12 @@ impl PySessionContext {
) -> PyDataFusionResult<()> {
let options = ListingOptions::new(Arc::new(ParquetFormat::new()))
.with_file_extension(file_extension)
-
.with_table_partition_cols(convert_table_partition_cols(table_partition_cols)?)
+ .with_table_partition_cols(
+ table_partition_cols
+ .into_iter()
+ .map(|(name, ty)| (name, ty.0))
+ .collect::<Vec<(String, DataType)>>(),
+ )
.with_file_sort_order(
file_sort_order
.unwrap_or_default()
@@ -656,7 +661,7 @@ impl PySessionContext {
&mut self,
name: &str,
path: &str,
- table_partition_cols: Vec<(String, String)>,
+ table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
parquet_pruning: bool,
file_extension: &str,
skip_metadata: bool,
@@ -665,7 +670,12 @@ impl PySessionContext {
py: Python,
) -> PyDataFusionResult<()> {
let mut options = ParquetReadOptions::default()
-
.table_partition_cols(convert_table_partition_cols(table_partition_cols)?)
+ .table_partition_cols(
+ table_partition_cols
+ .into_iter()
+ .map(|(name, ty)| (name, ty.0))
+ .collect::<Vec<(String, DataType)>>(),
+ )
.parquet_pruning(parquet_pruning)
.skip_metadata(skip_metadata);
options.file_extension = file_extension;
@@ -745,7 +755,7 @@ impl PySessionContext {
schema: Option<PyArrowType<Schema>>,
schema_infer_max_records: usize,
file_extension: &str,
- table_partition_cols: Vec<(String, String)>,
+ table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
file_compression_type: Option<String>,
py: Python,
) -> PyDataFusionResult<()> {
@@ -755,7 +765,12 @@ impl PySessionContext {
let mut options = NdJsonReadOptions::default()
.file_compression_type(parse_file_compression_type(file_compression_type)?)
-
.table_partition_cols(convert_table_partition_cols(table_partition_cols)?);
+ .table_partition_cols(
+ table_partition_cols
+ .into_iter()
+ .map(|(name, ty)| (name, ty.0))
+ .collect::<Vec<(String, DataType)>>(),
+ );
options.schema_infer_max_records = schema_infer_max_records;
options.file_extension = file_extension;
options.schema = schema.as_ref().map(|x| &x.0);
@@ -778,15 +793,19 @@ impl PySessionContext {
path: PathBuf,
schema: Option<PyArrowType<Schema>>,
file_extension: &str,
- table_partition_cols: Vec<(String, String)>,
+ table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
py: Python,
) -> PyDataFusionResult<()> {
let path = path
.to_str()
.ok_or_else(|| PyValueError::new_err("Unable to convert path to a
string"))?;
- let mut options = AvroReadOptions::default()
-
.table_partition_cols(convert_table_partition_cols(table_partition_cols)?);
+ let mut options = AvroReadOptions::default().table_partition_cols(
+ table_partition_cols
+ .into_iter()
+ .map(|(name, ty)| (name, ty.0))
+ .collect::<Vec<(String, DataType)>>(),
+ );
options.file_extension = file_extension;
options.schema = schema.as_ref().map(|x| &x.0);
@@ -887,7 +906,7 @@ impl PySessionContext {
schema: Option<PyArrowType<Schema>>,
schema_infer_max_records: usize,
file_extension: &str,
- table_partition_cols: Vec<(String, String)>,
+ table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
file_compression_type: Option<String>,
py: Python,
) -> PyDataFusionResult<PyDataFrame> {
@@ -895,7 +914,12 @@ impl PySessionContext {
.to_str()
.ok_or_else(|| PyValueError::new_err("Unable to convert path to a
string"))?;
let mut options = NdJsonReadOptions::default()
-
.table_partition_cols(convert_table_partition_cols(table_partition_cols)?)
+ .table_partition_cols(
+ table_partition_cols
+ .into_iter()
+ .map(|(name, ty)| (name, ty.0))
+ .collect::<Vec<(String, DataType)>>(),
+ )
.file_compression_type(parse_file_compression_type(file_compression_type)?);
options.schema_infer_max_records = schema_infer_max_records;
options.file_extension = file_extension;
@@ -928,7 +952,7 @@ impl PySessionContext {
delimiter: &str,
schema_infer_max_records: usize,
file_extension: &str,
- table_partition_cols: Vec<(String, String)>,
+ table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
file_compression_type: Option<String>,
py: Python,
) -> PyDataFusionResult<PyDataFrame> {
@@ -944,7 +968,12 @@ impl PySessionContext {
.delimiter(delimiter[0])
.schema_infer_max_records(schema_infer_max_records)
.file_extension(file_extension)
-
.table_partition_cols(convert_table_partition_cols(table_partition_cols)?)
+ .table_partition_cols(
+ table_partition_cols
+ .into_iter()
+ .map(|(name, ty)| (name, ty.0))
+ .collect::<Vec<(String, DataType)>>(),
+ )
.file_compression_type(parse_file_compression_type(file_compression_type)?);
options.schema = schema.as_ref().map(|x| &x.0);
@@ -974,7 +1003,7 @@ impl PySessionContext {
pub fn read_parquet(
&self,
path: &str,
- table_partition_cols: Vec<(String, String)>,
+ table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
parquet_pruning: bool,
file_extension: &str,
skip_metadata: bool,
@@ -983,7 +1012,12 @@ impl PySessionContext {
py: Python,
) -> PyDataFusionResult<PyDataFrame> {
let mut options = ParquetReadOptions::default()
-
.table_partition_cols(convert_table_partition_cols(table_partition_cols)?)
+ .table_partition_cols(
+ table_partition_cols
+ .into_iter()
+ .map(|(name, ty)| (name, ty.0))
+ .collect::<Vec<(String, DataType)>>(),
+ )
.parquet_pruning(parquet_pruning)
.skip_metadata(skip_metadata);
options.file_extension = file_extension;
@@ -1005,12 +1039,16 @@ impl PySessionContext {
&self,
path: &str,
schema: Option<PyArrowType<Schema>>,
- table_partition_cols: Vec<(String, String)>,
+ table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
file_extension: &str,
py: Python,
) -> PyDataFusionResult<PyDataFrame> {
- let mut options = AvroReadOptions::default()
-
.table_partition_cols(convert_table_partition_cols(table_partition_cols)?);
+ let mut options = AvroReadOptions::default().table_partition_cols(
+ table_partition_cols
+ .into_iter()
+ .map(|(name, ty)| (name, ty.0))
+ .collect::<Vec<(String, DataType)>>(),
+ );
options.file_extension = file_extension;
let df = if let Some(schema) = schema {
options.schema = Some(&schema.0);
@@ -1109,21 +1147,6 @@ impl PySessionContext {
}
}
-pub fn convert_table_partition_cols(
- table_partition_cols: Vec<(String, String)>,
-) -> PyDataFusionResult<Vec<(String, DataType)>> {
- table_partition_cols
- .into_iter()
- .map(|(name, ty)| match ty.as_str() {
- "string" => Ok((name, DataType::Utf8)),
- "int" => Ok((name, DataType::Int32)),
- _ => Err(crate::errors::PyDataFusionError::Common(format!(
- "Unsupported data type '{ty}' for partition column. Supported
types are 'string' and 'int'"
- ))),
- })
- .collect::<Result<Vec<_>, _>>()
-}
-
pub fn parse_file_compression_type(
file_compression_type: Option<String>,
) -> Result<FileCompressionType, PyErr> {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]