[ 
https://issues.apache.org/jira/browse/SPARK-42851?page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel&focusedCommentId=17702658#comment-17702658
 ] 

Apache Spark commented on SPARK-42851:
--------------------------------------

User 'peter-toth' has created a pull request for this issue:
https://github.com/apache/spark/pull/40488

> EquivalentExpressions methods need to be consistently guarded by 
> supportedExpression
> ------------------------------------------------------------------------------------
>
>                 Key: SPARK-42851
>                 URL: https://issues.apache.org/jira/browse/SPARK-42851
>             Project: Spark
>          Issue Type: Bug
>          Components: SQL
>    Affects Versions: 3.3.2, 3.4.0
>            Reporter: Kris Mok
>            Priority: Major
>
> SPARK-41468 tried to fix a bug but introduced a new regression. Its change to 
> {{EquivalentExpressions}} added a {{supportedExpression()}} guard to the 
> {{addExprTree()}} and {{getExprState()}} methods, but didn't add the same 
> guard to the other "add" entry point -- {{addExpr()}}.
> As such, uses that add single expressions to CSE via {{addExpr()}} may 
> succeed, but upon retrieval via {{getExprState()}} it'd inconsistently get a 
> {{None}} due to failing the guard.
> We need to make sure the "add" and "get" methods are consistent. It could be 
> done by one of:
> 1. Adding the same {{supportedExpression()}} guard to {{addExpr()}}, or
> 2. Removing the guard from {{getExprState()}}, relying solely on the guard on 
> the "add" path to make sure only intended state is added.
> (or other alternative refactorings to fuse the guard into various methods to 
> make it more efficient)
> There are pros and cons to the two directions above, because {{addExpr()}} 
> used to allow (potentially incorrect) more expressions to get CSE'd, making 
> it more restrictive may cause performance regressions (for the cases that 
> happened to work).
> Example:
> {code:sql}
> select max(transform(array(id), x -> x)), max(transform(array(id), x -> x)) 
> from range(2)
> {code}
> Running this query on Spark 3.2 branch returns the correct value:
> {code}
> scala> spark.sql("select max(transform(array(id), x -> x)), 
> max(transform(array(id), x -> x)) from range(2)").collect
> res0: Array[org.apache.spark.sql.Row] = 
> Array([WrappedArray(1),WrappedArray(1)])
> {code}
> Here, {{transform(array(id), x -> x)}} is an {{AggregateExpression}} that was 
> (potentially unsafely) recognized by {{addExpr()}} as a common subexpression, 
> and {{getExprState()}} doesn't do extra guarding, so during physical 
> planning, in {{PhysicalAggregation}} this expression gets CSE'd in both the 
> aggregation expression list and the result expressions list.
> {code}
> AdaptiveSparkPlan isFinalPlan=false
> +- SortAggregate(key=[], functions=[max(transform(array(id#0L), 
> lambdafunction(lambda x#1L, lambda x#1L, false)))])
>    +- Exchange SinglePartition, ENSURE_REQUIREMENTS, [plan_id=11]
>       +- SortAggregate(key=[], functions=[partial_max(transform(array(id#0L), 
> lambdafunction(lambda x#1L, lambda x#1L, false)))])
>          +- Range (0, 2, step=1, splits=16)
> {code}
> Running the same query on current master triggers an error when binding the 
> result expression to the aggregate expression in the Aggregate operators (for 
> a WSCG-enabled operator like {{HashAggregateExec}}, the same error would show 
> up during codegen):
> {code}
> ERROR TaskSetManager: Task 0 in stage 2.0 failed 1 times; aborting job
> org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in 
> stage 2.0 failed 1 times, most recent failure: Lost task 0.0 in stage 2.0 
> (TID 16) (ip-10-110-16-93.us-west-2.compute.internal executor driver): 
> java.lang.IllegalStateException: Couldn't find max(transform(array(id#0L), 
> lambdafunction(lambda x#2L, lambda x#2L, false)))#4 in 
> [max(transform(array(id#0L), lambdafunction(lambda x#1L, lambda x#1L, 
> false)))#3]
>       at 
> org.apache.spark.sql.catalyst.expressions.BindReferences$$anonfun$bindReference$1.applyOrElse(BoundAttribute.scala:80)
>       at 
> org.apache.spark.sql.catalyst.expressions.BindReferences$$anonfun$bindReference$1.applyOrElse(BoundAttribute.scala:73)
>       at 
> org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$transformDownWithPruning$1(TreeNode.scala:512)
>       at 
> org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:104)
>       at 
> org.apache.spark.sql.catalyst.trees.TreeNode.transformDownWithPruning(TreeNode.scala:512)
>       at 
> org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$transformDownWithPruning$3(TreeNode.scala:517)
>       at 
> org.apache.spark.sql.catalyst.trees.UnaryLike.mapChildren(TreeNode.scala:1249)
>       at 
> org.apache.spark.sql.catalyst.trees.UnaryLike.mapChildren$(TreeNode.scala:1248)
>       at 
> org.apache.spark.sql.catalyst.expressions.UnaryExpression.mapChildren(Expression.scala:532)
>       at 
> org.apache.spark.sql.catalyst.trees.TreeNode.transformDownWithPruning(TreeNode.scala:517)
>       at 
> org.apache.spark.sql.catalyst.trees.TreeNode.transformDown(TreeNode.scala:488)
>       at 
> org.apache.spark.sql.catalyst.trees.TreeNode.transform(TreeNode.scala:456)
>       at 
> org.apache.spark.sql.catalyst.expressions.BindReferences$.bindReference(BoundAttribute.scala:73)
>       at 
> org.apache.spark.sql.catalyst.expressions.BindReferences$.$anonfun$bindReferences$1(BoundAttribute.scala:94)
>       at scala.collection.immutable.List.map(List.scala:297)
>       at 
> org.apache.spark.sql.catalyst.expressions.BindReferences$.bindReferences(BoundAttribute.scala:94)
>       at 
> org.apache.spark.sql.catalyst.expressions.UnsafeProjection$.create(Projection.scala:161)
>       at 
> org.apache.spark.sql.execution.aggregate.AggregationIterator.generateResultProjection(AggregationIterator.scala:246)
>       at 
> org.apache.spark.sql.execution.aggregate.AggregationIterator.<init>(AggregationIterator.scala:296)
>       at 
> org.apache.spark.sql.execution.aggregate.SortBasedAggregationIterator.<init>(SortBasedAggregationIterator.scala:49)
>       at 
> org.apache.spark.sql.execution.aggregate.SortAggregateExec.$anonfun$doExecute$1(SortAggregateExec.scala:79)
>       at 
> org.apache.spark.sql.execution.aggregate.SortAggregateExec.$anonfun$doExecute$1$adapted(SortAggregateExec.scala:59)
> ...
> {code}
> Note that the aggregate expressions are deduplicated in 
> {{PhysicalAggregation}}, but the result expressions were unable to 
> deduplicate consistently due to the bug mentioned in this ticket.
> {code}
> AdaptiveSparkPlan isFinalPlan=false
> +- SortAggregate(key=[], functions=[max(transform(array(id#15L), 
> lambdafunction(lambda x#16L, lambda x#16L, false)))])
>    +- Exchange SinglePartition, ENSURE_REQUIREMENTS, [plan_id=38]
>       +- SortAggregate(key=[], 
> functions=[partial_max(transform(array(id#15L), lambdafunction(lambda x#16L, 
> lambda x#16L, false)))])
>          +- Range (0, 2, step=1, splits=16)
> {code}
> Fixing it via method 1 is more correct than method 2 in terms of avoiding 
> incorrect CSE:
> {code:diff}
> diff --git 
> a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala
>  
> b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala
> index 330d66a21b..12def60042 100644
> --- 
> a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala
> +++ 
> b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala
> @@ -40,7 +40,11 @@ class EquivalentExpressions {
>     * Returns true if there was already a matching expression.
>     */
>    def addExpr(expr: Expression): Boolean = {
> -    updateExprInMap(expr, equivalenceMap)
> +    if (supportedExpression(expr)) {
> +      updateExprInMap(expr, equivalenceMap)
> +    } else {
> +      false
> +    }
>    }
>  
>    /**
> {code}
> the query runs correctly again, but this time the aggregate expression is NOT 
> CSE'd anymore, done consistently for both aggregate expressions and result 
> expressions:
> {code}
> AdaptiveSparkPlan isFinalPlan=false
> +- SortAggregate(key=[], functions=[max(transform(array(id#0L), 
> lambdafunction(lambda x#1L, lambda x#1L, false))), 
> max(transform(array(id#0L), lambdafunction(lambda x#2L, lambda x#2L, 
> false)))])
>    +- Exchange SinglePartition, ENSURE_REQUIREMENTS, [plan_id=11]
>       +- SortAggregate(key=[], functions=[partial_max(transform(array(id#0L), 
> lambdafunction(lambda x#1L, lambda x#1L, false))), 
> partial_max(transform(array(id#0L), lambdafunction(lambda x#2L, lambda x#2L, 
> false)))])
>          +- Range (0, 2, step=1, splits=16)
> {code}
> and for this particular case, the CSE that used to take place was actually 
> okay, so losing CSE here means performance regression.



--
This message was sent by Atlassian Jira
(v8.20.10#820010)

---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscr...@spark.apache.org
For additional commands, e-mail: issues-h...@spark.apache.org

Reply via email to