This is an automated email from the ASF dual-hosted git repository. viirya pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 609e749 [SPARK-34960][SQL] Aggregate push down for ORC 609e749 is described below commit 609e7498326ebedb06904a1f5bab59b739380b1a Author: Cheng Su <chen...@fb.com> AuthorDate: Thu Oct 28 17:29:15 2021 -0700 [SPARK-34960][SQL] Aggregate push down for ORC ### What changes were proposed in this pull request? This PR is to add aggregate push down feature for ORC data source v2 reader. At a high level, the PR does: * The supported aggregate expression is MIN/MAX/COUNT same as [Parquet aggregate push down](https://github.com/apache/spark/pull/33639). * BooleanType, ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType, DateType are allowed in MIN/MAXX aggregate push down. All other columns types are not allowed in MIN/MAX aggregate push down. * All columns types are supported in COUNT aggregate push down. * Nested column's sub-fields are disallowed in aggregate push down. * If the file does not have valid statistics, Spark will throw exception and fail query. * If aggregate has filter or group-by column, aggregate will not be pushed down. At code level, the PR does: * `OrcScanBuilder`: `pushAggregation()` checks whether the aggregation can be pushed down. The most checking logic is shared between Parquet and ORC, extracted into `AggregatePushDownUtils.getSchemaForPushedAggregation()`. `OrcScanBuilder` will create a `OrcScan` with aggregation and aggregation data schema. * `OrcScan`: `createReaderFactory` creates a ORC reader factory with aggregation and schema. Similar change with `ParquetScan`. * `OrcPartitionReaderFactory`: `buildReaderWithAggregates` creates a ORC reader with aggregate push down (i.e. read ORC file footer to process columns statistics, instead of reading actual data in the file). `buildColumnarReaderWithAggregates` creates a columnar ORC reader similarly. Both delegate the real work to read footer in `OrcUtils.createAggInternalRowFromFooter`. * `OrcUtils.createAggInternalRowFromFooter`: reads ORC file footer to process columns statistics (real heavy lift happens here). Similar to `ParquetUtils.createAggInternalRowFromFooter`. Leverage utility method such as `OrcFooterReader.readStatistics`. * `OrcFooterReader`: `readStatistics` reads the ORC `ColumnStatistics[]` into Spark `OrcColumnStatistics`. The transformation is needed here, because ORC `ColumnStatistics[]` stores all columns statistics in a flatten array style, and hard to process. Spark `OrcColumnStatistics` stores the statistics in nested tree structure (e.g. like `StructType`). This is used by `OrcUtils.createAggInternalRowFromFooter` * `OrcColumnStatistics`: the easy-to-manipulate structure for ORC `ColumnStatistics`. This is used by `OrcFooterReader.readStatistics`. ### Why are the changes needed? To improve the performance of query with aggregate. ### Does this PR introduce _any_ user-facing change? Yes. A user-facing config `spark.sql.orc.aggregatePushdown` is added to control enabling/disabling the aggregate push down for ORC. By default the feature is disabled. ### How was this patch tested? Added unit test in `FileSourceAggregatePushDownSuite.scala`. Refactored all unit tests in https://github.com/apache/spark/pull/33639, and it now works for both Parquet and ORC. Closes #34298 from c21/orc-agg. Authored-by: Cheng Su <chen...@fb.com> Signed-off-by: Liang-Chi Hsieh <vii...@gmail.com> --- .../org/apache/spark/sql/internal/SQLConf.scala | 10 + .../org/apache/spark/sql/types/StructType.scala | 2 +- .../datasources/orc/OrcColumnStatistics.java | 80 +++++ .../execution/datasources/orc/OrcFooterReader.java | 67 +++++ .../datasources/AggregatePushDownUtils.scala | 141 +++++++++ .../datasources/orc/OrcDeserializer.scala | 16 + .../sql/execution/datasources/orc/OrcUtils.scala | 122 +++++++- .../datasources/parquet/ParquetUtils.scala | 41 --- .../v2/orc/OrcPartitionReaderFactory.scala | 93 ++++-- .../sql/execution/datasources/v2/orc/OrcScan.scala | 45 ++- .../datasources/v2/orc/OrcScanBuilder.scala | 43 ++- .../v2/parquet/ParquetPartitionReaderFactory.scala | 14 +- .../datasources/v2/parquet/ParquetScan.scala | 10 +- .../v2/parquet/ParquetScanBuilder.scala | 93 ++---- .../scala/org/apache/spark/sql/FileScanSuite.scala | 2 +- ...cala => FileSourceAggregatePushDownSuite.scala} | 324 ++++++++++++--------- 16 files changed, 804 insertions(+), 299 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index fe3204b..def6bbc1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -960,6 +960,14 @@ object SQLConf { .booleanConf .createWithDefault(true) + val ORC_AGGREGATE_PUSHDOWN_ENABLED = buildConf("spark.sql.orc.aggregatePushdown") + .doc("If true, aggregates will be pushed down to ORC for optimization. Support MIN, MAX and " + + "COUNT as aggregate expression. For MIN/MAX, support boolean, integer, float and date " + + "type. For COUNT, support all data types.") + .version("3.3.0") + .booleanConf + .createWithDefault(false) + val ORC_SCHEMA_MERGING_ENABLED = buildConf("spark.sql.orc.mergeSchema") .doc("When true, the Orc data source merges schemas collected from all data files, " + "otherwise the schema is picked from a random data file.") @@ -3706,6 +3714,8 @@ class SQLConf extends Serializable with Logging { def orcFilterPushDown: Boolean = getConf(ORC_FILTER_PUSHDOWN_ENABLED) + def orcAggregatePushDown: Boolean = getConf(ORC_AGGREGATE_PUSHDOWN_ENABLED) + def isOrcSchemaMergingEnabled: Boolean = getConf(ORC_SCHEMA_MERGING_ENABLED) def verifyPartitionPath: Boolean = getConf(HIVE_VERIFY_PARTITION_PATH) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index 205b08f..6707fb2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -115,7 +115,7 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru def names: Array[String] = fieldNames private lazy val fieldNamesSet: Set[String] = fieldNames.toSet - private[sql] lazy val nameToField: Map[String, StructField] = fields.map(f => f.name -> f).toMap + private lazy val nameToField: Map[String, StructField] = fields.map(f => f.name -> f).toMap private lazy val nameToIndex: Map[String, Int] = fieldNames.zipWithIndex.toMap override def equals(that: Any): Boolean = { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnStatistics.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnStatistics.java new file mode 100644 index 0000000..8adb9e8 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnStatistics.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.orc; + +import org.apache.orc.ColumnStatistics; + +import java.util.ArrayList; +import java.util.List; + +/** + * Columns statistics interface wrapping ORC {@link ColumnStatistics}s. + * + * Because ORC {@link ColumnStatistics}s are stored as an flatten array in ORC file footer, + * this class is used to covert ORC {@link ColumnStatistics}s from array to nested tree structure, + * according to data types. The flatten array stores all data types (including nested types) in + * tree pre-ordering. This is used for aggregate push down in ORC. + * + * For nested data types (array, map and struct), the sub-field statistics are stored recursively + * inside parent column's children field. Here is an example of {@link OrcColumnStatistics}: + * + * Data schema: + * c1: int + * c2: struct<f1: int, f2: float> + * c3: map<key: int, value: string> + * c4: array<int> + * + * OrcColumnStatistics + * | (children) + * --------------------------------------------- + * / | \ \ + * c1 c2 c3 c4 + * (integer) (struct) (map) (array) +* (min:1, | (children) | (children) | (children) + * max:10) ----- ----- element + * / \ / \ (integer) + * c2.f1 c2.f2 key value + * (integer) (float) (integer) (string) + * (min:0.1, (min:"a", + * max:100.5) max:"zzz") + */ +public class OrcColumnStatistics { + private final ColumnStatistics statistics; + private final List<OrcColumnStatistics> children; + + public OrcColumnStatistics(ColumnStatistics statistics) { + this.statistics = statistics; + this.children = new ArrayList<>(); + } + + public ColumnStatistics getStatistics() { + return statistics; + } + + public OrcColumnStatistics get(int ordinal) { + if (ordinal < 0 || ordinal >= children.size()) { + throw new IndexOutOfBoundsException( + String.format("Ordinal %d out of bounds of statistics size %d", ordinal, children.size())); + } + return children.get(ordinal); + } + + public void add(OrcColumnStatistics newChild) { + children.add(newChild); + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcFooterReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcFooterReader.java new file mode 100644 index 0000000..546b048 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcFooterReader.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.orc; + +import org.apache.orc.ColumnStatistics; +import org.apache.orc.Reader; +import org.apache.orc.TypeDescription; +import org.apache.spark.sql.types.*; + +import java.util.Arrays; +import java.util.LinkedList; +import java.util.Queue; + +/** + * {@link OrcFooterReader} is a util class which encapsulates the helper + * methods of reading ORC file footer. + */ +public class OrcFooterReader { + + /** + * Read the columns statistics from ORC file footer. + * + * @param orcReader the reader to read ORC file footer. + * @return Statistics for all columns in the file. + */ + public static OrcColumnStatistics readStatistics(Reader orcReader) { + TypeDescription orcSchema = orcReader.getSchema(); + ColumnStatistics[] orcStatistics = orcReader.getStatistics(); + StructType sparkSchema = OrcUtils.toCatalystSchema(orcSchema); + return convertStatistics(sparkSchema, new LinkedList<>(Arrays.asList(orcStatistics))); + } + + /** + * Convert a queue of ORC {@link ColumnStatistics}s into Spark {@link OrcColumnStatistics}. + * The queue of ORC {@link ColumnStatistics}s are assumed to be ordered as tree pre-order. + */ + private static OrcColumnStatistics convertStatistics( + DataType sparkSchema, Queue<ColumnStatistics> orcStatistics) { + OrcColumnStatistics statistics = new OrcColumnStatistics(orcStatistics.remove()); + if (sparkSchema instanceof StructType) { + for (StructField field : ((StructType) sparkSchema).fields()) { + statistics.add(convertStatistics(field.dataType(), orcStatistics)); + } + } else if (sparkSchema instanceof MapType) { + statistics.add(convertStatistics(((MapType) sparkSchema).keyType(), orcStatistics)); + statistics.add(convertStatistics(((MapType) sparkSchema).valueType(), orcStatistics)); + } else if (sparkSchema instanceof ArrayType) { + statistics.add(convertStatistics(((ArrayType) sparkSchema).elementType(), orcStatistics)); + } + return statistics; + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/AggregatePushDownUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/AggregatePushDownUtils.scala new file mode 100644 index 0000000..6340d97 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/AggregatePushDownUtils.scala @@ -0,0 +1,141 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.connector.expressions.NamedReference +import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Aggregation, Count, CountStar, Max, Min} +import org.apache.spark.sql.execution.RowToColumnConverter +import org.apache.spark.sql.execution.vectorized.{OffHeapColumnVector, OnHeapColumnVector} +import org.apache.spark.sql.types.{BooleanType, ByteType, DateType, DoubleType, FloatType, IntegerType, LongType, ShortType, StructField, StructType} +import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} + +/** + * Utility class for aggregate push down to Parquet and ORC. + */ +object AggregatePushDownUtils { + + /** + * Get the data schema for aggregate to be pushed down. + */ + def getSchemaForPushedAggregation( + aggregation: Aggregation, + schema: StructType, + partitionNames: Set[String], + dataFilters: Seq[Expression]): Option[StructType] = { + + var finalSchema = new StructType() + + def getStructFieldForCol(col: NamedReference): StructField = { + schema.apply(col.fieldNames.head) + } + + def isPartitionCol(col: NamedReference) = { + partitionNames.contains(col.fieldNames.head) + } + + def processMinOrMax(agg: AggregateFunc): Boolean = { + val (column, aggType) = agg match { + case max: Max => (max.column, "max") + case min: Min => (min.column, "min") + case _ => + throw new IllegalArgumentException(s"Unexpected type of AggregateFunc ${agg.describe}") + } + + if (isPartitionCol(column)) { + // don't push down partition column, footer doesn't have max/min for partition column + return false + } + val structField = getStructFieldForCol(column) + + structField.dataType match { + // not push down complex type + // not push down Timestamp because INT96 sort order is undefined, + // Parquet doesn't return statistics for INT96 + // not push down Parquet Binary because min/max could be truncated + // (https://issues.apache.org/jira/browse/PARQUET-1685), Parquet Binary + // could be Spark StringType, BinaryType or DecimalType. + // not push down for ORC with same reason. + case BooleanType | ByteType | ShortType | IntegerType + | LongType | FloatType | DoubleType | DateType => + finalSchema = finalSchema.add(structField.copy(s"$aggType(" + structField.name + ")")) + true + case _ => + false + } + } + + if (aggregation.groupByColumns.nonEmpty || dataFilters.nonEmpty) { + // Parquet/ORC footer has max/min/count for columns + // e.g. SELECT COUNT(col1) FROM t + // but footer doesn't have max/min/count for a column if max/min/count + // are combined with filter or group by + // e.g. SELECT COUNT(col1) FROM t WHERE col2 = 8 + // SELECT COUNT(col1) FROM t GROUP BY col2 + // However, if the filter is on partition column, max/min/count can still be pushed down + // Todo: add support if groupby column is partition col + // (https://issues.apache.org/jira/browse/SPARK-36646) + return None + } + + aggregation.aggregateExpressions.foreach { + case max: Max => + if (!processMinOrMax(max)) return None + case min: Min => + if (!processMinOrMax(min)) return None + case count: Count => + if (count.column.fieldNames.length != 1 || count.isDistinct) return None + finalSchema = + finalSchema.add(StructField(s"count(" + count.column.fieldNames.head + ")", LongType)) + case _: CountStar => + finalSchema = finalSchema.add(StructField("count(*)", LongType)) + case _ => + return None + } + + Some(finalSchema) + } + + /** + * Check if two Aggregation `a` and `b` is equal or not. + */ + def equivalentAggregations(a: Aggregation, b: Aggregation): Boolean = { + a.aggregateExpressions.sortBy(_.hashCode()) + .sameElements(b.aggregateExpressions.sortBy(_.hashCode())) && + a.groupByColumns.sortBy(_.hashCode()).sameElements(b.groupByColumns.sortBy(_.hashCode())) + } + + /** + * Convert the aggregates result from `InternalRow` to `ColumnarBatch`. + * This is used for columnar reader. + */ + def convertAggregatesRowToBatch( + aggregatesAsRow: InternalRow, + aggregatesSchema: StructType, + offHeap: Boolean): ColumnarBatch = { + val converter = new RowToColumnConverter(aggregatesSchema) + val columnVectors = if (offHeap) { + OffHeapColumnVector.allocateColumns(1, aggregatesSchema) + } else { + OnHeapColumnVector.allocateColumns(1, aggregatesSchema) + } + converter.convert(aggregatesAsRow, columnVectors.toArray) + new ColumnarBatch(columnVectors.asInstanceOf[Array[ColumnVector]], 1) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala index 1476083..9140833 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala @@ -68,6 +68,22 @@ class OrcDeserializer( resultRow } + def deserializeFromValues(orcValues: Seq[WritableComparable[_]]): InternalRow = { + var targetColumnIndex = 0 + while (targetColumnIndex < fieldWriters.length) { + if (fieldWriters(targetColumnIndex) != null) { + val value = orcValues(requestedColIds(targetColumnIndex)) + if (value == null) { + resultRow.setNullAt(targetColumnIndex) + } else { + fieldWriters(targetColumnIndex)(value) + } + } + targetColumnIndex += 1 + } + resultRow + } + /** * Creates a writer to write ORC values to Catalyst data structure at the given ordinal. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala index 475448a..b262415 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala @@ -25,15 +25,19 @@ import scala.collection.mutable.ArrayBuffer import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} -import org.apache.orc.{OrcConf, OrcFile, Reader, TypeDescription, Writer} +import org.apache.hadoop.hive.serde2.io.DateWritable +import org.apache.hadoop.io.{BooleanWritable, ByteWritable, DoubleWritable, FloatWritable, IntWritable, LongWritable, ShortWritable, WritableComparable} +import org.apache.orc.{BooleanColumnStatistics, ColumnStatistics, DateColumnStatistics, DoubleColumnStatistics, IntegerColumnStatistics, OrcConf, OrcFile, Reader, TypeDescription, Writer} -import org.apache.spark.SPARK_VERSION_SHORT +import org.apache.spark.{SPARK_VERSION_SHORT, SparkException} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.sql.{SPARK_VERSION_METADATA_KEY, SparkSession} +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.caseSensitiveResolution import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.util.{quoteIdentifier, CharVarcharUtils} +import org.apache.spark.sql.connector.expressions.aggregate.{Aggregation, Count, CountStar, Max, Min} import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.datasources.SchemaMergeUtils import org.apache.spark.sql.types._ @@ -87,7 +91,7 @@ object OrcUtils extends Logging { } } - private def toCatalystSchema(schema: TypeDescription): StructType = { + def toCatalystSchema(schema: TypeDescription): StructType = { import TypeDescription.Category def toCatalystType(orcType: TypeDescription): DataType = { @@ -377,4 +381,116 @@ object OrcUtils extends Logging { case _ => false } } + + /** + * When the partial aggregates (Max/Min/Count) are pushed down to ORC, we don't need to read data + * from ORC and aggregate at Spark layer. Instead we want to get the partial aggregates + * (Max/Min/Count) result using the statistics information from ORC file footer, and then + * construct an InternalRow from these aggregate results. + * + * @return Aggregate results in the format of InternalRow + */ + def createAggInternalRowFromFooter( + reader: Reader, + filePath: String, + dataSchema: StructType, + partitionSchema: StructType, + aggregation: Aggregation, + aggSchema: StructType): InternalRow = { + require(aggregation.groupByColumns.length == 0, + s"aggregate $aggregation with group-by column shouldn't be pushed down") + var columnsStatistics: OrcColumnStatistics = null + try { + columnsStatistics = OrcFooterReader.readStatistics(reader) + } catch { case e: Exception => + throw new SparkException( + s"Cannot read columns statistics in file: $filePath. Please consider disabling " + + s"ORC aggregate push down by setting 'spark.sql.orc.aggregatePushdown' to false.", e) + } + + // Get column statistics with column name. + def getColumnStatistics(columnName: String): ColumnStatistics = { + val columnIndex = dataSchema.fieldNames.indexOf(columnName) + columnsStatistics.get(columnIndex).getStatistics + } + + // Get Min/Max statistics and store as ORC `WritableComparable` format. + // Return null if number of non-null values is zero. + def getMinMaxFromColumnStatistics( + statistics: ColumnStatistics, + dataType: DataType, + isMax: Boolean): WritableComparable[_] = { + if (statistics.getNumberOfValues == 0) { + return null + } + + statistics match { + case s: BooleanColumnStatistics => + val value = if (isMax) s.getTrueCount > 0 else !(s.getFalseCount > 0) + new BooleanWritable(value) + case s: IntegerColumnStatistics => + val value = if (isMax) s.getMaximum else s.getMinimum + dataType match { + case ByteType => new ByteWritable(value.toByte) + case ShortType => new ShortWritable(value.toShort) + case IntegerType => new IntWritable(value.toInt) + case LongType => new LongWritable(value) + case _ => throw new IllegalArgumentException( + s"getMinMaxFromColumnStatistics should not take type $dataType " + + "for IntegerColumnStatistics") + } + case s: DoubleColumnStatistics => + val value = if (isMax) s.getMaximum else s.getMinimum + dataType match { + case FloatType => new FloatWritable(value.toFloat) + case DoubleType => new DoubleWritable(value) + case _ => throw new IllegalArgumentException( + s"getMinMaxFromColumnStatistics should not take type $dataType " + + "for DoubleColumnStatistics") + } + case s: DateColumnStatistics => + new DateWritable( + if (isMax) s.getMaximumDayOfEpoch.toInt else s.getMinimumDayOfEpoch.toInt) + case _ => throw new IllegalArgumentException( + s"getMinMaxFromColumnStatistics should not take ${statistics.getClass.getName}: " + + s"$statistics as the ORC column statistics") + } + } + + val aggORCValues: Seq[WritableComparable[_]] = + aggregation.aggregateExpressions.zipWithIndex.map { + case (max: Max, index) => + val columnName = max.column.fieldNames.head + val statistics = getColumnStatistics(columnName) + val dataType = aggSchema(index).dataType + getMinMaxFromColumnStatistics(statistics, dataType, isMax = true) + case (min: Min, index) => + val columnName = min.column.fieldNames.head + val statistics = getColumnStatistics(columnName) + val dataType = aggSchema.apply(index).dataType + getMinMaxFromColumnStatistics(statistics, dataType, isMax = false) + case (count: Count, _) => + val columnName = count.column.fieldNames.head + val isPartitionColumn = partitionSchema.fields.map(_.name).contains(columnName) + // NOTE: Count(columnName) doesn't include null values. + // org.apache.orc.ColumnStatistics.getNumberOfValues() returns number of non-null values + // for ColumnStatistics of individual column. In addition to this, ORC also stores number + // of all values (null and non-null) separately. + val nonNullRowsCount = if (isPartitionColumn) { + columnsStatistics.getStatistics.getNumberOfValues + } else { + getColumnStatistics(columnName).getNumberOfValues + } + new LongWritable(nonNullRowsCount) + case (_: CountStar, _) => + // Count(*) includes both null and non-null values. + new LongWritable(columnsStatistics.getStatistics.getNumberOfValues) + case (x, _) => + throw new IllegalArgumentException( + s"createAggInternalRowFromFooter should not take $x as the aggregate expression") + } + + val orcValuesDeserializer = new OrcDeserializer(aggSchema, (0 until aggSchema.length).toArray) + orcValuesDeserializer.deserializeFromValues(aggORCValues) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala index 1093f9c..0e4b9283 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala @@ -32,12 +32,9 @@ import org.apache.spark.SparkException import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.connector.expressions.aggregate.{Aggregation, Count, CountStar, Max, Min} -import org.apache.spark.sql.execution.RowToColumnConverter import org.apache.spark.sql.execution.datasources.PartitioningUtils -import org.apache.spark.sql.execution.vectorized.{OffHeapColumnVector, OnHeapColumnVector} import org.apache.spark.sql.internal.SQLConf.{LegacyBehaviorPolicy, PARQUET_AGGREGATE_PUSHDOWN_ENABLED} import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} object ParquetUtils { def inferSchema( @@ -202,44 +199,6 @@ object ParquetUtils { } /** - * When the aggregates (Max/Min/Count) are pushed down to Parquet, in the case of - * PARQUET_VECTORIZED_READER_ENABLED sets to true, we don't need buildColumnarReader - * to read data from Parquet and aggregate at Spark layer. Instead we want - * to get the aggregates (Max/Min/Count) result using the statistics information - * from Parquet footer file, and then construct a ColumnarBatch from these aggregate results. - * - * @return Aggregate results in the format of ColumnarBatch - */ - private[sql] def createAggColumnarBatchFromFooter( - footer: ParquetMetadata, - filePath: String, - dataSchema: StructType, - partitionSchema: StructType, - aggregation: Aggregation, - aggSchema: StructType, - offHeap: Boolean, - datetimeRebaseMode: LegacyBehaviorPolicy.Value, - isCaseSensitive: Boolean): ColumnarBatch = { - val row = createAggInternalRowFromFooter( - footer, - filePath, - dataSchema, - partitionSchema, - aggregation, - aggSchema, - datetimeRebaseMode, - isCaseSensitive) - val converter = new RowToColumnConverter(aggSchema) - val columnVectors = if (offHeap) { - OffHeapColumnVector.allocateColumns(1, aggSchema) - } else { - OnHeapColumnVector.allocateColumns(1, aggSchema) - } - converter.convert(row, columnVectors.toArray) - new ColumnarBatch(columnVectors.asInstanceOf[Array[ColumnVector]], 1) - } - - /** * Calculate the pushed down aggregates (Max/Min/Count) result using the statistics * information from Parquet footer file. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala index c5020cb..246f160 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala @@ -23,15 +23,16 @@ import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce.{JobID, TaskAttemptID, TaskID, TaskType} import org.apache.hadoop.mapreduce.lib.input.FileSplit import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl -import org.apache.orc.{OrcConf, OrcFile, TypeDescription} +import org.apache.orc.{OrcConf, OrcFile, Reader, TypeDescription} import org.apache.orc.mapred.OrcStruct import org.apache.orc.mapreduce.OrcInputFormat import org.apache.spark.broadcast.Broadcast import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.connector.expressions.aggregate.Aggregation import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader} import org.apache.spark.sql.execution.WholeStageCodegenExec -import org.apache.spark.sql.execution.datasources.PartitionedFile +import org.apache.spark.sql.execution.datasources.{AggregatePushDownUtils, PartitionedFile} import org.apache.spark.sql.execution.datasources.orc.{OrcColumnarBatchReader, OrcDeserializer, OrcFilters, OrcUtils} import org.apache.spark.sql.execution.datasources.v2._ import org.apache.spark.sql.internal.SQLConf @@ -55,7 +56,8 @@ case class OrcPartitionReaderFactory( dataSchema: StructType, readDataSchema: StructType, partitionSchema: StructType, - filters: Array[Filter]) extends FilePartitionReaderFactory { + filters: Array[Filter], + aggregation: Option[Aggregation]) extends FilePartitionReaderFactory { private val resultSchema = StructType(readDataSchema.fields ++ partitionSchema.fields) private val isCaseSensitive = sqlConf.caseSensitiveAnalysis private val capacity = sqlConf.orcVectorizedReaderBatchSize @@ -81,17 +83,14 @@ case class OrcPartitionReaderFactory( override def buildReader(file: PartitionedFile): PartitionReader[InternalRow] = { val conf = broadcastedConf.value.value - - OrcConf.IS_SCHEMA_EVOLUTION_CASE_SENSITIVE.setBoolean(conf, isCaseSensitive) - val filePath = new Path(new URI(file.filePath)) - pushDownPredicates(filePath, conf) + if (aggregation.nonEmpty) { + return buildReaderWithAggregates(filePath, conf) + } - val fs = filePath.getFileSystem(conf) - val readerOptions = OrcFile.readerOptions(conf).filesystem(fs) val resultedColPruneInfo = - Utils.tryWithResource(OrcFile.createReader(filePath, readerOptions)) { reader => + Utils.tryWithResource(createORCReader(filePath, conf)) { reader => OrcUtils.requestedColumnIds( isCaseSensitive, dataSchema, readDataSchema, reader, conf) } @@ -128,17 +127,14 @@ case class OrcPartitionReaderFactory( override def buildColumnarReader(file: PartitionedFile): PartitionReader[ColumnarBatch] = { val conf = broadcastedConf.value.value - - OrcConf.IS_SCHEMA_EVOLUTION_CASE_SENSITIVE.setBoolean(conf, isCaseSensitive) - val filePath = new Path(new URI(file.filePath)) - pushDownPredicates(filePath, conf) + if (aggregation.nonEmpty) { + return buildColumnarReaderWithAggregates(filePath, conf) + } - val fs = filePath.getFileSystem(conf) - val readerOptions = OrcFile.readerOptions(conf).filesystem(fs) val resultedColPruneInfo = - Utils.tryWithResource(OrcFile.createReader(filePath, readerOptions)) { reader => + Utils.tryWithResource(createORCReader(filePath, conf)) { reader => OrcUtils.requestedColumnIds( isCaseSensitive, dataSchema, readDataSchema, reader, conf) } @@ -173,4 +169,67 @@ case class OrcPartitionReaderFactory( } } + private def createORCReader(filePath: Path, conf: Configuration): Reader = { + OrcConf.IS_SCHEMA_EVOLUTION_CASE_SENSITIVE.setBoolean(conf, isCaseSensitive) + + pushDownPredicates(filePath, conf) + + val fs = filePath.getFileSystem(conf) + val readerOptions = OrcFile.readerOptions(conf).filesystem(fs) + OrcFile.createReader(filePath, readerOptions) + } + + /** + * Build reader with aggregate push down. + */ + private def buildReaderWithAggregates( + filePath: Path, + conf: Configuration): PartitionReader[InternalRow] = { + new PartitionReader[InternalRow] { + private var hasNext = true + private lazy val row: InternalRow = { + Utils.tryWithResource(createORCReader(filePath, conf)) { reader => + OrcUtils.createAggInternalRowFromFooter( + reader, filePath.toString, dataSchema, partitionSchema, aggregation.get, readDataSchema) + } + } + + override def next(): Boolean = hasNext + + override def get(): InternalRow = { + hasNext = false + row + } + + override def close(): Unit = {} + } + } + + /** + * Build columnar reader with aggregate push down. + */ + private def buildColumnarReaderWithAggregates( + filePath: Path, + conf: Configuration): PartitionReader[ColumnarBatch] = { + new PartitionReader[ColumnarBatch] { + private var hasNext = true + private lazy val batch: ColumnarBatch = { + Utils.tryWithResource(createORCReader(filePath, conf)) { reader => + val row = OrcUtils.createAggInternalRowFromFooter( + reader, filePath.toString, dataSchema, partitionSchema, aggregation.get, + readDataSchema) + AggregatePushDownUtils.convertAggregatesRowToBatch(row, readDataSchema, offHeap = false) + } + } + + override def next(): Boolean = hasNext + + override def get(): ColumnarBatch = { + hasNext = false + batch + } + + override def close(): Unit = {} + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala index 7619e3c..6b9d181 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala @@ -21,8 +21,9 @@ import org.apache.hadoop.fs.Path import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.connector.expressions.aggregate.Aggregation import org.apache.spark.sql.connector.read.PartitionReaderFactory -import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex +import org.apache.spark.sql.execution.datasources.{AggregatePushDownUtils, PartitioningAwareFileIndex} import org.apache.spark.sql.execution.datasources.v2.FileScan import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType @@ -37,10 +38,25 @@ case class OrcScan( readDataSchema: StructType, readPartitionSchema: StructType, options: CaseInsensitiveStringMap, + pushedAggregate: Option[Aggregation] = None, pushedFilters: Array[Filter], partitionFilters: Seq[Expression] = Seq.empty, dataFilters: Seq[Expression] = Seq.empty) extends FileScan { - override def isSplitable(path: Path): Boolean = true + override def isSplitable(path: Path): Boolean = { + // If aggregate is pushed down, only the file footer will be read once, + // so file should be not split across multiple tasks. + pushedAggregate.isEmpty + } + + override def readSchema(): StructType = { + // If aggregate is pushed down, schema has already been pruned in `OrcScanBuilder` + // and no need to call super.readSchema() + if (pushedAggregate.nonEmpty) { + readDataSchema + } else { + super.readSchema() + } + } override def createReaderFactory(): PartitionReaderFactory = { val broadcastedConf = sparkSession.sparkContext.broadcast( @@ -48,24 +64,39 @@ case class OrcScan( // The partition values are already truncated in `FileScan.partitions`. // We should use `readPartitionSchema` as the partition schema here. OrcPartitionReaderFactory(sparkSession.sessionState.conf, broadcastedConf, - dataSchema, readDataSchema, readPartitionSchema, pushedFilters) + dataSchema, readDataSchema, readPartitionSchema, pushedFilters, pushedAggregate) } override def equals(obj: Any): Boolean = obj match { case o: OrcScan => + val pushedDownAggEqual = if (pushedAggregate.nonEmpty && o.pushedAggregate.nonEmpty) { + AggregatePushDownUtils.equivalentAggregations(pushedAggregate.get, o.pushedAggregate.get) + } else { + pushedAggregate.isEmpty && o.pushedAggregate.isEmpty + } super.equals(o) && dataSchema == o.dataSchema && options == o.options && - equivalentFilters(pushedFilters, o.pushedFilters) - + equivalentFilters(pushedFilters, o.pushedFilters) && pushedDownAggEqual case _ => false } override def hashCode(): Int = getClass.hashCode() + lazy private val (pushedAggregationsStr, pushedGroupByStr) = if (pushedAggregate.nonEmpty) { + (seqToString(pushedAggregate.get.aggregateExpressions), + seqToString(pushedAggregate.get.groupByColumns)) + } else { + ("[]", "[]") + } + override def description(): String = { - super.description() + ", PushedFilters: " + seqToString(pushedFilters) + super.description() + ", PushedFilters: " + seqToString(pushedFilters) + + ", PushedAggregation: " + pushedAggregationsStr + + ", PushedGroupBy: " + pushedGroupByStr } override def getMetaData(): Map[String, String] = { - super.getMetaData() ++ Map("PushedFilters" -> seqToString(pushedFilters)) + super.getMetaData() ++ Map("PushedFilters" -> seqToString(pushedFilters)) ++ + Map("PushedAggregation" -> pushedAggregationsStr) ++ + Map("PushedGroupBy" -> pushedGroupByStr) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala index cfa396f..d2c17fd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala @@ -20,8 +20,9 @@ package org.apache.spark.sql.execution.datasources.v2.orc import scala.collection.JavaConverters._ import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.connector.read.Scan -import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex +import org.apache.spark.sql.connector.expressions.aggregate.Aggregation +import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownAggregates} +import org.apache.spark.sql.execution.datasources.{AggregatePushDownUtils, PartitioningAwareFileIndex} import org.apache.spark.sql.execution.datasources.orc.OrcFilters import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder import org.apache.spark.sql.internal.SQLConf @@ -35,18 +36,31 @@ case class OrcScanBuilder( schema: StructType, dataSchema: StructType, options: CaseInsensitiveStringMap) - extends FileScanBuilder(sparkSession, fileIndex, dataSchema) { + extends FileScanBuilder(sparkSession, fileIndex, dataSchema) + with SupportsPushDownAggregates { + lazy val hadoopConf = { val caseSensitiveMap = options.asCaseSensitiveMap.asScala.toMap // Hadoop Configurations are case sensitive. sparkSession.sessionState.newHadoopConfWithOptions(caseSensitiveMap) } + private var finalSchema = new StructType() + + private var pushedAggregations = Option.empty[Aggregation] + override protected val supportsNestedSchemaPruning: Boolean = true override def build(): Scan = { - OrcScan(sparkSession, hadoopConf, fileIndex, dataSchema, readDataSchema(), - readPartitionSchema(), options, pushedDataFilters, partitionFilters, dataFilters) + // the `finalSchema` is either pruned in pushAggregation (if aggregates are + // pushed down), or pruned in readDataSchema() (in regular column pruning). These + // two are mutual exclusive. + if (pushedAggregations.isEmpty) { + finalSchema = readDataSchema() + } + OrcScan(sparkSession, hadoopConf, fileIndex, dataSchema, finalSchema, + readPartitionSchema(), options, pushedAggregations, pushedDataFilters, partitionFilters, + dataFilters) } override def pushDataFilters(dataFilters: Array[Filter]): Array[Filter] = { @@ -58,4 +72,23 @@ case class OrcScanBuilder( Array.empty[Filter] } } + + override def pushAggregation(aggregation: Aggregation): Boolean = { + if (!sparkSession.sessionState.conf.orcAggregatePushDown) { + return false + } + + AggregatePushDownUtils.getSchemaForPushedAggregation( + aggregation, + schema, + partitionNameSet, + dataFilters) match { + + case Some(schema) => + finalSchema = schema + this.pushedAggregations = Some(aggregation) + true + case _ => false + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala index 111018b..6f021ff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.connector.expressions.aggregate.Aggregation import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader} -import org.apache.spark.sql.execution.datasources.{DataSourceUtils, PartitionedFile, RecordReaderIterator} +import org.apache.spark.sql.execution.datasources.{AggregatePushDownUtils, DataSourceUtils, PartitionedFile, RecordReaderIterator} import org.apache.spark.sql.execution.datasources.parquet._ import org.apache.spark.sql.execution.datasources.v2._ import org.apache.spark.sql.internal.SQLConf @@ -175,24 +175,26 @@ case class ParquetPartitionReaderFactory( } else { new PartitionReader[ColumnarBatch] { private var hasNext = true - private val row: ColumnarBatch = { + private val batch: ColumnarBatch = { val footer = getFooter(file) if (footer != null && footer.getBlocks.size > 0) { - ParquetUtils.createAggColumnarBatchFromFooter(footer, file.filePath, dataSchema, - partitionSchema, aggregation.get, readDataSchema, enableOffHeapColumnVector, + val row = ParquetUtils.createAggInternalRowFromFooter(footer, file.filePath, + dataSchema, partitionSchema, aggregation.get, readDataSchema, getDatetimeRebaseMode(footer.getFileMetaData), isCaseSensitive) + AggregatePushDownUtils.convertAggregatesRowToBatch( + row, readDataSchema, enableOffHeapColumnVector && Option(TaskContext.get()).isDefined) } else { null } } override def next(): Boolean = { - hasNext && row != null + hasNext && batch != null } override def get(): ColumnarBatch = { hasNext = false - row + batch } override def close(): Unit = {} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala index 42dc287..b92ed82 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.connector.expressions.aggregate.Aggregation import org.apache.spark.sql.connector.read.PartitionReaderFactory -import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex +import org.apache.spark.sql.execution.datasources.{AggregatePushDownUtils, PartitioningAwareFileIndex} import org.apache.spark.sql.execution.datasources.parquet.{ParquetOptions, ParquetReadSupport, ParquetWriteSupport} import org.apache.spark.sql.execution.datasources.v2.FileScan import org.apache.spark.sql.internal.SQLConf @@ -101,7 +101,7 @@ case class ParquetScan( override def equals(obj: Any): Boolean = obj match { case p: ParquetScan => val pushedDownAggEqual = if (pushedAggregate.nonEmpty && p.pushedAggregate.nonEmpty) { - equivalentAggregations(pushedAggregate.get, p.pushedAggregate.get) + AggregatePushDownUtils.equivalentAggregations(pushedAggregate.get, p.pushedAggregate.get) } else { pushedAggregate.isEmpty && p.pushedAggregate.isEmpty } @@ -130,10 +130,4 @@ case class ParquetScan( Map("PushedAggregation" -> pushedAggregationsStr) ++ Map("PushedGroupBy" -> pushedGroupByStr) } - - private def equivalentAggregations(a: Aggregation, b: Aggregation): Boolean = { - a.aggregateExpressions.sortBy(_.hashCode()) - .sameElements(b.aggregateExpressions.sortBy(_.hashCode())) && - a.groupByColumns.sortBy(_.hashCode()).sameElements(b.groupByColumns.sortBy(_.hashCode())) - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala index da49381..74d11b6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala @@ -20,15 +20,14 @@ package org.apache.spark.sql.execution.datasources.v2.parquet import scala.collection.JavaConverters._ import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.connector.expressions.NamedReference -import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Aggregation, Count, CountStar, Max, Min} +import org.apache.spark.sql.connector.expressions.aggregate.Aggregation import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownAggregates} -import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex +import org.apache.spark.sql.execution.datasources.{AggregatePushDownUtils, PartitioningAwareFileIndex} import org.apache.spark.sql.execution.datasources.parquet.{ParquetFilters, SparkToParquetSchemaConverter} import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy import org.apache.spark.sql.sources.Filter -import org.apache.spark.sql.types.{BooleanType, ByteType, DateType, DoubleType, FloatType, IntegerType, LongType, ShortType, StructField, StructType} +import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap case class ParquetScanBuilder( @@ -87,87 +86,31 @@ case class ParquetScanBuilder( override def pushedFilters(): Array[Filter] = pushedParquetFilters override def pushAggregation(aggregation: Aggregation): Boolean = { - - def getStructFieldForCol(col: NamedReference): StructField = { - schema.nameToField(col.fieldNames.head) - } - - def isPartitionCol(col: NamedReference) = { - partitionNameSet.contains(col.fieldNames.head) - } - - def processMinOrMax(agg: AggregateFunc): Boolean = { - val (column, aggType) = agg match { - case max: Max => (max.column, "max") - case min: Min => (min.column, "min") - case _ => - throw new IllegalArgumentException(s"Unexpected type of AggregateFunc ${agg.describe}") - } - - if (isPartitionCol(column)) { - // don't push down partition column, footer doesn't have max/min for partition column - return false - } - val structField = getStructFieldForCol(column) - - structField.dataType match { - // not push down complex type - // not push down Timestamp because INT96 sort order is undefined, - // Parquet doesn't return statistics for INT96 - // not push down Parquet Binary because min/max could be truncated - // (https://issues.apache.org/jira/browse/PARQUET-1685), Parquet Binary - // could be Spark StringType, BinaryType or DecimalType - case BooleanType | ByteType | ShortType | IntegerType - | LongType | FloatType | DoubleType | DateType => - finalSchema = finalSchema.add(structField.copy(s"$aggType(" + structField.name + ")")) - true - case _ => - false - } - } - - if (!sparkSession.sessionState.conf.parquetAggregatePushDown || - aggregation.groupByColumns.nonEmpty || dataFilters.length > 0) { - // Parquet footer has max/min/count for columns - // e.g. SELECT COUNT(col1) FROM t - // but footer doesn't have max/min/count for a column if max/min/count - // are combined with filter or group by - // e.g. SELECT COUNT(col1) FROM t WHERE col2 = 8 - // SELECT COUNT(col1) FROM t GROUP BY col2 - // However, if the filter is on partition column, max/min/count can still be pushed down - // Todo: add support if groupby column is partition col - // (https://issues.apache.org/jira/browse/SPARK-36646) + if (!sparkSession.sessionState.conf.parquetAggregatePushDown) { return false } - aggregation.groupByColumns.foreach { col => - if (col.fieldNames.length != 1) return false - finalSchema = finalSchema.add(getStructFieldForCol(col)) + AggregatePushDownUtils.getSchemaForPushedAggregation( + aggregation, + schema, + partitionNameSet, + dataFilters) match { + + case Some(schema) => + finalSchema = schema + this.pushedAggregations = Some(aggregation) + true + case _ => false } - - aggregation.aggregateExpressions.foreach { - case max: Max => - if (!processMinOrMax(max)) return false - case min: Min => - if (!processMinOrMax(min)) return false - case count: Count => - if (count.column.fieldNames.length != 1 || count.isDistinct) return false - finalSchema = - finalSchema.add(StructField(s"count(" + count.column.fieldNames.head + ")", LongType)) - case _: CountStar => - finalSchema = finalSchema.add(StructField("count(*)", LongType)) - case _ => - return false - } - this.pushedAggregations = Some(aggregation) - true } override def build(): Scan = { // the `finalSchema` is either pruned in pushAggregation (if aggregates are // pushed down), or pruned in readDataSchema() (in regular column pruning). These // two are mutual exclusive. - if (pushedAggregations.isEmpty) finalSchema = readDataSchema() + if (pushedAggregations.isEmpty) { + finalSchema = readDataSchema() + } ParquetScan(sparkSession, hadoopConf, fileIndex, dataSchema, finalSchema, readPartitionSchema(), pushedParquetFilters, options, pushedAggregations, partitionFilters, dataFilters) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileScanSuite.scala index 604a892..14b59ba 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileScanSuite.scala @@ -358,7 +358,7 @@ class FileScanSuite extends FileScanSuiteBase { Seq.empty), ("OrcScan", (s, fi, ds, rds, rps, f, o, pf, df) => - OrcScan(s, s.sessionState.newHadoopConf(), fi, ds, rds, rps, o, f, pf, df), + OrcScan(s, s.sessionState.newHadoopConf(), fi, ds, rds, rps, o, None, f, pf, df), Seq.empty), ("CSVScan", (s, fi, ds, rds, rps, f, o, pf, df) => CSVScan(s, fi, ds, rds, rps, o, f, pf, df), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAggregatePushDownSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceAggregatePushDownSuite.scala similarity index 70% rename from sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAggregatePushDownSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceAggregatePushDownSuite.scala index 77ecd28..a3d01e4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAggregatePushDownSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceAggregatePushDownSuite.scala @@ -15,33 +15,39 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.datasources.parquet +package org.apache.spark.sql.execution.datasources import java.sql.{Date, Timestamp} import org.apache.spark.SparkConf -import org.apache.spark.sql._ +import org.apache.spark.sql.{ExplainSuiteHelper, QueryTest, Row} +import org.apache.spark.sql.execution.datasources.orc.OrcTest +import org.apache.spark.sql.execution.datasources.parquet.ParquetTest import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation import org.apache.spark.sql.functions.min import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession -import org.apache.spark.sql.types._ +import org.apache.spark.sql.types.{BinaryType, BooleanType, ByteType, DateType, Decimal, DecimalType, DoubleType, FloatType, IntegerType, LongType, ShortType, StringType, StructField, StructType, TimestampType} /** - * A test suite that tests Max/Min/Count push down. + * A test suite that tests aggregate push down for Parquet and ORC. */ -abstract class ParquetAggregatePushDownSuite +trait FileSourceAggregatePushDownSuite extends QueryTest - with ParquetTest + with FileBasedDataSourceTest with SharedSparkSession with ExplainSuiteHelper { + import testImplicits._ - test("aggregate push down - nested column: Max(top level column) not push down") { + protected def format: String + // The SQL config key for enabling aggregate push down. + protected val aggPushDownEnabledKey: String + + test("nested column: Max(top level column) not push down") { val data = (1 to 10).map(i => Tuple1((i, Seq(s"val_$i")))) - withSQLConf( - SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") { - withParquetTable(data, "t") { + withSQLConf(aggPushDownEnabledKey -> "true") { + withDataSourceTable(data, "t") { val max = sql("SELECT Max(_1) FROM t") max.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => @@ -53,11 +59,10 @@ abstract class ParquetAggregatePushDownSuite } } - test("aggregate push down - nested column: Count(top level column) push down") { + test("nested column: Count(top level column) push down") { val data = (1 to 10).map(i => Tuple1((i, Seq(s"val_$i")))) - withSQLConf( - SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") { - withParquetTable(data, "t") { + withSQLConf(aggPushDownEnabledKey -> "true") { + withDataSourceTable(data, "t") { val count = sql("SELECT Count(_1) FROM t") count.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => @@ -70,11 +75,10 @@ abstract class ParquetAggregatePushDownSuite } } - test("aggregate push down - nested column: Max(nested column) not push down") { + test("nested column: Max(nested sub-field) not push down") { val data = (1 to 10).map(i => Tuple1((i, Seq(s"val_$i")))) - withSQLConf( - SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") { - withParquetTable(data, "t") { + withSQLConf(aggPushDownEnabledKey-> "true") { + withDataSourceTable(data, "t") { val max = sql("SELECT Max(_1._2[0]) FROM t") max.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => @@ -86,11 +90,10 @@ abstract class ParquetAggregatePushDownSuite } } - test("aggregate push down - nested column: Count(nested column) not push down") { + test("nested column: Count(nested sub-field) not push down") { val data = (1 to 10).map(i => Tuple1((i, Seq(s"val_$i")))) - withSQLConf( - SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") { - withParquetTable(data, "t") { + withSQLConf(aggPushDownEnabledKey -> "true") { + withDataSourceTable(data, "t") { val count = sql("SELECT Count(_1._2[0]) FROM t") count.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => @@ -103,13 +106,13 @@ abstract class ParquetAggregatePushDownSuite } } - test("aggregate push down - Max(partition Col): not push dow") { + test("Max(partition column): not push down") { withTempPath { dir => spark.range(10).selectExpr("id", "id % 3 as p") - .write.partitionBy("p").parquet(dir.getCanonicalPath) + .write.partitionBy("p").format(format).save(dir.getCanonicalPath) withTempView("tmp") { - spark.read.parquet(dir.getCanonicalPath).createOrReplaceTempView("tmp"); - withSQLConf(SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") { + spark.read.format(format).load(dir.getCanonicalPath).createOrReplaceTempView("tmp") + withSQLConf(aggPushDownEnabledKey -> "true") { val max = sql("SELECT Max(p) FROM tmp") max.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => @@ -123,15 +126,16 @@ abstract class ParquetAggregatePushDownSuite } } - test("aggregate push down - Count(partition Col): push down") { + test("Count(partition column): push down") { withTempPath { dir => - spark.range(10).selectExpr("id", "id % 3 as p") - .write.partitionBy("p").parquet(dir.getCanonicalPath) + spark.range(10).selectExpr("if(id % 2 = 0, null, id) AS n", "id % 3 as p") + .write.partitionBy("p").format(format).save(dir.getCanonicalPath) withTempView("tmp") { - spark.read.parquet(dir.getCanonicalPath).createOrReplaceTempView("tmp"); - Seq("false", "true").foreach { enableVectorizedReader => - withSQLConf(SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true", - vectorizedReaderEnabledKey -> enableVectorizedReader) { + spark.read.format(format).load(dir.getCanonicalPath).createOrReplaceTempView("tmp") + val enableVectorizedReader = Seq("false", "true") + for (testVectorizedReader <- enableVectorizedReader) { + withSQLConf(aggPushDownEnabledKey -> "true", + vectorizedReaderEnabledKey -> testVectorizedReader) { val count = sql("SELECT COUNT(p) FROM tmp") count.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => @@ -146,12 +150,11 @@ abstract class ParquetAggregatePushDownSuite } } - test("aggregate push down - Filter alias over aggregate") { + test("filter alias over aggregate") { val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19), (9, "mno", 7), (2, null, 6)) - withParquetTable(data, "t") { - withSQLConf( - SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") { + withDataSourceTable(data, "t") { + withSQLConf(aggPushDownEnabledKey -> "true") { val selectAgg = sql("SELECT min(_1) + max(_1) as res FROM t having res > 1") selectAgg.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => @@ -164,12 +167,11 @@ abstract class ParquetAggregatePushDownSuite } } - test("aggregate push down - alias over aggregate") { + test("alias over aggregate") { val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19), (9, "mno", 7), (2, null, 6)) - withParquetTable(data, "t") { - withSQLConf( - SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") { + withDataSourceTable(data, "t") { + withSQLConf(aggPushDownEnabledKey -> "true") { val selectAgg = sql("SELECT min(_1) + 1 as minPlus1, min(_1) + 2 as minPlus2 FROM t") selectAgg.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => @@ -182,12 +184,11 @@ abstract class ParquetAggregatePushDownSuite } } - test("aggregate push down - aggregate over alias not push down") { + test("aggregate over alias not push down") { val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19), (9, "mno", 7), (2, null, 6)) - withParquetTable(data, "t") { - withSQLConf( - SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") { + withDataSourceTable(data, "t") { + withSQLConf(aggPushDownEnabledKey -> "true") { val df = spark.table("t") val query = df.select($"_1".as("col1")).agg(min($"col1")) query.queryExecution.optimizedPlan.collect { @@ -201,12 +202,11 @@ abstract class ParquetAggregatePushDownSuite } } - test("aggregate push down - query with group by not push down") { + test("query with group by not push down") { val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19), (9, "mno", 7), (2, null, 7)) - withParquetTable(data, "t") { - withSQLConf( - SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") { + withDataSourceTable(data, "t") { + withSQLConf(aggPushDownEnabledKey -> "true") { // aggregate not pushed down if there is group by val selectAgg = sql("SELECT min(_1) FROM t GROUP BY _3 ") selectAgg.queryExecution.optimizedPlan.collect { @@ -220,12 +220,11 @@ abstract class ParquetAggregatePushDownSuite } } - test("aggregate push down - aggregate with data filter cannot be pushed down") { + test("aggregate with data filter cannot be pushed down") { val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19), (9, "mno", 7), (2, null, 7)) - withParquetTable(data, "t") { - withSQLConf( - SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") { + withDataSourceTable(data, "t") { + withSQLConf(aggPushDownEnabledKey -> "true") { // aggregate not pushed down if there is filter val selectAgg = sql("SELECT min(_3) FROM t WHERE _1 > 0") selectAgg.queryExecution.optimizedPlan.collect { @@ -239,14 +238,14 @@ abstract class ParquetAggregatePushDownSuite } } - test("aggregate push down - aggregate with partition filter can be pushed down") { + test("aggregate with partition filter can be pushed down") { withTempPath { dir => spark.range(10).selectExpr("id", "id % 3 as p") - .write.partitionBy("p").parquet(dir.getCanonicalPath) + .write.partitionBy("p").format(format).save(dir.getCanonicalPath) withTempView("tmp") { - spark.read.parquet(dir.getCanonicalPath).createOrReplaceTempView("tmp"); + spark.read.format(format).load(dir.getCanonicalPath).createOrReplaceTempView("tmp") Seq("false", "true").foreach { enableVectorizedReader => - withSQLConf(SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true", + withSQLConf(aggPushDownEnabledKey -> "true", vectorizedReaderEnabledKey -> enableVectorizedReader) { val max = sql("SELECT max(id), min(id), count(id) FROM tmp WHERE p = 0") max.queryExecution.optimizedPlan.collect { @@ -262,12 +261,11 @@ abstract class ParquetAggregatePushDownSuite } } - test("aggregate push down - push down only if all the aggregates can be pushed down") { + test("push down only if all the aggregates can be pushed down") { val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19), (9, "mno", 7), (2, null, 7)) - withParquetTable(data, "t") { - withSQLConf( - SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") { + withDataSourceTable(data, "t") { + withSQLConf(aggPushDownEnabledKey -> "true") { // not push down since sum can't be pushed down val selectAgg = sql("SELECT min(_1), sum(_3) FROM t") selectAgg.queryExecution.optimizedPlan.collect { @@ -284,9 +282,8 @@ abstract class ParquetAggregatePushDownSuite test("aggregate push down - MIN/MAX/COUNT") { val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19), (9, "mno", 7), (2, null, 6)) - withParquetTable(data, "t") { - withSQLConf( - SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") { + withDataSourceTable(data, "t") { + withSQLConf(aggPushDownEnabledKey -> "true") { val selectAgg = sql("SELECT min(_3), min(_3), max(_3), min(_1), max(_1), max(_1)," + " count(*), count(_1), count(_2), count(_3) FROM t") selectAgg.queryExecution.optimizedPlan.collect { @@ -308,7 +305,13 @@ abstract class ParquetAggregatePushDownSuite } } - test("aggregate push down - different data types") { + private def testPushDownForAllDataTypes( + inputRows: Seq[Row], + expectedMinWithAllTypes: Seq[Row], + expectedMinWithOutTSAndBinary: Seq[Row], + expectedMaxWithAllTypes: Seq[Row], + expectedMaxWithOutTSAndBinary: Seq[Row], + expectedCount: Seq[Row]): Unit = { implicit class StringToDate(s: String) { def date: Date = Date.valueOf(s) } @@ -317,49 +320,6 @@ abstract class ParquetAggregatePushDownSuite def ts: Timestamp = Timestamp.valueOf(s) } - val rows = - Seq( - Row( - "a string", - true, - 10.toByte, - "Spark SQL".getBytes, - 12.toShort, - 3, - Long.MaxValue, - 0.15.toFloat, - 0.75D, - Decimal("12.345678"), - ("2021-01-01").date, - ("2015-01-01 23:50:59.123").ts), - Row( - "test string", - false, - 1.toByte, - "Parquet".getBytes, - 2.toShort, - null, - Long.MinValue, - 0.25.toFloat, - 0.85D, - Decimal("1.2345678"), - ("2015-01-01").date, - ("2021-01-01 23:50:59.123").ts), - Row( - null, - true, - 10000.toByte, - "Spark ML".getBytes, - 222.toShort, - 113, - 11111111L, - 0.25.toFloat, - 0.75D, - Decimal("12345.678"), - ("2004-06-19").date, - ("1999-08-26 10:43:59.123").ts) - ) - val schema = StructType(List(StructField("StringCol", StringType, true), StructField("BooleanCol", BooleanType, false), StructField("ByteCol", ByteType, false), @@ -373,13 +333,13 @@ abstract class ParquetAggregatePushDownSuite StructField("DateCol", DateType, false), StructField("TimestampCol", TimestampType, false)).toArray) - val rdd = sparkContext.parallelize(rows) + val rdd = sparkContext.parallelize(inputRows) withTempPath { file => - spark.createDataFrame(rdd, schema).write.parquet(file.getCanonicalPath) + spark.createDataFrame(rdd, schema).write.format(format).save(file.getCanonicalPath) withTempView("test") { - spark.read.parquet(file.getCanonicalPath).createOrReplaceTempView("test") + spark.read.format(format).load(file.getCanonicalPath).createOrReplaceTempView("test") Seq("false", "true").foreach { enableVectorizedReader => - withSQLConf(SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true", + withSQLConf(aggPushDownEnabledKey -> "true", vectorizedReaderEnabledKey -> enableVectorizedReader) { val testMinWithAllTypes = sql("SELECT min(StringCol), min(BooleanCol), min(ByteCol), " + @@ -389,7 +349,8 @@ abstract class ParquetAggregatePushDownSuite // INT96 (Timestamp) sort order is undefined, parquet doesn't return stats for this type // so aggregates are not pushed down // In addition, Parquet Binary min/max could be truncated, so we disable aggregate - // push down for Parquet Binary (could be Spark StringType, BinaryType or DecimalType) + // push down for Parquet Binary (could be Spark StringType, BinaryType or DecimalType). + // Also do not push down for ORC with same reason. testMinWithAllTypes.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => val expected_plan_fragment = @@ -397,9 +358,7 @@ abstract class ParquetAggregatePushDownSuite checkKeywordsExistsInExplain(testMinWithAllTypes, expected_plan_fragment) } - checkAnswer(testMinWithAllTypes, Seq(Row("a string", false, 1.toByte, - "Parquet".getBytes, 2.toShort, 3, -9223372036854775808L, 0.15.toFloat, 0.75D, - 1.23457, ("2004-06-19").date, ("1999-08-26 10:43:59.123").ts))) + checkAnswer(testMinWithAllTypes, expectedMinWithAllTypes) val testMinWithOutTSAndBinary = sql("SELECT min(BooleanCol), min(ByteCol), " + "min(ShortCol), min(IntegerCol), min(LongCol), min(FloatCol), " + @@ -419,8 +378,7 @@ abstract class ParquetAggregatePushDownSuite checkKeywordsExistsInExplain(testMinWithOutTSAndBinary, expected_plan_fragment) } - checkAnswer(testMinWithOutTSAndBinary, Seq(Row(false, 1.toByte, - 2.toShort, 3, -9223372036854775808L, 0.15.toFloat, 0.75D, ("2004-06-19").date))) + checkAnswer(testMinWithOutTSAndBinary, expectedMinWithOutTSAndBinary) val testMaxWithAllTypes = sql("SELECT max(StringCol), max(BooleanCol), " + "max(ByteCol), max(BinaryCol), max(ShortCol), max(IntegerCol), max(LongCol), " + @@ -430,7 +388,8 @@ abstract class ParquetAggregatePushDownSuite // INT96 (Timestamp) sort order is undefined, parquet doesn't return stats for this type // so aggregates are not pushed down // In addition, Parquet Binary min/max could be truncated, so we disable aggregate - // push down for Parquet Binary (could be Spark StringType, BinaryType or DecimalType) + // push down for Parquet Binary (could be Spark StringType, BinaryType or DecimalType). + // Also do not push down for ORC with same reason. testMaxWithAllTypes.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => val expected_plan_fragment = @@ -438,9 +397,7 @@ abstract class ParquetAggregatePushDownSuite checkKeywordsExistsInExplain(testMaxWithAllTypes, expected_plan_fragment) } - checkAnswer(testMaxWithAllTypes, Seq(Row("test string", true, 16.toByte, - "Spark SQL".getBytes, 222.toShort, 113, 9223372036854775807L, 0.25.toFloat, 0.85D, - 12345.678, ("2021-01-01").date, ("2021-01-01 23:50:59.123").ts))) + checkAnswer(testMaxWithAllTypes, expectedMaxWithAllTypes) val testMaxWithoutTSAndBinary = sql("SELECT max(BooleanCol), max(ByteCol), " + "max(ShortCol), max(IntegerCol), max(LongCol), max(FloatCol), " + @@ -460,8 +417,7 @@ abstract class ParquetAggregatePushDownSuite checkKeywordsExistsInExplain(testMaxWithoutTSAndBinary, expected_plan_fragment) } - checkAnswer(testMaxWithoutTSAndBinary, Seq(Row(true, 16.toByte, - 222.toShort, 113, 9223372036854775807L, 0.25.toFloat, 0.85D, ("2021-01-01").date))) + checkAnswer(testMaxWithoutTSAndBinary, expectedMaxWithOutTSAndBinary) val testCount = sql("SELECT count(StringCol), count(BooleanCol)," + " count(ByteCol), count(BinaryCol), count(ShortCol), count(IntegerCol)," + @@ -487,22 +443,97 @@ abstract class ParquetAggregatePushDownSuite checkKeywordsExistsInExplain(testCount, expected_plan_fragment) } - checkAnswer(testCount, Seq(Row(2, 3, 3, 3, 3, 2, 3, 3, 3, 3, 3, 3))) + checkAnswer(testCount, expectedCount) } } } } } - test("aggregate push down - column name case sensitivity") { + test("aggregate push down - different data types") { + implicit class StringToDate(s: String) { + def date: Date = Date.valueOf(s) + } + + implicit class StringToTs(s: String) { + def ts: Timestamp = Timestamp.valueOf(s) + } + + val rows = + Seq( + Row( + "a string", + true, + 10.toByte, + "Spark SQL".getBytes, + 12.toShort, + 3, + Long.MaxValue, + 0.15.toFloat, + 0.75D, + Decimal("12.345678"), + ("2021-01-01").date, + ("2015-01-01 23:50:59.123").ts), + Row( + "test string", + false, + 1.toByte, + "Parquet".getBytes, + 2.toShort, + null, + Long.MinValue, + 0.25.toFloat, + 0.85D, + Decimal("1.2345678"), + ("2015-01-01").date, + ("2021-01-01 23:50:59.123").ts), + Row( + null, + true, + 10000.toByte, + "Spark ML".getBytes, + 222.toShort, + 113, + 11111111L, + 0.25.toFloat, + 0.75D, + Decimal("12345.678"), + ("2004-06-19").date, + ("1999-08-26 10:43:59.123").ts) + ) + + testPushDownForAllDataTypes( + rows, + Seq(Row("a string", false, 1.toByte, + "Parquet".getBytes, 2.toShort, 3, -9223372036854775808L, 0.15.toFloat, 0.75D, + 1.23457, ("2004-06-19").date, ("1999-08-26 10:43:59.123").ts)), + Seq(Row(false, 1.toByte, + 2.toShort, 3, -9223372036854775808L, 0.15.toFloat, 0.75D, ("2004-06-19").date)), + Seq(Row("test string", true, 16.toByte, + "Spark SQL".getBytes, 222.toShort, 113, 9223372036854775807L, 0.25.toFloat, 0.85D, + 12345.678, ("2021-01-01").date, ("2021-01-01 23:50:59.123").ts)), + Seq(Row(true, 16.toByte, + 222.toShort, 113, 9223372036854775807L, 0.25.toFloat, 0.85D, ("2021-01-01").date)), + Seq(Row(2, 3, 3, 3, 3, 2, 3, 3, 3, 3, 3, 3)) + ) + + // Test for 0 row (empty file) + val nullRow = Row.fromSeq((1 to 12).map(_ => null)) + val nullRowWithOutTSAndBinary = Row.fromSeq((1 to 8).map(_ => null)) + val zeroCount = Row.fromSeq((1 to 12).map(_ => 0)) + testPushDownForAllDataTypes(Seq.empty, Seq(nullRow), Seq(nullRowWithOutTSAndBinary), + Seq(nullRow), Seq(nullRowWithOutTSAndBinary), Seq(zeroCount)) + } + + test("column name case sensitivity") { Seq("false", "true").foreach { enableVectorizedReader => - withSQLConf(SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true", + withSQLConf(aggPushDownEnabledKey -> "true", vectorizedReaderEnabledKey -> enableVectorizedReader) { withTempPath { dir => spark.range(10).selectExpr("id", "id % 3 as p") - .write.partitionBy("p").parquet(dir.getCanonicalPath) + .write.partitionBy("p").format(format).save(dir.getCanonicalPath) withTempView("tmp") { - spark.read.parquet(dir.getCanonicalPath).createOrReplaceTempView("tmp"); + spark.read.format(format).load(dir.getCanonicalPath).createOrReplaceTempView("tmp") val selectAgg = sql("SELECT max(iD), min(Id) FROM tmp") selectAgg.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => @@ -518,18 +549,41 @@ abstract class ParquetAggregatePushDownSuite } } +abstract class ParquetAggregatePushDownSuite + extends FileSourceAggregatePushDownSuite with ParquetTest { + + override def format: String = "parquet" + override protected val aggPushDownEnabledKey: String = + SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key +} + class ParquetV1AggregatePushDownSuite extends ParquetAggregatePushDownSuite { override protected def sparkConf: SparkConf = - super - .sparkConf - .set(SQLConf.USE_V1_SOURCE_LIST, "parquet") + super.sparkConf.set(SQLConf.USE_V1_SOURCE_LIST, "parquet") } class ParquetV2AggregatePushDownSuite extends ParquetAggregatePushDownSuite { override protected def sparkConf: SparkConf = - super - .sparkConf - .set(SQLConf.USE_V1_SOURCE_LIST, "") + super.sparkConf.set(SQLConf.USE_V1_SOURCE_LIST, "") +} + +abstract class OrcAggregatePushDownSuite extends OrcTest with FileSourceAggregatePushDownSuite { + + override def format: String = "orc" + override protected val aggPushDownEnabledKey: String = + SQLConf.ORC_AGGREGATE_PUSHDOWN_ENABLED.key +} + +class OrcV1AggregatePushDownSuite extends OrcAggregatePushDownSuite { + + override protected def sparkConf: SparkConf = + super.sparkConf.set(SQLConf.USE_V1_SOURCE_LIST, "orc") +} + +class OrcV2AggregatePushDownSuite extends OrcAggregatePushDownSuite { + + override protected def sparkConf: SparkConf = + super.sparkConf.set(SQLConf.USE_V1_SOURCE_LIST, "") } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org