This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch branch-3.3 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.3 by this push: new 68dfff7 [SPARK-37839][SQL][FOLLOWUP] Check overflow when DS V2 partial aggregate push-down `AVG` 68dfff7 is described below commit 68dfff767aa3faaaf3c614514dcd8ef322256579 Author: Jiaan Geng <belie...@163.com> AuthorDate: Thu Mar 31 19:18:58 2022 +0800 [SPARK-37839][SQL][FOLLOWUP] Check overflow when DS V2 partial aggregate push-down `AVG` ### What changes were proposed in this pull request? https://github.com/apache/spark/pull/35130 supports partial aggregate push-down `AVG` for DS V2. The behavior doesn't consistent with `Average` if occurs overflow in ansi mode. This PR closely follows the implement of `Average` to respect overflow in ansi mode. ### Why are the changes needed? Make the behavior consistent with `Average` if occurs overflow in ansi mode. ### Does this PR introduce _any_ user-facing change? 'Yes'. Users could see the exception about overflow throws in ansi mode. ### How was this patch tested? New tests. Closes #35320 from beliefer/SPARK-37839_followup. Authored-by: Jiaan Geng <belie...@163.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> (cherry picked from commit e6839ad7340bc9eb5df03df2a62110bdda805e6b) Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../catalyst/expressions/aggregate/Average.scala | 4 +- .../datasources/v2/V2ScanRelationPushDown.scala | 21 ++++------ .../org/apache/spark/sql/jdbc/JDBCV2Suite.scala | 45 +++++++++++++++++++++- 3 files changed, 52 insertions(+), 18 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala index 05f7eda..533f7f2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala @@ -76,8 +76,8 @@ case class Average( case _ => DoubleType } - private lazy val sum = AttributeReference("sum", sumDataType)() - private lazy val count = AttributeReference("count", LongType)() + lazy val sum = AttributeReference("sum", sumDataType)() + lazy val count = AttributeReference("count", LongType)() override lazy val aggBufferAttributes = sum :: count :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index c8ef8b0..5371829 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources.v2 import scala.collection.mutable -import org.apache.spark.sql.catalyst.expressions.{Alias, AliasHelper, And, Attribute, AttributeReference, Cast, Divide, DivideDTInterval, DivideYMInterval, EqualTo, Expression, If, IntegerLiteral, Literal, NamedExpression, PredicateHelper, ProjectionOverSchema, SortOrder, SubqueryExpression} +import org.apache.spark.sql.catalyst.expressions.{Alias, AliasHelper, And, Attribute, AttributeReference, Cast, Expression, IntegerLiteral, NamedExpression, PredicateHelper, ProjectionOverSchema, SortOrder, SubqueryExpression} import org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.optimizer.CollapseProject @@ -32,7 +32,7 @@ import org.apache.spark.sql.connector.expressions.filter.Predicate import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, V1Scan} import org.apache.spark.sql.execution.datasources.DataSourceStrategy import org.apache.spark.sql.sources -import org.apache.spark.sql.types.{DataType, DayTimeIntervalType, LongType, StructType, YearMonthIntervalType} +import org.apache.spark.sql.types.{DataType, LongType, StructType} import org.apache.spark.sql.util.SchemaUtils._ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper with AliasHelper { @@ -138,18 +138,11 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper wit case AggregateExpression(avg: aggregate.Average, _, isDistinct, _, _) => val sum = aggregate.Sum(avg.child).toAggregateExpression(isDistinct) val count = aggregate.Count(avg.child).toAggregateExpression(isDistinct) - // Closely follow `Average.evaluateExpression` - avg.dataType match { - case _: YearMonthIntervalType => - If(EqualTo(count, Literal(0L)), - Literal(null, YearMonthIntervalType()), DivideYMInterval(sum, count)) - case _: DayTimeIntervalType => - If(EqualTo(count, Literal(0L)), - Literal(null, DayTimeIntervalType()), DivideDTInterval(sum, count)) - case _ => - // TODO deal with the overflow issue - Divide(addCastIfNeeded(sum, avg.dataType), - addCastIfNeeded(count, avg.dataType), false) + avg.evaluateExpression transform { + case a: Attribute if a.semanticEquals(avg.sum) => + addCastIfNeeded(sum, avg.sum.dataType) + case a: Attribute if a.semanticEquals(avg.count) => + addCastIfNeeded(count, avg.count.dataType) } } }.asInstanceOf[Seq[NamedExpression]] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index a5e3a71..67a0290 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -95,6 +95,14 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel """CREATE TABLE "test"."view1" ("|col1" INTEGER, "|col2" INTEGER)""").executeUpdate() conn.prepareStatement( """CREATE TABLE "test"."view2" ("|col1" INTEGER, "|col3" INTEGER)""").executeUpdate() + + conn.prepareStatement( + "CREATE TABLE \"test\".\"item\" (id INTEGER, name TEXT(32), price NUMERIC(23, 3))") + .executeUpdate() + conn.prepareStatement("INSERT INTO \"test\".\"item\" VALUES " + + "(1, 'bottle', 11111111111111111111.123)").executeUpdate() + conn.prepareStatement("INSERT INTO \"test\".\"item\" VALUES " + + "(1, 'bottle', 99999999999999999999.123)").executeUpdate() } } @@ -484,8 +492,8 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel test("show tables") { checkAnswer(sql("SHOW TABLES IN h2.test"), Seq(Row("test", "people", false), Row("test", "empty_table", false), - Row("test", "employee", false), Row("test", "dept", false), Row("test", "person", false), - Row("test", "view1", false), Row("test", "view2", false))) + Row("test", "employee", false), Row("test", "item", false), Row("test", "dept", false), + Row("test", "person", false), Row("test", "view1", false), Row("test", "view2", false))) } test("SQL API: create table as select") { @@ -1106,4 +1114,37 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAnswer(df2, Seq(Row("alex", 12000.00), Row("amy", 10000.00), Row("cathy", 9000.00), Row("david", 10000.00), Row("jen", 12000.00))) } + + test("scan with aggregate push-down: partial push-down AVG with overflow") { + def createDataFrame: DataFrame = spark.read + .option("partitionColumn", "id") + .option("lowerBound", "0") + .option("upperBound", "2") + .option("numPartitions", "2") + .table("h2.test.item") + .agg(avg($"PRICE").as("avg")) + + Seq(true, false).foreach { ansiEnabled => + withSQLConf((SQLConf.ANSI_ENABLED.key, ansiEnabled.toString)) { + val df = createDataFrame + checkAggregateRemoved(df, false) + df.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregates: [SUM(PRICE), COUNT(PRICE)]" + checkKeywordsExistsInExplain(df, expected_plan_fragment) + } + if (ansiEnabled) { + val e = intercept[SparkException] { + df.collect() + } + assert(e.getCause.isInstanceOf[ArithmeticException]) + assert(e.getCause.getMessage.contains("cannot be represented as Decimal") || + e.getCause.getMessage.contains("Overflow in sum of decimals")) + } else { + checkAnswer(df, Seq(Row(null))) + } + } + } + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org