This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new a873f51563 Convert `StringAgg` to UDAF (#10945)
a873f51563 is described below

commit a873f5156364f4357592c4bc9117887916e606f7
Author: 张林伟 <[email protected]>
AuthorDate: Tue Jun 18 22:08:13 2024 +0800

    Convert `StringAgg` to UDAF (#10945)
    
    * Convert StringAgg to UDAF
    
    * generate proto code
    
    * Fix bug
    
    * Fix
    
    * Add license
    
    * Add doc
    
    * Fix clippy
    
    * Remove aliases field
    
    * Add StringAgg proto test
    
    * Add roundtrip_expr_api test
---
 datafusion/expr/src/aggregate_function.rs          |   8 -
 datafusion/expr/src/type_coercion/aggregates.rs    |  26 ---
 datafusion/functions-aggregate/src/lib.rs          |   2 +
 datafusion/functions-aggregate/src/string_agg.rs   | 153 +++++++++++++
 datafusion/physical-expr/src/aggregate/build_in.rs |  16 --
 datafusion/physical-expr/src/aggregate/mod.rs      |   1 -
 .../physical-expr/src/aggregate/string_agg.rs      | 246 ---------------------
 datafusion/physical-expr/src/expressions/mod.rs    |   1 -
 datafusion/proto/proto/datafusion.proto            |   2 +-
 datafusion/proto/src/generated/pbjson.rs           |   3 -
 datafusion/proto/src/generated/prost.rs            |   4 +-
 datafusion/proto/src/logical_plan/from_proto.rs    |   1 -
 datafusion/proto/src/logical_plan/to_proto.rs      |   4 -
 datafusion/proto/src/physical_plan/to_proto.rs     |   5 +-
 .../proto/tests/cases/roundtrip_logical_plan.rs    |   2 +
 .../proto/tests/cases/roundtrip_physical_plan.rs   |  23 +-
 datafusion/sqllogictest/test_files/aggregate.slt   |  16 ++
 17 files changed, 192 insertions(+), 321 deletions(-)

diff --git a/datafusion/expr/src/aggregate_function.rs 
b/datafusion/expr/src/aggregate_function.rs
index a7fbf26feb..1cde1c5050 100644
--- a/datafusion/expr/src/aggregate_function.rs
+++ b/datafusion/expr/src/aggregate_function.rs
@@ -51,8 +51,6 @@ pub enum AggregateFunction {
     BoolAnd,
     /// Bool Or
     BoolOr,
-    /// String aggregation
-    StringAgg,
 }
 
 impl AggregateFunction {
@@ -68,7 +66,6 @@ impl AggregateFunction {
             Grouping => "GROUPING",
             BoolAnd => "BOOL_AND",
             BoolOr => "BOOL_OR",
-            StringAgg => "STRING_AGG",
         }
     }
 }
@@ -92,7 +89,6 @@ impl FromStr for AggregateFunction {
             "min" => AggregateFunction::Min,
             "array_agg" => AggregateFunction::ArrayAgg,
             "nth_value" => AggregateFunction::NthValue,
-            "string_agg" => AggregateFunction::StringAgg,
             // statistical
             "corr" => AggregateFunction::Correlation,
             // other
@@ -146,7 +142,6 @@ impl AggregateFunction {
             )))),
             AggregateFunction::Grouping => Ok(DataType::Int32),
             AggregateFunction::NthValue => Ok(coerced_data_types[0].clone()),
-            AggregateFunction::StringAgg => Ok(DataType::LargeUtf8),
         }
     }
 }
@@ -195,9 +190,6 @@ impl AggregateFunction {
             AggregateFunction::Correlation => {
                 Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable)
             }
-            AggregateFunction::StringAgg => {
-                Signature::uniform(2, STRINGS.to_vec(), Volatility::Immutable)
-            }
         }
     }
 }
diff --git a/datafusion/expr/src/type_coercion/aggregates.rs 
b/datafusion/expr/src/type_coercion/aggregates.rs
index a216c98899..abe6d8b182 100644
--- a/datafusion/expr/src/type_coercion/aggregates.rs
+++ b/datafusion/expr/src/type_coercion/aggregates.rs
@@ -145,23 +145,6 @@ pub fn coerce_types(
         }
         AggregateFunction::NthValue => Ok(input_types.to_vec()),
         AggregateFunction::Grouping => Ok(vec![input_types[0].clone()]),
