This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-3.3
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.3 by this push:
     new 68dfff7  [SPARK-37839][SQL][FOLLOWUP] Check overflow when DS V2 
partial aggregate push-down `AVG`
68dfff7 is described below

commit 68dfff767aa3faaaf3c614514dcd8ef322256579
Author: Jiaan Geng <belie...@163.com>
AuthorDate: Thu Mar 31 19:18:58 2022 +0800

    [SPARK-37839][SQL][FOLLOWUP] Check overflow when DS V2 partial aggregate 
push-down `AVG`
    
    ### What changes were proposed in this pull request?
    https://github.com/apache/spark/pull/35130 supports partial aggregate 
push-down `AVG` for DS V2.
    The behavior doesn't consistent with `Average` if occurs overflow in ansi 
mode.
    This PR closely follows the implement of `Average` to respect overflow in 
ansi mode.
    
    ### Why are the changes needed?
    Make the behavior consistent with `Average` if occurs overflow in ansi mode.
    
    ### Does this PR introduce _any_ user-facing change?
    'Yes'.
    Users could see the exception about overflow throws in ansi mode.
    
    ### How was this patch tested?
    New tests.
    
    Closes #35320 from beliefer/SPARK-37839_followup.
    
    Authored-by: Jiaan Geng <belie...@163.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
    (cherry picked from commit e6839ad7340bc9eb5df03df2a62110bdda805e6b)
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../catalyst/expressions/aggregate/Average.scala   |  4 +-
 .../datasources/v2/V2ScanRelationPushDown.scala    | 21 ++++------
 .../org/apache/spark/sql/jdbc/JDBCV2Suite.scala    | 45 +++++++++++++++++++++-
 3 files changed, 52 insertions(+), 18 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
index 05f7eda..533f7f2 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
@@ -76,8 +76,8 @@ case class Average(
     case _ => DoubleType
   }
 
-  private lazy val sum = AttributeReference("sum", sumDataType)()
-  private lazy val count = AttributeReference("count", LongType)()
+  lazy val sum = AttributeReference("sum", sumDataType)()
+  lazy val count = AttributeReference("count", LongType)()
 
   override lazy val aggBufferAttributes = sum :: count :: Nil
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
index c8ef8b0..5371829 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources.v2
 
 import scala.collection.mutable
 
-import org.apache.spark.sql.catalyst.expressions.{Alias, AliasHelper, And, 
Attribute, AttributeReference, Cast, Divide, DivideDTInterval, 
DivideYMInterval, EqualTo, Expression, If, IntegerLiteral, Literal, 
NamedExpression, PredicateHelper, ProjectionOverSchema, SortOrder, 
SubqueryExpression}
+import org.apache.spark.sql.catalyst.expressions.{Alias, AliasHelper, And, 
Attribute, AttributeReference, Cast, Expression, IntegerLiteral, 
NamedExpression, PredicateHelper, ProjectionOverSchema, SortOrder, 
SubqueryExpression}
 import org.apache.spark.sql.catalyst.expressions.aggregate
 import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
 import org.apache.spark.sql.catalyst.optimizer.CollapseProject
@@ -32,7 +32,7 @@ import 
org.apache.spark.sql.connector.expressions.filter.Predicate
 import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, 
SupportsPushDownAggregates, SupportsPushDownFilters, V1Scan}
 import org.apache.spark.sql.execution.datasources.DataSourceStrategy
 import org.apache.spark.sql.sources
-import org.apache.spark.sql.types.{DataType, DayTimeIntervalType, LongType, 
StructType, YearMonthIntervalType}
+import org.apache.spark.sql.types.{DataType, LongType, StructType}
 import org.apache.spark.sql.util.SchemaUtils._
 
 object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper 
with AliasHelper {
@@ -138,18 +138,11 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] 
with PredicateHelper wit
                       case AggregateExpression(avg: aggregate.Average, _, 
isDistinct, _, _) =>
                         val sum = 
aggregate.Sum(avg.child).toAggregateExpression(isDistinct)
                         val count = 
