This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 808186835077 [SPARK-48035][SQL][FOLLOWUP] Fix try_add/try_multiply 
being semantic equal to add/multiply
808186835077 is described below

commit 808186835077cf50f10262c633f19de4ccc09d9d
Author: Supun Nakandala <supun.nakand...@databricks.com>
AuthorDate: Tue May 7 09:17:01 2024 -0700

    [SPARK-48035][SQL][FOLLOWUP] Fix try_add/try_multiply being semantic equal 
to add/multiply
    
    ### What changes were proposed in this pull request?
    - This is a follow-up to the previous PR: 
https://github.com/apache/spark/pull/46307.
    - With the new changes we do the evalMode check in the `collectOperands` 
function instead of introducing a new function.
    
    ### Why are the changes needed?
    - Better code quality and readability.
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    - Existing unit tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    - No
    
    Closes #46414 from db-scnakandala/db-scnakandala/master.
    
    Authored-by: Supun Nakandala <supun.nakand...@databricks.com>
    Signed-off-by: Dongjoon Hyun <dh...@apple.com>
---
 .../sql/catalyst/expressions/Expression.scala      | 14 -------------
 .../sql/catalyst/expressions/arithmetic.scala      | 23 ++++++++--------------
 2 files changed, 8 insertions(+), 29 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index 2759f5a29c79..de15ec43c4f3 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -1378,20 +1378,6 @@ trait CommutativeExpression extends Expression {
       }
     reorderResult
   }
-
-  /**
-   * Helper method to collect the evaluation mode of the commutative 
expressions. This is
-   * used by the canonicalized methods of [[Add]] and [[Multiply]] operators 
to ensure that
-   * all operands have the same evaluation mode before reordering the operands.
-   */
-  protected def collectEvalModes(
-      e: Expression,
-      f: PartialFunction[CommutativeExpression, Seq[EvalMode.Value]]
-  ): Seq[EvalMode.Value] = e match {
-    case c: CommutativeExpression if f.isDefinedAt(c) =>
-      f(c) ++ c.children.flatMap(collectEvalModes(_, f))
-    case _ => Nil
-  }
 }
 
 /**
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index 91c10a53af8a..a085a4e3a8a3 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -452,14 +452,12 @@ case class Add(
     copy(left = newLeft, right = newRight)
 
   override lazy val canonicalized: Expression = {
-    val evalModes = collectEvalModes(this, {case Add(_, _, evalMode) => 
Seq(evalMode)})
-    lazy val reorderResult = buildCanonicalizedPlan(
-      { case Add(l, r, _) => Seq(l, r) },
+    val reorderResult = buildCanonicalizedPlan(
+      { case Add(l, r, em) if em == evalMode => Seq(l, r) },
       { case (l: Expression, r: Expression) => Add(l, r, evalMode)},
       Some(evalMode)
     )
-    if (resolved && evalModes.forall(_ == evalMode) && reorderResult.resolved 
&&
-      reorderResult.dataType == dataType) {
+    if (resolved && reorderResult.resolved && reorderResult.dataType == 
dataType) {
       reorderResult
     } else {
       // SPARK-40903: Avoid reordering decimal Add for canonicalization if the 
result data type is
@@ -609,16 +607,11 @@ case class Multiply(
     newLeft: Expression, newRight: Expression): Multiply = copy(left = 
newLeft, right = newRight)
 
   override lazy val canonicalized: Expression = {
-    val evalModes = collectEvalModes(this, {case Multiply(_, _, evalMode) => 
Seq(evalMode)})
-    if (evalModes.forall(_ == evalMode)) {
-      buildCanonicalizedPlan(
-        { case Multiply(l, r, _) => Seq(l, r) },
-        { case (l: Expression, r: Expression) => Multiply(l, r, evalMode)},
-        Some(evalMode)
-      )
-    } else {
-      withCanonicalizedChildren
-    }
+    buildCanonicalizedPlan(
+      { case Multiply(l, r, em) if em == evalMode => Seq(l, r) },
+      { case (l: Expression, r: Expression) => Multiply(l, r, evalMode) },
+      Some(evalMode)
+    )
   }
 }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to