-        AggregateFunction::StringAgg => {
-            if !is_string_agg_supported_arg_type(&input_types[0]) {
-                return plan_err!(
-                    "The function {:?} does not support inputs of type {:?}",
-                    agg_fun,
-                    input_types[0]
-                );
-            }
-            if !is_string_agg_supported_arg_type(&input_types[1]) {
-                return plan_err!(
-                    "The function {:?} does not support inputs of type {:?}",
-                    agg_fun,
-                    input_types[1]
-                );
-            }
-            Ok(vec![LargeUtf8, input_types[1].clone()])
-        }
     }
 }
 
@@ -391,15 +374,6 @@ pub fn is_integer_arg_type(arg_type: &DataType) -> bool {
     arg_type.is_integer()
 }
 
-/// Return `true` if `arg_type` is of a [`DataType`] that the
-/// [`AggregateFunction::StringAgg`] aggregation can operate on.
-pub fn is_string_agg_supported_arg_type(arg_type: &DataType) -> bool {
-    matches!(
-        arg_type,
-        DataType::Utf8 | DataType::LargeUtf8 | DataType::Null
-    )
-}
-
 #[cfg(test)]
 mod tests {
     use super::*;
diff --git a/datafusion/functions-aggregate/src/lib.rs 
b/datafusion/functions-aggregate/src/lib.rs
index 990303bd1d..20a8d2c159 100644
--- a/datafusion/functions-aggregate/src/lib.rs
+++ b/datafusion/functions-aggregate/src/lib.rs
@@ -70,6 +70,7 @@ pub mod approx_median;
 pub mod approx_percentile_cont;
 pub mod approx_percentile_cont_with_weight;
 pub mod bit_and_or_xor;
+pub mod string_agg;
 
 use crate::approx_percentile_cont::approx_percentile_cont_udaf;
 use 
crate::approx_percentile_cont_with_weight::approx_percentile_cont_with_weight_udaf;
@@ -138,6 +139,7 @@ pub fn all_default_aggregate_functions() -> 
Vec<Arc<AggregateUDF>> {
         approx_distinct::approx_distinct_udaf(),
         approx_percentile_cont_udaf(),
         approx_percentile_cont_with_weight_udaf(),
+        string_agg::string_agg_udaf(),
         bit_and_or_xor::bit_and_udaf(),
         bit_and_or_xor::bit_or_udaf(),
         bit_and_or_xor::bit_xor_udaf(),
diff --git a/datafusion/functions-aggregate/src/string_agg.rs 
b/datafusion/functions-aggregate/src/string_agg.rs
new file mode 100644
index 0000000000..371cc8fb97
--- /dev/null
+++ b/datafusion/functions-aggregate/src/string_agg.rs
@@ -0,0 +1,153 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! [`StringAgg`] and [`StringAggAccumulator`] accumulator for the 
`string_agg` function
+
+use arrow::array::ArrayRef;
+use arrow_schema::DataType;
+use datafusion_common::cast::as_generic_string_array;
+use datafusion_common::Result;
+use datafusion_common::{not_impl_err, ScalarValue};
+use datafusion_expr::function::AccumulatorArgs;
+use datafusion_expr::{
+    Accumulator, AggregateUDFImpl, Expr, Signature, TypeSignature, Volatility,
+};
+use std::any::Any;
+
+make_udaf_expr_and_func!(
+    StringAgg,
+    string_agg,
+    expr delimiter,
+    "Concatenates the values of string expressions and places separator values 
between them",
+    string_agg_udaf
+);
+
+/// STRING_AGG aggregate expression
+#[derive(Debug)]
+pub struct StringAgg {
+    signature: Signature,
+}
+
+impl StringAgg {
+    /// Create a new StringAgg aggregate function
+    pub fn new() -> Self {
+        Self {
+            signature: Signature::one_of(
+                vec![
+                    TypeSignature::Exact(vec![DataType::LargeUtf8, 
DataType::Utf8]),
+                    TypeSignature::Exact(vec![DataType::LargeUtf8, 
DataType::LargeUtf8]),
+                    TypeSignature::Exact(vec![DataType::LargeUtf8, 
DataType::Null]),
+                ],
+                Volatility::Immutable,
+            ),
+        }
+    }
+}
+
+impl Default for StringAgg {
+    fn default() -> Self {
+        Self::new()
+    }
+}
+
+impl AggregateUDFImpl for StringAgg {
+    fn as_any(&self) -> &dyn Any {
+        self
+    }
+
+    fn name(&self) -> &str {
+        "string_agg"
+    }
+
+    fn signature(&self) -> &Signature {
+        &self.signature
+    }
+
+    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
+        Ok(DataType::LargeUtf8)
+    }
+
+    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn 
Accumulator>> {
+        match &acc_args.input_exprs[1] {
+            Expr::Literal(ScalarValue::Utf8(Some(delimiter)))
+            | Expr::Literal(ScalarValue::LargeUtf8(Some(delimiter))) => {
+                Ok(Box::new(StringAggAccumulator::new(delimiter)))
+            }
+            Expr::Literal(ScalarValue::Utf8(None))
+            | Expr::Literal(ScalarValue::LargeUtf8(None))
+            | Expr::Literal(ScalarValue::Null) => {
+                Ok(Box::new(StringAggAccumulator::new("")))
+            }
+            _ => not_impl_err!(
+                "StringAgg not supported for delimiter {}",
+                &acc_args.input_exprs[1]
+            ),
+        }
+    }
+}
+
+#[derive(Debug)]
+pub(crate) struct StringAggAccumulator {
+    values: Option<String>,
+    delimiter: String,
+}
+
+impl StringAggAccumulator {
+    pub fn new(delimiter: &str) -> Self {
+        Self {
+            values: None,
+            delimiter: delimiter.to_string(),
+        }
+    }
+}
+
+impl Accumulator for StringAggAccumulator {
+    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
+        let string_array: Vec<_> = as_generic_string_array::<i64>(&values[0])?
+            .iter()
+            .filter_map(|v| v.as_ref().map(ToString::to_string))
+            .collect();
+        if !string_array.is_empty() {
+            let s = string_array.join(self.delimiter.as_str());
+            let v = self.values.get_or_insert("".to_string());
+            if !v.is_empty() {
+                v.push_str(self.delimiter.as_str());
+            }
+            v.push_str(s.as_str());
+        }
+        Ok(())
+    }
+
+    fn merge_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
+        self.update_batch(values)?;
+        Ok(())
+    }
+
+    fn state(&mut self) -> Result<Vec<ScalarValue>> {
+        Ok(vec![self.evaluate()?])
+    }
+
+    fn evaluate(&mut self) -> Result<ScalarValue> {
+        Ok(ScalarValue::LargeUtf8(self.values.clone()))
+    }
+
+    fn size(&self) -> usize {
+        std::mem::size_of_val(self)
+            + self.values.as_ref().map(|v| v.capacity()).unwrap_or(0)
+            + self.delimiter.capacity()
+    }
+}
diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs 
b/datafusion/physical-expr/src/aggregate/build_in.rs
index 6c01decdbf..1dfe9ffd69 100644
--- a/datafusion/physical-expr/src/aggregate/build_in.rs
+++ b/datafusion/physical-expr/src/aggregate/build_in.rs
@@ -155,22 +155,6 @@ pub fn create_aggregate_expr(
                 ordering_req.to_vec(),
             ))
         }
