This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-python.git
The following commit(s) were added to refs/heads/main by this push:
new 7366f89 Expose unnest feature (#641)
7366f89 is described below
commit 7366f8919d9b679baa5d312c515b29cc055867a2
Author: Tim Saucer <[email protected]>
AuthorDate: Tue Apr 23 10:52:08 2024 -0400
Expose unnest feature (#641)
* Expose unnest feature
* Update dataframe operation name to match rust implementation
---
datafusion/tests/test_dataframe.py | 34 +++++++++++++++
src/dataframe.rs | 12 ++++++
src/expr.rs | 2 +
src/expr/unnest.rs | 85 ++++++++++++++++++++++++++++++++++++++
src/sql/logical.rs | 2 +
5 files changed, 135 insertions(+)
diff --git a/datafusion/tests/test_dataframe.py
b/datafusion/tests/test_dataframe.py
index c8c74fa..efb1679 100644
--- a/datafusion/tests/test_dataframe.py
+++ b/datafusion/tests/test_dataframe.py
@@ -62,6 +62,20 @@ def struct_df():
return ctx.create_dataframe([[batch]])
[email protected]
+def nested_df():
+ ctx = SessionContext()
+
+ # create a RecordBatch and a new DataFrame from it
+ # Intentionally make each array of different length
+ batch = pa.RecordBatch.from_arrays(
+ [pa.array([[1], [2, 3], [4, 5, 6], None]), pa.array([7, 8, 9, 10])],
+ names=["a", "b"],
+ )
+
+ return ctx.create_dataframe([[batch]])
+
+
@pytest.fixture
def aggregate_df():
ctx = SessionContext()
@@ -160,6 +174,26 @@ def test_with_column_renamed(df):
assert result.schema.field(2).name == "sum"
+def test_unnest(nested_df):
+ nested_df = nested_df.unnest_column("a")
+
+ # execute and collect the first (and only) batch
+ result = nested_df.collect()[0]
+
+ assert result.column(0) == pa.array([1, 2, 3, 4, 5, 6, None])
+ assert result.column(1) == pa.array([7, 8, 8, 9, 9, 9, 10])
+
+
+def test_unnest_without_nulls(nested_df):
+ nested_df = nested_df.unnest_column("a", preserve_nulls=False)
+
+ # execute and collect the first (and only) batch
+ result = nested_df.collect()[0]
+
+ assert result.column(0) == pa.array([1, 2, 3, 4, 5, 6])
+ assert result.column(1) == pa.array([7, 8, 8, 9, 9, 9])
+
+
def test_udf(df):
# is_null is a pa function over arrays
is_null = udf(
diff --git a/src/dataframe.rs b/src/dataframe.rs
index a239a35..a319b3d 100644
--- a/src/dataframe.rs
+++ b/src/dataframe.rs
@@ -25,6 +25,7 @@ use datafusion::execution::SendableRecordBatchStream;
use datafusion::parquet::basic::{BrotliLevel, Compression, GzipLevel,
ZstdLevel};
use datafusion::parquet::file::properties::WriterProperties;
use datafusion::prelude::*;
+use datafusion_common::UnnestOptions;
use pyo3::exceptions::{PyTypeError, PyValueError};
use pyo3::prelude::*;
use pyo3::types::PyTuple;
@@ -293,6 +294,17 @@ impl PyDataFrame {
Ok(Self::new(new_df))
}
+ #[pyo3(signature = (column, preserve_nulls=true))]
+ fn unnest_column(&self, column: &str, preserve_nulls: bool) ->
PyResult<Self> {
+ let unnest_options = UnnestOptions { preserve_nulls };
+ let df = self
+ .df
+ .as_ref()
+ .clone()
+ .unnest_column_with_options(column, unnest_options)?;
+ Ok(Self::new(df))
+ }
+
/// Calculate the intersection of two `DataFrame`s. The two `DataFrame`s
must have exactly the same schema
fn intersect(&self, py_df: PyDataFrame) -> PyResult<Self> {
let new_df = self
diff --git a/src/expr.rs b/src/expr.rs
index c0e7019..0958c4a 100644
--- a/src/expr.rs
+++ b/src/expr.rs
@@ -89,6 +89,7 @@ pub mod subquery;
pub mod subquery_alias;
pub mod table_scan;
pub mod union;
+pub mod unnest;
pub mod window;
/// A PyExpr that can be used on a DataFrame
@@ -684,6 +685,7 @@ pub(crate) fn init_module(m: &PyModule) -> PyResult<()> {
m.add_class::<join::PyJoinConstraint>()?;
m.add_class::<cross_join::PyCrossJoin>()?;
m.add_class::<union::PyUnion>()?;
+ m.add_class::<unnest::PyUnnest>()?;
m.add_class::<extension::PyExtension>()?;
m.add_class::<filter::PyFilter>()?;
m.add_class::<projection::PyProjection>()?;
diff --git a/src/expr/unnest.rs b/src/expr/unnest.rs
new file mode 100644
index 0000000..33fb82f
--- /dev/null
+++ b/src/expr/unnest.rs
@@ -0,0 +1,85 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use datafusion_expr::logical_plan::Unnest;
+use pyo3::prelude::*;
+use std::fmt::{self, Display, Formatter};
+
+use crate::common::df_schema::PyDFSchema;
+use crate::expr::logical_node::LogicalNode;
+use crate::sql::logical::PyLogicalPlan;
+
+#[pyclass(name = "Unnest", module = "datafusion.expr", subclass)]
+#[derive(Clone)]
+pub struct PyUnnest {
+ unnest_: Unnest,
+}
+
+impl From<Unnest> for PyUnnest {
+ fn from(unnest_: Unnest) -> PyUnnest {
+ PyUnnest { unnest_ }
+ }
+}
+
+impl From<PyUnnest> for Unnest {
+ fn from(unnest_: PyUnnest) -> Self {
+ unnest_.unnest_
+ }
+}
+
+impl Display for PyUnnest {
+ fn fmt(&self, f: &mut Formatter) -> fmt::Result {
+ write!(
+ f,
+ "Unnest
+ Inputs: {:?}
+ Schema: {:?}",
+ &self.unnest_.input, &self.unnest_.schema,
+ )
+ }
+}
+
+#[pymethods]
+impl PyUnnest {
+ /// Retrieves the input `LogicalPlan` to this `Unnest` node
+ fn input(&self) -> PyResult<Vec<PyLogicalPlan>> {
+ Ok(Self::inputs(self))
+ }
+
+ /// Resulting Schema for this `Unnest` node instance
+ fn schema(&self) -> PyResult<PyDFSchema> {
+ Ok(self.unnest_.schema.as_ref().clone().into())
+ }
+
+ fn __repr__(&self) -> PyResult<String> {
+ Ok(format!("Unnest({})", self))
+ }
+
+ fn __name__(&self) -> PyResult<String> {
+ Ok("Unnest".to_string())
+ }
+}
+
+impl LogicalNode for PyUnnest {
+ fn inputs(&self) -> Vec<PyLogicalPlan> {
+ vec![PyLogicalPlan::from((*self.unnest_.input).clone())]
+ }
+
+ fn to_variant(&self, py: Python) -> PyResult<PyObject> {
+ Ok(self.clone().into_py(py))
+ }
+}
diff --git a/src/sql/logical.rs b/src/sql/logical.rs
index 3aa8a69..62515c3 100644
--- a/src/sql/logical.rs
+++ b/src/sql/logical.rs
@@ -33,6 +33,7 @@ use crate::expr::sort::PySort;
use crate::expr::subquery::PySubquery;
use crate::expr::subquery_alias::PySubqueryAlias;
use crate::expr::table_scan::PyTableScan;
+use crate::expr::unnest::PyUnnest;
use crate::expr::window::PyWindow;
use datafusion_expr::LogicalPlan;
use pyo3::prelude::*;
@@ -78,6 +79,7 @@ impl PyLogicalPlan {
LogicalPlan::TableScan(plan) =>
PyTableScan::from(plan.clone()).to_variant(py),
LogicalPlan::Subquery(plan) =>
PySubquery::from(plan.clone()).to_variant(py),
LogicalPlan::SubqueryAlias(plan) =>
PySubqueryAlias::from(plan.clone()).to_variant(py),
+ LogicalPlan::Unnest(plan) =>
PyUnnest::from(plan.clone()).to_variant(py),
LogicalPlan::Window(plan) =>
PyWindow::from(plan.clone()).to_variant(py),
other => Err(py_unsupported_variant_err(format!(
"Cannot convert this plan to a LogicalNode: {:?}",
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]