This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new a873f51563 Convert `StringAgg` to UDAF (#10945)
a873f51563 is described below
commit a873f5156364f4357592c4bc9117887916e606f7
Author: 张林伟 <[email protected]>
AuthorDate: Tue Jun 18 22:08:13 2024 +0800
Convert `StringAgg` to UDAF (#10945)
* Convert StringAgg to UDAF
* generate proto code
* Fix bug
* Fix
* Add license
* Add doc
* Fix clippy
* Remove aliases field
* Add StringAgg proto test
* Add roundtrip_expr_api test
---
datafusion/expr/src/aggregate_function.rs | 8 -
datafusion/expr/src/type_coercion/aggregates.rs | 26 ---
datafusion/functions-aggregate/src/lib.rs | 2 +
datafusion/functions-aggregate/src/string_agg.rs | 153 +++++++++++++
datafusion/physical-expr/src/aggregate/build_in.rs | 16 --
datafusion/physical-expr/src/aggregate/mod.rs | 1 -
.../physical-expr/src/aggregate/string_agg.rs | 246 ---------------------
datafusion/physical-expr/src/expressions/mod.rs | 1 -
datafusion/proto/proto/datafusion.proto | 2 +-
datafusion/proto/src/generated/pbjson.rs | 3 -
datafusion/proto/src/generated/prost.rs | 4 +-
datafusion/proto/src/logical_plan/from_proto.rs | 1 -
datafusion/proto/src/logical_plan/to_proto.rs | 4 -
datafusion/proto/src/physical_plan/to_proto.rs | 5 +-
.../proto/tests/cases/roundtrip_logical_plan.rs | 2 +
.../proto/tests/cases/roundtrip_physical_plan.rs | 23 +-
datafusion/sqllogictest/test_files/aggregate.slt | 16 ++
17 files changed, 192 insertions(+), 321 deletions(-)
diff --git a/datafusion/expr/src/aggregate_function.rs
b/datafusion/expr/src/aggregate_function.rs
index a7fbf26feb..1cde1c5050 100644
--- a/datafusion/expr/src/aggregate_function.rs
+++ b/datafusion/expr/src/aggregate_function.rs
@@ -51,8 +51,6 @@ pub enum AggregateFunction {
BoolAnd,
/// Bool Or
BoolOr,
- /// String aggregation
- StringAgg,
}
impl AggregateFunction {
@@ -68,7 +66,6 @@ impl AggregateFunction {
Grouping => "GROUPING",
BoolAnd => "BOOL_AND",
BoolOr => "BOOL_OR",
- StringAgg => "STRING_AGG",
}
}
}
@@ -92,7 +89,6 @@ impl FromStr for AggregateFunction {
"min" => AggregateFunction::Min,
"array_agg" => AggregateFunction::ArrayAgg,
"nth_value" => AggregateFunction::NthValue,
- "string_agg" => AggregateFunction::StringAgg,
// statistical
"corr" => AggregateFunction::Correlation,
// other
@@ -146,7 +142,6 @@ impl AggregateFunction {
)))),
AggregateFunction::Grouping => Ok(DataType::Int32),
AggregateFunction::NthValue => Ok(coerced_data_types[0].clone()),
- AggregateFunction::StringAgg => Ok(DataType::LargeUtf8),
}
}
}
@@ -195,9 +190,6 @@ impl AggregateFunction {
AggregateFunction::Correlation => {
Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable)
}
- AggregateFunction::StringAgg => {
- Signature::uniform(2, STRINGS.to_vec(), Volatility::Immutable)
- }
}
}
}
diff --git a/datafusion/expr/src/type_coercion/aggregates.rs
b/datafusion/expr/src/type_coercion/aggregates.rs
index a216c98899..abe6d8b182 100644
--- a/datafusion/expr/src/type_coercion/aggregates.rs
+++ b/datafusion/expr/src/type_coercion/aggregates.rs
@@ -145,23 +145,6 @@ pub fn coerce_types(
}
AggregateFunction::NthValue => Ok(input_types.to_vec()),
AggregateFunction::Grouping => Ok(vec![input_types[0].clone()]),
- AggregateFunction::StringAgg => {
- if !is_string_agg_supported_arg_type(&input_types[0]) {
- return plan_err!(
- "The function {:?} does not support inputs of type {:?}",
- agg_fun,
- input_types[0]
- );
- }
- if !is_string_agg_supported_arg_type(&input_types[1]) {
- return plan_err!(
- "The function {:?} does not support inputs of type {:?}",
- agg_fun,
- input_types[1]
- );
- }
- Ok(vec![LargeUtf8, input_types[1].clone()])
- }
}
}
@@ -391,15 +374,6 @@ pub fn is_integer_arg_type(arg_type: &DataType) -> bool {
arg_type.is_integer()
}
-/// Return `true` if `arg_type` is of a [`DataType`] that the
-/// [`AggregateFunction::StringAgg`] aggregation can operate on.
-pub fn is_string_agg_supported_arg_type(arg_type: &DataType) -> bool {
- matches!(
- arg_type,
- DataType::Utf8 | DataType::LargeUtf8 | DataType::Null
- )
-}
-
#[cfg(test)]
mod tests {
use super::*;
diff --git a/datafusion/functions-aggregate/src/lib.rs
b/datafusion/functions-aggregate/src/lib.rs
index 990303bd1d..20a8d2c159 100644
--- a/datafusion/functions-aggregate/src/lib.rs
+++ b/datafusion/functions-aggregate/src/lib.rs
@@ -70,6 +70,7 @@ pub mod approx_median;
pub mod approx_percentile_cont;
pub mod approx_percentile_cont_with_weight;
pub mod bit_and_or_xor;
+pub mod string_agg;
use crate::approx_percentile_cont::approx_percentile_cont_udaf;
use
crate::approx_percentile_cont_with_weight::approx_percentile_cont_with_weight_udaf;
@@ -138,6 +139,7 @@ pub fn all_default_aggregate_functions() ->
Vec<Arc<AggregateUDF>> {
approx_distinct::approx_distinct_udaf(),
approx_percentile_cont_udaf(),
approx_percentile_cont_with_weight_udaf(),
+ string_agg::string_agg_udaf(),
bit_and_or_xor::bit_and_udaf(),
bit_and_or_xor::bit_or_udaf(),
bit_and_or_xor::bit_xor_udaf(),
diff --git a/datafusion/functions-aggregate/src/string_agg.rs
b/datafusion/functions-aggregate/src/string_agg.rs
new file mode 100644
index 0000000000..371cc8fb97
--- /dev/null
+++ b/datafusion/functions-aggregate/src/string_agg.rs
@@ -0,0 +1,153 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! [`StringAgg`] and [`StringAggAccumulator`] accumulator for the
`string_agg` function
+
+use arrow::array::ArrayRef;
+use arrow_schema::DataType;
+use datafusion_common::cast::as_generic_string_array;
+use datafusion_common::Result;
+use datafusion_common::{not_impl_err, ScalarValue};
+use datafusion_expr::function::AccumulatorArgs;
+use datafusion_expr::{
+ Accumulator, AggregateUDFImpl, Expr, Signature, TypeSignature, Volatility,
+};
+use std::any::Any;
+
+make_udaf_expr_and_func!(
+ StringAgg,
+ string_agg,
+ expr delimiter,
+ "Concatenates the values of string expressions and places separator values
between them",
+ string_agg_udaf
+);
+
+/// STRING_AGG aggregate expression
+#[derive(Debug)]
+pub struct StringAgg {
+ signature: Signature,
+}
+
+impl StringAgg {
+ /// Create a new StringAgg aggregate function
+ pub fn new() -> Self {
+ Self {
+ signature: Signature::one_of(
+ vec![
+ TypeSignature::Exact(vec![DataType::LargeUtf8,
DataType::Utf8]),
+ TypeSignature::Exact(vec![DataType::LargeUtf8,
DataType::LargeUtf8]),
+ TypeSignature::Exact(vec![DataType::LargeUtf8,
DataType::Null]),
+ ],
+ Volatility::Immutable,
+ ),
+ }
+ }
+}
+
+impl Default for StringAgg {
+ fn default() -> Self {
+ Self::new()
+ }
+}
+
+impl AggregateUDFImpl for StringAgg {
+ fn as_any(&self) -> &dyn Any {
+ self
+ }
+
+ fn name(&self) -> &str {
+ "string_agg"
+ }
+
+ fn signature(&self) -> &Signature {
+ &self.signature
+ }
+
+ fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
+ Ok(DataType::LargeUtf8)
+ }
+
+ fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn
Accumulator>> {
+ match &acc_args.input_exprs[1] {
+ Expr::Literal(ScalarValue::Utf8(Some(delimiter)))
+ | Expr::Literal(ScalarValue::LargeUtf8(Some(delimiter))) => {
+ Ok(Box::new(StringAggAccumulator::new(delimiter)))
+ }
+ Expr::Literal(ScalarValue::Utf8(None))
+ | Expr::Literal(ScalarValue::LargeUtf8(None))
+ | Expr::Literal(ScalarValue::Null) => {
+ Ok(Box::new(StringAggAccumulator::new("")))
+ }
+ _ => not_impl_err!(
+ "StringAgg not supported for delimiter {}",
+ &acc_args.input_exprs[1]
+ ),
+ }
+ }
+}
+
+#[derive(Debug)]
+pub(crate) struct StringAggAccumulator {
+ values: Option<String>,
+ delimiter: String,
+}
+
+impl StringAggAccumulator {
+ pub fn new(delimiter: &str) -> Self {
+ Self {
+ values: None,
+ delimiter: delimiter.to_string(),
+ }
+ }
+}
+
+impl Accumulator for StringAggAccumulator {
+ fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
+ let string_array: Vec<_> = as_generic_string_array::<i64>(&values[0])?
+ .iter()
+ .filter_map(|v| v.as_ref().map(ToString::to_string))
+ .collect();
+ if !string_array.is_empty() {
+ let s = string_array.join(self.delimiter.as_str());
+ let v = self.values.get_or_insert("".to_string());
+ if !v.is_empty() {
+ v.push_str(self.delimiter.as_str());
+ }
+ v.push_str(s.as_str());
+ }
+ Ok(())
+ }
+
+ fn merge_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
+ self.update_batch(values)?;
+ Ok(())
+ }
+
+ fn state(&mut self) -> Result<Vec<ScalarValue>> {
+ Ok(vec![self.evaluate()?])
+ }
+
+ fn evaluate(&mut self) -> Result<ScalarValue> {
+ Ok(ScalarValue::LargeUtf8(self.values.clone()))
+ }
+
+ fn size(&self) -> usize {
+ std::mem::size_of_val(self)
+ + self.values.as_ref().map(|v| v.capacity()).unwrap_or(0)
+ + self.delimiter.capacity()
+ }
+}
diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs
b/datafusion/physical-expr/src/aggregate/build_in.rs
index 6c01decdbf..1dfe9ffd69 100644
--- a/datafusion/physical-expr/src/aggregate/build_in.rs
+++ b/datafusion/physical-expr/src/aggregate/build_in.rs
@@ -155,22 +155,6 @@ pub fn create_aggregate_expr(
ordering_req.to_vec(),
))
}
- (AggregateFunction::StringAgg, false) => {
- if !ordering_req.is_empty() {
- return not_impl_err!(
- "STRING_AGG(ORDER BY a ASC) order-sensitive aggregations
are not available"
- );
- }
- Arc::new(expressions::StringAgg::new(
- input_phy_exprs[0].clone(),
- input_phy_exprs[1].clone(),
- name,
- data_type,
- ))
- }
- (AggregateFunction::StringAgg, true) => {
- return not_impl_err!("STRING_AGG(DISTINCT) aggregations are not
available");
- }
})
}
diff --git a/datafusion/physical-expr/src/aggregate/mod.rs
b/datafusion/physical-expr/src/aggregate/mod.rs
index 0b1f5f5774..87c7deccc2 100644
--- a/datafusion/physical-expr/src/aggregate/mod.rs
+++ b/datafusion/physical-expr/src/aggregate/mod.rs
@@ -26,7 +26,6 @@ pub(crate) mod correlation;
pub(crate) mod covariance;
pub(crate) mod grouping;
pub(crate) mod nth_value;
-pub(crate) mod string_agg;
#[macro_use]
pub(crate) mod min_max;
pub(crate) mod groups_accumulator;
diff --git a/datafusion/physical-expr/src/aggregate/string_agg.rs
b/datafusion/physical-expr/src/aggregate/string_agg.rs
deleted file mode 100644
index dc0ffc5579..0000000000
--- a/datafusion/physical-expr/src/aggregate/string_agg.rs
+++ /dev/null
@@ -1,246 +0,0 @@
-// Licensed to the Apache Software Foundation (ASF) under one
-// or more contributor license agreements. See the NOTICE file
-// distributed with this work for additional information
-// regarding copyright ownership. The ASF licenses this file
-// to you under the Apache License, Version 2.0 (the
-// "License"); you may not use this file except in compliance
-// with the License. You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing,
-// software distributed under the License is distributed on an
-// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-// KIND, either express or implied. See the License for the
-// specific language governing permissions and limitations
-// under the License.
-
-//! [`StringAgg`] and [`StringAggAccumulator`] accumulator for the
`string_agg` function
-
-use crate::aggregate::utils::down_cast_any_ref;
-use crate::expressions::{format_state_name, Literal};
-use crate::{AggregateExpr, PhysicalExpr};
-use arrow::array::ArrayRef;
-use arrow::datatypes::{DataType, Field};
-use datafusion_common::cast::as_generic_string_array;
-use datafusion_common::{not_impl_err, Result, ScalarValue};
-use datafusion_expr::Accumulator;
-use std::any::Any;
-use std::sync::Arc;
-
-/// STRING_AGG aggregate expression
-#[derive(Debug)]
-pub struct StringAgg {
- name: String,
- data_type: DataType,
- expr: Arc<dyn PhysicalExpr>,
- delimiter: Arc<dyn PhysicalExpr>,
- nullable: bool,
-}
-
-impl StringAgg {
- /// Create a new StringAgg aggregate function
- pub fn new(
- expr: Arc<dyn PhysicalExpr>,
- delimiter: Arc<dyn PhysicalExpr>,
- name: impl Into<String>,
- data_type: DataType,
- ) -> Self {
- Self {
- name: name.into(),
- data_type,
- delimiter,
- expr,
- nullable: true,
- }
- }
-}
-
-impl AggregateExpr for StringAgg {
- fn as_any(&self) -> &dyn Any {
- self
- }
-
- fn field(&self) -> Result<Field> {
- Ok(Field::new(
- &self.name,
- self.data_type.clone(),
- self.nullable,
- ))
- }
-
- fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
- if let Some(delimiter) =
self.delimiter.as_any().downcast_ref::<Literal>() {
- match delimiter.value() {
- ScalarValue::Utf8(Some(delimiter))
- | ScalarValue::LargeUtf8(Some(delimiter)) => {
- return Ok(Box::new(StringAggAccumulator::new(delimiter)));
- }
- ScalarValue::Null => {
- return Ok(Box::new(StringAggAccumulator::new("")));
- }
- _ => return not_impl_err!("StringAgg not supported for {}",
self.name),
- }
- }
- not_impl_err!("StringAgg not supported for {}", self.name)
- }
-
- fn state_fields(&self) -> Result<Vec<Field>> {
- Ok(vec![Field::new(
- format_state_name(&self.name, "string_agg"),
- self.data_type.clone(),
- self.nullable,
- )])
- }
-
- fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
- vec![self.expr.clone(), self.delimiter.clone()]
- }
-
- fn name(&self) -> &str {
- &self.name
- }
-}
-
-impl PartialEq<dyn Any> for StringAgg {
- fn eq(&self, other: &dyn Any) -> bool {
- down_cast_any_ref(other)
- .downcast_ref::<Self>()
- .map(|x| {
- self.name == x.name
- && self.data_type == x.data_type
- && self.expr.eq(&x.expr)
- && self.delimiter.eq(&x.delimiter)
- })
- .unwrap_or(false)
- }
-}
-
-#[derive(Debug)]
-pub(crate) struct StringAggAccumulator {
- values: Option<String>,
- delimiter: String,
-}
-
-impl StringAggAccumulator {
- pub fn new(delimiter: &str) -> Self {
- Self {
- values: None,
- delimiter: delimiter.to_string(),
- }
- }
-}
-
-impl Accumulator for StringAggAccumulator {
- fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
- let string_array: Vec<_> = as_generic_string_array::<i64>(&values[0])?
- .iter()
- .filter_map(|v| v.as_ref().map(ToString::to_string))
- .collect();
- if !string_array.is_empty() {
- let s = string_array.join(self.delimiter.as_str());
- let v = self.values.get_or_insert("".to_string());
- if !v.is_empty() {
- v.push_str(self.delimiter.as_str());
- }
- v.push_str(s.as_str());
- }
- Ok(())
- }
-
- fn merge_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
- self.update_batch(values)?;
- Ok(())
- }
-
- fn state(&mut self) -> Result<Vec<ScalarValue>> {
- Ok(vec![self.evaluate()?])
- }
-
- fn evaluate(&mut self) -> Result<ScalarValue> {
- Ok(ScalarValue::LargeUtf8(self.values.clone()))
- }
-
- fn size(&self) -> usize {
- std::mem::size_of_val(self)
- + self.values.as_ref().map(|v| v.capacity()).unwrap_or(0)
- + self.delimiter.capacity()
- }
-}
-
-#[cfg(test)]
-mod tests {
- use super::*;
- use crate::expressions::tests::aggregate;
- use crate::expressions::{col, create_aggregate_expr, try_cast};
- use arrow::datatypes::*;
- use arrow::record_batch::RecordBatch;
- use arrow_array::LargeStringArray;
- use arrow_array::StringArray;
- use datafusion_expr::type_coercion::aggregates::coerce_types;
- use datafusion_expr::AggregateFunction;
-
- fn assert_string_aggregate(
- array: ArrayRef,
- function: AggregateFunction,
- distinct: bool,
- expected: ScalarValue,
- delimiter: String,
- ) {
- let data_type = array.data_type();
- let sig = function.signature();
- let coerced =
- coerce_types(&function, &[data_type.clone(), DataType::Utf8],
&sig).unwrap();
-
- let input_schema = Schema::new(vec![Field::new("a", data_type.clone(),
true)]);
- let batch =
- RecordBatch::try_new(Arc::new(input_schema.clone()),
vec![array]).unwrap();
-
- let input = try_cast(
- col("a", &input_schema).unwrap(),
- &input_schema,
- coerced[0].clone(),
- )
- .unwrap();
-
- let delimiter = Arc::new(Literal::new(ScalarValue::from(delimiter)));
- let schema = Schema::new(vec![Field::new("a", coerced[0].clone(),
true)]);
- let agg = create_aggregate_expr(
- &function,
- distinct,
- &[input, delimiter],
- &[],
- &schema,
- "agg",
- false,
- )
- .unwrap();
-
- let result = aggregate(&batch, agg).unwrap();
- assert_eq!(expected, result);
- }
-
- #[test]
- fn string_agg_utf8() {
- let a: ArrayRef = Arc::new(StringArray::from(vec!["h", "e", "l", "l",
"o"]));
- assert_string_aggregate(
- a,
- AggregateFunction::StringAgg,
- false,
- ScalarValue::LargeUtf8(Some("h,e,l,l,o".to_owned())),
- ",".to_owned(),
- );
- }
-
- #[test]
- fn string_agg_largeutf8() {
- let a: ArrayRef = Arc::new(LargeStringArray::from(vec!["h", "e", "l",
"l", "o"]));
- assert_string_aggregate(
- a,
- AggregateFunction::StringAgg,
- false,
- ScalarValue::LargeUtf8(Some("h|e|l|l|o".to_owned())),
- "|".to_owned(),
- );
- }
-}
diff --git a/datafusion/physical-expr/src/expressions/mod.rs
b/datafusion/physical-expr/src/expressions/mod.rs
index bffaafd7da..3226104040 100644
--- a/datafusion/physical-expr/src/expressions/mod.rs
+++ b/datafusion/physical-expr/src/expressions/mod.rs
@@ -47,7 +47,6 @@ pub use crate::aggregate::grouping::Grouping;
pub use crate::aggregate::min_max::{Max, MaxAccumulator, Min, MinAccumulator};
pub use crate::aggregate::nth_value::NthValueAgg;
pub use crate::aggregate::stats::StatsType;
-pub use crate::aggregate::string_agg::StringAgg;
pub use crate::window::cume_dist::{cume_dist, CumeDist};
pub use crate::window::lead_lag::{lag, lead, WindowShift};
pub use crate::window::nth_value::NthValue;
diff --git a/datafusion/proto/proto/datafusion.proto
b/datafusion/proto/proto/datafusion.proto
index ae4445eaa8..6375df721a 100644
--- a/datafusion/proto/proto/datafusion.proto
+++ b/datafusion/proto/proto/datafusion.proto
@@ -505,7 +505,7 @@ enum AggregateFunction {
// REGR_SXX = 32;
// REGR_SYY = 33;
// REGR_SXY = 34;
- STRING_AGG = 35;
+ // STRING_AGG = 35;
NTH_VALUE_AGG = 36;
}
diff --git a/datafusion/proto/src/generated/pbjson.rs
b/datafusion/proto/src/generated/pbjson.rs
index 243c75435f..5c483f70d1 100644
--- a/datafusion/proto/src/generated/pbjson.rs
+++ b/datafusion/proto/src/generated/pbjson.rs
@@ -540,7 +540,6 @@ impl serde::Serialize for AggregateFunction {
Self::Grouping => "GROUPING",
Self::BoolAnd => "BOOL_AND",
Self::BoolOr => "BOOL_OR",
- Self::StringAgg => "STRING_AGG",
Self::NthValueAgg => "NTH_VALUE_AGG",
};
serializer.serialize_str(variant)
@@ -561,7 +560,6 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction {
"GROUPING",
"BOOL_AND",
"BOOL_OR",
- "STRING_AGG",
"NTH_VALUE_AGG",
];
@@ -611,7 +609,6 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction {
"GROUPING" => Ok(AggregateFunction::Grouping),
"BOOL_AND" => Ok(AggregateFunction::BoolAnd),
"BOOL_OR" => Ok(AggregateFunction::BoolOr),
- "STRING_AGG" => Ok(AggregateFunction::StringAgg),
"NTH_VALUE_AGG" => Ok(AggregateFunction::NthValueAgg),
_ => Err(serde::de::Error::unknown_variant(value, FIELDS)),
}
diff --git a/datafusion/proto/src/generated/prost.rs
b/datafusion/proto/src/generated/prost.rs
index 1172eccb90..bc5b6be2ad 100644
--- a/datafusion/proto/src/generated/prost.rs
+++ b/datafusion/proto/src/generated/prost.rs
@@ -1959,7 +1959,7 @@ pub enum AggregateFunction {
/// REGR_SXX = 32;
/// REGR_SYY = 33;
/// REGR_SXY = 34;
- StringAgg = 35,
+ /// STRING_AGG = 35;
NthValueAgg = 36,
}
impl AggregateFunction {
@@ -1977,7 +1977,6 @@ impl AggregateFunction {
AggregateFunction::Grouping => "GROUPING",
AggregateFunction::BoolAnd => "BOOL_AND",
AggregateFunction::BoolOr => "BOOL_OR",
- AggregateFunction::StringAgg => "STRING_AGG",
AggregateFunction::NthValueAgg => "NTH_VALUE_AGG",
}
}
@@ -1992,7 +1991,6 @@ impl AggregateFunction {
"GROUPING" => Some(Self::Grouping),
"BOOL_AND" => Some(Self::BoolAnd),
"BOOL_OR" => Some(Self::BoolOr),
- "STRING_AGG" => Some(Self::StringAgg),
"NTH_VALUE_AGG" => Some(Self::NthValueAgg),
_ => None,
}
diff --git a/datafusion/proto/src/logical_plan/from_proto.rs
b/datafusion/proto/src/logical_plan/from_proto.rs
index 43cc352f98..5bec655bb1 100644
--- a/datafusion/proto/src/logical_plan/from_proto.rs
+++ b/datafusion/proto/src/logical_plan/from_proto.rs
@@ -146,7 +146,6 @@ impl From<protobuf::AggregateFunction> for
AggregateFunction {
protobuf::AggregateFunction::Correlation => Self::Correlation,
protobuf::AggregateFunction::Grouping => Self::Grouping,
protobuf::AggregateFunction::NthValueAgg => Self::NthValue,
- protobuf::AggregateFunction::StringAgg => Self::StringAgg,
}
}
}
diff --git a/datafusion/proto/src/logical_plan/to_proto.rs
b/datafusion/proto/src/logical_plan/to_proto.rs
index 33a58daeaf..66b7c77799 100644
--- a/datafusion/proto/src/logical_plan/to_proto.rs
+++ b/datafusion/proto/src/logical_plan/to_proto.rs
@@ -117,7 +117,6 @@ impl From<&AggregateFunction> for
protobuf::AggregateFunction {
AggregateFunction::Correlation => Self::Correlation,
AggregateFunction::Grouping => Self::Grouping,
AggregateFunction::NthValue => Self::NthValueAgg,
- AggregateFunction::StringAgg => Self::StringAgg,
}
}
}
@@ -387,9 +386,6 @@ pub fn serialize_expr(
AggregateFunction::NthValue => {
protobuf::AggregateFunction::NthValueAgg
}
- AggregateFunction::StringAgg => {
- protobuf::AggregateFunction::StringAgg
- }
};
let aggregate_expr = protobuf::AggregateExprNode {
diff --git a/datafusion/proto/src/physical_plan/to_proto.rs
b/datafusion/proto/src/physical_plan/to_proto.rs
index 886179bf56..ed966509b8 100644
--- a/datafusion/proto/src/physical_plan/to_proto.rs
+++ b/datafusion/proto/src/physical_plan/to_proto.rs
@@ -26,8 +26,7 @@ use datafusion::physical_plan::expressions::{
ArrayAgg, Avg, BinaryExpr, BoolAnd, BoolOr, CaseExpr, CastExpr, Column,
Correlation,
CumeDist, DistinctArrayAgg, Grouping, InListExpr, IsNotNullExpr,
IsNullExpr, Literal,
Max, Min, NegativeExpr, NotExpr, NthValue, NthValueAgg, Ntile,
- OrderSensitiveArrayAgg, Rank, RankType, RowNumber, StringAgg, TryCastExpr,
- WindowShift,
+ OrderSensitiveArrayAgg, Rank, RankType, RowNumber, TryCastExpr,
WindowShift,
};
use datafusion::physical_plan::udaf::AggregateFunctionExpr;
use datafusion::physical_plan::windows::{BuiltInWindowExpr,
PlainAggregateWindowExpr};
@@ -260,8 +259,6 @@ fn aggr_expr_to_aggr_fn(expr: &dyn AggregateExpr) ->
Result<AggrFn> {
protobuf::AggregateFunction::Avg
} else if aggr_expr.downcast_ref::<Correlation>().is_some() {
protobuf::AggregateFunction::Correlation
- } else if aggr_expr.downcast_ref::<StringAgg>().is_some() {
- protobuf::AggregateFunction::StringAgg
} else if aggr_expr.downcast_ref::<NthValueAgg>().is_some() {
protobuf::AggregateFunction::NthValueAgg
} else {
diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
index 52696a1061..61764394ee 100644
--- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
+++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
@@ -60,6 +60,7 @@ use datafusion_expr::{
WindowFunctionDefinition, WindowUDF, WindowUDFImpl,
};
use datafusion_functions_aggregate::expr_fn::{bit_and, bit_or, bit_xor};
+use datafusion_functions_aggregate::string_agg::string_agg;
use datafusion_proto::bytes::{
logical_plan_from_bytes, logical_plan_from_bytes_with_extension_codec,
logical_plan_to_bytes, logical_plan_to_bytes_with_extension_codec,
@@ -669,6 +670,7 @@ async fn roundtrip_expr_api() -> Result<()> {
bit_and(lit(2)),
bit_or(lit(2)),
bit_xor(lit(2)),
+ string_agg(col("a").cast_to(&DataType::Utf8, &schema)?, lit("|")),
];
// ensure expressions created with the expr api can be round tripped
diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
index 7f66cdbf76..eb33132395 100644
--- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
+++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
@@ -48,7 +48,7 @@ use datafusion::physical_plan::analyze::AnalyzeExec;
use datafusion::physical_plan::empty::EmptyExec;
use datafusion::physical_plan::expressions::{
binary, cast, col, in_list, like, lit, Avg, BinaryExpr, Column, NotExpr,
NthValue,
- PhysicalSortExpr, StringAgg,
+ PhysicalSortExpr,
};
use datafusion::physical_plan::filter::FilterExec;
use datafusion::physical_plan::insert::DataSinkExec;
@@ -79,6 +79,7 @@ use datafusion_expr::{
Accumulator, AccumulatorFactoryFunction, AggregateUDF, ColumnarValue,
ScalarUDF,
ScalarUDFImpl, Signature, SimpleAggregateUDF, WindowFrame,
WindowFrameBound,
};
+use datafusion_functions_aggregate::string_agg::StringAgg;
use datafusion_proto::physical_plan::{
AsExecutionPlan, DefaultPhysicalExtensionCodec, PhysicalExtensionCodec,
};
@@ -357,12 +358,20 @@ fn rountrip_aggregate() -> Result<()> {
Vec::new(),
))],
// STRING_AGG
- vec![Arc::new(StringAgg::new(
- cast(col("b", &schema)?, &schema, DataType::Utf8)?,
- lit(ScalarValue::Utf8(Some(",".to_string()))),
- "STRING_AGG(name, ',')".to_string(),
- DataType::Utf8,
- ))],
+ vec![udaf::create_aggregate_expr(
+ &AggregateUDF::new_from_impl(StringAgg::new()),
+ &[
+ cast(col("b", &schema)?, &schema, DataType::Utf8)?,
+ lit(ScalarValue::Utf8(Some(",".to_string()))),
+ ],
+ &[],
+ &[],
+ &[],
+ &schema,
+ "STRING_AGG(name, ',')",
+ false,
+ false,
+ )?],
];
for aggregates in test_cases {
diff --git a/datafusion/sqllogictest/test_files/aggregate.slt
b/datafusion/sqllogictest/test_files/aggregate.slt
index 0a6def3d6f..378cab2062 100644
--- a/datafusion/sqllogictest/test_files/aggregate.slt
+++ b/datafusion/sqllogictest/test_files/aggregate.slt
@@ -4972,6 +4972,22 @@ CREATE TABLE float_table (
( 32768.3, arrow_cast('NAN','Float32'), 32768.3, 32768.3 ),
( 27.3, 27.3, 27.3, arrow_cast('NAN','Float64') );
+# Test string_agg with largeutf8
+statement ok
+create table string_agg_large_utf8 (c string) as values
+ (arrow_cast('a', 'LargeUtf8')),
+ (arrow_cast('b', 'LargeUtf8')),
+ (arrow_cast('c', 'LargeUtf8'))
+;
+
+query T
+SELECT STRING_AGG(c, ',') FROM string_agg_large_utf8;
+----
+a,b,c
+
+statement ok
+drop table string_agg_large_utf8;
+
query RRRRI
select min(col_f32), max(col_f32), avg(col_f32), sum(col_f32), count(col_f32)
from float_table;
----
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]