This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new abb0c1f62b UDAF and UDWF support aliases (#9489)
abb0c1f62b is described below

commit abb0c1f62bf622bd0e40769560cf0804dac2ecbf
Author: 张林伟 <[email protected]>
AuthorDate: Tue Mar 12 02:55:41 2024 +0800

    UDAF and UDWF support aliases (#9489)
    
    * UDAF and UDWF support aliases
    
    * Add tests for udaf and udwf aliases
    
    * Fix clippy lint
---
 datafusion/core/src/execution/context/mod.rs       | 22 ++++++-
 .../tests/user_defined/user_defined_aggregates.rs  | 37 +++++++++++
 .../user_defined/user_defined_window_functions.rs  | 39 ++++++++++++
 datafusion/execution/src/task.rs                   |  6 ++
 datafusion/expr/src/udaf.rs                        | 71 ++++++++++++++++++++++
 datafusion/expr/src/udwf.rs                        | 71 +++++++++++++++++++++-
 6 files changed, 242 insertions(+), 4 deletions(-)

diff --git a/datafusion/core/src/execution/context/mod.rs 
b/datafusion/core/src/execution/context/mod.rs
index dc4e39d37c..d64dd7c896 100644
--- a/datafusion/core/src/execution/context/mod.rs
+++ b/datafusion/core/src/execution/context/mod.rs
@@ -2155,10 +2155,16 @@ impl FunctionRegistry for SessionState {
         &mut self,
         udaf: Arc<AggregateUDF>,
     ) -> Result<Option<Arc<AggregateUDF>>> {
+        udaf.aliases().iter().for_each(|alias| {
+            self.aggregate_functions.insert(alias.clone(), udaf.clone());
+        });
         Ok(self.aggregate_functions.insert(udaf.name().into(), udaf))
     }
 
     fn register_udwf(&mut self, udwf: Arc<WindowUDF>) -> 
Result<Option<Arc<WindowUDF>>> {
+        udwf.aliases().iter().for_each(|alias| {
+            self.window_functions.insert(alias.clone(), udwf.clone());
+        });
         Ok(self.window_functions.insert(udwf.name().into(), udwf))
     }
 
@@ -2173,11 +2179,23 @@ impl FunctionRegistry for SessionState {
     }
 
     fn deregister_udaf(&mut self, name: &str) -> 
Result<Option<Arc<AggregateUDF>>> {
-        Ok(self.aggregate_functions.remove(name))
+        let udaf = self.aggregate_functions.remove(name);
+        if let Some(udaf) = &udaf {
+            for alias in udaf.aliases() {
+                self.aggregate_functions.remove(alias);
+            }
+        }
+        Ok(udaf)
     }
 
     fn deregister_udwf(&mut self, name: &str) -> 
Result<Option<Arc<WindowUDF>>> {
-        Ok(self.window_functions.remove(name))
+        let udwf = self.window_functions.remove(name);
+        if let Some(udwf) = &udwf {
+            for alias in udwf.aliases() {
+                self.window_functions.remove(alias);
+            }
+        }
+        Ok(udwf)
     }
 }
 
diff --git a/datafusion/core/tests/user_defined/user_defined_aggregates.rs 
b/datafusion/core/tests/user_defined/user_defined_aggregates.rs
index 9e231d25f2..3f40c55a3e 100644
--- a/datafusion/core/tests/user_defined/user_defined_aggregates.rs
+++ b/datafusion/core/tests/user_defined/user_defined_aggregates.rs
@@ -27,6 +27,7 @@ use std::sync::{
 };
 
 use datafusion::datasource::MemTable;
+use datafusion::test_util::plan_and_collect;
 use datafusion::{
     arrow::{
         array::{ArrayRef, Float64Array, TimestampNanosecondArray},
@@ -320,6 +321,42 @@ async fn 
case_sensitive_identifiers_user_defined_aggregates() -> Result<()> {
     Ok(())
 }
 
+#[tokio::test]
+async fn test_user_defined_functions_with_alias() -> Result<()> {
+    let ctx = SessionContext::new();
+    let arr = Int32Array::from(vec![1]);
+    let batch = RecordBatch::try_from_iter(vec![("i", Arc::new(arr) as _)])?;
+    ctx.register_batch("t", batch).unwrap();
+
+    let my_avg = create_udaf(
+        "dummy",
+        vec![DataType::Float64],
+        Arc::new(DataType::Float64),
+        Volatility::Immutable,
+        Arc::new(|_| Ok(Box::<AvgAccumulator>::default())),
+        Arc::new(vec![DataType::UInt64, DataType::Float64]),
+    )
+    .with_aliases(vec!["dummy_alias"]);
+
+    ctx.register_udaf(my_avg);
+
+    let expected = [
+        "+------------+",
+        "| dummy(t.i) |",
+        "+------------+",
+        "| 1.0        |",
+        "+------------+",
+    ];
+
+    let result = plan_and_collect(&ctx, "SELECT dummy(i) FROM t").await?;
+    assert_batches_eq!(expected, &result);
+
+    let alias_result = plan_and_collect(&ctx, "SELECT dummy_alias(i) FROM 
t").await?;
+    assert_batches_eq!(expected, &alias_result);
+
+    Ok(())
+}
+
 #[tokio::test]
 async fn test_groups_accumulator() -> Result<()> {
     let ctx = SessionContext::new();
diff --git 
a/datafusion/core/tests/user_defined/user_defined_window_functions.rs 
b/datafusion/core/tests/user_defined/user_defined_window_functions.rs
index cfd74f8861..3c607301fc 100644
--- a/datafusion/core/tests/user_defined/user_defined_window_functions.rs
+++ b/datafusion/core/tests/user_defined/user_defined_window_functions.rs
@@ -41,6 +41,10 @@ const UNBOUNDED_WINDOW_QUERY: &str = "SELECT x, y, val, \
      odd_counter(val) OVER (PARTITION BY x ORDER BY y) \
      from t ORDER BY x, y";
 
+const UNBOUNDED_WINDOW_QUERY_WITH_ALIAS: &str = "SELECT x, y, val, \
+     odd_counter_alias(val) OVER (PARTITION BY x ORDER BY y) \
+     from t ORDER BY x, y";
+
 /// A query with a window function evaluated over a moving window
 const BOUNDED_WINDOW_QUERY:  &str  =
     "SELECT x, y, val, \
@@ -118,6 +122,35 @@ async fn test_deregister_udwf() -> Result<()> {
     Ok(())
 }
 
+#[tokio::test]
+async fn test_udwf_with_alias() {
+    let test_state = TestState::new();
+    let TestContext { ctx, .. } = TestContext::new(test_state);
+
+    let expected = vec![
+        
"+---+---+-----+-----------------------------------------------------------------------------------------------------------------------+",
+        "| x | y | val | odd_counter(t.val) PARTITION BY [t.x] ORDER BY [t.y 
ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW |",
+        
"+---+---+-----+-----------------------------------------------------------------------------------------------------------------------+",
+        "| 1 | a | 0   | 1                                                     
                                                                |",
+        "| 1 | b | 1   | 1                                                     
                                                                |",
+        "| 1 | c | 2   | 1                                                     
                                                                |",
+        "| 2 | d | 3   | 2                                                     
                                                                |",
+        "| 2 | e | 4   | 2                                                     
                                                                |",
+        "| 2 | f | 5   | 2                                                     
                                                                |",
+        "| 2 | g | 6   | 2                                                     
                                                                |",
+        "| 2 | h | 6   | 2                                                     
                                                                |",
+        "| 2 | i | 6   | 2                                                     
                                                                |",
+        "| 2 | j | 6   | 2                                                     
                                                                |",
+        
"+---+---+-----+-----------------------------------------------------------------------------------------------------------------------+",
+    ];
+    assert_batches_eq!(
+        expected,
+        &execute(&ctx, UNBOUNDED_WINDOW_QUERY_WITH_ALIAS)
+            .await
+            .unwrap()
+    );
+}
+
 /// Basic user defined window function with bounded window
 #[tokio::test]
 async fn test_udwf_bounded_window_ignores_frame() {
@@ -491,6 +524,7 @@ impl OddCounter {
             signature: Signature,
             return_type: DataType,
             test_state: Arc<TestState>,
+            aliases: Vec<String>,
         }
 
         impl SimpleWindowUDF {
@@ -502,6 +536,7 @@ impl OddCounter {
                     signature,
                     return_type,
                     test_state,
+                    aliases: vec!["odd_counter_alias".to_string()],
                 }
             }
         }
@@ -526,6 +561,10 @@ impl OddCounter {
             fn partition_evaluator(&self) -> Result<Box<dyn 
PartitionEvaluator>> {
                 Ok(Box::new(OddCounter::new(Arc::clone(&self.test_state))))
             }
+
+            fn aliases(&self) -> &[String] {
+                &self.aliases
+            }
         }
 
         ctx.register_udwf(WindowUDF::from(SimpleWindowUDF::new(test_state)))
diff --git a/datafusion/execution/src/task.rs b/datafusion/execution/src/task.rs
index b39b4a0032..cae410655d 100644
--- a/datafusion/execution/src/task.rs
+++ b/datafusion/execution/src/task.rs
@@ -207,9 +207,15 @@ impl FunctionRegistry for TaskContext {
         &mut self,
         udaf: Arc<AggregateUDF>,
     ) -> Result<Option<Arc<AggregateUDF>>> {
+        udaf.aliases().iter().for_each(|alias| {
+            self.aggregate_functions.insert(alias.clone(), udaf.clone());
+        });
         Ok(self.aggregate_functions.insert(udaf.name().into(), udaf))
     }
     fn register_udwf(&mut self, udwf: Arc<WindowUDF>) -> 
Result<Option<Arc<WindowUDF>>> {
+        udwf.aliases().iter().for_each(|alias| {
+            self.window_functions.insert(alias.clone(), udwf.clone());
+        });
         Ok(self.window_functions.insert(udwf.name().into(), udwf))
     }
     fn register_udf(&mut self, udf: Arc<ScalarUDF>) -> 
Result<Option<Arc<ScalarUDF>>> {
diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs
index e56723063e..c46dd9cd3a 100644
--- a/datafusion/expr/src/udaf.rs
+++ b/datafusion/expr/src/udaf.rs
@@ -118,6 +118,14 @@ impl AggregateUDF {
         self.inner.clone()
     }
 
+    /// Adds additional names that can be used to invoke this function, in
+    /// addition to `name`
+    ///
+    /// If you implement [`AggregateUDFImpl`] directly you should return 
aliases directly.
+    pub fn with_aliases(self, aliases: impl IntoIterator<Item = &'static str>) 
-> Self {
+        Self::new_from_impl(AliasedAggregateUDFImpl::new(self.inner.clone(), 
aliases))
+    }
+
     /// creates an [`Expr`] that calls the aggregate function.
     ///
     /// This utility allows using the UDAF without requiring access to
@@ -139,6 +147,11 @@ impl AggregateUDF {
         self.inner.name()
     }
 
+    /// Returns the aliases for this function.
+    pub fn aliases(&self) -> &[String] {
+        self.inner.aliases()
+    }
+
     /// Returns this function's signature (what input types are accepted)
     ///
     /// See [`AggregateUDFImpl::signature`] for more details.
@@ -277,6 +290,64 @@ pub trait AggregateUDFImpl: Debug + Send + Sync {
     fn create_groups_accumulator(&self) -> Result<Box<dyn GroupsAccumulator>> {
         not_impl_err!("GroupsAccumulator hasn't been implemented for {self:?} 
yet")
     }
+
+    /// Returns any aliases (alternate names) for this function.
+    ///
+    /// Note: `aliases` should only include names other than [`Self::name`].
+    /// Defaults to `[]` (no aliases)
+    fn aliases(&self) -> &[String] {
+        &[]
+    }
+}
+
+/// AggregateUDF that adds an alias to the underlying function. It is better to
+/// implement [`AggregateUDFImpl`], which supports aliases, directly if 
possible.
+#[derive(Debug)]
+struct AliasedAggregateUDFImpl {
+    inner: Arc<dyn AggregateUDFImpl>,
+    aliases: Vec<String>,
+}
+
+impl AliasedAggregateUDFImpl {
+    pub fn new(
+        inner: Arc<dyn AggregateUDFImpl>,
+        new_aliases: impl IntoIterator<Item = &'static str>,
+    ) -> Self {
+        let mut aliases = inner.aliases().to_vec();
+        aliases.extend(new_aliases.into_iter().map(|s| s.to_string()));
+
+        Self { inner, aliases }
+    }
+}
+
+impl AggregateUDFImpl for AliasedAggregateUDFImpl {
+    fn as_any(&self) -> &dyn Any {
+        self
+    }
+
+    fn name(&self) -> &str {
+        self.inner.name()
+    }
+
+    fn signature(&self) -> &Signature {
+        self.inner.signature()
+    }
+
+    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
+        self.inner.return_type(arg_types)
+    }
+
+    fn accumulator(&self, arg: &DataType) -> Result<Box<dyn Accumulator>> {
+        self.inner.accumulator(arg)
+    }
+
+    fn state_type(&self, return_type: &DataType) -> Result<Vec<DataType>> {
+        self.inner.state_type(return_type)
+    }
+
+    fn aliases(&self) -> &[String] {
+        &self.aliases
+    }
 }
 
 /// Implementation of [`AggregateUDFImpl`] that wraps the function style 
pointers
diff --git a/datafusion/expr/src/udwf.rs b/datafusion/expr/src/udwf.rs
index 3ab40fe70a..d3925f2e19 100644
--- a/datafusion/expr/src/udwf.rs
+++ b/datafusion/expr/src/udwf.rs
@@ -80,7 +80,7 @@ impl WindowUDF {
     ///
     /// See [`WindowUDFImpl`] for a more convenient way to create a
     /// `WindowUDF` using trait objects
-    #[deprecated(since = "34.0.0", note = "please implement ScalarUDFImpl 
instead")]
+    #[deprecated(since = "34.0.0", note = "please implement WindowUDFImpl 
instead")]
     pub fn new(
         name: &str,
         signature: &Signature,
@@ -112,6 +112,14 @@ impl WindowUDF {
         self.inner.clone()
     }
 
+    /// Adds additional names that can be used to invoke this function, in
+    /// addition to `name`
+    ///
+    /// If you implement [`WindowUDFImpl`] directly you should return aliases 
directly.
+    pub fn with_aliases(self, aliases: impl IntoIterator<Item = &'static str>) 
-> Self {
+        Self::new_from_impl(AliasedWindowUDFImpl::new(self.inner.clone(), 
aliases))
+    }
+
     /// creates a [`Expr`] that calls the window function given
     /// the `partition_by`, `order_by`, and `window_frame` definition
     ///
@@ -143,6 +151,11 @@ impl WindowUDF {
         self.inner.name()
     }
 
+    /// Returns the aliases for this function.
+    pub fn aliases(&self) -> &[String] {
+        self.inner.aliases()
+    }
+
     /// Returns this function's signature (what input types are accepted)
     ///
     /// See [`WindowUDFImpl::signature`] for more details.
@@ -217,7 +230,7 @@ where
 ///    fn partition_evaluator(&self) -> Result<Box<dyn PartitionEvaluator>> { 
unimplemented!() }
 /// }
 ///
-/// // Create a new ScalarUDF from the implementation
+/// // Create a new WindowUDF from the implementation
 /// let smooth_it = WindowUDF::from(SmoothIt::new());
 ///
 /// // Call the function `add_one(col)`
@@ -245,6 +258,60 @@ pub trait WindowUDFImpl: Debug + Send + Sync {
 
     /// Invoke the function, returning the [`PartitionEvaluator`] instance
     fn partition_evaluator(&self) -> Result<Box<dyn PartitionEvaluator>>;
+
+    /// Returns any aliases (alternate names) for this function.
+    ///
+    /// Note: `aliases` should only include names other than [`Self::name`].
+    /// Defaults to `[]` (no aliases)
+    fn aliases(&self) -> &[String] {
+        &[]
+    }
+}
+
+/// WindowUDF that adds an alias to the underlying function. It is better to
+/// implement [`WindowUDFImpl`], which supports aliases, directly if possible.
+#[derive(Debug)]
+struct AliasedWindowUDFImpl {
+    inner: Arc<dyn WindowUDFImpl>,
+    aliases: Vec<String>,
+}
+
+impl AliasedWindowUDFImpl {
+    pub fn new(
+        inner: Arc<dyn WindowUDFImpl>,
+        new_aliases: impl IntoIterator<Item = &'static str>,
+    ) -> Self {
+        let mut aliases = inner.aliases().to_vec();
+        aliases.extend(new_aliases.into_iter().map(|s| s.to_string()));
+
+        Self { inner, aliases }
+    }
+}
+
+impl WindowUDFImpl for AliasedWindowUDFImpl {
+    fn as_any(&self) -> &dyn Any {
+        self
+    }
+
+    fn name(&self) -> &str {
+        self.inner.name()
+    }
+
+    fn signature(&self) -> &Signature {
+        self.inner.signature()
+    }
+
+    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
+        self.inner.return_type(arg_types)
+    }
+
+    fn partition_evaluator(&self) -> Result<Box<dyn PartitionEvaluator>> {
+        self.inner.partition_evaluator()
+    }
+
+    fn aliases(&self) -> &[String] {
+        &self.aliases
+    }
 }
 
 /// Implementation of [`WindowUDFImpl`] that wraps the function style pointers

Reply via email to