This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.0 by this push:
     new 89765f5  [SPARK-32018][SQL][FOLLOWUP][3.0] Throw exception on decimal 
value overflow of sum aggregation
89765f5 is described below

commit 89765f556f26252aed1add71a9da84209ff03493
Author: Gengliang Wang <gengliang.w...@databricks.com>
AuthorDate: Thu Aug 13 03:52:12 2020 +0000

    [SPARK-32018][SQL][FOLLOWUP][3.0] Throw exception on decimal value overflow 
of sum aggregation
    
    ### What changes were proposed in this pull request?
    
    This is a followup of https://github.com/apache/spark/pull/29125
    In branch 3.0:
    1. for hash aggregation, before https://github.com/apache/spark/pull/29125 
there will be a runtime exception on decimal overflow of sum aggregation; after 
https://github.com/apache/spark/pull/29125, there could be a wrong result.
    2. for sort aggregation, with/without 
https://github.com/apache/spark/pull/29125, there could be a wrong result on 
decimal overflow.
    
    While in master branch(the future 3.1 release), the problem doesn't exist 
since in https://github.com/apache/spark/pull/27627 there is a flag for marking 
whether overflow happens in aggregation buffer. However, the aggregation buffer 
is written in steaming checkpoints. Thus, we can't change to aggregation buffer 
to resolve the issue.
    
    As there is no easy solution for returning null/throwing exception 
regarding `spark.sql.ansi.enabled` on overflow in branch 3.0, we have to make a 
choice here: always throw exception on decimal value overflow of sum 
aggregation.
    ### Why are the changes needed?
    
    Avoid returning wrong result in decimal value sum aggregation.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, there is always exception on decimal value overflow of sum 
aggregation, instead of a possible wrong result.
    
    ### How was this patch tested?
    
    Unit test case
    
    Closes #29404 from gengliangwang/fixSum.
    
    Authored-by: Gengliang Wang <gengliang.w...@databricks.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../sql/catalyst/expressions/aggregate/Sum.scala   | 19 +++++++++--
 .../apache/spark/sql/DataFrameAggregateSuite.scala | 37 ++++++++++++++++++++++
 2 files changed, 53 insertions(+), 3 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
index d2daaac..d442549 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
@@ -71,23 +71,36 @@ case class Sum(child: Expression) extends 
DeclarativeAggregate with ImplicitCast
   )
 
   override lazy val updateExpressions: Seq[Expression] = {
+    val sumWithChild = resultType match {
+      case d: DecimalType =>
+        CheckOverflow(coalesce(sum, zero) + child.cast(sumDataType), d, 
nullOnOverflow = false)
+      case _ =>
+        coalesce(sum, zero) + child.cast(sumDataType)
+    }
+
     if (child.nullable) {
       Seq(
         /* sum = */
-        coalesce(coalesce(sum, zero) + child.cast(sumDataType), sum)
+        coalesce(sumWithChild, sum)
       )
     } else {
       Seq(
         /* sum = */
-        coalesce(sum, zero) + child.cast(sumDataType)
+        sumWithChild
       )
     }
   }
 
   override lazy val mergeExpressions: Seq[Expression] = {
+    val sumWithRight = resultType match {
+      case d: DecimalType =>
+        CheckOverflow(coalesce(sum.left, zero) + sum.right, d, nullOnOverflow 
= false)
+
+      case _ => coalesce(sum.left, zero) + sum.right
+    }
     Seq(
       /* sum = */
-      coalesce(coalesce(sum.left, zero) + sum.right, sum.left)
+      coalesce(sumWithRight, sum.left)
     )
   }
 
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index 54327b3..8c0358e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -21,6 +21,7 @@ import scala.util.Random
 
 import org.scalatest.Matchers.the
 
+import org.apache.spark.SparkException
 import org.apache.spark.sql.execution.WholeStageCodegenExec
 import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
 import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, 
ObjectHashAggregateExec, SortAggregateExec}
@@ -1044,6 +1045,42 @@ class DataFrameAggregateSuite extends QueryTest
     checkAnswer(sql(queryTemplate("FIRST")), Row(1))
     checkAnswer(sql(queryTemplate("LAST")), Row(3))
   }
+
+  private def exceptionOnDecimalOverflow(df: DataFrame): Unit = {
+    val msg = intercept[SparkException] {
+      df.collect()
+    }.getCause.getMessage
+    assert(msg.contains("cannot be represented as Decimal(38, 18)"))
+  }
+
+  test("SPARK-32018: Throw exception on decimal overflow at partial aggregate 
phase") {
+    val decimalString = "1" + "0" * 19
+    val union = spark.range(0, 1, 1, 1).union(spark.range(0, 11, 1, 1))
+    val hashAgg = union
+      .select(expr(s"cast('$decimalString' as decimal (38, 18)) as d"), 
lit("1").as("key"))
+      .groupBy("key")
+      .agg(sum($"d").alias("sumD"))
+      .select($"sumD")
+    exceptionOnDecimalOverflow(hashAgg)
+
+    val sortAgg = union
+      .select(expr(s"cast('$decimalString' as decimal (38, 18)) as d"), 
lit("a").as("str"),
+      lit("1").as("key")).groupBy("key")
+      .agg(sum($"d").alias("sumD"), 
min($"str").alias("minStr")).select($"sumD", $"minStr")
+    exceptionOnDecimalOverflow(sortAgg)
+  }
+
+  test("SPARK-32018: Throw exception on decimal overflow at merge aggregation 
phase") {
+    val decimalString = "5" + "0" * 19
+    val union = spark.range(0, 1, 1, 1).union(spark.range(0, 1, 1, 1))
+      .union(spark.range(0, 1, 1, 1))
+    val agg = union
+      .select(expr(s"cast('$decimalString' as decimal (38, 18)) as d"), 
lit("1").as("key"))
+      .groupBy("key")
+      .agg(sum($"d").alias("sumD"))
+      .select($"sumD")
+    exceptionOnDecimalOverflow(agg)
+  }
 }
 
 case class B(c: Option[Double])


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to