cloud-fan commented on code in PR #36698: URL: https://github.com/apache/spark/pull/36698#discussion_r885776567
########## sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala: ########## @@ -208,6 +210,79 @@ case class Abs(child: Expression, failOnError: Boolean = SQLConf.get.ansiEnabled override protected def withNewChildInternal(newChild: Expression): Abs = copy(child = newChild) } +/** + * The child class should override decimalType method to report the result data type. + * + * When `spark.sql.decimalOperations.allowPrecisionLoss` is set to true, if the precision / scale + * needed are out of the range of available values, the scale is reduced up to 6, in order to + * prevent the truncation of the integer part of the decimals. + * + * Rounds the decimal to given scale and check whether the decimal can fit in provided precision + * or not. If not, if `nullOnOverflow` is `true`, it returns `null`; otherwise an + * `ArithmeticException` is thrown. + */ +trait DecimalArithmeticSupport extends BinaryArithmetic { + protected val nullOnOverflow: Boolean = !failOnError + protected val allowPrecisionLoss: Boolean = SQLConf.get.decimalOperationsAllowPrecisionLoss + + override def checkInputDataTypes(): TypeCheckResult = (left.dataType, right.dataType) match { + case (_: DecimalType, _: DecimalType) => + // We allow eval decimal type with different precision and scale, and change the precision + // and scale before return result. + TypeCheckResult.TypeCheckSuccess + case _ => super.checkInputDataTypes() + } + + /** Name of the function for this expression on a [[Decimal]] type. */ + protected def decimalMethod: String + protected def decimalType(p1: Int, s1: Int, p2: Int, s2: Int): DecimalType + + override def nullable: Boolean = dataType match { + case _: DecimalType => nullOnOverflow + case _ => super.nullable + } + + override def dataType: DataType = (left, right) match { + case (DecimalType.Expression(p1, s1), DecimalType.Expression(p2, s2)) => + decimalType(p1, s1, p2, s2) + case _ => super.dataType + } + + def checkOverflow(value: Decimal, decimalType: DecimalType): Decimal = { + value.toPrecision( + decimalType.precision, + decimalType.scale, + Decimal.ROUND_HALF_UP, + nullOnOverflow, + queryContext) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = dataType match { + case decimalType: DecimalType => + val errorContextCode = if (nullOnOverflow) { + "\"\"" + } else { + ctx.addReferenceObj("errCtx", queryContext) + } + val isNull = if (nullOnOverflow) { Review Comment: ```suggestion val updateisNull = if (nullOnOverflow) { ``` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org