This is an automated email from the ASF dual-hosted git repository.

gengliang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 42721120f3c [SPARK-42045][SQL] ANSI SQL mode: Round/Bround should 
return an error on integer overflow
42721120f3c is described below

commit 42721120f3c7206a9fc22db5d0bb7cf40f0cacfd
Author: Gengliang Wang <gengli...@apache.org>
AuthorDate: Fri Jan 13 09:40:36 2023 -0800

    [SPARK-42045][SQL] ANSI SQL mode: Round/Bround should return an error on 
integer overflow
    
    ### What changes were proposed in this pull request?
    
    In ANSI SQL mode, Round/Bround should return an error on integer overflow.
    Note this PR is for integer only. Once it is merge, I will create one 
follow-up PR for all the rest integral types: byte, short, and long.
    Also, the function ceil and floor accepts decimal type input, so there is 
no need to change them.
    
    ### Why are the changes needed?
    
    In ANSI SQL mode, integer overflow should cause error instead of returning 
an unreasonable result.
    For example, `round(2147483647, -1)` should return error instead of 
returning `-2147483646`
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, in ANSI SQL mode, SQL function Round and Bround will return an error 
on integer overflow
    
    ### How was this patch tested?
    
    UT
    
    Closes #39546 from gengliangwang/fixRound.
    
    Authored-by: Gengliang Wang <gengli...@apache.org>
    Signed-off-by: Gengliang Wang <gengli...@apache.org>
---
 .../sql/catalyst/expressions/mathExpressions.scala |  60 +++++--
 .../apache/spark/sql/catalyst/util/MathUtils.scala |  12 +-
 .../expressions/MathExpressionsSuite.scala         |  15 +-
 .../catalyst/util/PhysicalAggregationSuite.scala   |   2 +-
 .../test/resources/sql-tests/inputs/ansi/math.sql  |   1 +
 .../src/test/resources/sql-tests/inputs/math.sql   |  17 ++
 .../resources/sql-tests/results/ansi/math.sql.out  | 175 +++++++++++++++++++++
 .../test/resources/sql-tests/results/math.sql.out  | 111 +++++++++++++
 8 files changed, 381 insertions(+), 12 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