-        (AggregateFunction::StringAgg, false) => {
-            if !ordering_req.is_empty() {
-                return not_impl_err!(
-                    "STRING_AGG(ORDER BY a ASC) order-sensitive aggregations 
are not available"
-                );
-            }
-            Arc::new(expressions::StringAgg::new(
-                input_phy_exprs[0].clone(),
-                input_phy_exprs[1].clone(),
-                name,
-                data_type,
-            ))
-        }
-        (AggregateFunction::StringAgg, true) => {
-            return not_impl_err!("STRING_AGG(DISTINCT) aggregations are not 
available");
-        }
     })
 }
 
diff --git a/datafusion/physical-expr/src/aggregate/mod.rs 
b/datafusion/physical-expr/src/aggregate/mod.rs
index 0b1f5f5774..87c7deccc2 100644
--- a/datafusion/physical-expr/src/aggregate/mod.rs
+++ b/datafusion/physical-expr/src/aggregate/mod.rs
@@ -26,7 +26,6 @@ pub(crate) mod correlation;
 pub(crate) mod covariance;
 pub(crate) mod grouping;
 pub(crate) mod nth_value;
-pub(crate) mod string_agg;
 #[macro_use]
 pub(crate) mod min_max;
 pub(crate) mod groups_accumulator;
