This is an automated email from the ASF dual-hosted git repository.
lzljs3620320 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/paimon.git
The following commit(s) were added to refs/heads/master by this push:
new 111efca381 [core] Support pushing down Decimal filter in parquet
format (#7175)
111efca381 is described below
commit 111efca381eccc8c36cc0b07008cc64df1fc3f8e
Author: Kerwin Zhang <[email protected]>
AuthorDate: Mon Feb 2 16:51:31 2026 +0800
[core] Support pushing down Decimal filter in parquet format (#7175)
---
.../parquet/filter2/predicate/ParquetFilters.java | 75 ++++++---
.../paimon/format/parquet/ParquetFiltersTest.java | 179 +++++++++++++++++++++
.../paimon/spark/sql/PaimonPushDownTestBase.scala | 41 +++++
3 files changed, 272 insertions(+), 23 deletions(-)
diff --git
a/paimon-format/src/main/java/org/apache/parquet/filter2/predicate/ParquetFilters.java
b/paimon-format/src/main/java/org/apache/parquet/filter2/predicate/ParquetFilters.java
index 1c2b2106b9..8643ebb6d3 100644
---
a/paimon-format/src/main/java/org/apache/parquet/filter2/predicate/ParquetFilters.java
+++
b/paimon-format/src/main/java/org/apache/parquet/filter2/predicate/ParquetFilters.java
@@ -19,6 +19,8 @@
package org.apache.parquet.filter2.predicate;
import org.apache.paimon.data.BinaryString;
+import org.apache.paimon.data.Decimal;
+import org.apache.paimon.format.parquet.ParquetSchemaConverter;
import org.apache.paimon.predicate.FieldRef;
import org.apache.paimon.predicate.FunctionVisitor;
import org.apache.paimon.predicate.LeafPredicate;
@@ -95,32 +97,38 @@ public class ParquetFilters {
@Override
public FilterPredicate visitLessThan(FieldRef fieldRef, Object
literal) {
- return new Operators.Lt(toParquetColumn(fieldRef),
toParquetObject(literal));
+ return new Operators.Lt(
+ toParquetColumn(fieldRef), toParquetObject(literal,
fieldRef.type()));
}
@Override
public FilterPredicate visitGreaterOrEqual(FieldRef fieldRef, Object
literal) {
- return new Operators.GtEq(toParquetColumn(fieldRef),
toParquetObject(literal));
+ return new Operators.GtEq(
+ toParquetColumn(fieldRef), toParquetObject(literal,
fieldRef.type()));
}
@Override
public FilterPredicate visitNotEqual(FieldRef fieldRef, Object
literal) {
- return new Operators.NotEq(toParquetColumn(fieldRef),
toParquetObject(literal));
+ return new Operators.NotEq(
+ toParquetColumn(fieldRef), toParquetObject(literal,
fieldRef.type()));
}
@Override
public FilterPredicate visitLessOrEqual(FieldRef fieldRef, Object
literal) {
- return new Operators.LtEq(toParquetColumn(fieldRef),
toParquetObject(literal));
+ return new Operators.LtEq(
+ toParquetColumn(fieldRef), toParquetObject(literal,
fieldRef.type()));
}
@Override
public FilterPredicate visitEqual(FieldRef fieldRef, Object literal) {
- return new Operators.Eq(toParquetColumn(fieldRef),
toParquetObject(literal));
+ return new Operators.Eq(
+ toParquetColumn(fieldRef), toParquetObject(literal,
fieldRef.type()));
}
@Override
public FilterPredicate visitGreaterThan(FieldRef fieldRef, Object
literal) {
- return new Operators.Gt(toParquetColumn(fieldRef),
toParquetObject(literal));
+ return new Operators.Gt(
+ toParquetColumn(fieldRef), toParquetObject(literal,
fieldRef.type()));
}
@Override
@@ -164,21 +172,22 @@ public class ParquetFilters {
@Override
public FilterPredicate visitIn(FieldRef fieldRef, List<Object>
literals) {
Operators.Column<?> column = toParquetColumn(fieldRef);
+ org.apache.paimon.types.DataType type = fieldRef.type();
if (column instanceof Operators.LongColumn) {
return FilterApi.in(
- (Operators.LongColumn) column, convertSets(literals,
Long.class));
+ (Operators.LongColumn) column, convertSets(literals,
Long.class, type));
} else if (column instanceof Operators.IntColumn) {
return FilterApi.in(
- (Operators.IntColumn) column, convertSets(literals,
Integer.class));
+ (Operators.IntColumn) column, convertSets(literals,
Integer.class, type));
} else if (column instanceof Operators.DoubleColumn) {
return FilterApi.in(
- (Operators.DoubleColumn) column, convertSets(literals,
Double.class));
+ (Operators.DoubleColumn) column, convertSets(literals,
Double.class, type));
} else if (column instanceof Operators.FloatColumn) {
return FilterApi.in(
- (Operators.FloatColumn) column, convertSets(literals,
Float.class));
+ (Operators.FloatColumn) column, convertSets(literals,
Float.class, type));
} else if (column instanceof Operators.BinaryColumn) {
return FilterApi.in(
- (Operators.BinaryColumn) column, convertSets(literals,
Binary.class));
+ (Operators.BinaryColumn) column, convertSets(literals,
Binary.class, type));
}
throw new UnsupportedOperationException();
@@ -187,21 +196,22 @@ public class ParquetFilters {
@Override
public FilterPredicate visitNotIn(FieldRef fieldRef, List<Object>
literals) {
Operators.Column<?> column = toParquetColumn(fieldRef);
+ org.apache.paimon.types.DataType type = fieldRef.type();
if (column instanceof Operators.LongColumn) {
return FilterApi.notIn(
- (Operators.LongColumn) column, convertSets(literals,
Long.class));
+ (Operators.LongColumn) column, convertSets(literals,
Long.class, type));
} else if (column instanceof Operators.IntColumn) {
return FilterApi.notIn(
- (Operators.IntColumn) column, convertSets(literals,
Integer.class));
+ (Operators.IntColumn) column, convertSets(literals,
Integer.class, type));
} else if (column instanceof Operators.DoubleColumn) {
return FilterApi.notIn(
- (Operators.DoubleColumn) column, convertSets(literals,
Double.class));
+ (Operators.DoubleColumn) column, convertSets(literals,
Double.class, type));
} else if (column instanceof Operators.FloatColumn) {
return FilterApi.notIn(
- (Operators.FloatColumn) column, convertSets(literals,
Float.class));
+ (Operators.FloatColumn) column, convertSets(literals,
Float.class, type));
} else if (column instanceof Operators.BinaryColumn) {
return FilterApi.notIn(
- (Operators.BinaryColumn) column, convertSets(literals,
Binary.class));
+ (Operators.BinaryColumn) column, convertSets(literals,
Binary.class, type));
}
throw new UnsupportedOperationException();
@@ -213,10 +223,11 @@ public class ParquetFilters {
}
}
- private static <T> Set<T> convertSets(List<Object> values, Class<T>
kclass) {
+ private static <T> Set<T> convertSets(
+ List<Object> values, Class<T> kclass,
org.apache.paimon.types.DataType type) {
Set<T> converted = new HashSet<>();
for (Object value : values) {
- Comparable<?> cmp = toParquetObject(value);
+ Comparable<?> cmp = toParquetObject(value, type);
if (kclass.isInstance(cmp)) {
converted.add((T) cmp);
} else {
@@ -230,7 +241,8 @@ public class ParquetFilters {
return fieldRef.type().accept(new
ConvertToColumnTypeVisitor(fieldRef.name()));
}
- private static Comparable<?> toParquetObject(Object value) {
+ private static Comparable<?> toParquetObject(
+ Object value, org.apache.paimon.types.DataType type) {
if (value == null) {
return null;
}
@@ -248,9 +260,19 @@ public class ParquetFilters {
return Binary.fromString(value.toString());
} else if (value instanceof byte[]) {
return Binary.fromReusedByteArray((byte[]) value);
+ } else if (value instanceof Decimal) {
+ Decimal decimal = (Decimal) value;
+ int precision = decimal.precision();
+ if (ParquetSchemaConverter.is32BitDecimal(precision)) {
+ return (int) decimal.toUnscaledLong();
+ } else if (ParquetSchemaConverter.is64BitDecimal(precision)) {
+ return decimal.toUnscaledLong();
+ } else {
+ return Binary.fromConstantByteArray(decimal.toUnscaledBytes());
+ }
}
- // TODO Support Decimal and Timestamp
+ // TODO Support Timestamp
throw new UnsupportedOperationException();
}
@@ -328,13 +350,20 @@ public class ParquetFilters {
return FilterApi.intColumn(name);
}
- // TODO we can support decimal and timestamp
-
@Override
public Operators.Column<?> visit(DecimalType decimalType) {
- throw new UnsupportedOperationException();
+ int precision = decimalType.getPrecision();
+ if (ParquetSchemaConverter.is32BitDecimal(precision)) {
+ return FilterApi.intColumn(name);
+ } else if (ParquetSchemaConverter.is64BitDecimal(precision)) {
+ return FilterApi.longColumn(name);
+ } else {
+ return FilterApi.binaryColumn(name);
+ }
}
+ // TODO we can support timestamp
+
@Override
public Operators.Column<?> visit(TimestampType timestampType) {
throw new UnsupportedOperationException();
diff --git
a/paimon-format/src/test/java/org/apache/paimon/format/parquet/ParquetFiltersTest.java
b/paimon-format/src/test/java/org/apache/paimon/format/parquet/ParquetFiltersTest.java
index 402fc782d4..3cd98e1434 100644
---
a/paimon-format/src/test/java/org/apache/paimon/format/parquet/ParquetFiltersTest.java
+++
b/paimon-format/src/test/java/org/apache/paimon/format/parquet/ParquetFiltersTest.java
@@ -18,10 +18,12 @@
package org.apache.paimon.format.parquet;
+import org.apache.paimon.data.Decimal;
import org.apache.paimon.predicate.Predicate;
import org.apache.paimon.predicate.PredicateBuilder;
import org.apache.paimon.types.BigIntType;
import org.apache.paimon.types.DataField;
+import org.apache.paimon.types.DecimalType;
import org.apache.paimon.types.DoubleType;
import org.apache.paimon.types.FloatType;
import org.apache.paimon.types.RowType;
@@ -35,6 +37,7 @@ import org.apache.parquet.filter2.predicate.ParquetFilters;
import org.apache.parquet.io.api.Binary;
import org.junit.jupiter.api.Test;
+import java.math.BigDecimal;
import java.util.Arrays;
import java.util.Collections;
import java.util.stream.Collectors;
@@ -231,6 +234,182 @@ class ParquetFiltersTest {
true);
}
+ @Test
+ public void testDecimal32Bit() {
+ // precision <= 9 uses INT32
+ int precision = 9;
+ int scale = 2;
+ PredicateBuilder builder =
+ new PredicateBuilder(
+ new RowType(
+ Collections.singletonList(
+ new DataField(
+ 0,
+ "decimal1",
+ new DecimalType(precision,
scale)))));
+
+ Decimal value = Decimal.fromBigDecimal(new BigDecimal("123.45"),
precision, scale);
+ int expectedIntVal = (int) value.toUnscaledLong(); // 12345
+
+ test(builder.isNull(0), "eq(decimal1, null)", true);
+ test(builder.isNotNull(0), "noteq(decimal1, null)", true);
+ test(builder.equal(0, value), "eq(decimal1, " + expectedIntVal + ")",
true);
+ test(builder.notEqual(0, value), "noteq(decimal1, " + expectedIntVal +
")", true);
+ test(builder.lessThan(0, value), "lt(decimal1, " + expectedIntVal +
")", true);
+ test(builder.lessOrEqual(0, value), "lteq(decimal1, " + expectedIntVal
+ ")", true);
+ test(builder.greaterThan(0, value), "gt(decimal1, " + expectedIntVal +
")", true);
+ test(builder.greaterOrEqual(0, value), "gteq(decimal1, " +
expectedIntVal + ")", true);
+ }
+
+ @Test
+ public void testDecimal64Bit() {
+ // 9 < precision <= 18 uses INT64
+ int precision = 18;
+ int scale = 4;
+ PredicateBuilder builder =
+ new PredicateBuilder(
+ new RowType(
+ Collections.singletonList(
+ new DataField(
+ 0,
+ "decimal1",
+ new DecimalType(precision,
scale)))));
+
+ Decimal value =
+ Decimal.fromBigDecimal(new BigDecimal("12345678901234.5678"),
precision, scale);
+ long expectedLongVal = value.toUnscaledLong();
+
+ test(builder.isNull(0), "eq(decimal1, null)", true);
+ test(builder.isNotNull(0), "noteq(decimal1, null)", true);
+ test(builder.equal(0, value), "eq(decimal1, " + expectedLongVal + ")",
true);
+ test(builder.notEqual(0, value), "noteq(decimal1, " + expectedLongVal
+ ")", true);
+ test(builder.lessThan(0, value), "lt(decimal1, " + expectedLongVal +
")", true);
+ test(builder.lessOrEqual(0, value), "lteq(decimal1, " +
expectedLongVal + ")", true);
+ test(builder.greaterThan(0, value), "gt(decimal1, " + expectedLongVal
+ ")", true);
+ test(builder.greaterOrEqual(0, value), "gteq(decimal1, " +
expectedLongVal + ")", true);
+ }
+
+ @Test
+ public void testDecimalBinary() {
+ // precision > 18 uses Binary
+ int precision = 38;
+ int scale = 10;
+ PredicateBuilder builder =
+ new PredicateBuilder(
+ new RowType(
+ Collections.singletonList(
+ new DataField(
+ 0,
+ "decimal1",
+ new DecimalType(precision,
scale)))));
+
+ Decimal value =
+ Decimal.fromBigDecimal(
+ new BigDecimal("12345678901234567890.1234567890"),
precision, scale);
+ Binary expectedBinary =
Binary.fromConstantByteArray(value.toUnscaledBytes());
+
+ test(builder.isNull(0), "eq(decimal1, null)", true);
+ test(builder.isNotNull(0), "noteq(decimal1, null)", true);
+ test(
+ builder.equal(0, value),
+ FilterApi.eq(FilterApi.binaryColumn("decimal1"),
expectedBinary),
+ true);
+ test(
+ builder.notEqual(0, value),
+ FilterApi.notEq(FilterApi.binaryColumn("decimal1"),
expectedBinary),
+ true);
+ test(
+ builder.lessThan(0, value),
+ FilterApi.lt(FilterApi.binaryColumn("decimal1"),
expectedBinary),
+ true);
+ test(
+ builder.greaterThan(0, value),
+ FilterApi.gt(FilterApi.binaryColumn("decimal1"),
expectedBinary),
+ true);
+ }
+
+ @Test
+ public void testInFilterDecimal32Bit() {
+ int precision = 9;
+ int scale = 2;
+ PredicateBuilder builder =
+ new PredicateBuilder(
+ new RowType(
+ Collections.singletonList(
+ new DataField(
+ 0,
+ "decimal1",
+ new DecimalType(precision,
scale)))));
+
+ Decimal v1 = Decimal.fromBigDecimal(new BigDecimal("100.00"),
precision, scale);
+ Decimal v2 = Decimal.fromBigDecimal(new BigDecimal("200.00"),
precision, scale);
+ Decimal v3 = Decimal.fromBigDecimal(new BigDecimal("300.00"),
precision, scale);
+
+ // For less than 21 elements, it expands to or(eq, eq, eq)
+ test(
+ builder.in(0, Arrays.asList(v1, v2, v3)),
+ "or(or(eq(decimal1, "
+ + (int) v1.toUnscaledLong()
+ + "), eq(decimal1, "
+ + (int) v2.toUnscaledLong()
+ + ")), eq(decimal1, "
+ + (int) v3.toUnscaledLong()
+ + "))",
+ true);
+
+ test(
+ builder.notIn(0, Arrays.asList(v1, v2, v3)),
+ "and(and(noteq(decimal1, "
+ + (int) v1.toUnscaledLong()
+ + "), noteq(decimal1, "
+ + (int) v2.toUnscaledLong()
+ + ")), noteq(decimal1, "
+ + (int) v3.toUnscaledLong()
+ + "))",
+ true);
+ }
+
+ @Test
+ public void testInFilterDecimal64Bit() {
+ int precision = 18;
+ int scale = 4;
+ PredicateBuilder builder =
+ new PredicateBuilder(
+ new RowType(
+ Collections.singletonList(
+ new DataField(
+ 0,
+ "decimal1",
+ new DecimalType(precision,
scale)))));
+
+ Decimal v1 = Decimal.fromBigDecimal(new
BigDecimal("10000000000.0000"), precision, scale);
+ Decimal v2 = Decimal.fromBigDecimal(new
BigDecimal("20000000000.0000"), precision, scale);
+ Decimal v3 = Decimal.fromBigDecimal(new
BigDecimal("30000000000.0000"), precision, scale);
+
+ // For less than 21 elements, it expands to or(eq, eq, eq)
+ test(
+ builder.in(0, Arrays.asList(v1, v2, v3)),
+ "or(or(eq(decimal1, "
+ + v1.toUnscaledLong()
+ + "), eq(decimal1, "
+ + v2.toUnscaledLong()
+ + ")), eq(decimal1, "
+ + v3.toUnscaledLong()
+ + "))",
+ true);
+
+ test(
+ builder.notIn(0, Arrays.asList(v1, v2, v3)),
+ "and(and(noteq(decimal1, "
+ + v1.toUnscaledLong()
+ + "), noteq(decimal1, "
+ + v2.toUnscaledLong()
+ + ")), noteq(decimal1, "
+ + v3.toUnscaledLong()
+ + "))",
+ true);
+ }
+
private void test(Predicate predicate, FilterPredicate parquetPredicate,
boolean canPushDown) {
FilterCompat.Filter filter =
ParquetFilters.convert(PredicateBuilder.splitAnd(predicate));
if (canPushDown) {
diff --git
a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTestBase.scala
b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTestBase.scala
index 1731231a48..747b140cb3 100644
---
a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTestBase.scala
+++
b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTestBase.scala
@@ -766,6 +766,47 @@ abstract class PaimonPushDownTestBase extends
PaimonSparkTestBase with AdaptiveS
}
}
+ test(s"Paimon pushdown: parquet decimal filter") {
+ withTable("T") {
+ spark.sql(s"""
+ |CREATE TABLE T (a DECIMAL(18,4), b STRING) using paimon
TBLPROPERTIES
+ |(
+ |'file.format' = 'parquet',
+ |'parquet.block.size' = '100',
+ |'target-file-size' = '10g'
+ |)
+ |""".stripMargin)
+
+ spark.sql("""INSERT INTO T VALUES
+ |(CAST(100.1234 AS DECIMAL(18,4)), 'a'),
+ |(CAST(200.5678 AS DECIMAL(18,4)), 'b'),
+ |(CAST(300.9999 AS DECIMAL(18,4)), 'c'),
+ |(CAST(150.0000 AS DECIMAL(18,4)), 'd')
+ |""".stripMargin)
+
+ // Test equals filter
+ checkAnswer(
+ spark.sql("SELECT * FROM T WHERE a = CAST(100.1234 AS DECIMAL(18,4))"),
+ Row(new java.math.BigDecimal("100.1234"), "a") :: Nil
+ )
+
+ // Test comparison filter
+ checkAnswer(
+ spark.sql("SELECT * FROM T WHERE a < CAST(200.0000 AS DECIMAL(18,4))
ORDER BY a"),
+ Row(new java.math.BigDecimal("100.1234"), "a") ::
+ Row(new java.math.BigDecimal("150.0000"), "d") :: Nil
+ )
+
+ // Test in filter
+ checkAnswer(
+ spark.sql(
+ "SELECT * FROM T WHERE a IN (CAST(100.1234 AS DECIMAL(18,4)),
CAST(300.9999 AS DECIMAL(18,4))) ORDER BY a"),
+ Row(new java.math.BigDecimal("100.1234"), "a") ::
+ Row(new java.math.BigDecimal("300.9999"), "c") :: Nil
+ )
+ }
+ }
+
private def getScanBuilder(tableName: String = "T"): ScanBuilder = {
SparkTable(loadTable(tableName)).newScanBuilder(CaseInsensitiveStringMap.empty())
}