This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 7290000d51d7 [SPARK-48035][SQL] Fix try_add/try_multiply being 
semantic equal to add/multiply
7290000d51d7 is described below

commit 7290000d51d72ad3a3fb395a7d1975c84b8f8df4
Author: Supun Nakandala <supun.nakand...@databricks.com>
AuthorDate: Tue May 7 10:02:27 2024 +0900

    [SPARK-48035][SQL] Fix try_add/try_multiply being semantic equal to 
add/multiply
    
    ### What changes were proposed in this pull request?
    - This PR fixes a correctness bug in commutative operator canonicalization 
where we currently do not take into account the evaluation mode during operand 
reordering.
    - As a result, the following condition will be incorrectly true:
    ```
    val l1 = Literal(1)
    val l2 = Literal(2)
    val l3 = Literal(3)
    val expr1 = Add(Add(l1, l2), l3)
    val expr2 = Add(Add(l2, l1, EvalMode.TRY), l3)
    expr1.semanticEquals(expr2)
    ```
    - To fix the issue, we now reorder commutative operands only if all 
operators have the same evaluation mode.
    
    ### Why are the changes needed?
    - To fix a correctness bug.
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    - Added unit tests
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #46307 from db-scnakandala/db-scnakandala/master.
    
    Authored-by: Supun Nakandala <supun.nakand...@databricks.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 .../sql/catalyst/expressions/Expression.scala      | 14 ++++++++++++
 .../sql/catalyst/expressions/arithmetic.scala      | 23 ++++++++++++--------
 .../catalyst/expressions/CanonicalizeSuite.scala   | 25 ++++++++++++++++++++++
 3 files changed, 53 insertions(+), 9 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index de15ec43c4f3..2759f5a29c79 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -1378,6 +1378,20 @@ trait CommutativeExpression extends Expression {
       }
     reorderResult
   }
+
+  /**
+   * Helper method to collect the evaluation mode of the commutative 
expressions. This is
+   * used by the canonicalized methods of [[Add]] and [[Multiply]] operators 
to ensure that
+   * all operands have the same evaluation mode before reordering the operands.
+   */
+  protected def collectEvalModes(
+      e: Expression,
+      f: PartialFunction[CommutativeExpression, Seq[EvalMode.Value]]
+  ): Seq[EvalMode.Value] = e match {
+    case c: CommutativeExpression if f.isDefinedAt(c) =>
+      f(c) ++ c.children.flatMap(collectEvalModes(_, f))
+    case _ => Nil
+  }
 }
 
 /**
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index 9eecf81684ce..91c10a53af8a 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -452,13 +452,14 @@ case class Add(
     copy(left = newLeft, right = newRight)
 
   override lazy val canonicalized: Expression = {
-    // TODO: do not reorder consecutive `Add`s with different `evalMode`
-    val reorderResult = buildCanonicalizedPlan(
+    val evalModes = collectEvalModes(this, {case Add(_, _, evalMode) => 
Seq(evalMode)})
+    lazy val reorderResult = buildCanonicalizedPlan(
       { case Add(l, r, _) => Seq(l, r) },
       { case (l: Expression, r: Expression) => Add(l, r, evalMode)},
       Some(evalMode)
     )
-    if (resolved && reorderResult.resolved && reorderResult.dataType == 
dataType) {
+    if (resolved && evalModes.forall(_ == evalMode) && reorderResult.resolved 
&&
+      reorderResult.dataType == dataType) {
       reorderResult
     } else {
       // SPARK-40903: Avoid reordering decimal Add for canonicalization if the 
result data type is
@@ -608,12 +609,16 @@ case class Multiply(
     newLeft: Expression, newRight: Expression): Multiply = copy(left = 
newLeft, right = newRight)
 
   override lazy val canonicalized: Expression = {
-    // TODO: do not reorder consecutive `Multiply`s with different `evalMode`
-    buildCanonicalizedPlan(
-      { case Multiply(l, r, _) => Seq(l, r) },
-      { case (l: Expression, r: Expression) => Multiply(l, r, evalMode)},
-      Some(evalMode)
-    )
+    val evalModes = collectEvalModes(this, {case Multiply(_, _, evalMode) => 
Seq(evalMode)})
+    if (evalModes.forall(_ == evalMode)) {
+      buildCanonicalizedPlan(
+        { case Multiply(l, r, _) => Seq(l, r) },
+        { case (l: Expression, r: Expression) => Multiply(l, r, evalMode)},
+        Some(evalMode)
+      )
+    } else {
+      withCanonicalizedChildren
+    }
   }
 }
 
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala
index 3366d99dd75e..7e545d332105 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala
@@ -454,4 +454,29 @@ class CanonicalizeSuite extends SparkFunSuite {
     // different.
     assert(common3.canonicalized != common4.canonicalized)
   }
+
+  test("SPARK-48035: Add/Multiply operator canonicalization should take into 
account the" +
+    "evaluation mode of the operands before operand reordering") {
+    Seq(1, 10) map { multiCommutativeOpOptThreshold =>
+        val default = SQLConf.get.getConf(MULTI_COMMUTATIVE_OP_OPT_THRESHOLD)
+        SQLConf.get.setConfString(MULTI_COMMUTATIVE_OP_OPT_THRESHOLD.key,
+          multiCommutativeOpOptThreshold.toString)
+        try {
+          val l1 = Literal(1)
+          val l2 = Literal(2)
+          val l3 = Literal(3)
+
+          val expr1 = Add(Add(l1, l2), l3)
+          val expr2 = Add(Add(l2, l1, EvalMode.TRY), l3)
+          assert(!expr1.semanticEquals(expr2))
+
+          val expr3 = Multiply(Multiply(l1, l2), l3)
+          val expr4 = Multiply(Multiply(l2, l1, EvalMode.TRY), l3)
+          assert(!expr3.semanticEquals(expr4))
+        } finally {
+          SQLConf.get.setConfString(MULTI_COMMUTATIVE_OP_OPT_THRESHOLD.key,
+            default.toString)
+        }
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to