This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch branch-3.0 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.0 by this push: new 89765f5 [SPARK-32018][SQL][FOLLOWUP][3.0] Throw exception on decimal value overflow of sum aggregation 89765f5 is described below commit 89765f556f26252aed1add71a9da84209ff03493 Author: Gengliang Wang <gengliang.w...@databricks.com> AuthorDate: Thu Aug 13 03:52:12 2020 +0000 [SPARK-32018][SQL][FOLLOWUP][3.0] Throw exception on decimal value overflow of sum aggregation ### What changes were proposed in this pull request? This is a followup of https://github.com/apache/spark/pull/29125 In branch 3.0: 1. for hash aggregation, before https://github.com/apache/spark/pull/29125 there will be a runtime exception on decimal overflow of sum aggregation; after https://github.com/apache/spark/pull/29125, there could be a wrong result. 2. for sort aggregation, with/without https://github.com/apache/spark/pull/29125, there could be a wrong result on decimal overflow. While in master branch(the future 3.1 release), the problem doesn't exist since in https://github.com/apache/spark/pull/27627 there is a flag for marking whether overflow happens in aggregation buffer. However, the aggregation buffer is written in steaming checkpoints. Thus, we can't change to aggregation buffer to resolve the issue. As there is no easy solution for returning null/throwing exception regarding `spark.sql.ansi.enabled` on overflow in branch 3.0, we have to make a choice here: always throw exception on decimal value overflow of sum aggregation. ### Why are the changes needed? Avoid returning wrong result in decimal value sum aggregation. ### Does this PR introduce _any_ user-facing change? Yes, there is always exception on decimal value overflow of sum aggregation, instead of a possible wrong result. ### How was this patch tested? Unit test case Closes #29404 from gengliangwang/fixSum. Authored-by: Gengliang Wang <gengliang.w...@databricks.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../sql/catalyst/expressions/aggregate/Sum.scala | 19 +++++++++-- .../apache/spark/sql/DataFrameAggregateSuite.scala | 37 ++++++++++++++++++++++ 2 files changed, 53 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala index d2daaac..d442549 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala @@ -71,23 +71,36 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast ) override lazy val updateExpressions: Seq[Expression] = { + val sumWithChild = resultType match { + case d: DecimalType => + CheckOverflow(coalesce(sum, zero) + child.cast(sumDataType), d, nullOnOverflow = false) + case _ => + coalesce(sum, zero) + child.cast(sumDataType) + } + if (child.nullable) { Seq( /* sum = */ - coalesce(coalesce(sum, zero) + child.cast(sumDataType), sum) + coalesce(sumWithChild, sum) ) } else { Seq( /* sum = */ - coalesce(sum, zero) + child.cast(sumDataType) + sumWithChild ) } } override lazy val mergeExpressions: Seq[Expression] = { + val sumWithRight = resultType match { + case d: DecimalType => + CheckOverflow(coalesce(sum.left, zero) + sum.right, d, nullOnOverflow = false) + + case _ => coalesce(sum.left, zero) + sum.right + } Seq( /* sum = */ - coalesce(coalesce(sum.left, zero) + sum.right, sum.left) + coalesce(sumWithRight, sum.left) ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 54327b3..8c0358e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -21,6 +21,7 @@ import scala.util.Random import org.scalatest.Matchers.the +import org.apache.spark.SparkException import org.apache.spark.sql.execution.WholeStageCodegenExec import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} @@ -1044,6 +1045,42 @@ class DataFrameAggregateSuite extends QueryTest checkAnswer(sql(queryTemplate("FIRST")), Row(1)) checkAnswer(sql(queryTemplate("LAST")), Row(3)) } + + private def exceptionOnDecimalOverflow(df: DataFrame): Unit = { + val msg = intercept[SparkException] { + df.collect() + }.getCause.getMessage + assert(msg.contains("cannot be represented as Decimal(38, 18)")) + } + + test("SPARK-32018: Throw exception on decimal overflow at partial aggregate phase") { + val decimalString = "1" + "0" * 19 + val union = spark.range(0, 1, 1, 1).union(spark.range(0, 11, 1, 1)) + val hashAgg = union + .select(expr(s"cast('$decimalString' as decimal (38, 18)) as d"), lit("1").as("key")) + .groupBy("key") + .agg(sum($"d").alias("sumD")) + .select($"sumD") + exceptionOnDecimalOverflow(hashAgg) + + val sortAgg = union + .select(expr(s"cast('$decimalString' as decimal (38, 18)) as d"), lit("a").as("str"), + lit("1").as("key")).groupBy("key") + .agg(sum($"d").alias("sumD"), min($"str").alias("minStr")).select($"sumD", $"minStr") + exceptionOnDecimalOverflow(sortAgg) + } + + test("SPARK-32018: Throw exception on decimal overflow at merge aggregation phase") { + val decimalString = "5" + "0" * 19 + val union = spark.range(0, 1, 1, 1).union(spark.range(0, 1, 1, 1)) + .union(spark.range(0, 1, 1, 1)) + val agg = union + .select(expr(s"cast('$decimalString' as decimal (38, 18)) as d"), lit("1").as("key")) + .groupBy("key") + .agg(sum($"d").alias("sumD")) + .select($"sumD") + exceptionOnDecimalOverflow(agg) + } } case class B(c: Option[Double]) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org