Kris Mok created SPARK-42851:
--------------------------------

             Summary: EquivalentExpressions methods need to be consistently 
guarded by supportedExpression
                 Key: SPARK-42851
                 URL: https://issues.apache.org/jira/browse/SPARK-42851
             Project: Spark
          Issue Type: Bug
          Components: SQL
    Affects Versions: 3.3.2, 3.4.0
            Reporter: Kris Mok


SPARK-41468 tried to fix a bug but introduced a new regression. Its change to 
{{EquivalentExpressions}} added a {{supportedExpression()}} guard to the 
{{addExprTree()}} and {{getExprState()}} methods, but didn't add the same guard 
to the other "add" entry point -- {{addExpr()}}.

As such, uses that add single expressions to CSE via {{addExpr()}} may succeed, 
but upon retrieval via {{getExprState()}} it'd inconsistently get a {{None}} 
due to failing the guard.

We need to make sure the "add" and "get" methods are consistent. It could be 
done by one of:
1. Adding the same {{supportedExpression()}} guard to {{addExpr()}}, or
2. Removing the guard from {{getExprState()}}, relying solely on the guard on 
the "add" path to make sure only intended state is added.
(or other alternative refactorings to fuse the guard into various methods to 
make it more efficient)

There are pros and cons to the two directions above, because {{addExpr()}} used 
to allow (potentially incorrect) more expressions to get CSE'd, making it more 
restrictive may cause performance regressions (for the cases that happened to 
work).

Example:
{code:sql}
select max(transform(array(id), x -> x)), max(transform(array(id), x -> x)) 
from range(2)
{code}

