cloud-fan commented on a change in pull request #27627:
URL: https://github.com/apache/spark/pull/27627#discussion_r433128417



##########
File path: sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
##########
@@ -200,14 +222,90 @@ class DataFrameSuite extends QueryTest
     Seq(true, false).foreach { ansiEnabled =>
       withSQLConf((SQLConf.ANSI_ENABLED.key, ansiEnabled.toString)) {
         val structDf = largeDecimals.select("a").agg(sum("a"))
-        if (!ansiEnabled) {
-          checkAnswer(structDf, Row(null))
-        } else {
-          val e = intercept[SparkException] {
-            structDf.collect
+        assertDecimalSumOverflow(structDf, ansiEnabled, Row(null))
+      }
+    }
+  }
+
+  test("SPARK-28067: sum of null decimal values") {
+    Seq("true", "false").foreach { wholeStageEnabled =>
+      withSQLConf((SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, wholeStageEnabled)) 
{
+        Seq("true", "false").foreach { ansiEnabled =>
+          withSQLConf((SQLConf.ANSI_ENABLED.key, ansiEnabled)) {
+            val df = spark.range(1, 4, 1).select(expr(s"cast(null as 
decimal(38,18)) as d"))
+            checkAnswer(df.agg(sum($"d")), Row(null))
+          }
+        }
+      }
+    }
+  }
+
+  test("SPARK-28067: Aggregate sum should not return wrong results for decimal 
overflow") {
+    Seq("true", "false").foreach { wholeStageEnabled =>
+      withSQLConf((SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, wholeStageEnabled)) 
{
+        Seq(true, false).foreach { ansiEnabled =>
+          withSQLConf((SQLConf.ANSI_ENABLED.key, ansiEnabled.toString)) {
+            val df0 = Seq(
+              (BigDecimal("10000000000000000000"), 1),
+              (BigDecimal("10000000000000000000"), 1),
+              (BigDecimal("10000000000000000000"), 2)).toDF("decNum", "intNum")
+            val df1 = Seq(
+              (BigDecimal("10000000000000000000"), 2),
+              (BigDecimal("10000000000000000000"), 2),
+              (BigDecimal("10000000000000000000"), 2),
+              (BigDecimal("10000000000000000000"), 2),
+              (BigDecimal("10000000000000000000"), 2),
+              (BigDecimal("10000000000000000000"), 2),
+              (BigDecimal("10000000000000000000"), 2),
+              (BigDecimal("10000000000000000000"), 2),
+              (BigDecimal("10000000000000000000"), 2)).toDF("decNum", "intNum")
+            val df = df0.union(df1)
+            val df2 = df.withColumnRenamed("decNum", "decNum2").
+              join(df, "intNum").agg(sum("decNum"))

Review comment:
       it's hard to understand the test with joins. As I mentioned earlier in 
https://github.com/apache/spark/pull/27627#issuecomment-595067541, there is a 
way to reproduce the bug without join. Can we simplify this test?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to