This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 0a64f34197 Remove Arc wrapping from create_udf's return_type (#12489)
0a64f34197 is described below

commit 0a64f34197c52be592e7020456a0b610f6b7e827
Author: Piotr Findeisen <[email protected]>
AuthorDate: Tue Sep 17 18:37:37 2024 +0200

    Remove Arc wrapping from create_udf's return_type (#12489)
    
    The argument types are moved into `create_udf` so moving also
    `return_type` would increase API consistency.
    
    Internally, the `create_udf` unwrapped or cloned (so moves) the passed in
    return type Arc, so there was no non-API benefit from using a shared
    pointer.
---
 datafusion-examples/examples/simple_udf.rs                   |  2 +-
 datafusion/core/src/dataframe/mod.rs                         |  2 +-
 datafusion/core/tests/expr_api/simplification.rs             |  2 +-
 .../core/tests/user_defined/user_defined_scalar_functions.rs | 12 ++++++------
 datafusion/expr/src/expr_fn.rs                               |  3 +--
 datafusion/proto/src/bytes/mod.rs                            |  2 +-
 datafusion/proto/tests/cases/roundtrip_logical_plan.rs       |  2 +-
 datafusion/proto/tests/cases/roundtrip_physical_plan.rs      |  2 +-
 datafusion/proto/tests/cases/serialize.rs                    |  2 +-
 datafusion/sqllogictest/src/test_context.rs                  |  2 +-
 10 files changed, 15 insertions(+), 16 deletions(-)

diff --git a/datafusion-examples/examples/simple_udf.rs 
b/datafusion-examples/examples/simple_udf.rs
index 64cf7857e2..6879a17f34 100644
--- a/datafusion-examples/examples/simple_udf.rs
+++ b/datafusion-examples/examples/simple_udf.rs
@@ -109,7 +109,7 @@ async fn main() -> Result<()> {
         // expects two f64
         vec![DataType::Float64, DataType::Float64],
         // returns f64
-        Arc::new(DataType::Float64),
+        DataType::Float64,
         Volatility::Immutable,
         pow,
     );