diff --git a/datafusion/physical-expr/src/aggregate/string_agg.rs 
b/datafusion/physical-expr/src/aggregate/string_agg.rs
deleted file mode 100644
index dc0ffc5579..0000000000
--- a/datafusion/physical-expr/src/aggregate/string_agg.rs
+++ /dev/null
@@ -1,246 +0,0 @@
-// Licensed to the Apache Software Foundation (ASF) under one
-// or more contributor license agreements.  See the NOTICE file
-// distributed with this work for additional information
-// regarding copyright ownership.  The ASF licenses this file
-// to you under the Apache License, Version 2.0 (the
-// "License"); you may not use this file except in compliance
-// with the License.  You may obtain a copy of the License at
-//
-//   http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing,
-// software distributed under the License is distributed on an
-// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-// KIND, either express or implied.  See the License for the
-// specific language governing permissions and limitations
-// under the License.
-
-//! [`StringAgg`] and [`StringAggAccumulator`] accumulator for the 
`string_agg` function
-
-use crate::aggregate::utils::down_cast_any_ref;
-use crate::expressions::{format_state_name, Literal};
-use crate::{AggregateExpr, PhysicalExpr};
-use arrow::array::ArrayRef;
-use arrow::datatypes::{DataType, Field};
-use datafusion_common::cast::as_generic_string_array;
-use datafusion_common::{not_impl_err, Result, ScalarValue};
-use datafusion_expr::Accumulator;
-use std::any::Any;
-use std::sync::Arc;
-
-/// STRING_AGG aggregate expression
-#[derive(Debug)]
-pub struct StringAgg {
-    name: String,
-    data_type: DataType,
-    expr: Arc<dyn PhysicalExpr>,
-    delimiter: Arc<dyn PhysicalExpr>,
-    nullable: bool,
-}
-
-impl StringAgg {
-    /// Create a new StringAgg aggregate function
-    pub fn new(
-        expr: Arc<dyn PhysicalExpr>,
-        delimiter: Arc<dyn PhysicalExpr>,
-        name: impl Into<String>,
-        data_type: DataType,
-    ) -> Self {
-        Self {
-            name: name.into(),
-            data_type,
-            delimiter,
-            expr,
-            nullable: true,
-        }
-    }
-}
-
-impl AggregateExpr for StringAgg {
-    fn as_any(&self) -> &dyn Any {
-        self
-    }
-
-    fn field(&self) -> Result<Field> {
-        Ok(Field::new(
-            &self.name,
-            self.data_type.clone(),
-            self.nullable,
-        ))
-    }
-
-    fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
-        if let Some(delimiter) = 
self.delimiter.as_any().downcast_ref::<Literal>() {
-            match delimiter.value() {
-                ScalarValue::Utf8(Some(delimiter))
-                | ScalarValue::LargeUtf8(Some(delimiter)) => {
-                    return Ok(Box::new(StringAggAccumulator::new(delimiter)));
-                }
-                ScalarValue::Null => {
-                    return Ok(Box::new(StringAggAccumulator::new("")));
-                }
-                _ => return not_impl_err!("StringAgg not supported for {}", 
self.name),
-            }
-        }
-        not_impl_err!("StringAgg not supported for {}", self.name)
-    }
-
-    fn state_fields(&self) -> Result<Vec<Field>> {
-        Ok(vec![Field::new(
-            format_state_name(&self.name, "string_agg"),
-            self.data_type.clone(),
-            self.nullable,
-        )])
-    }
-
-    fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
-        vec![self.expr.clone(), self.delimiter.clone()]
-    }
-
-    fn name(&self) -> &str {
-        &self.name
-    }
-}
-
-impl PartialEq<dyn Any> for StringAgg {
-    fn eq(&self, other: &dyn Any) -> bool {
-        down_cast_any_ref(other)
-            .downcast_ref::<Self>()
-            .map(|x| {
-                self.name == x.name
-                    && self.data_type == x.data_type
-                    && self.expr.eq(&x.expr)
-                    && self.delimiter.eq(&x.delimiter)
-            })
-            .unwrap_or(false)
-    }
-}
-
-#[derive(Debug)]
-pub(crate) struct StringAggAccumulator {
-    values: Option<String>,
-    delimiter: String,
-}
-
-impl StringAggAccumulator {
-    pub fn new(delimiter: &str) -> Self {
-        Self {
-            values: None,
-            delimiter: delimiter.to_string(),
-        }
-    }
-}
-
-impl Accumulator for StringAggAccumulator {
-    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
-        let string_array: Vec<_> = as_generic_string_array::<i64>(&values[0])?
-            .iter()
-            .filter_map(|v| v.as_ref().map(ToString::to_string))
-            .collect();
-        if !string_array.is_empty() {
-            let s = string_array.join(self.delimiter.as_str());
-            let v = self.values.get_or_insert("".to_string());
-            if !v.is_empty() {
-                v.push_str(self.delimiter.as_str());
-            }
-            v.push_str(s.as_str());
-        }
-        Ok(())
-    }
-
-    fn merge_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
-        self.update_batch(values)?;
-        Ok(())
-    }
-
-    fn state(&mut self) -> Result<Vec<ScalarValue>> {
-        Ok(vec![self.evaluate()?])
-    }
-
-    fn evaluate(&mut self) -> Result<ScalarValue> {
-        Ok(ScalarValue::LargeUtf8(self.values.clone()))
-    }
-
-    fn size(&self) -> usize {
-        std::mem::size_of_val(self)
-            + self.values.as_ref().map(|v| v.capacity()).unwrap_or(0)
-            + self.delimiter.capacity()
-    }
-}
-
-#[cfg(test)]
-mod tests {
-    use super::*;
-    use crate::expressions::tests::aggregate;
-    use crate::expressions::{col, create_aggregate_expr, try_cast};
-    use arrow::datatypes::*;
-    use arrow::record_batch::RecordBatch;
-    use arrow_array::LargeStringArray;
-    use arrow_array::StringArray;
-    use datafusion_expr::type_coercion::aggregates::coerce_types;
-    use datafusion_expr::AggregateFunction;
-
-    fn assert_string_aggregate(
-        array: ArrayRef,
-        function: AggregateFunction,
-        distinct: bool,
-        expected: ScalarValue,
-        delimiter: String,
-    ) {
-        let data_type = array.data_type();
-        let sig = function.signature();
-        let coerced =
-            coerce_types(&function, &[data_type.clone(), DataType::Utf8], 
&sig).unwrap();
-
-        let input_schema = Schema::new(vec![Field::new("a", data_type.clone(), 
true)]);
-        let batch =
-            RecordBatch::try_new(Arc::new(input_schema.clone()), 
vec![array]).unwrap();
-
-        let input = try_cast(
-            col("a", &input_schema).unwrap(),
-            &input_schema,
-            coerced[0].clone(),
-        )
-        .unwrap();
-
-        let delimiter = Arc::new(Literal::new(ScalarValue::from(delimiter)));
-        let schema = Schema::new(vec![Field::new("a", coerced[0].clone(), 
true)]);
-        let agg = create_aggregate_expr(
-            &function,
-            distinct,
-            &[input, delimiter],
-            &[],
-            &schema,
-            "agg",
-            false,
-        )
-        .unwrap();
-
-        let result = aggregate(&batch, agg).unwrap();
-        assert_eq!(expected, result);
-    }
-
-    #[test]
-    fn string_agg_utf8() {
-        let a: ArrayRef = Arc::new(StringArray::from(vec!["h", "e", "l", "l", 
"o"]));
-        assert_string_aggregate(
-            a,
-            AggregateFunction::StringAgg,
-            false,
-            ScalarValue::LargeUtf8(Some("h,e,l,l,o".to_owned())),
-            ",".to_owned(),
-        );
-    }
-
-    #[test]
-    fn string_agg_largeutf8() {
-        let a: ArrayRef = Arc::new(LargeStringArray::from(vec!["h", "e", "l", 
"l", "o"]));
-        assert_string_aggregate(
-            a,
-            AggregateFunction::StringAgg,
-            false,
-            ScalarValue::LargeUtf8(Some("h|e|l|l|o".to_owned())),
-            "|".to_owned(),
-        );
-    }
-}
diff --git a/datafusion/physical-expr/src/expressions/mod.rs 
b/datafusion/physical-expr/src/expressions/mod.rs
index bffaafd7da..3226104040 100644
--- a/datafusion/physical-expr/src/expressions/mod.rs
+++ b/datafusion/physical-expr/src/expressions/mod.rs
@@ -47,7 +47,6 @@ pub use crate::aggregate::grouping::Grouping;
 pub use crate::aggregate::min_max::{Max, MaxAccumulator, Min, MinAccumulator};
 pub use crate::aggregate::nth_value::NthValueAgg;
 pub use crate::aggregate::stats::StatsType;
