This is an automated email from the ASF dual-hosted git repository. maxgekk pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new e35e29a0517d [SPARK-46915][SQL] Simplify `UnaryMinus` `Abs` and align error class e35e29a0517d is described below commit e35e29a0517db930e12fe801f0f0ab1a31c3b23e Author: panbingkun <panbing...@baidu.com> AuthorDate: Fri Feb 2 20:33:31 2024 +0300 [SPARK-46915][SQL] Simplify `UnaryMinus` `Abs` and align error class ### What changes were proposed in this pull request? The pr aims to: - simplify `UnaryMinus` & `Abs` - convert error-class `_LEGACY_ERROR_TEMP_2043` to `ARITHMETIC_OVERFLOW`, and remove it. ### Why are the changes needed? 1.When the data type in `UnaryMinus` and `Abs` is `ByteType` or `ShortType`, if `an overflow exception` occurs, the corresponding error class is: `_LEGACY_ERROR_TEMP_2043` But when the data type is `IntegerType` or `LongType`, if `an overflow exception` occurs, its corresponding error class is: ARITHMETIC_OVERFLOW, We should unify it. 2.In the `codegen` logic of `UnaryMinus` and `Abs`, there is a difference between the logic of generating code when the data type is `ByteType` or `ShortType` and when the data type is `IntegerType` or `LongType`. We can unify it and simplify the code. ### Does this PR introduce _any_ user-facing change? Yes, ### How was this patch tested? - Update existed UT. - Pass GA. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #44942 from panbingkun/UnaryMinus_improve. Authored-by: panbingkun <panbing...@baidu.com> Signed-off-by: Max Gekk <max.g...@gmail.com> --- .../src/main/resources/error/error-classes.json | 5 --- .../sql/catalyst/expressions/arithmetic.scala | 45 ++++++++-------------- .../spark/sql/errors/QueryExecutionErrors.scala | 8 ---- .../org/apache/spark/sql/types/numerics.scala | 6 +-- .../expressions/ArithmeticExpressionSuite.scala | 27 ++++++++----- 5 files changed, 36 insertions(+), 55 deletions(-) diff --git a/common/utils/src/main/resources/error/error-classes.json b/common/utils/src/main/resources/error/error-classes.json index 136825ab374d..6d88f5ee511c 100644 --- a/common/utils/src/main/resources/error/error-classes.json +++ b/common/utils/src/main/resources/error/error-classes.json @@ -5747,11 +5747,6 @@ "<message>. If necessary set <ansiConfig> to false to bypass this error." ] }, - "_LEGACY_ERROR_TEMP_2043" : { - "message" : [ - "- <sqlValue> caused overflow." - ] - }, "_LEGACY_ERROR_TEMP_2045" : { "message" : [ "Unsupported table change: <message>" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 9f1b42ad84d3..0f95ae821ab0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -60,23 +60,15 @@ case class UnaryMinus( override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = dataType match { case _: DecimalType => defineCodeGen(ctx, ev, c => s"$c.unary_$$minus()") - case ByteType | ShortType if failOnError => + case ByteType | ShortType | IntegerType | LongType if failOnError => + val typeUtils = TypeUtils.getClass.getCanonicalName.stripSuffix("$") + val refDataType = ctx.addReferenceObj("refDataType", dataType, dataType.getClass.getName) nullSafeCodeGen(ctx, ev, eval => { val javaBoxedType = CodeGenerator.boxedType(dataType) - val javaType = CodeGenerator.javaType(dataType) - val originValue = ctx.freshName("origin") s""" - |$javaType $originValue = ($javaType)($eval); - |if ($originValue == $javaBoxedType.MIN_VALUE) { - | throw QueryExecutionErrors.unaryMinusCauseOverflowError($originValue); - |} - |${ev.value} = ($javaType)(-($originValue)); - """.stripMargin - }) - case IntegerType | LongType if failOnError => - val mathUtils = MathUtils.getClass.getCanonicalName.stripSuffix("$") - nullSafeCodeGen(ctx, ev, eval => { - s"${ev.value} = $mathUtils.negateExact($eval);" + |${ev.value} = ($javaBoxedType)$typeUtils.getNumeric( + | $refDataType, $failOnError).negate($eval); + """.stripMargin }) case dt: NumericType => nullSafeCodeGen(ctx, ev, eval => { val originValue = ctx.freshName("origin") @@ -181,23 +173,16 @@ case class Abs(child: Expression, failOnError: Boolean = SQLConf.get.ansiEnabled case _: DecimalType => defineCodeGen(ctx, ev, c => s"$c.abs()") - case ByteType | ShortType if failOnError => - val javaBoxedType = CodeGenerator.boxedType(dataType) - val javaType = CodeGenerator.javaType(dataType) - nullSafeCodeGen(ctx, ev, eval => + case ByteType | ShortType | IntegerType | LongType if failOnError => + val typeUtils = TypeUtils.getClass.getCanonicalName.stripSuffix("$") + val refDataType = ctx.addReferenceObj("refDataType", dataType, dataType.getClass.getName) + nullSafeCodeGen(ctx, ev, eval => { + val javaBoxedType = CodeGenerator.boxedType(dataType) s""" - |if ($eval == $javaBoxedType.MIN_VALUE) { - | throw QueryExecutionErrors.unaryMinusCauseOverflowError($eval); - |} else if ($eval < 0) { - | ${ev.value} = ($javaType)-$eval; - |} else { - | ${ev.value} = $eval; - |} - |""".stripMargin) - - case IntegerType | LongType if failOnError => - val mathUtils = MathUtils.getClass.getCanonicalName.stripSuffix("$") - defineCodeGen(ctx, ev, c => s"$c < 0 ? $mathUtils.negateExact($c) : $c") + |${ev.value} = ($javaBoxedType)$typeUtils.getNumeric( + | $refDataType, $failOnError).abs($eval); + """.stripMargin + }) case _: AnsiIntervalType => val mathUtils = MathUtils.getClass.getCanonicalName.stripSuffix("$") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala index b09885c904a5..9ff076c5fd50 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala @@ -601,14 +601,6 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE summary = "") } - def unaryMinusCauseOverflowError(originValue: Int): SparkArithmeticException = { - new SparkArithmeticException( - errorClass = "_LEGACY_ERROR_TEMP_2043", - messageParameters = Map("sqlValue" -> toSQLValue(originValue, IntegerType)), - context = Array.empty, - summary = "") - } - def binaryArithmeticCauseOverflowError( eval1: Short, symbol: String, eval2: Short): SparkArithmeticException = { new SparkArithmeticException( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/numerics.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/numerics.scala index c3d893d82fce..45b6cb44e5fa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/numerics.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/numerics.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.types import scala.math.Numeric._ import org.apache.spark.sql.catalyst.util.{MathUtils, SQLOrderingUtil} -import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.sql.errors.{ExecutionErrors, QueryExecutionErrors} import org.apache.spark.sql.types.Decimal.DecimalIsConflicted private[sql] object ByteExactNumeric extends ByteIsIntegral with Ordering.ByteOrdering { @@ -50,7 +50,7 @@ private[sql] object ByteExactNumeric extends ByteIsIntegral with Ordering.ByteOr override def negate(x: Byte): Byte = { if (x == Byte.MinValue) { // if and only if x is Byte.MinValue, overflow can happen - throw QueryExecutionErrors.unaryMinusCauseOverflowError(x) + throw ExecutionErrors.arithmeticOverflowError("byte overflow") } (-x).toByte } @@ -84,7 +84,7 @@ private[sql] object ShortExactNumeric extends ShortIsIntegral with Ordering.Shor override def negate(x: Short): Short = { if (x == Short.MinValue) { // if and only if x is Byte.MinValue, overflow can happen - throw QueryExecutionErrors.unaryMinusCauseOverflowError(x) + throw ExecutionErrors.arithmeticOverflowError("short overflow") } (-x).toShort } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala index 7a80188d445d..89f0b95f5c18 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -29,7 +29,8 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin import org.apache.spark.sql.catalyst.trees.Origin -import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.errors.DataTypeErrors.toSQLConf +import org.apache.spark.sql.internal.{SqlApiConf, SQLConf} import org.apache.spark.sql.types._ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -116,14 +117,22 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(UnaryMinus(Literal(Byte.MinValue)), Byte.MinValue) } withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") { - checkExceptionInExpression[ArithmeticException]( - UnaryMinus(Literal(Long.MinValue)), "overflow") - checkExceptionInExpression[ArithmeticException]( - UnaryMinus(Literal(Int.MinValue)), "overflow") - checkExceptionInExpression[ArithmeticException]( - UnaryMinus(Literal(Short.MinValue)), "overflow") - checkExceptionInExpression[ArithmeticException]( - UnaryMinus(Literal(Byte.MinValue)), "overflow") + checkErrorInExpression[SparkArithmeticException]( + UnaryMinus(Literal(Long.MinValue)), "ARITHMETIC_OVERFLOW", + Map("message" -> "long overflow", "alternative" -> "", + "config" -> toSQLConf(SqlApiConf.ANSI_ENABLED_KEY))) + checkErrorInExpression[SparkArithmeticException]( + UnaryMinus(Literal(Int.MinValue)), "ARITHMETIC_OVERFLOW", + Map("message" -> "integer overflow", "alternative" -> "", + "config" -> toSQLConf(SqlApiConf.ANSI_ENABLED_KEY))) + checkErrorInExpression[SparkArithmeticException]( + UnaryMinus(Literal(Short.MinValue)), "ARITHMETIC_OVERFLOW", + Map("message" -> "short overflow", "alternative" -> "", + "config" -> toSQLConf(SqlApiConf.ANSI_ENABLED_KEY))) + checkErrorInExpression[SparkArithmeticException]( + UnaryMinus(Literal(Byte.MinValue)), "ARITHMETIC_OVERFLOW", + Map("message" -> "byte overflow", "alternative" -> "", + "config" -> toSQLConf(SqlApiConf.ANSI_ENABLED_KEY))) checkEvaluation(UnaryMinus(positiveShortLit), (- positiveShort).toShort) checkEvaluation(UnaryMinus(negativeShortLit), (- negativeShort).toShort) checkEvaluation(UnaryMinus(positiveIntLit), - positiveInt) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org