This is an automated email from the ASF dual-hosted git repository.
huaxingao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg.git
The following commit(s) were added to refs/heads/main by this push:
new d9749048e3 Spark: Fix aggregate pushdown (#15070)
d9749048e3 is described below
commit d9749048e38eb947883c5e65e05b12733d35b7d9
Author: Vrishabh <[email protected]>
AuthorDate: Tue Mar 17 05:49:40 2026 +0530
Spark: Fix aggregate pushdown (#15070)
* Fix aggregate pushdown
* Address review comment
* Add additional test case
* Address review comment.
* Minor comment updates
---
.../apache/iceberg/expressions/BoundAggregate.java | 5 +++
.../apache/iceberg/expressions/MaxAggregate.java | 4 ++
.../apache/iceberg/expressions/MinAggregate.java | 4 ++
.../iceberg/spark/sql/TestAggregatePushDown.java | 45 ++++++++++++++++++++++
4 files changed, 58 insertions(+)
diff --git
a/api/src/main/java/org/apache/iceberg/expressions/BoundAggregate.java
b/api/src/main/java/org/apache/iceberg/expressions/BoundAggregate.java
index 03c371df25..72f611df20 100644
--- a/api/src/main/java/org/apache/iceberg/expressions/BoundAggregate.java
+++ b/api/src/main/java/org/apache/iceberg/expressions/BoundAggregate.java
@@ -51,6 +51,11 @@ public class BoundAggregate<T, C> extends
Aggregate<BoundTerm<T>> implements Bou
this.getClass().getName() + " does not implement newAggregator()");
}
+ boolean containsNan(DataFile file, int fieldId) {
+ Long nanCount = safeGet(file.nanValueCounts(), fieldId);
+ return nanCount != null && nanCount > 0;
+ }
+
@Override
public BoundReference<?> ref() {
return term().ref();
diff --git a/api/src/main/java/org/apache/iceberg/expressions/MaxAggregate.java
b/api/src/main/java/org/apache/iceberg/expressions/MaxAggregate.java
index d37af7470d..2948ffa421 100644
--- a/api/src/main/java/org/apache/iceberg/expressions/MaxAggregate.java
+++ b/api/src/main/java/org/apache/iceberg/expressions/MaxAggregate.java
@@ -40,6 +40,10 @@ public class MaxAggregate<T> extends ValueAggregate<T> {
@Override
protected boolean hasValue(DataFile file) {
+ // Can't determine max from metadata when NaN values are present since it
could be -NaN or +NaN
+ if (containsNan(file, fieldId)) {
+ return false;
+ }
boolean hasBound = safeContainsKey(file.upperBounds(), fieldId);
Long valueCount = safeGet(file.valueCounts(), fieldId);
Long nullCount = safeGet(file.nullValueCounts(), fieldId);
diff --git a/api/src/main/java/org/apache/iceberg/expressions/MinAggregate.java
b/api/src/main/java/org/apache/iceberg/expressions/MinAggregate.java
index 667b66d650..cf13f92562 100644
--- a/api/src/main/java/org/apache/iceberg/expressions/MinAggregate.java
+++ b/api/src/main/java/org/apache/iceberg/expressions/MinAggregate.java
@@ -40,6 +40,10 @@ public class MinAggregate<T> extends ValueAggregate<T> {
@Override
protected boolean hasValue(DataFile file) {
+ // Can't determine min from metadata when NaN values are present since it
could be -NaN or +NaN
+ if (containsNan(file, fieldId)) {
+ return false;
+ }
boolean hasBound = safeContainsKey(file.lowerBounds(), fieldId);
Long valueCount = safeGet(file.valueCounts(), fieldId);
Long nullCount = safeGet(file.nullValueCounts(), fieldId);
diff --git
a/spark/v4.1/spark/src/test/java/org/apache/iceberg/spark/sql/TestAggregatePushDown.java
b/spark/v4.1/spark/src/test/java/org/apache/iceberg/spark/sql/TestAggregatePushDown.java
index ce0a0f26a0..4baaf2d1fb 100644
---
a/spark/v4.1/spark/src/test/java/org/apache/iceberg/spark/sql/TestAggregatePushDown.java
+++
b/spark/v4.1/spark/src/test/java/org/apache/iceberg/spark/sql/TestAggregatePushDown.java
@@ -766,6 +766,51 @@ public class TestAggregatePushDown extends CatalogTestBase
{
assertEquals("expected and actual should equal", expected, actual);
}
+ @TestTemplate
+ public void testNanWithLowerAndUpperBoundMetrics() {
+ sql("CREATE TABLE %s (id int, data float) USING iceberg PARTITIONED BY
(id)", tableName);
+ sql(
+ "INSERT INTO %s VALUES (1, float('nan')),"
+ + "(1, float('nan')), "
+ + "(1, 10.0), "
+ + "(2, 2), "
+ + "(2, float('nan')), "
+ + "(3, float('nan')), "
+ + "(3, 1)",
+ tableName);
+
+ // Validate all files has upper bound, lower bound and nan count
+ String countsQuery =
+ "select readable_metrics.data.nan_value_count > 0, "
+ + "isnull(readable_metrics.data.lower_bound), "
+ + "isnull(readable_metrics.data.upper_bound) "
+ + "from %s.files";
+
+ Object[] expectedResult = new Object[] {true, false, false};
+ assertThat(sql(countsQuery, tableName))
+ .as("Data files should contain nan count, lower bound and upper
bound.")
+ .allMatch(row -> Arrays.equals(row, expectedResult));
+
+ // Check aggregates are not pushed down
+ String select = "SELECT count(*), max(data), min(data), count(data) FROM
%s";
+
+ List<Object[]> explain = sql("EXPLAIN " + select, tableName);
+ String explainString =
explain.get(0)[0].toString().toLowerCase(Locale.ROOT);
+ boolean explainContainsPushDownAggregates =
+ (explainString.contains("max(data)")
+ || explainString.contains("min(data)")
+ || explainString.contains("count(data)"));
+
+ assertThat(explainContainsPushDownAggregates)
+ .as("explain should not contain the pushed down aggregates")
+ .isFalse();
+
+ List<Object[]> actual = sql(select, tableName);
+ List<Object[]> expected = Lists.newArrayList();
+ expected.add(new Object[] {7L, Float.NaN, 1.0F, 7L});
+ assertEquals("expected and actual should equal", expected, actual);
+ }
+
@TestTemplate
public void testInfinity() {
sql(