-pub use crate::aggregate::string_agg::StringAgg;
 pub use crate::window::cume_dist::{cume_dist, CumeDist};
 pub use crate::window::lead_lag::{lag, lead, WindowShift};
 pub use crate::window::nth_value::NthValue;
diff --git a/datafusion/proto/proto/datafusion.proto 
b/datafusion/proto/proto/datafusion.proto
index ae4445eaa8..6375df721a 100644
--- a/datafusion/proto/proto/datafusion.proto
+++ b/datafusion/proto/proto/datafusion.proto
@@ -505,7 +505,7 @@ enum AggregateFunction {
   // REGR_SXX = 32;
   // REGR_SYY = 33;
   // REGR_SXY = 34;
-  STRING_AGG = 35;
+  // STRING_AGG = 35;
   NTH_VALUE_AGG = 36;
 }
 
diff --git a/datafusion/proto/src/generated/pbjson.rs 
b/datafusion/proto/src/generated/pbjson.rs
index 243c75435f..5c483f70d1 100644
--- a/datafusion/proto/src/generated/pbjson.rs
+++ b/datafusion/proto/src/generated/pbjson.rs
@@ -540,7 +540,6 @@ impl serde::Serialize for AggregateFunction {
             Self::Grouping => "GROUPING",
             Self::BoolAnd => "BOOL_AND",
             Self::BoolOr => "BOOL_OR",
-            Self::StringAgg => "STRING_AGG",
             Self::NthValueAgg => "NTH_VALUE_AGG",
         };
         serializer.serialize_str(variant)
