This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new abb0c1f62b UDAF and UDWF support aliases (#9489)
abb0c1f62b is described below
commit abb0c1f62bf622bd0e40769560cf0804dac2ecbf
Author: 张林伟 <[email protected]>
AuthorDate: Tue Mar 12 02:55:41 2024 +0800
UDAF and UDWF support aliases (#9489)
* UDAF and UDWF support aliases
* Add tests for udaf and udwf aliases
* Fix clippy lint
---
datafusion/core/src/execution/context/mod.rs | 22 ++++++-
.../tests/user_defined/user_defined_aggregates.rs | 37 +++++++++++
.../user_defined/user_defined_window_functions.rs | 39 ++++++++++++
datafusion/execution/src/task.rs | 6 ++
datafusion/expr/src/udaf.rs | 71 ++++++++++++++++++++++
datafusion/expr/src/udwf.rs | 71 +++++++++++++++++++++-
6 files changed, 242 insertions(+), 4 deletions(-)
diff --git a/datafusion/core/src/execution/context/mod.rs
b/datafusion/core/src/execution/context/mod.rs
index dc4e39d37c..d64dd7c896 100644
--- a/datafusion/core/src/execution/context/mod.rs
+++ b/datafusion/core/src/execution/context/mod.rs
@@ -2155,10 +2155,16 @@ impl FunctionRegistry for SessionState {
&mut self,
udaf: Arc<AggregateUDF>,
) -> Result<Option<Arc<AggregateUDF>>> {
+ udaf.aliases().iter().for_each(|alias| {
+ self.aggregate_functions.insert(alias.clone(), udaf.clone());
+ });
Ok(self.aggregate_functions.insert(udaf.name().into(), udaf))
}
fn register_udwf(&mut self, udwf: Arc<WindowUDF>) ->
Result<Option<Arc<WindowUDF>>> {
+ udwf.aliases().iter().for_each(|alias| {
+ self.window_functions.insert(alias.clone(), udwf.clone());
+ });
Ok(self.window_functions.insert(udwf.name().into(), udwf))
}
@@ -2173,11 +2179,23 @@ impl FunctionRegistry for SessionState {
}
fn deregister_udaf(&mut self, name: &str) ->
Result<Option<Arc<AggregateUDF>>> {
- Ok(self.aggregate_functions.remove(name))
+ let udaf = self.aggregate_functions.remove(name);
+ if let Some(udaf) = &udaf {
+ for alias in udaf.aliases() {
+ self.aggregate_functions.remove(alias);
+ }
+ }
+ Ok(udaf)
}
fn deregister_udwf(&mut self, name: &str) ->
Result<Option<Arc<WindowUDF>>> {
- Ok(self.window_functions.remove(name))
+ let udwf = self.window_functions.remove(name);
+ if let Some(udwf) = &udwf {
+ for alias in udwf.aliases() {
+ self.window_functions.remove(alias);
+ }
+ }
+ Ok(udwf)
}
}
diff --git a/datafusion/core/tests/user_defined/user_defined_aggregates.rs
b/datafusion/core/tests/user_defined/user_defined_aggregates.rs
index 9e231d25f2..3f40c55a3e 100644
--- a/datafusion/core/tests/user_defined/user_defined_aggregates.rs
+++ b/datafusion/core/tests/user_defined/user_defined_aggregates.rs
@@ -27,6 +27,7 @@ use std::sync::{
};
use datafusion::datasource::MemTable;
+use datafusion::test_util::plan_and_collect;
use datafusion::{
arrow::{
array::{ArrayRef, Float64Array, TimestampNanosecondArray},
@@ -320,6 +321,42 @@ async fn
case_sensitive_identifiers_user_defined_aggregates() -> Result<()> {
Ok(())
}
+#[tokio::test]
+async fn test_user_defined_functions_with_alias() -> Result<()> {
+ let ctx = SessionContext::new();
+ let arr = Int32Array::from(vec![1]);
+ let batch = RecordBatch::try_from_iter(vec![("i", Arc::new(arr) as _)])?;
+ ctx.register_batch("t", batch).unwrap();
+
+ let my_avg = create_udaf(
+ "dummy",
+ vec![DataType::Float64],
+ Arc::new(DataType::Float64),
+ Volatility::Immutable,
+ Arc::new(|_| Ok(Box::<AvgAccumulator>::default())),
+ Arc::new(vec![DataType::UInt64, DataType::Float64]),
+ )
+ .with_aliases(vec!["dummy_alias"]);
+
+ ctx.register_udaf(my_avg);
+
+ let expected = [
+ "+------------+",
+ "| dummy(t.i) |",
+ "+------------+",
+ "| 1.0 |",
+ "+------------+",
+ ];
+
+ let result = plan_and_collect(&ctx, "SELECT dummy(i) FROM t").await?;
+ assert_batches_eq!(expected, &result);
+
+ let alias_result = plan_and_collect(&ctx, "SELECT dummy_alias(i) FROM
t").await?;
+ assert_batches_eq!(expected, &alias_result);
+
+ Ok(())
+}
+
#[tokio::test]
async fn test_groups_accumulator() -> Result<()> {
let ctx = SessionContext::new();
diff --git
a/datafusion/core/tests/user_defined/user_defined_window_functions.rs
b/datafusion/core/tests/user_defined/user_defined_window_functions.rs
index cfd74f8861..3c607301fc 100644
--- a/datafusion/core/tests/user_defined/user_defined_window_functions.rs
+++ b/datafusion/core/tests/user_defined/user_defined_window_functions.rs
@@ -41,6 +41,10 @@ const UNBOUNDED_WINDOW_QUERY: &str = "SELECT x, y, val, \
odd_counter(val) OVER (PARTITION BY x ORDER BY y) \
from t ORDER BY x, y";
+const UNBOUNDED_WINDOW_QUERY_WITH_ALIAS: &str = "SELECT x, y, val, \
+ odd_counter_alias(val) OVER (PARTITION BY x ORDER BY y) \
+ from t ORDER BY x, y";
+
/// A query with a window function evaluated over a moving window
const BOUNDED_WINDOW_QUERY: &str =
"SELECT x, y, val, \
@@ -118,6 +122,35 @@ async fn test_deregister_udwf() -> Result<()> {
Ok(())
}
+#[tokio::test]
+async fn test_udwf_with_alias() {
+ let test_state = TestState::new();
+ let TestContext { ctx, .. } = TestContext::new(test_state);
+
+ let expected = vec![
+
"+---+---+-----+-----------------------------------------------------------------------------------------------------------------------+",
+ "| x | y | val | odd_counter(t.val) PARTITION BY [t.x] ORDER BY [t.y
ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW |",
+
"+---+---+-----+-----------------------------------------------------------------------------------------------------------------------+",
+ "| 1 | a | 0 | 1
|",
+ "| 1 | b | 1 | 1
|",
+ "| 1 | c | 2 | 1
|",
+ "| 2 | d | 3 | 2
|",
+ "| 2 | e | 4 | 2
|",
+ "| 2 | f | 5 | 2
|",
+ "| 2 | g | 6 | 2
|",
+ "| 2 | h | 6 | 2
|",
+ "| 2 | i | 6 | 2
|",
+ "| 2 | j | 6 | 2
|",
+
"+---+---+-----+-----------------------------------------------------------------------------------------------------------------------+",
+ ];
+ assert_batches_eq!(
+ expected,
+ &execute(&ctx, UNBOUNDED_WINDOW_QUERY_WITH_ALIAS)
+ .await
+ .unwrap()
+ );
+}
+
/// Basic user defined window function with bounded window
#[tokio::test]
async fn test_udwf_bounded_window_ignores_frame() {
@@ -491,6 +524,7 @@ impl OddCounter {
signature: Signature,
return_type: DataType,
test_state: Arc<TestState>,
+ aliases: Vec<String>,
}
impl SimpleWindowUDF {
@@ -502,6 +536,7 @@ impl OddCounter {
signature,
return_type,
test_state,
+ aliases: vec!["odd_counter_alias".to_string()],
}
}
}
@@ -526,6 +561,10 @@ impl OddCounter {
fn partition_evaluator(&self) -> Result<Box<dyn
PartitionEvaluator>> {
Ok(Box::new(OddCounter::new(Arc::clone(&self.test_state))))
}
+
+ fn aliases(&self) -> &[String] {
+ &self.aliases
+ }
}
ctx.register_udwf(WindowUDF::from(SimpleWindowUDF::new(test_state)))
diff --git a/datafusion/execution/src/task.rs b/datafusion/execution/src/task.rs
index b39b4a0032..cae410655d 100644
--- a/datafusion/execution/src/task.rs
+++ b/datafusion/execution/src/task.rs
@@ -207,9 +207,15 @@ impl FunctionRegistry for TaskContext {
&mut self,
udaf: Arc<AggregateUDF>,
) -> Result<Option<Arc<AggregateUDF>>> {
+ udaf.aliases().iter().for_each(|alias| {
+ self.aggregate_functions.insert(alias.clone(), udaf.clone());
+ });
Ok(self.aggregate_functions.insert(udaf.name().into(), udaf))
}
fn register_udwf(&mut self, udwf: Arc<WindowUDF>) ->
Result<Option<Arc<WindowUDF>>> {
+ udwf.aliases().iter().for_each(|alias| {
+ self.window_functions.insert(alias.clone(), udwf.clone());
+ });
Ok(self.window_functions.insert(udwf.name().into(), udwf))
}
fn register_udf(&mut self, udf: Arc<ScalarUDF>) ->
Result<Option<Arc<ScalarUDF>>> {
diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs
index e56723063e..c46dd9cd3a 100644
--- a/datafusion/expr/src/udaf.rs
+++ b/datafusion/expr/src/udaf.rs
@@ -118,6 +118,14 @@ impl AggregateUDF {
self.inner.clone()
}
+ /// Adds additional names that can be used to invoke this function, in
+ /// addition to `name`
+ ///
+ /// If you implement [`AggregateUDFImpl`] directly you should return
aliases directly.
+ pub fn with_aliases(self, aliases: impl IntoIterator<Item = &'static str>)
-> Self {
+ Self::new_from_impl(AliasedAggregateUDFImpl::new(self.inner.clone(),
aliases))
+ }
+
/// creates an [`Expr`] that calls the aggregate function.
///
/// This utility allows using the UDAF without requiring access to
@@ -139,6 +147,11 @@ impl AggregateUDF {
self.inner.name()
}
+ /// Returns the aliases for this function.
+ pub fn aliases(&self) -> &[String] {
+ self.inner.aliases()
+ }
+
/// Returns this function's signature (what input types are accepted)
///
/// See [`AggregateUDFImpl::signature`] for more details.
@@ -277,6 +290,64 @@ pub trait AggregateUDFImpl: Debug + Send + Sync {
fn create_groups_accumulator(&self) -> Result<Box<dyn GroupsAccumulator>> {
not_impl_err!("GroupsAccumulator hasn't been implemented for {self:?}
yet")
}
+
+ /// Returns any aliases (alternate names) for this function.
+ ///
+ /// Note: `aliases` should only include names other than [`Self::name`].
+ /// Defaults to `[]` (no aliases)
+ fn aliases(&self) -> &[String] {
+ &[]
+ }
+}
+
+/// AggregateUDF that adds an alias to the underlying function. It is better to
+/// implement [`AggregateUDFImpl`], which supports aliases, directly if
possible.
+#[derive(Debug)]
+struct AliasedAggregateUDFImpl {
+ inner: Arc<dyn AggregateUDFImpl>,
+ aliases: Vec<String>,
+}
+
+impl AliasedAggregateUDFImpl {
+ pub fn new(
+ inner: Arc<dyn AggregateUDFImpl>,
+ new_aliases: impl IntoIterator<Item = &'static str>,
+ ) -> Self {
+ let mut aliases = inner.aliases().to_vec();
+ aliases.extend(new_aliases.into_iter().map(|s| s.to_string()));
+
+ Self { inner, aliases }
+ }
+}
+
+impl AggregateUDFImpl for AliasedAggregateUDFImpl {
+ fn as_any(&self) -> &dyn Any {
+ self
+ }
+
+ fn name(&self) -> &str {
+ self.inner.name()
+ }
+
+ fn signature(&self) -> &Signature {
+ self.inner.signature()
+ }
+
+ fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
+ self.inner.return_type(arg_types)
+ }
+
+ fn accumulator(&self, arg: &DataType) -> Result<Box<dyn Accumulator>> {
+ self.inner.accumulator(arg)
+ }
+
+ fn state_type(&self, return_type: &DataType) -> Result<Vec<DataType>> {
+ self.inner.state_type(return_type)
+ }
+
+ fn aliases(&self) -> &[String] {
+ &self.aliases
+ }
}
/// Implementation of [`AggregateUDFImpl`] that wraps the function style
pointers
diff --git a/datafusion/expr/src/udwf.rs b/datafusion/expr/src/udwf.rs
index 3ab40fe70a..d3925f2e19 100644
--- a/datafusion/expr/src/udwf.rs
+++ b/datafusion/expr/src/udwf.rs
@@ -80,7 +80,7 @@ impl WindowUDF {
///
/// See [`WindowUDFImpl`] for a more convenient way to create a
/// `WindowUDF` using trait objects
- #[deprecated(since = "34.0.0", note = "please implement ScalarUDFImpl
instead")]
+ #[deprecated(since = "34.0.0", note = "please implement WindowUDFImpl
instead")]
pub fn new(
name: &str,
signature: &Signature,
@@ -112,6 +112,14 @@ impl WindowUDF {
self.inner.clone()
}
+ /// Adds additional names that can be used to invoke this function, in
+ /// addition to `name`
+ ///
+ /// If you implement [`WindowUDFImpl`] directly you should return aliases
directly.
+ pub fn with_aliases(self, aliases: impl IntoIterator<Item = &'static str>)
-> Self {
+ Self::new_from_impl(AliasedWindowUDFImpl::new(self.inner.clone(),
aliases))
+ }
+
/// creates a [`Expr`] that calls the window function given
/// the `partition_by`, `order_by`, and `window_frame` definition
///
@@ -143,6 +151,11 @@ impl WindowUDF {
self.inner.name()
}
+ /// Returns the aliases for this function.
+ pub fn aliases(&self) -> &[String] {
+ self.inner.aliases()
+ }
+
/// Returns this function's signature (what input types are accepted)
///
/// See [`WindowUDFImpl::signature`] for more details.
@@ -217,7 +230,7 @@ where
/// fn partition_evaluator(&self) -> Result<Box<dyn PartitionEvaluator>> {
unimplemented!() }
/// }
///
-/// // Create a new ScalarUDF from the implementation
+/// // Create a new WindowUDF from the implementation
/// let smooth_it = WindowUDF::from(SmoothIt::new());
///
/// // Call the function `add_one(col)`
@@ -245,6 +258,60 @@ pub trait WindowUDFImpl: Debug + Send + Sync {
/// Invoke the function, returning the [`PartitionEvaluator`] instance
fn partition_evaluator(&self) -> Result<Box<dyn PartitionEvaluator>>;
+
+ /// Returns any aliases (alternate names) for this function.
+ ///
+ /// Note: `aliases` should only include names other than [`Self::name`].
+ /// Defaults to `[]` (no aliases)
+ fn aliases(&self) -> &[String] {
+ &[]
+ }
+}
+
+/// WindowUDF that adds an alias to the underlying function. It is better to
+/// implement [`WindowUDFImpl`], which supports aliases, directly if possible.
+#[derive(Debug)]
+struct AliasedWindowUDFImpl {
+ inner: Arc<dyn WindowUDFImpl>,
+ aliases: Vec<String>,
+}
+
+impl AliasedWindowUDFImpl {
+ pub fn new(
+ inner: Arc<dyn WindowUDFImpl>,
+ new_aliases: impl IntoIterator<Item = &'static str>,
+ ) -> Self {
+ let mut aliases = inner.aliases().to_vec();
+ aliases.extend(new_aliases.into_iter().map(|s| s.to_string()));
+
+ Self { inner, aliases }
+ }
+}
+
+impl WindowUDFImpl for AliasedWindowUDFImpl {
+ fn as_any(&self) -> &dyn Any {
+ self
+ }
+
+ fn name(&self) -> &str {
+ self.inner.name()
+ }
+
+ fn signature(&self) -> &Signature {
+ self.inner.signature()
+ }
+
+ fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
+ self.inner.return_type(arg_types)
+ }
+
+ fn partition_evaluator(&self) -> Result<Box<dyn PartitionEvaluator>> {
+ self.inner.partition_evaluator()
+ }
+
+ fn aliases(&self) -> &[String] {
+ &self.aliases
+ }
}
/// Implementation of [`WindowUDFImpl`] that wraps the function style pointers