This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 19c4253512 allow window UDF to return null (#6915)
19c4253512 is described below

commit 19c4253512b130c266982d92196696730e59fb53
Author: Martin Hilton <[email protected]>
AuthorDate: Tue Jul 11 18:28:04 2023 +0100

    allow window UDF to return null (#6915)
    
    Update the schema produced by the WindowUDF expression to allow
    user-defined window functions to return NULL.
---
 datafusion/core/src/physical_plan/windows/mod.rs   |  2 +-
 .../user_defined/user_defined_window_functions.rs  | 48 +++++++++++++++++++++-
 2 files changed, 48 insertions(+), 2 deletions(-)

diff --git a/datafusion/core/src/physical_plan/windows/mod.rs 
b/datafusion/core/src/physical_plan/windows/mod.rs
index 88fafe99b4..ff7936e5ce 100644
--- a/datafusion/core/src/physical_plan/windows/mod.rs
+++ b/datafusion/core/src/physical_plan/windows/mod.rs
@@ -258,7 +258,7 @@ impl BuiltInWindowFunctionExpr for WindowUDFExpr {
     }
 
     fn field(&self) -> Result<Field> {
-        let nullable = false;
+        let nullable = true;
         Ok(Field::new(
             &self.name,
             self.data_type.as_ref().clone(),
diff --git 
a/datafusion/core/tests/user_defined/user_defined_window_functions.rs 
b/datafusion/core/tests/user_defined/user_defined_window_functions.rs
index dfa1781285..8736ede690 100644
--- a/datafusion/core/tests/user_defined/user_defined_window_functions.rs
+++ b/datafusion/core/tests/user_defined/user_defined_window_functions.rs
@@ -297,6 +297,39 @@ async fn test_udwf_bounded_query_include_rank() {
     assert_eq!(test_state.evaluate_all_with_rank_called(), 2);
 }
 
+/// Basic user defined window function that can return NULL.
+#[tokio::test]
+async fn test_udwf_bounded_window_returns_null() {
+    let test_state = TestState::new()
+        .with_uses_window_frame()
+        .with_null_for_zero();
+    let TestContext { ctx, test_state } = TestContext::new(test_state);
+
+    let expected = vec![
+    
"+---+---+-----+--------------------------------------------------------------------------------------------------------------+",
+    "| x | y | val | odd_counter(t.val) PARTITION BY [t.x] ORDER BY [t.y ASC 
NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING |",
+    
"+---+---+-----+--------------------------------------------------------------------------------------------------------------+",
+    "| 1 | a | 0   | 1                                                         
                                                   |",
+    "| 1 | b | 1   | 1                                                         
                                                   |",
+    "| 1 | c | 2   | 1                                                         
                                                   |",
+    "| 2 | d | 3   | 1                                                         
                                                   |",
+    "| 2 | e | 4   | 2                                                         
                                                   |",
+    "| 2 | f | 5   | 1                                                         
                                                   |",
+    "| 2 | g | 6   | 1                                                         
                                                   |",
+    "| 2 | h | 6   |                                                           
                                                   |",
+    "| 2 | i | 6   |                                                           
                                                   |",
+    "| 2 | j | 6   |                                                           
                                                   |",
+    
"+---+---+-----+--------------------------------------------------------------------------------------------------------------+",
+    ];
+    assert_batches_eq!(
+        expected,
+        &execute(&ctx, BOUNDED_WINDOW_QUERY).await.unwrap()
+    );
+    // Evaluate is called for each input rows
+    assert_eq!(test_state.evaluate_called(), 10);
+    assert_eq!(test_state.evaluate_all_called(), 0);
+}
+
 async fn execute(ctx: &SessionContext, sql: &str) -> Result<Vec<RecordBatch>> {
     ctx.sql(sql).await?.collect().await
 }
@@ -365,6 +398,8 @@ struct TestState {
     supports_bounded_execution: bool,
     /// should the functions they need include rank
     include_rank: bool,
+    /// should the functions return NULL for 0s?
+    null_for_zero: bool,
 }
 
 impl TestState {
@@ -390,6 +425,12 @@ impl TestState {
         self
     }
 
+    // Set that this function should return NULL instead of zero.
+    fn with_null_for_zero(mut self) -> Self {
+        self.null_for_zero = true;
+        self
+    }
+
     /// return the evaluate_all_called counter
     fn evaluate_all_called(&self) -> usize {
         self.evaluate_all_called.load(Ordering::SeqCst)
@@ -476,7 +517,12 @@ impl PartitionEvaluator for OddCounter {
         self.test_state.inc_evaluate_called();
         let values: &Int64Array = values.get(0).unwrap().as_primitive();
         let values = values.slice(range.start, range.len());
-        let scalar = ScalarValue::Int64(Some(odd_count(&values)));
+        let scalar = ScalarValue::Int64(
+            match (odd_count(&values), self.test_state.null_for_zero) {
+                (0, true) => None,
+                (n, _) => Some(n),
+            },
+        );
         Ok(scalar)
     }
 

Reply via email to