liukun4515 commented on code in PR #3768:
URL: https://github.com/apache/arrow-datafusion/pull/3768#discussion_r992870967
##########
datafusion/optimizer/src/type_coercion.rs:
##########
@@ -596,6 +659,123 @@ mod test {
Ok(())
}
+ #[test]
+ fn agg_udaf() -> Result<()> {
+ let empty = empty();
+ let my_avg = create_udaf(
+ "MY_AVG",
+ DataType::Float64,
+ Arc::new(DataType::Float64),
+ Volatility::Immutable,
+ Arc::new(|_|
Ok(Box::new(AvgAccumulator::try_new(&DataType::Float64)?))),
+ Arc::new(vec![DataType::UInt64, DataType::Float64]),
+ );
+ let udaf = Expr::AggregateUDF {
+ fun: Arc::new(my_avg),
+ args: vec![lit(10i64)],
+ filter: None,
+ };
+ let plan = LogicalPlan::Projection(Projection::try_new(vec![udaf],
empty, None)?);
+ let rule = TypeCoercion::new();
+ let mut config = OptimizerConfig::default();
+ let plan = rule.optimize(&plan, &mut config)?;
+ assert_eq!(
+ "Projection: MY_AVG(CAST(Int64(10) AS Float64))\n EmptyRelation",
+ &format!("{:?}", plan)
+ );
+ Ok(())
+ }
+
+ #[test]
+ fn agg_udaf_invalid_input() -> Result<()> {
+ let empty = empty();
+ let return_type: ReturnTypeFunction =
+ Arc::new(move |_| Ok(Arc::new(DataType::Float64)));
+ let state_type: StateTypeFunction =
+ Arc::new(move |_| Ok(Arc::new(vec![DataType::UInt64,
DataType::Float64])));
+ let accumulator: AccumulatorFunctionImplementation =
+ Arc::new(|_|
Ok(Box::new(AvgAccumulator::try_new(&DataType::Float64)?)));
+ let my_avg = AggregateUDF::new(
+ "MY_AVG",
+ &Signature::uniform(1, vec![DataType::Float64],
Volatility::Immutable),
+ &return_type,
+ &accumulator,
+ &state_type,
+ );
+ let udaf = Expr::AggregateUDF {
+ fun: Arc::new(my_avg),
+ args: vec![lit("10")],
+ filter: None,
+ };
+ let plan = LogicalPlan::Projection(Projection::try_new(vec![udaf],
empty, None)?);
+ let rule = TypeCoercion::new();
+ let mut config = OptimizerConfig::default();
+ let plan = rule.optimize(&plan, &mut config);
+ assert!(plan.is_err());
+ assert_eq!(
+ "Plan(\"Coercion from [Utf8] to the signature Uniform(1,
[Float64]) failed.\")",
+ &format!("{:?}", plan.err().unwrap())
+ );
+ Ok(())
+ }
+
+ #[test]
+ fn agg_function_case() -> Result<()> {
+ let empty = empty();
+ let fun: AggregateFunction = AggregateFunction::Avg;
+ let agg_expr = Expr::AggregateFunction {
+ fun,
+ args: vec![lit(12i64)],
+ distinct: false,
+ filter: None,
+ };
+ let plan =
+ LogicalPlan::Projection(Projection::try_new(vec![agg_expr], empty,
None)?);
+ let rule = TypeCoercion::new();
+ let mut config = OptimizerConfig::default();
+ let plan = rule.optimize(&plan, &mut config)?;
+ assert_eq!(
+ "Projection: AVG(Int64(12))\n EmptyRelation",
Review Comment:
Yes, you don't missing anything.
You can take a look `type_coercion::aggregates::coerce_types` function which
just check the input data type and don't do any coercion for the function.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]