This is an automated email from the ASF dual-hosted git repository. thinkharderdev pushed a commit to branch issue-6062 in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
commit 15532a0123f5d0c9237ffc623d4ef6090c82b4ef Author: Dan Harris <[email protected]> AuthorDate: Wed Apr 19 13:28:13 2023 -0400 Add support for UDAF in physical plan serialization --- datafusion/core/src/physical_plan/udaf.rs | 7 ++ datafusion/proto/proto/datafusion.proto | 5 +- datafusion/proto/src/generated/pbjson.rs | 57 +++++++---- datafusion/proto/src/generated/prost.rs | 17 +++- datafusion/proto/src/physical_plan/mod.rs | 129 ++++++++++++++++++++----- datafusion/proto/src/physical_plan/to_proto.rs | 33 +++++-- 6 files changed, 196 insertions(+), 52 deletions(-) diff --git a/datafusion/core/src/physical_plan/udaf.rs b/datafusion/core/src/physical_plan/udaf.rs index cbbb851865..07e5cc3e6d 100644 --- a/datafusion/core/src/physical_plan/udaf.rs +++ b/datafusion/core/src/physical_plan/udaf.rs @@ -65,6 +65,13 @@ pub struct AggregateFunctionExpr { name: String, } +impl AggregateFunctionExpr { + /// Return the `AggregateUDF` used by this `AggregateFunctionExpr` + pub fn fun(&self) -> &AggregateUDF { + &self.fun + } +} + impl AggregateExpr for AggregateFunctionExpr { /// Return a reference to Any that can be used for downcasting fn as_any(&self) -> &dyn Any { diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 3023cbc264..7d02fda86c 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -1062,7 +1062,10 @@ message PhysicalScalarUdfNode { } message PhysicalAggregateExprNode { - AggregateFunction aggr_function = 1; + oneof AggregateFunction { + AggregateFunction aggr_function = 1; + string user_defined_aggr_function = 4; + } repeated PhysicalExprNode expr = 2; bool distinct = 3; } diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 6a416d37c9..553f3f2911 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -12661,27 +12661,34 @@ impl serde::Serialize for PhysicalAggregateExprNode { { use serde::ser::SerializeStruct; let mut len = 0; - if self.aggr_function != 0 { - len += 1; - } if !self.expr.is_empty() { len += 1; } if self.distinct { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalAggregateExprNode", len)?; - if self.aggr_function != 0 { - let v = AggregateFunction::from_i32(self.aggr_function) - .ok_or_else(|| serde::ser::Error::custom(format!("Invalid variant {}", self.aggr_function)))?; - struct_ser.serialize_field("aggrFunction", &v)?; + if self.aggregate_function.is_some() { + len += 1; } + let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalAggregateExprNode", len)?; if !self.expr.is_empty() { struct_ser.serialize_field("expr", &self.expr)?; } if self.distinct { struct_ser.serialize_field("distinct", &self.distinct)?; } + if let Some(v) = self.aggregate_function.as_ref() { + match v { + physical_aggregate_expr_node::AggregateFunction::AggrFunction(v) => { + let v = AggregateFunction::from_i32(*v) + .ok_or_else(|| serde::ser::Error::custom(format!("Invalid variant {}", *v)))?; + struct_ser.serialize_field("aggrFunction", &v)?; + } + physical_aggregate_expr_node::AggregateFunction::UserDefinedAggrFunction(v) => { + struct_ser.serialize_field("userDefinedAggrFunction", v)?; + } + } + } struct_ser.end() } } @@ -12692,17 +12699,20 @@ impl<'de> serde::Deserialize<'de> for PhysicalAggregateExprNode { D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "aggr_function", - "aggrFunction", "expr", "distinct", + "aggr_function", + "aggrFunction", + "user_defined_aggr_function", + "userDefinedAggrFunction", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - AggrFunction, Expr, Distinct, + AggrFunction, + UserDefinedAggrFunction, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize<D>(deserializer: D) -> std::result::Result<GeneratedField, D::Error> @@ -12724,9 +12734,10 @@ impl<'de> serde::Deserialize<'de> for PhysicalAggregateExprNode { E: serde::de::Error, { match value { - "aggrFunction" | "aggr_function" => Ok(GeneratedField::AggrFunction), "expr" => Ok(GeneratedField::Expr), "distinct" => Ok(GeneratedField::Distinct), + "aggrFunction" | "aggr_function" => Ok(GeneratedField::AggrFunction), + "userDefinedAggrFunction" | "user_defined_aggr_function" => Ok(GeneratedField::UserDefinedAggrFunction), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -12746,17 +12757,11 @@ impl<'de> serde::Deserialize<'de> for PhysicalAggregateExprNode { where V: serde::de::MapAccess<'de>, { - let mut aggr_function__ = None; let mut expr__ = None; let mut distinct__ = None; + let mut aggregate_function__ = None; while let Some(k) = map.next_key()? { match k { - GeneratedField::AggrFunction => { - if aggr_function__.is_some() { - return Err(serde::de::Error::duplicate_field("aggrFunction")); - } - aggr_function__ = Some(map.next_value::<AggregateFunction>()? as i32); - } GeneratedField::Expr => { if expr__.is_some() { return Err(serde::de::Error::duplicate_field("expr")); @@ -12769,12 +12774,24 @@ impl<'de> serde::Deserialize<'de> for PhysicalAggregateExprNode { } distinct__ = Some(map.next_value()?); } + GeneratedField::AggrFunction => { + if aggregate_function__.is_some() { + return Err(serde::de::Error::duplicate_field("aggrFunction")); + } + aggregate_function__ = map.next_value::<::std::option::Option<AggregateFunction>>()?.map(|x| physical_aggregate_expr_node::AggregateFunction::AggrFunction(x as i32)); + } + GeneratedField::UserDefinedAggrFunction => { + if aggregate_function__.is_some() { + return Err(serde::de::Error::duplicate_field("userDefinedAggrFunction")); + } + aggregate_function__ = map.next_value::<::std::option::Option<_>>()?.map(physical_aggregate_expr_node::AggregateFunction::UserDefinedAggrFunction); + } } } Ok(PhysicalAggregateExprNode { - aggr_function: aggr_function__.unwrap_or_default(), expr: expr__.unwrap_or_default(), distinct: distinct__.unwrap_or_default(), + aggregate_function: aggregate_function__, }) } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 8ec16070ee..fd3cdc1292 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1487,12 +1487,25 @@ pub struct PhysicalScalarUdfNode { #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct PhysicalAggregateExprNode { - #[prost(enumeration = "AggregateFunction", tag = "1")] - pub aggr_function: i32, #[prost(message, repeated, tag = "2")] pub expr: ::prost::alloc::vec::Vec<PhysicalExprNode>, #[prost(bool, tag = "3")] pub distinct: bool, + #[prost(oneof = "physical_aggregate_expr_node::AggregateFunction", tags = "1, 4")] + pub aggregate_function: ::core::option::Option< + physical_aggregate_expr_node::AggregateFunction, + >, +} +/// Nested message and enum types in `PhysicalAggregateExprNode`. +pub mod physical_aggregate_expr_node { + #[allow(clippy::derive_partial_eq_without_eq)] + #[derive(Clone, PartialEq, ::prost::Oneof)] + pub enum AggregateFunction { + #[prost(enumeration = "super::AggregateFunction", tag = "1")] + AggrFunction(i32), + #[prost(string, tag = "4")] + UserDefinedAggrFunction(::prost::alloc::string::String), + } } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index 381073dec0..2c35428e86 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -45,7 +45,7 @@ use datafusion::physical_plan::sorts::sort_preserving_merge::SortPreservingMerge use datafusion::physical_plan::union::UnionExec; use datafusion::physical_plan::windows::{create_window_expr, WindowAggExec}; use datafusion::physical_plan::{ - AggregateExpr, ExecutionPlan, Partitioning, PhysicalExpr, WindowExpr, + udaf, AggregateExpr, ExecutionPlan, Partitioning, PhysicalExpr, WindowExpr, }; use datafusion_common::{DataFusionError, Result}; use prost::bytes::BufMut; @@ -56,6 +56,7 @@ use crate::common::{csv_delimiter_to_string, str_to_byte}; use crate::physical_plan::from_proto::{ parse_physical_expr, parse_protobuf_file_scan_config, }; +use crate::protobuf::physical_aggregate_expr_node::AggregateFunction; use crate::protobuf::physical_expr_node::ExprType; use crate::protobuf::physical_plan_node::PhysicalPlanType; use crate::protobuf::repartition_exec_node::PartitionMethod; @@ -427,29 +428,38 @@ impl AsExecutionPlan for PhysicalPlanNode { match expr_type { ExprType::AggregateExpr(agg_node) => { - let aggr_function = - protobuf::AggregateFunction::from_i32( - agg_node.aggr_function, - ) - .ok_or_else( - || { - proto_error(format!( - "Received an unknown aggregate function: {}", - agg_node.aggr_function - )) - }, - )?; - let input_phy_expr: Vec<Arc<dyn PhysicalExpr>> = agg_node.expr.iter() .map(|e| parse_physical_expr(e, registry, &physical_schema).unwrap()).collect(); - Ok(create_aggregate_expr( - &aggr_function.into(), - agg_node.distinct, - input_phy_expr.as_slice(), - &physical_schema, - name.to_string(), - )?) + agg_node.aggregate_function.as_ref().map(|func| { + match func { + AggregateFunction::AggrFunction(i) => { + let aggr_function = protobuf::AggregateFunction::from_i32(*i) + .ok_or_else( + || { + proto_error(format!( + "Received an unknown aggregate function: {}", + i + )) + }, + )?; + + create_aggregate_expr( + &aggr_function.into(), + agg_node.distinct, + input_phy_expr.as_slice(), + &physical_schema, + name.to_string(), + ) + } + AggregateFunction::UserDefinedAggrFunction(udaf_name) => { + let agg_udf = registry.udaf(udaf_name)?; + udaf::create_aggregate_expr(agg_udf.as_ref(), &input_phy_expr, &physical_schema, name) + } + } + }).transpose()?.ok_or_else(|| { + proto_error("Invalid AggregateExpr, missing aggregate_function") + }) } _ => Err(DataFusionError::Internal( "Invalid aggregate expression for AggregateExec" @@ -1238,9 +1248,9 @@ mod roundtrip_tests { use datafusion::physical_expr::ScalarFunctionExpr; use datafusion::physical_plan::aggregates::PhysicalGroupBy; use datafusion::physical_plan::expressions::{like, BinaryExpr, GetIndexedFieldExpr}; - use datafusion::physical_plan::functions; use datafusion::physical_plan::functions::make_scalar_function; use datafusion::physical_plan::projection::ProjectionExec; + use datafusion::physical_plan::{functions, udaf}; use datafusion::{ arrow::{ compute::kernels::sort::SortOptions, @@ -1264,6 +1274,10 @@ mod roundtrip_tests { scalar::ScalarValue, }; use datafusion_common::Result; + use datafusion_expr::{ + Accumulator, AccumulatorFunctionImplementation, AggregateUDF, ReturnTypeFunction, + Signature, StateTypeFunction, + }; fn roundtrip_test(exec_plan: Arc<dyn ExecutionPlan>) -> Result<()> { let ctx = SessionContext::new(); @@ -1419,6 +1433,77 @@ mod roundtrip_tests { )?)) } + #[test] + fn roundtrip_aggregate_udaf() -> Result<()> { + let field_a = Field::new("a", DataType::Int64, false); + let field_b = Field::new("b", DataType::Int64, false); + let schema = Arc::new(Schema::new(vec![field_a, field_b])); + + #[derive(Debug)] + struct Example; + impl Accumulator for Example { + fn state(&self) -> Result<Vec<ScalarValue>> { + Ok(vec![ScalarValue::Int64(Some(0))]) + } + + fn update_batch(&mut self, _values: &[ArrayRef]) -> Result<()> { + Ok(()) + } + + fn merge_batch(&mut self, _states: &[ArrayRef]) -> Result<()> { + Ok(()) + } + + fn evaluate(&self) -> Result<ScalarValue> { + Ok(ScalarValue::Int64(Some(0))) + } + + fn size(&self) -> usize { + 0 + } + } + + let rt_func: ReturnTypeFunction = + Arc::new(move |_| Ok(Arc::new(DataType::Int64))); + let accumulator: AccumulatorFunctionImplementation = + Arc::new(|_| Ok(Box::new(Example))); + let st_func: StateTypeFunction = + Arc::new(move |_| Ok(Arc::new(vec![DataType::Int64]))); + + let udaf = AggregateUDF::new( + "example", + &Signature::exact(vec![DataType::Int64], Volatility::Immutable), + &rt_func, + &accumulator, + &st_func, + ); + + let ctx = SessionContext::new(); + ctx.register_udaf(udaf.clone()); + + let groups: Vec<(Arc<dyn PhysicalExpr>, String)> = + vec![(col("a", &schema)?, "unused".to_string())]; + + let aggregates: Vec<Arc<dyn AggregateExpr>> = vec![udaf::create_aggregate_expr( + &udaf, + &[col("b", &schema)?], + &schema, + "example_agg", + )?]; + + roundtrip_test_with_context( + Arc::new(AggregateExec::try_new( + AggregateMode::Final, + PhysicalGroupBy::new_single(groups.clone()), + aggregates.clone(), + vec![None], + Arc::new(EmptyExec::new(false, schema.clone())), + schema, + )?), + ctx, + ) + } + #[test] fn roundtrip_filter_with_not_and_in_list() -> Result<()> { let field_a = Field::new("a", DataType::Boolean, false); diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index e18932575c..9495c841be 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -42,11 +42,12 @@ use datafusion::physical_plan::expressions::{ use datafusion::physical_plan::{AggregateExpr, PhysicalExpr}; use crate::protobuf; -use crate::protobuf::{PhysicalSortExprNode, ScalarValue}; +use crate::protobuf::{physical_aggregate_expr_node, PhysicalSortExprNode, ScalarValue}; use datafusion::logical_expr::BuiltinScalarFunction; use datafusion::physical_expr::expressions::{DateTimeIntervalExpr, GetIndexedFieldExpr}; use datafusion::physical_expr::ScalarFunctionExpr; use datafusion::physical_plan::joins::utils::JoinSide; +use datafusion::physical_plan::udaf::AggregateFunctionExpr; use datafusion_common::{DataFusionError, Result}; impl TryFrom<Arc<dyn AggregateExpr>> for protobuf::PhysicalExprNode { @@ -56,6 +57,12 @@ impl TryFrom<Arc<dyn AggregateExpr>> for protobuf::PhysicalExprNode { use datafusion::physical_plan::expressions; use protobuf::AggregateFunction; + let expressions: Vec<protobuf::PhysicalExprNode> = a + .expressions() + .iter() + .map(|e| e.clone().try_into()) + .collect::<Result<Vec<_>>>()?; + let mut distinct = false; let aggr_function = if a.as_any().downcast_ref::<Avg>().is_some() { Ok(AggregateFunction::Avg.into()) @@ -131,19 +138,31 @@ impl TryFrom<Arc<dyn AggregateExpr>> for protobuf::PhysicalExprNode { { Ok(AggregateFunction::ApproxMedian.into()) } else { + if let Some(a) = a.as_any().downcast_ref::<AggregateFunctionExpr>() { + return Ok(protobuf::PhysicalExprNode { + expr_type: Some(protobuf::physical_expr_node::ExprType::AggregateExpr( + protobuf::PhysicalAggregateExprNode { + aggregate_function: Some(physical_aggregate_expr_node::AggregateFunction::UserDefinedAggrFunction(a.fun().name.clone())), + expr: expressions, + distinct, + }, + )), + }); + } + Err(DataFusionError::NotImplemented(format!( "Aggregate function not supported: {a:?}" ))) }?; - let expressions: Vec<protobuf::PhysicalExprNode> = a - .expressions() - .iter() - .map(|e| e.clone().try_into()) - .collect::<Result<Vec<_>>>()?; + Ok(protobuf::PhysicalExprNode { expr_type: Some(protobuf::physical_expr_node::ExprType::AggregateExpr( protobuf::PhysicalAggregateExprNode { - aggr_function, + aggregate_function: Some( + physical_aggregate_expr_node::AggregateFunction::AggrFunction( + aggr_function, + ), + ), expr: expressions, distinct, },
