This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 99daafd7b7 feat(spark): implement Spark conditional function if
(#16946)
99daafd7b7 is described below
commit 99daafd7b738831f5d6f95007536e0712a90ba5c
Author: Chen Chongchen <[email protected]>
AuthorDate: Sat Aug 30 20:44:20 2025 +0800
feat(spark): implement Spark conditional function if (#16946)
---
datafusion/spark/src/function/conditional/if.rs | 101 ++++++++++++++
datafusion/spark/src/function/conditional/mod.rs | 13 +-
.../test_files/spark/conditional/if.slt | 147 ++++++++++++++++++++-
3 files changed, 255 insertions(+), 6 deletions(-)
diff --git a/datafusion/spark/src/function/conditional/if.rs
b/datafusion/spark/src/function/conditional/if.rs
new file mode 100644
index 0000000000..aee43dd8d0
--- /dev/null
+++ b/datafusion/spark/src/function/conditional/if.rs
@@ -0,0 +1,101 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use arrow::datatypes::DataType;
+use datafusion_common::{internal_err, plan_err, Result};
+use datafusion_expr::{
+ binary::try_type_union_resolution, simplify::ExprSimplifyResult, when,
ColumnarValue,
+ Expr, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
+};
+
+#[derive(Debug, PartialEq, Eq, Hash)]
+pub struct SparkIf {
+ signature: Signature,
+}
+
+impl Default for SparkIf {
+ fn default() -> Self {
+ Self::new()
+ }
+}
+
+impl SparkIf {
+ pub fn new() -> Self {
+ Self {
+ signature: Signature::user_defined(Volatility::Immutable),
+ }
+ }
+}
+
+impl ScalarUDFImpl for SparkIf {
+ fn as_any(&self) -> &dyn std::any::Any {
+ self
+ }
+
+ fn name(&self) -> &str {
+ "if"
+ }
+
+ fn signature(&self) -> &Signature {
+ &self.signature
+ }
+
+ fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
+ if arg_types.len() != 3 {
+ return plan_err!(
+ "Function 'if' expects 3 arguments but received {}",
+ arg_types.len()
+ );
+ }
+
+ if arg_types[0] != DataType::Boolean && arg_types[0] != DataType::Null
{
+ return plan_err!(
+ "For function 'if' {} is not a boolean or null",
+ arg_types[0]
+ );
+ }
+
+ let target_types = try_type_union_resolution(&arg_types[1..])?;
+ let mut result = vec![DataType::Boolean];
+ result.extend(target_types);
+ Ok(result)
+ }
+
+ fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
+ Ok(arg_types[1].clone())
+ }
+
+ fn invoke_with_args(&self, _args: ScalarFunctionArgs) ->
Result<ColumnarValue> {
+ internal_err!("if should have been simplified to case")
+ }
+
+ fn simplify(
+ &self,
+ args: Vec<Expr>,
+ _info: &dyn datafusion_expr::simplify::SimplifyInfo,
+ ) -> Result<ExprSimplifyResult> {
+ let condition = args[0].clone();
+ let then_expr = args[1].clone();
+ let else_expr = args[2].clone();
+
+ // Convert IF(condition, then_expr, else_expr) to
+ // CASE WHEN condition THEN then_expr ELSE else_expr END
+ let case_expr = when(condition, then_expr).otherwise(else_expr)?;
+
+ Ok(ExprSimplifyResult::Simplified(case_expr))
+ }
+}
diff --git a/datafusion/spark/src/function/conditional/mod.rs
b/datafusion/spark/src/function/conditional/mod.rs
index a87df9a2c8..4301d7642b 100644
--- a/datafusion/spark/src/function/conditional/mod.rs
+++ b/datafusion/spark/src/function/conditional/mod.rs
@@ -16,10 +16,19 @@
// under the License.
use datafusion_expr::ScalarUDF;
+use datafusion_functions::make_udf_function;
use std::sync::Arc;
-pub mod expr_fn {}
+mod r#if;
+
+make_udf_function!(r#if::SparkIf, r#if);
+
+pub mod expr_fn {
+ use datafusion_functions::export_functions;
+
+ export_functions!((r#if, "If arg1 evaluates to true, then returns arg2;
otherwise returns arg3", arg1 arg2 arg3));
+}
pub fn functions() -> Vec<Arc<ScalarUDF>> {
- vec![]
+ vec![r#if()]
}
diff --git a/datafusion/sqllogictest/test_files/spark/conditional/if.slt
b/datafusion/sqllogictest/test_files/spark/conditional/if.slt
index 7baedad745..b4380e065b 100644
--- a/datafusion/sqllogictest/test_files/spark/conditional/if.slt
+++ b/datafusion/sqllogictest/test_files/spark/conditional/if.slt
@@ -21,7 +21,146 @@
# For more information, please see:
# https://github.com/apache/datafusion/issues/15914
-## Original Query: SELECT if(1 < 2, 'a', 'b');
-## PySpark 3.5.5 Result: {'(IF((1 < 2), a, b))': 'a', 'typeof((IF((1 < 2), a,
b)))': 'string', 'typeof((1 < 2))': 'boolean', 'typeof(a)': 'string',
'typeof(b)': 'string'}
-#query
-#SELECT if((1 < 2)::boolean, 'a'::string, 'b'::string);
+## Basic IF function tests
+
+# Test basic true condition
+query T
+SELECT if(true, 'yes', 'no');
+----
+yes
+
+# Test basic false condition
+query T
+SELECT if(false, 'yes', 'no');
+----
+no
+
+# Test with comparison operators
+query T
+SELECT if(1 < 2, 'a', 'b');
+----
+a
+
+query T
+SELECT if(1 > 2, 'a', 'b');
+----
+b
+
+
+## Numeric type tests
+
+# Test with integers
+query I
+SELECT if(true, 10, 20);
+----
+10
+
+query I
+SELECT if(false, 10, 20);
+----
+20
+
+# Test with different integer types
+query I
+SELECT if(true, 100, 200);
+----
+100
+
+## Float type tests
+
+# Test with floating point numbers
+query R
+SELECT if(true, 1.5, 2.5);
+----
+1.5
+
+query R
+SELECT if(false, 1.5, 2.5);
+----
+2.5
+
+## String type tests
+
+# Test with different string values
+query T
+SELECT if(true, 'hello', 'world');
+----
+hello
+
+query T
+SELECT if(false, 'hello', 'world');
+----
+world
+
+## NULL handling tests
+
+# Test with NULL condition
+query T
+SELECT if(NULL, 'yes', 'no');
+----
+no
+
+query T
+SELECT if(NOT NULL, 'yes', 'no');
+----
+no
+
+# Test with NULL true value
+query T
+SELECT if(true, NULL, 'no');
+----
+NULL
+
+# Test with NULL false value
+query T
+SELECT if(false, 'yes', NULL);
+----
+NULL
+
+# Test with all NULL
+query ?
+SELECT if(true, NULL, NULL);
+----
+NULL
+
+## Type coercion tests
+
+# Test integer to float coercion
+query R
+SELECT if(true, 10, 20.5);
+----
+10
+
+query R
+SELECT if(false, 10, 20.5);
+----
+20.5
+
+# Test float to integer coercion
+query R
+SELECT if(true, 10.5, 20);
+----
+10.5
+
+query R
+SELECT if(false, 10.5, 20);
+----
+20
+
+statement error Int64 is not a boolean or null
+SELECT if(1, 10.5, 20);
+
+
+statement error Utf8 is not a boolean or null
+SELECT if('x', 10.5, 20);
+
+query II
+SELECT v, IF(v < 0, 10/0, 1) FROM (VALUES (1), (2)) t(v)
+----
+1 1
+2 1
+
+query I
+SELECT IF(true, 1 / 1, 1 / 0);
+----
+1
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]