This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new d9e3d9b [SPARK-37644][SQL] Support datasource v2 complete aggregate pushdown d9e3d9b is described below commit d9e3d9b9d97e7a238062060675913e29b9184cfb Author: Jiaan Geng <belie...@163.com> AuthorDate: Thu Dec 23 23:05:20 2021 +0800 [SPARK-37644][SQL] Support datasource v2 complete aggregate pushdown ### What changes were proposed in this pull request? Currently , Spark supports push down aggregate with partial-agg and final-agg . For some data source (e.g. JDBC ) , we can avoid partial-agg and final-agg by running completely on database. ### Why are the changes needed? Improve performance for aggregate pushdown. ### Does this PR introduce _any_ user-facing change? 'No'. Just change the inner implement. ### How was this patch tested? New tests. Closes #34904 from beliefer/SPARK-37644. Authored-by: Jiaan Geng <belie...@163.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../connector/read/SupportsPushDownAggregates.java | 8 ++ .../datasources/v2/V2ScanRelationPushDown.scala | 101 +++++++++++++-------- .../datasources/v2/jdbc/JDBCScanBuilder.scala | 3 + .../org/apache/spark/sql/jdbc/JDBCV2Suite.scala | 101 ++++++++++++++++++++- 4 files changed, 173 insertions(+), 40 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java index 3e643b5..4e6c59e 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java @@ -46,6 +46,14 @@ import org.apache.spark.sql.connector.expressions.aggregate.Aggregation; public interface SupportsPushDownAggregates extends ScanBuilder { /** + * Whether the datasource support complete aggregation push-down. Spark could avoid partial-agg + * and final-agg when the aggregation operation can be pushed down to the datasource completely. + * + * @return true if the aggregation can be pushed down to datasource completely, false otherwise. + */ + default boolean supportCompletePushDown() { return false; } + + /** * Pushes down Aggregation to datasource. The order of the datasource scan output columns should * be: grouping columns, aggregate columns (in the same order as the aggregate functions in * the given Aggregation). diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index e7c06d0..3a792f4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources.v2 import scala.collection.mutable -import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeReference, Expression, IntegerLiteral, NamedExpression, PredicateHelper, ProjectionOverSchema, SubqueryExpression} +import org.apache.spark.sql.catalyst.expressions.{Alias, And, Attribute, AttributeReference, Cast, Expression, IntegerLiteral, NamedExpression, PredicateHelper, ProjectionOverSchema, SubqueryExpression} import org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.planning.ScanOperation @@ -30,7 +30,7 @@ import org.apache.spark.sql.connector.expressions.aggregate.Aggregation import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, V1Scan} import org.apache.spark.sql.execution.datasources.DataSourceStrategy import org.apache.spark.sql.sources -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{DataType, LongType, StructType} import org.apache.spark.sql.util.SchemaUtils._ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { @@ -131,7 +131,8 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { case (a: Attribute, b: Attribute) => b.withExprId(a.exprId) case (_, b) => b } - val output = groupAttrs ++ newOutput.drop(groupAttrs.length) + val aggOutput = newOutput.drop(groupAttrs.length) + val output = groupAttrs ++ aggOutput logInfo( s""" @@ -147,40 +148,59 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { val scanRelation = DataSourceV2ScanRelation(sHolder.relation, wrappedScan, output) - val plan = Aggregate( - output.take(groupingExpressions.length), resultExpressions, scanRelation) - - // scalastyle:off - // Change the optimized logical plan to reflect the pushed down aggregate - // e.g. TABLE t (c1 INT, c2 INT, c3 INT) - // SELECT min(c1), max(c1) FROM t GROUP BY c2; - // The original logical plan is - // Aggregate [c2#10],[min(c1#9) AS min(c1)#17, max(c1#9) AS max(c1)#18] - // +- RelationV2[c1#9, c2#10] ... - // - // After change the V2ScanRelation output to [c2#10, min(c1)#21, max(c1)#22] - // we have the following - // !Aggregate [c2#10], [min(c1#9) AS min(c1)#17, max(c1#9) AS max(c1)#18] - // +- RelationV2[c2#10, min(c1)#21, max(c1)#22] ... - // - // We want to change it to - // == Optimized Logical Plan == - // Aggregate [c2#10], [min(min(c1)#21) AS min(c1)#17, max(max(c1)#22) AS max(c1)#18] - // +- RelationV2[c2#10, min(c1)#21, max(c1)#22] ... - // scalastyle:on - val aggOutput = output.drop(groupAttrs.length) - plan.transformExpressions { - case agg: AggregateExpression => - val ordinal = aggExprToOutputOrdinal(agg.canonicalized) - val aggFunction: aggregate.AggregateFunction = - agg.aggregateFunction match { - case max: aggregate.Max => max.copy(child = aggOutput(ordinal)) - case min: aggregate.Min => min.copy(child = aggOutput(ordinal)) - case sum: aggregate.Sum => sum.copy(child = aggOutput(ordinal)) - case _: aggregate.Count => aggregate.Sum(aggOutput(ordinal)) - case other => other - } - agg.copy(aggregateFunction = aggFunction) + if (r.supportCompletePushDown()) { + val projectExpressions = resultExpressions.map { expr => + // TODO At present, only push down group by attribute is supported. + // In future, more attribute conversion is extended here. e.g. GetStructField + expr.transform { + case agg: AggregateExpression => + val ordinal = aggExprToOutputOrdinal(agg.canonicalized) + val child = + addCastIfNeeded(aggOutput(ordinal), agg.resultAttribute.dataType) + Alias(child, agg.resultAttribute.name)(agg.resultAttribute.exprId) + } + }.asInstanceOf[Seq[NamedExpression]] + Project(projectExpressions, scanRelation) + } else { + val plan = Aggregate( + output.take(groupingExpressions.length), resultExpressions, scanRelation) + + // scalastyle:off + // Change the optimized logical plan to reflect the pushed down aggregate + // e.g. TABLE t (c1 INT, c2 INT, c3 INT) + // SELECT min(c1), max(c1) FROM t GROUP BY c2; + // The original logical plan is + // Aggregate [c2#10],[min(c1#9) AS min(c1)#17, max(c1#9) AS max(c1)#18] + // +- RelationV2[c1#9, c2#10] ... + // + // After change the V2ScanRelation output to [c2#10, min(c1)#21, max(c1)#22] + // we have the following + // !Aggregate [c2#10], [min(c1#9) AS min(c1)#17, max(c1#9) AS max(c1)#18] + // +- RelationV2[c2#10, min(c1)#21, max(c1)#22] ... + // + // We want to change it to + // == Optimized Logical Plan == + // Aggregate [c2#10], [min(min(c1)#21) AS min(c1)#17, max(max(c1)#22) AS max(c1)#18] + // +- RelationV2[c2#10, min(c1)#21, max(c1)#22] ... + // scalastyle:on + plan.transformExpressions { + case agg: AggregateExpression => + val ordinal = aggExprToOutputOrdinal(agg.canonicalized) + val aggAttribute = aggOutput(ordinal) + val aggFunction: aggregate.AggregateFunction = + agg.aggregateFunction match { + case max: aggregate.Max => + max.copy(child = addCastIfNeeded(aggAttribute, max.child.dataType)) + case min: aggregate.Min => + min.copy(child = addCastIfNeeded(aggAttribute, min.child.dataType)) + case sum: aggregate.Sum => + sum.copy(child = addCastIfNeeded(aggAttribute, sum.child.dataType)) + case _: aggregate.Count => + aggregate.Sum(addCastIfNeeded(aggAttribute, LongType)) + case other => other + } + agg.copy(aggregateFunction = aggFunction) + } } } case _ => aggNode @@ -189,6 +209,13 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { } } + private def addCastIfNeeded(aggAttribute: AttributeReference, aggDataType: DataType) = + if (aggAttribute.dataType == aggDataType) { + aggAttribute + } else { + Cast(aggAttribute, aggDataType) + } + def applyColumnPruning(plan: LogicalPlan): LogicalPlan = plan.transform { case ScanOperation(project, filters, sHolder: ScanBuilderHolder) => // column pruning diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala index 1760122..01722e8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala @@ -72,6 +72,9 @@ case class JDBCScanBuilder( private var pushedGroupByCols: Option[Array[String]] = None + override def supportCompletePushDown: Boolean = + jdbcOptions.numPartitions.map(_ == 1).getOrElse(true) + override def pushAggregation(aggregation: Aggregation): Boolean = { if (!jdbcOptions.pushDownAggregate) return false diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index 5df875b..eb3c8d6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -23,7 +23,7 @@ import java.util.Properties import org.apache.spark.SparkConf import org.apache.spark.sql.{DataFrame, ExplainSuiteHelper, QueryTest, Row} import org.apache.spark.sql.catalyst.analysis.CannotReplaceMissingTableException -import org.apache.spark.sql.catalyst.plans.logical.{Filter, Sort} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, Sort} import org.apache.spark.sql.connector.expressions.{FieldReference, NullOrdering, SortDirection, SortValue} import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanRelation, V1ScanWrapper} import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog @@ -388,6 +388,17 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel } } + private def checkAggregateRemoved(df: DataFrame, removed: Boolean = true): Unit = { + val aggregates = df.queryExecution.optimizedPlan.collect { + case agg: Aggregate => agg + } + if (removed) { + assert(aggregates.isEmpty) + } else { + assert(aggregates.nonEmpty) + } + } + test("scan with aggregate push-down: MAX MIN with filter and group by") { val df = sql("select MAX(SaLaRY), MIN(BONUS) FROM h2.test.employee where dept > 0" + " group by DePt") @@ -395,6 +406,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel case f: Filter => f } assert(filters.isEmpty) + checkAggregateRemoved(df) df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => val expected_plan_fragment = @@ -412,6 +424,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel case f: Filter => f } assert(filters.isEmpty) + checkAggregateRemoved(df) df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => val expected_plan_fragment = @@ -425,6 +438,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel test("scan with aggregate push-down: aggregate + number") { val df = sql("select MAX(SALARY) + 1 FROM h2.test.employee") + checkAggregateRemoved(df) df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => val expected_plan_fragment = @@ -436,6 +450,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel test("scan with aggregate push-down: COUNT(*)") { val df = sql("select COUNT(*) FROM h2.test.employee") + checkAggregateRemoved(df) df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => val expected_plan_fragment = @@ -447,6 +462,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel test("scan with aggregate push-down: COUNT(col)") { val df = sql("select COUNT(DEPT) FROM h2.test.employee") + checkAggregateRemoved(df) df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => val expected_plan_fragment = @@ -458,6 +474,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel test("scan with aggregate push-down: COUNT(DISTINCT col)") { val df = sql("select COUNT(DISTINCT DEPT) FROM h2.test.employee") + checkAggregateRemoved(df) df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => val expected_plan_fragment = @@ -469,6 +486,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel test("scan with aggregate push-down: SUM without filer and group by") { val df = sql("SELECT SUM(SALARY) FROM h2.test.employee") + checkAggregateRemoved(df) df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => val expected_plan_fragment = @@ -480,6 +498,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel test("scan with aggregate push-down: DISTINCT SUM without filer and group by") { val df = sql("SELECT SUM(DISTINCT SALARY) FROM h2.test.employee") + checkAggregateRemoved(df) df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => val expected_plan_fragment = @@ -491,6 +510,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel test("scan with aggregate push-down: SUM with group by") { val df = sql("SELECT SUM(SALARY) FROM h2.test.employee GROUP BY DEPT") + checkAggregateRemoved(df) df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => val expected_plan_fragment = @@ -504,6 +524,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel test("scan with aggregate push-down: DISTINCT SUM with group by") { val df = sql("SELECT SUM(DISTINCT SALARY) FROM h2.test.employee GROUP BY DEPT") + checkAggregateRemoved(df) df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => val expected_plan_fragment = @@ -518,10 +539,11 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel test("scan with aggregate push-down: with multiple group by columns") { val df = sql("select MAX(SALARY), MIN(BONUS) FROM h2.test.employee where dept > 0" + " group by DEPT, NAME") - val filters11 = df.queryExecution.optimizedPlan.collect { + val filters = df.queryExecution.optimizedPlan.collect { case f: Filter => f } - assert(filters11.isEmpty) + assert(filters.isEmpty) + checkAggregateRemoved(df) df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => val expected_plan_fragment = @@ -534,6 +556,60 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel Row(10000, 1000), Row(12000, 1200))) } + test("scan with aggregate push-down: with concat multiple group key in project") { + val df1 = sql("select concat_ws('#', DEPT, NAME), MAX(SALARY) FROM h2.test.employee" + + " where dept > 0 group by DEPT, NAME") + val filters1 = df1.queryExecution.optimizedPlan.collect { + case f: Filter => f + } + assert(filters1.isEmpty) + checkAggregateRemoved(df1) + df1.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregates: [MAX(SALARY)], " + + "PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " + + "PushedGroupByColumns: [DEPT, NAME]" + checkKeywordsExistsInExplain(df1, expected_plan_fragment) + } + checkAnswer(df1, Seq(Row("1#amy", 10000), Row("1#cathy", 9000), Row("2#alex", 12000), + Row("2#david", 10000), Row("6#jen", 12000))) + + val df2 = sql("select concat_ws('#', DEPT, NAME), MAX(SALARY) + MIN(BONUS)" + + " FROM h2.test.employee where dept > 0 group by DEPT, NAME") + val filters2 = df2.queryExecution.optimizedPlan.collect { + case f: Filter => f + } + assert(filters2.isEmpty) + checkAggregateRemoved(df2) + df2.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregates: [MAX(SALARY), MIN(BONUS)], " + + "PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " + + "PushedGroupByColumns: [DEPT, NAME]" + checkKeywordsExistsInExplain(df2, expected_plan_fragment) + } + checkAnswer(df2, Seq(Row("1#amy", 11000), Row("1#cathy", 10200), Row("2#alex", 13200), + Row("2#david", 11300), Row("6#jen", 13200))) + + val df3 = sql("select concat_ws('#', DEPT, NAME), MAX(SALARY) + MIN(BONUS)" + + " FROM h2.test.employee where dept > 0 group by concat_ws('#', DEPT, NAME)") + val filters3 = df3.queryExecution.optimizedPlan.collect { + case f: Filter => f + } + assert(filters3.isEmpty) + checkAggregateRemoved(df3, false) + df3.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " + checkKeywordsExistsInExplain(df3, expected_plan_fragment) + } + checkAnswer(df3, Seq(Row("1#amy", 11000), Row("1#cathy", 10200), Row("2#alex", 13200), + Row("2#david", 11300), Row("6#jen", 13200))) + } + test("scan with aggregate push-down: with having clause") { val df = sql("select MAX(SALARY), MIN(BONUS) FROM h2.test.employee where dept > 0" + " group by DEPT having MIN(BONUS) > 1000") @@ -541,6 +617,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel case f: Filter => f // filter over aggregate not push down } assert(filters.nonEmpty) + checkAggregateRemoved(df) df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => val expected_plan_fragment = @@ -556,6 +633,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel val df = sql("select * from h2.test.employee") .groupBy($"DEPT") .min("SALARY").as("total") + checkAggregateRemoved(df) df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => val expected_plan_fragment = @@ -579,6 +657,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel case f: Filter => f } assert(filters.nonEmpty) // filter over aggregate not pushed down + checkAggregateRemoved(df) query.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => val expected_plan_fragment = @@ -594,6 +673,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel val df = spark.table("h2.test.employee") val decrease = udf { (x: Double, y: Double) => x - y } val query = df.select(decrease(sum($"SALARY"), sum($"BONUS")).as("value")) + checkAggregateRemoved(query) query.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => val expected_plan_fragment = @@ -607,6 +687,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel val cols = Seq("a", "b", "c", "d") val df1 = sql("select * from h2.test.employee").toDF(cols: _*) val df2 = df1.groupBy().sum("c") + checkAggregateRemoved(df2, false) df2.queryExecution.optimizedPlan.collect { case relation: DataSourceV2ScanRelation => relation.scan match { case v1: V1ScanWrapper => @@ -625,6 +706,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .filter("SALARY > 100") .filter(name($"shortName")) .agg(sum($"SALARY").as("sum_salary")) + checkAggregateRemoved(query, false) query.queryExecution.optimizedPlan.collect { case relation: DataSourceV2ScanRelation => relation.scan match { case v1: V1ScanWrapper => @@ -633,4 +715,17 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel } checkAnswer(query, Seq(Row(29000.0))) } + + test("scan with aggregate push-down: SUM(CASE WHEN) with group by") { + val df = + sql("SELECT SUM(CASE WHEN SALARY > 0 THEN 1 ELSE 0 END) FROM h2.test.employee GROUP BY DEPT") + checkAggregateRemoved(df, false) + df.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedFilters: [], " + checkKeywordsExistsInExplain(df, expected_plan_fragment) + } + checkAnswer(df, Seq(Row(1), Row(2), Row(2))) + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org