[ https://issues.apache.org/jira/browse/SPARK-35184?page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel&focusedCommentId=17337027#comment-17337027 ]
Hyukjin Kwon commented on SPARK-35184: -------------------------------------- [~xiaoking] Spark 2.4's last release is out, and there would be no more releases. please upgrade Spark verion. > Filtering a dataframe after groupBy and user-define-aggregate-function in > Pyspark will cause java.lang.UnsupportedOperationException > ------------------------------------------------------------------------------------------------------------------------------------ > > Key: SPARK-35184 > URL: https://issues.apache.org/jira/browse/SPARK-35184 > Project: Spark > Issue Type: Bug > Components: Optimizer > Affects Versions: 2.4.0 > Reporter: Xiao Jin > Priority: Major > > I found some strange error when I'm coding Pyspark UDAF. After I call groupBy > function and agg function, I want to filter some data from remaining > dataframe, but it seems not work. My sample code is below. > {code:java} > >>> from pyspark.sql.functions import pandas_udf, PandasUDFType, col > >>> df = spark.createDataFrame( > ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], > ... ("id", "v")) > >>> @pandas_udf("double", PandasUDFType.GROUPED_AGG) > ... def mean_udf(v): > ... return v.mean() > >>> df.groupby("id").agg(mean_udf(df['v']).alias("mean")).filter(col("mean") > >>> > 5).show() > {code} > The code above will cause exception printed below > {code:java} > Traceback (most recent call last): > File "<stdin>", line 1, in <module> > File "/opt/spark/python/pyspark/sql/dataframe.py", line 378, in show > print(self._jdf.showString(n, 20, vertical)) > File "/opt/spark/python/lib/py4j-0.10.7-src.zip/py4j/java_gateway.py", line > 1257, in __call__ > File "/opt/spark/python/pyspark/sql/utils.py", line 63, in deco > return f(*a, **kw) > File "/opt/spark/python/lib/py4j-0.10.7-src.zip/py4j/protocol.py", line > 328, in get_return_value > py4j.protocol.Py4JJavaError: An error occurred while calling o3717.showString. > : org.apache.spark.sql.catalyst.errors.package$TreeNodeException: execute, > tree: > Exchange hashpartitioning(id#1726L, 200) > +- *(1) Filter (mean_udf(v#1727) > 5.0) > +- Scan ExistingRDD[id#1726L,v#1727] > at > org.apache.spark.sql.catalyst.errors.package$.attachTree(package.scala:56) > at > org.apache.spark.sql.execution.exchange.ShuffleExchangeExec.doExecute(ShuffleExchangeExec.scala:119) > at > org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:131) > at > org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:127) > at > org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$1.apply(SparkPlan.scala:155) > at > org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151) > at > org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:152) > at > org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:127) > at > org.apache.spark.sql.execution.InputAdapter.inputRDDs(WholeStageCodegenExec.scala:391) > at > org.apache.spark.sql.execution.SortExec.inputRDDs(SortExec.scala:121) > at > org.apache.spark.sql.execution.WholeStageCodegenExec.doExecute(WholeStageCodegenExec.scala:627) > at > org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:131) > at > org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:127) > at > org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$1.apply(SparkPlan.scala:155) > at > org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151) > at > org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:152) > at > org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:127) > at > org.apache.spark.sql.execution.python.AggregateInPandasExec.doExecute(AggregateInPandasExec.scala:80) > at > org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:131) > at > org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:127) > at > org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$1.apply(SparkPlan.scala:155) > at > org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151) > at > org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:152) > at > org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:127) > at > org.apache.spark.sql.execution.SparkPlan.getByteArrayRdd(SparkPlan.scala:247) > at > org.apache.spark.sql.execution.SparkPlan.executeTake(SparkPlan.scala:339) > at > org.apache.spark.sql.execution.CollectLimitExec.executeCollect(limit.scala:38) > at > org.apache.spark.sql.Dataset.org$apache$spark$sql$Dataset$$collectFromPlan(Dataset.scala:3383) > at > org.apache.spark.sql.Dataset$$anonfun$head$1.apply(Dataset.scala:2544) > at > org.apache.spark.sql.Dataset$$anonfun$head$1.apply(Dataset.scala:2544) > at org.apache.spark.sql.Dataset$$anonfun$53.apply(Dataset.scala:3364) > at > org.apache.spark.sql.execution.SQLExecution$$anonfun$withNewExecutionId$1.apply(SQLExecution.scala:78) > at > org.apache.spark.sql.execution.SQLExecution$.withSQLConfPropagated(SQLExecution.scala:125) > at > org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:73) > at org.apache.spark.sql.Dataset.withAction(Dataset.scala:3363) > at org.apache.spark.sql.Dataset.head(Dataset.scala:2544) > at org.apache.spark.sql.Dataset.take(Dataset.scala:2758) > at org.apache.spark.sql.Dataset.getRows(Dataset.scala:254) > at org.apache.spark.sql.Dataset.showString(Dataset.scala:291) > at sun.reflect.GeneratedMethodAccessor139.invoke(Unknown Source) > at > sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43) > at java.lang.reflect.Method.invoke(Method.java:498) > at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244) > at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357) > at py4j.Gateway.invoke(Gateway.java:282) > at > py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132) > at py4j.commands.CallCommand.execute(CallCommand.java:79) > at py4j.GatewayConnection.run(GatewayConnection.java:238) > at java.lang.Thread.run(Thread.java:748) > Caused by: java.lang.UnsupportedOperationException: Cannot evaluate > expression: mean_udf(input[1, double, true]) > at > org.apache.spark.sql.catalyst.expressions.Unevaluable$class.doGenCode(Expression.scala:261) > at > org.apache.spark.sql.catalyst.expressions.PythonUDF.doGenCode(PythonUDF.scala:50) > at > org.apache.spark.sql.catalyst.expressions.Expression$$anonfun$genCode$2.apply(Expression.scala:108) > at > org.apache.spark.sql.catalyst.expressions.Expression$$anonfun$genCode$2.apply(Expression.scala:105) > at scala.Option.getOrElse(Option.scala:121) > at > org.apache.spark.sql.catalyst.expressions.Expression.genCode(Expression.scala:105) > at > org.apache.spark.sql.catalyst.expressions.BinaryExpression.nullSafeCodeGen(Expression.scala:525) > at > org.apache.spark.sql.catalyst.expressions.BinaryExpression.defineCodeGen(Expression.scala:508) > at > org.apache.spark.sql.catalyst.expressions.BinaryComparison.doGenCode(predicates.scala:563) > at > org.apache.spark.sql.catalyst.expressions.Expression$$anonfun$genCode$2.apply(Expression.scala:108) > at > org.apache.spark.sql.catalyst.expressions.Expression$$anonfun$genCode$2.apply(Expression.scala:105) > at scala.Option.getOrElse(Option.scala:121) > at > org.apache.spark.sql.catalyst.expressions.Expression.genCode(Expression.scala:105) > at > org.apache.spark.sql.execution.FilterExec.org$apache$spark$sql$execution$FilterExec$$genPredicate$1(basicPhysicalOperators.scala:139) > at > org.apache.spark.sql.execution.FilterExec$$anonfun$13.apply(basicPhysicalOperators.scala:179) > at > org.apache.spark.sql.execution.FilterExec$$anonfun$13.apply(basicPhysicalOperators.scala:163) > at > scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234) > at > scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234) > at scala.collection.immutable.List.foreach(List.scala:392) > at > scala.collection.TraversableLike$class.map(TraversableLike.scala:234) > at scala.collection.immutable.List.map(List.scala:296) > at > org.apache.spark.sql.execution.FilterExec.doConsume(basicPhysicalOperators.scala:163) > at > org.apache.spark.sql.execution.CodegenSupport$class.consume(WholeStageCodegenExec.scala:189) > at > org.apache.spark.sql.execution.InputAdapter.consume(WholeStageCodegenExec.scala:374) > at > org.apache.spark.sql.execution.InputAdapter.doProduce(WholeStageCodegenExec.scala:403) > at > org.apache.spark.sql.execution.CodegenSupport$$anonfun$produce$1.apply(WholeStageCodegenExec.scala:90) > at > org.apache.spark.sql.execution.CodegenSupport$$anonfun$produce$1.apply(WholeStageCodegenExec.scala:85) > at > org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$1.apply(SparkPlan.scala:155) > at > org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151) > at > org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:152) > at > org.apache.spark.sql.execution.CodegenSupport$class.produce(WholeStageCodegenExec.scala:85) > at > org.apache.spark.sql.execution.InputAdapter.produce(WholeStageCodegenExec.scala:374) > at > org.apache.spark.sql.execution.FilterExec.doProduce(basicPhysicalOperators.scala:125) > at > org.apache.spark.sql.execution.CodegenSupport$$anonfun$produce$1.apply(WholeStageCodegenExec.scala:90) > at > org.apache.spark.sql.execution.CodegenSupport$$anonfun$produce$1.apply(WholeStageCodegenExec.scala:85) > at > org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$1.apply(SparkPlan.scala:155) > at > org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151) > at > org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:152) > at > org.apache.spark.sql.execution.CodegenSupport$class.produce(WholeStageCodegenExec.scala:85) > at > org.apache.spark.sql.execution.FilterExec.produce(basicPhysicalOperators.scala:85) > at > org.apache.spark.sql.execution.WholeStageCodegenExec.doCodeGen(WholeStageCodegenExec.scala:544) > at > org.apache.spark.sql.execution.WholeStageCodegenExec.doExecute(WholeStageCodegenExec.scala:598) > at > org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:131) > at > org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:127) > at > org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$1.apply(SparkPlan.scala:155) > at > org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151) > at > org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:152) > at > org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:127) > at > org.apache.spark.sql.execution.exchange.ShuffleExchangeExec.prepareShuffleDependency(ShuffleExchangeExec.scala:92) > at > org.apache.spark.sql.execution.exchange.ShuffleExchangeExec$$anonfun$doExecute$1.apply(ShuffleExchangeExec.scala:128) > at > org.apache.spark.sql.execution.exchange.ShuffleExchangeExec$$anonfun$doExecute$1.apply(ShuffleExchangeExec.scala:119) > at > org.apache.spark.sql.catalyst.errors.package$.attachTree(package.scala:52) > ... 48 more > {code} > Optimized Logical Plan here, I found Optimizer had already push down the > Filter through PushDownPredicates rule. > {code:java} > >>> df.groupby("id").agg(mean_udf(df['v']).alias("mean")).filter(col("mean") > >>> > 5).explain(True) > == Parsed Logical Plan == > 'Filter ('mean > 5) > +- Aggregate [id#0L], [id#0L, mean_udf(v#1) AS mean#79] > +- LogicalRDD [id#0L, v#1], false== Analyzed Logical Plan == > id: bigint, mean: double > Filter (mean#79 > cast(5 as double)) > +- Aggregate [id#0L], [id#0L, mean_udf(v#1) AS mean#79] > +- LogicalRDD [id#0L, v#1], false== Optimized Logical Plan == > Aggregate [id#0L], [id#0L, mean_udf(v#1) AS mean#79] > +- Filter (mean_udf(v#1) > 5.0) > +- LogicalRDD [id#0L, v#1], false== Physical Plan == > !AggregateInPandas [id#0L], [mean_udf(v#1)], [id#0L, mean_udf(v)#78 AS > mean#79] > +- *(2) Sort [id#0L ASC NULLS FIRST], false, 0 > +- Exchange hashpartitioning(id#0L, 200) > +- *(1) Filter (mean_udf(v#1) > 5.0) > +- Scan ExistingRDD[id#0L,v#1] > {code} > Compare with the official mean function, it will not push down Filter node > throuph PushDownPredicates rule. > {code:java} > >>> from pyspark.sql import functions as F > >>> df.groupby("id").agg(F.mean(df['v']).alias("mean")).filter(col("mean") > > >>> 5).explain(True) > == Parsed Logical Plan == > 'Filter ('mean > 5) > +- Aggregate [id#0L], [id#0L, avg(v#1) AS mean#7] > +- LogicalRDD [id#0L, v#1], false== Analyzed Logical Plan == > id: bigint, mean: double > Filter (mean#7 > cast(5 as double)) > +- Aggregate [id#0L], [id#0L, avg(v#1) AS mean#7] > +- LogicalRDD [id#0L, v#1], false== Optimized Logical Plan == > Filter (isnotnull(mean#7) && (mean#7 > 5.0)) > +- Aggregate [id#0L], [id#0L, avg(v#1) AS mean#7] > +- LogicalRDD [id#0L, v#1], false== Physical Plan == > *(2) Filter (isnotnull(mean#7) && (mean#7 > 5.0)) > +- *(2) HashAggregate(keys=[id#0L], functions=[avg(v#1)], output=[id#0L, > mean#7]) > +- Exchange hashpartitioning(id#0L, 200) > +- *(1) HashAggregate(keys=[id#0L], functions=[partial_avg(v#1)], > output=[id#0L, sum#15, count#16L]) > +- Scan ExistingRDD[id#0L,v#1] > {code} > > Update: > Delete useless confusing guess. I think I found the reason. > In branch 2.4, I think wrong aliasMap has been created without filtering the > pandas aggregate function. See codes here > {code:java} > case filter @ Filter(condition, aggregate: Aggregate) > if aggregate.aggregateExpressions.forall(_.deterministic) > && aggregate.groupingExpressions.nonEmpty => > // Find all the aliased expressions in the aggregate list that don't > include any actual > // AggregateExpression, and create a map from the alias to the > expression > val aliasMap = AttributeMap(aggregate.aggregateExpressions.collect { > case a: Alias if > a.child.find(_.isInstanceOf[AggregateExpression]).isEmpty => > (a.toAttribute, a.child) > }) > ......{code} > But in master branch, it has been corrected by using getAliasMap function in > AliasHelper.scala. > {code:java} > protected def getAliasMap(plan: Aggregate): AttributeMap[Alias] = { > // Find all the aliased expressions in the aggregate list that don't > include any actual > // AggregateExpression or PythonUDF, and create a map from the alias to > the expression > val aliasMap = plan.aggregateExpressions.collect { > case a: Alias if a.child.find(e => e.isInstanceOf[AggregateExpression] > || > PythonUDF.isGroupedAggPandasUDF(e)).isEmpty => > (a.toAttribute, a) > } > AttributeMap(aliasMap) > } > {code} > So in branch 2.4, it has not filtered all the aggregate functions. > -- This message was sent by Atlassian Jira (v8.3.4#803005) --------------------------------------------------------------------- To unsubscribe, e-mail: issues-unsubscr...@spark.apache.org For additional commands, e-mail: issues-h...@spark.apache.org