Running this query on Spark 3.2 branch returns the correct value:
{code}
scala> spark.sql("select max(transform(array(id), x -> x)), 
max(transform(array(id), x -> x)) from range(2)").collect
res0: Array[org.apache.spark.sql.Row] = Array([WrappedArray(1),WrappedArray(1)])
{code}
Here, {{transform(array(id), x -> x)}} is an {{AggregateExpression}} that was 
(potentially unsafely) recognized by {{addExpr()}} as a common subexpression, 
and {{getExprState()}} doesn't do extra guarding, so during physical planning, 
in {{PhysicalAggregation}} this expression gets CSE'd in both the aggregation 
expression list and the result expressions list.
{code}
AdaptiveSparkPlan isFinalPlan=false
+- SortAggregate(key=[], functions=[max(transform(array(id#0L), 
lambdafunction(lambda x#1L, lambda x#1L, false)))])
   +- Exchange SinglePartition, ENSURE_REQUIREMENTS, [plan_id=11]
      +- SortAggregate(key=[], functions=[partial_max(transform(array(id#0L), 
lambdafunction(lambda x#1L, lambda x#1L, false)))])
         +- Range (0, 2, step=1, splits=16)
{code}

Running the same query on current master triggers an error when binding the 
result expression to the aggregate expression in the Aggregate operators (for a 
WSCG-enabled operator like {{HashAggregateExec}}, the same error would show up 
during codegen):
{code}
ERROR TaskSetManager: Task 0 in stage 2.0 failed 1 times; aborting job
org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in 
stage 2.0 failed 1 times, most recent failure: Lost task 0.0 in stage 2.0 (TID 
16) (ip-10-110-16-93.us-west-2.compute.internal executor driver): 
java.lang.IllegalStateException: Couldn't find max(transform(array(id#0L), 
lambdafunction(lambda x#2L, lambda x#2L, false)))#4 in 
[max(transform(array(id#0L), lambdafunction(lambda x#1L, lambda x#1L, 
false)))#3]
        at 
org.apache.spark.sql.catalyst.expressions.BindReferences$$anonfun$bindReference$1.applyOrElse(BoundAttribute.scala:80)
        at 
org.apache.spark.sql.catalyst.expressions.BindReferences$$anonfun$bindReference$1.applyOrElse(BoundAttribute.scala:73)
        at 
org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$transformDownWithPruning$1(TreeNode.scala:512)
        at 
org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:104)
        at 
org.apache.spark.sql.catalyst.trees.TreeNode.transformDownWithPruning(TreeNode.scala:512)
        at 
org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$transformDownWithPruning$3(TreeNode.scala:517)
        at 
org.apache.spark.sql.catalyst.trees.UnaryLike.mapChildren(TreeNode.scala:1249)
        at 
org.apache.spark.sql.catalyst.trees.UnaryLike.mapChildren$(TreeNode.scala:1248)
        at 
org.apache.spark.sql.catalyst.expressions.UnaryExpression.mapChildren(Expression.scala:532)
        at 
org.apache.spark.sql.catalyst.trees.TreeNode.transformDownWithPruning(TreeNode.scala:517)
        at 
org.apache.spark.sql.catalyst.trees.TreeNode.transformDown(TreeNode.scala:488)
        at 
org.apache.spark.sql.catalyst.trees.TreeNode.transform(TreeNode.scala:456)
        at 
org.apache.spark.sql.catalyst.expressions.BindReferences$.bindReference(BoundAttribute.scala:73)
        at 
org.apache.spark.sql.catalyst.expressions.BindReferences$.$anonfun$bindReferences$1(BoundAttribute.scala:94)
        at scala.collection.immutable.List.map(List.scala:297)
        at 
org.apache.spark.sql.catalyst.expressions.BindReferences$.bindReferences(BoundAttribute.scala:94)
        at 
org.apache.spark.sql.catalyst.expressions.UnsafeProjection$.create(Projection.scala:161)
        at 
org.apache.spark.sql.execution.aggregate.AggregationIterator.generateResultProjection(AggregationIterator.scala:246)
        at 
org.apache.spark.sql.execution.aggregate.AggregationIterator.<init>(AggregationIterator.scala:296)
        at 
org.apache.spark.sql.execution.aggregate.SortBasedAggregationIterator.<init>(SortBasedAggregationIterator.scala:49)
        at 
org.apache.spark.sql.execution.aggregate.SortAggregateExec.$anonfun$doExecute$1(SortAggregateExec.scala:79)
        at 
org.apache.spark.sql.execution.aggregate.SortAggregateExec.$anonfun$doExecute$1$adapted(SortAggregateExec.scala:59)
...
{code}
Note that the aggregate expressions are deduplicated in 
{{PhysicalAggregation}}, but the result expressions were unable to deduplicate 
consistently due to the bug mentioned in this ticket.
{code}
AdaptiveSparkPlan isFinalPlan=false
+- SortAggregate(key=[], functions=[max(transform(array(id#15L), 
lambdafunction(lambda x#16L, lambda x#16L, false)))])
   +- Exchange SinglePartition, ENSURE_REQUIREMENTS, [plan_id=38]
      +- SortAggregate(key=[], functions=[partial_max(transform(array(id#15L), 
lambdafunction(lambda x#16L, lambda x#16L, false)))])
         +- Range (0, 2, step=1, splits=16)
{code}

Fixing it via method 1 is more correct than method 2 in terms of avoiding 
incorrect CSE:
{code:diff}
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala
index 330d66a21b..12def60042 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala
@@ -40,7 +40,11 @@ class EquivalentExpressions {
    * Returns true if there was already a matching expression.
    */
   def addExpr(expr: Expression): Boolean = {
-    updateExprInMap(expr, equivalenceMap)
+    if (supportedExpression(expr)) {
+      updateExprInMap(expr, equivalenceMap)
+    } else {
+      false
+    }
   }
 
   /**
{code}
the query runs correctly again, but this time the aggregate expression is NOT 
CSE'd anymore, done consistently for both aggregate expressions and result 
expressions:
{code}
AdaptiveSparkPlan isFinalPlan=false
+- SortAggregate(key=[], functions=[max(transform(array(id#0L), 
lambdafunction(lambda x#1L, lambda x#1L, false))), max(transform(array(id#0L), 
lambdafunction(lambda x#2L, lambda x#2L, false)))])
   +- Exchange SinglePartition, ENSURE_REQUIREMENTS, [plan_id=11]
      +- SortAggregate(key=[], functions=[partial_max(transform(array(id#0L), 
lambdafunction(lambda x#1L, lambda x#1L, false))), 
partial_max(transform(array(id#0L), lambdafunction(lambda x#2L, lambda x#2L, 
false)))])
         +- Range (0, 2, step=1, splits=16)
{code}
and for this particular case, the CSE that used to take place was actually 
okay, so losing CSE here means performance regression.



--
This message was sent by Atlassian Jira
(v8.20.10#820010)

---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscr...@spark.apache.org
For additional commands, e-mail: issues-h...@spark.apache.org

Reply via email to