This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.4 by this push:
     new 77e52fd76b30 [SPARK-45786][SQL] Fix inaccurate Decimal multiplication 
and division results
77e52fd76b30 is described below

commit 77e52fd76b3055e070ddc1d147e1d8c9e2b6be09
Author: Kazuyuki Tanimura <ktanim...@apple.com>
AuthorDate: Tue Nov 7 09:06:00 2023 -0800

    [SPARK-45786][SQL] Fix inaccurate Decimal multiplication and division 
results
    
    ### What changes were proposed in this pull request?
    This PR fixes inaccurate Decimal multiplication and division results.
    
    ### Why are the changes needed?
    Decimal multiplication and division results may be inaccurate due to 
rounding issues.
    #### Multiplication:
    ```
    scala> sql("select  -14120025096157587712113961295153.858047 * 
-0.4652").show(truncate=false)
    +----------------------------------------------------+
    |(-14120025096157587712113961295153.858047 * -0.4652)|
    +----------------------------------------------------+
    |6568635674732509803675414794505.574764              |
    +----------------------------------------------------+
    ```
    The correct answer is `6568635674732509803675414794505.574763`
    
    Please note that the last digit is `3` instead of `4` as
    
    ```
    scala> 
java.math.BigDecimal("-14120025096157587712113961295153.858047").multiply(java.math.BigDecimal("-0.4652"))
    val res21: java.math.BigDecimal = 6568635674732509803675414794505.5747634644
    ```
    Since the factional part `.574763` is followed by `4644`, it should not be 
rounded up.
    
    #### Division:
    ```
    scala> sql("select -0.172787979 / 
533704665545018957788294905796.5").show(truncate=false)
    +-------------------------------------------------+
    |(-0.172787979 / 533704665545018957788294905796.5)|
    +-------------------------------------------------+
    |-3.237521E-31                                    |
    +-------------------------------------------------+
    ```
    The correct answer is `-3.237520E-31`
    
    Please note that the last digit is `0` instead of `1` as
    
    ```
    scala> 
java.math.BigDecimal("-0.172787979").divide(java.math.BigDecimal("533704665545018957788294905796.5"),
 100, java.math.RoundingMode.DOWN)
    val res22: java.math.BigDecimal = 
-3.237520489418037889998826491401059986665344697406144511563561222578738E-31
    ```
    Since the factional part `.237520` is followed by `4894...`, it should not 
be rounded up.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes, users will see correct Decimal multiplication and division results.
    Directly multiplying and dividing with 
`org.apache.spark.sql.types.Decimal()` (not via SQL) will return 39 digit at 
maximum instead of 38 at maximum and round down instead of round half-up
    
    ### How was this patch tested?
    Test added
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #43678 from kazuyukitanimura/SPARK-45786.
    
    Authored-by: Kazuyuki Tanimura <ktanim...@apple.com>
    Signed-off-by: Dongjoon Hyun <dh...@apple.com>
    (cherry picked from commit 5ef3a846f52ab90cb7183953cff3080449d0b57b)
    Signed-off-by: Dongjoon Hyun <dh...@apple.com>
---
 .../scala/org/apache/spark/sql/types/Decimal.scala |   8 +-
 .../expressions/ArithmeticExpressionSuite.scala    | 107 +++++++++++++++++++++
 .../ansi/decimalArithmeticOperations.sql.out       |  14 +--
 3 files changed, 120 insertions(+), 9 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
index 2c0b6677541f..baf0dc9cfbaf 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
@@ -497,7 +497,7 @@ final class Decimal extends Ordered[Decimal] with 
Serializable {
 
   def / (that: Decimal): Decimal =
     if (that.isZero) null else 
Decimal(toJavaBigDecimal.divide(that.toJavaBigDecimal,
-      DecimalType.MAX_SCALE, MATH_CONTEXT.getRoundingMode))
+      DecimalType.MAX_SCALE + 1, MATH_CONTEXT.getRoundingMode))
 
   def % (that: Decimal): Decimal =
     if (that.isZero) null
@@ -545,7 +545,11 @@ object Decimal {
 
   val POW_10 = Array.tabulate[Long](MAX_LONG_DIGITS + 1)(i => math.pow(10, 
i).toLong)
 
-  private val MATH_CONTEXT = new MathContext(DecimalType.MAX_PRECISION, 
RoundingMode.HALF_UP)
+  // SPARK-45786 Using RoundingMode.HALF_UP with MathContext may cause 
inaccurate SQL results
+  // because TypeCoercion later rounds again. Instead, always round down and 
use 1 digit longer
+  // precision than DecimalType.MAX_PRECISION. Then, TypeCoercion will 
properly round up/down
+  // the last extra digit.
+  private val MATH_CONTEXT = new MathContext(DecimalType.MAX_PRECISION + 1, 
RoundingMode.DOWN)
 
   private[sql] val ZERO = Decimal(0)
   private[sql] val ONE = Decimal(1)
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
index e21793ab506c..568dcd10d116 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
@@ -17,6 +17,7 @@
 
 package org.apache.spark.sql.catalyst.expressions
 
+import java.math.RoundingMode
 import java.sql.{Date, Timestamp}
 import java.time.{Duration, Period}
 import java.time.temporal.ChronoUnit
@@ -225,6 +226,112 @@ class ArithmeticExpressionSuite extends SparkFunSuite 
with ExpressionEvalHelper
     }
   }
 
