This is an automated email from the ASF dual-hosted git repository.
lzljs3620320 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/paimon.git
The following commit(s) were added to refs/heads/master by this push:
new a5dc3ef83b [spark] support to push down min/max aggregation (#5270)
a5dc3ef83b is described below
commit a5dc3ef83b01f6276360f18e842dd9c0d2749804
Author: Yann Byron <[email protected]>
AuthorDate: Fri Mar 21 13:23:41 2025 +0800
[spark] support to push down min/max aggregation (#5270)
---
.../apache/paimon/stats/SimpleStatsEvolution.java | 24 ++++
.../org/apache/paimon/table/source/DataSplit.java | 44 ++++++
.../org/apache/paimon/table/source/SplitTest.java | 110 ++++++++++++++-
.../apache/paimon/spark/PaimonScanBuilder.scala | 40 +++---
.../paimon/spark/aggregate/AggFuncEvaluator.scala | 96 +++++++++++++
.../spark/aggregate/AggregatePushDownUtils.scala | 124 +++++++++++++++++
.../paimon/spark/aggregate/LocalAggregator.scala | 93 +++++--------
.../paimon/spark/sql/PushDownAggregatesTest.scala | 150 +++++++++++++++++----
8 files changed, 572 insertions(+), 109 deletions(-)
diff --git
a/paimon-core/src/main/java/org/apache/paimon/stats/SimpleStatsEvolution.java
b/paimon-core/src/main/java/org/apache/paimon/stats/SimpleStatsEvolution.java
index fb029eccdb..b1c7cfebee 100644
---
a/paimon-core/src/main/java/org/apache/paimon/stats/SimpleStatsEvolution.java
+++
b/paimon-core/src/main/java/org/apache/paimon/stats/SimpleStatsEvolution.java
@@ -64,6 +64,30 @@ public class SimpleStatsEvolution {
this.emptyNullCounts = new GenericArray(new Object[fieldNames.size()]);
}
+ public InternalRow evolution(InternalRow row, @Nullable List<String>
denseFields) {
+ InternalRow result = row;
+
+ if (denseFields != null && denseFields.isEmpty()) {
+ result = emptyValues;
+ } else if (denseFields != null) {
+ int[] denseIndexMapping =
+ indexMappings.computeIfAbsent(
+ denseFields,
+ k ->
fieldNames.stream().mapToInt(denseFields::indexOf).toArray());
+ result = ProjectedRow.from(denseIndexMapping).replaceRow(result);
+ }
+
+ if (indexMapping != null) {
+ result = ProjectedRow.from(indexMapping).replaceRow(result);
+ }
+
+ if (castFieldGetters != null) {
+ result = CastedRow.from(castFieldGetters).replaceRow(result);
+ }
+
+ return result;
+ }
+
public Result evolution(
SimpleStats stats, @Nullable Long rowCount, @Nullable List<String>
denseFields) {
InternalRow minValues = stats.minValues();
diff --git
a/paimon-core/src/main/java/org/apache/paimon/table/source/DataSplit.java
b/paimon-core/src/main/java/org/apache/paimon/table/source/DataSplit.java
index 39f9269f41..5e39d3a71b 100644
--- a/paimon-core/src/main/java/org/apache/paimon/table/source/DataSplit.java
+++ b/paimon-core/src/main/java/org/apache/paimon/table/source/DataSplit.java
@@ -19,6 +19,7 @@
package org.apache.paimon.table.source;
import org.apache.paimon.data.BinaryRow;
+import org.apache.paimon.data.InternalRow;
import org.apache.paimon.io.DataFileMeta;
import org.apache.paimon.io.DataFileMeta08Serializer;
import org.apache.paimon.io.DataFileMeta09Serializer;
@@ -28,7 +29,12 @@ import org.apache.paimon.io.DataInputView;
import org.apache.paimon.io.DataInputViewStreamWrapper;
import org.apache.paimon.io.DataOutputView;
import org.apache.paimon.io.DataOutputViewStreamWrapper;
+import org.apache.paimon.predicate.CompareUtils;
+import org.apache.paimon.stats.SimpleStatsEvolution;
+import org.apache.paimon.stats.SimpleStatsEvolutions;
+import org.apache.paimon.types.DataField;
import org.apache.paimon.utils.FunctionWithIOException;
+import org.apache.paimon.utils.InternalRowUtils;
import org.apache.paimon.utils.SerializationUtils;
import javax.annotation.Nullable;
@@ -141,6 +147,44 @@ public class DataSplit implements Split {
return partialMergedRowCount();
}
+ public Object minValue(int fieldIndex, DataField dataField,
SimpleStatsEvolutions evolutions) {
+ Object minValue = null;
+ for (DataFileMeta dataFile : dataFiles) {
+ SimpleStatsEvolution evolution =
evolutions.getOrCreate(dataFile.schemaId());
+ InternalRow minValues =
+ evolution.evolution(
+ dataFile.valueStats().minValues(),
dataFile.valueStatsCols());
+ Object other = InternalRowUtils.get(minValues, fieldIndex,
dataField.type());
+ if (minValue == null) {
+ minValue = other;
+ } else if (other != null) {
+ if (CompareUtils.compareLiteral(dataField.type(), minValue,
other) > 0) {
+ minValue = other;
+ }
+ }
+ }
+ return minValue;
+ }
+
+ public Object maxValue(int fieldIndex, DataField dataField,
SimpleStatsEvolutions evolutions) {
+ Object maxValue = null;
+ for (DataFileMeta dataFile : dataFiles) {
+ SimpleStatsEvolution evolution =
evolutions.getOrCreate(dataFile.schemaId());
+ InternalRow maxValues =
+ evolution.evolution(
+ dataFile.valueStats().maxValues(),
dataFile.valueStatsCols());
+ Object other = InternalRowUtils.get(maxValues, fieldIndex,
dataField.type());
+ if (maxValue == null) {
+ maxValue = other;
+ } else if (other != null) {
+ if (CompareUtils.compareLiteral(dataField.type(), maxValue,
other) < 0) {
+ maxValue = other;
+ }
+ }
+ }
+ return maxValue;
+ }
+
/**
* Obtain merged row count as much as possible. There are two scenarios
where accurate row count
* can be calculated:
diff --git
a/paimon-core/src/test/java/org/apache/paimon/table/source/SplitTest.java
b/paimon-core/src/test/java/org/apache/paimon/table/source/SplitTest.java
index a088f40dab..a87a645711 100644
--- a/paimon-core/src/test/java/org/apache/paimon/table/source/SplitTest.java
+++ b/paimon-core/src/test/java/org/apache/paimon/table/source/SplitTest.java
@@ -28,18 +28,30 @@ import org.apache.paimon.io.DataInputDeserializer;
import org.apache.paimon.io.DataOutputViewStreamWrapper;
import org.apache.paimon.manifest.FileSource;
import org.apache.paimon.stats.SimpleStats;
+import org.apache.paimon.stats.SimpleStatsEvolutions;
+import org.apache.paimon.types.BigIntType;
+import org.apache.paimon.types.DataField;
+import org.apache.paimon.types.DoubleType;
+import org.apache.paimon.types.FloatType;
+import org.apache.paimon.types.IntType;
+import org.apache.paimon.types.SmallIntType;
+import org.apache.paimon.types.TimestampType;
import org.apache.paimon.utils.IOUtils;
import org.apache.paimon.utils.InstantiationUtil;
import org.junit.jupiter.api.Test;
+import javax.annotation.Nullable;
+
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.time.LocalDateTime;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
+import java.util.HashMap;
import java.util.List;
+import java.util.Map;
import java.util.concurrent.ThreadLocalRandom;
import static org.apache.paimon.data.BinaryArray.fromLongArray;
@@ -84,6 +96,70 @@ public class SplitTest {
assertThat(split.mergedRowCount()).isEqualTo(5700L);
}
+ @Test
+ public void testSplitMinMaxValue() {
+ Map<Long, List<DataField>> schemas = new HashMap<>();
+
+ Timestamp minTs =
Timestamp.fromLocalDateTime(LocalDateTime.parse("2025-01-01T00:00:00"));
+ Timestamp maxTs1 =
Timestamp.fromLocalDateTime(LocalDateTime.parse("2025-03-01T00:00:00"));
+ Timestamp maxTs2 =
Timestamp.fromLocalDateTime(LocalDateTime.parse("2025-03-12T00:00:00"));
+ BinaryRow min1 = newBinaryRow(new Object[] {10, 123L, 888.0D, minTs});
+ BinaryRow max1 = newBinaryRow(new Object[] {99, 456L, 999.0D, maxTs1});
+ SimpleStats valueStats1 = new SimpleStats(min1, max1,
fromLongArray(new Long[] {0L}));
+
+ BinaryRow min2 = newBinaryRow(new Object[] {5, 0L, 777.0D, minTs});
+ BinaryRow max2 = newBinaryRow(new Object[] {90, 789L, 899.0D, maxTs2});
+ SimpleStats valueStats2 = new SimpleStats(min2, max2,
fromLongArray(new Long[] {0L}));
+
+ // test the common case.
+ DataFileMeta d1 = newDataFile(100, valueStats1, null);
+ DataFileMeta d2 = newDataFile(100, valueStats2, null);
+ DataSplit split1 = newDataSplit(true, Arrays.asList(d1, d2), null);
+
+ DataField intField = new DataField(0, "c_int", new IntType());
+ DataField longField = new DataField(1, "c_long", new BigIntType());
+ DataField doubleField = new DataField(2, "c_double", new DoubleType());
+ DataField tsField = new DataField(3, "c_ts", new TimestampType());
+ schemas.put(1L, Arrays.asList(intField, longField, doubleField,
tsField));
+
+ SimpleStatsEvolutions evolutions = new
SimpleStatsEvolutions(schemas::get, 1);
+ assertThat(split1.minValue(0, intField, evolutions)).isEqualTo(5);
+ assertThat(split1.maxValue(0, intField, evolutions)).isEqualTo(99);
+ assertThat(split1.minValue(1, longField, evolutions)).isEqualTo(0L);
+ assertThat(split1.maxValue(1, longField, evolutions)).isEqualTo(789L);
+ assertThat(split1.minValue(2, doubleField,
evolutions)).isEqualTo(777D);
+ assertThat(split1.maxValue(2, doubleField,
evolutions)).isEqualTo(999D);
+ assertThat(split1.minValue(3, tsField, evolutions)).isEqualTo(minTs);
+ assertThat(split1.maxValue(3, tsField, evolutions)).isEqualTo(maxTs2);
+
+ // test the case which provide non-null valueStatsCol and there are
different between file
+ // schema and table schema.
+ BinaryRow min3 = newBinaryRow(new Object[] {10, 123L, minTs});
+ BinaryRow max3 = newBinaryRow(new Object[] {99, 456L, maxTs1});
+ SimpleStats valueStats3 = new SimpleStats(min3, max3,
fromLongArray(new Long[] {0L}));
+ BinaryRow min4 = newBinaryRow(new Object[] {5, 0L, minTs});
+ BinaryRow max4 = newBinaryRow(new Object[] {90, 789L, maxTs2});
+ SimpleStats valueStats4 = new SimpleStats(min4, max4,
fromLongArray(new Long[] {0L}));
+ List<String> valueStatsCols2 = Arrays.asList("c_int", "c_long",
"c_ts");
+ DataFileMeta d3 = newDataFile(100, valueStats3, valueStatsCols2);
+ DataFileMeta d4 = newDataFile(100, valueStats4, valueStatsCols2);
+ DataSplit split2 = newDataSplit(true, Arrays.asList(d3, d4), null);
+
+ DataField smallField = new DataField(4, "c_small", new SmallIntType());
+ DataField floatField = new DataField(5, "c_float", new FloatType());
+ schemas.put(2L, Arrays.asList(intField, smallField, tsField,
floatField));
+
+ evolutions = new SimpleStatsEvolutions(schemas::get, 2);
+ assertThat(split2.minValue(0, intField, evolutions)).isEqualTo(5);
+ assertThat(split2.maxValue(0, intField, evolutions)).isEqualTo(99);
+ assertThat(split2.minValue(1, smallField, evolutions)).isEqualTo(null);
+ assertThat(split2.maxValue(1, smallField, evolutions)).isEqualTo(null);
+ assertThat(split2.minValue(2, tsField, evolutions)).isEqualTo(minTs);
+ assertThat(split2.maxValue(2, tsField, evolutions)).isEqualTo(maxTs2);
+ assertThat(split2.minValue(3, floatField, evolutions)).isEqualTo(null);
+ assertThat(split2.maxValue(3, floatField, evolutions)).isEqualTo(null);
+ }
+
@Test
public void testSerializer() throws IOException {
DataFileTestDataGenerator gen =
DataFileTestDataGenerator.builder().build();
@@ -436,18 +512,23 @@ public class SplitTest {
}
private DataFileMeta newDataFile(long rowCount) {
+ return newDataFile(rowCount, null, null);
+ }
+
+ private DataFileMeta newDataFile(
+ long rowCount, SimpleStats rowStats, @Nullable List<String>
valueStatsCols) {
return DataFileMeta.forAppend(
"my_data_file.parquet",
1024 * 1024,
rowCount,
- null,
+ rowStats,
0L,
- rowCount,
+ rowCount - 1,
1,
Collections.emptyList(),
null,
null,
- null,
+ valueStatsCols,
null);
}
@@ -467,4 +548,27 @@ public class SplitTest {
}
return builder.build();
}
+
+ private BinaryRow newBinaryRow(Object[] objs) {
+ BinaryRow row = new BinaryRow(objs.length);
+ BinaryRowWriter writer = new BinaryRowWriter(row);
+ writer.reset();
+ for (int i = 0; i < objs.length; i++) {
+ if (objs[i] instanceof Integer) {
+ writer.writeInt(i, (Integer) objs[i]);
+ } else if (objs[i] instanceof Long) {
+ writer.writeLong(i, (Long) objs[i]);
+ } else if (objs[i] instanceof Float) {
+ writer.writeFloat(i, (Float) objs[i]);
+ } else if (objs[i] instanceof Double) {
+ writer.writeDouble(i, (Double) objs[i]);
+ } else if (objs[i] instanceof Timestamp) {
+ writer.writeTimestamp(i, (Timestamp) objs[i], 5);
+ } else {
+ throw new UnsupportedOperationException("It's not supported.");
+ }
+ }
+ writer.complete();
+ return row;
+ }
}
diff --git
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScanBuilder.scala
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScanBuilder.scala
index c9afa07021..5fe1737c0d 100644
---
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScanBuilder.scala
+++
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScanBuilder.scala
@@ -19,8 +19,8 @@
package org.apache.paimon.spark
import org.apache.paimon.predicate.{PartitionPredicateVisitor, Predicate,
PredicateBuilder}
-import org.apache.paimon.spark.aggregate.LocalAggregator
-import org.apache.paimon.table.Table
+import org.apache.paimon.spark.aggregate.{AggregatePushDownUtils,
LocalAggregator}
+import org.apache.paimon.table.{FileStoreTable, Table}
import org.apache.paimon.table.source.DataSplit
import org.apache.spark.sql.PaimonUtils
@@ -101,13 +101,12 @@ class PaimonScanBuilder(table: Table)
return true
}
- // Only support when there is no post scan predicates.
- if (hasPostScanPredicates) {
+ if (!table.isInstanceOf[FileStoreTable]) {
return false
}
- val aggregator = new LocalAggregator(table)
- if (!aggregator.pushAggregation(aggregation)) {
+ // Only support when there is no post scan predicates.
+ if (hasPostScanPredicates) {
return false
}
@@ -116,19 +115,26 @@ class PaimonScanBuilder(table: Table)
val pushedPartitionPredicate =
PredicateBuilder.and(pushedPaimonPredicates.toList.asJava)
readBuilder.withFilter(pushedPartitionPredicate)
}
- val dataSplits =
+ val dataSplits = if
(AggregatePushDownUtils.hasMinMaxAggregation(aggregation)) {
+
readBuilder.newScan().plan().splits().asScala.map(_.asInstanceOf[DataSplit])
+ } else {
readBuilder.dropStats().newScan().plan().splits().asScala.map(_.asInstanceOf[DataSplit])
- if (!dataSplits.forall(_.mergedRowCountAvailable())) {
- return false
}
- dataSplits.foreach(aggregator.update)
- localScan = Some(
- PaimonLocalScan(
- aggregator.result(),
- aggregator.resultSchema(),
- table,
- pushedPaimonPredicates))
- true
+ if (AggregatePushDownUtils.canPushdownAggregation(table, aggregation,
dataSplits.toSeq)) {
+ val aggregator = new LocalAggregator(table.asInstanceOf[FileStoreTable])
+ aggregator.initialize(aggregation)
+ dataSplits.foreach(aggregator.update)
+ localScan = Some(
+ PaimonLocalScan(
+ aggregator.result(),
+ aggregator.resultSchema(),
+ table,
+ pushedPaimonPredicates)
+ )
+ true
+ } else {
+ false
+ }
}
override def build(): Scan = {
diff --git
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/aggregate/AggFuncEvaluator.scala
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/aggregate/AggFuncEvaluator.scala
new file mode 100644
index 0000000000..fcb64e3064
--- /dev/null
+++
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/aggregate/AggFuncEvaluator.scala
@@ -0,0 +1,96 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.paimon.spark.aggregate
+
+import org.apache.paimon.data.BinaryString
+import org.apache.paimon.predicate.CompareUtils
+import org.apache.paimon.spark.SparkTypeUtils
+import org.apache.paimon.stats.SimpleStatsEvolutions
+import org.apache.paimon.table.source.DataSplit
+import org.apache.paimon.types.DataField
+
+import org.apache.spark.sql.types.{DataType, LongType}
+import org.apache.spark.unsafe.types.UTF8String
+
+trait AggFuncEvaluator[T] {
+ def update(dataSplit: DataSplit): Unit
+
+ def result(): T
+
+ def resultType: DataType
+
+ def prettyName: String
+}
+
+class CountStarEvaluator extends AggFuncEvaluator[Long] {
+ private var _result: Long = 0L
+
+ override def update(dataSplit: DataSplit): Unit = {
+ _result += dataSplit.mergedRowCount()
+ }
+
+ val a: Int = 1;
+ override def result(): Long = _result
+
+ override def resultType: DataType = LongType
+
+ override def prettyName: String = "count_star"
+}
+
+case class MinEvaluator(idx: Int, dataField: DataField, evolutions:
SimpleStatsEvolutions)
+ extends AggFuncEvaluator[Any] {
+ private var _result: Any = _
+
+ override def update(dataSplit: DataSplit): Unit = {
+ val other = dataSplit.minValue(idx, dataField, evolutions)
+ if (_result == null || CompareUtils.compareLiteral(dataField.`type`(),
_result, other) > 0) {
+ _result = other;
+ }
+ }
+
+ override def result(): Any = _result match {
+ case s: BinaryString => UTF8String.fromString(s.toString)
+ case a => a
+ }
+
+ override def resultType: DataType =
SparkTypeUtils.fromPaimonType(dataField.`type`())
+
+ override def prettyName: String = "min"
+}
+
+case class MaxEvaluator(idx: Int, dataField: DataField, evolutions:
SimpleStatsEvolutions)
+ extends AggFuncEvaluator[Any] {
+ private var _result: Any = _
+
+ override def update(dataSplit: DataSplit): Unit = {
+ val other = dataSplit.maxValue(idx, dataField, evolutions)
+ if (_result == null || CompareUtils.compareLiteral(dataField.`type`(),
_result, other) < 0) {
+ _result = other
+ }
+ }
+
+ override def result(): Any = _result match {
+ case s: BinaryString => UTF8String.fromString(s.toString)
+ case a => a
+ }
+
+ override def resultType: DataType =
SparkTypeUtils.fromPaimonType(dataField.`type`())
+
+ override def prettyName: String = "max"
+}
diff --git
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/aggregate/AggregatePushDownUtils.scala
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/aggregate/AggregatePushDownUtils.scala
new file mode 100644
index 0000000000..c6abec1acd
--- /dev/null
+++
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/aggregate/AggregatePushDownUtils.scala
@@ -0,0 +1,124 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.paimon.spark.aggregate
+
+import org.apache.paimon.table.Table
+import org.apache.paimon.table.source.DataSplit
+import org.apache.paimon.types._
+
+import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc,
Aggregation, CountStar, Max, Min}
+import org.apache.spark.sql.execution.datasources.v2.V2ColumnUtils
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable
+
+object AggregatePushDownUtils {
+
+ def canPushdownAggregation(
+ table: Table,
+ aggregation: Aggregation,
+ dataSplits: Seq[DataSplit]): Boolean = {
+
+ var hasMinMax = false
+ val minmaxColumns = mutable.HashSet.empty[String]
+ var hasCount = false
+
+ def getDataFieldForCol(colName: String): DataField = {
+ table.rowType.getField(colName)
+ }
+
+ def isPartitionCol(colName: String) = {
+ table.partitionKeys.contains(colName)
+ }
+
+ def processMinOrMax(agg: AggregateFunc): Boolean = {
+ val columnName = agg match {
+ case max: Max if V2ColumnUtils.extractV2Column(max.column).isDefined =>
+ V2ColumnUtils.extractV2Column(max.column).get
+ case min: Min if V2ColumnUtils.extractV2Column(min.column).isDefined =>
+ V2ColumnUtils.extractV2Column(min.column).get
+ case _ => return false
+ }
+
+ val dataField = getDataFieldForCol(columnName)
+
+ dataField.`type`() match {
+ // not push down complex type
+ // not push down Timestamp because INT96 sort order is undefined,
+ // Parquet doesn't return statistics for INT96
+ // not push down Parquet Binary because min/max could be truncated
+ // (https://issues.apache.org/jira/browse/PARQUET-1685), Parquet Binary
+ // could be Spark StringType, BinaryType or DecimalType.
+ // not push down for ORC with same reason.
+ case _: BooleanType | _: TinyIntType | _: SmallIntType | _: IntType |
_: BigIntType |
+ _: FloatType | _: DoubleType | _: DateType =>
+ minmaxColumns.add(columnName)
+ hasMinMax = true
+ true
+ case _ =>
+ false
+ }
+ }
+
+ aggregation.groupByExpressions.map(V2ColumnUtils.extractV2Column).foreach {
+ colName =>
+ // don't push down if the group by columns are not the same as the
partition columns (orders
+ // doesn't matter because reorder can be done at data source layer)
+ if (colName.isEmpty || !isPartitionCol(colName.get)) return false
+ }
+
+ aggregation.aggregateExpressions.foreach {
+ case max: Max =>
+ if (!processMinOrMax(max)) return false
+ case min: Min =>
+ if (!processMinOrMax(min)) return false
+ case _: CountStar =>
+ hasCount = true
+ case _ =>
+ return false
+ }
+
+ if (hasMinMax) {
+ dataSplits.forall {
+ dataSplit =>
+ dataSplit.dataFiles().asScala.forall {
+ dataFile =>
+ // It means there are all column statistics when valueStatsCols
== null
+ dataFile.valueStatsCols() == null ||
+ minmaxColumns.forall(dataFile.valueStatsCols().contains)
+ }
+ }
+ } else if (hasCount) {
+ dataSplits.forall(_.mergedRowCountAvailable())
+ } else {
+ true
+ }
+ }
+
+ def hasMinMaxAggregation(aggregation: Aggregation): Boolean = {
+ var hasMinMax = false;
+ aggregation.aggregateExpressions().foreach {
+ case _: Min | _: Max =>
+ hasMinMax = true
+ case _ =>
+ }
+ hasMinMax
+ }
+
+}
diff --git
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/aggregate/LocalAggregator.scala
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/aggregate/LocalAggregator.scala
index 8988e7218d..bb88aa669e 100644
---
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/aggregate/LocalAggregator.scala
+++
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/aggregate/LocalAggregator.scala
@@ -18,34 +18,59 @@
package org.apache.paimon.spark.aggregate
+import org.apache.paimon.CoreOptions
import org.apache.paimon.data.BinaryRow
+import org.apache.paimon.schema.SchemaManager
import org.apache.paimon.spark.SparkTypeUtils
import org.apache.paimon.spark.data.SparkInternalRow
-import org.apache.paimon.table.{DataTable, Table}
+import org.apache.paimon.stats.SimpleStatsEvolutions
+import org.apache.paimon.table.FileStoreTable
import org.apache.paimon.table.source.DataSplit
import org.apache.paimon.utils.{InternalRowUtils, ProjectedRow}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.JoinedRow
-import org.apache.spark.sql.connector.expressions.{Expression, NamedReference}
-import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc,
Aggregation, CountStar}
-import org.apache.spark.sql.types.{DataType, LongType, StructField, StructType}
+import org.apache.spark.sql.connector.expressions.NamedReference
+import org.apache.spark.sql.connector.expressions.aggregate.{Aggregation,
CountStar, Max, Min}
+import org.apache.spark.sql.execution.datasources.v2.V2ColumnUtils
+import org.apache.spark.sql.types.{DataType, StructField, StructType}
import scala.collection.mutable
-class LocalAggregator(table: Table) {
+class LocalAggregator(table: FileStoreTable) {
+ private val rowType = table.rowType()
private val partitionType = SparkTypeUtils.toPartitionType(table)
private val groupByEvaluatorMap = new mutable.HashMap[InternalRow,
Seq[AggFuncEvaluator[_]]]()
private var requiredGroupByType: Seq[DataType] = _
private var requiredGroupByIndexMapping: Seq[Int] = _
private var aggFuncEvaluatorGetter: () => Seq[AggFuncEvaluator[_]] = _
private var isInitialized = false
+ private lazy val simpleStatsEvolutions = {
+ val schemaManager = new SchemaManager(
+ table.fileIO(),
+ table.location(),
+ CoreOptions.branch(table.schema().options()))
+ new SimpleStatsEvolutions(sid => schemaManager.schema(sid).fields(),
table.schema().id())
+ }
- private def initialize(aggregation: Aggregation): Unit = {
+ def initialize(aggregation: Aggregation): Unit = {
aggFuncEvaluatorGetter = () =>
aggregation.aggregateExpressions().map {
case _: CountStar => new CountStarEvaluator()
- case _ => throw new UnsupportedOperationException()
+ case min: Min if V2ColumnUtils.extractV2Column(min.column).isDefined =>
+ val fieldName = V2ColumnUtils.extractV2Column(min.column).get
+ MinEvaluator(
+ rowType.getFieldIndex(fieldName),
+ rowType.getField(fieldName),
+ simpleStatsEvolutions)
+ case max: Max if V2ColumnUtils.extractV2Column(max.column).isDefined =>
+ val fieldName = V2ColumnUtils.extractV2Column(max.column).get
+ MaxEvaluator(
+ rowType.getFieldIndex(fieldName),
+ rowType.getField(fieldName),
+ simpleStatsEvolutions)
+ case _ =>
+ throw new UnsupportedOperationException()
}
requiredGroupByType = aggregation.groupByExpressions().map {
@@ -61,39 +86,6 @@ class LocalAggregator(table: Table) {
isInitialized = true
}
- private def supportAggregateFunction(func: AggregateFunc): Boolean = {
- func match {
- case _: CountStar => true
- case _ => false
- }
- }
-
- private def supportGroupByExpressions(exprs: Array[Expression]): Boolean = {
- // Support empty group by keys or group by partition column
- exprs.forall {
- case r: NamedReference =>
- r.fieldNames.length == 1 &&
table.partitionKeys().contains(r.fieldNames().head)
- case _ => false
- }
- }
-
- def pushAggregation(aggregation: Aggregation): Boolean = {
- if (!table.isInstanceOf[DataTable]) {
- return false
- }
-
- if (
- !supportGroupByExpressions(aggregation.groupByExpressions()) ||
- aggregation.aggregateExpressions().isEmpty ||
- aggregation.aggregateExpressions().exists(!supportAggregateFunction(_))
- ) {
- return false
- }
-
- initialize(aggregation)
- true
- }
-
private def requiredGroupByRow(partitionRow: BinaryRow): InternalRow = {
val projectedRow =
ProjectedRow.from(requiredGroupByIndexMapping.toArray).replaceRow(partitionRow)
@@ -139,24 +131,3 @@ class LocalAggregator(table: Table) {
StructType.apply(groupByFields ++ aggResultFields)
}
}
-
-trait AggFuncEvaluator[T] {
- def update(dataSplit: DataSplit): Unit
- def result(): T
- def resultType: DataType
- def prettyName: String
-}
-
-class CountStarEvaluator extends AggFuncEvaluator[Long] {
- private var _result: Long = 0L
-
- override def update(dataSplit: DataSplit): Unit = {
- _result += dataSplit.mergedRowCount()
- }
-
- override def result(): Long = _result
-
- override def resultType: DataType = LongType
-
- override def prettyName: String = "count_star"
-}
diff --git
a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/PushDownAggregatesTest.scala
b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/PushDownAggregatesTest.scala
index 78c02644a7..26c19ecc27 100644
---
a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/PushDownAggregatesTest.scala
+++
b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/PushDownAggregatesTest.scala
@@ -26,6 +26,8 @@ import org.apache.spark.sql.execution.LocalTableScanExec
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.aggregate.BaseAggregateExec
+import java.sql.Date
+
class PushDownAggregatesTest extends PaimonSparkTestBase with
AdaptiveSparkPlanHelper {
private def runAndCheckAggregate(
@@ -49,70 +51,162 @@ class PushDownAggregatesTest extends PaimonSparkTestBase
with AdaptiveSparkPlanH
}
}
- test("Push down aggregate - append table") {
+ test("Push down aggregate - append table without partitions") {
withTable("T") {
- spark.sql("CREATE TABLE T (c1 INT, c2 STRING) PARTITIONED BY(day
STRING)")
+ spark.sql("CREATE TABLE T (c1 INT, c2 STRING, c3 DOUBLE, c4 DATE)")
runAndCheckAggregate("SELECT COUNT(*) FROM T", Row(0) :: Nil, 0)
+ runAndCheckAggregate(
+ "SELECT COUNT(*), MIN(c1), MIN(c3), MIN(c4) FROM T",
+ Row(0, null, null, null) :: Nil,
+ 0)
+ runAndCheckAggregate(
+ "SELECT COUNT(*), MAX(c1), MAX(c3), MAX(c4) FROM T",
+ Row(0, null, null, null) :: Nil,
+ 0)
+ // count(c1) and min/max for string are not supported.
+ runAndCheckAggregate("SELECT COUNT(c1) FROM T", Row(0) :: Nil, 2)
+ runAndCheckAggregate("SELECT MIN(c2) FROM T", Row(null) :: Nil, 2)
+ runAndCheckAggregate("SELECT MAX(c2) FROM T", Row(null) :: Nil, 2)
+
// This query does not contain aggregate due to AQE optimize it to empty
relation.
runAndCheckAggregate("SELECT COUNT(*) FROM T GROUP BY c1", Nil, 0)
- runAndCheckAggregate("SELECT COUNT(c1) FROM T", Row(0) :: Nil, 2)
runAndCheckAggregate("SELECT COUNT(*), COUNT(c1) FROM T", Row(0, 0) ::
Nil, 2)
- runAndCheckAggregate("SELECT COUNT(*), COUNT(*) + 1 FROM T", Row(0, 1)
:: Nil, 0)
- runAndCheckAggregate("SELECT COUNT(*) as c FROM T WHERE day='a'", Row(0)
:: Nil, 0)
- runAndCheckAggregate("SELECT COUNT(*) FROM T WHERE c1=1", Row(0) :: Nil,
2)
- runAndCheckAggregate("SELECT COUNT(*) FROM T WHERE day='a' and c1=1",
Row(0) :: Nil, 2)
+ runAndCheckAggregate(
+ "SELECT COUNT(*) + 1, MIN(c1) * 10, MAX(c3) + 1.0 FROM T",
+ Row(1, null, null) :: Nil,
+ 0)
+ runAndCheckAggregate(
+ "SELECT COUNT(*) as cnt, MIN(c4) as min_c4 FROM T",
+ Row(0, null) :: Nil,
+ 0)
+ // The cases with common data filters are not supported.
+ runAndCheckAggregate("SELECT COUNT(*) FROM T WHERE c1 = 1", Row(0) ::
Nil, 2)
spark.sql(
- "INSERT INTO T VALUES(1, 'x', 'a'), (2, 'x', 'a'), (3, 'x', 'b'), (3,
'x', 'c'), (null, 'x', 'a')")
+ s"""
+ |INSERT INTO T VALUES (1, 'xyz', 11.1, TO_DATE('2025-01-01',
'yyyy-MM-dd')),
+ |(2, null, null, TO_DATE('2025-01-01', 'yyyy-MM-dd')), (3, 'abc',
33.3, null),
+ |(3, 'abc', null, TO_DATE('2025-03-01', 'yyyy-MM-dd')), (null,
'abc', 44.4, TO_DATE('2025-03-01', 'yyyy-MM-dd'))
+ |""".stripMargin)
+ val date1 = Date.valueOf("2025-01-01")
+ val date2 = Date.valueOf("2025-03-01")
runAndCheckAggregate("SELECT COUNT(*) FROM T", Row(5) :: Nil, 0)
runAndCheckAggregate(
"SELECT COUNT(*) FROM T GROUP BY c1",
Row(1) :: Row(1) :: Row(1) :: Row(2) :: Nil,
2)
runAndCheckAggregate("SELECT COUNT(c1) FROM T", Row(4) :: Nil, 2)
+
+ runAndCheckAggregate("SELECT COUNT(*), MIN(c1), MAX(c1) FROM T", Row(5,
1, 3) :: Nil, 0)
+ runAndCheckAggregate(
+ "SELECT COUNT(*), MIN(c2), MAX(c2) FROM T",
+ Row(5, "abc", "xyz") :: Nil,
+ 2)
+ runAndCheckAggregate("SELECT COUNT(*), MIN(c3), MAX(c3) FROM T", Row(5,
11.1, 44.4) :: Nil, 0)
+ runAndCheckAggregate(
+ "SELECT COUNT(*), MIN(c4), MAX(c4) FROM T",
+ Row(5, date1, date2) :: Nil,
+ 0)
runAndCheckAggregate("SELECT COUNT(*), COUNT(c1) FROM T", Row(5, 4) ::
Nil, 2)
runAndCheckAggregate("SELECT COUNT(*), COUNT(*) + 1 FROM T", Row(5, 6)
:: Nil, 0)
- runAndCheckAggregate("SELECT COUNT(*) as c FROM T WHERE day='a'", Row(3)
:: Nil, 0)
- runAndCheckAggregate("SELECT COUNT(*) FROM T WHERE c1=1", Row(1) :: Nil,
2)
- runAndCheckAggregate("SELECT COUNT(*) FROM T WHERE day='a' and c1=1",
Row(1) :: Nil, 2)
+ runAndCheckAggregate(
+ "SELECT COUNT(*) + 1, MIN(c1) * 10, MAX(c3) + 1.0 FROM T",
+ Row(6, 10, 45.4) :: Nil,
+ 0)
+ runAndCheckAggregate(
+ "SELECT MIN(c3) as min, MAX(c4) as max FROM T",
+ Row(11.1, date2) :: Nil,
+ 0)
+ runAndCheckAggregate("SELECT COUNT(*), MIN(c3) FROM T WHERE c1 = 3",
Row(2, 33.3) :: Nil, 2)
}
}
- test("Push down aggregate - group by partition column") {
+ test("Push down aggregate - append table with partitions") {
withTable("T") {
- spark.sql("CREATE TABLE T (c1 INT) PARTITIONED BY(day STRING, hour INT)")
+ spark.sql("CREATE TABLE T (c1 INT, c2 LONG) PARTITIONED BY(day STRING,
hour INT)")
runAndCheckAggregate("SELECT COUNT(*) FROM T GROUP BY day", Nil, 0)
- runAndCheckAggregate("SELECT day, COUNT(*) as c FROM T GROUP BY day,
hour", Nil, 0)
- runAndCheckAggregate("SELECT day, COUNT(*), hour FROM T GROUP BY day,
hour", Nil, 0)
runAndCheckAggregate(
- "SELECT day, COUNT(*), hour FROM T WHERE day='x' GROUP BY day, hour",
+ "SELECT day, hour, COUNT(*), MIN(c1), MIN(c1) FROM T GROUP BY day,
hour",
+ Nil,
+ 0)
+ runAndCheckAggregate(
+ "SELECT day, hour, COUNT(*), MIN(c2), MIN(c2) FROM T GROUP BY day,
hour",
+ Nil,
+ 0)
+ runAndCheckAggregate(
+ "SELECT day, COUNT(*), hour FROM T WHERE day= '2025-01-01' GROUP BY
day, hour",
Nil,
0)
// This query does not contain aggregate due to AQE optimize it to empty
relation.
runAndCheckAggregate("SELECT day, COUNT(*) FROM T GROUP BY c1, day",
Nil, 0)
spark.sql(
- "INSERT INTO T VALUES(1, 'x', 1), (2, 'x', 1), (3, 'x', 2), (3, 'x',
3), (null, 'y', null)")
+ """
+ |INSERT INTO T VALUES(1, 100L, '2025-01-01', 1), (2, null,
'2025-01-01', 1),
+ |(3, 300L, '2025-03-01', 3), (3, 330L, '2025-03-01', 3), (null,
400L, '2025-03-01', null)
+ |""".stripMargin)
- runAndCheckAggregate("SELECT COUNT(*) FROM T GROUP BY day", Row(1) ::
Row(4) :: Nil, 0)
+ runAndCheckAggregate("SELECT COUNT(*) FROM T GROUP BY day", Row(2) ::
Row(3) :: Nil, 0)
runAndCheckAggregate(
- "SELECT day, COUNT(*) as c FROM T GROUP BY day, hour",
- Row("x", 1) :: Row("x", 1) :: Row("x", 2) :: Row("y", 1) :: Nil,
+ "SELECT day, hour, COUNT(*) as c FROM T GROUP BY day, hour",
+ Row("2025-01-01", 1, 2) :: Row("2025-03-01", 3, 2) ::
Row("2025-03-01", null, 1) :: Nil,
0)
runAndCheckAggregate(
- "SELECT day, COUNT(*), hour FROM T GROUP BY day, hour",
- Row("x", 1, 2) :: Row("y", 1, null) :: Row("x", 2, 1) :: Row("x", 1,
3) :: Nil,
- 0)
+ "SELECT day, COUNT(*), hour, MIN(c1), MAX(c1) FROM T GROUP BY day,
hour",
+ Row("2025-01-01", 2, 1, 1, 2) :: Row("2025-03-01", 2, 3, 3, 3) :: Row(
+ "2025-03-01",
+ 1,
+ null,
+ null,
+ null) :: Nil,
+ 0
+ )
runAndCheckAggregate(
- "SELECT day, COUNT(*), hour FROM T WHERE day='x' GROUP BY day, hour",
- Row("x", 1, 2) :: Row("x", 1, 3) :: Row("x", 2, 1) :: Nil,
- 0)
+ "SELECT hour, COUNT(*), MIN(c2) as min, MAX(c2) as max FROM T WHERE
day='2025-03-01' GROUP BY day, hour",
+ Row(3, 2, 300L, 330L) :: Row(null, 1, 400L, 400L) :: Nil,
+ 0
+ )
+ runAndCheckAggregate(
+ "SELECT c1, day, COUNT(*) FROM T GROUP BY c1, day ORDER BY c1, day",
+ Row(null, "2025-03-01", 1) :: Row(1, "2025-01-01", 1) :: Row(2,
"2025-01-01", 1) :: Row(
+ 3,
+ "2025-03-01",
+ 2) :: Nil,
+ 2
+ )
+ }
+ }
+
+ test("Push down aggregate - append table with dense statistics") {
+ withTable("T") {
+ spark.sql("""
+ |CREATE TABLE T (c1 INT, c2 STRING, c3 DOUBLE, c4 DATE)
+ |TBLPROPERTIES('metadata.stats-mode' = 'none')
+ |""".stripMargin)
+ spark.sql(
+ s"""
+ |INSERT INTO T VALUES (1, 'xyz', 11.1, TO_DATE('2025-01-01',
'yyyy-MM-dd')),
+ |(2, null, null, TO_DATE('2025-01-01', 'yyyy-MM-dd')), (3, 'abc',
33.3, null),
+ |(3, 'abc', null, TO_DATE('2025-03-01', 'yyyy-MM-dd')), (null,
'abc', 44.4, TO_DATE('2025-03-01', 'yyyy-MM-dd'))
+ |""".stripMargin)
+
+ val date1 = Date.valueOf("2025-01-01")
+ val date2 = Date.valueOf("2025-03-01")
+ runAndCheckAggregate("SELECT COUNT(*) FROM T", Row(5) :: Nil, 0)
+
+ // for metadata.stats-mode = none, no available statistics.
+ runAndCheckAggregate("SELECT COUNT(*), MIN(c1), MAX(c1) FROM T", Row(5,
1, 3) :: Nil, 2)
+ runAndCheckAggregate(
+ "SELECT COUNT(*), MIN(c2), MAX(c2) FROM T",
+ Row(5, "abc", "xyz") :: Nil,
+ 2)
+ runAndCheckAggregate("SELECT COUNT(*), MIN(c3), MAX(c3) FROM T", Row(5,
11.1, 44.4) :: Nil, 2)
runAndCheckAggregate(
- "SELECT day, COUNT(*) FROM T GROUP BY c1, day",
- Row("x", 1) :: Row("x", 1) :: Row("x", 2) :: Row("y", 1) :: Nil,
+ "SELECT COUNT(*), MIN(c4), MAX(c4) FROM T",
+ Row(5, date1, date2) :: Nil,
2)
}
}