martin-g commented on code in PR #21679:
URL: https://github.com/apache/datafusion/pull/21679#discussion_r3098957130
##########
datafusion/sql/src/expr/function.rs:
##########
@@ -363,6 +369,146 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
}
}
+ if let Some(fm) = self.context_provider.get_higher_order_meta(&name) {
+ // plan non-lambda arguments first so we can get theirs datatype
and call
+ // HigherOrderUDF::lambda_parameters to then plan the lambda
arguments with
+ // resolved lambda variables
+ enum ExprOrLambda {
+ Expr(Expr),
+ Lambda(sqlparser::ast::LambdaFunction),
+ }
+
+ let partially_planned = args
+ .into_iter()
+ .map(|a| match a {
+ FunctionArg::Unnamed(FunctionArgExpr::Expr(SQLExpr::Lambda(
+ lambda,
+ ))) => {
+ if !all_unique(&lambda.params) {
+ return plan_err!(
+ "lambda parameters names must be unique, got
{}",
+ lambda.params
+ );
+ }
+
+ Ok(ExprOrLambda::Lambda(lambda))
+ }
+ _ => Ok(ExprOrLambda::Expr(self.sql_fn_arg_to_logical_expr(
+ a,
+ schema,
+ planner_context,
+ )?)),
+ })
+ .collect::<Result<Vec<_>>>()?;
+
+ let current_fields = partially_planned
+ .iter()
+ .map(|e| match e {
+ ExprOrLambda::Expr(expr) => {
+ Ok(ValueOrLambda::Value(expr.to_field(schema)?.1))
+ }
+ ExprOrLambda::Lambda(_lambda_function) => {
+ Ok(ValueOrLambda::Lambda(()))
+ }
+ })
+ .collect::<Result<Vec<_>>>()?;
+
+ let coerced_values =
+ value_fields_with_higher_order_udf(¤t_fields,
fm.as_ref())?
+ .into_iter()
+ .filter_map(|arg| match arg {
+ ValueOrLambda::Value(value) => Some(value),
+ ValueOrLambda::Lambda(_lambda) => None,
+ })
+ .collect::<Vec<_>>();
+
+ // lambda_parameters refers only to lambdas and not to values, so
instead
+ // of zipping it with partially_planned, we iterate over
partially_planned and only
+ // consume from lambda_parameters when a given argument is a lambda
+ // to reconstruct the arguments list with the correct order
+ // this supports any value and lambda positioning including
+ // multiple lambdas interleaved with values
+ let mut lambda_parameters =
+ fm.lambda_parameters(&coerced_values)?.into_iter();
+
+ let num_lambdas = partially_planned.len() - coerced_values.len();
+
+ // functions can support multiple lambdas where some trailing ones
are optional,
+ // but to simplify the implementor, lambda_parameters returns the
parameters of all of them,
+ // so we can't do equality check. one example is spark reduce:
+ // https://spark.apache.org/docs/latest/api/sql/index.html#reduce
+ if lambda_parameters.len() < num_lambdas {
+ return plan_err!(
+ "{} invocation defined {num_lambdas} but lambda_parameters
returned only {}",
+ fm.name(),
+ lambda_parameters.len()
+ );
+ }
+
+ let args = partially_planned
+ .into_iter()
+ .map(|arg| match arg {
+ ExprOrLambda::Expr(expr) => Ok(expr),
+ ExprOrLambda::Lambda(lambda) => {
+ let lambda_params =
+ lambda_parameters.next().ok_or_else(|| {
+ internal_datafusion_err!(
+ "lambda_parameters len should have been
checked above"
+ )
+ })?;
+
+ if lambda.params.len() > lambda_params.len() {
+ return plan_err!(
+ "lambda defined {} params but UDF support only
{}",
+ lambda.params.len(),
+ lambda_params.len()
+ );
+ }
+
+ let params =
+ lambda.params.iter().map(|p|
p.value.clone()).collect();
Review Comment:
```suggestion
lambda.params.iter().map(|p|
crate::utils::normalize_ident(p.clone())).collect();
```
##########
docs/source/user-guide/sql/scalar_functions.md:
##########
@@ -4375,6 +4377,34 @@ array_to_string(array, delimiter[, null_string])
- array_join
- list_join
+### `array_transform`
+
+transforms the values of a array
Review Comment:
```suggestion
transforms the values of an array
```
##########
datafusion/physical-expr/src/expressions/case.rs:
##########
@@ -171,6 +178,20 @@ impl CaseBody {
projected,
))));
}
+ } else if let Some(lambda_variable) =
+ expr.downcast_ref::<LambdaVariable>()
+ {
+ let original = lambda_variable.index();
+ let projected =
*column_index_map.get(&original).unwrap();
+ if projected != original {
+ return
Ok(Transformed::yes(Arc::new(LambdaVariable::new(
+ projected,
+ Arc::clone(lambda_variable.field()),
+ ))));
+ }
+ } else if expr.is::<LambdaExpr>() {
Review Comment:
```suggestion
} else if e.is::<LambdaExpr>() {
```
##########
datafusion/sql/src/unparser/expr.rs:
##########
@@ -552,6 +555,30 @@ impl Unparser<'_> {
}
Expr::OuterReferenceColumn(_, col) => self.col_to_sql(col),
Expr::Unnest(unnest) => self.unnest_to_sql(unnest),
+ Expr::HigherOrderFunction(HigherOrderFunction { func, args }) => {
+ let func_name = func.name();
+
+ if let Some(expr) = self
+ .dialect
+ .higher_order_function_to_sql_overrides(self, func_name,
args)?
+ {
+ return Ok(expr);
+ }
+
+ self.function_to_sql_internal(func_name, args)
+ }
+ Expr::Lambda(Lambda { params, body }) => {
+ Ok(ast::Expr::Lambda(ast::LambdaFunction {
+ params: ast::OneOrManyWithParens::Many(
+ params.iter().map(|param|
param.as_str().into()).collect(),
Review Comment:
```suggestion
params
.iter()
.map(|param|
self.new_ident_quoted_if_needs(param.clone()))
.collect(),
```
?
##########
datafusion/physical-expr/src/expressions/lambda.rs:
##########
@@ -0,0 +1,156 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Physical lambda expression: [`LambdaExpr`]
+
+use std::hash::Hash;
+use std::sync::Arc;
+
+use crate::physical_expr::PhysicalExpr;
+use arrow::{
+ datatypes::{DataType, Schema},
+ record_batch::RecordBatch,
+};
+use datafusion_common::plan_err;
+use datafusion_common::{HashSet, Result, internal_err};
+use datafusion_expr::ColumnarValue;
+
+/// Represents a lambda with the given parameters names and body
+#[derive(Debug, Eq, Clone)]
+pub struct LambdaExpr {
+ params: Vec<String>,
+ body: Arc<dyn PhysicalExpr>,
+}
+
+// Manually derive PartialEq and Hash to work around
https://github.com/rust-lang/rust/issues/78808
[https://github.com/apache/datafusion/issues/13196]
+impl PartialEq for LambdaExpr {
+ fn eq(&self, other: &Self) -> bool {
+ self.params.eq(&other.params) && self.body.eq(&other.body)
+ }
+}
+
+impl Hash for LambdaExpr {
+ fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
+ self.params.hash(state);
+ self.body.hash(state);
+ }
+}
+
+impl LambdaExpr {
+ /// Create a new lambda expression with the given parameters and body
+ pub fn try_new(params: Vec<String>, body: Arc<dyn PhysicalExpr>) ->
Result<Self> {
+ if all_unique(¶ms) {
+ Ok(Self::new(params, body))
+ } else {
+ plan_err!("lambda params must be unique, got ({})", params.join(",
"))
+ }
+ }
+
+ fn new(params: Vec<String>, body: Arc<dyn PhysicalExpr>) -> Self {
+ Self { params, body }
+ }
+
+ /// Get the lambda's params names
+ pub fn params(&self) -> &[String] {
+ &self.params
+ }
+
+ /// Get the lambda's body
+ pub fn body(&self) -> &Arc<dyn PhysicalExpr> {
+ &self.body
+ }
+}
+
+impl std::fmt::Display for LambdaExpr {
+ fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
+ write!(f, "({}) -> {}", self.params.join(", "), self.body)
+ }
+}
+
+impl PhysicalExpr for LambdaExpr {
+ fn data_type(&self, _input_schema: &Schema) -> Result<DataType> {
+ Ok(DataType::Null)
+ }
+
+ fn nullable(&self, _input_schema: &Schema) -> Result<bool> {
+ Ok(true)
+ }
+
+ fn evaluate(&self, _batch: &RecordBatch) -> Result<ColumnarValue> {
+ internal_err!("LambdaExpr::evaluate() should not be called")
+ }
+
+ fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
+ vec![&self.body]
+ }
+
+ fn with_new_children(
+ self: Arc<Self>,
+ children: Vec<Arc<dyn PhysicalExpr>>,
+ ) -> Result<Arc<dyn PhysicalExpr>> {
+ Ok(Arc::new(Self::new(
+ self.params.clone(),
+ Arc::clone(&children[0]),
+ )))
Review Comment:
```suggestion
let [body] = children.as_slice() else {
return internal_err!(
"LambdaExpr expects exactly 1 child, got {}",
children.len()
);
};
Ok(Arc::new(Self::new(
self.params.clone(),
Arc::clone(body),
)))
```
##########
datafusion/optimizer/src/analyzer/type_coercion.rs:
##########
@@ -763,6 +766,31 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> {
});
Ok(Transformed::yes(new_expr))
}
+ Expr::HigherOrderFunction(HigherOrderFunction { func, args }) => {
+ let current_fields = args
+ .iter()
+ .map(|arg| match arg {
+ Expr::Lambda(_) => Ok(ValueOrLambda::Lambda(())),
+ _ =>
Ok(ValueOrLambda::Value(arg.to_field(self.schema)?.1)),
+ })
+ .collect::<Result<Vec<_>>>()?;
+
+ let new_fields =
+ value_fields_with_higher_order_udf(¤t_fields,
func.as_ref())?;
+
+ let new_args = std::iter::zip(args, new_fields)
+ .map(|(arg, new_field)| match (&arg, new_field) {
+ (Expr::Lambda(_lambda), ValueOrLambda::Lambda(_)) =>
Ok(arg),
+ (Expr::Lambda(_lambda), ValueOrLambda::Value(_)) =>
plan_err!("value_fields_with_higher_order_udf return a value for a lambda
argument"),
+ (_, ValueOrLambda::Value(new_field)) =>
arg.cast_to(new_field.data_type(), self.schema),
+ (_, ValueOrLambda::Lambda(_)) =>
plan_err!("value_fields_with_higher_order_udf return a lambda for a value
argument"),
Review Comment:
```suggestion
(_, ValueOrLambda::Lambda(_)) =>
plan_err!("value_fields_with_higher_order_udf returned a lambda for a value
argument"),
```
same note about `value_fields_with_higher_order_udf` as above
##########
datafusion/physical-expr/src/expressions/case.rs:
##########
@@ -171,6 +178,20 @@ impl CaseBody {
projected,
))));
}
+ } else if let Some(lambda_variable) =
+ expr.downcast_ref::<LambdaVariable>()
Review Comment:
```suggestion
e.downcast_ref::<LambdaVariable>()
```
##########
datafusion/sql/src/expr/function.rs:
##########
@@ -956,3 +1102,15 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
}
}
}
+
+fn all_unique(params: &[sqlparser::ast::Ident]) -> bool {
+ match params.len() {
+ 0 | 1 => true,
+ 2 => params[0].value != params[1].value,
+ _ => {
+ let mut set = HashSet::with_capacity(params.len());
+
+ params.iter().all(|p| set.insert(p.value.as_str()))
Review Comment:
```suggestion
params
.iter()
.map(|p| crate::utils::normalize_ident(p.clone()))
.all(|p| set.insert(p))
```
##########
datafusion/optimizer/src/analyzer/type_coercion.rs:
##########
@@ -763,6 +766,31 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> {
});
Ok(Transformed::yes(new_expr))
}
+ Expr::HigherOrderFunction(HigherOrderFunction { func, args }) => {
+ let current_fields = args
+ .iter()
+ .map(|arg| match arg {
+ Expr::Lambda(_) => Ok(ValueOrLambda::Lambda(())),
+ _ =>
Ok(ValueOrLambda::Value(arg.to_field(self.schema)?.1)),
+ })
+ .collect::<Result<Vec<_>>>()?;
+
+ let new_fields =
+ value_fields_with_higher_order_udf(¤t_fields,
func.as_ref())?;
+
+ let new_args = std::iter::zip(args, new_fields)
+ .map(|(arg, new_field)| match (&arg, new_field) {
+ (Expr::Lambda(_lambda), ValueOrLambda::Lambda(_)) =>
Ok(arg),
+ (Expr::Lambda(_lambda), ValueOrLambda::Value(_)) =>
plan_err!("value_fields_with_higher_order_udf return a value for a lambda
argument"),
Review Comment:
```suggestion
(Expr::Lambda(_lambda), ValueOrLambda::Value(_)) =>
plan_err!("value_fields_with_higher_order_udf returned a value for a lambda
argument"),
```
`plan_err` suggests that this is a user error (e.g. wrong SQL query) but
using `value_fields_with_higher_order_udf` in the error message is an internal
technical detail.
Could this be made more user-friendly ? For example by displaying the
problematic expression.
##########
datafusion/sql/src/expr/function.rs:
##########
@@ -956,3 +1102,15 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
}
}
}
+
+fn all_unique(params: &[sqlparser::ast::Ident]) -> bool {
+ match params.len() {
+ 0 | 1 => true,
+ 2 => params[0].value != params[1].value,
Review Comment:
```suggestion
2 => {
crate::utils::normalize_ident(params[0].clone())
!= crate::utils::normalize_ident(params[1].clone())
}
```
##########
datafusion/session/src/session.rs:
##########
@@ -111,6 +113,9 @@ pub trait Session: Send + Sync {
/// Return reference to scalar_functions
fn scalar_functions(&self) -> &HashMap<String, Arc<ScalarUDF>>;
+ /// Return reference to higher_order_functions
+ fn higher_order_functions(&self) -> &HashMap<String, Arc<dyn
HigherOrderUDF>>;
Review Comment:
I wonder whether it would be a good idea to return an empty HashMap by
default would prevent some broken builds for third party implementations
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]