This is an automated email from the ASF dual-hosted git repository.

mjward pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-python.git


The following commit(s) were added to refs/heads/main by this push:
     new a00cfbf  feat: aggregates as windows (#871)
a00cfbf is described below

commit a00cfbfdbf3143c8c56d9ea043be3fb69da008ee
Author: Tim Saucer <[email protected]>
AuthorDate: Wed Sep 18 16:56:43 2024 -0400

    feat: aggregates as windows (#871)
    
    * Add  to turn any aggregate function into a window function
    
    * Rename Window to WindowExpr so we can define Window to mean a window 
definition to be reused
    
    * Add unit test to cover default frames
    
    * Improve error report
---
 python/datafusion/expr.py                 | 57 ++++++++++++++++++++++-
 python/datafusion/tests/test_dataframe.py | 75 ++++++++++++++++++++++---------
 src/expr.rs                               | 46 ++++++++++++++++++-
 src/expr/window.rs                        | 20 ++++-----
 src/functions.rs                          | 29 +++++++-----
 src/sql/logical.rs                        |  4 +-
 6 files changed, 183 insertions(+), 48 deletions(-)

diff --git a/python/datafusion/expr.py b/python/datafusion/expr.py
index fd5e6f0..152aa38 100644
--- a/python/datafusion/expr.py
+++ b/python/datafusion/expr.py
@@ -92,7 +92,7 @@ TryCast = expr_internal.TryCast
 Union = expr_internal.Union
 Unnest = expr_internal.Unnest
 UnnestExpr = expr_internal.UnnestExpr
-Window = expr_internal.Window
+WindowExpr = expr_internal.WindowExpr
 
 __all__ = [
     "Expr",
@@ -154,6 +154,7 @@ __all__ = [
     "Partitioning",
     "Repartition",
     "Window",
+    "WindowExpr",
     "WindowFrame",
     "WindowFrameBound",
 ]
@@ -542,6 +543,36 @@ class Expr:
         """
         return 
ExprFuncBuilder(self.expr.window_frame(window_frame.window_frame))
 
+    def over(self, window: Window) -> Expr:
+        """Turn an aggregate function into a window function.
+
+        This function turns any aggregate function into a window function. 
With the
+        exception of ``partition_by``, how each of the parameters is used is 
determined
+        by the underlying aggregate function.
+
+        Args:
+            window: Window definition
+        """
+        partition_by_raw = expr_list_to_raw_expr_list(window._partition_by)
+        order_by_raw = sort_list_to_raw_sort_list(window._order_by)
+        window_frame_raw = (
+            window._window_frame.window_frame
+            if window._window_frame is not None
+            else None
+        )
+        null_treatment_raw = (
+            window._null_treatment.value if window._null_treatment is not None 
else None
+        )
+
+        return Expr(
+            self.expr.over(
+                partition_by=partition_by_raw,
+                order_by=order_by_raw,
+                window_frame=window_frame_raw,
+                null_treatment=null_treatment_raw,
+            )
+        )
+
 
 class ExprFuncBuilder:
     def __init__(self, builder: expr_internal.ExprFuncBuilder):
@@ -584,6 +615,30 @@ class ExprFuncBuilder:
         return Expr(self.builder.build())
 
 
+class Window:
+    """Define reusable window parameters."""
+
+    def __init__(
+        self,
+        partition_by: Optional[list[Expr]] = None,
+        window_frame: Optional[WindowFrame] = None,
+        order_by: Optional[list[SortExpr | Expr]] = None,
+        null_treatment: Optional[NullTreatment] = None,
+    ) -> None:
+        """Construct a window definition.
+
+        Args:
+            partition_by: Partitions for window operation
+            window_frame: Define the start and end bounds of the window frame
+            order_by: Set ordering
+            null_treatment: Indicate how nulls are to be treated
+        """
+        self._partition_by = partition_by
+        self._window_frame = window_frame
+        self._order_by = order_by
+        self._null_treatment = null_treatment
+
+
 class WindowFrame:
     """Defines a window frame for performing window operations."""
 
diff --git a/python/datafusion/tests/test_dataframe.py 
b/python/datafusion/tests/test_dataframe.py
index 90954d0..ad7f728 100644
--- a/python/datafusion/tests/test_dataframe.py
+++ b/python/datafusion/tests/test_dataframe.py
@@ -31,6 +31,7 @@ from datafusion import (
     literal,
     udf,
 )
+from datafusion.expr import Window
 
 
 @pytest.fixture
@@ -386,38 +387,32 @@ data_test_window_functions = [
         ),
         [-1, -1, None, 7, -1, -1, None],
     ),
-    # TODO update all aggregate functions as windows once upstream merges 
https://github.com/apache/datafusion-python/issues/833
-    pytest.param(
+    (
         "first_value",
-        f.window(
-            "first_value",
-            [column("a")],
-            order_by=[f.order_by(column("b"))],
-            partition_by=[column("c")],
+        f.first_value(column("a")).over(
+            Window(partition_by=[column("c")], order_by=[column("b")])
         ),
         [1, 1, 1, 1, 5, 5, 5],
     ),
-    pytest.param(
+    (
         "last_value",
-        f.window("last_value", [column("a")])
-        .window_frame(WindowFrame("rows", 0, None))
-        .order_by(column("b"))
-        .partition_by(column("c"))
-        .build(),
+        f.last_value(column("a")).over(
+            Window(
+                partition_by=[column("c")],
+                order_by=[column("b")],
+                window_frame=WindowFrame("rows", None, None),
+            )
+        ),
         [3, 3, 3, 3, 6, 6, 6],
     ),
-    pytest.param(
+    (
         "3rd_value",
-        f.window(
-            "nth_value",
-            [column("b"), literal(3)],
-            order_by=[f.order_by(column("a"))],
-        ),
+        f.nth_value(column("b"), 3).over(Window(order_by=[column("a")])),
         [None, None, 7, 7, 7, 7, 7],
     ),
-    pytest.param(
+    (
         "avg",
-        f.round(f.window("avg", [column("b")], order_by=[column("a")]), 
literal(3)),
+        f.round(f.avg(column("b")).over(Window(order_by=[column("a")])), 
literal(3)),
         [7.0, 7.0, 7.0, 7.333, 7.75, 7.75, 8.0],
     ),
 ]
@@ -473,6 +468,44 @@ def test_invalid_window_frame(units, start_bound, 
end_bound):
         WindowFrame(units, start_bound, end_bound)
 
 
+def test_window_frame_defaults_match_postgres(partitioned_df):
+    # ref: https://github.com/apache/datafusion-python/issues/688
+
+    window_frame = WindowFrame("rows", None, None)
+
+    col_a = column("a")
+
+    # Using `f.window` with or without an unbounded window_frame produces the 
same
+    # results. These tests are included as a regression check but can be 
removed when
+    # f.window() is deprecated in favor of using the .over() approach.
+    no_frame = f.window("avg", [col_a]).alias("no_frame")
+    with_frame = f.window("avg", [col_a], 
window_frame=window_frame).alias("with_frame")
+    df_1 = partitioned_df.select(col_a, no_frame, with_frame)
+
+    expected = {
+        "a": [0, 1, 2, 3, 4, 5, 6],
+        "no_frame": [3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0],
+        "with_frame": [3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0],
+    }
+
+    assert df_1.sort(col_a).to_pydict() == expected
+
+    # When order is not set, the default frame should be unounded preceeding to
+    # unbounded following. When order is set, the default frame is unbounded 
preceeding
+    # to current row.
+    no_order = f.avg(col_a).over(Window()).alias("over_no_order")
+    with_order = 
f.avg(col_a).over(Window(order_by=[col_a])).alias("over_with_order")
+    df_2 = partitioned_df.select(col_a, no_order, with_order)
+
+    expected = {
+        "a": [0, 1, 2, 3, 4, 5, 6],
+        "over_no_order": [3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0],
+        "over_with_order": [0.0, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0],
+    }
+
+    assert df_2.sort(col_a).to_pydict() == expected
+
+
 def test_get_dataframe(tmp_path):
     ctx = SessionContext()
 
diff --git a/src/expr.rs b/src/expr.rs
index 304d147..49fa4b8 100644
--- a/src/expr.rs
+++ b/src/expr.rs
@@ -16,7 +16,9 @@
 // under the License.
 
 use datafusion::logical_expr::utils::exprlist_to_fields;
-use datafusion::logical_expr::{ExprFuncBuilder, ExprFunctionExt, LogicalPlan};
+use datafusion::logical_expr::{
+    ExprFuncBuilder, ExprFunctionExt, LogicalPlan, WindowFunctionDefinition,
+};
 use pyo3::{basic::CompareOp, prelude::*};
 use std::convert::{From, Into};
 use std::sync::Arc;
@@ -39,6 +41,7 @@ use crate::expr::aggregate_expr::PyAggregateFunction;
 use crate::expr::binary_expr::PyBinaryExpr;
 use crate::expr::column::PyColumn;
 use crate::expr::literal::PyLiteral;
+use crate::functions::add_builder_fns_to_window;
 use crate::sql::logical::PyLogicalPlan;
 
 use self::alias::PyAlias;
@@ -558,6 +561,45 @@ impl PyExpr {
     pub fn window_frame(&self, window_frame: PyWindowFrame) -> 
PyExprFuncBuilder {
         self.expr.clone().window_frame(window_frame.into()).into()
     }
+
+    #[pyo3(signature = (partition_by=None, window_frame=None, order_by=None, 
null_treatment=None))]
+    pub fn over(
+        &self,
+        partition_by: Option<Vec<PyExpr>>,
+        window_frame: Option<PyWindowFrame>,
+        order_by: Option<Vec<PySortExpr>>,
+        null_treatment: Option<NullTreatment>,
+    ) -> PyResult<PyExpr> {
+        match &self.expr {
+            Expr::AggregateFunction(agg_fn) => {
+                let window_fn = Expr::WindowFunction(WindowFunction::new(
+                    
WindowFunctionDefinition::AggregateUDF(agg_fn.func.clone()),
+                    agg_fn.args.clone(),
+                ));
+
+                add_builder_fns_to_window(
+                    window_fn,
+                    partition_by,
+                    window_frame,
+                    order_by,
+                    null_treatment,
+                )
+            }
+            Expr::WindowFunction(_) => add_builder_fns_to_window(
+                self.expr.clone(),
+                partition_by,
+                window_frame,
+                order_by,
+                null_treatment,
+            ),
+            _ => Err(
+                
DataFusionError::ExecutionError(datafusion::error::DataFusionError::Plan(
+                    format!("Using {} with `over` is not allowed. Must use an 
aggregate or window function.", self.expr.variant_name()),
+                ))
+                .into(),
+            ),
+        }
+    }
 }
 
 #[pyclass(name = "ExprFuncBuilder", module = "datafusion.expr", subclass)]
@@ -749,7 +791,7 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> 
PyResult<()> {
     m.add_class::<drop_table::PyDropTable>()?;
     m.add_class::<repartition::PyPartitioning>()?;
     m.add_class::<repartition::PyRepartition>()?;
-    m.add_class::<window::PyWindow>()?;
+    m.add_class::<window::PyWindowExpr>()?;
     m.add_class::<window::PyWindowFrame>()?;
     m.add_class::<window::PyWindowFrameBound>()?;
     Ok(())
diff --git a/src/expr/window.rs b/src/expr/window.rs
index 950db12..6486dbb 100644
--- a/src/expr/window.rs
+++ b/src/expr/window.rs
@@ -32,9 +32,9 @@ use super::py_expr_list;
 
 use crate::errors::py_datafusion_err;
 
-#[pyclass(name = "Window", module = "datafusion.expr", subclass)]
+#[pyclass(name = "WindowExpr", module = "datafusion.expr", subclass)]
 #[derive(Clone)]
-pub struct PyWindow {
+pub struct PyWindowExpr {
     window: Window,
 }
 
@@ -62,15 +62,15 @@ pub struct PyWindowFrameBound {
     frame_bound: WindowFrameBound,
 }
 
-impl From<PyWindow> for Window {
-    fn from(window: PyWindow) -> Window {
+impl From<PyWindowExpr> for Window {
+    fn from(window: PyWindowExpr) -> Window {
         window.window
     }
 }
 
-impl From<Window> for PyWindow {
-    fn from(window: Window) -> PyWindow {
-        PyWindow { window }
+impl From<Window> for PyWindowExpr {
+    fn from(window: Window) -> PyWindowExpr {
+        PyWindowExpr { window }
     }
 }
 
@@ -80,7 +80,7 @@ impl From<WindowFrameBound> for PyWindowFrameBound {
     }
 }
 
-impl Display for PyWindow {
+impl Display for PyWindowExpr {
     fn fmt(&self, f: &mut Formatter) -> fmt::Result {
         write!(
             f,
@@ -103,7 +103,7 @@ impl Display for PyWindowFrame {
 }
 
 #[pymethods]
-impl PyWindow {
+impl PyWindowExpr {
     /// Returns the schema of the Window
     pub fn schema(&self) -> PyResult<PyDFSchema> {
         Ok(self.window.schema.as_ref().clone().into())
@@ -283,7 +283,7 @@ impl PyWindowFrameBound {
     }
 }
 
-impl LogicalNode for PyWindow {
+impl LogicalNode for PyWindowExpr {
     fn inputs(&self) -> Vec<PyLogicalPlan> {
         vec![self.window.input.as_ref().clone().into()]
     }
diff --git a/src/functions.rs b/src/functions.rs
index 32f6519..6f8dd7a 100644
--- a/src/functions.rs
+++ b/src/functions.rs
@@ -711,14 +711,15 @@ pub fn string_agg(
     add_builder_fns_to_aggregate(agg_fn, distinct, filter, order_by, 
null_treatment)
 }
 
-fn add_builder_fns_to_window(
+pub(crate) fn add_builder_fns_to_window(
     window_fn: Expr,
     partition_by: Option<Vec<PyExpr>>,
+    window_frame: Option<PyWindowFrame>,
     order_by: Option<Vec<PySortExpr>>,
+    null_treatment: Option<NullTreatment>,
 ) -> PyResult<PyExpr> {
-    // Since ExprFuncBuilder::new() is private, set an empty partition and then
-    // override later if appropriate.
-    let mut builder = window_fn.partition_by(vec![]);
+    let null_treatment = null_treatment.map(|n| n.into());
+    let mut builder = window_fn.null_treatment(null_treatment);
 
     if let Some(partition_cols) = partition_by {
         builder = builder.partition_by(
@@ -734,6 +735,10 @@ fn add_builder_fns_to_window(
         builder = builder.order_by(order_by_cols);
     }
 
+    if let Some(window_frame) = window_frame {
+        builder = builder.window_frame(window_frame.into());
+    }
+
     builder.build().map(|e| e.into()).map_err(|err| err.into())
 }
 
@@ -748,7 +753,7 @@ pub fn lead(
 ) -> PyResult<PyExpr> {
     let window_fn = window_function::lead(arg.expr, Some(shift_offset), 
default_value);
 
-    add_builder_fns_to_window(window_fn, partition_by, order_by)
+    add_builder_fns_to_window(window_fn, partition_by, None, order_by, None)
 }
 
 #[pyfunction]
@@ -762,7 +767,7 @@ pub fn lag(
 ) -> PyResult<PyExpr> {
     let window_fn = window_function::lag(arg.expr, Some(shift_offset), 
default_value);
 
-    add_builder_fns_to_window(window_fn, partition_by, order_by)
+    add_builder_fns_to_window(window_fn, partition_by, None, order_by, None)
 }
 
 #[pyfunction]
@@ -773,7 +778,7 @@ pub fn row_number(
 ) -> PyResult<PyExpr> {
     let window_fn = datafusion::functions_window::expr_fn::row_number();
 
-    add_builder_fns_to_window(window_fn, partition_by, order_by)
+    add_builder_fns_to_window(window_fn, partition_by, None, order_by, None)
 }
 
 #[pyfunction]
@@ -784,7 +789,7 @@ pub fn rank(
 ) -> PyResult<PyExpr> {
     let window_fn = window_function::rank();
 
-    add_builder_fns_to_window(window_fn, partition_by, order_by)
+    add_builder_fns_to_window(window_fn, partition_by, None, order_by, None)
 }
 
 #[pyfunction]
@@ -795,7 +800,7 @@ pub fn dense_rank(
 ) -> PyResult<PyExpr> {
     let window_fn = window_function::dense_rank();
 
-    add_builder_fns_to_window(window_fn, partition_by, order_by)
+    add_builder_fns_to_window(window_fn, partition_by, None, order_by, None)
 }
 
 #[pyfunction]
@@ -806,7 +811,7 @@ pub fn percent_rank(
 ) -> PyResult<PyExpr> {
     let window_fn = window_function::percent_rank();
 
-    add_builder_fns_to_window(window_fn, partition_by, order_by)
+    add_builder_fns_to_window(window_fn, partition_by, None, order_by, None)
 }
 
 #[pyfunction]
@@ -817,7 +822,7 @@ pub fn cume_dist(
 ) -> PyResult<PyExpr> {
     let window_fn = window_function::cume_dist();
 
-    add_builder_fns_to_window(window_fn, partition_by, order_by)
+    add_builder_fns_to_window(window_fn, partition_by, None, order_by, None)
 }
 
 #[pyfunction]
@@ -829,7 +834,7 @@ pub fn ntile(
 ) -> PyResult<PyExpr> {
     let window_fn = window_function::ntile(arg.into());
 
-    add_builder_fns_to_window(window_fn, partition_by, order_by)
+    add_builder_fns_to_window(window_fn, partition_by, None, order_by, None)
 }
 
 pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
diff --git a/src/sql/logical.rs b/src/sql/logical.rs
index 89655ab..d00f0af 100644
--- a/src/sql/logical.rs
+++ b/src/sql/logical.rs
@@ -34,7 +34,7 @@ use crate::expr::subquery::PySubquery;
 use crate::expr::subquery_alias::PySubqueryAlias;
 use crate::expr::table_scan::PyTableScan;
 use crate::expr::unnest::PyUnnest;
-use crate::expr::window::PyWindow;
+use crate::expr::window::PyWindowExpr;
 use datafusion::logical_expr::LogicalPlan;
 use pyo3::prelude::*;
 
@@ -80,7 +80,7 @@ impl PyLogicalPlan {
             LogicalPlan::Subquery(plan) => 
PySubquery::from(plan.clone()).to_variant(py),
             LogicalPlan::SubqueryAlias(plan) => 
PySubqueryAlias::from(plan.clone()).to_variant(py),
             LogicalPlan::Unnest(plan) => 
PyUnnest::from(plan.clone()).to_variant(py),
-            LogicalPlan::Window(plan) => 
PyWindow::from(plan.clone()).to_variant(py),
+            LogicalPlan::Window(plan) => 
PyWindowExpr::from(plan.clone()).to_variant(py),
             LogicalPlan::Repartition(_)
             | LogicalPlan::Union(_)
             | LogicalPlan::Statement(_)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to