diff --git a/datafusion/core/src/dataframe/mod.rs 
b/datafusion/core/src/dataframe/mod.rs
index e7aa1172a8..e229b28490 100644
--- a/datafusion/core/src/dataframe/mod.rs
+++ b/datafusion/core/src/dataframe/mod.rs
@@ -2772,7 +2772,7 @@ mod tests {
         ctx.register_udf(create_udf(
             "my_fn",
             vec![DataType::Float64],
-            Arc::new(DataType::Float64),
+            DataType::Float64,
             Volatility::Immutable,
             my_fn,
         ));
diff --git a/datafusion/core/tests/expr_api/simplification.rs 
b/datafusion/core/tests/expr_api/simplification.rs
index b6068e4859..d7995d4663 100644
--- a/datafusion/core/tests/expr_api/simplification.rs
+++ b/datafusion/core/tests/expr_api/simplification.rs
@@ -155,7 +155,7 @@ fn test_evaluate(input_expr: Expr, expected_expr: Expr) {
 // Make a UDF that adds its two values together, with the specified volatility
 fn make_udf_add(volatility: Volatility) -> Arc<ScalarUDF> {
     let input_types = vec![DataType::Int32, DataType::Int32];
-    let return_type = Arc::new(DataType::Int32);
+    let return_type = DataType::Int32;
 
     let fun = Arc::new(|args: &[ColumnarValue]| {
         let args = ColumnarValue::values_to_arrays(args)?;
diff --git 
a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs 
b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs
index 0f1c3b8e53..013aec48d5 100644
--- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs
+++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs
@@ -120,7 +120,7 @@ async fn scalar_udf() -> Result<()> {
     ctx.register_udf(create_udf(
         "my_add",
         vec![DataType::Int32, DataType::Int32],
-        Arc::new(DataType::Int32),
+        DataType::Int32,
         Volatility::Immutable,
         myfunc,
     ));
@@ -237,7 +237,7 @@ async fn test_row_mismatch_error_in_scalar_udf() -> 
Result<()> {
     ctx.register_udf(create_udf(
         "buggy_func",
         vec![DataType::Int32],
-        Arc::new(DataType::Int32),
+        DataType::Int32,
         Volatility::Immutable,
         buggy_udf,
     ));
@@ -321,7 +321,7 @@ async fn scalar_udf_override_built_in_scalar_function() -> 
Result<()> {
     ctx.register_udf(create_udf(
         "abs",
         vec![DataType::Int32],
-        Arc::new(DataType::Int32),
+        DataType::Int32,
         Volatility::Immutable,
         Arc::new(move |_| 
Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(1))))),
     ));
@@ -414,7 +414,7 @@ async fn 
case_sensitive_identifiers_user_defined_functions() -> Result<()> {
     ctx.register_udf(create_udf(
         "MY_FUNC",
         vec![DataType::Int32],
-        Arc::new(DataType::Int32),
+        DataType::Int32,
         Volatility::Immutable,
         myfunc,
     ));
@@ -459,7 +459,7 @@ async fn test_user_defined_functions_with_alias() -> 
Result<()> {
     let udf = create_udf(
         "dummy",
         vec![DataType::Int32],
-        Arc::new(DataType::Int32),
+        DataType::Int32,
         Volatility::Immutable,
         myfunc,
     )
@@ -1149,7 +1149,7 @@ fn create_udf_context() -> SessionContext {
     ctx.register_udf(create_udf(
         "custom_sqrt",
         vec![DataType::Float64],
-        Arc::new(DataType::Float64),
+        DataType::Float64,
         Volatility::Immutable,
         Arc::new(custom_sqrt),
     ));
diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs
index 8d01712b95..5fd3177bc2 100644
--- a/datafusion/expr/src/expr_fn.rs
+++ b/datafusion/expr/src/expr_fn.rs
@@ -390,11 +390,10 @@ pub fn unnest(expr: Expr) -> Expr {
 pub fn create_udf(
     name: &str,
     input_types: Vec<DataType>,
-    return_type: Arc<DataType>,
+    return_type: DataType,
     volatility: Volatility,
     fun: ScalarFunctionImplementation,
 ) -> ScalarUDF {
-    let return_type = Arc::unwrap_or_clone(return_type);
     ScalarUDF::from(SimpleScalarUDF::new(
         name,
         input_types,
diff --git a/datafusion/proto/src/bytes/mod.rs 
b/datafusion/proto/src/bytes/mod.rs
index 9188480431..12ddb4cb2e 100644
--- a/datafusion/proto/src/bytes/mod.rs
+++ b/datafusion/proto/src/bytes/mod.rs
@@ -116,7 +116,7 @@ impl Serializeable for Expr {
                 Ok(Arc::new(create_udf(
                     name,
                     vec![],
-                    Arc::new(arrow::datatypes::DataType::Null),
+                    arrow::datatypes::DataType::Null,
                     Volatility::Immutable,
                     Arc::new(|_| unimplemented!()),
                 )))
diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs 
b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
index 1ff39e9e65..71c8dbe6ec 100644
--- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
+++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
@@ -2172,7 +2172,7 @@ fn roundtrip_scalar_udf() {
     let udf = create_udf(
         "dummy",
         vec![DataType::Utf8],
-        Arc::new(DataType::Utf8),
+        DataType::Utf8,
         Volatility::Immutable,
         scalar_fn,
     );
diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs 
b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
index 58f6015ee3..f4b32e662e 100644
--- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
+++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
@@ -871,7 +871,7 @@ fn roundtrip_scalar_udf() -> Result<()> {
     let udf = create_udf(
         "dummy",
         vec![DataType::Int64],
-        Arc::new(DataType::Int64),
+        DataType::Int64,
         Volatility::Immutable,
         scalar_fn.clone(),
     );
diff --git a/datafusion/proto/tests/cases/serialize.rs 
b/datafusion/proto/tests/cases/serialize.rs
index f28098d83b..d1b50105d0 100644
--- a/datafusion/proto/tests/cases/serialize.rs
+++ b/datafusion/proto/tests/cases/serialize.rs
@@ -238,7 +238,7 @@ fn context_with_udf() -> SessionContext {
     let udf = create_udf(
         "dummy",
         vec![DataType::Utf8],
-        Arc::new(DataType::Utf8),
+        DataType::Utf8,
         Volatility::Immutable,
         scalar_fn,
     );
diff --git a/datafusion/sqllogictest/src/test_context.rs 
b/datafusion/sqllogictest/src/test_context.rs
index ef2fa863e6..19016d328f 100644
--- a/datafusion/sqllogictest/src/test_context.rs
+++ b/datafusion/sqllogictest/src/test_context.rs
@@ -359,7 +359,7 @@ fn create_example_udf() -> ScalarUDF {
         // Expects two f64 values:
         vec![DataType::Float64, DataType::Float64],
         // Returns an f64 value:
-        Arc::new(DataType::Float64),
+        DataType::Float64,
         Volatility::Immutable,
         adder,
     )


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to