cloud-fan commented on code in PR #37649: URL: https://github.com/apache/spark/pull/37649#discussion_r954423097
########## sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala: ########## @@ -394,48 +394,49 @@ final class Decimal extends Ordered[Decimal] with Serializable { if (precision == this.precision && scale == this.scale) { return true } + var lv = longVal DecimalType.checkNegativeScale(scale) - // First, update our longVal if we can, or transfer over to using a BigDecimal + // First, update our lv if we can, or transfer over to using a BigDecimal if (decimalVal.eq(null)) { if (scale < _scale) { // Easier case: we just need to divide our scale down val diff = _scale - scale val pow10diff = POW_10(diff) // % and / always round to 0 - val droppedDigits = longVal % pow10diff - longVal /= pow10diff + val droppedDigits = lv % pow10diff + lv /= pow10diff roundMode match { case ROUND_FLOOR => if (droppedDigits < 0) { - longVal += -1L + lv += -1L } case ROUND_CEILING => if (droppedDigits > 0) { - longVal += 1L + lv += 1L } case ROUND_HALF_UP => if (math.abs(droppedDigits) * 2 >= pow10diff) { - longVal += (if (droppedDigits < 0) -1L else 1L) + lv += (if (droppedDigits < 0) -1L else 1L) } case ROUND_HALF_EVEN => val doubled = math.abs(droppedDigits) * 2 - if (doubled > pow10diff || doubled == pow10diff && longVal % 2 != 0) { - longVal += (if (droppedDigits < 0) -1L else 1L) + if (doubled > pow10diff || doubled == pow10diff && lv % 2 != 0) { + lv += (if (droppedDigits < 0) -1L else 1L) } case _ => throw QueryExecutionErrors.unsupportedRoundingMode(roundMode) } } else if (scale > _scale) { - // We might be able to multiply longVal by a power of 10 and not overflow, but if not, + // We might be able to multiply lv by a power of 10 and not overflow, but if not, // switch to using a BigDecimal val diff = scale - _scale val p = POW_10(math.max(MAX_LONG_DIGITS - diff, 0)) - if (diff <= MAX_LONG_DIGITS && longVal > -p && longVal < p) { - // Multiplying longVal by POW_10(diff) will still keep it below MAX_LONG_DIGITS - longVal *= POW_10(diff) + if (diff <= MAX_LONG_DIGITS && lv > -p && lv < p) { + // Multiplying lv by POW_10(diff) will still keep it below MAX_LONG_DIGITS + lv *= POW_10(diff) } else { // Give up on using Longs; switch to BigDecimal, which we'll modify below - decimalVal = BigDecimal(longVal, _scale) + decimalVal = BigDecimal(lv, _scale) Review Comment: shall we avoid updating `decimalVal` as well if we eventually return false? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org