This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new aec79534abf [SPARK-39148][SQL] DS V2 aggregate push down can work with 
OFFSET or LIMIT
aec79534abf is described below

commit aec79534abf819e7981babc73d13450ea8e49b08
Author: Wenchen Fan <wenc...@databricks.com>
AuthorDate: Wed Jul 20 11:13:08 2022 +0800

    [SPARK-39148][SQL] DS V2 aggregate push down can work with OFFSET or LIMIT
    
    ### What changes were proposed in this pull request?
    
    This PR refactors the v2 agg pushdown code. The main change is, now we 
don't build the `Scan` immediately when pushing agg. We did it so before 
because we want to know the data schema with agg pushed, then we can add cast 
when rewriting the query plan after pushdown. But the problem is, we build 
`Scan` too early and can't push down any more operators, while it's common to 
see LIMIT/OFFSET after agg.
    
    The idea of the refactor is, we don't need to know the data schema with agg 
pushed. We just give an expectation (the data type should be the same of Spark 
agg functions), use it to define the output of `ScanBuilderHolder`, and then 
rewrite the query plan. Later on, when we build the `Scan` and replace 
`ScanBuilderHolder` with `DataSourceV2ScanRelation`, we check the actual data 
schema and add a `Project` to do type cast if necessary.
    
    ### Why are the changes needed?
    
    support pushing down LIMIT/OFFSET after agg.
    
    ### Does this PR introduce _any_ user-facing change?
    
    no
    
    ### How was this patch tested?
    
    updated tests
    
    Closes #37195 from cloud-fan/agg.
    
    Lead-authored-by: Wenchen Fan <wenc...@databricks.com>
    Co-authored-by: Wenchen Fan <cloud0...@gmail.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../datasources/v2/V2ScanRelationPushDown.scala    | 419 +++++++++++----------
 .../org/apache/spark/sql/jdbc/JDBCV2Suite.scala    |  38 +-
 2 files changed, 254 insertions(+), 203 deletions(-)

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
index 8951c37e127..f1e0e6d80c5 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
@@ -26,12 +26,12 @@ import org.apache.spark.sql.catalyst.planning.ScanOperation
 import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, 
LeafNode, Limit, LimitAndOffset, LocalLimit, LogicalPlan, Offset, 
OffsetAndLimit, Project, Sample, Sort}
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.connector.expressions.{SortOrder => V2SortOrder}
-import org.apache.spark.sql.connector.expressions.aggregate.{Aggregation, Avg, 
Count, GeneralAggregateFunc, Sum, UserDefinedAggregateFunc}
+import org.apache.spark.sql.connector.expressions.aggregate.{Aggregation, Avg, 
Count, CountStar, Max, Min, Sum}
 import org.apache.spark.sql.connector.expressions.filter.Predicate
 import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, 
SupportsPushDownAggregates, SupportsPushDownFilters, V1Scan}
 import org.apache.spark.sql.execution.datasources.DataSourceStrategy
 import org.apache.spark.sql.sources
-import org.apache.spark.sql.types.{DataType, IntegerType, LongType, StructType}
+import org.apache.spark.sql.types.{DataType, DecimalType, IntegerType, 
StructType}
 import org.apache.spark.sql.util.SchemaUtils._
 
 object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper 
