This is an automated email from the ASF dual-hosted git repository.

ozankabak pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 8dee7bbb66 Causality Analysis for Builtin Window Functions (#9048)
8dee7bbb66 is described below

commit 8dee7bbb669e54fd6f813dbcb8aaa37a665e486d
Author: Mustafa Akur <[email protected]>
AuthorDate: Tue Jan 30 16:37:21 2024 +0300

    Causality Analysis for Builtin Window Functions (#9048)
    
    * Add new tests, causality support for row number
    
    * Add causality support rank, and add test
    
    * Update comments
    
    Co-authored-by: Mehmet Ozan Kabak <[email protected]>
    
    * Remove leftover code
    
    ---------
    
    Co-authored-by: Mehmet Ozan Kabak <[email protected]>
---
 datafusion/core/tests/fuzz_cases/window_fuzz.rs   | 205 ++++++++++++++++------
 datafusion/expr/src/partition_evaluator.rs        |   5 +
 datafusion/physical-expr/src/window/built_in.rs   |   7 +-
 datafusion/physical-expr/src/window/lead_lag.rs   |   5 +
 datafusion/physical-expr/src/window/rank.rs       |   4 +
 datafusion/physical-expr/src/window/row_number.rs |   5 +
 6 files changed, 174 insertions(+), 57 deletions(-)

diff --git a/datafusion/core/tests/fuzz_cases/window_fuzz.rs 
b/datafusion/core/tests/fuzz_cases/window_fuzz.rs
index 7358ec2884..d22d0c0f2e 100644
--- a/datafusion/core/tests/fuzz_cases/window_fuzz.rs
+++ b/datafusion/core/tests/fuzz_cases/window_fuzz.rs
@@ -154,67 +154,162 @@ async fn bounded_window_causal_non_causal() -> 
Result<()> {
         schema.clone(),
         None,
     )?);
-    let window_fn = 
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Count);
-    let fn_name = "COUNT".to_string();
-    let args = vec![col("x", &schema)?];
+
+    // Different window functions to test causality
+    let window_functions = vec![
+        // Simulate cases of the following form:
+        // COUNT(x) OVER (
+        //     ROWS BETWEEN UNBOUNDED PRECEDING AND <end_bound> 
PRECEDING/FOLLOWING
+        // )
+        (
+            // Window function
+            
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Count),
+            // its name
+            "COUNT",
+            // window function argument
+            vec![col("x", &schema)?],
+            // Expected causality, for None cases causality will be determined 
from window frame boundaries
+            None,
+        ),
+        // Simulate cases of the following form:
+        // ROW_NUMBER() OVER (
+        //     ROWS BETWEEN UNBOUNDED PRECEDING AND <end_bound> 
PRECEDING/FOLLOWING
+        // )
+        (
+            // Window function
+            WindowFunctionDefinition::BuiltInWindowFunction(
+                BuiltInWindowFunction::RowNumber,
+            ),
+            // its name
+            "ROW_NUMBER",
+            // no argument
+            vec![],
+            // Expected causality, for None cases causality will be determined 
from window frame boundaries
+            Some(true),
+        ),
+        // Simulate cases of the following form:
+        // LAG(x) OVER (
+        //     ROWS BETWEEN UNBOUNDED PRECEDING AND <end_bound> 
PRECEDING/FOLLOWING
+        // )
+        (
+            // Window function
+            
WindowFunctionDefinition::BuiltInWindowFunction(BuiltInWindowFunction::Lag),
+            // its name
+            "LAG",
+            // no argument
+            vec![col("x", &schema)?],
+            // Expected causality, for None cases causality will be determined 
from window frame boundaries
+            Some(true),
+        ),
+        // Simulate cases of the following form:
+        // LEAD(x) OVER (
+        //     ROWS BETWEEN UNBOUNDED PRECEDING AND <end_bound> 
PRECEDING/FOLLOWING
+        // )
+        (
+            // Window function
+            
WindowFunctionDefinition::BuiltInWindowFunction(BuiltInWindowFunction::Lead),
+            // its name
+            "LEAD",
+            // no argument
+            vec![col("x", &schema)?],
+            // Expected causality, for None cases causality will be determined 
from window frame boundaries
+            Some(false),
+        ),
+        // Simulate cases of the following form:
+        // RANK() OVER (
+        //     ROWS BETWEEN UNBOUNDED PRECEDING AND <end_bound> 
PRECEDING/FOLLOWING
+        // )
+        (
+            // Window function
+            
WindowFunctionDefinition::BuiltInWindowFunction(BuiltInWindowFunction::Rank),
+            // its name
+            "RANK",
+            // no argument
+            vec![],
+            // Expected causality, for None cases causality will be determined 
from window frame boundaries
+            Some(true),
+        ),
+        // Simulate cases of the following form:
+        // DENSE_RANK() OVER (
+        //     ROWS BETWEEN UNBOUNDED PRECEDING AND <end_bound> 
PRECEDING/FOLLOWING
+        // )
+        (
+            // Window function
+            WindowFunctionDefinition::BuiltInWindowFunction(
+                BuiltInWindowFunction::DenseRank,
+            ),
+            // its name
+            "DENSE_RANK",
+            // no argument
+            vec![],
+            // Expected causality, for None cases causality will be determined 
from window frame boundaries
+            Some(true),
+        ),
+    ];
+
     let partitionby_exprs = vec![];
     let orderby_exprs = vec![];
     // Window frame starts with "UNBOUNDED PRECEDING":
     let start_bound = WindowFrameBound::Preceding(ScalarValue::UInt64(None));
 