@@ -561,7 +560,6 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction {
             "GROUPING",
             "BOOL_AND",
             "BOOL_OR",
-            "STRING_AGG",
             "NTH_VALUE_AGG",
         ];
 
@@ -611,7 +609,6 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction {
                     "GROUPING" => Ok(AggregateFunction::Grouping),
                     "BOOL_AND" => Ok(AggregateFunction::BoolAnd),
                     "BOOL_OR" => Ok(AggregateFunction::BoolOr),
-                    "STRING_AGG" => Ok(AggregateFunction::StringAgg),
                     "NTH_VALUE_AGG" => Ok(AggregateFunction::NthValueAgg),
                     _ => Err(serde::de::Error::unknown_variant(value, FIELDS)),
                 }
diff --git a/datafusion/proto/src/generated/prost.rs 
b/datafusion/proto/src/generated/prost.rs
index 1172eccb90..bc5b6be2ad 100644
--- a/datafusion/proto/src/generated/prost.rs
+++ b/datafusion/proto/src/generated/prost.rs
@@ -1959,7 +1959,7 @@ pub enum AggregateFunction {
     /// REGR_SXX = 32;
     /// REGR_SYY = 33;
     /// REGR_SXY = 34;
-    StringAgg = 35,
+    /// STRING_AGG = 35;
     NthValueAgg = 36,
 }
 impl AggregateFunction {
@@ -1977,7 +1977,6 @@ impl AggregateFunction {
             AggregateFunction::Grouping => "GROUPING",
             AggregateFunction::BoolAnd => "BOOL_AND",
             AggregateFunction::BoolOr => "BOOL_OR",
-            AggregateFunction::StringAgg => "STRING_AGG",
             AggregateFunction::NthValueAgg => "NTH_VALUE_AGG",
         }
     }