+  test("SPARK-45786: Decimal multiply, divide, remainder, quot") {
+    // Some known cases
+    checkEvaluation(
+      Multiply(
+        
Literal(Decimal(BigDecimal("-14120025096157587712113961295153.858047"), 38, 6)),
+        Literal(Decimal(BigDecimal("-0.4652"), 4, 4))
+      ),
+      Decimal(BigDecimal("6568635674732509803675414794505.574763"))
+    )
+    checkEvaluation(
+      Multiply(
+        Literal(Decimal(BigDecimal("-240810500742726"), 15, 0)),
+        Literal(Decimal(BigDecimal("-5677.6988688550027099967697071"), 29, 25))
+      ),
+      Decimal(BigDecimal("1367249507675382200.164877854336665327"))
+    )
+    checkEvaluation(
+      Divide(
+        Literal(Decimal(BigDecimal("-0.172787979"), 9, 9)),
+        Literal(Decimal(BigDecimal("533704665545018957788294905796.5"), 31, 1))
+      ),
+      Decimal(BigDecimal("-3.237520E-31"))
+    )
+    checkEvaluation(
+      Divide(
+        Literal(Decimal(BigDecimal("-0.574302343618"), 12, 12)),
+        Literal(Decimal(BigDecimal("-795826820326278835912868.106"), 27, 3))
+      ),
+      Decimal(BigDecimal("7.21642358550E-25"))
+    )
+
+    // Random tests
+    val rand = scala.util.Random
+    def makeNum(p: Int, s: Int): String = {
+      val int1 = rand.nextLong()
+      val int2 = rand.nextLong().abs
+      val frac1 = rand.nextLong().abs
+      val frac2 = rand.nextLong().abs
+      s"$int1$int2".take(p - s + (int1 >>> 63).toInt) + "." + 
s"$frac1$frac2".take(s)
+    }
+
+    (0 until 100).foreach { _ =>
+      val p1 = rand.nextInt(38) + 1 // 1 <= p1 <= 38
+      val s1 = rand.nextInt(p1 + 1) // 0 <= s1 <= p1
+      val p2 = rand.nextInt(38) + 1
+      val s2 = rand.nextInt(p2 + 1)
+
+      val n1 = makeNum(p1, s1)
+      val n2 = makeNum(p2, s2)
+
+      val mulActual = Multiply(
+        Literal(Decimal(BigDecimal(n1), p1, s1)),
+        Literal(Decimal(BigDecimal(n2), p2, s2))
+      )
+      val mulExact = new java.math.BigDecimal(n1).multiply(new 
java.math.BigDecimal(n2))
+
+      val divActual = Divide(
+        Literal(Decimal(BigDecimal(n1), p1, s1)),
+        Literal(Decimal(BigDecimal(n2), p2, s2))
+      )
+      val divExact = new java.math.BigDecimal(n1)
+        .divide(new java.math.BigDecimal(n2), 100, RoundingMode.DOWN)
+
+      val remActual = Remainder(
+        Literal(Decimal(BigDecimal(n1), p1, s1)),
+        Literal(Decimal(BigDecimal(n2), p2, s2))
+      )
+      val remExact = new java.math.BigDecimal(n1).remainder(new 
java.math.BigDecimal(n2))
+
+      val quotActual = IntegralDivide(
+        Literal(Decimal(BigDecimal(n1), p1, s1)),
+        Literal(Decimal(BigDecimal(n2), p2, s2))
+      )
+      val quotExact =
+        new java.math.BigDecimal(n1).divideToIntegralValue(new 
java.math.BigDecimal(n2))
+
+      Seq(true, false).foreach { allowPrecLoss =>
+        withSQLConf(SQLConf.DECIMAL_OPERATIONS_ALLOW_PREC_LOSS.key -> 
allowPrecLoss.toString) {
+          val mulType = Multiply(null, null).resultDecimalType(p1, s1, p2, s2)
+          val mulResult = Decimal(mulExact.setScale(mulType.scale, 
RoundingMode.HALF_UP))
+          val mulExpected =
+            if (mulResult.precision > DecimalType.MAX_PRECISION) null else 
mulResult
+          checkEvaluation(mulActual, mulExpected)
+
+          val divType = Divide(null, null).resultDecimalType(p1, s1, p2, s2)
+          val divResult = Decimal(divExact.setScale(divType.scale, 
RoundingMode.HALF_UP))
+          val divExpected =
+            if (divResult.precision > DecimalType.MAX_PRECISION) null else 
divResult
+          checkEvaluation(divActual, divExpected)
+
+          val remType = Remainder(null, null).resultDecimalType(p1, s1, p2, s2)
+          val remResult = Decimal(remExact.setScale(remType.scale, 
RoundingMode.HALF_UP))
+          val remExpected =
+            if (remResult.precision > DecimalType.MAX_PRECISION) null else 
remResult
+          checkEvaluation(remActual, remExpected)
+
+          val quotType = IntegralDivide(null, null).resultDecimalType(p1, s1, 
p2, s2)
+          val quotResult = Decimal(quotExact.setScale(quotType.scale, 
RoundingMode.HALF_UP))
+          val quotExpected =
+            if (quotResult.precision > DecimalType.MAX_PRECISION) null else 
quotResult
+          checkEvaluation(quotActual, quotExpected.toLong)
+        }
+      }
+    }
+  }
+
   private def testDecimalAndDoubleType(testFunc: (Int => Any) => Unit): Unit = 
{
     testFunc(_.toDouble)
     testFunc(Decimal(_))
diff --git 
a/sql/core/src/test/resources/sql-tests/results/ansi/decimalArithmeticOperations.sql.out
 
b/sql/core/src/test/resources/sql-tests/results/ansi/decimalArithmeticOperations.sql.out
index 699c916fd8fd..9593291fae21 100644
--- 
a/sql/core/src/test/resources/sql-tests/results/ansi/decimalArithmeticOperations.sql.out
+++ 
b/sql/core/src/test/resources/sql-tests/results/ansi/decimalArithmeticOperations.sql.out
@@ -155,7 +155,7 @@ org.apache.spark.SparkArithmeticException
     "config" : "\"spark.sql.ansi.enabled\"",
     "precision" : "38",
     "scale" : "6",
-    "value" : 
"1000000000000000000000000000000000000.00000000000000000000000000000000000000"
+    "value" : 
"1000000000000000000000000000000000000.000000000000000000000000000000000000000"
   },
   "queryContext" : [ {
     "objectType" : "",
@@ -204,7 +204,7 @@ org.apache.spark.SparkArithmeticException
     "config" : "\"spark.sql.ansi.enabled\"",
     "precision" : "38",
     "scale" : "6",
-    "value" : 
"10123456789012345678901234567890123456.00000000000000000000000000000000000000"
+    "value" : 
"10123456789012345678901234567890123456.000000000000000000000000000000000000000"
   },
   "queryContext" : [ {
     "objectType" : "",
@@ -229,7 +229,7 @@ org.apache.spark.SparkArithmeticException
     "config" : "\"spark.sql.ansi.enabled\"",
     "precision" : "38",
     "scale" : "6",
-    "value" : 
"101234567890123456789012345678901234.56000000000000000000000000000000000000"
+    "value" : 
"101234567890123456789012345678901234.560000000000000000000000000000000000000"
   },
   "queryContext" : [ {
     "objectType" : "",
@@ -254,7 +254,7 @@ org.apache.spark.SparkArithmeticException
     "config" : "\"spark.sql.ansi.enabled\"",
     "precision" : "38",
     "scale" : "6",
-    "value" : 
"10123456789012345678901234567890123.45600000000000000000000000000000000000"
+    "value" : 
"10123456789012345678901234567890123.456000000000000000000000000000000000000"
   },
   "queryContext" : [ {
     "objectType" : "",
@@ -279,7 +279,7 @@ org.apache.spark.SparkArithmeticException
     "config" : "\"spark.sql.ansi.enabled\"",
     "precision" : "38",
     "scale" : "6",
-    "value" : 
"1012345678901234567890123456789012.34560000000000000000000000000000000000"
+    "value" : 
"1012345678901234567890123456789012.345600000000000000000000000000000000000"
   },
   "queryContext" : [ {
     "objectType" : "",
@@ -304,7 +304,7 @@ org.apache.spark.SparkArithmeticException
     "config" : "\"spark.sql.ansi.enabled\"",
     "precision" : "38",
     "scale" : "6",
-    "value" : 
"101234567890123456789012345678901.23456000000000000000000000000000000000"
+    "value" : 
"101234567890123456789012345678901.234560000000000000000000000000000000000"
   },
   "queryContext" : [ {
     "objectType" : "",
@@ -337,7 +337,7 @@ org.apache.spark.SparkArithmeticException
     "config" : "\"spark.sql.ansi.enabled\"",
     "precision" : "38",
     "scale" : "6",
-    "value" : 
"101234567890123456789012345678901.23456000000000000000000000000000000000"
+    "value" : 
"101234567890123456789012345678901.234560000000000000000000000000000000000"
   },
   "queryContext" : [ {
     "objectType" : "",


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to