This is an automated email from the ASF dual-hosted git repository.

gengliang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 1ca14149546 [SPARK-40903][SQL] Avoid reordering decimal Add for 
canonicalization if data type is changed
1ca14149546 is described below

commit 1ca14149546ed475d2903eed17a6a47bb098937e
Author: Gengliang Wang <gengli...@apache.org>
AuthorDate: Tue Oct 25 22:39:13 2022 -0700

    [SPARK-40903][SQL] Avoid reordering decimal Add for canonicalization if 
data type is changed
    
    ### What changes were proposed in this pull request?
    
    Avoid reordering Add for canonicalizing if it is decimal type and the 
result data type is changed.
    Expressions are canonicalized for comparisons and explanations. For 
non-decimal Add expression, the order can be sorted by hashcode, and the result 
is supposed to be the same.
    However, for Add expression of Decimal type, the behavior is different: 
Given decimal (p1, s1) and another decimal (p2, s2), the result integral part 
is `max(p1-s1, p2-s2) +1`, the result decimal part is `max(s1, s2)`. Thus the 
result data type is `(max(p1-s1, p2-s2) +1 + max(s1, s2), max(s1, s2))`.
    Thus the order matters:
    
    For `(decimal(12,5) + decimal(12,6)) + decimal(3, 2)`, the first add 
`decimal(12,5) + decimal(12,6)` results in `decimal(14, 6)`, and then 
`decimal(14, 6) + decimal(3, 2)`  results in `decimal(15, 6)`
    For `(decimal(12, 6) + decimal(3,2)) + decimal(12, 5)`, the first add 
`decimal(12, 6) + decimal(3,2)` results in `decimal(13, 6)`, and then 
`decimal(13, 6) + decimal(12, 5)` results in `decimal(14, 6)`
    
    In the following query:
    ```
    create table foo(a decimal(12, 5), b decimal(12, 6)) using orc
    select sum(coalesce(a+b+ 1.75, a)) from foo
    ```
    At first `coalesce(a+b+ 1.75, a)` is resolved as `coalesce(a+b+ 1.75, 
cast(a as decimal(15, 6))`. In the canonicalized version, the expression 
becomes `coalesce(1.75+b+a, cast(a as decimal(15, 6))`. As explained above, 
`1.75+b+a` is of decimal(14, 6), which is different from  `cast(a as 
decimal(15, 6)`. Thus the following error will happen:
    ```
    java.lang.IllegalArgumentException: requirement failed: All input types 
must be the same except nullable, containsNull, valueContainsNull flags. The 
input types found are
            DecimalType(14,6)
            DecimalType(15,6)
            at scala.Predef$.require(Predef.scala:281)
            at 
org.apache.spark.sql.catalyst.expressions.ComplexTypeMergingExpression.dataTypeCheck(Expression.scala:1149)
            at 
org.apache.spark.sql.catalyst.expressions.ComplexTypeMergingExpression.dataTypeCheck$(Expression.scala:1143)
    ```
    This PR is to fix the bug.
    ### Why are the changes needed?
    
    Bug fix
    ### Does this PR introduce _any_ user-facing change?
    
    No
    ### How was this patch tested?
    
    A new test case
    
    Closes #38379 from gengliangwang/fixDecimalAdd.
    
    Lead-authored-by: Gengliang Wang <gengli...@apache.org>
    Co-authored-by: Gengliang Wang <ltn...@gmail.com>
    Signed-off-by: Gengliang Wang <gengli...@apache.org>
