martin-g commented on code in PR #21679:
URL: https://github.com/apache/datafusion/pull/21679#discussion_r3098957130


##########
datafusion/sql/src/expr/function.rs:
##########
@@ -363,6 +369,146 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
             }
         }
 
+        if let Some(fm) = self.context_provider.get_higher_order_meta(&name) {
+            // plan non-lambda arguments first so we can get theirs datatype 
and call
+            // HigherOrderUDF::lambda_parameters to then plan the lambda 
arguments with
+            // resolved lambda variables
+            enum ExprOrLambda {
+                Expr(Expr),
+                Lambda(sqlparser::ast::LambdaFunction),
+            }
+
+            let partially_planned = args
+                .into_iter()
+                .map(|a| match a {
+                    FunctionArg::Unnamed(FunctionArgExpr::Expr(SQLExpr::Lambda(
+                        lambda,
+                    ))) => {
+                        if !all_unique(&lambda.params) {
+                            return plan_err!(
+                                "lambda parameters names must be unique, got 
{}",
+                                lambda.params
+                            );
+                        }
+
+                        Ok(ExprOrLambda::Lambda(lambda))
+                    }
+                    _ => Ok(ExprOrLambda::Expr(self.sql_fn_arg_to_logical_expr(
+                        a,
+                        schema,
+                        planner_context,
+                    )?)),
+                })
+                .collect::<Result<Vec<_>>>()?;
+
+            let current_fields = partially_planned
+                .iter()
+                .map(|e| match e {
+                    ExprOrLambda::Expr(expr) => {
+                        Ok(ValueOrLambda::Value(expr.to_field(schema)?.1))
+                    }
+                    ExprOrLambda::Lambda(_lambda_function) => {
+                        Ok(ValueOrLambda::Lambda(()))
+                    }
+                })
+                .collect::<Result<Vec<_>>>()?;
+
+            let coerced_values =
+                value_fields_with_higher_order_udf(&current_fields, 
fm.as_ref())?
+                    .into_iter()
+                    .filter_map(|arg| match arg {
+                        ValueOrLambda::Value(value) => Some(value),
+                        ValueOrLambda::Lambda(_lambda) => None,
+                    })
+                    .collect::<Vec<_>>();
+
+            // lambda_parameters refers only to lambdas and not to values, so 
instead
+            // of zipping it with partially_planned, we iterate over 
partially_planned and only
+            // consume from lambda_parameters when a given argument is a lambda
+            // to reconstruct the arguments list with the correct order
+            // this supports any value and lambda positioning including
+            // multiple lambdas interleaved with values
+            let mut lambda_parameters =
+                fm.lambda_parameters(&coerced_values)?.into_iter();
+
+            let num_lambdas = partially_planned.len() - coerced_values.len();
+
+            // functions can support multiple lambdas where some trailing ones 
are optional,
+            // but to simplify the implementor, lambda_parameters returns the 
parameters of all of them,
+            // so we can't do equality check. one example is spark reduce:
+            // https://spark.apache.org/docs/latest/api/sql/index.html#reduce
+            if lambda_parameters.len() < num_lambdas {
+                return plan_err!(
+                    "{} invocation defined {num_lambdas} but lambda_parameters 
returned only {}",
+                    fm.name(),
+                    lambda_parameters.len()
+                );
+            }
+
+            let args = partially_planned
+                .into_iter()
+                .map(|arg| match arg {
+                    ExprOrLambda::Expr(expr) => Ok(expr),
+                    ExprOrLambda::Lambda(lambda) => {
+                        let lambda_params =
+                            lambda_parameters.next().ok_or_else(|| {
+                                internal_datafusion_err!(
+                                    "lambda_parameters len should have been 
checked above"
+                                )
+                            })?;
+
+                        if lambda.params.len() > lambda_params.len() {
+                            return plan_err!(
+                                "lambda defined {} params but UDF support only 
{}",
+                                lambda.params.len(),
+                                lambda_params.len()
+                            );
+                        }
+
+                        let params =
+                            lambda.params.iter().map(|p| 
p.value.clone()).collect();

Review Comment:
   ```suggestion
                               lambda.params.iter().map(|p| 
crate::utils::normalize_ident(p.clone())).collect();
   ```



##########
docs/source/user-guide/sql/scalar_functions.md:
##########
@@ -4375,6 +4377,34 @@ array_to_string(array, delimiter[, null_string])
 - array_join
 - list_join
 
+### `array_transform`
+
+transforms the values of a array

Review Comment:
   ```suggestion
   transforms the values of an array
   ```



##########
datafusion/physical-expr/src/expressions/case.rs:
##########
@@ -171,6 +178,20 @@ impl CaseBody {
                                 projected,
                             ))));
                         }