with AliasHelper {
@@ -44,6 +44,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with 
PredicateHelper wit
       pushDownFilters,
       pushDownAggregates,
       pushDownLimitAndOffset,
+      buildScanWithPushedAggregate,
       pruneColumns)
 
     pushdownRules.foldLeft(plan) { (newPlan, pushDownRule) =>
@@ -92,189 +93,201 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] 
with PredicateHelper wit
 
   def pushDownAggregates(plan: LogicalPlan): LogicalPlan = plan.transform {
     // update the scan builder with agg pushdown and return a new plan with 
agg pushed
-    case aggNode @ Aggregate(groupingExpressions, resultExpressions, child) =>
-      child match {
-        case ScanOperation(project, filters, sHolder: ScanBuilderHolder)
-          if filters.isEmpty && CollapseProject.canCollapseExpressions(
-            resultExpressions, project, alwaysInline = true) =>
-          sHolder.builder match {
-            case r: SupportsPushDownAggregates =>
-              val aliasMap = getAliasMap(project)
-              val actualResultExprs = 
resultExpressions.map(replaceAliasButKeepName(_, aliasMap))
-              val actualGroupExprs = groupingExpressions.map(replaceAlias(_, 
aliasMap))
-
-              val aggExprToOutputOrdinal = mutable.HashMap.empty[Expression, 
Int]
-              val aggregates = collectAggregates(actualResultExprs, 
aggExprToOutputOrdinal)
-              val normalizedAggregates = DataSourceStrategy.normalizeExprs(
-                aggregates, 
sHolder.relation.output).asInstanceOf[Seq[AggregateExpression]]
-              val normalizedGroupingExpressions = 
DataSourceStrategy.normalizeExprs(
-                actualGroupExprs, sHolder.relation.output)
-              val translatedAggregates = 
DataSourceStrategy.translateAggregation(
-                normalizedAggregates, normalizedGroupingExpressions)
-              val (finalResultExpressions, finalAggregates, 
finalTranslatedAggregates) = {
-                if (translatedAggregates.isEmpty ||
-                  r.supportCompletePushDown(translatedAggregates.get) ||
-                  
translatedAggregates.get.aggregateExpressions().forall(!_.isInstanceOf[Avg])) {
-                  (actualResultExprs, aggregates, translatedAggregates)
-                } else {
-                  // scalastyle:off
-                  // The data source doesn't support the complete push-down of 
this aggregation.
-                  // Here we translate `AVG` to `SUM / COUNT`, so that it's 
more likely to be
-                  // pushed, completely or partially.
-                  // e.g. TABLE t (c1 INT, c2 INT, c3 INT)
-                  // SELECT avg(c1) FROM t GROUP BY c2;
-                  // The original logical plan is
-                  // Aggregate [c2#10],[avg(c1#9) AS avg(c1)#19]
-                  // +- ScanOperation[...]
-                  //
-                  // After convert avg(c1#9) to sum(c1#9)/count(c1#9)
-                  // we have the following
-                  // Aggregate [c2#10],[sum(c1#9)/count(c1#9) AS avg(c1)#19]
-                  // +- ScanOperation[...]
-                  // scalastyle:on
-                  val newResultExpressions = actualResultExprs.map { expr =>
-                    expr.transform {
-                      case AggregateExpression(avg: aggregate.Average, _, 
isDistinct, _, _) =>
-                        val sum = 
aggregate.Sum(avg.child).toAggregateExpression(isDistinct)
-                        val count = 
aggregate.Count(avg.child).toAggregateExpression(isDistinct)
-                        avg.evaluateExpression transform {
-                          case a: Attribute if a.semanticEquals(avg.sum) =>
-                            addCastIfNeeded(sum, avg.sum.dataType)
-                          case a: Attribute if a.semanticEquals(avg.count) =>
-                            addCastIfNeeded(count, avg.count.dataType)
-                        }
-                    }
-                  }.asInstanceOf[Seq[NamedExpression]]
-                  // Because aggregate expressions changed, translate them 
again.
-                  aggExprToOutputOrdinal.clear()
-                  val newAggregates =
-                    collectAggregates(newResultExpressions, 
aggExprToOutputOrdinal)
-                  val newNormalizedAggregates = 
DataSourceStrategy.normalizeExprs(
-                    newAggregates, 
sHolder.relation.output).asInstanceOf[Seq[AggregateExpression]]
-                  (newResultExpressions, newAggregates, 
DataSourceStrategy.translateAggregation(
-                    newNormalizedAggregates, normalizedGroupingExpressions))
+    case agg: Aggregate => rewriteAggregate(agg)
+  }
+
+  private def rewriteAggregate(agg: Aggregate): LogicalPlan = agg.child match {
+    case ScanOperation(project, Nil, holder @ ScanBuilderHolder(_, _,
+        r: SupportsPushDownAggregates)) if 
CollapseProject.canCollapseExpressions(
+        agg.aggregateExpressions, project, alwaysInline = true) =>
+      val aliasMap = getAliasMap(project)
+      val actualResultExprs = 
agg.aggregateExpressions.map(replaceAliasButKeepName(_, aliasMap))
+      val actualGroupExprs = agg.groupingExpressions.map(replaceAlias(_, 
aliasMap))
+
+      val aggExprToOutputOrdinal = mutable.HashMap.empty[Expression, Int]
+      val aggregates = collectAggregates(actualResultExprs, 
aggExprToOutputOrdinal)
+      val normalizedAggExprs = DataSourceStrategy.normalizeExprs(
+        aggregates, 
holder.relation.output).asInstanceOf[Seq[AggregateExpression]]
+      val normalizedGroupingExpr = DataSourceStrategy.normalizeExprs(
+        actualGroupExprs, holder.relation.output)
+      val translatedAggOpt = DataSourceStrategy.translateAggregation(
+        normalizedAggExprs, normalizedGroupingExpr)
+      if (translatedAggOpt.isEmpty) {
+        // Cannot translate the catalyst aggregate, return the query plan 
unchanged.
+        return agg
+      }
+
+      val (finalResultExprs, finalAggExprs, translatedAgg, 
canCompletePushDown) = {
+        if (r.supportCompletePushDown(translatedAggOpt.get)) {
+          (actualResultExprs, normalizedAggExprs, translatedAggOpt.get, true)
+        } else if 
(!translatedAggOpt.get.aggregateExpressions().exists(_.isInstanceOf[Avg])) {
+          (actualResultExprs, normalizedAggExprs, translatedAggOpt.get, false)
+        } else {
+          // scalastyle:off
+          // The data source doesn't support the complete push-down of this 
aggregation.
+          // Here we translate `AVG` to `SUM / COUNT`, so that it's more 
likely to be
+          // pushed, completely or partially.
+          // e.g. TABLE t (c1 INT, c2 INT, c3 INT)
+          // SELECT avg(c1) FROM t GROUP BY c2;
+          // The original logical plan is
+          // Aggregate [c2#10],[avg(c1#9) AS avg(c1)#19]
+          // +- ScanOperation[...]
+          //
+          // After convert avg(c1#9) to sum(c1#9)/count(c1#9)
+          // we have the following
+          // Aggregate [c2#10],[sum(c1#9)/count(c1#9) AS avg(c1)#19]
+          // +- ScanOperation[...]
+          // scalastyle:on
+          val newResultExpressions = actualResultExprs.map { expr =>
+            expr.transform {
+              case AggregateExpression(avg: aggregate.Average, _, isDistinct, 
_, _) =>
+                val sum = 
aggregate.Sum(avg.child).toAggregateExpression(isDistinct)
+                val count = 
aggregate.Count(avg.child).toAggregateExpression(isDistinct)
+                avg.evaluateExpression transform {
+                  case a: Attribute if a.semanticEquals(avg.sum) =>
+                    addCastIfNeeded(sum, avg.sum.dataType)
+                  case a: Attribute if a.semanticEquals(avg.count) =>
+                    addCastIfNeeded(count, avg.count.dataType)
                 }
-              }
+            }
+          }.asInstanceOf[Seq[NamedExpression]]
+          // Because aggregate expressions changed, translate them again.
+          aggExprToOutputOrdinal.clear()
+          val newAggregates =
+            collectAggregates(newResultExpressions, aggExprToOutputOrdinal)
+          val newNormalizedAggExprs = DataSourceStrategy.normalizeExprs(
+            newAggregates, 
holder.relation.output).asInstanceOf[Seq[AggregateExpression]]
+          val newTranslatedAggOpt = DataSourceStrategy.translateAggregation(
+            newNormalizedAggExprs, normalizedGroupingExpr)
+          if (newTranslatedAggOpt.isEmpty) {
+            // Ideally we should never reach here. But if we end up with not 
able to translate
+            // new aggregate with AVG replaced by SUM/COUNT, revert to the 
original one.
+            (actualResultExprs, normalizedAggExprs, translatedAggOpt.get, 
false)
+          } else {
+            (newResultExpressions, newNormalizedAggExprs, 
newTranslatedAggOpt.get,
+              r.supportCompletePushDown(newTranslatedAggOpt.get))
+          }
+        }
+      }
 
-              if (finalTranslatedAggregates.isEmpty) {
-                aggNode // return original plan node
-              } else if 
(!r.supportCompletePushDown(finalTranslatedAggregates.get) &&
-                !supportPartialAggPushDown(finalTranslatedAggregates.get)) {
-                aggNode // return original plan node
-              } else {
-                val pushedAggregates = 
finalTranslatedAggregates.filter(r.pushAggregation)
-                if (pushedAggregates.isEmpty) {
-                  aggNode // return original plan node
-                } else {
-                  // No need to do column pruning because only the aggregate 
columns are used as
-                  // DataSourceV2ScanRelation output columns. All the other 
columns are not
-                  // included in the output.
-                  val scan = sHolder.builder.build()
-
-                  // scalastyle:off
-                  // use the group by columns and aggregate columns as the 
output columns
-                  // e.g. TABLE t (c1 INT, c2 INT, c3 INT)
-                  // SELECT min(c1), max(c1) FROM t GROUP BY c2;
-                  // Use c2, min(c1), max(c1) as output for 
DataSourceV2ScanRelation
-                  // We want to have the following logical plan:
-                  // == Optimized Logical Plan ==
-                  // Aggregate [c2#10], [min(min(c1)#21) AS min(c1)#17, 
max(max(c1)#22) AS max(c1)#18]
-                  // +- RelationV2[c2#10, min(c1)#21, max(c1)#22]
-                  // scalastyle:on
-                  val newOutput = scan.readSchema().toAttributes
-                  assert(newOutput.length == groupingExpressions.length + 
finalAggregates.length)
-                  val groupByExprToOutputOrdinal = 
mutable.HashMap.empty[Expression, Int]
-                  val groupAttrs = 
normalizedGroupingExpressions.zip(newOutput).zipWithIndex.map {
-                    case ((a: Attribute, b: Attribute), _) => 
b.withExprId(a.exprId)
-                    case ((expr, attr), ordinal) =>
-                      if 
(!groupByExprToOutputOrdinal.contains(expr.canonicalized)) {
-                        groupByExprToOutputOrdinal(expr.canonicalized) = 
ordinal
-                      }
-                      attr
-                  }
-                  val aggOutput = newOutput.drop(groupAttrs.length)
-                  val output = groupAttrs ++ aggOutput
-
-                  logInfo(
-                    s"""
-                       |Pushing operators to ${sHolder.relation.name}
-                       |Pushed Aggregate Functions:
-                       | 
${pushedAggregates.get.aggregateExpressions.mkString(", ")}
-                       |Pushed Group by:
-                       | ${pushedAggregates.get.groupByExpressions.mkString(", 
")}
-                       |Output: ${output.mkString(", ")}
-                      """.stripMargin)
-
-                  val wrappedScan = getWrappedScan(scan, sHolder, 
pushedAggregates)
-                  val scanRelation =
-                    DataSourceV2ScanRelation(sHolder.relation, wrappedScan, 
output)
-                  if (r.supportCompletePushDown(pushedAggregates.get)) {
-                    val projectExpressions = finalResultExpressions.map { expr 
=>
-                      expr.transformDown {
-                        case agg: AggregateExpression =>
-                          val ordinal = 
aggExprToOutputOrdinal(agg.canonicalized)
-                          val child =
-                            addCastIfNeeded(aggOutput(ordinal), 
agg.resultAttribute.dataType)
-                          Alias(child, 
agg.resultAttribute.name)(agg.resultAttribute.exprId)
-                        case expr if 
groupByExprToOutputOrdinal.contains(expr.canonicalized) =>
-                          val ordinal = 
groupByExprToOutputOrdinal(expr.canonicalized)
-                          addCastIfNeeded(groupAttrs(ordinal), expr.dataType)
-                      }
-                    }.asInstanceOf[Seq[NamedExpression]]
-                    Project(projectExpressions, scanRelation)
+      if (!canCompletePushDown && !supportPartialAggPushDown(translatedAgg)) {
+        return agg
+      }
+      if (!r.pushAggregation(translatedAgg)) {
+        return agg
+      }
+
+      // scalastyle:off
+      // We name the output columns of group expressions and aggregate 
functions by
+      // ordinal: `group_col_0`, `group_col_1`, ..., `agg_func_0`, 
`agg_func_1`, ...
+      // e.g. TABLE t (c1 INT, c2 INT, c3 INT)
+      // SELECT min(c1), max(c1) FROM t GROUP BY c2;
+      // Use group_col_0, agg_func_0, agg_func_1 as output for 
ScanBuilderHolder.
+      // We want to have the following logical plan:
+      // == Optimized Logical Plan ==
+      // Aggregate [group_col_0#10], [min(agg_func_0#21) AS min(c1)#17, 
max(agg_func_1#22) AS max(c1)#18]
+      // +- ScanBuilderHolder[group_col_0#10, agg_func_0#21, agg_func_1#22]
+      // Later, we build the `Scan` instance and convert ScanBuilderHolder to 
DataSourceV2ScanRelation.
+      // scalastyle:on
+      val groupOutput = normalizedGroupingExpr.zipWithIndex.map { case (e, i) 
=>
+        AttributeReference(s"group_col_$i", e.dataType)()
+      }
+      val aggOutput = finalAggExprs.zipWithIndex.map { case (e, i) =>
+        AttributeReference(s"agg_func_$i", e.dataType)()
+      }
+      val newOutput = groupOutput ++ aggOutput
+      val groupByExprToOutputOrdinal = mutable.HashMap.empty[Expression, Int]
+      normalizedGroupingExpr.zipWithIndex.foreach { case (expr, ordinal) =>
+        if (!groupByExprToOutputOrdinal.contains(expr.canonicalized)) {
+          groupByExprToOutputOrdinal(expr.canonicalized) = ordinal
+        }
+      }
+
+      holder.pushedAggregate = Some(translatedAgg)
+      holder.output = newOutput
+      logInfo(
+        s"""
+           |Pushing operators to ${holder.relation.name}
+           |Pushed Aggregate Functions:
+           | ${translatedAgg.aggregateExpressions().mkString(", ")}
+           |Pushed Group by:
+           | ${translatedAgg.groupByExpressions.mkString(", ")}
+         """.stripMargin)
+
+      if (canCompletePushDown) {
+        val projectExpressions = finalResultExprs.map { expr =>
+          expr.transformDown {
+            case agg: AggregateExpression =>
+              val ordinal = aggExprToOutputOrdinal(agg.canonicalized)
+              Alias(aggOutput(ordinal), 
agg.resultAttribute.name)(agg.resultAttribute.exprId)
+            case expr if 
groupByExprToOutputOrdinal.contains(expr.canonicalized) =>
+              val ordinal = groupByExprToOutputOrdinal(expr.canonicalized)
+              expr match {
+                case ne: NamedExpression => Alias(groupOutput(ordinal), 
ne.name)(ne.exprId)
+                case _ => groupOutput(ordinal)
+              }
+          }
+        }.asInstanceOf[Seq[NamedExpression]]
+        Project(projectExpressions, holder)
+      } else {
+        // scalastyle:off
+        // Change the optimized logical plan to reflect the pushed down 
aggregate
+        // e.g. TABLE t (c1 INT, c2 INT, c3 INT)
+        // SELECT min(c1), max(c1) FROM t GROUP BY c2;
+        // The original logical plan is
+        // Aggregate [c2#10],[min(c1#9) AS min(c1)#17, max(c1#9) AS max(c1)#18]
+        // +- RelationV2[c1#9, c2#10] ...
+        //
+        // After change the V2ScanRelation output to [c2#10, min(c1)#21, 
max(c1)#22]
+        // we have the following
+        // !Aggregate [c2#10], [min(c1#9) AS min(c1)#17, max(c1#9) AS 
max(c1)#18]
+        // +- RelationV2[c2#10, min(c1)#21, max(c1)#22] ...
+        //
+        // We want to change it to
+        // == Optimized Logical Plan ==
+        // Aggregate [c2#10], [min(min(c1)#21) AS min(c1)#17, max(max(c1)#22) 
AS max(c1)#18]
+        // +- RelationV2[c2#10, min(c1)#21, max(c1)#22] ...
+        // scalastyle:on
+        val aggExprs = finalResultExprs.map(_.transform {
+          case agg: AggregateExpression =>
+            val ordinal = aggExprToOutputOrdinal(agg.canonicalized)
+            val aggAttribute = aggOutput(ordinal)
+            val aggFunction: aggregate.AggregateFunction =
+              agg.aggregateFunction match {
+                case max: aggregate.Max =>
+                  max.copy(child = aggAttribute)
+                case min: aggregate.Min =>
+                  min.copy(child = aggAttribute)
+                case sum: aggregate.Sum =>
+                  // To keep the dataType of `Sum` unchanged, we need to cast 
the
+                  // data-source-aggregated result to `Sum.child.dataType` if 
it's decimal.
+                  // See `SumBase.resultType`
+                  val newChild = if (sum.dataType.isInstanceOf[DecimalType]) {
+                    addCastIfNeeded(aggAttribute, sum.child.dataType)
                   } else {
-                    val plan = 
Aggregate(output.take(groupingExpressions.length),
-                      finalResultExpressions, scanRelation)
-
-                    // scalastyle:off
-                    // Change the optimized logical plan to reflect the pushed 
down aggregate
-                    // e.g. TABLE t (c1 INT, c2 INT, c3 INT)
-                    // SELECT min(c1), max(c1) FROM t GROUP BY c2;
-                    // The original logical plan is
-                    // Aggregate [c2#10],[min(c1#9) AS min(c1)#17, max(c1#9) 
AS max(c1)#18]
-                    // +- RelationV2[c1#9, c2#10] ...
-                    //
-                    // After change the V2ScanRelation output to [c2#10, 
min(c1)#21, max(c1)#22]
-                    // we have the following
-                    // !Aggregate [c2#10], [min(c1#9) AS min(c1)#17, max(c1#9) 
AS max(c1)#18]
-                    // +- RelationV2[c2#10, min(c1)#21, max(c1)#22] ...
-                    //
-                    // We want to change it to
-                    // == Optimized Logical Plan ==
-                    // Aggregate [c2#10], [min(min(c1)#21) AS min(c1)#17, 
max(max(c1)#22) AS max(c1)#18]
-                    // +- RelationV2[c2#10, min(c1)#21, max(c1)#22] ...
-                    // scalastyle:on
-                    plan.transformExpressions {
-                      case agg: AggregateExpression =>
-                        val ordinal = aggExprToOutputOrdinal(agg.canonicalized)
-                        val aggAttribute = aggOutput(ordinal)
-                        val aggFunction: aggregate.AggregateFunction =
-                          agg.aggregateFunction match {
-                            case max: aggregate.Max =>
-                              max.copy(child = addCastIfNeeded(aggAttribute, 
max.child.dataType))
-                            case min: aggregate.Min =>
-                              min.copy(child = addCastIfNeeded(aggAttribute, 
min.child.dataType))
-                            case sum: aggregate.Sum =>
-                              sum.copy(child = addCastIfNeeded(aggAttribute, 
sum.child.dataType))
-                            case _: aggregate.Count =>
-                              aggregate.Sum(addCastIfNeeded(aggAttribute, 
LongType))
-                            case other => other
-                          }
-                        agg.copy(aggregateFunction = aggFunction)
-                      case expr if 
groupByExprToOutputOrdinal.contains(expr.canonicalized) =>
-                        val ordinal = 
groupByExprToOutputOrdinal(expr.canonicalized)
-                        addCastIfNeeded(groupAttrs(ordinal), expr.dataType)
-                    }
+                    aggAttribute
                   }
-                }
+                  sum.copy(child = newChild)
+                case _: aggregate.Count =>
+                  aggregate.Sum(aggAttribute)
+                case other => other
               }
-            case _ => aggNode
-          }
-        case _ => aggNode
+            agg.copy(aggregateFunction = aggFunction)
+          case expr if groupByExprToOutputOrdinal.contains(expr.canonicalized) 
=>
+            val ordinal = groupByExprToOutputOrdinal(expr.canonicalized)
+            expr match {
+              case ne: NamedExpression => Alias(groupOutput(ordinal), 
ne.name)(ne.exprId)
+              case _ => groupOutput(ordinal)
+            }
+        }).asInstanceOf[Seq[NamedExpression]]
+        Aggregate(groupOutput, aggExprs, holder)
       }
+
+    case _ => agg
   }
 
-  private def collectAggregates(resultExpressions: Seq[NamedExpression],
+  private def collectAggregates(
+      resultExpressions: Seq[NamedExpression],
       aggExprToOutputOrdinal: mutable.HashMap[Expression, Int]): 
Seq[AggregateExpression] = {
     var ordinal = 0
     resultExpressions.flatMap { expr =>
@@ -292,15 +305,12 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] 
with PredicateHelper wit
   }
 
   private def supportPartialAggPushDown(agg: Aggregation): Boolean = {
-    // We don't know the agg buffer of `GeneralAggregateFunc`, so can't do 
partial agg push down.
-    // If `Sum`, `Count`, `Avg` with distinct, can't do partial agg push down.
-    agg.aggregateExpressions().isEmpty || agg.aggregateExpressions().exists {
+    // We can only partially push down min/max/sum/count without DISTINCT.
+    agg.aggregateExpressions().isEmpty || agg.aggregateExpressions().forall {
       case sum: Sum => !sum.isDistinct
       case count: Count => !count.isDistinct
-      case avg: Avg => !avg.isDistinct
-      case _: GeneralAggregateFunc => false
-      case _: UserDefinedAggregateFunc => false
-      case _ => true
+      case _: Min | _: Max | _: CountStar => true
+      case _ => false
     }
   }
 
@@ -311,6 +321,26 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] 
with PredicateHelper wit
       Cast(expression, expectedDataType)
     }
 
+  def buildScanWithPushedAggregate(plan: LogicalPlan): LogicalPlan = 
plan.transform {
+    case holder: ScanBuilderHolder if holder.pushedAggregate.isDefined =>
+      // No need to do column pruning because only the aggregate columns are 
used as
+      // DataSourceV2ScanRelation output columns. All the other columns are not
+      // included in the output.
+      val scan = holder.builder.build()
+      val realOutput = scan.readSchema().toAttributes
+      assert(realOutput.length == holder.output.length,
+        "The data source returns unexpected number of columns")
+      val wrappedScan = getWrappedScan(scan, holder)
+      val scanRelation = DataSourceV2ScanRelation(holder.relation, 
wrappedScan, realOutput)
+      val projectList = realOutput.zip(holder.output).map { case (a1, a2) =>
+        // The data source may return columns with arbitrary data types and 
it's safer to cast them
+        // to the expected data type.
+        assert(Cast.canCast(a1.dataType, a2.dataType))
+        Alias(addCastIfNeeded(a1, a2.dataType), a2.name)(a2.exprId)
+      }
+      Project(projectList, scanRelation)
+  }
+
   def pruneColumns(plan: LogicalPlan): LogicalPlan = plan.transform {
     case ScanOperation(project, filters, sHolder: ScanBuilderHolder) =>
       // column pruning
@@ -325,7 +355,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] 
with PredicateHelper wit
            |Output: ${output.mkString(", ")}
          """.stripMargin)
 
-      val wrappedScan = getWrappedScan(scan, sHolder, 
Option.empty[Aggregation])
+      val wrappedScan = getWrappedScan(scan, sHolder)
 
       val scanRelation = DataSourceV2ScanRelation(sHolder.relation, 
wrappedScan, output)
 
@@ -378,8 +408,12 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] 
with PredicateHelper wit
       }
       (operation, isPushed && !isPartiallyPushed)
     case s @ Sort(order, _, operation @ ScanOperation(project, filter, 
sHolder: ScanBuilderHolder))
-        if filter.isEmpty && CollapseProject.canCollapseExpressions(
-          order, project, alwaysInline = true) =>
+        // Without building the Scan, we do not know the resulting column 
names after aggregate
+        // push-down, and thus can't push down Top-N which needs to know the 
ordering column names.
+        // TODO: we can support simple cases like GROUP BY columns directly 
and ORDER BY the same
+        //       columns, which we know the resulting column names: the 
original table columns.
+        if sHolder.pushedAggregate.isEmpty && filter.isEmpty &&
+          CollapseProject.canCollapseExpressions(order, project, alwaysInline 
= true) =>
       val aliasMap = getAliasMap(project)
       val newOrder = order.map(replaceAlias(_, 
aliasMap)).asInstanceOf[Seq[SortOrder]]
       val normalizedOrders = DataSourceStrategy.normalizeExprs(
@@ -480,10 +514,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] 
with PredicateHelper wit
       }
   }
 
-  private def getWrappedScan(
-      scan: Scan,
-      sHolder: ScanBuilderHolder,
-      aggregation: Option[Aggregation]): Scan = {
+  private def getWrappedScan(scan: Scan, sHolder: ScanBuilderHolder): Scan = {
     scan match {
       case v1: V1Scan =>
         val pushedFilters = sHolder.builder match {
@@ -491,7 +522,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] 
with PredicateHelper wit
             f.pushedFilters()
           case _ => Array.empty[sources.Filter]
         }
-        val pushedDownOperators = PushedDownOperators(aggregation, 
sHolder.pushedSample,
+        val pushedDownOperators = PushedDownOperators(sHolder.pushedAggregate, 
sHolder.pushedSample,
           sHolder.pushedLimit, sHolder.pushedOffset, sHolder.sortOrders, 
sHolder.pushedPredicates)
         V1ScanWrapper(v1, pushedFilters, pushedDownOperators)
       case _ => scan
@@ -500,7 +531,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] 
with PredicateHelper wit
 }
 
 case class ScanBuilderHolder(
-    output: Seq[AttributeReference],
+    var output: Seq[AttributeReference],
     relation: DataSourceV2Relation,
     builder: ScanBuilder) extends LeafNode {
   var pushedLimit: Option[Int] = None
@@ -512,6 +543,8 @@ case class ScanBuilderHolder(
   var pushedSample: Option[TableSampleInfo] = None
 
   var pushedPredicates: Seq[Predicate] = Seq.empty[Predicate]
+
+  var pushedAggregate: Option[Aggregation] = None
 }
 
 // A wrapper for v1 scan to carry the translated filters and the handled ones, 
along with
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
index 7e772c0febb..d64b1815007 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
@@ -265,9 +265,13 @@ class JDBCV2Suite extends QueryTest with 
SharedSparkSession with ExplainSuiteHel
       .table("h2.test.employee")
       .groupBy("DEPT").sum("SALARY")
       .limit(1)
-    checkLimitRemoved(df4, false)
+    checkAggregateRemoved(df4)
+    checkLimitRemoved(df4)
     checkPushedInfo(df4,
-      "PushedAggregates: [SUM(SALARY)], PushedFilters: [], 
PushedGroupByExpressions: [DEPT], ")
+      "PushedAggregates: [SUM(SALARY)]",
+      "PushedGroupByExpressions: [DEPT]",
+      "PushedFilters: []",
+      "PushedLimit: LIMIT 1")
     checkAnswer(df4, Seq(Row(1, 19000.00)))
 
     val name = udf { (x: String) => x.matches("cat|dav|amy") }
@@ -340,9 +344,13 @@ class JDBCV2Suite extends QueryTest with 
SharedSparkSession with ExplainSuiteHel
       .table("h2.test.employee")
       .groupBy("DEPT").sum("SALARY")
       .offset(1)
-    checkOffsetRemoved(df5, false)
+    checkAggregateRemoved(df5)
+    checkLimitRemoved(df5)
     checkPushedInfo(df5,
-      "PushedAggregates: [SUM(SALARY)], PushedFilters: [], 
PushedGroupByExpressions: [DEPT], ")
+      "PushedAggregates: [SUM(SALARY)]",
+      "PushedGroupByExpressions: [DEPT]",
+      "PushedFilters: []",
+      "PushedOffset: OFFSET 1")
     checkAnswer(df5, Seq(Row(2, 22000.00), Row(6, 12000.00)))
 
     val name = udf { (x: String) => x.matches("cat|dav|amy") }
@@ -477,10 +485,15 @@ class JDBCV2Suite extends QueryTest with 
SharedSparkSession with ExplainSuiteHel
       .groupBy("DEPT").sum("SALARY")
       .limit(2)
       .offset(1)
-    checkLimitRemoved(df10, false)
-    checkOffsetRemoved(df10, false)
+    checkAggregateRemoved(df10)
+    checkLimitRemoved(df10)
+    checkOffsetRemoved(df10)
     checkPushedInfo(df10,
-      "PushedAggregates: [SUM(SALARY)], PushedFilters: [], 
PushedGroupByExpressions: [DEPT], ")
+      "PushedAggregates: [SUM(SALARY)]",
+      "PushedGroupByExpressions: [DEPT]",
+      "PushedFilters: []",
+      "PushedLimit: LIMIT 2",
+      "PushedOffset: OFFSET 1")
     checkAnswer(df10, Seq(Row(2, 22000.00)))
 
     val name = udf { (x: String) => x.matches("cat|dav|amy") }
@@ -612,10 +625,15 @@ class JDBCV2Suite extends QueryTest with 
SharedSparkSession with ExplainSuiteHel
     checkAnswer(df9, Seq(Row(2, "david", 10000.00, 1300.0, true)))
 
     val df10 = sql("SELECT dept, sum(salary) FROM h2.test.employee group by 
dept LIMIT 1 OFFSET 1")
-    checkLimitRemoved(df10, false)
-    checkOffsetRemoved(df10, false)
+    checkAggregateRemoved(df10)
+    checkLimitRemoved(df10)
+    checkOffsetRemoved(df10)
     checkPushedInfo(df10,
-      "PushedAggregates: [SUM(SALARY)], PushedFilters: [], 
PushedGroupByExpressions: [DEPT], ")
+      "PushedAggregates: [SUM(SALARY)]",
+      "PushedGroupByExpressions: [DEPT]",
+      "PushedFilters: []",
+      "PushedLimit: LIMIT 2",
+      "PushedOffset: OFFSET 1")
     checkAnswer(df10, Seq(Row(2, 22000.00)))
 
     val name = udf { (x: String) => x.matches("cat|dav|amy") }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to