This is an automated email from the ASF dual-hosted git repository.

gengliang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 3fbcb26d8e99 [SPARK-48016][SQL] Fix a bug in try_divide function when 
with decimals
3fbcb26d8e99 is described below

commit 3fbcb26d8e992c65a2778b96da4142e234786e53
Author: Gengliang Wang <gengli...@apache.org>
AuthorDate: Mon Apr 29 16:40:56 2024 -0700

    [SPARK-48016][SQL] Fix a bug in try_divide function when with decimals
    
    ### What changes were proposed in this pull request?
    
     Currently, the following query will throw DIVIDE_BY_ZERO error instead of 
returning null
     ```
    SELECT try_divide(1, decimal(0));
    ```
    
    This is caused by the rule `DecimalPrecision`:
    ```
    case b  BinaryOperator(left, right) if left.dataType != right.dataType =>
      (left, right) match {
     ...
        case (l: Literal, r) if r.dataType.isInstanceOf[DecimalType] &&
            l.dataType.isInstanceOf[IntegralType] &&
            literalPickMinimumPrecision =>
          b.makeCopy(Array(Cast(l, DataTypeUtils.fromLiteral(l)), r))
    ```
    The result of the above makeCopy will contain `ANSI` as the `evalMode`, 
instead of `TRY`.
    This PR is to fix this bug by replacing the makeCopy method calls with 
withNewChildren
    
    ### Why are the changes needed?
    
    Bug fix in try_* functions.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, it fixes a long-standing bug in the try_divide function.
    
    ### How was this patch tested?
    
    New UT
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #46286 from gengliangwang/avoidMakeCopy.
    
    Authored-by: Gengliang Wang <gengli...@apache.org>
    Signed-off-by: Gengliang Wang <gengli...@apache.org>
---
 .../sql/catalyst/analysis/DecimalPrecision.scala   | 14 ++---
 .../spark/sql/catalyst/analysis/TypeCoercion.scala | 10 ++--
 .../analyzer-results/ansi/try_arithmetic.sql.out   | 56 +++++++++++++++++++
 .../analyzer-results/try_arithmetic.sql.out        | 56 +++++++++++++++++++
 .../resources/sql-tests/inputs/try_arithmetic.sql  |  8 +++
 .../sql-tests/results/ansi/try_arithmetic.sql.out  | 64 ++++++++++++++++++++++
 .../sql-tests/results/try_arithmetic.sql.out       | 64 ++++++++++++++++++++++
 7 files changed, 260 insertions(+), 12 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala
