This is an automated email from the ASF dual-hosted git repository.

github-bot pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 22c4214fe1 Refactor `nvl2` Function to Support Lazy Evaluation and 
Simplification via CASE Expression (#18191)
22c4214fe1 is described below

commit 22c4214fe1ca3953932f3f12ccd5b68dbfbefdf3
Author: kosiew <[email protected]>
AuthorDate: Fri Oct 24 17:58:57 2025 +0800

    Refactor `nvl2` Function to Support Lazy Evaluation and Simplification via 
CASE Expression (#18191)
    
    ## Which issue does this PR close?
    
    * Closes #17983
    
    ## Rationale for this change
    
    The current implementation of the `nvl2` function in DataFusion eagerly
    evaluates all its arguments, which can lead to unnecessary computation
    and incorrect behavior when handling expressions that should only be
    conditionally evaluated. This PR introduces **lazy evaluation** for
    `nvl2`, aligning its behavior with other conditional expressions like
    `coalesce` and improving both performance and correctness.
    
    This change also introduces a **simplification rule** that rewrites
    `nvl2` expressions into equivalent `CASE` statements, allowing for
    better optimization during query planning and execution.
    
    ## What changes are included in this PR?
    
    * Refactored `nvl2` implementation in
    `datafusion/functions/src/core/nvl2.rs`:
    
    * Added support for **short-circuit (lazy) evaluation** using
    `short_circuits()`.
    * Implemented **simplify()** method to rewrite expressions into `CASE`
    form.
    * Introduced **return_field_from_args()** for correct nullability and
    type inference.
    * Replaced the previous eager `nvl2_func()` logic with an optimized,
    more declarative approach.
    
    * Added comprehensive **unit tests**:
    
    * `test_nvl2_short_circuit` in `dataframe_functions.rs` verifies correct
    short-circuit behavior.
    * `test_create_physical_expr_nvl2` in `expr_api/mod.rs` validates
    physical expression creation and output correctness.
    
    ## Are these changes tested?
    
    ✅ Yes, multiple new tests are included:
    
    * **`test_nvl2_short_circuit`** ensures `nvl2` does not evaluate
    unnecessary branches.
    * **`test_create_physical_expr_nvl2`** checks the correctness of
    evaluation and type coercion behavior.
    
    All existing and new tests pass successfully.
    
    ## Are there any user-facing changes?
    
    Yes, but they are **non-breaking** and **performance-enhancing**:
    
    * `nvl2` now evaluates lazily, meaning only the required branch is
    computed based on the nullity of the test expression.
    * Expression simplification will yield more optimized query plans.
    
    There are **no API-breaking changes**. However, users may observe
    improved performance and reduced computation for expressions involving
    `nvl2`.
    
    ---------
    
    Co-authored-by: Jeffrey Vo <[email protected]>
---
 .../core/tests/dataframe/dataframe_functions.rs    | 27 +++++++
 datafusion/core/tests/expr_api/mod.rs              | 20 +++++
 datafusion/functions/src/core/nvl2.rs              | 85 ++++++++++------------
 3 files changed, 84 insertions(+), 48 deletions(-)

diff --git a/datafusion/core/tests/dataframe/dataframe_functions.rs 
b/datafusion/core/tests/dataframe/dataframe_functions.rs
index b664fccdfa..d95eb38c19 100644
--- a/datafusion/core/tests/dataframe/dataframe_functions.rs
+++ b/datafusion/core/tests/dataframe/dataframe_functions.rs
@@ -274,6 +274,33 @@ async fn test_nvl2() -> Result<()> {
 
     Ok(())
 }
+
+#[tokio::test]
+async fn test_nvl2_short_circuit() -> Result<()> {
+    let expr = nvl2(
+        col("a"),
+        arrow_cast(lit("1"), lit("Int32")),
+        arrow_cast(col("a"), lit("Int32")),
+    );
+
+    let batches = get_batches(expr).await?;
+
+    assert_snapshot!(
+        batches_to_string(&batches),
+        @r#"
+    
+-----------------------------------------------------------------------------------+
+    | 
nvl2(test.a,arrow_cast(Utf8("1"),Utf8("Int32")),arrow_cast(test.a,Utf8("Int32")))
 |
+    
+-----------------------------------------------------------------------------------+
+    | 1                                                                        
         |
+    | 1                                                                        
         |
+    | 1                                                                        
         |
+    | 1                                                                        
         |
+    
+-----------------------------------------------------------------------------------+
+    "#
+    );
+
+    Ok(())
+}
 #[tokio::test]
 async fn test_fn_arrow_typeof() -> Result<()> {
     let expr = arrow_typeof(col("l"));
diff --git a/datafusion/core/tests/expr_api/mod.rs 
b/datafusion/core/tests/expr_api/mod.rs
index 4aee274de9..84e644480a 100644
--- a/datafusion/core/tests/expr_api/mod.rs
+++ b/datafusion/core/tests/expr_api/mod.rs
@@ -320,6 +320,26 @@ async fn test_create_physical_expr() {
     create_simplified_expr_test(lit(1i32) + lit(2i32), "3");
 }
 
+#[test]
+fn test_create_physical_expr_nvl2() {
+    let batch = &TEST_BATCH;
+    let df_schema = DFSchema::try_from(batch.schema()).unwrap();
+    let ctx = SessionContext::new();
+
+    let expect_err = |expr| {
+        let physical_expr = ctx.create_physical_expr(expr, 
&df_schema).unwrap();
+        let err = physical_expr.evaluate(batch).unwrap_err();
+        assert!(
+            err.to_string()
+                .contains("nvl2 should have been simplified to case"),
+            "unexpected error: {err:?}"
+        );
+    };
+
+    expect_err(nvl2(col("i"), lit(1i64), lit(0i64)));
+    expect_err(nvl2(lit(1i64), col("i"), lit(0i64)));
+}
+
 #[tokio::test]
 async fn test_create_physical_expr_coercion() {
     // create_physical_expr does apply type coercion and unwrapping in cast
diff --git a/datafusion/functions/src/core/nvl2.rs 
b/datafusion/functions/src/core/nvl2.rs
index 82aa8d2a4c..45cb6760d0 100644
--- a/datafusion/functions/src/core/nvl2.rs
+++ b/datafusion/functions/src/core/nvl2.rs
@@ -15,17 +15,16 @@
 // specific language governing permissions and limitations
 // under the License.
 
-use arrow::array::Array;
-use arrow::compute::is_not_null;
-use arrow::compute::kernels::zip::zip;
-use arrow::datatypes::DataType;
+use arrow::datatypes::{DataType, Field, FieldRef};
 use datafusion_common::{internal_err, utils::take_function_args, Result};
 use datafusion_expr::{
-    type_coercion::binary::comparison_coercion, ColumnarValue, Documentation,
-    ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
+    conditional_expressions::CaseBuilder,
+    simplify::{ExprSimplifyResult, SimplifyInfo},
+    type_coercion::binary::comparison_coercion,
+    ColumnarValue, Documentation, Expr, ReturnFieldArgs, ScalarFunctionArgs,
+    ScalarUDFImpl, Signature, Volatility,
 };
 use datafusion_macros::user_doc;
-use std::sync::Arc;
 
 #[user_doc(
     doc_section(label = "Conditional Functions"),
@@ -95,8 +94,37 @@ impl ScalarUDFImpl for NVL2Func {
         Ok(arg_types[1].clone())
     }
 
-    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> 
Result<ColumnarValue> {
-        nvl2_func(&args.args)
+    fn return_field_from_args(&self, args: ReturnFieldArgs) -> 
Result<FieldRef> {
+        let nullable =
+            args.arg_fields[1].is_nullable() || 
args.arg_fields[2].is_nullable();
+        let return_type = args.arg_fields[1].data_type().clone();
+        Ok(Field::new(self.name(), return_type, nullable).into())
+    }
+
+    fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> 
Result<ColumnarValue> {
+        internal_err!("nvl2 should have been simplified to case")
+    }
+
+    fn simplify(
+        &self,
+        args: Vec<Expr>,
+        _info: &dyn SimplifyInfo,
+    ) -> Result<ExprSimplifyResult> {
+        let [test, if_non_null, if_null] = take_function_args(self.name(), 
args)?;
+
+        let expr = CaseBuilder::new(
+            None,
+            vec![test.is_not_null()],
+            vec![if_non_null],
+            Some(Box::new(if_null)),
+        )
+        .end()?;
+
+        Ok(ExprSimplifyResult::Simplified(expr))
+    }
+
+    fn short_circuits(&self) -> bool {
+        true
     }
 
     fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
@@ -123,42 +151,3 @@ impl ScalarUDFImpl for NVL2Func {
         self.doc()
     }
 }
-
-fn nvl2_func(args: &[ColumnarValue]) -> Result<ColumnarValue> {
-    let mut len = 1;
-    let mut is_array = false;
-    for arg in args {
-        if let ColumnarValue::Array(array) = arg {
-            len = array.len();
-            is_array = true;
-            break;
-        }
-    }
-    if is_array {
-        let args = args
-            .iter()
-            .map(|arg| match arg {
-                ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(len),
-                ColumnarValue::Array(array) => Ok(Arc::clone(array)),
-            })
-            .collect::<Result<Vec<_>>>()?;
-        let [tested, if_non_null, if_null] = take_function_args("nvl2", args)?;
-        let to_apply = is_not_null(&tested)?;
-        let value = zip(&to_apply, &if_non_null, &if_null)?;
-        Ok(ColumnarValue::Array(value))
-    } else {
-        let [tested, if_non_null, if_null] = take_function_args("nvl2", args)?;
-        match &tested {
-            ColumnarValue::Array(_) => {
-                internal_err!("except Scalar value, but got Array")
-            }
-            ColumnarValue::Scalar(scalar) => {
-                if scalar.is_null() {
-                    Ok(if_null.clone())
-                } else {
-                    Ok(if_non_null.clone())
-                }
-            }
-        }
-    }
-}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to