index 9ffc148180a..50a1194c2f1 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
@@ -26,8 +26,10 @@ import 
org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{DataTypeMismatch,
 import org.apache.spark.sql.catalyst.expressions.Cast._
 import org.apache.spark.sql.catalyst.expressions.codegen._
 import org.apache.spark.sql.catalyst.expressions.codegen.Block._
-import org.apache.spark.sql.catalyst.util.{NumberConverter, TypeUtils}
+import org.apache.spark.sql.catalyst.trees.SQLQueryContext
+import org.apache.spark.sql.catalyst.util.{MathUtils, NumberConverter, 
TypeUtils}
 import org.apache.spark.sql.errors.{QueryCompilationErrors, 
QueryExecutionErrors}
+import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.UTF8String
 
@@ -1447,11 +1449,13 @@ case class Logarithm(left: Expression, right: 
Expression)
  */
 abstract class RoundBase(child: Expression, scale: Expression,
     mode: BigDecimal.RoundingMode.Value, modeStr: String)
-  extends BinaryExpression with Serializable with ImplicitCastInputTypes {
+  extends BinaryExpression with Serializable with ImplicitCastInputTypes with 
SupportQueryContext {
 
   override def left: Expression = child
   override def right: Expression = scale
 
+  protected def ansiEnabled: Boolean = false
+
   // round of Decimal would eval to null if it fails to `changePrecision`
   override def nullable: Boolean = true
 
@@ -1501,6 +1505,14 @@ abstract class RoundBase(child: Expression, scale: 
Expression,
   private lazy val scaleV: Any = scale.eval(EmptyRow)
   protected lazy val _scale: Int = scaleV.asInstanceOf[Int]
 
+  override def initQueryContext(): Option[SQLQueryContext] = {
+    if (ansiEnabled) {
+      Some(origin.context)
+    } else {
+      None
+    }
+  }
+
   override def eval(input: InternalRow): Any = {
     if (scaleV == null) { // if scale is null, no need to eval its child at all
       null
@@ -1529,6 +1541,10 @@ abstract class RoundBase(child: Expression, scale: 
Expression,
         BigDecimal(input1.asInstanceOf[Byte]).setScale(_scale, mode).toByte
       case ShortType =>
         BigDecimal(input1.asInstanceOf[Short]).setScale(_scale, mode).toShort
+      case IntegerType if ansiEnabled =>
+        MathUtils.withOverflow(
+          f = BigDecimal(input1.asInstanceOf[Int]).setScale(_scale, 
mode).toIntExact,
+          context = getContextOrNull)
       case IntegerType =>
         BigDecimal(input1.asInstanceOf[Int]).setScale(_scale, mode).toInt
       case LongType =>
@@ -1584,9 +1600,19 @@ abstract class RoundBase(child: Expression, scale: 
Expression,
         }
       case IntegerType =>
         if (_scale < 0) {
-          s"""
-          ${ev.value} = new java.math.BigDecimal(${ce.value}).
-            setScale(${_scale}, java.math.BigDecimal.${modeStr}).intValue();"""
+          if (ansiEnabled) {
+            val errorContext = getContextOrNullCode(ctx)
+            val evalCode = s"""
+              |${ev.value} = new java.math.BigDecimal(${ce.value}).
+              |setScale(${_scale}, 
java.math.BigDecimal.${modeStr}).intValueExact();
+              |""".stripMargin
+            MathUtils.withOverflowCode(evalCode, errorContext)
+          } else {
+            s"""
+               |${ev.value} = new java.math.BigDecimal(${ce.value}).
+               |setScale(${_scale}, 
java.math.BigDecimal.${modeStr}).intValue();
+               |""".stripMargin
+          }
         } else {
           s"${ev.value} = ${ce.value};"
         }
@@ -1648,9 +1674,17 @@ abstract class RoundBase(child: Expression, scale: 
Expression,
   since = "1.5.0",
   group = "math_funcs")
 // scalastyle:on line.size.limit
-case class Round(child: Expression, scale: Expression)
+case class Round(
+    child: Expression,
+    scale: Expression,
+    override val ansiEnabled: Boolean = SQLConf.get.ansiEnabled)
   extends RoundBase(child, scale, BigDecimal.RoundingMode.HALF_UP, 
"ROUND_HALF_UP") {
-  def this(child: Expression) = this(child, Literal(0))
+  def this(child: Expression) = this(child, Literal(0), 
SQLConf.get.ansiEnabled)
+
+  def this(child: Expression, scale: Expression) = this(child, scale, 
SQLConf.get.ansiEnabled)
+
+  override def flatArguments: Iterator[Any] = Iterator(child, scale)
+
   override protected def withNewChildrenInternal(newLeft: Expression, 
newRight: Expression): Round =
     copy(child = newLeft, scale = newRight)
 }
@@ -1673,9 +1707,17 @@ case class Round(child: Expression, scale: Expression)
   since = "2.0.0",
   group = "math_funcs")
 // scalastyle:on line.size.limit
-case class BRound(child: Expression, scale: Expression)
+case class BRound(
+    child: Expression,
+    scale: Expression,
+    override val ansiEnabled: Boolean = SQLConf.get.ansiEnabled)
   extends RoundBase(child, scale, BigDecimal.RoundingMode.HALF_EVEN, 
"ROUND_HALF_EVEN") {
-  def this(child: Expression) = this(child, Literal(0))
+  def this(child: Expression) = this(child, Literal(0), 
SQLConf.get.ansiEnabled)
+
+  def this(child: Expression, scale: Expression) = this(child, scale, 
SQLConf.get.ansiEnabled)
+
+  override def flatArguments: Iterator[Any] = Iterator(child, scale)
+
   override protected def withNewChildrenInternal(
     newLeft: Expression, newRight: Expression): BRound = copy(child = newLeft, 
scale = newRight)
 }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala
index e79e483076d..b285b1df572 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala
@@ -75,7 +75,7 @@ object MathUtils {
 
   def floorMod(a: Long, b: Long): Long = withOverflow(Math.floorMod(a, b))
 
-  private def withOverflow[A](
+  def withOverflow[A](
       f: => A,
       hint: String = "",
       context: SQLQueryContext = null): A = {
@@ -86,4 +86,14 @@ object MathUtils {
         throw QueryExecutionErrors.arithmeticOverflowError(e.getMessage, hint, 
context)
     }
   }
+
+  def withOverflowCode(evalCode: String, context: String): String = {
+    s"""
+       |try {
+       |  $evalCode
+       |} catch (ArithmeticException e) {
+       |  throw QueryExecutionErrors.arithmeticOverflowError(e.getMessage(), 
"", $context);
+       |}
+       |""".stripMargin
+  }
 }
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala
index c78d72e7a98..92b683a7106 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala
@@ -23,7 +23,7 @@ import java.time.temporal.ChronoUnit
 
 import com.google.common.math.LongMath
 
-import org.apache.spark.SparkFunSuite
+import org.apache.spark.{SparkArithmeticException, SparkFunSuite}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.analysis.TypeCoercion.implicitCast
 import org.apache.spark.sql.catalyst.dsl.expressions._
@@ -838,6 +838,19 @@ class MathExpressionsSuite extends SparkFunSuite with 
ExpressionEvalHelper {
     checkEvaluation(checkDataTypeAndCast(RoundCeil(Literal(135.135), 
Literal(-2))), Decimal(200))
   }
 
+  test("SPARK-42045: integer overflow in round/bround") {
+    val input = 2147483647
+    val scale = -1
+    Seq(Round(input, scale, ansiEnabled = true),
+      BRound(input, scale, ansiEnabled = true)).foreach { expr =>
+      checkExceptionInExpression[SparkArithmeticException](expr, "Overflow")
+    }
+    Seq(Round(input, scale, ansiEnabled = false),
+      BRound(input, scale, ansiEnabled = false)).foreach { expr =>
+      checkEvaluation(expr, -2147483646)
+    }
+  }
+
   test("SPARK-36922: Support ANSI intervals for SIGN/SIGNUM") {
     checkEvaluation(Signum(Literal(Period.ZERO)), 0.0)
     checkEvaluation(Signum(Literal(Period.ofYears(10))), 1.0)
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/PhysicalAggregationSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/PhysicalAggregationSuite.scala
index cf9c9490fab..c0db9c61388 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/PhysicalAggregationSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/PhysicalAggregationSuite.scala
@@ -48,7 +48,7 @@ class PhysicalAggregationSuite extends PlanTest {
 
     // Verify that Round's scale parameter is a Literal.
     resultExpressions(1) match {
-      case Alias(Round(_, _: Literal), _) =>
+      case Alias(Round(_, _: Literal, _), _) =>
       case other => fail("unexpected result expression: " + other)
     }
   }
diff --git a/sql/core/src/test/resources/sql-tests/inputs/ansi/math.sql 
b/sql/core/src/test/resources/sql-tests/inputs/ansi/math.sql
new file mode 100644
index 00000000000..5ee19c28ca6
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/inputs/ansi/math.sql
@@ -0,0 +1 @@
+--IMPORT math.sql
diff --git a/sql/core/src/test/resources/sql-tests/inputs/math.sql 
b/sql/core/src/test/resources/sql-tests/inputs/math.sql
new file mode 100644
index 00000000000..df7210c4595
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/inputs/math.sql
@@ -0,0 +1,17 @@
+-- Round with integer input
+SELECT round(525, 1);
+SELECT round(525, 0);
+SELECT round(525, -1);
+SELECT round(525, -2);
+SELECT round(525, -3);
+SELECT round(2147483647, -1);
+SELECT round(-2147483647, -1);
+
+-- BRound with integer input
+SELECT bround(525, 1);
+SELECT bround(525, 0);
+SELECT bround(525, -1);
+SELECT bround(525, -2);
+SELECT bround(525, -3);
+SELECT bround(2147483647, -1);
+SELECT bround(-2147483647, -1);
\ No newline at end of file
diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/math.sql.out 
b/sql/core/src/test/resources/sql-tests/results/ansi/math.sql.out
new file mode 100644
index 00000000000..e7866b59047
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/results/ansi/math.sql.out
@@ -0,0 +1,175 @@
+-- Automatically generated by SQLQueryTestSuite
+-- !query
+SELECT round(525, 1)
+-- !query schema
+struct<round(525, 1):int>
+-- !query output
+525
+
+
+-- !query
+SELECT round(525, 0)
+-- !query schema
+struct<round(525, 0):int>
+-- !query output
+525
+
+
+-- !query
+SELECT round(525, -1)
+-- !query schema
+struct<round(525, -1):int>
+-- !query output
+530
+
+
+-- !query
+SELECT round(525, -2)
+-- !query schema
+struct<round(525, -2):int>
+-- !query output
+500
+
+
+-- !query
+SELECT round(525, -3)
+-- !query schema
+struct<round(525, -3):int>
+-- !query output
+1000
+
+
+-- !query
+SELECT round(2147483647, -1)
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.SparkArithmeticException
+{
+  "errorClass" : "ARITHMETIC_OVERFLOW",
+  "sqlState" : "22003",
+  "messageParameters" : {
+    "alternative" : "",
+    "config" : "\"spark.sql.ansi.enabled\"",
+    "message" : "Overflow"
+  },
+  "queryContext" : [ {
+    "objectType" : "",
+    "objectName" : "",
+    "startIndex" : 8,
+    "stopIndex" : 28,
+    "fragment" : "round(2147483647, -1)"
+  } ]
+}
+
+
+-- !query
+SELECT round(-2147483647, -1)
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.SparkArithmeticException
+{
+  "errorClass" : "ARITHMETIC_OVERFLOW",
+  "sqlState" : "22003",
+  "messageParameters" : {
+    "alternative" : "",
+    "config" : "\"spark.sql.ansi.enabled\"",
+    "message" : "Overflow"
+  },
+  "queryContext" : [ {
+    "objectType" : "",
+    "objectName" : "",
+    "startIndex" : 8,
+    "stopIndex" : 29,
+    "fragment" : "round(-2147483647, -1)"
+  } ]
+}
+
+
+-- !query
+SELECT bround(525, 1)
+-- !query schema
+struct<bround(525, 1):int>
+-- !query output
+525
+
+
+-- !query
+SELECT bround(525, 0)
+-- !query schema
+struct<bround(525, 0):int>
+-- !query output
+525
+
+
+-- !query
+SELECT bround(525, -1)
+-- !query schema
+struct<bround(525, -1):int>
+-- !query output
+520
+
+
+-- !query
+SELECT bround(525, -2)
+-- !query schema
+struct<bround(525, -2):int>
+-- !query output
+500
+
+
+-- !query
+SELECT bround(525, -3)
+-- !query schema
+struct<bround(525, -3):int>
+-- !query output
+1000
+
+
+-- !query
+SELECT bround(2147483647, -1)
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.SparkArithmeticException
+{
+  "errorClass" : "ARITHMETIC_OVERFLOW",
+  "sqlState" : "22003",
+  "messageParameters" : {
+    "alternative" : "",
+    "config" : "\"spark.sql.ansi.enabled\"",
+    "message" : "Overflow"
+  },
+  "queryContext" : [ {
+    "objectType" : "",
+    "objectName" : "",
+    "startIndex" : 8,
+    "stopIndex" : 29,
+    "fragment" : "bround(2147483647, -1)"
+  } ]
+}
+
+
+-- !query
+SELECT bround(-2147483647, -1)
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.SparkArithmeticException
+{
+  "errorClass" : "ARITHMETIC_OVERFLOW",
+  "sqlState" : "22003",
+  "messageParameters" : {
+    "alternative" : "",
+    "config" : "\"spark.sql.ansi.enabled\"",
+    "message" : "Overflow"
+  },
+  "queryContext" : [ {
+    "objectType" : "",
+    "objectName" : "",
+    "startIndex" : 8,
+    "stopIndex" : 30,
+    "fragment" : "bround(-2147483647, -1)"
+  } ]
+}
diff --git a/sql/core/src/test/resources/sql-tests/results/math.sql.out 
b/sql/core/src/test/resources/sql-tests/results/math.sql.out
new file mode 100644
index 00000000000..693ce3e8cbf
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/results/math.sql.out
@@ -0,0 +1,111 @@
+-- Automatically generated by SQLQueryTestSuite
+-- !query
+SELECT round(525, 1)
+-- !query schema
+struct<round(525, 1):int>
+-- !query output
+525
+
+
+-- !query
+SELECT round(525, 0)
+-- !query schema
+struct<round(525, 0):int>
+-- !query output
+525
+
+
+-- !query
+SELECT round(525, -1)
+-- !query schema
+struct<round(525, -1):int>
+-- !query output
+530
+
+
+-- !query
+SELECT round(525, -2)
+-- !query schema
+struct<round(525, -2):int>
+-- !query output
+500
+
+
+-- !query
+SELECT round(525, -3)
+-- !query schema
+struct<round(525, -3):int>
+-- !query output
+1000
+
+
+-- !query
+SELECT round(2147483647, -1)
+-- !query schema
+struct<round(2147483647, -1):int>
+-- !query output
+-2147483646
+
+
+-- !query
+SELECT round(-2147483647, -1)
+-- !query schema
+struct<round(-2147483647, -1):int>
+-- !query output
+2147483646
+
+
+-- !query
+SELECT bround(525, 1)
+-- !query schema
+struct<bround(525, 1):int>
+-- !query output
+525
+
+
+-- !query
+SELECT bround(525, 0)
+-- !query schema
+struct<bround(525, 0):int>
+-- !query output
+525
+
+
+-- !query
+SELECT bround(525, -1)
+-- !query schema
+struct<bround(525, -1):int>
+-- !query output
+520
+
+
+-- !query
+SELECT bround(525, -2)
+-- !query schema
+struct<bround(525, -2):int>
+-- !query output
+500
+
+
+-- !query
+SELECT bround(525, -3)
+-- !query schema
+struct<bround(525, -3):int>
+-- !query output
+1000
+
+
+-- !query
+SELECT bround(2147483647, -1)
+-- !query schema
+struct<bround(2147483647, -1):int>
+-- !query output
+-2147483646
+
+
+-- !query
+SELECT bround(-2147483647, -1)
+-- !query schema
+struct<bround(-2147483647, -1):int>
+-- !query output
+2147483646


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to