This is an automated email from the ASF dual-hosted git repository.

maxgekk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new e35e29a0517d [SPARK-46915][SQL] Simplify `UnaryMinus` `Abs` and align 
error class
e35e29a0517d is described below

commit e35e29a0517db930e12fe801f0f0ab1a31c3b23e
Author: panbingkun <panbing...@baidu.com>
AuthorDate: Fri Feb 2 20:33:31 2024 +0300

    [SPARK-46915][SQL] Simplify `UnaryMinus` `Abs` and align error class
    
    ### What changes were proposed in this pull request?
    The pr aims to:
    - simplify `UnaryMinus` & `Abs`
    - convert error-class `_LEGACY_ERROR_TEMP_2043` to `ARITHMETIC_OVERFLOW`, 
and remove it.
    
    ### Why are the changes needed?
    1.When the data type in `UnaryMinus` and `Abs` is `ByteType` or 
`ShortType`, if `an overflow exception` occurs, the corresponding error class 
is: `_LEGACY_ERROR_TEMP_2043`
    But when the data type is `IntegerType` or `LongType`, if `an overflow 
exception` occurs, its corresponding error class is: ARITHMETIC_OVERFLOW, We 
should unify it.
    
    2.In the `codegen` logic of `UnaryMinus` and `Abs`, there is a difference 
between the logic of generating code when the data type is `ByteType` or 
`ShortType` and when the data type is `IntegerType` or `LongType`. We can unify 
it and simplify the code.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes,
    
    ### How was this patch tested?
    - Update existed UT.
    - Pass GA.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No.
    
    Closes #44942 from panbingkun/UnaryMinus_improve.
    
    Authored-by: panbingkun <panbing...@baidu.com>
    Signed-off-by: Max Gekk <max.g...@gmail.com>
---
 .../src/main/resources/error/error-classes.json    |  5 ---
 .../sql/catalyst/expressions/arithmetic.scala      | 45 ++++++++--------------
 .../spark/sql/errors/QueryExecutionErrors.scala    |  8 ----
 .../org/apache/spark/sql/types/numerics.scala      |  6 +--
 .../expressions/ArithmeticExpressionSuite.scala    | 27 ++++++++-----
 5 files changed, 36 insertions(+), 55 deletions(-)

diff --git a/common/utils/src/main/resources/error/error-classes.json 
b/common/utils/src/main/resources/error/error-classes.json
index 136825ab374d..6d88f5ee511c 100644
--- a/common/utils/src/main/resources/error/error-classes.json
+++ b/common/utils/src/main/resources/error/error-classes.json
@@ -5747,11 +5747,6 @@
       "<message>. If necessary set <ansiConfig> to false to bypass this error."
     ]
   },