aggregate.Count(avg.child).toAggregateExpression(isDistinct)
-                        // Closely follow `Average.evaluateExpression`
-                        avg.dataType match {
-                          case _: YearMonthIntervalType =>
-                            If(EqualTo(count, Literal(0L)),
-                              Literal(null, YearMonthIntervalType()), 
DivideYMInterval(sum, count))
-                          case _: DayTimeIntervalType =>
-                            If(EqualTo(count, Literal(0L)),
-                              Literal(null, DayTimeIntervalType()), 
DivideDTInterval(sum, count))
-                          case _ =>
-                            // TODO deal with the overflow issue
-                            Divide(addCastIfNeeded(sum, avg.dataType),
-                              addCastIfNeeded(count, avg.dataType), false)
+                        avg.evaluateExpression transform {
+                          case a: Attribute if a.semanticEquals(avg.sum) =>
+                            addCastIfNeeded(sum, avg.sum.dataType)
+                          case a: Attribute if a.semanticEquals(avg.count) =>
+                            addCastIfNeeded(count, avg.count.dataType)
                         }
                     }
                   }.asInstanceOf[Seq[NamedExpression]]
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
index a5e3a71..67a0290 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
@@ -95,6 +95,14 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession 
with ExplainSuiteHel
         """CREATE TABLE "test"."view1" ("|col1" INTEGER, "|col2" 
INTEGER)""").executeUpdate()
       conn.prepareStatement(
         """CREATE TABLE "test"."view2" ("|col1" INTEGER, "|col3" 
INTEGER)""").executeUpdate()
+
+      conn.prepareStatement(
+        "CREATE TABLE \"test\".\"item\" (id INTEGER, name TEXT(32), price 
NUMERIC(23, 3))")
+        .executeUpdate()
+      conn.prepareStatement("INSERT INTO \"test\".\"item\" VALUES " +
+        "(1, 'bottle', 11111111111111111111.123)").executeUpdate()
+      conn.prepareStatement("INSERT INTO \"test\".\"item\" VALUES " +
+        "(1, 'bottle', 99999999999999999999.123)").executeUpdate()
     }
   }
 
@@ -484,8 +492,8 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession 
with ExplainSuiteHel
   test("show tables") {
     checkAnswer(sql("SHOW TABLES IN h2.test"),
       Seq(Row("test", "people", false), Row("test", "empty_table", false),
-        Row("test", "employee", false), Row("test", "dept", false), 
Row("test", "person", false),
-        Row("test", "view1", false), Row("test", "view2", false)))
+        Row("test", "employee", false), Row("test", "item", false), 
Row("test", "dept", false),
+        Row("test", "person", false), Row("test", "view1", false), Row("test", 
"view2", false)))
   }
 
   test("SQL API: create table as select") {
@@ -1106,4 +1114,37 @@ class JDBCV2Suite extends QueryTest with 
SharedSparkSession with ExplainSuiteHel
     checkAnswer(df2, Seq(Row("alex", 12000.00), Row("amy", 10000.00),
       Row("cathy", 9000.00), Row("david", 10000.00), Row("jen", 12000.00)))
   }
+
+  test("scan with aggregate push-down: partial push-down AVG with overflow") {
+    def createDataFrame: DataFrame = spark.read
+      .option("partitionColumn", "id")
+      .option("lowerBound", "0")
+      .option("upperBound", "2")
+      .option("numPartitions", "2")
+      .table("h2.test.item")
+      .agg(avg($"PRICE").as("avg"))
+
+    Seq(true, false).foreach { ansiEnabled =>
+      withSQLConf((SQLConf.ANSI_ENABLED.key, ansiEnabled.toString)) {
+        val df = createDataFrame
+        checkAggregateRemoved(df, false)
+        df.queryExecution.optimizedPlan.collect {
+          case _: DataSourceV2ScanRelation =>
+            val expected_plan_fragment =
+              "PushedAggregates: [SUM(PRICE), COUNT(PRICE)]"
+            checkKeywordsExistsInExplain(df, expected_plan_fragment)
+        }
+        if (ansiEnabled) {
+          val e = intercept[SparkException] {
+            df.collect()
+          }
+          assert(e.getCause.isInstanceOf[ArithmeticException])
+          assert(e.getCause.getMessage.contains("cannot be represented as 
Decimal") ||
+            e.getCause.getMessage.contains("Overflow in sum of decimals"))
+        } else {
+          checkAnswer(df, Seq(Row(null)))
+        }
+      }
+    }
+  }
 }

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to