[ https://issues.apache.org/jira/browse/SPARK-35184?page=com.atlassian.jira.plugin.system.issuetabpanels:all-tabpanel ]
Xiao Jin updated SPARK-35184: ----------------------------- Description: I found some strange error when I'm coding Pyspark UDAF. After I call groupBy function and agg function, I want to filter some data from remaining dataframe, but it seems not work. My sample code is below. {code:java} >>> from pyspark.sql.functions import pandas_udf, PandasUDFType, col >>> df = spark.createDataFrame( ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ... ("id", "v")) >>> @pandas_udf("double", PandasUDFType.GROUPED_AGG) ... def mean_udf(v): ... return v.mean() >>> df.groupby("id").agg(mean_udf(df['v']).alias("mean")).filter(col("mean") > >>> 5).show() {code} The code above will cause exception printed below {code:java} Traceback (most recent call last): File "<stdin>", line 1, in <module> File "/opt/spark/python/pyspark/sql/dataframe.py", line 378, in show print(self._jdf.showString(n, 20, vertical)) File "/opt/spark/python/lib/py4j-0.10.7-src.zip/py4j/java_gateway.py", line 1257, in __call__ File "/opt/spark/python/pyspark/sql/utils.py", line 63, in deco return f(*a, **kw) File "/opt/spark/python/lib/py4j-0.10.7-src.zip/py4j/protocol.py", line 328, in get_return_value py4j.protocol.Py4JJavaError: An error occurred while calling o3717.showString. : org.apache.spark.sql.catalyst.errors.package$TreeNodeException: execute, tree: Exchange hashpartitioning(id#1726L, 200) +- *(1) Filter (mean_udf(v#1727) > 5.0) +- Scan ExistingRDD[id#1726L,v#1727] at org.apache.spark.sql.catalyst.errors.package$.attachTree(package.scala:56) at org.apache.spark.sql.execution.exchange.ShuffleExchangeExec.doExecute(ShuffleExchangeExec.scala:119) at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:131) at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:127) at org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$1.apply(SparkPlan.scala:155) at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151) at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:152) at org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:127) at org.apache.spark.sql.execution.InputAdapter.inputRDDs(WholeStageCodegenExec.scala:391) at org.apache.spark.sql.execution.SortExec.inputRDDs(SortExec.scala:121) at org.apache.spark.sql.execution.WholeStageCodegenExec.doExecute(WholeStageCodegenExec.scala:627) at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:131) at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:127) at org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$1.apply(SparkPlan.scala:155) at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151) at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:152) at org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:127) at org.apache.spark.sql.execution.python.AggregateInPandasExec.doExecute(AggregateInPandasExec.scala:80) at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:131) at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:127) at org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$1.apply(SparkPlan.scala:155) at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151) at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:152) at org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:127) at org.apache.spark.sql.execution.SparkPlan.getByteArrayRdd(SparkPlan.scala:247) at org.apache.spark.sql.execution.SparkPlan.executeTake(SparkPlan.scala:339) at org.apache.spark.sql.execution.CollectLimitExec.executeCollect(limit.scala:38) at org.apache.spark.sql.Dataset.org$apache$spark$sql$Dataset$$collectFromPlan(Dataset.scala:3383) at org.apache.spark.sql.Dataset$$anonfun$head$1.apply(Dataset.scala:2544) at org.apache.spark.sql.Dataset$$anonfun$head$1.apply(Dataset.scala:2544) at org.apache.spark.sql.Dataset$$anonfun$53.apply(Dataset.scala:3364) at org.apache.spark.sql.execution.SQLExecution$$anonfun$withNewExecutionId$1.apply(SQLExecution.scala:78) at org.apache.spark.sql.execution.SQLExecution$.withSQLConfPropagated(SQLExecution.scala:125) at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:73) at org.apache.spark.sql.Dataset.withAction(Dataset.scala:3363) at org.apache.spark.sql.Dataset.head(Dataset.scala:2544) at org.apache.spark.sql.Dataset.take(Dataset.scala:2758) at org.apache.spark.sql.Dataset.getRows(Dataset.scala:254) at org.apache.spark.sql.Dataset.showString(Dataset.scala:291) at sun.reflect.GeneratedMethodAccessor139.invoke(Unknown Source) at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43) at java.lang.reflect.Method.invoke(Method.java:498) at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244) at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357) at py4j.Gateway.invoke(Gateway.java:282) at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132) at py4j.commands.CallCommand.execute(CallCommand.java:79) at py4j.GatewayConnection.run(GatewayConnection.java:238) at java.lang.Thread.run(Thread.java:748) Caused by: java.lang.UnsupportedOperationException: Cannot evaluate expression: mean_udf(input[1, double, true]) at org.apache.spark.sql.catalyst.expressions.Unevaluable$class.doGenCode(Expression.scala:261) at org.apache.spark.sql.catalyst.expressions.PythonUDF.doGenCode(PythonUDF.scala:50) at org.apache.spark.sql.catalyst.expressions.Expression$$anonfun$genCode$2.apply(Expression.scala:108) at org.apache.spark.sql.catalyst.expressions.Expression$$anonfun$genCode$2.apply(Expression.scala:105) at scala.Option.getOrElse(Option.scala:121) at org.apache.spark.sql.catalyst.expressions.Expression.genCode(Expression.scala:105) at org.apache.spark.sql.catalyst.expressions.BinaryExpression.nullSafeCodeGen(Expression.scala:525) at org.apache.spark.sql.catalyst.expressions.BinaryExpression.defineCodeGen(Expression.scala:508) at org.apache.spark.sql.catalyst.expressions.BinaryComparison.doGenCode(predicates.scala:563) at org.apache.spark.sql.catalyst.expressions.Expression$$anonfun$genCode$2.apply(Expression.scala:108) at org.apache.spark.sql.catalyst.expressions.Expression$$anonfun$genCode$2.apply(Expression.scala:105) at scala.Option.getOrElse(Option.scala:121) at org.apache.spark.sql.catalyst.expressions.Expression.genCode(Expression.scala:105) at org.apache.spark.sql.execution.FilterExec.org$apache$spark$sql$execution$FilterExec$$genPredicate$1(basicPhysicalOperators.scala:139) at org.apache.spark.sql.execution.FilterExec$$anonfun$13.apply(basicPhysicalOperators.scala:179) at org.apache.spark.sql.execution.FilterExec$$anonfun$13.apply(basicPhysicalOperators.scala:163) at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234) at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234) at scala.collection.immutable.List.foreach(List.scala:392) at scala.collection.TraversableLike$class.map(TraversableLike.scala:234) at scala.collection.immutable.List.map(List.scala:296) at org.apache.spark.sql.execution.FilterExec.doConsume(basicPhysicalOperators.scala:163) at org.apache.spark.sql.execution.CodegenSupport$class.consume(WholeStageCodegenExec.scala:189) at org.apache.spark.sql.execution.InputAdapter.consume(WholeStageCodegenExec.scala:374) at org.apache.spark.sql.execution.InputAdapter.doProduce(WholeStageCodegenExec.scala:403) at org.apache.spark.sql.execution.CodegenSupport$$anonfun$produce$1.apply(WholeStageCodegenExec.scala:90) at org.apache.spark.sql.execution.CodegenSupport$$anonfun$produce$1.apply(WholeStageCodegenExec.scala:85) at org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$1.apply(SparkPlan.scala:155) at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151) at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:152) at org.apache.spark.sql.execution.CodegenSupport$class.produce(WholeStageCodegenExec.scala:85) at org.apache.spark.sql.execution.InputAdapter.produce(WholeStageCodegenExec.scala:374) at org.apache.spark.sql.execution.FilterExec.doProduce(basicPhysicalOperators.scala:125) at org.apache.spark.sql.execution.CodegenSupport$$anonfun$produce$1.apply(WholeStageCodegenExec.scala:90) at org.apache.spark.sql.execution.CodegenSupport$$anonfun$produce$1.apply(WholeStageCodegenExec.scala:85) at org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$1.apply(SparkPlan.scala:155) at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151) at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:152) at org.apache.spark.sql.execution.CodegenSupport$class.produce(WholeStageCodegenExec.scala:85) at org.apache.spark.sql.execution.FilterExec.produce(basicPhysicalOperators.scala:85) at org.apache.spark.sql.execution.WholeStageCodegenExec.doCodeGen(WholeStageCodegenExec.scala:544) at org.apache.spark.sql.execution.WholeStageCodegenExec.doExecute(WholeStageCodegenExec.scala:598) at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:131) at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:127) at org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$1.apply(SparkPlan.scala:155) at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151) at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:152) at org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:127) at org.apache.spark.sql.execution.exchange.ShuffleExchangeExec.prepareShuffleDependency(ShuffleExchangeExec.scala:92) at org.apache.spark.sql.execution.exchange.ShuffleExchangeExec$$anonfun$doExecute$1.apply(ShuffleExchangeExec.scala:128) at org.apache.spark.sql.execution.exchange.ShuffleExchangeExec$$anonfun$doExecute$1.apply(ShuffleExchangeExec.scala:119) at org.apache.spark.sql.catalyst.errors.package$.attachTree(package.scala:52) ... 48 more {code} Optimized Logical Plan here, I found Optimizer had already push down the Filter through PushDownPredicates rule. {code:java} >>> df.groupby("id").agg(mean_udf(df['v']).alias("mean")).filter(col("mean") > >>> 5).explain(True) == Parsed Logical Plan == 'Filter ('mean > 5) +- Aggregate [id#0L], [id#0L, mean_udf(v#1) AS mean#79] +- LogicalRDD [id#0L, v#1], false== Analyzed Logical Plan == id: bigint, mean: double Filter (mean#79 > cast(5 as double)) +- Aggregate [id#0L], [id#0L, mean_udf(v#1) AS mean#79] +- LogicalRDD [id#0L, v#1], false== Optimized Logical Plan == Aggregate [id#0L], [id#0L, mean_udf(v#1) AS mean#79] +- Filter (mean_udf(v#1) > 5.0) +- LogicalRDD [id#0L, v#1], false== Physical Plan == !AggregateInPandas [id#0L], [mean_udf(v#1)], [id#0L, mean_udf(v)#78 AS mean#79] +- *(2) Sort [id#0L ASC NULLS FIRST], false, 0 +- Exchange hashpartitioning(id#0L, 200) +- *(1) Filter (mean_udf(v#1) > 5.0) +- Scan ExistingRDD[id#0L,v#1] {code} Compare with the official mean function, it will not push down Filter node throuph PushDownPredicates rule. {code:java} >>> from pyspark.sql import functions as F >>> df.groupby("id").agg(F.mean(df['v']).alias("mean")).filter(col("mean") > >>> 5).explain(True) == Parsed Logical Plan == 'Filter ('mean > 5) +- Aggregate [id#0L], [id#0L, avg(v#1) AS mean#7] +- LogicalRDD [id#0L, v#1], false== Analyzed Logical Plan == id: bigint, mean: double Filter (mean#7 > cast(5 as double)) +- Aggregate [id#0L], [id#0L, avg(v#1) AS mean#7] +- LogicalRDD [id#0L, v#1], false== Optimized Logical Plan == Filter (isnotnull(mean#7) && (mean#7 > 5.0)) +- Aggregate [id#0L], [id#0L, avg(v#1) AS mean#7] +- LogicalRDD [id#0L, v#1], false== Physical Plan == *(2) Filter (isnotnull(mean#7) && (mean#7 > 5.0)) +- *(2) HashAggregate(keys=[id#0L], functions=[avg(v#1)], output=[id#0L, mean#7]) +- Exchange hashpartitioning(id#0L, 200) +- *(1) HashAggregate(keys=[id#0L], functions=[partial_avg(v#1)], output=[id#0L, sum#15, count#16L]) +- Scan ExistingRDD[id#0L,v#1] {code} And see the code in PushPredicateThroughNonJoin rule matched our case below. {code:java} case filter @ Filter(condition, aggregate: Aggregate) if aggregate.aggregateExpressions.forall(_.deterministic) && aggregate.groupingExpressions.nonEmpty => val aliasMap = getAliasMap(aggregate) // For each filter, expand the alias and check if the filter can be evaluated using // attributes produced by the aggregate operator's child operator. val (candidates, nonDeterministic) = splitConjunctivePredicates(condition).partition(_.deterministic) val (pushDown, rest) = candidates.partition { cond => val replaced = replaceAlias(cond, aliasMap) cond.references.nonEmpty && replaced.references.subsetOf(aggregate.child.outputSet) } val stayUp = rest ++ nonDeterministic if (pushDown.nonEmpty) { val pushDownPredicate = pushDown.reduce(And) val replaced = replaceAlias(pushDownPredicate, aliasMap) val newAggregate = aggregate.copy(child = Filter(replaced, aggregate.child)) // If there is no more filter to stay up, just eliminate the filter. // Otherwise, create "Filter(stayUp) <- Aggregate <- Filter(pushDownPredicate)". if (stayUp.isEmpty) newAggregate else Filter(stayUp.reduce(And), newAggregate) } else { filter } {code} It's easy to infer when I use Python UDAF function, the references in condition is the subset of child of Aggregate node, because Python UDAF function in Catalyst is actually a PythonUDF expression, with its references is the references of its input expression {code:java} case class PythonUDF( name: String, func: PythonFunction, dataType: DataType, children: Seq[Expression], evalType: Int, udfDeterministic: Boolean, resultId: ExprId = NamedExpression.newExprId) {code} But the official mean function in Catalyst is Average expression, which is DeclarativeAggregate with multiple aggBufferAttributes, which means the references of Average is a sumDataType and a LongType. {code:java} case class Average(child: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes { ... private lazy val sum = AttributeReference("sum", sumDataType)() private lazy val count = AttributeReference("count", LongType)() override lazy val aggBufferAttributes = sum :: count :: Nil ...{code} {code:java} case class AggregateExpression( aggregateFunction: AggregateFunction, mode: AggregateMode, isDistinct: Boolean, filter: Option[Expression], resultId: ExprId) extends Expression with Unevaluable { ... @transient override lazy val references: AttributeSet = { val aggAttributes = mode match { case Partial | Complete => aggregateFunction.references case PartialMerge | Final => AttributeSet(aggregateFunction.inputAggBufferAttributes) } aggAttributes ++ filterAttributes }{code} So, the references in PythonUDF is the subset of Aggregate's child's output but Average is not. I think the root cause of the problem is Catalyst does not treat the Pandas UDAF as real AggregateFunction, so the Pandas UDAF will optimized like normal UDF function. Maybe it's time to redesign the definition of the Pandas UDAF? So it can get on the right track? PS: All the speculation above is only a guess. was: I found some strange error when I'm coding Pyspark UDAF. After I call groupBy function and agg function, I want to filter some data from remaining dataframe, but it seems not work. My sample code is below. {code:java} >>> from pyspark.sql.functions import pandas_udf, PandasUDFType, col >>> df = spark.createDataFrame( ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ... ("id", "v")) >>> @pandas_udf("double", PandasUDFType.GROUPED_AGG) ... def mean_udf(v): ... return v.mean() >>> df.groupby("id").agg(mean_udf(df['v']).alias("mean")).filter(col("mean") > >>> 5).show() {code} The code above will cause exception printed below {code:java} Traceback (most recent call last): File "<stdin>", line 1, in <module> File "/opt/spark/python/pyspark/sql/dataframe.py", line 378, in show print(self._jdf.showString(n, 20, vertical)) File "/opt/spark/python/lib/py4j-0.10.7-src.zip/py4j/java_gateway.py", line 1257, in __call__ File "/opt/spark/python/pyspark/sql/utils.py", line 63, in deco return f(*a, **kw) File "/opt/spark/python/lib/py4j-0.10.7-src.zip/py4j/protocol.py", line 328, in get_return_value py4j.protocol.Py4JJavaError: An error occurred while calling o3717.showString. : org.apache.spark.sql.catalyst.errors.package$TreeNodeException: execute, tree: Exchange hashpartitioning(id#1726L, 200) +- *(1) Filter (mean_udf(v#1727) > 5.0) +- Scan ExistingRDD[id#1726L,v#1727] at org.apache.spark.sql.catalyst.errors.package$.attachTree(package.scala:56) at org.apache.spark.sql.execution.exchange.ShuffleExchangeExec.doExecute(ShuffleExchangeExec.scala:119) at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:131) at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:127) at org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$1.apply(SparkPlan.scala:155) at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151) at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:152) at org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:127) at org.apache.spark.sql.execution.InputAdapter.inputRDDs(WholeStageCodegenExec.scala:391) at org.apache.spark.sql.execution.SortExec.inputRDDs(SortExec.scala:121) at org.apache.spark.sql.execution.WholeStageCodegenExec.doExecute(WholeStageCodegenExec.scala:627) at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:131) at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:127) at org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$1.apply(SparkPlan.scala:155) at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151) at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:152) at org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:127) at org.apache.spark.sql.execution.python.AggregateInPandasExec.doExecute(AggregateInPandasExec.scala:80) at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:131) at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:127) at org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$1.apply(SparkPlan.scala:155) at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151) at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:152) at org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:127) at org.apache.spark.sql.execution.SparkPlan.getByteArrayRdd(SparkPlan.scala:247) at org.apache.spark.sql.execution.SparkPlan.executeTake(SparkPlan.scala:339) at org.apache.spark.sql.execution.CollectLimitExec.executeCollect(limit.scala:38) at org.apache.spark.sql.Dataset.org$apache$spark$sql$Dataset$$collectFromPlan(Dataset.scala:3383) at org.apache.spark.sql.Dataset$$anonfun$head$1.apply(Dataset.scala:2544) at org.apache.spark.sql.Dataset$$anonfun$head$1.apply(Dataset.scala:2544) at org.apache.spark.sql.Dataset$$anonfun$53.apply(Dataset.scala:3364) at org.apache.spark.sql.execution.SQLExecution$$anonfun$withNewExecutionId$1.apply(SQLExecution.scala:78) at org.apache.spark.sql.execution.SQLExecution$.withSQLConfPropagated(SQLExecution.scala:125) at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:73) at org.apache.spark.sql.Dataset.withAction(Dataset.scala:3363) at org.apache.spark.sql.Dataset.head(Dataset.scala:2544) at org.apache.spark.sql.Dataset.take(Dataset.scala:2758) at org.apache.spark.sql.Dataset.getRows(Dataset.scala:254) at org.apache.spark.sql.Dataset.showString(Dataset.scala:291) at sun.reflect.GeneratedMethodAccessor139.invoke(Unknown Source) at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43) at java.lang.reflect.Method.invoke(Method.java:498) at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244) at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357) at py4j.Gateway.invoke(Gateway.java:282) at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132) at py4j.commands.CallCommand.execute(CallCommand.java:79) at py4j.GatewayConnection.run(GatewayConnection.java:238) at java.lang.Thread.run(Thread.java:748) Caused by: java.lang.UnsupportedOperationException: Cannot evaluate expression: mean_udf(input[1, double, true]) at org.apache.spark.sql.catalyst.expressions.Unevaluable$class.doGenCode(Expression.scala:261) at org.apache.spark.sql.catalyst.expressions.PythonUDF.doGenCode(PythonUDF.scala:50) at org.apache.spark.sql.catalyst.expressions.Expression$$anonfun$genCode$2.apply(Expression.scala:108) at org.apache.spark.sql.catalyst.expressions.Expression$$anonfun$genCode$2.apply(Expression.scala:105) at scala.Option.getOrElse(Option.scala:121) at org.apache.spark.sql.catalyst.expressions.Expression.genCode(Expression.scala:105) at org.apache.spark.sql.catalyst.expressions.BinaryExpression.nullSafeCodeGen(Expression.scala:525) at org.apache.spark.sql.catalyst.expressions.BinaryExpression.defineCodeGen(Expression.scala:508) at org.apache.spark.sql.catalyst.expressions.BinaryComparison.doGenCode(predicates.scala:563) at org.apache.spark.sql.catalyst.expressions.Expression$$anonfun$genCode$2.apply(Expression.scala:108) at org.apache.spark.sql.catalyst.expressions.Expression$$anonfun$genCode$2.apply(Expression.scala:105) at scala.Option.getOrElse(Option.scala:121) at org.apache.spark.sql.catalyst.expressions.Expression.genCode(Expression.scala:105) at org.apache.spark.sql.execution.FilterExec.org$apache$spark$sql$execution$FilterExec$$genPredicate$1(basicPhysicalOperators.scala:139) at org.apache.spark.sql.execution.FilterExec$$anonfun$13.apply(basicPhysicalOperators.scala:179) at org.apache.spark.sql.execution.FilterExec$$anonfun$13.apply(basicPhysicalOperators.scala:163) at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234) at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234) at scala.collection.immutable.List.foreach(List.scala:392) at scala.collection.TraversableLike$class.map(TraversableLike.scala:234) at scala.collection.immutable.List.map(List.scala:296) at org.apache.spark.sql.execution.FilterExec.doConsume(basicPhysicalOperators.scala:163) at org.apache.spark.sql.execution.CodegenSupport$class.consume(WholeStageCodegenExec.scala:189) at org.apache.spark.sql.execution.InputAdapter.consume(WholeStageCodegenExec.scala:374) at org.apache.spark.sql.execution.InputAdapter.doProduce(WholeStageCodegenExec.scala:403) at org.apache.spark.sql.execution.CodegenSupport$$anonfun$produce$1.apply(WholeStageCodegenExec.scala:90) at org.apache.spark.sql.execution.CodegenSupport$$anonfun$produce$1.apply(WholeStageCodegenExec.scala:85) at org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$1.apply(SparkPlan.scala:155) at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151) at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:152) at org.apache.spark.sql.execution.CodegenSupport$class.produce(WholeStageCodegenExec.scala:85) at org.apache.spark.sql.execution.InputAdapter.produce(WholeStageCodegenExec.scala:374) at org.apache.spark.sql.execution.FilterExec.doProduce(basicPhysicalOperators.scala:125) at org.apache.spark.sql.execution.CodegenSupport$$anonfun$produce$1.apply(WholeStageCodegenExec.scala:90) at org.apache.spark.sql.execution.CodegenSupport$$anonfun$produce$1.apply(WholeStageCodegenExec.scala:85) at org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$1.apply(SparkPlan.scala:155) at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151) at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:152) at org.apache.spark.sql.execution.CodegenSupport$class.produce(WholeStageCodegenExec.scala:85) at org.apache.spark.sql.execution.FilterExec.produce(basicPhysicalOperators.scala:85) at org.apache.spark.sql.execution.WholeStageCodegenExec.doCodeGen(WholeStageCodegenExec.scala:544) at org.apache.spark.sql.execution.WholeStageCodegenExec.doExecute(WholeStageCodegenExec.scala:598) at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:131) at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:127) at org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$1.apply(SparkPlan.scala:155) at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151) at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:152) at org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:127) at org.apache.spark.sql.execution.exchange.ShuffleExchangeExec.prepareShuffleDependency(ShuffleExchangeExec.scala:92) at org.apache.spark.sql.execution.exchange.ShuffleExchangeExec$$anonfun$doExecute$1.apply(ShuffleExchangeExec.scala:128) at org.apache.spark.sql.execution.exchange.ShuffleExchangeExec$$anonfun$doExecute$1.apply(ShuffleExchangeExec.scala:119) at org.apache.spark.sql.catalyst.errors.package$.attachTree(package.scala:52) ... 48 more {code} Optimized Logical Plan here, I found Optimizer had already push down the Filter through PushDownPredicates rule. {code:java} >>> df.groupby("id").agg(mean_udf(df['v']).alias("mean")).filter(col("mean") > >>> 5).explain(True) == Parsed Logical Plan == 'Filter ('mean > 5) +- Aggregate [id#0L], [id#0L, mean_udf(v#1) AS mean#79] +- LogicalRDD [id#0L, v#1], false== Analyzed Logical Plan == id: bigint, mean: double Filter (mean#79 > cast(5 as double)) +- Aggregate [id#0L], [id#0L, mean_udf(v#1) AS mean#79] +- LogicalRDD [id#0L, v#1], false== Optimized Logical Plan == Aggregate [id#0L], [id#0L, mean_udf(v#1) AS mean#79] +- Filter (mean_udf(v#1) > 5.0) +- LogicalRDD [id#0L, v#1], false== Physical Plan == !AggregateInPandas [id#0L], [mean_udf(v#1)], [id#0L, mean_udf(v)#78 AS mean#79] +- *(2) Sort [id#0L ASC NULLS FIRST], false, 0 +- Exchange hashpartitioning(id#0L, 200) +- *(1) Filter (mean_udf(v#1) > 5.0) +- Scan ExistingRDD[id#0L,v#1] {code} Compare with the official mean function, it will not push down Filter node throuph PushDownPredicates rule. {code:java} >>> df.groupby("id").agg(mean_udf(df['v']).alias("mean")).filter(col("mean") > >>> 5).explain(True) == Parsed Logical Plan == 'Filter ('mean > 5) +- Aggregate [id#0L], [id#0L, mean_udf(v#1) AS mean#79] +- LogicalRDD [id#0L, v#1], false== Analyzed Logical Plan == id: bigint, mean: double Filter (mean#79 > cast(5 as double)) +- Aggregate [id#0L], [id#0L, mean_udf(v#1) AS mean#79] +- LogicalRDD [id#0L, v#1], false== Optimized Logical Plan == Aggregate [id#0L], [id#0L, mean_udf(v#1) AS mean#79] +- Filter (mean_udf(v#1) > 5.0) +- LogicalRDD [id#0L, v#1], false== Physical Plan == !AggregateInPandas [id#0L], [mean_udf(v#1)], [id#0L, mean_udf(v)#78 AS mean#79] +- *(2) Sort [id#0L ASC NULLS FIRST], false, 0 +- Exchange hashpartitioning(id#0L, 200) +- *(1) Filter (mean_udf(v#1) > 5.0) +- Scan ExistingRDD[id#0L,v#1] {code} And see the code in PushPredicateThroughNonJoin rule matched our case below. {code:java} case filter @ Filter(condition, aggregate: Aggregate) if aggregate.aggregateExpressions.forall(_.deterministic) && aggregate.groupingExpressions.nonEmpty => val aliasMap = getAliasMap(aggregate) // For each filter, expand the alias and check if the filter can be evaluated using // attributes produced by the aggregate operator's child operator. val (candidates, nonDeterministic) = splitConjunctivePredicates(condition).partition(_.deterministic) val (pushDown, rest) = candidates.partition { cond => val replaced = replaceAlias(cond, aliasMap) cond.references.nonEmpty && replaced.references.subsetOf(aggregate.child.outputSet) } val stayUp = rest ++ nonDeterministic if (pushDown.nonEmpty) { val pushDownPredicate = pushDown.reduce(And) val replaced = replaceAlias(pushDownPredicate, aliasMap) val newAggregate = aggregate.copy(child = Filter(replaced, aggregate.child)) // If there is no more filter to stay up, just eliminate the filter. // Otherwise, create "Filter(stayUp) <- Aggregate <- Filter(pushDownPredicate)". if (stayUp.isEmpty) newAggregate else Filter(stayUp.reduce(And), newAggregate) } else { filter } {code} It's easy to infer when I use Python UDAF function, the references in condition is the subset of child of Aggregate node, because Python UDAF function in Catalyst is actually a PythonUDF expression, with its references is the references of its input expression {code:java} case class PythonUDF( name: String, func: PythonFunction, dataType: DataType, children: Seq[Expression], evalType: Int, udfDeterministic: Boolean, resultId: ExprId = NamedExpression.newExprId) {code} But the official mean function in Catalyst is Average expression, which is DeclarativeAggregate with multiple aggBufferAttributes, which means the references of Average is a sumDataType and a LongType. {code:java} case class Average(child: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes { ... private lazy val sum = AttributeReference("sum", sumDataType)() private lazy val count = AttributeReference("count", LongType)() override lazy val aggBufferAttributes = sum :: count :: Nil ...{code} {code:java} case class AggregateExpression( aggregateFunction: AggregateFunction, mode: AggregateMode, isDistinct: Boolean, filter: Option[Expression], resultId: ExprId) extends Expression with Unevaluable { ... @transient override lazy val references: AttributeSet = { val aggAttributes = mode match { case Partial | Complete => aggregateFunction.references case PartialMerge | Final => AttributeSet(aggregateFunction.inputAggBufferAttributes) } aggAttributes ++ filterAttributes }{code} So, the references in PythonUDF is the subset of Aggregate's child's output but Average is not. I think the root cause of the problem is Catalyst does not treat the Pandas UDAF as real AggregateFunction, so the Pandas UDAF will optimized like normal UDF function. Maybe it's time to redesign the definition of the Pandas UDAF? So it can get on the right track? PS: All the speculation above is only a guess. > Filtering a dataframe after groupBy and user-define-aggregate-function in > Pyspark will cause java.lang.UnsupportedOperationException > ------------------------------------------------------------------------------------------------------------------------------------ > > Key: SPARK-35184 > URL: https://issues.apache.org/jira/browse/SPARK-35184 > Project: Spark > Issue Type: Bug > Components: Optimizer > Affects Versions: 2.4.0 > Reporter: Xiao Jin > Priority: Major > > I found some strange error when I'm coding Pyspark UDAF. After I call groupBy > function and agg function, I want to filter some data from remaining > dataframe, but it seems not work. My sample code is below. > {code:java} > >>> from pyspark.sql.functions import pandas_udf, PandasUDFType, col > >>> df = spark.createDataFrame( > ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], > ... ("id", "v")) > >>> @pandas_udf("double", PandasUDFType.GROUPED_AGG) > ... def mean_udf(v): > ... return v.mean() > >>> df.groupby("id").agg(mean_udf(df['v']).alias("mean")).filter(col("mean") > >>> > 5).show() > {code} > The code above will cause exception printed below > {code:java} > Traceback (most recent call last): > File "<stdin>", line 1, in <module> > File "/opt/spark/python/pyspark/sql/dataframe.py", line 378, in show > print(self._jdf.showString(n, 20, vertical)) > File "/opt/spark/python/lib/py4j-0.10.7-src.zip/py4j/java_gateway.py", line > 1257, in __call__ > File "/opt/spark/python/pyspark/sql/utils.py", line 63, in deco > return f(*a, **kw) > File "/opt/spark/python/lib/py4j-0.10.7-src.zip/py4j/protocol.py", line > 328, in get_return_value > py4j.protocol.Py4JJavaError: An error occurred while calling o3717.showString. > : org.apache.spark.sql.catalyst.errors.package$TreeNodeException: execute, > tree: > Exchange hashpartitioning(id#1726L, 200) > +- *(1) Filter (mean_udf(v#1727) > 5.0) > +- Scan ExistingRDD[id#1726L,v#1727] > at > org.apache.spark.sql.catalyst.errors.package$.attachTree(package.scala:56) > at > org.apache.spark.sql.execution.exchange.ShuffleExchangeExec.doExecute(ShuffleExchangeExec.scala:119) > at > org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:131) > at > org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:127) > at > org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$1.apply(SparkPlan.scala:155) > at > org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151) > at > org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:152) > at > org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:127) > at > org.apache.spark.sql.execution.InputAdapter.inputRDDs(WholeStageCodegenExec.scala:391) > at > org.apache.spark.sql.execution.SortExec.inputRDDs(SortExec.scala:121) > at > org.apache.spark.sql.execution.WholeStageCodegenExec.doExecute(WholeStageCodegenExec.scala:627) > at > org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:131) > at > org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:127) > at > org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$1.apply(SparkPlan.scala:155) > at > org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151) > at > org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:152) > at > org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:127) > at > org.apache.spark.sql.execution.python.AggregateInPandasExec.doExecute(AggregateInPandasExec.scala:80) > at > org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:131) > at > org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:127) > at > org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$1.apply(SparkPlan.scala:155) > at > org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151) > at > org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:152) > at > org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:127) > at > org.apache.spark.sql.execution.SparkPlan.getByteArrayRdd(SparkPlan.scala:247) > at > org.apache.spark.sql.execution.SparkPlan.executeTake(SparkPlan.scala:339) > at > org.apache.spark.sql.execution.CollectLimitExec.executeCollect(limit.scala:38) > at > org.apache.spark.sql.Dataset.org$apache$spark$sql$Dataset$$collectFromPlan(Dataset.scala:3383) > at > org.apache.spark.sql.Dataset$$anonfun$head$1.apply(Dataset.scala:2544) > at > org.apache.spark.sql.Dataset$$anonfun$head$1.apply(Dataset.scala:2544) > at org.apache.spark.sql.Dataset$$anonfun$53.apply(Dataset.scala:3364) > at > org.apache.spark.sql.execution.SQLExecution$$anonfun$withNewExecutionId$1.apply(SQLExecution.scala:78) > at > org.apache.spark.sql.execution.SQLExecution$.withSQLConfPropagated(SQLExecution.scala:125) > at > org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:73) > at org.apache.spark.sql.Dataset.withAction(Dataset.scala:3363) > at org.apache.spark.sql.Dataset.head(Dataset.scala:2544) > at org.apache.spark.sql.Dataset.take(Dataset.scala:2758) > at org.apache.spark.sql.Dataset.getRows(Dataset.scala:254) > at org.apache.spark.sql.Dataset.showString(Dataset.scala:291) > at sun.reflect.GeneratedMethodAccessor139.invoke(Unknown Source) > at > sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43) > at java.lang.reflect.Method.invoke(Method.java:498) > at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244) > at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357) > at py4j.Gateway.invoke(Gateway.java:282) > at > py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132) > at py4j.commands.CallCommand.execute(CallCommand.java:79) > at py4j.GatewayConnection.run(GatewayConnection.java:238) > at java.lang.Thread.run(Thread.java:748) > Caused by: java.lang.UnsupportedOperationException: Cannot evaluate > expression: mean_udf(input[1, double, true]) > at > org.apache.spark.sql.catalyst.expressions.Unevaluable$class.doGenCode(Expression.scala:261) > at > org.apache.spark.sql.catalyst.expressions.PythonUDF.doGenCode(PythonUDF.scala:50) > at > org.apache.spark.sql.catalyst.expressions.Expression$$anonfun$genCode$2.apply(Expression.scala:108) > at > org.apache.spark.sql.catalyst.expressions.Expression$$anonfun$genCode$2.apply(Expression.scala:105) > at scala.Option.getOrElse(Option.scala:121) > at > org.apache.spark.sql.catalyst.expressions.Expression.genCode(Expression.scala:105) > at > org.apache.spark.sql.catalyst.expressions.BinaryExpression.nullSafeCodeGen(Expression.scala:525) > at > org.apache.spark.sql.catalyst.expressions.BinaryExpression.defineCodeGen(Expression.scala:508) > at > org.apache.spark.sql.catalyst.expressions.BinaryComparison.doGenCode(predicates.scala:563) > at > org.apache.spark.sql.catalyst.expressions.Expression$$anonfun$genCode$2.apply(Expression.scala:108) > at > org.apache.spark.sql.catalyst.expressions.Expression$$anonfun$genCode$2.apply(Expression.scala:105) > at scala.Option.getOrElse(Option.scala:121) > at > org.apache.spark.sql.catalyst.expressions.Expression.genCode(Expression.scala:105) > at > org.apache.spark.sql.execution.FilterExec.org$apache$spark$sql$execution$FilterExec$$genPredicate$1(basicPhysicalOperators.scala:139) > at > org.apache.spark.sql.execution.FilterExec$$anonfun$13.apply(basicPhysicalOperators.scala:179) > at > org.apache.spark.sql.execution.FilterExec$$anonfun$13.apply(basicPhysicalOperators.scala:163) > at > scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234) > at > scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234) > at scala.collection.immutable.List.foreach(List.scala:392) > at > scala.collection.TraversableLike$class.map(TraversableLike.scala:234) > at scala.collection.immutable.List.map(List.scala:296) > at > org.apache.spark.sql.execution.FilterExec.doConsume(basicPhysicalOperators.scala:163) > at > org.apache.spark.sql.execution.CodegenSupport$class.consume(WholeStageCodegenExec.scala:189) > at > org.apache.spark.sql.execution.InputAdapter.consume(WholeStageCodegenExec.scala:374) > at > org.apache.spark.sql.execution.InputAdapter.doProduce(WholeStageCodegenExec.scala:403) > at > org.apache.spark.sql.execution.CodegenSupport$$anonfun$produce$1.apply(WholeStageCodegenExec.scala:90) > at > org.apache.spark.sql.execution.CodegenSupport$$anonfun$produce$1.apply(WholeStageCodegenExec.scala:85) > at > org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$1.apply(SparkPlan.scala:155) > at > org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151) > at > org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:152) > at > org.apache.spark.sql.execution.CodegenSupport$class.produce(WholeStageCodegenExec.scala:85) > at > org.apache.spark.sql.execution.InputAdapter.produce(WholeStageCodegenExec.scala:374) > at > org.apache.spark.sql.execution.FilterExec.doProduce(basicPhysicalOperators.scala:125) > at > org.apache.spark.sql.execution.CodegenSupport$$anonfun$produce$1.apply(WholeStageCodegenExec.scala:90) > at > org.apache.spark.sql.execution.CodegenSupport$$anonfun$produce$1.apply(WholeStageCodegenExec.scala:85) > at > org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$1.apply(SparkPlan.scala:155) > at > org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151) > at > org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:152) > at > org.apache.spark.sql.execution.CodegenSupport$class.produce(WholeStageCodegenExec.scala:85) > at > org.apache.spark.sql.execution.FilterExec.produce(basicPhysicalOperators.scala:85) > at > org.apache.spark.sql.execution.WholeStageCodegenExec.doCodeGen(WholeStageCodegenExec.scala:544) > at > org.apache.spark.sql.execution.WholeStageCodegenExec.doExecute(WholeStageCodegenExec.scala:598) > at > org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:131) > at > org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:127) > at > org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$1.apply(SparkPlan.scala:155) > at > org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151) > at > org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:152) > at > org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:127) > at > org.apache.spark.sql.execution.exchange.ShuffleExchangeExec.prepareShuffleDependency(ShuffleExchangeExec.scala:92) > at > org.apache.spark.sql.execution.exchange.ShuffleExchangeExec$$anonfun$doExecute$1.apply(ShuffleExchangeExec.scala:128) > at > org.apache.spark.sql.execution.exchange.ShuffleExchangeExec$$anonfun$doExecute$1.apply(ShuffleExchangeExec.scala:119) > at > org.apache.spark.sql.catalyst.errors.package$.attachTree(package.scala:52) > ... 48 more > {code} > Optimized Logical Plan here, I found Optimizer had already push down the > Filter through PushDownPredicates rule. > {code:java} > >>> df.groupby("id").agg(mean_udf(df['v']).alias("mean")).filter(col("mean") > >>> > 5).explain(True) > == Parsed Logical Plan == > 'Filter ('mean > 5) > +- Aggregate [id#0L], [id#0L, mean_udf(v#1) AS mean#79] > +- LogicalRDD [id#0L, v#1], false== Analyzed Logical Plan == > id: bigint, mean: double > Filter (mean#79 > cast(5 as double)) > +- Aggregate [id#0L], [id#0L, mean_udf(v#1) AS mean#79] > +- LogicalRDD [id#0L, v#1], false== Optimized Logical Plan == > Aggregate [id#0L], [id#0L, mean_udf(v#1) AS mean#79] > +- Filter (mean_udf(v#1) > 5.0) > +- LogicalRDD [id#0L, v#1], false== Physical Plan == > !AggregateInPandas [id#0L], [mean_udf(v#1)], [id#0L, mean_udf(v)#78 AS > mean#79] > +- *(2) Sort [id#0L ASC NULLS FIRST], false, 0 > +- Exchange hashpartitioning(id#0L, 200) > +- *(1) Filter (mean_udf(v#1) > 5.0) > +- Scan ExistingRDD[id#0L,v#1] > {code} > Compare with the official mean function, it will not push down Filter node > throuph PushDownPredicates rule. > {code:java} > >>> from pyspark.sql import functions as F > >>> df.groupby("id").agg(F.mean(df['v']).alias("mean")).filter(col("mean") > > >>> 5).explain(True) > == Parsed Logical Plan == > 'Filter ('mean > 5) > +- Aggregate [id#0L], [id#0L, avg(v#1) AS mean#7] > +- LogicalRDD [id#0L, v#1], false== Analyzed Logical Plan == > id: bigint, mean: double > Filter (mean#7 > cast(5 as double)) > +- Aggregate [id#0L], [id#0L, avg(v#1) AS mean#7] > +- LogicalRDD [id#0L, v#1], false== Optimized Logical Plan == > Filter (isnotnull(mean#7) && (mean#7 > 5.0)) > +- Aggregate [id#0L], [id#0L, avg(v#1) AS mean#7] > +- LogicalRDD [id#0L, v#1], false== Physical Plan == > *(2) Filter (isnotnull(mean#7) && (mean#7 > 5.0)) > +- *(2) HashAggregate(keys=[id#0L], functions=[avg(v#1)], output=[id#0L, > mean#7]) > +- Exchange hashpartitioning(id#0L, 200) > +- *(1) HashAggregate(keys=[id#0L], functions=[partial_avg(v#1)], > output=[id#0L, sum#15, count#16L]) > +- Scan ExistingRDD[id#0L,v#1] > {code} > And see the code in PushPredicateThroughNonJoin rule matched our case below. > {code:java} > case filter @ Filter(condition, aggregate: Aggregate) > if aggregate.aggregateExpressions.forall(_.deterministic) > && aggregate.groupingExpressions.nonEmpty => > val aliasMap = getAliasMap(aggregate) // For each filter, expand > the alias and check if the filter can be evaluated using > // attributes produced by the aggregate operator's child operator. > val (candidates, nonDeterministic) = > splitConjunctivePredicates(condition).partition(_.deterministic) > val (pushDown, rest) = candidates.partition { cond => > val replaced = replaceAlias(cond, aliasMap) > cond.references.nonEmpty && > replaced.references.subsetOf(aggregate.child.outputSet) > } val stayUp = rest ++ nonDeterministic if > (pushDown.nonEmpty) { > val pushDownPredicate = pushDown.reduce(And) > val replaced = replaceAlias(pushDownPredicate, aliasMap) > val newAggregate = aggregate.copy(child = Filter(replaced, > aggregate.child)) > // If there is no more filter to stay up, just eliminate the filter. > // Otherwise, create "Filter(stayUp) <- Aggregate <- > Filter(pushDownPredicate)". > if (stayUp.isEmpty) newAggregate else Filter(stayUp.reduce(And), > newAggregate) > } else { > filter > } > {code} > It's easy to infer when I use Python UDAF function, the references in > condition is the subset of child of Aggregate node, because Python UDAF > function in Catalyst is actually a PythonUDF expression, with its references > is the references of its input expression > {code:java} > case class PythonUDF( > name: String, > func: PythonFunction, > dataType: DataType, > children: Seq[Expression], > evalType: Int, > udfDeterministic: Boolean, > resultId: ExprId = NamedExpression.newExprId) > {code} > But the official mean function in Catalyst is Average expression, which is > DeclarativeAggregate with multiple aggBufferAttributes, which means the > references of Average is a sumDataType and a LongType. > {code:java} > case class Average(child: Expression) extends DeclarativeAggregate with > ImplicitCastInputTypes { > ... > private lazy val sum = AttributeReference("sum", sumDataType)() > private lazy val count = AttributeReference("count", LongType)() override > lazy val aggBufferAttributes = sum :: count :: Nil > ...{code} > {code:java} > case class AggregateExpression( > aggregateFunction: AggregateFunction, > mode: AggregateMode, > isDistinct: Boolean, > filter: Option[Expression], > resultId: ExprId) > extends Expression > with Unevaluable { > ... > @transient > override lazy val references: AttributeSet = { > val aggAttributes = mode match { > case Partial | Complete => aggregateFunction.references > case PartialMerge | Final => > AttributeSet(aggregateFunction.inputAggBufferAttributes) > } > aggAttributes ++ filterAttributes > }{code} > So, the references in PythonUDF is the subset of Aggregate's child's output > but Average is not. > I think the root cause of the problem is Catalyst does not treat the Pandas > UDAF as real AggregateFunction, so the Pandas UDAF will optimized like normal > UDF function. Maybe it's time to redesign the definition of the Pandas UDAF? > So it can get on the right track? > PS: All the speculation above is only a guess. -- This message was sent by Atlassian Jira (v8.3.4#803005) --------------------------------------------------------------------- To unsubscribe, e-mail: issues-unsubscr...@spark.apache.org For additional commands, e-mail: issues-h...@spark.apache.org