This is an automated email from the ASF dual-hosted git repository.

srowen pushed a commit to branch branch-3.2
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.2 by this push:
     new d304a652823 [SPARK-41554] fix changing of Decimal scale when scale 
decreased by m…
d304a652823 is described below

commit d304a6528233798dc93d31da58be220cf6d0485e
Author: oleksii.diagiliev <oleksii.diagil...@workday.com>
AuthorDate: Fri Feb 3 10:49:42 2023 -0600

    [SPARK-41554] fix changing of Decimal scale when scale decreased by m…
    
    …ore than 18
    
    This is a backport PR for https://github.com/apache/spark/pull/39099
    
    Closes #39381 from fe2s/branch-3.2-fix-decimal-scaling.
    
    Authored-by: oleksii.diagiliev <oleksii.diagil...@workday.com>
    Signed-off-by: Sean Owen <sro...@gmail.com>
---
 .../scala/org/apache/spark/sql/types/Decimal.scala | 60 +++++++++++++---------
 .../org/apache/spark/sql/types/DecimalSuite.scala  | 53 ++++++++++++++++++-
 2 files changed, 88 insertions(+), 25 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
index bc5fba8d0d8..503a887d690 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
@@ -388,30 +388,42 @@ final class Decimal extends Ordered[Decimal] with 
Serializable {
       if (scale < _scale) {
         // Easier case: we just need to divide our scale down
         val diff = _scale - scale
-        val pow10diff = POW_10(diff)
-        // % and / always round to 0
-        val droppedDigits = longVal % pow10diff
-        longVal /= pow10diff
-        roundMode match {
-          case ROUND_FLOOR =>
-            if (droppedDigits < 0) {
-              longVal += -1L
-            }
-          case ROUND_CEILING =>
-            if (droppedDigits > 0) {
-              longVal += 1L
-            }
-          case ROUND_HALF_UP =>
-            if (math.abs(droppedDigits) * 2 >= pow10diff) {
-              longVal += (if (droppedDigits < 0) -1L else 1L)
-            }
-          case ROUND_HALF_EVEN =>
-            val doubled = math.abs(droppedDigits) * 2
-            if (doubled > pow10diff || doubled == pow10diff && longVal % 2 != 
0) {
-              longVal += (if (droppedDigits < 0) -1L else 1L)
-            }
-          case _ =>
-            throw QueryExecutionErrors.unsupportedRoundingMode(roundMode)
+        // If diff is greater than max number of digits we store in Long, then
+        // value becomes 0. Otherwise we calculate new value dividing by power 
of 10.
+        // In both cases we apply rounding after that.
+        if (diff > MAX_LONG_DIGITS) {
+          longVal = roundMode match {
+            case ROUND_FLOOR => if (longVal < 0) -1L else 0L
+            case ROUND_CEILING => if (longVal > 0) 1L else 0L
+            case ROUND_HALF_UP | ROUND_HALF_EVEN => 0L
+            case _ => sys.error(s"Not supported rounding mode: $roundMode")
+          }
+        } else {
+          val pow10diff = POW_10(diff)
+          // % and / always round to 0
+          val droppedDigits = longVal % pow10diff
+          longVal /= pow10diff
+          roundMode match {
+            case ROUND_FLOOR =>
+              if (droppedDigits < 0) {
+                longVal += -1L
+              }
+            case ROUND_CEILING =>
+              if (droppedDigits > 0) {
+                longVal += 1L
+              }
+            case ROUND_HALF_UP =>
+              if (math.abs(droppedDigits) * 2 >= pow10diff) {
+                longVal += (if (droppedDigits < 0) -1L else 1L)
+              }
+            case ROUND_HALF_EVEN =>
+              val doubled = math.abs(droppedDigits) * 2
+              if (doubled > pow10diff || doubled == pow10diff && longVal % 2 
!= 0) {
+                longVal += (if (droppedDigits < 0) -1L else 1L)
+              }
+            case _ =>
+              throw QueryExecutionErrors.unsupportedRoundingMode(roundMode)
+          }
         }
       } else if (scale > _scale) {
         // We might be able to multiply longVal by a power of 10 and not 
overflow, but if not,
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala
index 5433c561a03..1f4862fcbdc 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala
@@ -27,6 +27,9 @@ import org.apache.spark.sql.types.Decimal._
 import org.apache.spark.unsafe.types.UTF8String
 
 class DecimalSuite extends SparkFunSuite with PrivateMethodTester with 
SQLHelper {
+
+  val allSupportedRoundModes = Seq(ROUND_HALF_UP, ROUND_HALF_EVEN, 
ROUND_CEILING, ROUND_FLOOR)
+
   /** Check that a Decimal has the given string representation, precision and 
scale */
   private def checkDecimal(d: Decimal, string: String, precision: Int, scale: 
Int): Unit = {
     assert(d.toString === string)
@@ -222,7 +225,7 @@ class DecimalSuite extends SparkFunSuite with 
PrivateMethodTester with SQLHelper
   }
 
   test("changePrecision/toPrecision on compact decimal should respect rounding 
mode") {
-    Seq(ROUND_FLOOR, ROUND_CEILING, ROUND_HALF_UP, ROUND_HALF_EVEN).foreach { 
mode =>
+    allSupportedRoundModes.foreach { mode =>
       Seq("0.4", "0.5", "0.6", "1.0", "1.1", "1.6", "2.5", "5.5").foreach { n 
=>
         Seq("", "-").foreach { sign =>
           val bd = BigDecimal(sign + n)
@@ -314,4 +317,52 @@ class DecimalSuite extends SparkFunSuite with 
PrivateMethodTester with SQLHelper
       }
     }
   }
+
+  // 18 is a max number of digits in Decimal's compact long
+  test("SPARK-41554: decrease/increase scale by 18 and more on compact 
decimal") {
+    val unscaledNums = Seq(
+      0L, 1L, 10L, 51L, 123L, 523L,
+      // 18 digits
+      912345678901234567L,
+      112345678901234567L,
+      512345678901234567L
+    )
+    val precision = 38
+    // generate some (from, to) scale pairs, e.g. (38, 18), (-20, -2), etc
+    val scalePairs = for {
+      scale <- Seq(38, 20, 19, 18)
+      delta <- Seq(38, 20, 19, 18)
+      a = scale
+      b = scale - delta
+    } yield {
+      Seq((a, b), (-a, -b), (b, a), (-b, -a))
+    }
+
+    for {
+      unscaled <- unscaledNums
+      mode <- allSupportedRoundModes
+      (scaleFrom, scaleTo) <- scalePairs.flatten
+      sign <- Seq(1L, -1L)
+    } {
+      val unscaledWithSign = unscaled * sign
+      if (scaleFrom < 0 || scaleTo < 0) {
+        withSQLConf(SQLConf.LEGACY_ALLOW_NEGATIVE_SCALE_OF_DECIMAL_ENABLED.key 
-> "true") {
+          checkScaleChange(unscaledWithSign, scaleFrom, scaleTo, mode)
+        }
+      } else {
+        checkScaleChange(unscaledWithSign, scaleFrom, scaleTo, mode)
+      }
+    }
+
+    def checkScaleChange(unscaled: Long, scaleFrom: Int, scaleTo: Int,
+                         roundMode: BigDecimal.RoundingMode.Value): Unit = {
+      val decimal = Decimal(unscaled, precision, scaleFrom)
+      checkCompact(decimal, true)
+      decimal.changePrecision(precision, scaleTo, roundMode)
+      val bd = BigDecimal(unscaled, scaleFrom).setScale(scaleTo, roundMode)
+      assert(decimal.toBigDecimal === bd,
+        s"unscaled: $unscaled, scaleFrom: $scaleFrom, scaleTo: $scaleTo, mode: 
$roundMode")
+    }
+  }
+
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to