---
 .../spark/sql/catalyst/expressions/Expression.scala  |  8 +++++++-
 .../spark/sql/catalyst/expressions/arithmetic.scala  | 10 +++++++++-
 .../sql/catalyst/expressions/CanonicalizeSuite.scala | 20 ++++++++++++++++++--
 .../scala/org/apache/spark/sql/SQLQuerySuite.scala   |  8 ++++++++
 4 files changed, 42 insertions(+), 4 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index 6df03aa8e84..6d8c2e83ef7 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -242,7 +242,13 @@ abstract class Expression extends TreeNode[Expression] {
    * This means that the lazy `cannonicalized` is called and computed only on 
the root of the
    * adjacent expressions.
    */
-  lazy val canonicalized: Expression = {
+  lazy val canonicalized: Expression = withCanonicalizedChildren
+
+  /**
+   * The default process of canonicalization. It is a one pass, bottum-up 
expression tree
+   * computation based oncanonicalizing children before canonicalizing the 
current node.
+   */
+  final protected def withCanonicalizedChildren: Expression = {
     val canonicalizedChildren = children.map(_.canonicalized)
     withNewChildren(canonicalizedChildren)
   }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index 3e8ec94c33c..4d99c3b02a0 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -479,7 +479,15 @@ case class Add(
 
   override lazy val canonicalized: Expression = {
     // TODO: do not reorder consecutive `Add`s with different `evalMode`
-    orderCommutative({ case Add(l, r, _) => Seq(l, r) }).reduce(Add(_, _, 
evalMode))
+    val reorderResult =
+      orderCommutative({ case Add(l, r, _) => Seq(l, r) }).reduce(Add(_, _, 
evalMode))
+    if (resolved && reorderResult.resolved && reorderResult.dataType == 
dataType) {
+      reorderResult
+    } else {
+      // SPARK-40903: Avoid reordering decimal Add for canonicalization if the 
result data type is
+      // changed, which may cause data checking error within 
ComplexTypeMergingExpression.
+      withCanonicalizedChildren
+    }
   }
 }
 
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala
index 43b7f35f7bb..057fb98c239 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala
@@ -23,7 +23,7 @@ import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.dsl.plans._
 import org.apache.spark.sql.catalyst.plans.logical.Range
-import org.apache.spark.sql.types.{IntegerType, LongType, StringType, 
StructField, StructType}
+import org.apache.spark.sql.types.{Decimal, DecimalType, IntegerType, 
LongType, StringType, StructField, StructType}
 
 class CanonicalizeSuite extends SparkFunSuite {
 
@@ -187,7 +187,23 @@ class CanonicalizeSuite extends SparkFunSuite {
   test("SPARK-40362: Commutative operator under BinaryComparison") {
     Seq(EqualTo, EqualNullSafe, GreaterThan, LessThan, GreaterThanOrEqual, 
LessThanOrEqual)
       .foreach { bc =>
-        assert(bc(Add($"a", $"b"), Literal(10)).semanticEquals(bc(Add($"b", 
$"a"), Literal(10))))
+        assert(bc(Multiply($"a", $"b"), Literal(10)).semanticEquals(
+          bc(Multiply($"b", $"a"), Literal(10))))
       }
   }
+
+  test("SPARK-40903: Only reorder decimal Add when the result data type is not 
changed") {
+    val d = Decimal(1.2)
+    val literal1 = Literal.create(d, DecimalType(2, 1))
+    val literal2 = Literal.create(d, DecimalType(2, 1))
+    val literal3 = Literal.create(d, DecimalType(3, 2))
+    assert(Add(literal1, literal2).semanticEquals(Add(literal2, literal1)))
+    assert(Add(Add(literal1, literal2), literal3).semanticEquals(
+      Add(Add(literal3, literal2), literal1)))
+
+    val literal4 = Literal.create(d, DecimalType(12, 5))
+    val literal5 = Literal.create(d, DecimalType(12, 6))
+    assert(!Add(Add(literal4, literal5), literal1).semanticEquals(
+      Add(Add(literal1, literal5), literal4)))
+  }
 }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 030e68d227a..dd3ad0f4d6b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -4518,6 +4518,14 @@ class SQLQuerySuite extends QueryTest with 
SharedSparkSession with AdaptiveSpark
       }
     }
   }
+
+  test("SPARK-40903: Don't reorder Add for canonicalize if it is decimal 
type") {
+    val tableName = "decimalTable"
+    withTable(tableName) {
+      sql(s"create table $tableName(a decimal(12, 5), b decimal(12, 6)) using 
orc")
+      checkAnswer(sql(s"select sum(coalesce(a + b + 1.75, a)) from 
$tableName"), Row(null))
+    }
+  }
 }
 
 case class Foo(bar: Option[String])


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to