-    // Simulate cases of the following form:
-    // COUNT(x) OVER (
-    //     ROWS BETWEEN UNBOUNDED PRECEDING AND <end_bound> PRECEDING/FOLLOWING
-    // )
-    for is_preceding in [false, true] {
-        for end_bound in [0, 1, 2, 3] {
-            let end_bound = if is_preceding {
-                
WindowFrameBound::Preceding(ScalarValue::UInt64(Some(end_bound)))
-            } else {
-                
WindowFrameBound::Following(ScalarValue::UInt64(Some(end_bound)))
-            };
-            let window_frame = WindowFrame::new_bounds(
-                WindowFrameUnits::Rows,
-                start_bound.clone(),
-                end_bound,
-            );
-            let causal = window_frame.is_causal();
+    for (window_fn, fn_name, args, expected_causal) in window_functions {
+        for is_preceding in [false, true] {
+            for end_bound in [0, 1, 2, 3] {
+                let end_bound = if is_preceding {
+                    
WindowFrameBound::Preceding(ScalarValue::UInt64(Some(end_bound)))
+                } else {
+                    
WindowFrameBound::Following(ScalarValue::UInt64(Some(end_bound)))
+                };
+                let window_frame = WindowFrame::new_bounds(
+                    WindowFrameUnits::Rows,
+                    start_bound.clone(),
+                    end_bound,
+                );
+                let causal = if let Some(expected_causal) = expected_causal {
+                    expected_causal
+                } else {
+                    // If there is no expected causality
+                    // calculate it from window frame
+                    window_frame.is_causal()
+                };
 
-            let window_expr = create_window_expr(
-                &window_fn,
-                fn_name.clone(),
-                &args,
-                &partitionby_exprs,
-                &orderby_exprs,
-                Arc::new(window_frame),
-                schema.as_ref(),
-            )?;
-            let running_window_exec = Arc::new(BoundedWindowAggExec::try_new(
-                vec![window_expr],
-                memory_exec.clone(),
-                vec![],
-                InputOrderMode::Linear,
-            )?);
-            let task_ctx = ctx.task_ctx();
-            let mut collected_results = collect(running_window_exec, 
task_ctx).await?;
-            collected_results.retain(|batch| batch.num_rows() > 0);
-            let input_batch_sizes = batches
-                .iter()
-                .map(|batch| batch.num_rows())
-                .collect::<Vec<_>>();
-            let result_batch_sizes = collected_results
-                .iter()
-                .map(|batch| batch.num_rows())
-                .collect::<Vec<_>>();
-            if causal {
-                // For causal window frames, we can generate results 
immediately
-                // for each input batch. Hence, batch sizes should match.
-                assert_eq!(input_batch_sizes, result_batch_sizes);
-            } else {
-                // For non-causal window frames, we cannot generate results
-                // immediately for each input batch. Hence, batch sizes 
shouldn't
-                // match.
-                assert_ne!(input_batch_sizes, result_batch_sizes);
+                let window_expr = create_window_expr(
+                    &window_fn,
+                    fn_name.to_string(),
+                    &args,
+                    &partitionby_exprs,
+                    &orderby_exprs,
+                    Arc::new(window_frame),
+                    schema.as_ref(),
+                )?;
+                let running_window_exec = 
Arc::new(BoundedWindowAggExec::try_new(
+                    vec![window_expr],
+                    memory_exec.clone(),
+                    vec![],
+                    InputOrderMode::Linear,
+                )?);
+                let task_ctx = ctx.task_ctx();
+                let mut collected_results =
+                    collect(running_window_exec, task_ctx).await?;
+                collected_results.retain(|batch| batch.num_rows() > 0);
+                let input_batch_sizes = batches
+                    .iter()
+                    .map(|batch| batch.num_rows())
+                    .collect::<Vec<_>>();
+                let result_batch_sizes = collected_results
+                    .iter()
+                    .map(|batch| batch.num_rows())
+                    .collect::<Vec<_>>();
+                if causal {
+                    // For causal window frames, we can generate results 
immediately
+                    // for each input batch. Hence, batch sizes should match.
+                    assert_eq!(input_batch_sizes, result_batch_sizes);
+                } else {
+                    // For non-causal window frames, we cannot generate results
+                    // immediately for each input batch. Hence, batch sizes 
shouldn't
+                    // match.
+                    assert_ne!(input_batch_sizes, result_batch_sizes);
+                }
             }
         }
     }
diff --git a/datafusion/expr/src/partition_evaluator.rs 
b/datafusion/expr/src/partition_evaluator.rs
index 0a765b30b0..4b5357ddf8 100644
--- a/datafusion/expr/src/partition_evaluator.rs
+++ b/datafusion/expr/src/partition_evaluator.rs
@@ -118,6 +118,11 @@ pub trait PartitionEvaluator: Debug + Send {
         }
     }
 
+    /// Get whether evaluator needs future data for its result (if so returns 
`false`) or not
+    fn is_causal(&self) -> bool {
+        false
+    }
+
     /// Evaluate a window function on an entire input partition.
     ///
     /// This function is called once per input *partition* for window
diff --git a/datafusion/physical-expr/src/window/built_in.rs 
b/datafusion/physical-expr/src/window/built_in.rs
index c3c7400026..065260a73e 100644
--- a/datafusion/physical-expr/src/window/built_in.rs
+++ b/datafusion/physical-expr/src/window/built_in.rs
@@ -206,7 +206,11 @@ impl WindowExpr for BuiltInWindowExpr {
             let record_batch = &partition_batch_state.record_batch;
             let num_rows = record_batch.num_rows();
             let mut row_wise_results: Vec<ScalarValue> = vec![];
-            let mut is_causal = self.window_frame.is_causal();
+            let is_causal = if evaluator.uses_window_frame() {
+                self.window_frame.is_causal()
+            } else {
+                evaluator.is_causal()
+            };
             for idx in state.last_calculated_index..num_rows {
                 let frame_range = if evaluator.uses_window_frame() {
                     state
@@ -225,7 +229,6 @@ impl WindowExpr for BuiltInWindowExpr {
                             idx,
                         )
                 } else {
-                    is_causal = false;
                     evaluator.get_range(idx, num_rows)
                 }?;
 
diff --git a/datafusion/physical-expr/src/window/lead_lag.rs 
b/datafusion/physical-expr/src/window/lead_lag.rs
index c218b5555a..6a33f26ca1 100644
--- a/datafusion/physical-expr/src/window/lead_lag.rs
+++ b/datafusion/physical-expr/src/window/lead_lag.rs
@@ -194,6 +194,11 @@ impl PartitionEvaluator for WindowShiftEvaluator {
         }
     }
 
+    fn is_causal(&self) -> bool {
+        // Lagging windows are causal by definition:
+        self.shift_offset > 0
+    }
+
     fn evaluate(
         &mut self,
         values: &[ArrayRef],
diff --git a/datafusion/physical-expr/src/window/rank.rs 
b/datafusion/physical-expr/src/window/rank.rs
index 1f643f0280..437fdbe0b9 100644
--- a/datafusion/physical-expr/src/window/rank.rs
+++ b/datafusion/physical-expr/src/window/rank.rs
@@ -132,6 +132,10 @@ pub(crate) struct RankEvaluator {
 }
 
 impl PartitionEvaluator for RankEvaluator {
+    fn is_causal(&self) -> bool {
+        matches!(self.rank_type, RankType::Basic | RankType::Dense)
+    }
+
     /// Evaluates the window function inside the given range.
     fn evaluate(
         &mut self,
diff --git a/datafusion/physical-expr/src/window/row_number.rs 
b/datafusion/physical-expr/src/window/row_number.rs
index 759f447ab0..0140342405 100644
--- a/datafusion/physical-expr/src/window/row_number.rs
+++ b/datafusion/physical-expr/src/window/row_number.rs
@@ -92,6 +92,11 @@ pub(crate) struct NumRowsEvaluator {
 }
 
 impl PartitionEvaluator for NumRowsEvaluator {
+    fn is_causal(&self) -> bool {
+        // The ROW_NUMBER function doesn't need "future" values to emit 
results:
+        true
+    }
+
     /// evaluate window function result inside given range
     fn evaluate(
         &mut self,

Reply via email to