This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-python.git
The following commit(s) were added to refs/heads/main by this push:
new 45a6844 Pyarrow filter pushdowns (#735)
45a6844 is described below
commit 45a684445e25032961a7bb44ced3ce06f5ed9e6d
Author: Michael J Ward <[email protected]>
AuthorDate: Wed Jun 19 11:20:39 2024 -0500
Pyarrow filter pushdowns (#735)
* fix pushdown for pyarrow filter IsNull
The conversion was incorrectly passing in the expression itself as the
`nan_as_null` argument. This caused the pushdown to silently fail.
* expand the Expr::Literal's that can be used in PyArrowFilterExpression
Closes #703
---
python/datafusion/tests/test_context.py | 54 +++++++++++++++++++++++++++++++++
src/pyarrow_filter_expression.rs | 29 ++++++------------
2 files changed, 64 insertions(+), 19 deletions(-)
diff --git a/python/datafusion/tests/test_context.py
b/python/datafusion/tests/test_context.py
index df7e181..abc324d 100644
--- a/python/datafusion/tests/test_context.py
+++ b/python/datafusion/tests/test_context.py
@@ -16,6 +16,7 @@
# under the License.
import gzip
import os
+import datetime as dt
import pyarrow as pa
import pyarrow.dataset as ds
@@ -303,6 +304,59 @@ def test_dataset_filter(ctx, capfd):
assert result[0].column(1) == pa.array([-3])
+def test_pyarrow_predicate_pushdown_is_null(ctx, capfd):
+ """Ensure that pyarrow filter gets pushed down for `IsNull`"""
+ # create a RecordBatch and register it as a pyarrow.dataset.Dataset
+ batch = pa.RecordBatch.from_arrays(
+ [pa.array([1, 2, 3]), pa.array([4, 5, 6]), pa.array([7, None, 9])],
+ names=["a", "b", "c"],
+ )
+ dataset = ds.dataset([batch])
+ ctx.register_dataset("t", dataset)
+ # Make sure the filter was pushed down in Physical Plan
+ df = ctx.sql("SELECT a FROM t WHERE c is NULL")
+ df.explain()
+ captured = capfd.readouterr()
+ assert "filter_expr=is_null(c, {nan_is_null=false})" in captured.out
+
+ result = df.collect()
+ assert result[0].column(0) == pa.array([2])
+
+
+def test_pyarrow_predicate_pushdown_timestamp(ctx, tmpdir, capfd):
+ """Ensure that pyarrow filter gets pushed down for timestamp"""
+ # Ref: https://github.com/apache/datafusion-python/issues/703
+
+ # create pyarrow dataset with no actual files
+ col_type = pa.timestamp("ns", "+00:00")
+ nyd_2000 = pa.scalar(dt.datetime(2000, 1, 1, tzinfo=dt.timezone.utc),
col_type)
+ pa_dataset_fs = pa.fs.SubTreeFileSystem(str(tmpdir),
pa.fs.LocalFileSystem())
+ pa_dataset_format = pa.dataset.ParquetFileFormat()
+ pa_dataset_partition = pa.dataset.field("a") <= nyd_2000
+ fragments = [
+ # NOTE: we never actually make this file.
+ # Working predicate pushdown means it never gets accessed
+ pa_dataset_format.make_fragment(
+ "1.parquet",
+ filesystem=pa_dataset_fs,
+ partition_expression=pa_dataset_partition,
+ )
+ ]
+ pa_dataset = pa.dataset.FileSystemDataset(
+ fragments,
+ pa.schema([pa.field("a", col_type)]),
+ pa_dataset_format,
+ pa_dataset_fs,
+ )
+
+ ctx.register_dataset("t", pa_dataset)
+
+ # the partition for our only fragment is for a < 2000-01-01.
+ # so querying for a > 2024-01-01 should not touch any files
+ df = ctx.sql("SELECT * FROM t WHERE a > '2024-01-01T00:00:00+00:00'")
+ assert df.collect() == []
+
+
def test_dataset_filter_nested_data(ctx):
# create Arrow StructArrays to test nested data types
data = pa.StructArray.from_arrays(
diff --git a/src/pyarrow_filter_expression.rs b/src/pyarrow_filter_expression.rs
index fca8851..ff447e1 100644
--- a/src/pyarrow_filter_expression.rs
+++ b/src/pyarrow_filter_expression.rs
@@ -21,6 +21,7 @@ use pyo3::prelude::*;
use std::convert::TryFrom;
use std::result::Result;
+use arrow::pyarrow::ToPyArrow;
use datafusion_common::{Column, ScalarValue};
use datafusion_expr::{expr::InList, Between, BinaryExpr, Expr, Operator};
@@ -56,6 +57,7 @@ fn extract_scalar_list(exprs: &[Expr], py: Python) ->
Result<Vec<PyObject>, Data
let ret: Result<Vec<PyObject>, DataFusionError> = exprs
.iter()
.map(|expr| match expr {
+ // TODO: should we also leverage `ScalarValue::to_pyarrow` here?
Expr::Literal(v) => match v {
ScalarValue::Boolean(Some(b)) => Ok(b.into_py(py)),
ScalarValue::Int8(Some(i)) => Ok(i.into_py(py)),
@@ -100,23 +102,7 @@ impl TryFrom<&Expr> for PyArrowFilterExpression {
let op_module = Python::import_bound(py, "operator")?;
let pc_expr: Result<Bound<'_, PyAny>, DataFusionError> = match
expr {
Expr::Column(Column { name, .. }) =>
Ok(pc.getattr("field")?.call1((name,))?),
- Expr::Literal(v) => match v {
- ScalarValue::Boolean(Some(b)) =>
Ok(pc.getattr("scalar")?.call1((*b,))?),
- ScalarValue::Int8(Some(i)) =>
Ok(pc.getattr("scalar")?.call1((*i,))?),
- ScalarValue::Int16(Some(i)) =>
Ok(pc.getattr("scalar")?.call1((*i,))?),
- ScalarValue::Int32(Some(i)) =>
Ok(pc.getattr("scalar")?.call1((*i,))?),
- ScalarValue::Int64(Some(i)) =>
Ok(pc.getattr("scalar")?.call1((*i,))?),
- ScalarValue::UInt8(Some(i)) =>
Ok(pc.getattr("scalar")?.call1((*i,))?),
- ScalarValue::UInt16(Some(i)) =>
Ok(pc.getattr("scalar")?.call1((*i,))?),
- ScalarValue::UInt32(Some(i)) =>
Ok(pc.getattr("scalar")?.call1((*i,))?),
- ScalarValue::UInt64(Some(i)) =>
Ok(pc.getattr("scalar")?.call1((*i,))?),
- ScalarValue::Float32(Some(f)) =>
Ok(pc.getattr("scalar")?.call1((*f,))?),
- ScalarValue::Float64(Some(f)) =>
Ok(pc.getattr("scalar")?.call1((*f,))?),
- ScalarValue::Utf8(Some(s)) =>
Ok(pc.getattr("scalar")?.call1((s,))?),
- _ => Err(DataFusionError::Common(format!(
- "PyArrow can't handle ScalarValue: {v:?}"
- ))),
- },
+ Expr::Literal(scalar) =>
Ok(scalar.to_pyarrow(py)?.into_bound(py)),
Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
let operator = operator_to_py(op, &op_module)?;
let left =
PyArrowFilterExpression::try_from(left.as_ref())?.0;
@@ -138,8 +124,13 @@ impl TryFrom<&Expr> for PyArrowFilterExpression {
let expr =
PyArrowFilterExpression::try_from(expr.as_ref())?
.0
.into_bound(py);
- // TODO: this expression does not seems like it should be
`call_method0`
- Ok(expr.clone().call_method1("is_null", (expr,))?)
+
+ //
https://arrow.apache.org/docs/python/generated/pyarrow.dataset.Expression.html#pyarrow.dataset.Expression.is_null
+ // Whether floating-point NaNs are considered null.
+ let nan_is_null = false;
+
+ let res = expr.call_method1("is_null", (nan_is_null,))?;
+ Ok(res)
}
Expr::Between(Between {
expr,
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]