index 9ad8368d007e..6524ff9b2c57 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala
@@ -92,7 +92,7 @@ object DecimalPrecision extends TypeCoercionRule {
       val resultType = widerDecimalType(p1, s1, p2, s2)
       val newE1 = if (e1.dataType == resultType) e1 else Cast(e1, resultType)
       val newE2 = if (e2.dataType == resultType) e2 else Cast(e2, resultType)
-      b.makeCopy(Array(newE1, newE2))
+      b.withNewChildren(Seq(newE1, newE2))
   }
 
   /**
@@ -211,21 +211,21 @@ object DecimalPrecision extends TypeCoercionRule {
         case (l: Literal, r) if r.dataType.isInstanceOf[DecimalType] &&
             l.dataType.isInstanceOf[IntegralType] &&
             literalPickMinimumPrecision =>
-          b.makeCopy(Array(Cast(l, DataTypeUtils.fromLiteral(l)), r))
+          b.withNewChildren(Seq(Cast(l, DataTypeUtils.fromLiteral(l)), r))
         case (l, r: Literal) if l.dataType.isInstanceOf[DecimalType] &&
             r.dataType.isInstanceOf[IntegralType] &&
             literalPickMinimumPrecision =>
-          b.makeCopy(Array(l, Cast(r, DataTypeUtils.fromLiteral(r))))
+          b.withNewChildren(Seq(l, Cast(r, DataTypeUtils.fromLiteral(r))))
         // Promote integers inside a binary expression with fixed-precision 
decimals to decimals,
         // and fixed-precision decimals in an expression with floats / doubles 
to doubles
         case (l @ IntegralTypeExpression(), r @ DecimalExpression(_, _)) =>
-          b.makeCopy(Array(Cast(l, DecimalType.forType(l.dataType)), r))
+          b.withNewChildren(Seq(Cast(l, DecimalType.forType(l.dataType)), r))
         case (l @ DecimalExpression(_, _), r @ IntegralTypeExpression()) =>
-          b.makeCopy(Array(l, Cast(r, DecimalType.forType(r.dataType))))
+          b.withNewChildren(Seq(l, Cast(r, DecimalType.forType(r.dataType))))
         case (l, r @ DecimalExpression(_, _)) if isFloat(l.dataType) =>
-          b.makeCopy(Array(l, Cast(r, DoubleType)))
+          b.withNewChildren(Seq(l, Cast(r, DoubleType)))
         case (l @ DecimalExpression(_, _), r) if isFloat(r.dataType) =>
-          b.makeCopy(Array(Cast(l, DoubleType), r))
+          b.withNewChildren(Seq(Cast(l, DoubleType), r))
         case _ => b
       }
   }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
index 506314effde3..936bb22baa46 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
@@ -1111,22 +1111,22 @@ object TypeCoercion extends TypeCoercionBase {
 
       case a @ BinaryArithmetic(left @ StringTypeExpression(), right)
         if !isIntervalType(right.dataType) =>
-        a.makeCopy(Array(Cast(left, DoubleType), right))
+        a.withNewChildren(Seq(Cast(left, DoubleType), right))
       case a @ BinaryArithmetic(left, right @ StringTypeExpression())
         if !isIntervalType(left.dataType) =>
-        a.makeCopy(Array(left, Cast(right, DoubleType)))
+        a.withNewChildren(Seq(left, Cast(right, DoubleType)))
 
       // For equality between string and timestamp we cast the string to a 
timestamp
       // so that things like rounding of subsecond precision does not affect 
the comparison.
       case p @ Equality(left @ StringTypeExpression(), right @ 
TimestampTypeExpression()) =>
-        p.makeCopy(Array(Cast(left, TimestampType), right))
+        p.withNewChildren(Seq(Cast(left, TimestampType), right))
       case p @ Equality(left @ TimestampTypeExpression(), right @ 
StringTypeExpression()) =>
-        p.makeCopy(Array(left, Cast(right, TimestampType)))
+        p.withNewChildren(Seq(left, Cast(right, TimestampType)))
 
       case p @ BinaryComparison(left, right)
           if findCommonTypeForBinaryComparison(left.dataType, right.dataType, 
conf).isDefined =>
         val commonType = findCommonTypeForBinaryComparison(left.dataType, 
right.dataType, conf).get
-        p.makeCopy(Array(castExpr(left, commonType), castExpr(right, 
commonType)))
+        p.withNewChildren(Seq(castExpr(left, commonType), castExpr(right, 
commonType)))
     }
   }
 
diff --git 
a/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/try_arithmetic.sql.out
 
b/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/try_arithmetic.sql.out
index ef17f6b50b90..30654d1d71e2 100644
--- 
a/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/try_arithmetic.sql.out
+++ 
b/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/try_arithmetic.sql.out
@@ -13,6 +13,20 @@ Project [try_add(2147483647, 1) AS try_add(2147483647, 1)#x]
 +- OneRowRelation
 
 
+-- !query
+SELECT try_add(2147483647, decimal(1))
+-- !query analysis
+Project [try_add(2147483647, cast(1 as decimal(10,0))) AS try_add(2147483647, 
1)#x]
++- OneRowRelation
+
+
+-- !query
+SELECT try_add(2147483647, "1")
+-- !query analysis
+Project [try_add(2147483647, 1) AS try_add(2147483647, 1)#xL]
++- OneRowRelation
+
+
 -- !query
 SELECT try_add(-2147483648, -1)
 -- !query analysis
@@ -211,6 +225,20 @@ Project [try_divide(1, (1.0 / 0.0)) AS try_divide(1, (1.0 
/ 0.0))#x]
 +- OneRowRelation
 
 
+-- !query
+SELECT try_divide(1, decimal(0))
+-- !query analysis
+Project [try_divide(1, cast(0 as decimal(10,0))) AS try_divide(1, 0)#x]
++- OneRowRelation
+
+
+-- !query
+SELECT try_divide(1, "0")
+-- !query analysis
+Project [try_divide(1, 0) AS try_divide(1, 0)#x]
++- OneRowRelation
+
+
 -- !query
 SELECT try_divide(interval 2 year, 2)
 -- !query analysis
@@ -267,6 +295,20 @@ Project [try_subtract(2147483647, -1) AS 
try_subtract(2147483647, -1)#x]
 +- OneRowRelation
 
 
+-- !query
+SELECT try_subtract(2147483647, decimal(-1))
+-- !query analysis
+Project [try_subtract(2147483647, cast(-1 as decimal(10,0))) AS 
try_subtract(2147483647, -1)#x]
++- OneRowRelation
+
+
+-- !query
+SELECT try_subtract(2147483647, "-1")
+-- !query analysis
+Project [try_subtract(2147483647, -1) AS try_subtract(2147483647, -1)#xL]
++- OneRowRelation
+
+
 -- !query
 SELECT try_subtract(-2147483648, 1)
 -- !query analysis
@@ -351,6 +393,20 @@ Project [try_multiply(2147483647, -2) AS 
try_multiply(2147483647, -2)#x]
 +- OneRowRelation
 
 
+-- !query
+SELECT try_multiply(2147483647, decimal(-2))
+-- !query analysis
+Project [try_multiply(2147483647, cast(-2 as decimal(10,0))) AS 
try_multiply(2147483647, -2)#x]
++- OneRowRelation
+
+
+-- !query
+SELECT try_multiply(2147483647, "-2")
+-- !query analysis
+Project [try_multiply(2147483647, -2) AS try_multiply(2147483647, -2)#xL]
++- OneRowRelation
+
+
 -- !query
 SELECT try_multiply(-2147483648, 2)
 -- !query analysis
diff --git 
a/sql/core/src/test/resources/sql-tests/analyzer-results/try_arithmetic.sql.out 
b/sql/core/src/test/resources/sql-tests/analyzer-results/try_arithmetic.sql.out
index ef17f6b50b90..caf997f6ccbb 100644
--- 
a/sql/core/src/test/resources/sql-tests/analyzer-results/try_arithmetic.sql.out
+++ 
b/sql/core/src/test/resources/sql-tests/analyzer-results/try_arithmetic.sql.out
@@ -13,6 +13,20 @@ Project [try_add(2147483647, 1) AS try_add(2147483647, 1)#x]
 +- OneRowRelation
 
 
+-- !query
+SELECT try_add(2147483647, decimal(1))
+-- !query analysis
+Project [try_add(2147483647, cast(1 as decimal(10,0))) AS try_add(2147483647, 
1)#x]
++- OneRowRelation
+
+
+-- !query
+SELECT try_add(2147483647, "1")
+-- !query analysis
+Project [try_add(2147483647, 1) AS try_add(2147483647, 1)#x]
++- OneRowRelation
+
+
 -- !query
 SELECT try_add(-2147483648, -1)
 -- !query analysis
@@ -211,6 +225,20 @@ Project [try_divide(1, (1.0 / 0.0)) AS try_divide(1, (1.0 
/ 0.0))#x]
 +- OneRowRelation
 
 
+-- !query
+SELECT try_divide(1, decimal(0))
+-- !query analysis
+Project [try_divide(1, cast(0 as decimal(10,0))) AS try_divide(1, 0)#x]
++- OneRowRelation
+
+
+-- !query
+SELECT try_divide(1, "0")
+-- !query analysis
+Project [try_divide(1, 0) AS try_divide(1, 0)#x]
++- OneRowRelation
+
+
 -- !query
 SELECT try_divide(interval 2 year, 2)
 -- !query analysis
@@ -267,6 +295,20 @@ Project [try_subtract(2147483647, -1) AS 
try_subtract(2147483647, -1)#x]
 +- OneRowRelation
 
 
+-- !query
+SELECT try_subtract(2147483647, decimal(-1))
+-- !query analysis
+Project [try_subtract(2147483647, cast(-1 as decimal(10,0))) AS 
try_subtract(2147483647, -1)#x]
++- OneRowRelation
+
+
+-- !query
+SELECT try_subtract(2147483647, "-1")
+-- !query analysis
+Project [try_subtract(2147483647, -1) AS try_subtract(2147483647, -1)#x]
++- OneRowRelation
+
+
 -- !query
 SELECT try_subtract(-2147483648, 1)
 -- !query analysis
@@ -351,6 +393,20 @@ Project [try_multiply(2147483647, -2) AS 
try_multiply(2147483647, -2)#x]
 +- OneRowRelation
 
 
+-- !query
+SELECT try_multiply(2147483647, decimal(-2))
+-- !query analysis
+Project [try_multiply(2147483647, cast(-2 as decimal(10,0))) AS 
try_multiply(2147483647, -2)#x]
++- OneRowRelation
+
+
+-- !query
+SELECT try_multiply(2147483647, "-2")
+-- !query analysis
+Project [try_multiply(2147483647, -2) AS try_multiply(2147483647, -2)#x]
++- OneRowRelation
+
+
 -- !query
 SELECT try_multiply(-2147483648, 2)
 -- !query analysis
diff --git a/sql/core/src/test/resources/sql-tests/inputs/try_arithmetic.sql 
b/sql/core/src/test/resources/sql-tests/inputs/try_arithmetic.sql
index 55907b6701e5..943865b68d39 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/try_arithmetic.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/try_arithmetic.sql
@@ -1,6 +1,8 @@
 -- Numeric + Numeric
 SELECT try_add(1, 1);
 SELECT try_add(2147483647, 1);
+SELECT try_add(2147483647, decimal(1));
+SELECT try_add(2147483647, "1");
 SELECT try_add(-2147483648, -1);
 SELECT try_add(9223372036854775807L, 1);
 SELECT try_add(-9223372036854775808L, -1);
@@ -38,6 +40,8 @@ SELECT try_divide(0, 0);
 SELECT try_divide(1, (2147483647 + 1));
 SELECT try_divide(1L, (9223372036854775807L + 1L));
 SELECT try_divide(1, 1.0 / 0.0);
+SELECT try_divide(1, decimal(0));
+SELECT try_divide(1, "0");
 
 -- Interval / Numeric
 SELECT try_divide(interval 2 year, 2);
@@ -50,6 +54,8 @@ SELECT try_divide(interval 106751991 day, 0.5);
 -- Numeric - Numeric
 SELECT try_subtract(1, 1);
 SELECT try_subtract(2147483647, -1);
+SELECT try_subtract(2147483647, decimal(-1));
+SELECT try_subtract(2147483647, "-1");
 SELECT try_subtract(-2147483648, 1);
 SELECT try_subtract(9223372036854775807L, -1);
 SELECT try_subtract(-9223372036854775808L, 1);
@@ -66,6 +72,8 @@ SELECT try_subtract(interval 106751991 day, interval -3 day);
 -- Numeric * Numeric
 SELECT try_multiply(2, 3);
 SELECT try_multiply(2147483647, -2);
+SELECT try_multiply(2147483647, decimal(-2));
+SELECT try_multiply(2147483647, "-2");
 SELECT try_multiply(-2147483648, 2);
 SELECT try_multiply(9223372036854775807L, 2);
 SELECT try_multiply(-9223372036854775808L, -2);
diff --git 
a/sql/core/src/test/resources/sql-tests/results/ansi/try_arithmetic.sql.out 
b/sql/core/src/test/resources/sql-tests/results/ansi/try_arithmetic.sql.out
index adb6550e8083..acf6e70a50de 100644
--- a/sql/core/src/test/resources/sql-tests/results/ansi/try_arithmetic.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/ansi/try_arithmetic.sql.out
@@ -15,6 +15,22 @@ struct<try_add(2147483647, 1):int>
 NULL
 
 
+-- !query
+SELECT try_add(2147483647, decimal(1))
+-- !query schema
+struct<try_add(2147483647, 1):decimal(11,0)>
+-- !query output
+2147483648
+
+
+-- !query
+SELECT try_add(2147483647, "1")
+-- !query schema
+struct<try_add(2147483647, 1):bigint>
+-- !query output
+2147483648
+
+
 -- !query
 SELECT try_add(-2147483648, -1)
 -- !query schema
@@ -341,6 +357,22 @@ org.apache.spark.SparkArithmeticException
 }
 
 
+-- !query
+SELECT try_divide(1, decimal(0))
+-- !query schema
+struct<try_divide(1, 0):decimal(12,11)>
+-- !query output
+NULL
+
+
+-- !query
+SELECT try_divide(1, "0")
+-- !query schema
+struct<try_divide(1, 0):double>
+-- !query output
+NULL
+
+
 -- !query
 SELECT try_divide(interval 2 year, 2)
 -- !query schema
@@ -405,6 +437,22 @@ struct<try_subtract(2147483647, -1):int>
 NULL
 
 
+-- !query
+SELECT try_subtract(2147483647, decimal(-1))
+-- !query schema
+struct<try_subtract(2147483647, -1):decimal(11,0)>
+-- !query output
+2147483648
+
+
+-- !query
+SELECT try_subtract(2147483647, "-1")
+-- !query schema
+struct<try_subtract(2147483647, -1):bigint>
+-- !query output
+2147483648
+
+
 -- !query
 SELECT try_subtract(-2147483648, 1)
 -- !query schema
@@ -547,6 +595,22 @@ struct<try_multiply(2147483647, -2):int>
 NULL
 
 
+-- !query
+SELECT try_multiply(2147483647, decimal(-2))
+-- !query schema
+struct<try_multiply(2147483647, -2):decimal(21,0)>
+-- !query output
+-4294967294
+
+
+-- !query
+SELECT try_multiply(2147483647, "-2")
+-- !query schema
+struct<try_multiply(2147483647, -2):bigint>
+-- !query output
+-4294967294
+
+
 -- !query
 SELECT try_multiply(-2147483648, 2)
 -- !query schema
diff --git 
a/sql/core/src/test/resources/sql-tests/results/try_arithmetic.sql.out 
b/sql/core/src/test/resources/sql-tests/results/try_arithmetic.sql.out
index fa83652da0ed..b12680c2a675 100644
--- a/sql/core/src/test/resources/sql-tests/results/try_arithmetic.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/try_arithmetic.sql.out
@@ -15,6 +15,22 @@ struct<try_add(2147483647, 1):int>
 NULL
 
 
+-- !query
+SELECT try_add(2147483647, decimal(1))
+-- !query schema
+struct<try_add(2147483647, 1):decimal(11,0)>
+-- !query output
+2147483648
+
+
+-- !query
+SELECT try_add(2147483647, "1")
+-- !query schema
+struct<try_add(2147483647, 1):double>
+-- !query output
+2.147483648E9
+
+
 -- !query
 SELECT try_add(-2147483648, -1)
 -- !query schema
@@ -249,6 +265,22 @@ struct<try_divide(1, (1.0 / 0.0)):decimal(16,9)>
 NULL
 
 
+-- !query
+SELECT try_divide(1, decimal(0))
+-- !query schema
+struct<try_divide(1, 0):decimal(12,11)>
+-- !query output
+NULL
+
+
+-- !query
+SELECT try_divide(1, "0")
+-- !query schema
+struct<try_divide(1, 0):double>
+-- !query output
+NULL
+
+
 -- !query
 SELECT try_divide(interval 2 year, 2)
 -- !query schema
@@ -313,6 +345,22 @@ struct<try_subtract(2147483647, -1):int>
 NULL
 
 
+-- !query
+SELECT try_subtract(2147483647, decimal(-1))
+-- !query schema
+struct<try_subtract(2147483647, -1):decimal(11,0)>
+-- !query output
+2147483648
+
+
+-- !query
+SELECT try_subtract(2147483647, "-1")
+-- !query schema
+struct<try_subtract(2147483647, -1):double>
+-- !query output
+2.147483648E9
+
+
 -- !query
 SELECT try_subtract(-2147483648, 1)
 -- !query schema
@@ -409,6 +457,22 @@ struct<try_multiply(2147483647, -2):int>
 NULL
 
 
+-- !query
+SELECT try_multiply(2147483647, decimal(-2))
+-- !query schema
+struct<try_multiply(2147483647, -2):decimal(21,0)>
+-- !query output
+-4294967294
+
+
+-- !query
+SELECT try_multiply(2147483647, "-2")
+-- !query schema
+struct<try_multiply(2147483647, -2):double>
+-- !query output
+-4.294967294E9
+
+
 -- !query
 SELECT try_multiply(-2147483648, 2)
 -- !query schema


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to