This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new f1f0965331 feat: function name hints for UDFs (#9407)
f1f0965331 is described below
commit f1f09653319aea3186c2b1f9ca103ef7030c2da1
Author: SteveLauC <[email protected]>
AuthorDate: Sun Mar 10 19:31:56 2024 +0800
feat: function name hints for UDFs (#9407)
* feat: function name hints for UDFs
* refactor: rebase fn to xxx_names()
* style: fix clippy
* style: fix clippy
* Add test
---------
Co-authored-by: Andrew Lamb <[email protected]>
---
datafusion-cli/Cargo.lock | 1 +
datafusion-examples/examples/rewrite_expr.rs | 12 +++++
datafusion/core/src/execution/context/mod.rs | 12 +++++
datafusion/expr/src/function.rs | 37 ++------------
.../optimizer/tests/optimizer_integration.rs | 12 +++++
datafusion/sql/Cargo.toml | 1 +
datafusion/sql/examples/sql.rs | 12 +++++
datafusion/sql/src/expr/function.rs | 58 ++++++++++++++++++++--
datafusion/sql/src/expr/mod.rs | 12 +++++
datafusion/sql/src/planner.rs | 4 ++
datafusion/sql/tests/sql_integration.rs | 12 +++++
datafusion/sqllogictest/test_files/functions.slt | 2 +-
12 files changed, 135 insertions(+), 40 deletions(-)
diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock
index 5e3c8648fc..b4af789682 100644
--- a/datafusion-cli/Cargo.lock
+++ b/datafusion-cli/Cargo.lock
@@ -1363,6 +1363,7 @@ dependencies = [
"datafusion-expr",
"log",
"sqlparser",
+ "strum 0.26.1",
]
[[package]]
diff --git a/datafusion-examples/examples/rewrite_expr.rs
b/datafusion-examples/examples/rewrite_expr.rs
index cc1396f770..541448ebf1 100644
--- a/datafusion-examples/examples/rewrite_expr.rs
+++ b/datafusion-examples/examples/rewrite_expr.rs
@@ -226,6 +226,18 @@ impl ContextProvider for MyContextProvider {
fn options(&self) -> &ConfigOptions {
&self.options
}
+
+ fn udfs_names(&self) -> Vec<String> {
+ Vec::new()
+ }
+
+ fn udafs_names(&self) -> Vec<String> {
+ Vec::new()
+ }
+
+ fn udwfs_names(&self) -> Vec<String> {
+ Vec::new()
+ }
}
struct MyTableSource {
diff --git a/datafusion/core/src/execution/context/mod.rs
b/datafusion/core/src/execution/context/mod.rs
index 7b37e4914c..49d1b12e66 100644
--- a/datafusion/core/src/execution/context/mod.rs
+++ b/datafusion/core/src/execution/context/mod.rs
@@ -2098,6 +2098,18 @@ impl<'a> ContextProvider for SessionContextProvider<'a> {
fn options(&self) -> &ConfigOptions {
self.state.config_options()
}
+
+ fn udfs_names(&self) -> Vec<String> {
+ self.state.scalar_functions().keys().cloned().collect()
+ }
+
+ fn udafs_names(&self) -> Vec<String> {
+ self.state.aggregate_functions().keys().cloned().collect()
+ }
+
+ fn udwfs_names(&self) -> Vec<String> {
+ self.state.window_functions().keys().cloned().collect()
+ }
}
impl FunctionRegistry for SessionState {
diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs
index 3e30a5574b..a3760eeb35 100644
--- a/datafusion/expr/src/function.rs
+++ b/datafusion/expr/src/function.rs
@@ -17,13 +17,12 @@
//! Function module contains typing and signature for built-in and user
defined functions.
-use crate::{Accumulator, BuiltinScalarFunction, PartitionEvaluator, Signature};
-use crate::{AggregateFunction, BuiltInWindowFunction, ColumnarValue};
+use crate::{
+ Accumulator, BuiltinScalarFunction, ColumnarValue, PartitionEvaluator,
Signature,
+};
use arrow::datatypes::DataType;
-use datafusion_common::utils::datafusion_strsim;
use datafusion_common::Result;
use std::sync::Arc;
-use strum::IntoEnumIterator;
/// Scalar function
///
@@ -75,33 +74,3 @@ pub fn return_type(
pub fn signature(fun: &BuiltinScalarFunction) -> Signature {
fun.signature()
}
-
-/// Suggest a valid function based on an invalid input function name
-pub fn suggest_valid_function(input_function_name: &str, is_window_func: bool)
-> String {
- let valid_funcs = if is_window_func {
- // All aggregate functions and builtin window functions
- AggregateFunction::iter()
- .map(|func| func.to_string())
- .chain(BuiltInWindowFunction::iter().map(|func| func.to_string()))
- .collect()
- } else {
- // All scalar functions and aggregate functions
- BuiltinScalarFunction::iter()
- .map(|func| func.to_string())
- .chain(AggregateFunction::iter().map(|func| func.to_string()))
- .collect()
- };
- find_closest_match(valid_funcs, input_function_name)
-}
-
-/// Find the closest matching string to the target string in the candidates
list, using edit distance(case insensitve)
-/// Input `candidates` must not be empty otherwise it will panic
-fn find_closest_match(candidates: Vec<String>, target: &str) -> String {
- let target = target.to_lowercase();
- candidates
- .into_iter()
- .min_by_key(|candidate| {
- datafusion_strsim::levenshtein(&candidate.to_lowercase(), &target)
- })
- .expect("No candidates provided.") // Panic if `candidates` argument
is empty
-}
diff --git a/datafusion/optimizer/tests/optimizer_integration.rs
b/datafusion/optimizer/tests/optimizer_integration.rs
index db7bfa8b3b..b02623854b 100644
--- a/datafusion/optimizer/tests/optimizer_integration.rs
+++ b/datafusion/optimizer/tests/optimizer_integration.rs
@@ -417,6 +417,18 @@ impl ContextProvider for MyContextProvider {
fn options(&self) -> &ConfigOptions {
&self.options
}
+
+ fn udfs_names(&self) -> Vec<String> {
+ Vec::new()
+ }
+
+ fn udafs_names(&self) -> Vec<String> {
+ Vec::new()
+ }
+
+ fn udwfs_names(&self) -> Vec<String> {
+ Vec::new()
+ }
}
struct MyTableSource {
diff --git a/datafusion/sql/Cargo.toml b/datafusion/sql/Cargo.toml
index fb300e2c87..7739058a5c 100644
--- a/datafusion/sql/Cargo.toml
+++ b/datafusion/sql/Cargo.toml
@@ -43,6 +43,7 @@ datafusion-common = { workspace = true, default-features =
true }
datafusion-expr = { workspace = true }
log = { workspace = true }
sqlparser = { workspace = true }
+strum = { version = "0.26.1", features = ["derive"] }
[dev-dependencies]
ctor = { workspace = true }
diff --git a/datafusion/sql/examples/sql.rs b/datafusion/sql/examples/sql.rs
index 8744a90548..5bab2f19cf 100644
--- a/datafusion/sql/examples/sql.rs
+++ b/datafusion/sql/examples/sql.rs
@@ -131,4 +131,16 @@ impl ContextProvider for MyContextProvider {
fn options(&self) -> &ConfigOptions {
&self.options
}
+
+ fn udfs_names(&self) -> Vec<String> {
+ Vec::new()
+ }
+
+ fn udafs_names(&self) -> Vec<String> {
+ Vec::new()
+ }
+
+ fn udwfs_names(&self) -> Vec<String> {
+ Vec::new()
+ }
}
diff --git a/datafusion/sql/src/expr/function.rs
b/datafusion/sql/src/expr/function.rs
index bcf641e4b5..ffc951a6fa 100644
--- a/datafusion/sql/src/expr/function.rs
+++ b/datafusion/sql/src/expr/function.rs
@@ -20,20 +20,67 @@ use arrow_schema::DataType;
use datafusion_common::{
not_impl_err, plan_datafusion_err, plan_err, DFSchema, Dependency, Result,
};
-use datafusion_expr::expr::{ScalarFunction, Unnest};
-use datafusion_expr::function::suggest_valid_function;
use datafusion_expr::window_frame::{check_window_frame,
regularize_window_order_by};
use datafusion_expr::{
- expr, AggregateFunction, BuiltinScalarFunction, Expr, ExprSchemable,
WindowFrame,
- WindowFunctionDefinition,
+ expr, AggregateFunction, Expr, ExprSchemable, WindowFrame,
WindowFunctionDefinition,
+};
+use datafusion_expr::{
+ expr::{ScalarFunction, Unnest},
+ BuiltInWindowFunction, BuiltinScalarFunction,
};
use sqlparser::ast::{
Expr as SQLExpr, Function as SQLFunction, FunctionArg, FunctionArgExpr,
WindowType,
};
use std::str::FromStr;
+use strum::IntoEnumIterator;
use super::arrow_cast::ARROW_CAST_NAME;
+/// Suggest a valid function based on an invalid input function name
+pub fn suggest_valid_function(
+ input_function_name: &str,
+ is_window_func: bool,
+ ctx: &dyn ContextProvider,
+) -> String {
+ let valid_funcs = if is_window_func {
+ // All aggregate functions and builtin window functions
+ let mut funcs = Vec::new();
+
+ funcs.extend(AggregateFunction::iter().map(|func| func.to_string()));
+ funcs.extend(ctx.udafs_names());
+ funcs.extend(BuiltInWindowFunction::iter().map(|func|
func.to_string()));
+ funcs.extend(ctx.udwfs_names());
+
+ funcs
+ } else {
+ // All scalar functions and aggregate functions
+ let mut funcs = Vec::new();
+
+ funcs.extend(BuiltinScalarFunction::iter().map(|func|
func.to_string()));
+ funcs.extend(ctx.udfs_names());
+ funcs.extend(AggregateFunction::iter().map(|func| func.to_string()));
+ funcs.extend(ctx.udafs_names());
+
+ funcs
+ };
+ find_closest_match(valid_funcs, input_function_name)
+}
+
+/// Find the closest matching string to the target string in the candidates
list, using edit distance(case insensitve)
+/// Input `candidates` must not be empty otherwise it will panic
+fn find_closest_match(candidates: Vec<String>, target: &str) -> String {
+ let target = target.to_lowercase();
+ candidates
+ .into_iter()
+ .min_by_key(|candidate| {
+ datafusion_common::utils::datafusion_strsim::levenshtein(
+ &candidate.to_lowercase(),
+ &target,
+ )
+ })
+ .expect("No candidates provided.") // Panic if `candidates` argument
is empty
+}
+
impl<'a, S: ContextProvider> SqlToRel<'a, S> {
pub(super) fn sql_function_to_expr(
&self,
@@ -211,7 +258,8 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
}
// Could not find the relevant function, so return an error
- let suggested_func_name = suggest_valid_function(&name,
is_function_window);
+ let suggested_func_name =
+ suggest_valid_function(&name, is_function_window,
self.context_provider);
plan_err!("Invalid function '{name}'.\nDid you mean
'{suggested_func_name}'?")
}
diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs
index d6aa006ec3..e838a4cafb 100644
--- a/datafusion/sql/src/expr/mod.rs
+++ b/datafusion/sql/src/expr/mod.rs
@@ -983,6 +983,18 @@ mod tests {
fn get_window_meta(&self, _name: &str) -> Option<Arc<WindowUDF>> {
None
}
+
+ fn udfs_names(&self) -> Vec<String> {
+ Vec::new()
+ }
+
+ fn udafs_names(&self) -> Vec<String> {
+ Vec::new()
+ }
+
+ fn udwfs_names(&self) -> Vec<String> {
+ Vec::new()
+ }
}
fn create_table_source(fields: Vec<Field>) -> Arc<dyn TableSource> {
diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs
index 2db2c01c5e..f94c6ec4e8 100644
--- a/datafusion/sql/src/planner.rs
+++ b/datafusion/sql/src/planner.rs
@@ -85,6 +85,10 @@ pub trait ContextProvider {
/// Get configuration options
fn options(&self) -> &ConfigOptions;
+
+ fn udfs_names(&self) -> Vec<String>;
+ fn udafs_names(&self) -> Vec<String>;
+ fn udwfs_names(&self) -> Vec<String>;
}
/// SQL parser options
diff --git a/datafusion/sql/tests/sql_integration.rs
b/datafusion/sql/tests/sql_integration.rs
index 655eb63cc3..6681c3d025 100644
--- a/datafusion/sql/tests/sql_integration.rs
+++ b/datafusion/sql/tests/sql_integration.rs
@@ -2901,6 +2901,18 @@ impl ContextProvider for MockContextProvider {
) -> Result<Arc<dyn TableSource>> {
Ok(Arc::new(EmptyTable::new(schema)))
}
+
+ fn udfs_names(&self) -> Vec<String> {
+ self.udfs.keys().cloned().collect()
+ }
+
+ fn udafs_names(&self) -> Vec<String> {
+ self.udafs.keys().cloned().collect()
+ }
+
+ fn udwfs_names(&self) -> Vec<String> {
+ Vec::new()
+ }
}
#[test]
diff --git a/datafusion/sqllogictest/test_files/functions.slt
b/datafusion/sqllogictest/test_files/functions.slt
index 96aa3e2752..21433ba168 100644
--- a/datafusion/sqllogictest/test_files/functions.slt
+++ b/datafusion/sqllogictest/test_files/functions.slt
@@ -483,7 +483,7 @@ statement error Did you mean 'arrow_typeof'?
SELECT arrowtypeof(v1) from test;
# Scalar function
-statement error Invalid function 'to_timestamps_second'
+statement error Did you mean 'to_timestamp_seconds'?
SELECT to_TIMESTAMPS_second(v2) from test;
# Aggregate function