-  "_LEGACY_ERROR_TEMP_2043" : {
-    "message" : [
-      "- <sqlValue> caused overflow."
-    ]
-  },
   "_LEGACY_ERROR_TEMP_2045" : {
     "message" : [
       "Unsupported table change: <message>"
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index 9f1b42ad84d3..0f95ae821ab0 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -60,23 +60,15 @@ case class UnaryMinus(
 
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = 
dataType match {
     case _: DecimalType => defineCodeGen(ctx, ev, c => s"$c.unary_$$minus()")
-    case ByteType | ShortType if failOnError =>
+    case ByteType | ShortType | IntegerType | LongType if failOnError =>
+      val typeUtils = TypeUtils.getClass.getCanonicalName.stripSuffix("$")
+      val refDataType = ctx.addReferenceObj("refDataType", dataType, 
dataType.getClass.getName)
       nullSafeCodeGen(ctx, ev, eval => {
         val javaBoxedType = CodeGenerator.boxedType(dataType)
-        val javaType = CodeGenerator.javaType(dataType)
-        val originValue = ctx.freshName("origin")
         s"""
-           |$javaType $originValue = ($javaType)($eval);
-           |if ($originValue == $javaBoxedType.MIN_VALUE) {
-           |  throw 
QueryExecutionErrors.unaryMinusCauseOverflowError($originValue);
-           |}
-           |${ev.value} = ($javaType)(-($originValue));
-           """.stripMargin
-      })
-    case IntegerType | LongType if failOnError =>
-      val mathUtils = MathUtils.getClass.getCanonicalName.stripSuffix("$")
-      nullSafeCodeGen(ctx, ev, eval => {
-        s"${ev.value} = $mathUtils.negateExact($eval);"
+           |${ev.value} = ($javaBoxedType)$typeUtils.getNumeric(
+           |  $refDataType, $failOnError).negate($eval);
+         """.stripMargin
       })
     case dt: NumericType => nullSafeCodeGen(ctx, ev, eval => {
       val originValue = ctx.freshName("origin")
@@ -181,23 +173,16 @@ case class Abs(child: Expression, failOnError: Boolean = 
SQLConf.get.ansiEnabled
     case _: DecimalType =>
       defineCodeGen(ctx, ev, c => s"$c.abs()")
 
-    case ByteType | ShortType if failOnError =>
-      val javaBoxedType = CodeGenerator.boxedType(dataType)
-      val javaType = CodeGenerator.javaType(dataType)
-      nullSafeCodeGen(ctx, ev, eval =>
+    case ByteType | ShortType | IntegerType | LongType if failOnError =>
+      val typeUtils = TypeUtils.getClass.getCanonicalName.stripSuffix("$")
+      val refDataType = ctx.addReferenceObj("refDataType", dataType, 
dataType.getClass.getName)
+      nullSafeCodeGen(ctx, ev, eval => {
+        val javaBoxedType = CodeGenerator.boxedType(dataType)
         s"""
-          |if ($eval == $javaBoxedType.MIN_VALUE) {
-          |  throw QueryExecutionErrors.unaryMinusCauseOverflowError($eval);
-          |} else if ($eval < 0) {
-          |  ${ev.value} = ($javaType)-$eval;
-          |} else {
-          |  ${ev.value} = $eval;
-          |}
-          |""".stripMargin)
-
-    case IntegerType | LongType if failOnError =>
-      val mathUtils = MathUtils.getClass.getCanonicalName.stripSuffix("$")
-      defineCodeGen(ctx, ev, c => s"$c < 0 ? $mathUtils.negateExact($c) : $c")
+           |${ev.value} = ($javaBoxedType)$typeUtils.getNumeric(
+           |  $refDataType, $failOnError).abs($eval);
+         """.stripMargin
+      })
 
     case _: AnsiIntervalType =>
       val mathUtils = MathUtils.getClass.getCanonicalName.stripSuffix("$")
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
index b09885c904a5..9ff076c5fd50 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
@@ -601,14 +601,6 @@ private[sql] object QueryExecutionErrors extends 
QueryErrorsBase with ExecutionE
       summary = "")
   }
 
-  def unaryMinusCauseOverflowError(originValue: Int): SparkArithmeticException 
= {
-    new SparkArithmeticException(
-      errorClass = "_LEGACY_ERROR_TEMP_2043",
-      messageParameters = Map("sqlValue" -> toSQLValue(originValue, 
IntegerType)),
-      context = Array.empty,
-      summary = "")
-  }
-
   def binaryArithmeticCauseOverflowError(
       eval1: Short, symbol: String, eval2: Short): SparkArithmeticException = {
     new SparkArithmeticException(
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/numerics.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/numerics.scala
index c3d893d82fce..45b6cb44e5fa 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/numerics.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/numerics.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.types
 import scala.math.Numeric._
 
 import org.apache.spark.sql.catalyst.util.{MathUtils, SQLOrderingUtil}
-import org.apache.spark.sql.errors.QueryExecutionErrors
+import org.apache.spark.sql.errors.{ExecutionErrors, QueryExecutionErrors}
 import org.apache.spark.sql.types.Decimal.DecimalIsConflicted
 
 private[sql] object ByteExactNumeric extends ByteIsIntegral with 
Ordering.ByteOrdering {
@@ -50,7 +50,7 @@ private[sql] object ByteExactNumeric extends ByteIsIntegral 
with Ordering.ByteOr
 
   override def negate(x: Byte): Byte = {
     if (x == Byte.MinValue) { // if and only if x is Byte.MinValue, overflow 
can happen
-      throw QueryExecutionErrors.unaryMinusCauseOverflowError(x)
+      throw ExecutionErrors.arithmeticOverflowError("byte overflow")
     }
     (-x).toByte
   }
@@ -84,7 +84,7 @@ private[sql] object ShortExactNumeric extends ShortIsIntegral 
with Ordering.Shor
 
   override def negate(x: Short): Short = {
     if (x == Short.MinValue) { // if and only if x is Byte.MinValue, overflow 
can happen
-      throw QueryExecutionErrors.unaryMinusCauseOverflowError(x)
+      throw ExecutionErrors.arithmeticOverflowError("short overflow")
     }
     (-x).toShort
   }
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
index 7a80188d445d..89f0b95f5c18 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
@@ -29,7 +29,8 @@ import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
 import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin
 import org.apache.spark.sql.catalyst.trees.Origin
-import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.errors.DataTypeErrors.toSQLConf
+import org.apache.spark.sql.internal.{SqlApiConf, SQLConf}
 import org.apache.spark.sql.types._
 
 class ArithmeticExpressionSuite extends SparkFunSuite with 
ExpressionEvalHelper {
@@ -116,14 +117,22 @@ class ArithmeticExpressionSuite extends SparkFunSuite 
with ExpressionEvalHelper
       checkEvaluation(UnaryMinus(Literal(Byte.MinValue)), Byte.MinValue)
     }
     withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") {
-      checkExceptionInExpression[ArithmeticException](
-        UnaryMinus(Literal(Long.MinValue)), "overflow")
-      checkExceptionInExpression[ArithmeticException](
-        UnaryMinus(Literal(Int.MinValue)), "overflow")
-      checkExceptionInExpression[ArithmeticException](
-        UnaryMinus(Literal(Short.MinValue)), "overflow")
-      checkExceptionInExpression[ArithmeticException](
-        UnaryMinus(Literal(Byte.MinValue)), "overflow")
+      checkErrorInExpression[SparkArithmeticException](
+        UnaryMinus(Literal(Long.MinValue)), "ARITHMETIC_OVERFLOW",
+        Map("message" -> "long overflow", "alternative" -> "",
+          "config" -> toSQLConf(SqlApiConf.ANSI_ENABLED_KEY)))
+      checkErrorInExpression[SparkArithmeticException](
+        UnaryMinus(Literal(Int.MinValue)), "ARITHMETIC_OVERFLOW",
+        Map("message" -> "integer overflow", "alternative" -> "",
+          "config" -> toSQLConf(SqlApiConf.ANSI_ENABLED_KEY)))
+      checkErrorInExpression[SparkArithmeticException](
+        UnaryMinus(Literal(Short.MinValue)), "ARITHMETIC_OVERFLOW",
+        Map("message" -> "short overflow", "alternative" -> "",
+          "config" -> toSQLConf(SqlApiConf.ANSI_ENABLED_KEY)))
+      checkErrorInExpression[SparkArithmeticException](
+        UnaryMinus(Literal(Byte.MinValue)), "ARITHMETIC_OVERFLOW",
+        Map("message" -> "byte overflow", "alternative" -> "",
+          "config" -> toSQLConf(SqlApiConf.ANSI_ENABLED_KEY)))
       checkEvaluation(UnaryMinus(positiveShortLit), (- positiveShort).toShort)
       checkEvaluation(UnaryMinus(negativeShortLit), (- negativeShort).toShort)
       checkEvaluation(UnaryMinus(positiveIntLit), - positiveInt)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to