+                    } else if let Some(lambda_variable) =
+                        expr.downcast_ref::<LambdaVariable>()
+                    {
+                        let original = lambda_variable.index();
+                        let projected = 
*column_index_map.get(&original).unwrap();
+                        if projected != original {
+                            return 
Ok(Transformed::yes(Arc::new(LambdaVariable::new(
+                                projected,
+                                Arc::clone(lambda_variable.field()),
+                            ))));
+                        }
+                    } else if expr.is::<LambdaExpr>() {

Review Comment:
   ```suggestion
                       } else if e.is::<LambdaExpr>() {
   ```



##########
datafusion/sql/src/unparser/expr.rs:
##########
@@ -552,6 +555,30 @@ impl Unparser<'_> {
             }
             Expr::OuterReferenceColumn(_, col) => self.col_to_sql(col),
             Expr::Unnest(unnest) => self.unnest_to_sql(unnest),
+            Expr::HigherOrderFunction(HigherOrderFunction { func, args }) => {
+                let func_name = func.name();
+
+                if let Some(expr) = self
+                    .dialect
+                    .higher_order_function_to_sql_overrides(self, func_name, 
args)?
+                {
+                    return Ok(expr);
+                }
+
+                self.function_to_sql_internal(func_name, args)
+            }
+            Expr::Lambda(Lambda { params, body }) => {
+                Ok(ast::Expr::Lambda(ast::LambdaFunction {
+                    params: ast::OneOrManyWithParens::Many(
+                        params.iter().map(|param| 
param.as_str().into()).collect(),

Review Comment:
   ```suggestion
                           params
                               .iter()
                               .map(|param| 
self.new_ident_quoted_if_needs(param.clone()))
                               .collect(),
   ```
   ?



##########
datafusion/physical-expr/src/expressions/lambda.rs:
##########
@@ -0,0 +1,156 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Physical lambda expression: [`LambdaExpr`]
+
+use std::hash::Hash;
+use std::sync::Arc;
+
+use crate::physical_expr::PhysicalExpr;
+use arrow::{
+    datatypes::{DataType, Schema},
+    record_batch::RecordBatch,
+};
+use datafusion_common::plan_err;
+use datafusion_common::{HashSet, Result, internal_err};
+use datafusion_expr::ColumnarValue;
+
+/// Represents a lambda with the given parameters names and body
+#[derive(Debug, Eq, Clone)]
+pub struct LambdaExpr {
+    params: Vec<String>,
+    body: Arc<dyn PhysicalExpr>,
+}
+
+// Manually derive PartialEq and Hash to work around 
https://github.com/rust-lang/rust/issues/78808 
[https://github.com/apache/datafusion/issues/13196]
+impl PartialEq for LambdaExpr {
+    fn eq(&self, other: &Self) -> bool {
+        self.params.eq(&other.params) && self.body.eq(&other.body)
+    }
+}
+
+impl Hash for LambdaExpr {
+    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
+        self.params.hash(state);
+        self.body.hash(state);
+    }
+}
+
+impl LambdaExpr {
+    /// Create a new lambda expression with the given parameters and body
+    pub fn try_new(params: Vec<String>, body: Arc<dyn PhysicalExpr>) -> 
Result<Self> {
+        if all_unique(&params) {
+            Ok(Self::new(params, body))
+        } else {
+            plan_err!("lambda params must be unique, got ({})", params.join(", 
"))
+        }
+    }
+
+    fn new(params: Vec<String>, body: Arc<dyn PhysicalExpr>) -> Self {
+        Self { params, body }
+    }
+
+    /// Get the lambda's params names
+    pub fn params(&self) -> &[String] {
+        &self.params
+    }
+
+    /// Get the lambda's body
+    pub fn body(&self) -> &Arc<dyn PhysicalExpr> {
+        &self.body
+    }
+}
+
+impl std::fmt::Display for LambdaExpr {
+    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
+        write!(f, "({}) -> {}", self.params.join(", "), self.body)
+    }
+}
+
+impl PhysicalExpr for LambdaExpr {
+    fn data_type(&self, _input_schema: &Schema) -> Result<DataType> {
+        Ok(DataType::Null)
+    }
+
+    fn nullable(&self, _input_schema: &Schema) -> Result<bool> {
+        Ok(true)
+    }
+
+    fn evaluate(&self, _batch: &RecordBatch) -> Result<ColumnarValue> {
+        internal_err!("LambdaExpr::evaluate() should not be called")
+    }
+
+    fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
+        vec![&self.body]
+    }
+
+    fn with_new_children(
+        self: Arc<Self>,
+        children: Vec<Arc<dyn PhysicalExpr>>,
+    ) -> Result<Arc<dyn PhysicalExpr>> {
+        Ok(Arc::new(Self::new(
+            self.params.clone(),
+            Arc::clone(&children[0]),
+        )))

Review Comment:
   ```suggestion
           let [body] = children.as_slice() else {
               return internal_err!(
                   "LambdaExpr expects exactly 1 child, got {}",
                   children.len()
               );
           };
   
           Ok(Arc::new(Self::new(
               self.params.clone(),
               Arc::clone(body),
           )))
   ```



##########
datafusion/optimizer/src/analyzer/type_coercion.rs:
##########
@@ -763,6 +766,31 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> {
                 });
                 Ok(Transformed::yes(new_expr))
             }
+            Expr::HigherOrderFunction(HigherOrderFunction { func, args }) => {
+                let current_fields = args
+                    .iter()
+                    .map(|arg| match arg {
+                        Expr::Lambda(_) => Ok(ValueOrLambda::Lambda(())),
+                        _ => 
Ok(ValueOrLambda::Value(arg.to_field(self.schema)?.1)),
+                    })
+                    .collect::<Result<Vec<_>>>()?;
+
+                let new_fields =
+                    value_fields_with_higher_order_udf(&current_fields, 
func.as_ref())?;
+
+                let new_args = std::iter::zip(args, new_fields)
+                    .map(|(arg, new_field)| match (&arg, new_field) {
+                        (Expr::Lambda(_lambda), ValueOrLambda::Lambda(_)) => 
Ok(arg),
+                        (Expr::Lambda(_lambda), ValueOrLambda::Value(_)) => 
plan_err!("value_fields_with_higher_order_udf return a value for a lambda 
argument"),
+                        (_, ValueOrLambda::Value(new_field)) => 
arg.cast_to(new_field.data_type(), self.schema),
+                        (_, ValueOrLambda::Lambda(_)) => 
plan_err!("value_fields_with_higher_order_udf return a lambda for a value 
argument"),

Review Comment:
   ```suggestion
                           (_, ValueOrLambda::Lambda(_)) => 
plan_err!("value_fields_with_higher_order_udf returned a lambda for a value 
argument"),
   ```
   same note about `value_fields_with_higher_order_udf` as above



##########
datafusion/physical-expr/src/expressions/case.rs:
##########
@@ -171,6 +178,20 @@ impl CaseBody {
                                 projected,
                             ))));
                         }
+                    } else if let Some(lambda_variable) =
+                        expr.downcast_ref::<LambdaVariable>()

Review Comment:
   ```suggestion
                           e.downcast_ref::<LambdaVariable>()
   ```



##########
datafusion/sql/src/expr/function.rs:
##########
@@ -956,3 +1102,15 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
         }
     }
 }
+
+fn all_unique(params: &[sqlparser::ast::Ident]) -> bool {
+    match params.len() {
+        0 | 1 => true,
+        2 => params[0].value != params[1].value,
+        _ => {
+            let mut set = HashSet::with_capacity(params.len());
+
+            params.iter().all(|p| set.insert(p.value.as_str()))

Review Comment:
   ```suggestion
               params
                   .iter()
                   .map(|p| crate::utils::normalize_ident(p.clone()))
                   .all(|p| set.insert(p))
   ```



##########
datafusion/optimizer/src/analyzer/type_coercion.rs:
##########
@@ -763,6 +766,31 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> {
                 });
                 Ok(Transformed::yes(new_expr))
             }
+            Expr::HigherOrderFunction(HigherOrderFunction { func, args }) => {
+                let current_fields = args
+                    .iter()
+                    .map(|arg| match arg {
+                        Expr::Lambda(_) => Ok(ValueOrLambda::Lambda(())),
+                        _ => 
Ok(ValueOrLambda::Value(arg.to_field(self.schema)?.1)),
+                    })
+                    .collect::<Result<Vec<_>>>()?;
+
+                let new_fields =
+                    value_fields_with_higher_order_udf(&current_fields, 
func.as_ref())?;
+
+                let new_args = std::iter::zip(args, new_fields)
+                    .map(|(arg, new_field)| match (&arg, new_field) {
+                        (Expr::Lambda(_lambda), ValueOrLambda::Lambda(_)) => 
Ok(arg),
+                        (Expr::Lambda(_lambda), ValueOrLambda::Value(_)) => 
plan_err!("value_fields_with_higher_order_udf return a value for a lambda 
argument"),

Review Comment:
   ```suggestion
                           (Expr::Lambda(_lambda), ValueOrLambda::Value(_)) => 
plan_err!("value_fields_with_higher_order_udf returned a value for a lambda 
argument"),
   ```
   `plan_err` suggests that this is a user error (e.g. wrong SQL query) but 
using `value_fields_with_higher_order_udf` in the error message is an internal 
technical detail. 
   Could this be made more user-friendly ? For example by displaying the 
problematic expression.



##########
datafusion/sql/src/expr/function.rs:
##########
@@ -956,3 +1102,15 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
         }
     }
 }
+
+fn all_unique(params: &[sqlparser::ast::Ident]) -> bool {
+    match params.len() {
+        0 | 1 => true,
+        2 => params[0].value != params[1].value,

Review Comment:
   ```suggestion
           2 => {
               crate::utils::normalize_ident(params[0].clone())
                   != crate::utils::normalize_ident(params[1].clone())
           }
   ```



##########
datafusion/session/src/session.rs:
##########
@@ -111,6 +113,9 @@ pub trait Session: Send + Sync {
     /// Return reference to scalar_functions
     fn scalar_functions(&self) -> &HashMap<String, Arc<ScalarUDF>>;
 
+    /// Return reference to higher_order_functions
+    fn higher_order_functions(&self) -> &HashMap<String, Arc<dyn 
HigherOrderUDF>>;

Review Comment:
   I wonder whether it would be a good idea to return an empty HashMap by 
default would prevent some broken builds for third party implementations



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to