Github user icexelloss commented on a diff in the pull request: https://github.com/apache/spark/pull/19872#discussion_r161856927 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala --- @@ -334,34 +339,51 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { object Aggregation extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case PhysicalAggregation( - groupingExpressions, aggregateExpressions, resultExpressions, child) => - - val (functionsWithDistinct, functionsWithoutDistinct) = - aggregateExpressions.partition(_.isDistinct) - if (functionsWithDistinct.map(_.aggregateFunction.children).distinct.length > 1) { - // This is a sanity check. We should not reach here when we have multiple distinct - // column sets. Our MultipleDistinctRewriter should take care this case. - sys.error("You hit a query analyzer bug. Please report your query to " + - "Spark user mailing list.") - } + groupingExpressions, aggExpressions, resultExpressions, child) => + + if (aggExpressions.forall(expr => expr.isInstanceOf[AggregateExpression])) { - val aggregateOperator = - if (functionsWithDistinct.isEmpty) { - aggregate.AggUtils.planAggregateWithoutDistinct( - groupingExpressions, - aggregateExpressions, - resultExpressions, - planLater(child)) - } else { - aggregate.AggUtils.planAggregateWithOneDistinct( - groupingExpressions, - functionsWithDistinct, - functionsWithoutDistinct, - resultExpressions, - planLater(child)) + val aggregateExpressions = aggExpressions.map(expr => + expr.asInstanceOf[AggregateExpression]) + + val (functionsWithDistinct, functionsWithoutDistinct) = + aggregateExpressions.partition(_.isDistinct) + if (functionsWithDistinct.map(_.aggregateFunction.children).distinct.length > 1) { + // This is a sanity check. We should not reach here when we have multiple distinct + // column sets. Our MultipleDistinctRewriter should take care this case. + sys.error("You hit a query analyzer bug. Please report your query to " + + "Spark user mailing list.") } - aggregateOperator + val aggregateOperator = + if (functionsWithDistinct.isEmpty) { + aggregate.AggUtils.planAggregateWithoutDistinct( + groupingExpressions, + aggregateExpressions, + resultExpressions, + planLater(child)) + } else { + aggregate.AggUtils.planAggregateWithOneDistinct( + groupingExpressions, + functionsWithDistinct, + functionsWithoutDistinct, + resultExpressions, + planLater(child)) + } + + aggregateOperator + } else if (aggExpressions.forall(expr => expr.isInstanceOf[PythonUDF])) { + val udfExpressions = aggExpressions.map(expr => expr.asInstanceOf[PythonUDF]) + + Seq(execution.python.AggregateInPandasExec( + groupingExpressions, + udfExpressions, + resultExpressions, + planLater(child))) + } else { + throw new IllegalArgumentException( --- End diff -- +1. Let's double check in https://github.com/apache/spark/pull/19872#discussion_r161507315
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org