[ https://issues.apache.org/jira/browse/SPARK-42851?page=com.atlassian.jira.plugin.system.issuetabpanels:all-tabpanel ]
Wenchen Fan reassigned SPARK-42851: ----------------------------------- Assignee: Kris Mok > EquivalentExpressions methods need to be consistently guarded by > supportedExpression > ------------------------------------------------------------------------------------ > > Key: SPARK-42851 > URL: https://issues.apache.org/jira/browse/SPARK-42851 > Project: Spark > Issue Type: Bug > Components: SQL > Affects Versions: 3.3.2, 3.4.0 > Reporter: Kris Mok > Assignee: Kris Mok > Priority: Major > > SPARK-41468 tried to fix a bug but introduced a new regression. Its change to > {{EquivalentExpressions}} added a {{supportedExpression()}} guard to the > {{addExprTree()}} and {{getExprState()}} methods, but didn't add the same > guard to the other "add" entry point -- {{addExpr()}}. > As such, uses that add single expressions to CSE via {{addExpr()}} may > succeed, but upon retrieval via {{getExprState()}} it'd inconsistently get a > {{None}} due to failing the guard. > We need to make sure the "add" and "get" methods are consistent. It could be > done by one of: > 1. Adding the same {{supportedExpression()}} guard to {{addExpr()}}, or > 2. Removing the guard from {{getExprState()}}, relying solely on the guard on > the "add" path to make sure only intended state is added. > (or other alternative refactorings to fuse the guard into various methods to > make it more efficient) > There are pros and cons to the two directions above, because {{addExpr()}} > used to allow (potentially incorrect) more expressions to get CSE'd, making > it more restrictive may cause performance regressions (for the cases that > happened to work). > Example: > {code:sql} > select max(transform(array(id), x -> x)), max(transform(array(id), x -> x)) > from range(2) > {code} > Running this query on Spark 3.2 branch returns the correct value: > {code} > scala> spark.sql("select max(transform(array(id), x -> x)), > max(transform(array(id), x -> x)) from range(2)").collect > res0: Array[org.apache.spark.sql.Row] = > Array([WrappedArray(1),WrappedArray(1)]) > {code} > Here, {{transform(array(id), x -> x)}} is an {{AggregateExpression}} that was > (potentially unsafely) recognized by {{addExpr()}} as a common subexpression, > and {{getExprState()}} doesn't do extra guarding, so during physical > planning, in {{PhysicalAggregation}} this expression gets CSE'd in both the > aggregation expression list and the result expressions list. > {code} > AdaptiveSparkPlan isFinalPlan=false > +- SortAggregate(key=[], functions=[max(transform(array(id#0L), > lambdafunction(lambda x#1L, lambda x#1L, false)))]) > +- Exchange SinglePartition, ENSURE_REQUIREMENTS, [plan_id=11] > +- SortAggregate(key=[], functions=[partial_max(transform(array(id#0L), > lambdafunction(lambda x#1L, lambda x#1L, false)))]) > +- Range (0, 2, step=1, splits=16) > {code} > Running the same query on current master triggers an error when binding the > result expression to the aggregate expression in the Aggregate operators (for > a WSCG-enabled operator like {{HashAggregateExec}}, the same error would show > up during codegen): > {code} > ERROR TaskSetManager: Task 0 in stage 2.0 failed 1 times; aborting job > org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in > stage 2.0 failed 1 times, most recent failure: Lost task 0.0 in stage 2.0 > (TID 16) (ip-10-110-16-93.us-west-2.compute.internal executor driver): > java.lang.IllegalStateException: Couldn't find max(transform(array(id#0L), > lambdafunction(lambda x#2L, lambda x#2L, false)))#4 in > [max(transform(array(id#0L), lambdafunction(lambda x#1L, lambda x#1L, > false)))#3] > at > org.apache.spark.sql.catalyst.expressions.BindReferences$$anonfun$bindReference$1.applyOrElse(BoundAttribute.scala:80) > at > org.apache.spark.sql.catalyst.expressions.BindReferences$$anonfun$bindReference$1.applyOrElse(BoundAttribute.scala:73) > at > org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$transformDownWithPruning$1(TreeNode.scala:512) > at > org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:104) > at > org.apache.spark.sql.catalyst.trees.TreeNode.transformDownWithPruning(TreeNode.scala:512) > at > org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$transformDownWithPruning$3(TreeNode.scala:517) > at > org.apache.spark.sql.catalyst.trees.UnaryLike.mapChildren(TreeNode.scala:1249) > at > org.apache.spark.sql.catalyst.trees.UnaryLike.mapChildren$(TreeNode.scala:1248) > at > org.apache.spark.sql.catalyst.expressions.UnaryExpression.mapChildren(Expression.scala:532) > at > org.apache.spark.sql.catalyst.trees.TreeNode.transformDownWithPruning(TreeNode.scala:517) > at > org.apache.spark.sql.catalyst.trees.TreeNode.transformDown(TreeNode.scala:488) > at > org.apache.spark.sql.catalyst.trees.TreeNode.transform(TreeNode.scala:456) > at > org.apache.spark.sql.catalyst.expressions.BindReferences$.bindReference(BoundAttribute.scala:73) > at > org.apache.spark.sql.catalyst.expressions.BindReferences$.$anonfun$bindReferences$1(BoundAttribute.scala:94) > at scala.collection.immutable.List.map(List.scala:297) > at > org.apache.spark.sql.catalyst.expressions.BindReferences$.bindReferences(BoundAttribute.scala:94) > at > org.apache.spark.sql.catalyst.expressions.UnsafeProjection$.create(Projection.scala:161) > at > org.apache.spark.sql.execution.aggregate.AggregationIterator.generateResultProjection(AggregationIterator.scala:246) > at > org.apache.spark.sql.execution.aggregate.AggregationIterator.<init>(AggregationIterator.scala:296) > at > org.apache.spark.sql.execution.aggregate.SortBasedAggregationIterator.<init>(SortBasedAggregationIterator.scala:49) > at > org.apache.spark.sql.execution.aggregate.SortAggregateExec.$anonfun$doExecute$1(SortAggregateExec.scala:79) > at > org.apache.spark.sql.execution.aggregate.SortAggregateExec.$anonfun$doExecute$1$adapted(SortAggregateExec.scala:59) > ... > {code} > Note that the aggregate expressions are deduplicated in > {{PhysicalAggregation}}, but the result expressions were unable to > deduplicate consistently due to the bug mentioned in this ticket. > {code} > AdaptiveSparkPlan isFinalPlan=false > +- SortAggregate(key=[], functions=[max(transform(array(id#15L), > lambdafunction(lambda x#16L, lambda x#16L, false)))]) > +- Exchange SinglePartition, ENSURE_REQUIREMENTS, [plan_id=38] > +- SortAggregate(key=[], > functions=[partial_max(transform(array(id#15L), lambdafunction(lambda x#16L, > lambda x#16L, false)))]) > +- Range (0, 2, step=1, splits=16) > {code} > Fixing it via method 1 is more correct than method 2 in terms of avoiding > incorrect CSE: > {code:diff} > diff --git > a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala > > b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala > index 330d66a21b..12def60042 100644 > --- > a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala > +++ > b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala > @@ -40,7 +40,11 @@ class EquivalentExpressions { > * Returns true if there was already a matching expression. > */ > def addExpr(expr: Expression): Boolean = { > - updateExprInMap(expr, equivalenceMap) > + if (supportedExpression(expr)) { > + updateExprInMap(expr, equivalenceMap) > + } else { > + false > + } > } > > /** > {code} > the query runs correctly again, but this time the aggregate expression is NOT > CSE'd anymore, done consistently for both aggregate expressions and result > expressions: > {code} > AdaptiveSparkPlan isFinalPlan=false > +- SortAggregate(key=[], functions=[max(transform(array(id#0L), > lambdafunction(lambda x#1L, lambda x#1L, false))), > max(transform(array(id#0L), lambdafunction(lambda x#2L, lambda x#2L, > false)))]) > +- Exchange SinglePartition, ENSURE_REQUIREMENTS, [plan_id=11] > +- SortAggregate(key=[], functions=[partial_max(transform(array(id#0L), > lambdafunction(lambda x#1L, lambda x#1L, false))), > partial_max(transform(array(id#0L), lambdafunction(lambda x#2L, lambda x#2L, > false)))]) > +- Range (0, 2, step=1, splits=16) > {code} > and for this particular case, the CSE that used to take place was actually > okay, so losing CSE here means performance regression. -- This message was sent by Atlassian Jira (v8.20.10#820010) --------------------------------------------------------------------- To unsubscribe, e-mail: issues-unsubscr...@spark.apache.org For additional commands, e-mail: issues-h...@spark.apache.org