This is an automated email from the ASF dual-hosted git repository.

viirya pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 609e749  [SPARK-34960][SQL] Aggregate push down for ORC
609e749 is described below

commit 609e7498326ebedb06904a1f5bab59b739380b1a
Author: Cheng Su <chen...@fb.com>
AuthorDate: Thu Oct 28 17:29:15 2021 -0700

    [SPARK-34960][SQL] Aggregate push down for ORC
    
    ### What changes were proposed in this pull request?
    
    This PR is to add aggregate push down feature for ORC data source v2 reader.
    
    At a high level, the PR does:
    
    * The supported aggregate expression is MIN/MAX/COUNT same as [Parquet 
aggregate push down](https://github.com/apache/spark/pull/33639).
    * BooleanType, ByteType, ShortType, IntegerType, LongType, FloatType, 
DoubleType, DateType are allowed in MIN/MAXX aggregate push down. All other 
columns types are not allowed in MIN/MAX aggregate push down.
    * All columns types are supported in COUNT aggregate push down.
    * Nested column's sub-fields are disallowed in aggregate push down.
    * If the file does not have valid statistics, Spark will throw exception 
and fail query.
    * If aggregate has filter or group-by column, aggregate will not be pushed 
down.
    
    At code level, the PR does:
    * `OrcScanBuilder`: `pushAggregation()` checks whether the aggregation can 
be pushed down. The most checking logic is shared between Parquet and ORC, 
extracted into `AggregatePushDownUtils.getSchemaForPushedAggregation()`. 
`OrcScanBuilder` will create a `OrcScan` with aggregation and aggregation data 
schema.
    * `OrcScan`: `createReaderFactory` creates a ORC reader factory with 
aggregation and schema. Similar change with `ParquetScan`.
    * `OrcPartitionReaderFactory`: `buildReaderWithAggregates` creates a ORC 
reader with aggregate push down (i.e. read ORC file footer to process columns 
statistics, instead of reading actual data in the file). 
`buildColumnarReaderWithAggregates` creates a columnar ORC reader similarly. 
Both delegate the real work to read footer in 
`OrcUtils.createAggInternalRowFromFooter`.
    * `OrcUtils.createAggInternalRowFromFooter`: reads ORC file footer to 
process columns statistics (real heavy lift happens here). Similar to 
`ParquetUtils.createAggInternalRowFromFooter`. Leverage utility method such as 
`OrcFooterReader.readStatistics`.
    * `OrcFooterReader`: `readStatistics` reads the ORC `ColumnStatistics[]` 
into Spark `OrcColumnStatistics`. The transformation is needed here, because 
ORC `ColumnStatistics[]` stores all columns statistics in a flatten array 
style, and hard to process. Spark `OrcColumnStatistics` stores the statistics 
in nested tree structure (e.g. like `StructType`). This is used by 
`OrcUtils.createAggInternalRowFromFooter`
    * `OrcColumnStatistics`: the easy-to-manipulate structure for ORC 
`ColumnStatistics`. This is used by `OrcFooterReader.readStatistics`.
    
    ### Why are the changes needed?
    
    To improve the performance of query with aggregate.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes. A user-facing config `spark.sql.orc.aggregatePushdown` is added to 
control enabling/disabling the aggregate push down for ORC. By default the 
feature is disabled.
    
    ### How was this patch tested?
    
    Added unit test in `FileSourceAggregatePushDownSuite.scala`. Refactored all 
unit tests in https://github.com/apache/spark/pull/33639, and it now works for 
both Parquet and ORC.
    
    Closes #34298 from c21/orc-agg.
    
    Authored-by: Cheng Su <chen...@fb.com>
    Signed-off-by: Liang-Chi Hsieh <vii...@gmail.com>
---
 .../org/apache/spark/sql/internal/SQLConf.scala    |  10 +
 .../org/apache/spark/sql/types/StructType.scala    |   2 +-
 .../datasources/orc/OrcColumnStatistics.java       |  80 +++++
 .../execution/datasources/orc/OrcFooterReader.java |  67 +++++
 .../datasources/AggregatePushDownUtils.scala       | 141 +++++++++
 .../datasources/orc/OrcDeserializer.scala          |  16 +
 .../sql/execution/datasources/orc/OrcUtils.scala   | 122 +++++++-
 .../datasources/parquet/ParquetUtils.scala         |  41 ---
 .../v2/orc/OrcPartitionReaderFactory.scala         |  93 ++++--
 .../sql/execution/datasources/v2/orc/OrcScan.scala |  45 ++-
 .../datasources/v2/orc/OrcScanBuilder.scala        |  43 ++-
 .../v2/parquet/ParquetPartitionReaderFactory.scala |  14 +-
 .../datasources/v2/parquet/ParquetScan.scala       |  10 +-
 .../v2/parquet/ParquetScanBuilder.scala            |  93 ++----
 .../scala/org/apache/spark/sql/FileScanSuite.scala |   2 +-
 ...cala => FileSourceAggregatePushDownSuite.scala} | 324 ++++++++++++---------
 16 files changed, 804 insertions(+), 299 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index fe3204b..def6bbc1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -960,6 +960,14 @@ object SQLConf {
     .booleanConf
     .createWithDefault(true)
 
+  val ORC_AGGREGATE_PUSHDOWN_ENABLED = 
buildConf("spark.sql.orc.aggregatePushdown")
+    .doc("If true, aggregates will be pushed down to ORC for optimization. 
Support MIN, MAX and " +
+      "COUNT as aggregate expression. For MIN/MAX, support boolean, integer, 
float and date " +
+      "type. For COUNT, support all data types.")
+    .version("3.3.0")
+    .booleanConf
+    .createWithDefault(false)
+
   val ORC_SCHEMA_MERGING_ENABLED = buildConf("spark.sql.orc.mergeSchema")
     .doc("When true, the Orc data source merges schemas collected from all 
data files, " +
       "otherwise the schema is picked from a random data file.")
@@ -3706,6 +3714,8 @@ class SQLConf extends Serializable with Logging {
 
   def orcFilterPushDown: Boolean = getConf(ORC_FILTER_PUSHDOWN_ENABLED)
 
+  def orcAggregatePushDown: Boolean = getConf(ORC_AGGREGATE_PUSHDOWN_ENABLED)
+
   def isOrcSchemaMergingEnabled: Boolean = getConf(ORC_SCHEMA_MERGING_ENABLED)
 
   def verifyPartitionPath: Boolean = getConf(HIVE_VERIFY_PARTITION_PATH)
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
index 205b08f..6707fb2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
@@ -115,7 +115,7 @@ case class StructType(fields: Array[StructField]) extends 
DataType with Seq[Stru
   def names: Array[String] = fieldNames
 
   private lazy val fieldNamesSet: Set[String] = fieldNames.toSet
-  private[sql] lazy val nameToField: Map[String, StructField] = fields.map(f 
=> f.name -> f).toMap
+  private lazy val nameToField: Map[String, StructField] = fields.map(f => 
f.name -> f).toMap
   private lazy val nameToIndex: Map[String, Int] = 
fieldNames.zipWithIndex.toMap
 
   override def equals(that: Any): Boolean = {
diff --git 
a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnStatistics.java
 
b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnStatistics.java
new file mode 100644
index 0000000..8adb9e8
--- /dev/null
+++ 
b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnStatistics.java
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.orc;
+
+import org.apache.orc.ColumnStatistics;
+
+import java.util.ArrayList;
+import java.util.List;
+
+/**
+ * Columns statistics interface wrapping ORC {@link ColumnStatistics}s.
+ *
+ * Because ORC {@link ColumnStatistics}s are stored as an flatten array in ORC 
file footer,
+ * this class is used to covert ORC {@link ColumnStatistics}s from array to 
nested tree structure,
+ * according to data types. The flatten array stores all data types (including 
nested types) in
+ * tree pre-ordering. This is used for aggregate push down in ORC.
+ *
+ * For nested data types (array, map and struct), the sub-field statistics are 
stored recursively
+ * inside parent column's children field. Here is an example of {@link 
OrcColumnStatistics}:
+ *
+ * Data schema:
+ * c1: int
+ * c2: struct<f1: int, f2: float>
+ * c3: map<key: int, value: string>
+ * c4: array<int>
+ *
+ *                        OrcColumnStatistics
+ *                                | (children)
+ *             ---------------------------------------------
+ *            /         |                 \                 \
+ *           c1        c2                 c3                c4
+ *      (integer)    (struct)            (map)             (array)
+*        (min:1,        | (children)       | (children)      | (children)
+ *       max:10)      -----              -----             element
+ *                   /     \            /     \           (integer)
+ *                c2.f1    c2.f2      key     value
+ *              (integer) (float)  (integer) (string)
+ *                       (min:0.1,           (min:"a",
+ *                       max:100.5)          max:"zzz")
+ */
+public class OrcColumnStatistics {
+  private final ColumnStatistics statistics;
+  private final List<OrcColumnStatistics> children;
+
+  public OrcColumnStatistics(ColumnStatistics statistics) {
+    this.statistics = statistics;
+    this.children = new ArrayList<>();
+  }
+
+  public ColumnStatistics getStatistics() {
+    return statistics;
+  }
+
+  public OrcColumnStatistics get(int ordinal) {
+    if (ordinal < 0 || ordinal >= children.size()) {
+      throw new IndexOutOfBoundsException(
+        String.format("Ordinal %d out of bounds of statistics size %d", 
ordinal, children.size()));
+    }
+    return children.get(ordinal);
+  }
+
+  public void add(OrcColumnStatistics newChild) {
+    children.add(newChild);
+  }
+}
diff --git 
a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcFooterReader.java
 
b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcFooterReader.java
new file mode 100644
index 0000000..546b048
--- /dev/null
+++ 
b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcFooterReader.java
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.orc;
+
+import org.apache.orc.ColumnStatistics;
+import org.apache.orc.Reader;
+import org.apache.orc.TypeDescription;
+import org.apache.spark.sql.types.*;
+
+import java.util.Arrays;
+import java.util.LinkedList;
+import java.util.Queue;
+
+/**
+ * {@link OrcFooterReader} is a util class which encapsulates the helper
+ * methods of reading ORC file footer.
+ */
+public class OrcFooterReader {
+
+  /**
+   * Read the columns statistics from ORC file footer.
+   *
+   * @param orcReader the reader to read ORC file footer.
+   * @return Statistics for all columns in the file.
+   */
+  public static OrcColumnStatistics readStatistics(Reader orcReader) {
+    TypeDescription orcSchema = orcReader.getSchema();
+    ColumnStatistics[] orcStatistics = orcReader.getStatistics();
+    StructType sparkSchema = OrcUtils.toCatalystSchema(orcSchema);
+    return convertStatistics(sparkSchema, new 
LinkedList<>(Arrays.asList(orcStatistics)));
+  }
+
+  /**
+   * Convert a queue of ORC {@link ColumnStatistics}s into Spark {@link 
OrcColumnStatistics}.
+   * The queue of ORC {@link ColumnStatistics}s are assumed to be ordered as 
tree pre-order.
+   */
+  private static OrcColumnStatistics convertStatistics(
+      DataType sparkSchema, Queue<ColumnStatistics> orcStatistics) {
+    OrcColumnStatistics statistics = new 
OrcColumnStatistics(orcStatistics.remove());
+    if (sparkSchema instanceof StructType) {
+      for (StructField field : ((StructType) sparkSchema).fields()) {
+        statistics.add(convertStatistics(field.dataType(), orcStatistics));
+      }
+    } else if (sparkSchema instanceof MapType) {
+      statistics.add(convertStatistics(((MapType) sparkSchema).keyType(), 
orcStatistics));
+      statistics.add(convertStatistics(((MapType) sparkSchema).valueType(), 
orcStatistics));
+    } else if (sparkSchema instanceof ArrayType) {
+      statistics.add(convertStatistics(((ArrayType) 
sparkSchema).elementType(), orcStatistics));
+    }
+    return statistics;
+  }
+}
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/AggregatePushDownUtils.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/AggregatePushDownUtils.scala
new file mode 100644
index 0000000..6340d97
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/AggregatePushDownUtils.scala
@@ -0,0 +1,141 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.connector.expressions.NamedReference
+import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, 
Aggregation, Count, CountStar, Max, Min}
+import org.apache.spark.sql.execution.RowToColumnConverter
+import org.apache.spark.sql.execution.vectorized.{OffHeapColumnVector, 
OnHeapColumnVector}
+import org.apache.spark.sql.types.{BooleanType, ByteType, DateType, 
DoubleType, FloatType, IntegerType, LongType, ShortType, StructField, 
StructType}
+import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector}
+
+/**
+ * Utility class for aggregate push down to Parquet and ORC.
+ */
+object AggregatePushDownUtils {
+
+  /**
+   * Get the data schema for aggregate to be pushed down.
+   */
+  def getSchemaForPushedAggregation(
+      aggregation: Aggregation,
+      schema: StructType,
+      partitionNames: Set[String],
+      dataFilters: Seq[Expression]): Option[StructType] = {
+
+    var finalSchema = new StructType()
+
+    def getStructFieldForCol(col: NamedReference): StructField = {
+      schema.apply(col.fieldNames.head)
+    }
+
+    def isPartitionCol(col: NamedReference) = {
+      partitionNames.contains(col.fieldNames.head)
+    }
+
+    def processMinOrMax(agg: AggregateFunc): Boolean = {
+      val (column, aggType) = agg match {
+        case max: Max => (max.column, "max")
+        case min: Min => (min.column, "min")
+        case _ =>
+          throw new IllegalArgumentException(s"Unexpected type of 
AggregateFunc ${agg.describe}")
+      }
+
+      if (isPartitionCol(column)) {
+        // don't push down partition column, footer doesn't have max/min for 
partition column
+        return false
+      }
+      val structField = getStructFieldForCol(column)
+
+      structField.dataType match {
+        // not push down complex type
+        // not push down Timestamp because INT96 sort order is undefined,
+        // Parquet doesn't return statistics for INT96
+        // not push down Parquet Binary because min/max could be truncated
+        // (https://issues.apache.org/jira/browse/PARQUET-1685), Parquet Binary
+        // could be Spark StringType, BinaryType or DecimalType.
+        // not push down for ORC with same reason.
+        case BooleanType | ByteType | ShortType | IntegerType
+             | LongType | FloatType | DoubleType | DateType =>
+          finalSchema = finalSchema.add(structField.copy(s"$aggType(" + 
structField.name + ")"))
+          true
+        case _ =>
+          false
+      }
+    }
+
+    if (aggregation.groupByColumns.nonEmpty || dataFilters.nonEmpty) {
+      // Parquet/ORC footer has max/min/count for columns
+      // e.g. SELECT COUNT(col1) FROM t
+      // but footer doesn't have max/min/count for a column if max/min/count
+      // are combined with filter or group by
+      // e.g. SELECT COUNT(col1) FROM t WHERE col2 = 8
+      //      SELECT COUNT(col1) FROM t GROUP BY col2
+      // However, if the filter is on partition column, max/min/count can 
still be pushed down
+      // Todo:  add support if groupby column is partition col
+      //        (https://issues.apache.org/jira/browse/SPARK-36646)
+      return None
+    }
+
+    aggregation.aggregateExpressions.foreach {
+      case max: Max =>
+        if (!processMinOrMax(max)) return None
+      case min: Min =>
+        if (!processMinOrMax(min)) return None
+      case count: Count =>
+        if (count.column.fieldNames.length != 1 || count.isDistinct) return 
None
+        finalSchema =
+          finalSchema.add(StructField(s"count(" + count.column.fieldNames.head 
+ ")", LongType))
+      case _: CountStar =>
+        finalSchema = finalSchema.add(StructField("count(*)", LongType))
+      case _ =>
+        return None
+    }
+
+    Some(finalSchema)
+  }
+
+  /**
+   * Check if two Aggregation `a` and `b` is equal or not.
+   */
+  def equivalentAggregations(a: Aggregation, b: Aggregation): Boolean = {
+    a.aggregateExpressions.sortBy(_.hashCode())
+      .sameElements(b.aggregateExpressions.sortBy(_.hashCode())) &&
+      
a.groupByColumns.sortBy(_.hashCode()).sameElements(b.groupByColumns.sortBy(_.hashCode()))
+  }
+
+  /**
+   * Convert the aggregates result from `InternalRow` to `ColumnarBatch`.
+   * This is used for columnar reader.
+   */
+  def convertAggregatesRowToBatch(
+      aggregatesAsRow: InternalRow,
+      aggregatesSchema: StructType,
+      offHeap: Boolean): ColumnarBatch = {
+    val converter = new RowToColumnConverter(aggregatesSchema)
+    val columnVectors = if (offHeap) {
+      OffHeapColumnVector.allocateColumns(1, aggregatesSchema)
+    } else {
+      OnHeapColumnVector.allocateColumns(1, aggregatesSchema)
+    }
+    converter.convert(aggregatesAsRow, columnVectors.toArray)
+    new ColumnarBatch(columnVectors.asInstanceOf[Array[ColumnVector]], 1)
+  }
+}
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala
index 1476083..9140833 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala
@@ -68,6 +68,22 @@ class OrcDeserializer(
     resultRow
   }
 
+  def deserializeFromValues(orcValues: Seq[WritableComparable[_]]): 
InternalRow = {
+    var targetColumnIndex = 0
+    while (targetColumnIndex < fieldWriters.length) {
+      if (fieldWriters(targetColumnIndex) != null) {
+        val value = orcValues(requestedColIds(targetColumnIndex))
+        if (value == null) {
+          resultRow.setNullAt(targetColumnIndex)
+        } else {
+          fieldWriters(targetColumnIndex)(value)
+        }
+      }
+      targetColumnIndex += 1
+    }
+    resultRow
+  }
+
   /**
    * Creates a writer to write ORC values to Catalyst data structure at the 
given ordinal.
    */
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala
index 475448a..b262415 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala
@@ -25,15 +25,19 @@ import scala.collection.mutable.ArrayBuffer
 
 import org.apache.hadoop.conf.Configuration
 import org.apache.hadoop.fs.{FileStatus, Path}
-import org.apache.orc.{OrcConf, OrcFile, Reader, TypeDescription, Writer}
+import org.apache.hadoop.hive.serde2.io.DateWritable
+import org.apache.hadoop.io.{BooleanWritable, ByteWritable, DoubleWritable, 
FloatWritable, IntWritable, LongWritable, ShortWritable, WritableComparable}
+import org.apache.orc.{BooleanColumnStatistics, ColumnStatistics, 
DateColumnStatistics, DoubleColumnStatistics, IntegerColumnStatistics, OrcConf, 
OrcFile, Reader, TypeDescription, Writer}
 
-import org.apache.spark.SPARK_VERSION_SHORT
+import org.apache.spark.{SPARK_VERSION_SHORT, SparkException}
 import org.apache.spark.deploy.SparkHadoopUtil
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.{SPARK_VERSION_METADATA_KEY, SparkSession}
+import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.analysis.caseSensitiveResolution
 import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
 import org.apache.spark.sql.catalyst.util.{quoteIdentifier, CharVarcharUtils}
+import org.apache.spark.sql.connector.expressions.aggregate.{Aggregation, 
Count, CountStar, Max, Min}
 import org.apache.spark.sql.errors.QueryExecutionErrors
 import org.apache.spark.sql.execution.datasources.SchemaMergeUtils
 import org.apache.spark.sql.types._
@@ -87,7 +91,7 @@ object OrcUtils extends Logging {
     }
   }
 
-  private def toCatalystSchema(schema: TypeDescription): StructType = {
+  def toCatalystSchema(schema: TypeDescription): StructType = {
     import TypeDescription.Category
 
     def toCatalystType(orcType: TypeDescription): DataType = {
@@ -377,4 +381,116 @@ object OrcUtils extends Logging {
       case _ => false
     }
   }
+
+  /**
+   * When the partial aggregates (Max/Min/Count) are pushed down to ORC, we 
don't need to read data
+   * from ORC and aggregate at Spark layer. Instead we want to get the partial 
aggregates
+   * (Max/Min/Count) result using the statistics information from ORC file 
footer, and then
+   * construct an InternalRow from these aggregate results.
+   *
+   * @return Aggregate results in the format of InternalRow
+   */
+  def createAggInternalRowFromFooter(
+      reader: Reader,
+      filePath: String,
+      dataSchema: StructType,
+      partitionSchema: StructType,
+      aggregation: Aggregation,
+      aggSchema: StructType): InternalRow = {
+    require(aggregation.groupByColumns.length == 0,
+      s"aggregate $aggregation with group-by column shouldn't be pushed down")
+    var columnsStatistics: OrcColumnStatistics = null
+    try {
+      columnsStatistics = OrcFooterReader.readStatistics(reader)
+    } catch { case e: Exception =>
+      throw new SparkException(
+        s"Cannot read columns statistics in file: $filePath. Please consider 
disabling " +
+        s"ORC aggregate push down by setting 'spark.sql.orc.aggregatePushdown' 
to false.", e)
+    }
+
+    // Get column statistics with column name.
+    def getColumnStatistics(columnName: String): ColumnStatistics = {
+      val columnIndex = dataSchema.fieldNames.indexOf(columnName)
+      columnsStatistics.get(columnIndex).getStatistics
+    }
+
+    // Get Min/Max statistics and store as ORC `WritableComparable` format.
+    // Return null if number of non-null values is zero.
+    def getMinMaxFromColumnStatistics(
+        statistics: ColumnStatistics,
+        dataType: DataType,
+        isMax: Boolean): WritableComparable[_] = {
+      if (statistics.getNumberOfValues == 0) {
+        return null
+      }
+
+      statistics match {
+        case s: BooleanColumnStatistics =>
+          val value = if (isMax) s.getTrueCount > 0 else !(s.getFalseCount > 0)
+          new BooleanWritable(value)
+        case s: IntegerColumnStatistics =>
+          val value = if (isMax) s.getMaximum else s.getMinimum
+          dataType match {
+            case ByteType => new ByteWritable(value.toByte)
+            case ShortType => new ShortWritable(value.toShort)
+            case IntegerType => new IntWritable(value.toInt)
+            case LongType => new LongWritable(value)
+            case _ => throw new IllegalArgumentException(
+              s"getMinMaxFromColumnStatistics should not take type $dataType " 
+
+              "for IntegerColumnStatistics")
+          }
+        case s: DoubleColumnStatistics =>
+          val value = if (isMax) s.getMaximum else s.getMinimum
+          dataType match {
+            case FloatType => new FloatWritable(value.toFloat)
+            case DoubleType => new DoubleWritable(value)
+            case _ => throw new IllegalArgumentException(
+              s"getMinMaxFromColumnStatistics should not take type $dataType " 
+
+                "for DoubleColumnStatistics")
+          }
+        case s: DateColumnStatistics =>
+          new DateWritable(
+            if (isMax) s.getMaximumDayOfEpoch.toInt else 
s.getMinimumDayOfEpoch.toInt)
+        case _ => throw new IllegalArgumentException(
+          s"getMinMaxFromColumnStatistics should not take 
${statistics.getClass.getName}: " +
+            s"$statistics as the ORC column statistics")
+      }
+    }
+
+    val aggORCValues: Seq[WritableComparable[_]] =
+      aggregation.aggregateExpressions.zipWithIndex.map {
+        case (max: Max, index) =>
+          val columnName = max.column.fieldNames.head
+          val statistics = getColumnStatistics(columnName)
+          val dataType = aggSchema(index).dataType
+          getMinMaxFromColumnStatistics(statistics, dataType, isMax = true)
+        case (min: Min, index) =>
+          val columnName = min.column.fieldNames.head
+          val statistics = getColumnStatistics(columnName)
+          val dataType = aggSchema.apply(index).dataType
+          getMinMaxFromColumnStatistics(statistics, dataType, isMax = false)
+        case (count: Count, _) =>
+          val columnName = count.column.fieldNames.head
+          val isPartitionColumn = 
partitionSchema.fields.map(_.name).contains(columnName)
+          // NOTE: Count(columnName) doesn't include null values.
+          // org.apache.orc.ColumnStatistics.getNumberOfValues() returns 
number of non-null values
+          // for ColumnStatistics of individual column. In addition to this, 
ORC also stores number
+          // of all values (null and non-null) separately.
+          val nonNullRowsCount = if (isPartitionColumn) {
+            columnsStatistics.getStatistics.getNumberOfValues
+          } else {
+            getColumnStatistics(columnName).getNumberOfValues
+          }
+          new LongWritable(nonNullRowsCount)
+        case (_: CountStar, _) =>
+          // Count(*) includes both null and non-null values.
+          new LongWritable(columnsStatistics.getStatistics.getNumberOfValues)
+        case (x, _) =>
+          throw new IllegalArgumentException(
+            s"createAggInternalRowFromFooter should not take $x as the 
aggregate expression")
+      }
+
+    val orcValuesDeserializer = new OrcDeserializer(aggSchema, (0 until 
aggSchema.length).toArray)
+    orcValuesDeserializer.deserializeFromValues(aggORCValues)
+  }
 }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala
index 1093f9c..0e4b9283 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala
@@ -32,12 +32,9 @@ import org.apache.spark.SparkException
 import org.apache.spark.sql.SparkSession
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.connector.expressions.aggregate.{Aggregation, 
Count, CountStar, Max, Min}
-import org.apache.spark.sql.execution.RowToColumnConverter
 import org.apache.spark.sql.execution.datasources.PartitioningUtils
-import org.apache.spark.sql.execution.vectorized.{OffHeapColumnVector, 
OnHeapColumnVector}
 import org.apache.spark.sql.internal.SQLConf.{LegacyBehaviorPolicy, 
PARQUET_AGGREGATE_PUSHDOWN_ENABLED}
 import org.apache.spark.sql.types.StructType
-import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector}
 
 object ParquetUtils {
   def inferSchema(
@@ -202,44 +199,6 @@ object ParquetUtils {
   }
 
   /**
-   * When the aggregates (Max/Min/Count) are pushed down to Parquet, in the 
case of
-   * PARQUET_VECTORIZED_READER_ENABLED sets to true, we don't need 
buildColumnarReader
-   * to read data from Parquet and aggregate at Spark layer. Instead we want
-   * to get the aggregates (Max/Min/Count) result using the statistics 
information
-   * from Parquet footer file, and then construct a ColumnarBatch from these 
aggregate results.
-   *
-   * @return Aggregate results in the format of ColumnarBatch
-   */
-  private[sql] def createAggColumnarBatchFromFooter(
-      footer: ParquetMetadata,
-      filePath: String,
-      dataSchema: StructType,
-      partitionSchema: StructType,
-      aggregation: Aggregation,
-      aggSchema: StructType,
-      offHeap: Boolean,
-      datetimeRebaseMode: LegacyBehaviorPolicy.Value,
-      isCaseSensitive: Boolean): ColumnarBatch = {
-    val row = createAggInternalRowFromFooter(
-      footer,
-      filePath,
-      dataSchema,
-      partitionSchema,
-      aggregation,
-      aggSchema,
-      datetimeRebaseMode,
-      isCaseSensitive)
-    val converter = new RowToColumnConverter(aggSchema)
-    val columnVectors = if (offHeap) {
-      OffHeapColumnVector.allocateColumns(1, aggSchema)
-    } else {
-      OnHeapColumnVector.allocateColumns(1, aggSchema)
-    }
-    converter.convert(row, columnVectors.toArray)
-    new ColumnarBatch(columnVectors.asInstanceOf[Array[ColumnVector]], 1)
-  }
-
-  /**
    * Calculate the pushed down aggregates (Max/Min/Count) result using the 
statistics
    * information from Parquet footer file.
    *
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala
index c5020cb..246f160 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala
@@ -23,15 +23,16 @@ import org.apache.hadoop.fs.Path
 import org.apache.hadoop.mapreduce.{JobID, TaskAttemptID, TaskID, TaskType}
 import org.apache.hadoop.mapreduce.lib.input.FileSplit
 import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl
-import org.apache.orc.{OrcConf, OrcFile, TypeDescription}
+import org.apache.orc.{OrcConf, OrcFile, Reader, TypeDescription}
 import org.apache.orc.mapred.OrcStruct
 import org.apache.orc.mapreduce.OrcInputFormat
 
 import org.apache.spark.broadcast.Broadcast
 import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.connector.expressions.aggregate.Aggregation
 import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader}
 import org.apache.spark.sql.execution.WholeStageCodegenExec
-import org.apache.spark.sql.execution.datasources.PartitionedFile
+import org.apache.spark.sql.execution.datasources.{AggregatePushDownUtils, 
PartitionedFile}
 import org.apache.spark.sql.execution.datasources.orc.{OrcColumnarBatchReader, 
OrcDeserializer, OrcFilters, OrcUtils}
 import org.apache.spark.sql.execution.datasources.v2._
 import org.apache.spark.sql.internal.SQLConf
@@ -55,7 +56,8 @@ case class OrcPartitionReaderFactory(
     dataSchema: StructType,
     readDataSchema: StructType,
     partitionSchema: StructType,
-    filters: Array[Filter]) extends FilePartitionReaderFactory {
+    filters: Array[Filter],
+    aggregation: Option[Aggregation]) extends FilePartitionReaderFactory {
   private val resultSchema = StructType(readDataSchema.fields ++ 
partitionSchema.fields)
   private val isCaseSensitive = sqlConf.caseSensitiveAnalysis
   private val capacity = sqlConf.orcVectorizedReaderBatchSize
@@ -81,17 +83,14 @@ case class OrcPartitionReaderFactory(
 
   override def buildReader(file: PartitionedFile): 
PartitionReader[InternalRow] = {
     val conf = broadcastedConf.value.value
-
-    OrcConf.IS_SCHEMA_EVOLUTION_CASE_SENSITIVE.setBoolean(conf, 
isCaseSensitive)
-
     val filePath = new Path(new URI(file.filePath))
 
-    pushDownPredicates(filePath, conf)
+    if (aggregation.nonEmpty) {
+      return buildReaderWithAggregates(filePath, conf)
+    }
 
-    val fs = filePath.getFileSystem(conf)
-    val readerOptions = OrcFile.readerOptions(conf).filesystem(fs)
     val resultedColPruneInfo =
-      Utils.tryWithResource(OrcFile.createReader(filePath, readerOptions)) { 
reader =>
+      Utils.tryWithResource(createORCReader(filePath, conf)) { reader =>
         OrcUtils.requestedColumnIds(
           isCaseSensitive, dataSchema, readDataSchema, reader, conf)
       }
@@ -128,17 +127,14 @@ case class OrcPartitionReaderFactory(
 
   override def buildColumnarReader(file: PartitionedFile): 
PartitionReader[ColumnarBatch] = {
     val conf = broadcastedConf.value.value
-
-    OrcConf.IS_SCHEMA_EVOLUTION_CASE_SENSITIVE.setBoolean(conf, 
isCaseSensitive)
-
     val filePath = new Path(new URI(file.filePath))
 
-    pushDownPredicates(filePath, conf)
+    if (aggregation.nonEmpty) {
+      return buildColumnarReaderWithAggregates(filePath, conf)
+    }
 
-    val fs = filePath.getFileSystem(conf)
-    val readerOptions = OrcFile.readerOptions(conf).filesystem(fs)
     val resultedColPruneInfo =
-      Utils.tryWithResource(OrcFile.createReader(filePath, readerOptions)) { 
reader =>
+      Utils.tryWithResource(createORCReader(filePath, conf)) { reader =>
         OrcUtils.requestedColumnIds(
           isCaseSensitive, dataSchema, readDataSchema, reader, conf)
       }
@@ -173,4 +169,67 @@ case class OrcPartitionReaderFactory(
     }
   }
 
+  private def createORCReader(filePath: Path, conf: Configuration): Reader = {
+    OrcConf.IS_SCHEMA_EVOLUTION_CASE_SENSITIVE.setBoolean(conf, 
isCaseSensitive)
+
+    pushDownPredicates(filePath, conf)
+
+    val fs = filePath.getFileSystem(conf)
+    val readerOptions = OrcFile.readerOptions(conf).filesystem(fs)
+    OrcFile.createReader(filePath, readerOptions)
+  }
+
+  /**
+   * Build reader with aggregate push down.
+   */
+  private def buildReaderWithAggregates(
+      filePath: Path,
+      conf: Configuration): PartitionReader[InternalRow] = {
+    new PartitionReader[InternalRow] {
+      private var hasNext = true
+      private lazy val row: InternalRow = {
+        Utils.tryWithResource(createORCReader(filePath, conf)) { reader =>
+          OrcUtils.createAggInternalRowFromFooter(
+            reader, filePath.toString, dataSchema, partitionSchema, 
aggregation.get, readDataSchema)
+        }
+      }
+
+      override def next(): Boolean = hasNext
+
+      override def get(): InternalRow = {
+        hasNext = false
+        row
+      }
+
+      override def close(): Unit = {}
+    }
+  }
+
+  /**
+   * Build columnar reader with aggregate push down.
+   */
+  private def buildColumnarReaderWithAggregates(
+      filePath: Path,
+      conf: Configuration): PartitionReader[ColumnarBatch] = {
+    new PartitionReader[ColumnarBatch] {
+      private var hasNext = true
+      private lazy val batch: ColumnarBatch = {
+        Utils.tryWithResource(createORCReader(filePath, conf)) { reader =>
+          val row = OrcUtils.createAggInternalRowFromFooter(
+            reader, filePath.toString, dataSchema, partitionSchema, 
aggregation.get,
+            readDataSchema)
+          AggregatePushDownUtils.convertAggregatesRowToBatch(row, 
readDataSchema, offHeap = false)
+        }
+      }
+
+      override def next(): Boolean = hasNext
+
+      override def get(): ColumnarBatch = {
+        hasNext = false
+        batch
+      }
+
+      override def close(): Unit = {}
+    }
+  }
 }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala
index 7619e3c..6b9d181 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala
@@ -21,8 +21,9 @@ import org.apache.hadoop.fs.Path
 
 import org.apache.spark.sql.SparkSession
 import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.connector.expressions.aggregate.Aggregation
 import org.apache.spark.sql.connector.read.PartitionReaderFactory
-import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex
+import org.apache.spark.sql.execution.datasources.{AggregatePushDownUtils, 
PartitioningAwareFileIndex}
 import org.apache.spark.sql.execution.datasources.v2.FileScan
 import org.apache.spark.sql.sources.Filter
 import org.apache.spark.sql.types.StructType
@@ -37,10 +38,25 @@ case class OrcScan(
     readDataSchema: StructType,
     readPartitionSchema: StructType,
     options: CaseInsensitiveStringMap,
+    pushedAggregate: Option[Aggregation] = None,
     pushedFilters: Array[Filter],
     partitionFilters: Seq[Expression] = Seq.empty,
     dataFilters: Seq[Expression] = Seq.empty) extends FileScan {
-  override def isSplitable(path: Path): Boolean = true
+  override def isSplitable(path: Path): Boolean = {
+    // If aggregate is pushed down, only the file footer will be read once,
+    // so file should be not split across multiple tasks.
+    pushedAggregate.isEmpty
+  }
+
+  override def readSchema(): StructType = {
+    // If aggregate is pushed down, schema has already been pruned in 
`OrcScanBuilder`
+    // and no need to call super.readSchema()
+    if (pushedAggregate.nonEmpty) {
+      readDataSchema
+    } else {
+      super.readSchema()
+    }
+  }
 
   override def createReaderFactory(): PartitionReaderFactory = {
     val broadcastedConf = sparkSession.sparkContext.broadcast(
@@ -48,24 +64,39 @@ case class OrcScan(
     // The partition values are already truncated in `FileScan.partitions`.
     // We should use `readPartitionSchema` as the partition schema here.
     OrcPartitionReaderFactory(sparkSession.sessionState.conf, broadcastedConf,
-      dataSchema, readDataSchema, readPartitionSchema, pushedFilters)
+      dataSchema, readDataSchema, readPartitionSchema, pushedFilters, 
pushedAggregate)
   }
 
   override def equals(obj: Any): Boolean = obj match {
     case o: OrcScan =>
+      val pushedDownAggEqual = if (pushedAggregate.nonEmpty && 
o.pushedAggregate.nonEmpty) {
+        AggregatePushDownUtils.equivalentAggregations(pushedAggregate.get, 
o.pushedAggregate.get)
+      } else {
+        pushedAggregate.isEmpty && o.pushedAggregate.isEmpty
+      }
       super.equals(o) && dataSchema == o.dataSchema && options == o.options &&
-        equivalentFilters(pushedFilters, o.pushedFilters)
-
+        equivalentFilters(pushedFilters, o.pushedFilters) && pushedDownAggEqual
     case _ => false
   }
 
   override def hashCode(): Int = getClass.hashCode()
 
+  lazy private val (pushedAggregationsStr, pushedGroupByStr) = if 
(pushedAggregate.nonEmpty) {
+    (seqToString(pushedAggregate.get.aggregateExpressions),
+      seqToString(pushedAggregate.get.groupByColumns))
+  } else {
+    ("[]", "[]")
+  }
+
   override def description(): String = {
-    super.description() + ", PushedFilters: " + seqToString(pushedFilters)
+    super.description() + ", PushedFilters: " + seqToString(pushedFilters) +
+      ", PushedAggregation: " + pushedAggregationsStr +
+      ", PushedGroupBy: " + pushedGroupByStr
   }
 
   override def getMetaData(): Map[String, String] = {
-    super.getMetaData() ++ Map("PushedFilters" -> seqToString(pushedFilters))
+    super.getMetaData() ++ Map("PushedFilters" -> seqToString(pushedFilters)) 
++
+      Map("PushedAggregation" -> pushedAggregationsStr) ++
+      Map("PushedGroupBy" -> pushedGroupByStr)
   }
 }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala
index cfa396f..d2c17fd 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala
@@ -20,8 +20,9 @@ package org.apache.spark.sql.execution.datasources.v2.orc
 import scala.collection.JavaConverters._
 
 import org.apache.spark.sql.SparkSession
-import org.apache.spark.sql.connector.read.Scan
-import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex
+import org.apache.spark.sql.connector.expressions.aggregate.Aggregation
+import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownAggregates}
+import org.apache.spark.sql.execution.datasources.{AggregatePushDownUtils, 
PartitioningAwareFileIndex}
 import org.apache.spark.sql.execution.datasources.orc.OrcFilters
 import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder
 import org.apache.spark.sql.internal.SQLConf
@@ -35,18 +36,31 @@ case class OrcScanBuilder(
     schema: StructType,
     dataSchema: StructType,
     options: CaseInsensitiveStringMap)
-  extends FileScanBuilder(sparkSession, fileIndex, dataSchema) {
+  extends FileScanBuilder(sparkSession, fileIndex, dataSchema)
+  with SupportsPushDownAggregates {
+
   lazy val hadoopConf = {
     val caseSensitiveMap = options.asCaseSensitiveMap.asScala.toMap
     // Hadoop Configurations are case sensitive.
     sparkSession.sessionState.newHadoopConfWithOptions(caseSensitiveMap)
   }
 
+  private var finalSchema = new StructType()
+
+  private var pushedAggregations = Option.empty[Aggregation]
+
   override protected val supportsNestedSchemaPruning: Boolean = true
 
   override def build(): Scan = {
-    OrcScan(sparkSession, hadoopConf, fileIndex, dataSchema, readDataSchema(),
-      readPartitionSchema(), options, pushedDataFilters, partitionFilters, 
dataFilters)
+    // the `finalSchema` is either pruned in pushAggregation (if aggregates are
+    // pushed down), or pruned in readDataSchema() (in regular column 
pruning). These
+    // two are mutual exclusive.
+    if (pushedAggregations.isEmpty) {
+      finalSchema = readDataSchema()
+    }
+    OrcScan(sparkSession, hadoopConf, fileIndex, dataSchema, finalSchema,
+      readPartitionSchema(), options, pushedAggregations, pushedDataFilters, 
partitionFilters,
+      dataFilters)
   }
 
   override def pushDataFilters(dataFilters: Array[Filter]): Array[Filter] = {
@@ -58,4 +72,23 @@ case class OrcScanBuilder(
       Array.empty[Filter]
     }
   }
+
+  override def pushAggregation(aggregation: Aggregation): Boolean = {
+    if (!sparkSession.sessionState.conf.orcAggregatePushDown) {
+      return false
+    }
+
+    AggregatePushDownUtils.getSchemaForPushedAggregation(
+      aggregation,
+      schema,
+      partitionNameSet,
+      dataFilters) match {
+
+      case Some(schema) =>
+        finalSchema = schema
+        this.pushedAggregations = Some(aggregation)
+        true
+      case _ => false
+    }
+  }
 }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala
index 111018b..6f021ff 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala
@@ -36,7 +36,7 @@ import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.util.DateTimeUtils
 import org.apache.spark.sql.connector.expressions.aggregate.Aggregation
 import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader}
-import org.apache.spark.sql.execution.datasources.{DataSourceUtils, 
PartitionedFile, RecordReaderIterator}
+import org.apache.spark.sql.execution.datasources.{AggregatePushDownUtils, 
DataSourceUtils, PartitionedFile, RecordReaderIterator}
 import org.apache.spark.sql.execution.datasources.parquet._
 import org.apache.spark.sql.execution.datasources.v2._
 import org.apache.spark.sql.internal.SQLConf
@@ -175,24 +175,26 @@ case class ParquetPartitionReaderFactory(
     } else {
       new PartitionReader[ColumnarBatch] {
         private var hasNext = true
-        private val row: ColumnarBatch = {
+        private val batch: ColumnarBatch = {
           val footer = getFooter(file)
           if (footer != null && footer.getBlocks.size > 0) {
-            ParquetUtils.createAggColumnarBatchFromFooter(footer, 
file.filePath, dataSchema,
-              partitionSchema, aggregation.get, readDataSchema, 
enableOffHeapColumnVector,
+            val row = ParquetUtils.createAggInternalRowFromFooter(footer, 
file.filePath,
+              dataSchema, partitionSchema, aggregation.get, readDataSchema,
               getDatetimeRebaseMode(footer.getFileMetaData), isCaseSensitive)
+            AggregatePushDownUtils.convertAggregatesRowToBatch(
+              row, readDataSchema, enableOffHeapColumnVector && 
Option(TaskContext.get()).isDefined)
           } else {
             null
           }
         }
 
         override def next(): Boolean = {
-          hasNext && row != null
+          hasNext && batch != null
         }
 
         override def get(): ColumnarBatch = {
           hasNext = false
-          row
+          batch
         }
 
         override def close(): Unit = {}
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala
index 42dc287..b92ed82 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala
@@ -26,7 +26,7 @@ import org.apache.spark.sql.SparkSession
 import org.apache.spark.sql.catalyst.expressions.Expression
 import org.apache.spark.sql.connector.expressions.aggregate.Aggregation
 import org.apache.spark.sql.connector.read.PartitionReaderFactory
-import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex
+import org.apache.spark.sql.execution.datasources.{AggregatePushDownUtils, 
PartitioningAwareFileIndex}
 import org.apache.spark.sql.execution.datasources.parquet.{ParquetOptions, 
ParquetReadSupport, ParquetWriteSupport}
 import org.apache.spark.sql.execution.datasources.v2.FileScan
 import org.apache.spark.sql.internal.SQLConf
@@ -101,7 +101,7 @@ case class ParquetScan(
   override def equals(obj: Any): Boolean = obj match {
     case p: ParquetScan =>
       val pushedDownAggEqual = if (pushedAggregate.nonEmpty && 
p.pushedAggregate.nonEmpty) {
-        equivalentAggregations(pushedAggregate.get, p.pushedAggregate.get)
+        AggregatePushDownUtils.equivalentAggregations(pushedAggregate.get, 
p.pushedAggregate.get)
       } else {
         pushedAggregate.isEmpty && p.pushedAggregate.isEmpty
       }
@@ -130,10 +130,4 @@ case class ParquetScan(
       Map("PushedAggregation" -> pushedAggregationsStr) ++
       Map("PushedGroupBy" -> pushedGroupByStr)
   }
-
-  private def equivalentAggregations(a: Aggregation, b: Aggregation): Boolean 
= {
-    a.aggregateExpressions.sortBy(_.hashCode())
-      .sameElements(b.aggregateExpressions.sortBy(_.hashCode())) &&
-      
a.groupByColumns.sortBy(_.hashCode()).sameElements(b.groupByColumns.sortBy(_.hashCode()))
-  }
 }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala
index da49381..74d11b6 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala
@@ -20,15 +20,14 @@ package 
org.apache.spark.sql.execution.datasources.v2.parquet
 import scala.collection.JavaConverters._
 
 import org.apache.spark.sql.SparkSession
-import org.apache.spark.sql.connector.expressions.NamedReference
-import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, 
Aggregation, Count, CountStar, Max, Min}
+import org.apache.spark.sql.connector.expressions.aggregate.Aggregation
 import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownAggregates}
-import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex
+import org.apache.spark.sql.execution.datasources.{AggregatePushDownUtils, 
PartitioningAwareFileIndex}
 import org.apache.spark.sql.execution.datasources.parquet.{ParquetFilters, 
SparkToParquetSchemaConverter}
 import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder
 import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy
 import org.apache.spark.sql.sources.Filter
-import org.apache.spark.sql.types.{BooleanType, ByteType, DateType, 
DoubleType, FloatType, IntegerType, LongType, ShortType, StructField, 
StructType}
+import org.apache.spark.sql.types.StructType
 import org.apache.spark.sql.util.CaseInsensitiveStringMap
 
 case class ParquetScanBuilder(
@@ -87,87 +86,31 @@ case class ParquetScanBuilder(
   override def pushedFilters(): Array[Filter] = pushedParquetFilters
 
   override def pushAggregation(aggregation: Aggregation): Boolean = {
-
-    def getStructFieldForCol(col: NamedReference): StructField = {
-      schema.nameToField(col.fieldNames.head)
-    }
-
-    def isPartitionCol(col: NamedReference) = {
-      partitionNameSet.contains(col.fieldNames.head)
-    }
-
-    def processMinOrMax(agg: AggregateFunc): Boolean = {
-      val (column, aggType) = agg match {
-        case max: Max => (max.column, "max")
-        case min: Min => (min.column, "min")
-        case _ =>
-          throw new IllegalArgumentException(s"Unexpected type of 
AggregateFunc ${agg.describe}")
-      }
-
-      if (isPartitionCol(column)) {
-        // don't push down partition column, footer doesn't have max/min for 
partition column
-        return false
-      }
-      val structField = getStructFieldForCol(column)
-
-      structField.dataType match {
-        // not push down complex type
-        // not push down Timestamp because INT96 sort order is undefined,
-        // Parquet doesn't return statistics for INT96
-        // not push down Parquet Binary because min/max could be truncated
-        // (https://issues.apache.org/jira/browse/PARQUET-1685), Parquet Binary
-        // could be Spark StringType, BinaryType or DecimalType
-        case BooleanType | ByteType | ShortType | IntegerType
-             | LongType | FloatType | DoubleType | DateType =>
-          finalSchema = finalSchema.add(structField.copy(s"$aggType(" + 
structField.name + ")"))
-          true
-        case _ =>
-          false
-      }
-    }
-
-    if (!sparkSession.sessionState.conf.parquetAggregatePushDown ||
-      aggregation.groupByColumns.nonEmpty || dataFilters.length > 0) {
-      // Parquet footer has max/min/count for columns
-      // e.g. SELECT COUNT(col1) FROM t
-      // but footer doesn't have max/min/count for a column if max/min/count
-      // are combined with filter or group by
-      // e.g. SELECT COUNT(col1) FROM t WHERE col2 = 8
-      //      SELECT COUNT(col1) FROM t GROUP BY col2
-      // However, if the filter is on partition column, max/min/count can 
still be pushed down
-      // Todo:  add support if groupby column is partition col
-      //        (https://issues.apache.org/jira/browse/SPARK-36646)
+    if (!sparkSession.sessionState.conf.parquetAggregatePushDown) {
       return false
     }
 
-    aggregation.groupByColumns.foreach { col =>
-      if (col.fieldNames.length != 1) return false
-      finalSchema = finalSchema.add(getStructFieldForCol(col))
+    AggregatePushDownUtils.getSchemaForPushedAggregation(
+      aggregation,
+      schema,
+      partitionNameSet,
+      dataFilters) match {
+
+      case Some(schema) =>
+        finalSchema = schema
+        this.pushedAggregations = Some(aggregation)
+        true
+      case _ => false
     }
-
-    aggregation.aggregateExpressions.foreach {
-      case max: Max =>
-        if (!processMinOrMax(max)) return false
-      case min: Min =>
-        if (!processMinOrMax(min)) return false
-      case count: Count =>
-        if (count.column.fieldNames.length != 1 || count.isDistinct) return 
false
-        finalSchema =
-          finalSchema.add(StructField(s"count(" + count.column.fieldNames.head 
+ ")", LongType))
-      case _: CountStar =>
-        finalSchema = finalSchema.add(StructField("count(*)", LongType))
-      case _ =>
-        return false
-    }
-    this.pushedAggregations = Some(aggregation)
-    true
   }
 
   override def build(): Scan = {
     // the `finalSchema` is either pruned in pushAggregation (if aggregates are
     // pushed down), or pruned in readDataSchema() (in regular column 
pruning). These
     // two are mutual exclusive.
-    if (pushedAggregations.isEmpty) finalSchema = readDataSchema()
+    if (pushedAggregations.isEmpty) {
+      finalSchema = readDataSchema()
+    }
     ParquetScan(sparkSession, hadoopConf, fileIndex, dataSchema, finalSchema,
       readPartitionSchema(), pushedParquetFilters, options, pushedAggregations,
       partitionFilters, dataFilters)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileScanSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/FileScanSuite.scala
index 604a892..14b59ba 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/FileScanSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/FileScanSuite.scala
@@ -358,7 +358,7 @@ class FileScanSuite extends FileScanSuiteBase {
       Seq.empty),
     ("OrcScan",
       (s, fi, ds, rds, rps, f, o, pf, df) =>
-        OrcScan(s, s.sessionState.newHadoopConf(), fi, ds, rds, rps, o, f, pf, 
df),
+        OrcScan(s, s.sessionState.newHadoopConf(), fi, ds, rds, rps, o, None, 
f, pf, df),
       Seq.empty),
     ("CSVScan",
       (s, fi, ds, rds, rps, f, o, pf, df) => CSVScan(s, fi, ds, rds, rps, o, 
f, pf, df),
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAggregatePushDownSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceAggregatePushDownSuite.scala
similarity index 70%
rename from 
sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAggregatePushDownSuite.scala
rename to 
sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceAggregatePushDownSuite.scala
index 77ecd28..a3d01e4 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAggregatePushDownSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceAggregatePushDownSuite.scala
@@ -15,33 +15,39 @@
  * limitations under the License.
  */
 
-package org.apache.spark.sql.execution.datasources.parquet
+package org.apache.spark.sql.execution.datasources
 
 import java.sql.{Date, Timestamp}
 
 import org.apache.spark.SparkConf
-import org.apache.spark.sql._
+import org.apache.spark.sql.{ExplainSuiteHelper, QueryTest, Row}
+import org.apache.spark.sql.execution.datasources.orc.OrcTest
+import org.apache.spark.sql.execution.datasources.parquet.ParquetTest
 import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation
 import org.apache.spark.sql.functions.min
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.SharedSparkSession
-import org.apache.spark.sql.types._
+import org.apache.spark.sql.types.{BinaryType, BooleanType, ByteType, 
DateType, Decimal, DecimalType, DoubleType, FloatType, IntegerType, LongType, 
ShortType, StringType, StructField, StructType, TimestampType}
 
 /**
- * A test suite that tests Max/Min/Count push down.
+ * A test suite that tests aggregate push down for Parquet and ORC.
  */
-abstract class ParquetAggregatePushDownSuite
+trait FileSourceAggregatePushDownSuite
   extends QueryTest
-  with ParquetTest
+  with FileBasedDataSourceTest
   with SharedSparkSession
   with ExplainSuiteHelper {
+
   import testImplicits._
 
-  test("aggregate push down - nested column: Max(top level column) not push 
down") {
+  protected def format: String
+  // The SQL config key for enabling aggregate push down.
+  protected val aggPushDownEnabledKey: String
+
+  test("nested column: Max(top level column) not push down") {
     val data = (1 to 10).map(i => Tuple1((i, Seq(s"val_$i"))))
-    withSQLConf(
-      SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") {
-      withParquetTable(data, "t") {
+    withSQLConf(aggPushDownEnabledKey -> "true") {
+      withDataSourceTable(data, "t") {
         val max = sql("SELECT Max(_1) FROM t")
         max.queryExecution.optimizedPlan.collect {
           case _: DataSourceV2ScanRelation =>
@@ -53,11 +59,10 @@ abstract class ParquetAggregatePushDownSuite
     }
   }
 
-  test("aggregate push down - nested column: Count(top level column) push 
down") {
+  test("nested column: Count(top level column) push down") {
     val data = (1 to 10).map(i => Tuple1((i, Seq(s"val_$i"))))
-    withSQLConf(
-      SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") {
-      withParquetTable(data, "t") {
+    withSQLConf(aggPushDownEnabledKey -> "true") {
+      withDataSourceTable(data, "t") {
         val count = sql("SELECT Count(_1) FROM t")
         count.queryExecution.optimizedPlan.collect {
           case _: DataSourceV2ScanRelation =>
@@ -70,11 +75,10 @@ abstract class ParquetAggregatePushDownSuite
     }
   }
 
-  test("aggregate push down - nested column: Max(nested column) not push 
down") {
+  test("nested column: Max(nested sub-field) not push down") {
     val data = (1 to 10).map(i => Tuple1((i, Seq(s"val_$i"))))
-    withSQLConf(
-      SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") {
-      withParquetTable(data, "t") {
+    withSQLConf(aggPushDownEnabledKey-> "true") {
+      withDataSourceTable(data, "t") {
         val max = sql("SELECT Max(_1._2[0]) FROM t")
         max.queryExecution.optimizedPlan.collect {
           case _: DataSourceV2ScanRelation =>
@@ -86,11 +90,10 @@ abstract class ParquetAggregatePushDownSuite
     }
   }
 
-  test("aggregate push down - nested column: Count(nested column) not push 
down") {
+  test("nested column: Count(nested sub-field) not push down") {
     val data = (1 to 10).map(i => Tuple1((i, Seq(s"val_$i"))))
-    withSQLConf(
-      SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") {
-      withParquetTable(data, "t") {
+    withSQLConf(aggPushDownEnabledKey -> "true") {
+      withDataSourceTable(data, "t") {
         val count = sql("SELECT Count(_1._2[0]) FROM t")
         count.queryExecution.optimizedPlan.collect {
           case _: DataSourceV2ScanRelation =>
@@ -103,13 +106,13 @@ abstract class ParquetAggregatePushDownSuite
     }
   }
 
-  test("aggregate push down - Max(partition Col): not push dow") {
+  test("Max(partition column): not push down") {
     withTempPath { dir =>
       spark.range(10).selectExpr("id", "id % 3 as p")
-        .write.partitionBy("p").parquet(dir.getCanonicalPath)
+        .write.partitionBy("p").format(format).save(dir.getCanonicalPath)
       withTempView("tmp") {
-        
spark.read.parquet(dir.getCanonicalPath).createOrReplaceTempView("tmp");
-        withSQLConf(SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") {
+        
spark.read.format(format).load(dir.getCanonicalPath).createOrReplaceTempView("tmp")
+        withSQLConf(aggPushDownEnabledKey -> "true") {
           val max = sql("SELECT Max(p) FROM tmp")
           max.queryExecution.optimizedPlan.collect {
             case _: DataSourceV2ScanRelation =>
@@ -123,15 +126,16 @@ abstract class ParquetAggregatePushDownSuite
     }
   }
 
-  test("aggregate push down - Count(partition Col): push down") {
+  test("Count(partition column): push down") {
     withTempPath { dir =>
-      spark.range(10).selectExpr("id", "id % 3 as p")
-        .write.partitionBy("p").parquet(dir.getCanonicalPath)
+      spark.range(10).selectExpr("if(id % 2 = 0, null, id) AS n", "id % 3 as 
p")
+        .write.partitionBy("p").format(format).save(dir.getCanonicalPath)
       withTempView("tmp") {
-        
spark.read.parquet(dir.getCanonicalPath).createOrReplaceTempView("tmp");
-        Seq("false", "true").foreach { enableVectorizedReader =>
-          withSQLConf(SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true",
-            vectorizedReaderEnabledKey -> enableVectorizedReader) {
+        
spark.read.format(format).load(dir.getCanonicalPath).createOrReplaceTempView("tmp")
+        val enableVectorizedReader = Seq("false", "true")
+        for (testVectorizedReader <- enableVectorizedReader) {
+          withSQLConf(aggPushDownEnabledKey -> "true",
+            vectorizedReaderEnabledKey -> testVectorizedReader) {
             val count = sql("SELECT COUNT(p) FROM tmp")
             count.queryExecution.optimizedPlan.collect {
               case _: DataSourceV2ScanRelation =>
@@ -146,12 +150,11 @@ abstract class ParquetAggregatePushDownSuite
     }
   }
 
-  test("aggregate push down - Filter alias over aggregate") {
+  test("filter alias over aggregate") {
     val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19),
       (9, "mno", 7), (2, null, 6))
-    withParquetTable(data, "t") {
-      withSQLConf(
-        SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") {
+    withDataSourceTable(data, "t") {
+      withSQLConf(aggPushDownEnabledKey -> "true") {
         val selectAgg = sql("SELECT min(_1) + max(_1) as res FROM t having res 
> 1")
         selectAgg.queryExecution.optimizedPlan.collect {
           case _: DataSourceV2ScanRelation =>
@@ -164,12 +167,11 @@ abstract class ParquetAggregatePushDownSuite
     }
   }
 
-  test("aggregate push down - alias over aggregate") {
+  test("alias over aggregate") {
     val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19),
       (9, "mno", 7), (2, null, 6))
-    withParquetTable(data, "t") {
-      withSQLConf(
-        SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") {
+    withDataSourceTable(data, "t") {
+      withSQLConf(aggPushDownEnabledKey -> "true") {
         val selectAgg = sql("SELECT min(_1) + 1 as minPlus1, min(_1) + 2 as 
minPlus2 FROM t")
         selectAgg.queryExecution.optimizedPlan.collect {
           case _: DataSourceV2ScanRelation =>
@@ -182,12 +184,11 @@ abstract class ParquetAggregatePushDownSuite
     }
   }
 
-  test("aggregate push down - aggregate over alias not push down") {
+  test("aggregate over alias not push down") {
     val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19),
       (9, "mno", 7), (2, null, 6))
-    withParquetTable(data, "t") {
-      withSQLConf(
-        SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") {
+    withDataSourceTable(data, "t") {
+      withSQLConf(aggPushDownEnabledKey -> "true") {
         val df = spark.table("t")
         val query = df.select($"_1".as("col1")).agg(min($"col1"))
         query.queryExecution.optimizedPlan.collect {
@@ -201,12 +202,11 @@ abstract class ParquetAggregatePushDownSuite
     }
   }
 
-  test("aggregate push down - query with group by not push down") {
+  test("query with group by not push down") {
     val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19),
       (9, "mno", 7), (2, null, 7))
-    withParquetTable(data, "t") {
-      withSQLConf(
-        SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") {
+    withDataSourceTable(data, "t") {
+      withSQLConf(aggPushDownEnabledKey -> "true") {
         // aggregate not pushed down if there is group by
         val selectAgg = sql("SELECT min(_1) FROM t GROUP BY _3 ")
         selectAgg.queryExecution.optimizedPlan.collect {
@@ -220,12 +220,11 @@ abstract class ParquetAggregatePushDownSuite
     }
   }
 
-  test("aggregate push down - aggregate with data filter cannot be pushed 
down") {
+  test("aggregate with data filter cannot be pushed down") {
     val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19),
       (9, "mno", 7), (2, null, 7))
-    withParquetTable(data, "t") {
-      withSQLConf(
-        SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") {
+    withDataSourceTable(data, "t") {
+      withSQLConf(aggPushDownEnabledKey -> "true") {
         // aggregate not pushed down if there is filter
         val selectAgg = sql("SELECT min(_3) FROM t WHERE _1 > 0")
         selectAgg.queryExecution.optimizedPlan.collect {
@@ -239,14 +238,14 @@ abstract class ParquetAggregatePushDownSuite
     }
   }
 
-  test("aggregate push down - aggregate with partition filter can be pushed 
down") {
+  test("aggregate with partition filter can be pushed down") {
     withTempPath { dir =>
       spark.range(10).selectExpr("id", "id % 3 as p")
-        .write.partitionBy("p").parquet(dir.getCanonicalPath)
+        .write.partitionBy("p").format(format).save(dir.getCanonicalPath)
       withTempView("tmp") {
-        
spark.read.parquet(dir.getCanonicalPath).createOrReplaceTempView("tmp");
+        
spark.read.format(format).load(dir.getCanonicalPath).createOrReplaceTempView("tmp")
         Seq("false", "true").foreach { enableVectorizedReader =>
-          withSQLConf(SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true",
+          withSQLConf(aggPushDownEnabledKey -> "true",
             vectorizedReaderEnabledKey -> enableVectorizedReader) {
             val max = sql("SELECT max(id), min(id), count(id) FROM tmp WHERE p 
= 0")
             max.queryExecution.optimizedPlan.collect {
@@ -262,12 +261,11 @@ abstract class ParquetAggregatePushDownSuite
     }
   }
 
-  test("aggregate push down - push down only if all the aggregates can be 
pushed down") {
+  test("push down only if all the aggregates can be pushed down") {
     val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19),
       (9, "mno", 7), (2, null, 7))
-    withParquetTable(data, "t") {
-      withSQLConf(
-        SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") {
+    withDataSourceTable(data, "t") {
+      withSQLConf(aggPushDownEnabledKey -> "true") {
         // not push down since sum can't be pushed down
         val selectAgg = sql("SELECT min(_1), sum(_3) FROM t")
         selectAgg.queryExecution.optimizedPlan.collect {
@@ -284,9 +282,8 @@ abstract class ParquetAggregatePushDownSuite
   test("aggregate push down - MIN/MAX/COUNT") {
     val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19),
       (9, "mno", 7), (2, null, 6))
-    withParquetTable(data, "t") {
-      withSQLConf(
-        SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") {
+    withDataSourceTable(data, "t") {
+      withSQLConf(aggPushDownEnabledKey -> "true") {
         val selectAgg = sql("SELECT min(_3), min(_3), max(_3), min(_1), 
max(_1), max(_1)," +
           " count(*), count(_1), count(_2), count(_3) FROM t")
         selectAgg.queryExecution.optimizedPlan.collect {
@@ -308,7 +305,13 @@ abstract class ParquetAggregatePushDownSuite
     }
   }
 
-  test("aggregate push down - different data types") {
+  private def testPushDownForAllDataTypes(
+      inputRows: Seq[Row],
+      expectedMinWithAllTypes: Seq[Row],
+      expectedMinWithOutTSAndBinary: Seq[Row],
+      expectedMaxWithAllTypes: Seq[Row],
+      expectedMaxWithOutTSAndBinary: Seq[Row],
+      expectedCount: Seq[Row]): Unit = {
     implicit class StringToDate(s: String) {
       def date: Date = Date.valueOf(s)
     }
@@ -317,49 +320,6 @@ abstract class ParquetAggregatePushDownSuite
       def ts: Timestamp = Timestamp.valueOf(s)
     }
 
-    val rows =
-      Seq(
-        Row(
-          "a string",
-          true,
-          10.toByte,
-          "Spark SQL".getBytes,
-          12.toShort,
-          3,
-          Long.MaxValue,
-          0.15.toFloat,
-          0.75D,
-          Decimal("12.345678"),
-          ("2021-01-01").date,
-          ("2015-01-01 23:50:59.123").ts),
-        Row(
-          "test string",
-          false,
-          1.toByte,
-          "Parquet".getBytes,
-          2.toShort,
-          null,
-          Long.MinValue,
-          0.25.toFloat,
-          0.85D,
-          Decimal("1.2345678"),
-          ("2015-01-01").date,
-          ("2021-01-01 23:50:59.123").ts),
-        Row(
-          null,
-          true,
-          10000.toByte,
-          "Spark ML".getBytes,
-          222.toShort,
-          113,
-          11111111L,
-          0.25.toFloat,
-          0.75D,
-          Decimal("12345.678"),
-          ("2004-06-19").date,
-          ("1999-08-26 10:43:59.123").ts)
-      )
-
     val schema = StructType(List(StructField("StringCol", StringType, true),
       StructField("BooleanCol", BooleanType, false),
       StructField("ByteCol", ByteType, false),
@@ -373,13 +333,13 @@ abstract class ParquetAggregatePushDownSuite
       StructField("DateCol", DateType, false),
       StructField("TimestampCol", TimestampType, false)).toArray)
 
-    val rdd = sparkContext.parallelize(rows)
+    val rdd = sparkContext.parallelize(inputRows)
     withTempPath { file =>
-      spark.createDataFrame(rdd, schema).write.parquet(file.getCanonicalPath)
+      spark.createDataFrame(rdd, 
schema).write.format(format).save(file.getCanonicalPath)
       withTempView("test") {
-        
spark.read.parquet(file.getCanonicalPath).createOrReplaceTempView("test")
+        
spark.read.format(format).load(file.getCanonicalPath).createOrReplaceTempView("test")
         Seq("false", "true").foreach { enableVectorizedReader =>
-          withSQLConf(SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true",
+          withSQLConf(aggPushDownEnabledKey -> "true",
             vectorizedReaderEnabledKey -> enableVectorizedReader) {
 
             val testMinWithAllTypes = sql("SELECT min(StringCol), 
min(BooleanCol), min(ByteCol), " +
@@ -389,7 +349,8 @@ abstract class ParquetAggregatePushDownSuite
             // INT96 (Timestamp) sort order is undefined, parquet doesn't 
return stats for this type
             // so aggregates are not pushed down
             // In addition, Parquet Binary min/max could be truncated, so we 
disable aggregate
-            // push down for Parquet Binary (could be Spark StringType, 
BinaryType or DecimalType)
+            // push down for Parquet Binary (could be Spark StringType, 
BinaryType or DecimalType).
+            // Also do not push down for ORC with same reason.
             testMinWithAllTypes.queryExecution.optimizedPlan.collect {
               case _: DataSourceV2ScanRelation =>
                 val expected_plan_fragment =
@@ -397,9 +358,7 @@ abstract class ParquetAggregatePushDownSuite
                 checkKeywordsExistsInExplain(testMinWithAllTypes, 
expected_plan_fragment)
             }
 
-            checkAnswer(testMinWithAllTypes, Seq(Row("a string", false, 
1.toByte,
-              "Parquet".getBytes, 2.toShort, 3, -9223372036854775808L, 
0.15.toFloat, 0.75D,
-              1.23457, ("2004-06-19").date, ("1999-08-26 10:43:59.123").ts)))
+            checkAnswer(testMinWithAllTypes, expectedMinWithAllTypes)
 
             val testMinWithOutTSAndBinary = sql("SELECT min(BooleanCol), 
min(ByteCol), " +
               "min(ShortCol), min(IntegerCol), min(LongCol), min(FloatCol), " +
@@ -419,8 +378,7 @@ abstract class ParquetAggregatePushDownSuite
                 checkKeywordsExistsInExplain(testMinWithOutTSAndBinary, 
expected_plan_fragment)
             }
 
-            checkAnswer(testMinWithOutTSAndBinary, Seq(Row(false, 1.toByte,
-              2.toShort, 3, -9223372036854775808L, 0.15.toFloat, 0.75D, 
("2004-06-19").date)))
+            checkAnswer(testMinWithOutTSAndBinary, 
expectedMinWithOutTSAndBinary)
 
             val testMaxWithAllTypes = sql("SELECT max(StringCol), 
max(BooleanCol), " +
               "max(ByteCol), max(BinaryCol), max(ShortCol), max(IntegerCol), 
max(LongCol), " +
@@ -430,7 +388,8 @@ abstract class ParquetAggregatePushDownSuite
             // INT96 (Timestamp) sort order is undefined, parquet doesn't 
return stats for this type
             // so aggregates are not pushed down
             // In addition, Parquet Binary min/max could be truncated, so we 
disable aggregate
-            // push down for Parquet Binary (could be Spark StringType, 
BinaryType or DecimalType)
+            // push down for Parquet Binary (could be Spark StringType, 
BinaryType or DecimalType).
+            // Also do not push down for ORC with same reason.
             testMaxWithAllTypes.queryExecution.optimizedPlan.collect {
               case _: DataSourceV2ScanRelation =>
                 val expected_plan_fragment =
@@ -438,9 +397,7 @@ abstract class ParquetAggregatePushDownSuite
                 checkKeywordsExistsInExplain(testMaxWithAllTypes, 
expected_plan_fragment)
             }
 
-            checkAnswer(testMaxWithAllTypes, Seq(Row("test string", true, 
16.toByte,
-              "Spark SQL".getBytes, 222.toShort, 113, 9223372036854775807L, 
0.25.toFloat, 0.85D,
-              12345.678, ("2021-01-01").date, ("2021-01-01 23:50:59.123").ts)))
+            checkAnswer(testMaxWithAllTypes, expectedMaxWithAllTypes)
 
             val testMaxWithoutTSAndBinary = sql("SELECT max(BooleanCol), 
max(ByteCol), " +
               "max(ShortCol), max(IntegerCol), max(LongCol), max(FloatCol), " +
@@ -460,8 +417,7 @@ abstract class ParquetAggregatePushDownSuite
                 checkKeywordsExistsInExplain(testMaxWithoutTSAndBinary, 
expected_plan_fragment)
             }
 
-            checkAnswer(testMaxWithoutTSAndBinary, Seq(Row(true, 16.toByte,
-              222.toShort, 113, 9223372036854775807L, 0.25.toFloat, 0.85D, 
("2021-01-01").date)))
+            checkAnswer(testMaxWithoutTSAndBinary, 
expectedMaxWithOutTSAndBinary)
 
             val testCount = sql("SELECT count(StringCol), count(BooleanCol)," +
               " count(ByteCol), count(BinaryCol), count(ShortCol), 
count(IntegerCol)," +
@@ -487,22 +443,97 @@ abstract class ParquetAggregatePushDownSuite
                 checkKeywordsExistsInExplain(testCount, expected_plan_fragment)
             }
 
-            checkAnswer(testCount, Seq(Row(2, 3, 3, 3, 3, 2, 3, 3, 3, 3, 3, 
3)))
+            checkAnswer(testCount, expectedCount)
           }
         }
       }
     }
   }
 
-  test("aggregate push down - column name case sensitivity") {
+  test("aggregate push down - different data types") {
+    implicit class StringToDate(s: String) {
+      def date: Date = Date.valueOf(s)
+    }
+
+    implicit class StringToTs(s: String) {
+      def ts: Timestamp = Timestamp.valueOf(s)
+    }
+
+    val rows =
+      Seq(
+        Row(
+          "a string",
+          true,
+          10.toByte,
+          "Spark SQL".getBytes,
+          12.toShort,
+          3,
+          Long.MaxValue,
+          0.15.toFloat,
+          0.75D,
+          Decimal("12.345678"),
+          ("2021-01-01").date,
+          ("2015-01-01 23:50:59.123").ts),
+        Row(
+          "test string",
+          false,
+          1.toByte,
+          "Parquet".getBytes,
+          2.toShort,
+          null,
+          Long.MinValue,
+          0.25.toFloat,
+          0.85D,
+          Decimal("1.2345678"),
+          ("2015-01-01").date,
+          ("2021-01-01 23:50:59.123").ts),
+        Row(
+          null,
+          true,
+          10000.toByte,
+          "Spark ML".getBytes,
+          222.toShort,
+          113,
+          11111111L,
+          0.25.toFloat,
+          0.75D,
+          Decimal("12345.678"),
+          ("2004-06-19").date,
+          ("1999-08-26 10:43:59.123").ts)
+      )
+
+    testPushDownForAllDataTypes(
+      rows,
+      Seq(Row("a string", false, 1.toByte,
+        "Parquet".getBytes, 2.toShort, 3, -9223372036854775808L, 0.15.toFloat, 
0.75D,
+        1.23457, ("2004-06-19").date, ("1999-08-26 10:43:59.123").ts)),
+      Seq(Row(false, 1.toByte,
+        2.toShort, 3, -9223372036854775808L, 0.15.toFloat, 0.75D, 
("2004-06-19").date)),
+      Seq(Row("test string", true, 16.toByte,
+        "Spark SQL".getBytes, 222.toShort, 113, 9223372036854775807L, 
0.25.toFloat, 0.85D,
+        12345.678, ("2021-01-01").date, ("2021-01-01 23:50:59.123").ts)),
+      Seq(Row(true, 16.toByte,
+        222.toShort, 113, 9223372036854775807L, 0.25.toFloat, 0.85D, 
("2021-01-01").date)),
+      Seq(Row(2, 3, 3, 3, 3, 2, 3, 3, 3, 3, 3, 3))
+    )
+
+    // Test for 0 row (empty file)
+    val nullRow = Row.fromSeq((1 to 12).map(_ => null))
+    val nullRowWithOutTSAndBinary = Row.fromSeq((1 to 8).map(_ => null))
+    val zeroCount = Row.fromSeq((1 to 12).map(_ => 0))
+    testPushDownForAllDataTypes(Seq.empty, Seq(nullRow), 
Seq(nullRowWithOutTSAndBinary),
+      Seq(nullRow), Seq(nullRowWithOutTSAndBinary), Seq(zeroCount))
+  }
+
+  test("column name case sensitivity") {
     Seq("false", "true").foreach { enableVectorizedReader =>
-      withSQLConf(SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true",
+      withSQLConf(aggPushDownEnabledKey -> "true",
         vectorizedReaderEnabledKey -> enableVectorizedReader) {
         withTempPath { dir =>
           spark.range(10).selectExpr("id", "id % 3 as p")
-            .write.partitionBy("p").parquet(dir.getCanonicalPath)
+            .write.partitionBy("p").format(format).save(dir.getCanonicalPath)
           withTempView("tmp") {
-            
spark.read.parquet(dir.getCanonicalPath).createOrReplaceTempView("tmp");
+            
spark.read.format(format).load(dir.getCanonicalPath).createOrReplaceTempView("tmp")
             val selectAgg = sql("SELECT max(iD), min(Id) FROM tmp")
             selectAgg.queryExecution.optimizedPlan.collect {
               case _: DataSourceV2ScanRelation =>
@@ -518,18 +549,41 @@ abstract class ParquetAggregatePushDownSuite
   }
 }
 
+abstract class ParquetAggregatePushDownSuite
+  extends FileSourceAggregatePushDownSuite with ParquetTest {
+
+  override def format: String = "parquet"
+  override protected val aggPushDownEnabledKey: String =
+    SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key
+}
+
 class ParquetV1AggregatePushDownSuite extends ParquetAggregatePushDownSuite {
 
   override protected def sparkConf: SparkConf =
-    super
-      .sparkConf
-      .set(SQLConf.USE_V1_SOURCE_LIST, "parquet")
+    super.sparkConf.set(SQLConf.USE_V1_SOURCE_LIST, "parquet")
 }
 
 class ParquetV2AggregatePushDownSuite extends ParquetAggregatePushDownSuite {
 
   override protected def sparkConf: SparkConf =
-    super
-      .sparkConf
-      .set(SQLConf.USE_V1_SOURCE_LIST, "")
+    super.sparkConf.set(SQLConf.USE_V1_SOURCE_LIST, "")
+}
+
+abstract class OrcAggregatePushDownSuite extends OrcTest with 
FileSourceAggregatePushDownSuite {
+
+  override def format: String = "orc"
+  override protected val aggPushDownEnabledKey: String =
+    SQLConf.ORC_AGGREGATE_PUSHDOWN_ENABLED.key
+}
+
+class OrcV1AggregatePushDownSuite extends OrcAggregatePushDownSuite {
+
+  override protected def sparkConf: SparkConf =
+    super.sparkConf.set(SQLConf.USE_V1_SOURCE_LIST, "orc")
+}
+
+class OrcV2AggregatePushDownSuite extends OrcAggregatePushDownSuite {
+
+  override protected def sparkConf: SparkConf =
+    super.sparkConf.set(SQLConf.USE_V1_SOURCE_LIST, "")
 }

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to