@@ -1992,7 +1991,6 @@ impl AggregateFunction {
             "GROUPING" => Some(Self::Grouping),
             "BOOL_AND" => Some(Self::BoolAnd),
             "BOOL_OR" => Some(Self::BoolOr),
-            "STRING_AGG" => Some(Self::StringAgg),
             "NTH_VALUE_AGG" => Some(Self::NthValueAgg),
             _ => None,
         }
diff --git a/datafusion/proto/src/logical_plan/from_proto.rs 
b/datafusion/proto/src/logical_plan/from_proto.rs
index 43cc352f98..5bec655bb1 100644
--- a/datafusion/proto/src/logical_plan/from_proto.rs
+++ b/datafusion/proto/src/logical_plan/from_proto.rs
@@ -146,7 +146,6 @@ impl From<protobuf::AggregateFunction> for 
AggregateFunction {
             protobuf::AggregateFunction::Correlation => Self::Correlation,
             protobuf::AggregateFunction::Grouping => Self::Grouping,
             protobuf::AggregateFunction::NthValueAgg => Self::NthValue,
-            protobuf::AggregateFunction::StringAgg => Self::StringAgg,
         }
     }
 }
diff --git a/datafusion/proto/src/logical_plan/to_proto.rs 
b/datafusion/proto/src/logical_plan/to_proto.rs
index 33a58daeaf..66b7c77799 100644
--- a/datafusion/proto/src/logical_plan/to_proto.rs
+++ b/datafusion/proto/src/logical_plan/to_proto.rs
@@ -117,7 +117,6 @@ impl From<&AggregateFunction> for 
protobuf::AggregateFunction {
             AggregateFunction::Correlation => Self::Correlation,
             AggregateFunction::Grouping => Self::Grouping,
             AggregateFunction::NthValue => Self::NthValueAgg,
-            AggregateFunction::StringAgg => Self::StringAgg,
         }
     }
 }
