This is an automated email from the ASF dual-hosted git repository. gengliang pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 42721120f3c [SPARK-42045][SQL] ANSI SQL mode: Round/Bround should return an error on integer overflow 42721120f3c is described below commit 42721120f3c7206a9fc22db5d0bb7cf40f0cacfd Author: Gengliang Wang <gengli...@apache.org> AuthorDate: Fri Jan 13 09:40:36 2023 -0800 [SPARK-42045][SQL] ANSI SQL mode: Round/Bround should return an error on integer overflow ### What changes were proposed in this pull request? In ANSI SQL mode, Round/Bround should return an error on integer overflow. Note this PR is for integer only. Once it is merge, I will create one follow-up PR for all the rest integral types: byte, short, and long. Also, the function ceil and floor accepts decimal type input, so there is no need to change them. ### Why are the changes needed? In ANSI SQL mode, integer overflow should cause error instead of returning an unreasonable result. For example, `round(2147483647, -1)` should return error instead of returning `-2147483646` ### Does this PR introduce _any_ user-facing change? Yes, in ANSI SQL mode, SQL function Round and Bround will return an error on integer overflow ### How was this patch tested? UT Closes #39546 from gengliangwang/fixRound. Authored-by: Gengliang Wang <gengli...@apache.org> Signed-off-by: Gengliang Wang <gengli...@apache.org> --- .../sql/catalyst/expressions/mathExpressions.scala | 60 +++++-- .../apache/spark/sql/catalyst/util/MathUtils.scala | 12 +- .../expressions/MathExpressionsSuite.scala | 15 +- .../catalyst/util/PhysicalAggregationSuite.scala | 2 +- .../test/resources/sql-tests/inputs/ansi/math.sql | 1 + .../src/test/resources/sql-tests/inputs/math.sql | 17 ++ .../resources/sql-tests/results/ansi/math.sql.out | 175 +++++++++++++++++++++ .../test/resources/sql-tests/results/math.sql.out | 111 +++++++++++++ 8 files changed, 381 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index 9ffc148180a..50a1194c2f1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -26,8 +26,10 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{DataTypeMismatch, import org.apache.spark.sql.catalyst.expressions.Cast._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ -import org.apache.spark.sql.catalyst.util.{NumberConverter, TypeUtils} +import org.apache.spark.sql.catalyst.trees.SQLQueryContext +import org.apache.spark.sql.catalyst.util.{MathUtils, NumberConverter, TypeUtils} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -1447,11 +1449,13 @@ case class Logarithm(left: Expression, right: Expression) */ abstract class RoundBase(child: Expression, scale: Expression, mode: BigDecimal.RoundingMode.Value, modeStr: String) - extends BinaryExpression with Serializable with ImplicitCastInputTypes { + extends BinaryExpression with Serializable with ImplicitCastInputTypes with SupportQueryContext { override def left: Expression = child override def right: Expression = scale + protected def ansiEnabled: Boolean = false + // round of Decimal would eval to null if it fails to `changePrecision` override def nullable: Boolean = true @@ -1501,6 +1505,14 @@ abstract class RoundBase(child: Expression, scale: Expression, private lazy val scaleV: Any = scale.eval(EmptyRow) protected lazy val _scale: Int = scaleV.asInstanceOf[Int] + override def initQueryContext(): Option[SQLQueryContext] = { + if (ansiEnabled) { + Some(origin.context) + } else { + None + } + } + override def eval(input: InternalRow): Any = { if (scaleV == null) { // if scale is null, no need to eval its child at all null @@ -1529,6 +1541,10 @@ abstract class RoundBase(child: Expression, scale: Expression, BigDecimal(input1.asInstanceOf[Byte]).setScale(_scale, mode).toByte case ShortType => BigDecimal(input1.asInstanceOf[Short]).setScale(_scale, mode).toShort + case IntegerType if ansiEnabled => + MathUtils.withOverflow( + f = BigDecimal(input1.asInstanceOf[Int]).setScale(_scale, mode).toIntExact, + context = getContextOrNull) case IntegerType => BigDecimal(input1.asInstanceOf[Int]).setScale(_scale, mode).toInt case LongType => @@ -1584,9 +1600,19 @@ abstract class RoundBase(child: Expression, scale: Expression, } case IntegerType => if (_scale < 0) { - s""" - ${ev.value} = new java.math.BigDecimal(${ce.value}). - setScale(${_scale}, java.math.BigDecimal.${modeStr}).intValue();""" + if (ansiEnabled) { + val errorContext = getContextOrNullCode(ctx) + val evalCode = s""" + |${ev.value} = new java.math.BigDecimal(${ce.value}). + |setScale(${_scale}, java.math.BigDecimal.${modeStr}).intValueExact(); + |""".stripMargin + MathUtils.withOverflowCode(evalCode, errorContext) + } else { + s""" + |${ev.value} = new java.math.BigDecimal(${ce.value}). + |setScale(${_scale}, java.math.BigDecimal.${modeStr}).intValue(); + |""".stripMargin + } } else { s"${ev.value} = ${ce.value};" } @@ -1648,9 +1674,17 @@ abstract class RoundBase(child: Expression, scale: Expression, since = "1.5.0", group = "math_funcs") // scalastyle:on line.size.limit -case class Round(child: Expression, scale: Expression) +case class Round( + child: Expression, + scale: Expression, + override val ansiEnabled: Boolean = SQLConf.get.ansiEnabled) extends RoundBase(child, scale, BigDecimal.RoundingMode.HALF_UP, "ROUND_HALF_UP") { - def this(child: Expression) = this(child, Literal(0)) + def this(child: Expression) = this(child, Literal(0), SQLConf.get.ansiEnabled) + + def this(child: Expression, scale: Expression) = this(child, scale, SQLConf.get.ansiEnabled) + + override def flatArguments: Iterator[Any] = Iterator(child, scale) + override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Round = copy(child = newLeft, scale = newRight) } @@ -1673,9 +1707,17 @@ case class Round(child: Expression, scale: Expression) since = "2.0.0", group = "math_funcs") // scalastyle:on line.size.limit -case class BRound(child: Expression, scale: Expression) +case class BRound( + child: Expression, + scale: Expression, + override val ansiEnabled: Boolean = SQLConf.get.ansiEnabled) extends RoundBase(child, scale, BigDecimal.RoundingMode.HALF_EVEN, "ROUND_HALF_EVEN") { - def this(child: Expression) = this(child, Literal(0)) + def this(child: Expression) = this(child, Literal(0), SQLConf.get.ansiEnabled) + + def this(child: Expression, scale: Expression) = this(child, scale, SQLConf.get.ansiEnabled) + + override def flatArguments: Iterator[Any] = Iterator(child, scale) + override protected def withNewChildrenInternal( newLeft: Expression, newRight: Expression): BRound = copy(child = newLeft, scale = newRight) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala index e79e483076d..b285b1df572 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala @@ -75,7 +75,7 @@ object MathUtils { def floorMod(a: Long, b: Long): Long = withOverflow(Math.floorMod(a, b)) - private def withOverflow[A]( + def withOverflow[A]( f: => A, hint: String = "", context: SQLQueryContext = null): A = { @@ -86,4 +86,14 @@ object MathUtils { throw QueryExecutionErrors.arithmeticOverflowError(e.getMessage, hint, context) } } + + def withOverflowCode(evalCode: String, context: String): String = { + s""" + |try { + | $evalCode + |} catch (ArithmeticException e) { + | throw QueryExecutionErrors.arithmeticOverflowError(e.getMessage(), "", $context); + |} + |""".stripMargin + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala index c78d72e7a98..92b683a7106 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala @@ -23,7 +23,7 @@ import java.time.temporal.ChronoUnit import com.google.common.math.LongMath -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SparkArithmeticException, SparkFunSuite} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCoercion.implicitCast import org.apache.spark.sql.catalyst.dsl.expressions._ @@ -838,6 +838,19 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(checkDataTypeAndCast(RoundCeil(Literal(135.135), Literal(-2))), Decimal(200)) } + test("SPARK-42045: integer overflow in round/bround") { + val input = 2147483647 + val scale = -1 + Seq(Round(input, scale, ansiEnabled = true), + BRound(input, scale, ansiEnabled = true)).foreach { expr => + checkExceptionInExpression[SparkArithmeticException](expr, "Overflow") + } + Seq(Round(input, scale, ansiEnabled = false), + BRound(input, scale, ansiEnabled = false)).foreach { expr => + checkEvaluation(expr, -2147483646) + } + } + test("SPARK-36922: Support ANSI intervals for SIGN/SIGNUM") { checkEvaluation(Signum(Literal(Period.ZERO)), 0.0) checkEvaluation(Signum(Literal(Period.ofYears(10))), 1.0) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/PhysicalAggregationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/PhysicalAggregationSuite.scala index cf9c9490fab..c0db9c61388 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/PhysicalAggregationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/PhysicalAggregationSuite.scala @@ -48,7 +48,7 @@ class PhysicalAggregationSuite extends PlanTest { // Verify that Round's scale parameter is a Literal. resultExpressions(1) match { - case Alias(Round(_, _: Literal), _) => + case Alias(Round(_, _: Literal, _), _) => case other => fail("unexpected result expression: " + other) } } diff --git a/sql/core/src/test/resources/sql-tests/inputs/ansi/math.sql b/sql/core/src/test/resources/sql-tests/inputs/ansi/math.sql new file mode 100644 index 00000000000..5ee19c28ca6 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/ansi/math.sql @@ -0,0 +1 @@ +--IMPORT math.sql diff --git a/sql/core/src/test/resources/sql-tests/inputs/math.sql b/sql/core/src/test/resources/sql-tests/inputs/math.sql new file mode 100644 index 00000000000..df7210c4595 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/math.sql @@ -0,0 +1,17 @@ +-- Round with integer input +SELECT round(525, 1); +SELECT round(525, 0); +SELECT round(525, -1); +SELECT round(525, -2); +SELECT round(525, -3); +SELECT round(2147483647, -1); +SELECT round(-2147483647, -1); + +-- BRound with integer input +SELECT bround(525, 1); +SELECT bround(525, 0); +SELECT bround(525, -1); +SELECT bround(525, -2); +SELECT bround(525, -3); +SELECT bround(2147483647, -1); +SELECT bround(-2147483647, -1); \ No newline at end of file diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/math.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/math.sql.out new file mode 100644 index 00000000000..e7866b59047 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/ansi/math.sql.out @@ -0,0 +1,175 @@ +-- Automatically generated by SQLQueryTestSuite +-- !query +SELECT round(525, 1) +-- !query schema +struct<round(525, 1):int> +-- !query output +525 + + +-- !query +SELECT round(525, 0) +-- !query schema +struct<round(525, 0):int> +-- !query output +525 + + +-- !query +SELECT round(525, -1) +-- !query schema +struct<round(525, -1):int> +-- !query output +530 + + +-- !query +SELECT round(525, -2) +-- !query schema +struct<round(525, -2):int> +-- !query output +500 + + +-- !query +SELECT round(525, -3) +-- !query schema +struct<round(525, -3):int> +-- !query output +1000 + + +-- !query +SELECT round(2147483647, -1) +-- !query schema +struct<> +-- !query output +org.apache.spark.SparkArithmeticException +{ + "errorClass" : "ARITHMETIC_OVERFLOW", + "sqlState" : "22003", + "messageParameters" : { + "alternative" : "", + "config" : "\"spark.sql.ansi.enabled\"", + "message" : "Overflow" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 28, + "fragment" : "round(2147483647, -1)" + } ] +} + + +-- !query +SELECT round(-2147483647, -1) +-- !query schema +struct<> +-- !query output +org.apache.spark.SparkArithmeticException +{ + "errorClass" : "ARITHMETIC_OVERFLOW", + "sqlState" : "22003", + "messageParameters" : { + "alternative" : "", + "config" : "\"spark.sql.ansi.enabled\"", + "message" : "Overflow" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 29, + "fragment" : "round(-2147483647, -1)" + } ] +} + + +-- !query +SELECT bround(525, 1) +-- !query schema +struct<bround(525, 1):int> +-- !query output +525 + + +-- !query +SELECT bround(525, 0) +-- !query schema +struct<bround(525, 0):int> +-- !query output +525 + + +-- !query +SELECT bround(525, -1) +-- !query schema +struct<bround(525, -1):int> +-- !query output +520 + + +-- !query +SELECT bround(525, -2) +-- !query schema +struct<bround(525, -2):int> +-- !query output +500 + + +-- !query +SELECT bround(525, -3) +-- !query schema +struct<bround(525, -3):int> +-- !query output +1000 + + +-- !query +SELECT bround(2147483647, -1) +-- !query schema +struct<> +-- !query output +org.apache.spark.SparkArithmeticException +{ + "errorClass" : "ARITHMETIC_OVERFLOW", + "sqlState" : "22003", + "messageParameters" : { + "alternative" : "", + "config" : "\"spark.sql.ansi.enabled\"", + "message" : "Overflow" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 29, + "fragment" : "bround(2147483647, -1)" + } ] +} + + +-- !query +SELECT bround(-2147483647, -1) +-- !query schema +struct<> +-- !query output +org.apache.spark.SparkArithmeticException +{ + "errorClass" : "ARITHMETIC_OVERFLOW", + "sqlState" : "22003", + "messageParameters" : { + "alternative" : "", + "config" : "\"spark.sql.ansi.enabled\"", + "message" : "Overflow" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 30, + "fragment" : "bround(-2147483647, -1)" + } ] +} diff --git a/sql/core/src/test/resources/sql-tests/results/math.sql.out b/sql/core/src/test/resources/sql-tests/results/math.sql.out new file mode 100644 index 00000000000..693ce3e8cbf --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/math.sql.out @@ -0,0 +1,111 @@ +-- Automatically generated by SQLQueryTestSuite +-- !query +SELECT round(525, 1) +-- !query schema +struct<round(525, 1):int> +-- !query output +525 + + +-- !query +SELECT round(525, 0) +-- !query schema +struct<round(525, 0):int> +-- !query output +525 + + +-- !query +SELECT round(525, -1) +-- !query schema +struct<round(525, -1):int> +-- !query output +530 + + +-- !query +SELECT round(525, -2) +-- !query schema +struct<round(525, -2):int> +-- !query output +500 + + +-- !query +SELECT round(525, -3) +-- !query schema +struct<round(525, -3):int> +-- !query output +1000 + + +-- !query +SELECT round(2147483647, -1) +-- !query schema +struct<round(2147483647, -1):int> +-- !query output +-2147483646 + + +-- !query +SELECT round(-2147483647, -1) +-- !query schema +struct<round(-2147483647, -1):int> +-- !query output +2147483646 + + +-- !query +SELECT bround(525, 1) +-- !query schema +struct<bround(525, 1):int> +-- !query output +525 + + +-- !query +SELECT bround(525, 0) +-- !query schema +struct<bround(525, 0):int> +-- !query output +525 + + +-- !query +SELECT bround(525, -1) +-- !query schema +struct<bround(525, -1):int> +-- !query output +520 + + +-- !query +SELECT bround(525, -2) +-- !query schema +struct<bround(525, -2):int> +-- !query output +500 + + +-- !query +SELECT bround(525, -3) +-- !query schema +struct<bround(525, -3):int> +-- !query output +1000 + + +-- !query +SELECT bround(2147483647, -1) +-- !query schema +struct<bround(2147483647, -1):int> +-- !query output +-2147483646 + + +-- !query +SELECT bround(-2147483647, -1) +-- !query schema +struct<bround(-2147483647, -1):int> +-- !query output +2147483646 --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org