@@ -387,9 +386,6 @@ pub fn serialize_expr(
                     AggregateFunction::NthValue => {
                         protobuf::AggregateFunction::NthValueAgg
                     }
-                    AggregateFunction::StringAgg => {
-                        protobuf::AggregateFunction::StringAgg
-                    }
                 };
 
                 let aggregate_expr = protobuf::AggregateExprNode {
diff --git a/datafusion/proto/src/physical_plan/to_proto.rs 
b/datafusion/proto/src/physical_plan/to_proto.rs
index 886179bf56..ed966509b8 100644
--- a/datafusion/proto/src/physical_plan/to_proto.rs
+++ b/datafusion/proto/src/physical_plan/to_proto.rs
@@ -26,8 +26,7 @@ use datafusion::physical_plan::expressions::{
     ArrayAgg, Avg, BinaryExpr, BoolAnd, BoolOr, CaseExpr, CastExpr, Column, 
Correlation,
     CumeDist, DistinctArrayAgg, Grouping, InListExpr, IsNotNullExpr, 
IsNullExpr, Literal,
     Max, Min, NegativeExpr, NotExpr, NthValue, NthValueAgg, Ntile,
-    OrderSensitiveArrayAgg, Rank, RankType, RowNumber, StringAgg, TryCastExpr,
-    WindowShift,
+    OrderSensitiveArrayAgg, Rank, RankType, RowNumber, TryCastExpr, 
WindowShift,
 };
 use datafusion::physical_plan::udaf::AggregateFunctionExpr;
 use datafusion::physical_plan::windows::{BuiltInWindowExpr, 
PlainAggregateWindowExpr};
@@ -260,8 +259,6 @@ fn aggr_expr_to_aggr_fn(expr: &dyn AggregateExpr) -> 
Result<AggrFn> {
         protobuf::AggregateFunction::Avg
     } else if aggr_expr.downcast_ref::<Correlation>().is_some() {
         protobuf::AggregateFunction::Correlation
-    } else if aggr_expr.downcast_ref::<StringAgg>().is_some() {
-        protobuf::AggregateFunction::StringAgg
     } else if aggr_expr.downcast_ref::<NthValueAgg>().is_some() {
         protobuf::AggregateFunction::NthValueAgg
     } else {
diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs 
b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
index 52696a1061..61764394ee 100644
--- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
+++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
@@ -60,6 +60,7 @@ use datafusion_expr::{
     WindowFunctionDefinition, WindowUDF, WindowUDFImpl,
 };
 use datafusion_functions_aggregate::expr_fn::{bit_and, bit_or, bit_xor};
+use datafusion_functions_aggregate::string_agg::string_agg;
 use datafusion_proto::bytes::{
     logical_plan_from_bytes, logical_plan_from_bytes_with_extension_codec,
     logical_plan_to_bytes, logical_plan_to_bytes_with_extension_codec,
@@ -669,6 +670,7 @@ async fn roundtrip_expr_api() -> Result<()> {
         bit_and(lit(2)),
         bit_or(lit(2)),
         bit_xor(lit(2)),
+        string_agg(col("a").cast_to(&DataType::Utf8, &schema)?, lit("|")),
     ];
 
     // ensure expressions created with the expr api can be round tripped
diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs 
b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
index 7f66cdbf76..eb33132395 100644
--- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
+++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
@@ -48,7 +48,7 @@ use datafusion::physical_plan::analyze::AnalyzeExec;
 use datafusion::physical_plan::empty::EmptyExec;
 use datafusion::physical_plan::expressions::{
     binary, cast, col, in_list, like, lit, Avg, BinaryExpr, Column, NotExpr, 
NthValue,
-    PhysicalSortExpr, StringAgg,
+    PhysicalSortExpr,
 };
 use datafusion::physical_plan::filter::FilterExec;
 use datafusion::physical_plan::insert::DataSinkExec;
@@ -79,6 +79,7 @@ use datafusion_expr::{
     Accumulator, AccumulatorFactoryFunction, AggregateUDF, ColumnarValue, 
ScalarUDF,
     ScalarUDFImpl, Signature, SimpleAggregateUDF, WindowFrame, 
WindowFrameBound,
 };
+use datafusion_functions_aggregate::string_agg::StringAgg;
 use datafusion_proto::physical_plan::{
     AsExecutionPlan, DefaultPhysicalExtensionCodec, PhysicalExtensionCodec,
 };
@@ -357,12 +358,20 @@ fn rountrip_aggregate() -> Result<()> {
             Vec::new(),
         ))],
         // STRING_AGG
-        vec![Arc::new(StringAgg::new(
-            cast(col("b", &schema)?, &schema, DataType::Utf8)?,
-            lit(ScalarValue::Utf8(Some(",".to_string()))),
-            "STRING_AGG(name, ',')".to_string(),
-            DataType::Utf8,
-        ))],
+        vec![udaf::create_aggregate_expr(
+            &AggregateUDF::new_from_impl(StringAgg::new()),
+            &[
+                cast(col("b", &schema)?, &schema, DataType::Utf8)?,
+                lit(ScalarValue::Utf8(Some(",".to_string()))),
+            ],
+            &[],
+            &[],
+            &[],
+            &schema,
+            "STRING_AGG(name, ',')",
+            false,
+            false,
+        )?],
     ];
 
     for aggregates in test_cases {
diff --git a/datafusion/sqllogictest/test_files/aggregate.slt 
b/datafusion/sqllogictest/test_files/aggregate.slt
index 0a6def3d6f..378cab2062 100644
--- a/datafusion/sqllogictest/test_files/aggregate.slt
+++ b/datafusion/sqllogictest/test_files/aggregate.slt
@@ -4972,6 +4972,22 @@ CREATE TABLE float_table (
 ( 32768.3, arrow_cast('NAN','Float32'), 32768.3, 32768.3 ),
 ( 27.3,    27.3,                        27.3,    arrow_cast('NAN','Float64') );
 
+# Test string_agg with largeutf8
+statement ok
+create table string_agg_large_utf8 (c string) as values 
+  (arrow_cast('a', 'LargeUtf8')),
+  (arrow_cast('b', 'LargeUtf8')),
+  (arrow_cast('c', 'LargeUtf8'))
+;
+
+query T
+SELECT STRING_AGG(c, ',') FROM string_agg_large_utf8;
+----
+a,b,c
+
+statement ok
+drop table string_agg_large_utf8;
+
 query RRRRI
 select min(col_f32), max(col_f32), avg(col_f32), sum(col_f32), count(col_f32) 
from float_table;
 ----


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to