This is an automated email from the ASF dual-hosted git repository.

lzljs3620320 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/paimon.git


The following commit(s) were added to refs/heads/master by this push:
     new 111efca381 [core] Support pushing down Decimal filter in parquet 
format (#7175)
111efca381 is described below

commit 111efca381eccc8c36cc0b07008cc64df1fc3f8e
Author: Kerwin Zhang <[email protected]>
AuthorDate: Mon Feb 2 16:51:31 2026 +0800

    [core] Support pushing down Decimal filter in parquet format (#7175)
---
 .../parquet/filter2/predicate/ParquetFilters.java  |  75 ++++++---
 .../paimon/format/parquet/ParquetFiltersTest.java  | 179 +++++++++++++++++++++
 .../paimon/spark/sql/PaimonPushDownTestBase.scala  |  41 +++++
 3 files changed, 272 insertions(+), 23 deletions(-)

diff --git 
a/paimon-format/src/main/java/org/apache/parquet/filter2/predicate/ParquetFilters.java
 
b/paimon-format/src/main/java/org/apache/parquet/filter2/predicate/ParquetFilters.java
index 1c2b2106b9..8643ebb6d3 100644
--- 
a/paimon-format/src/main/java/org/apache/parquet/filter2/predicate/ParquetFilters.java
+++ 
b/paimon-format/src/main/java/org/apache/parquet/filter2/predicate/ParquetFilters.java
@@ -19,6 +19,8 @@
 package org.apache.parquet.filter2.predicate;
 
 import org.apache.paimon.data.BinaryString;
+import org.apache.paimon.data.Decimal;
+import org.apache.paimon.format.parquet.ParquetSchemaConverter;
 import org.apache.paimon.predicate.FieldRef;
 import org.apache.paimon.predicate.FunctionVisitor;
 import org.apache.paimon.predicate.LeafPredicate;
@@ -95,32 +97,38 @@ public class ParquetFilters {
 
         @Override
         public FilterPredicate visitLessThan(FieldRef fieldRef, Object 
literal) {
-            return new Operators.Lt(toParquetColumn(fieldRef), 
toParquetObject(literal));
+            return new Operators.Lt(
+                    toParquetColumn(fieldRef), toParquetObject(literal, 
fieldRef.type()));
         }
 
         @Override
         public FilterPredicate visitGreaterOrEqual(FieldRef fieldRef, Object 
literal) {
-            return new Operators.GtEq(toParquetColumn(fieldRef), 
toParquetObject(literal));
+            return new Operators.GtEq(
+                    toParquetColumn(fieldRef), toParquetObject(literal, 
fieldRef.type()));
         }
 
         @Override
         public FilterPredicate visitNotEqual(FieldRef fieldRef, Object 
literal) {
-            return new Operators.NotEq(toParquetColumn(fieldRef), 
toParquetObject(literal));
+            return new Operators.NotEq(
+                    toParquetColumn(fieldRef), toParquetObject(literal, 
fieldRef.type()));
         }
 
         @Override
         public FilterPredicate visitLessOrEqual(FieldRef fieldRef, Object 
literal) {
-            return new Operators.LtEq(toParquetColumn(fieldRef), 
toParquetObject(literal));
+            return new Operators.LtEq(
+                    toParquetColumn(fieldRef), toParquetObject(literal, 
fieldRef.type()));
         }
 
         @Override
         public FilterPredicate visitEqual(FieldRef fieldRef, Object literal) {
-            return new Operators.Eq(toParquetColumn(fieldRef), 
toParquetObject(literal));
+            return new Operators.Eq(
+                    toParquetColumn(fieldRef), toParquetObject(literal, 
fieldRef.type()));
         }
 
         @Override
         public FilterPredicate visitGreaterThan(FieldRef fieldRef, Object 
literal) {
-            return new Operators.Gt(toParquetColumn(fieldRef), 
toParquetObject(literal));
+            return new Operators.Gt(
+                    toParquetColumn(fieldRef), toParquetObject(literal, 
fieldRef.type()));
         }
 
         @Override
@@ -164,21 +172,22 @@ public class ParquetFilters {
         @Override
         public FilterPredicate visitIn(FieldRef fieldRef, List<Object> 
literals) {
             Operators.Column<?> column = toParquetColumn(fieldRef);
+            org.apache.paimon.types.DataType type = fieldRef.type();
             if (column instanceof Operators.LongColumn) {
                 return FilterApi.in(
-                        (Operators.LongColumn) column, convertSets(literals, 
Long.class));
+                        (Operators.LongColumn) column, convertSets(literals, 
Long.class, type));
             } else if (column instanceof Operators.IntColumn) {
                 return FilterApi.in(
-                        (Operators.IntColumn) column, convertSets(literals, 
Integer.class));
+                        (Operators.IntColumn) column, convertSets(literals, 
Integer.class, type));
             } else if (column instanceof Operators.DoubleColumn) {
                 return FilterApi.in(
-                        (Operators.DoubleColumn) column, convertSets(literals, 
Double.class));
+                        (Operators.DoubleColumn) column, convertSets(literals, 
Double.class, type));
             } else if (column instanceof Operators.FloatColumn) {
                 return FilterApi.in(
-                        (Operators.FloatColumn) column, convertSets(literals, 
Float.class));
+                        (Operators.FloatColumn) column, convertSets(literals, 
Float.class, type));
             } else if (column instanceof Operators.BinaryColumn) {
                 return FilterApi.in(
-                        (Operators.BinaryColumn) column, convertSets(literals, 
Binary.class));
+                        (Operators.BinaryColumn) column, convertSets(literals, 
Binary.class, type));
             }
 
             throw new UnsupportedOperationException();
@@ -187,21 +196,22 @@ public class ParquetFilters {
         @Override
         public FilterPredicate visitNotIn(FieldRef fieldRef, List<Object> 
literals) {
             Operators.Column<?> column = toParquetColumn(fieldRef);
+            org.apache.paimon.types.DataType type = fieldRef.type();
             if (column instanceof Operators.LongColumn) {
                 return FilterApi.notIn(
-                        (Operators.LongColumn) column, convertSets(literals, 
Long.class));
+                        (Operators.LongColumn) column, convertSets(literals, 
Long.class, type));
             } else if (column instanceof Operators.IntColumn) {
                 return FilterApi.notIn(
-                        (Operators.IntColumn) column, convertSets(literals, 
Integer.class));
+                        (Operators.IntColumn) column, convertSets(literals, 
Integer.class, type));
             } else if (column instanceof Operators.DoubleColumn) {
                 return FilterApi.notIn(
-                        (Operators.DoubleColumn) column, convertSets(literals, 
Double.class));
+                        (Operators.DoubleColumn) column, convertSets(literals, 
Double.class, type));
             } else if (column instanceof Operators.FloatColumn) {
                 return FilterApi.notIn(
-                        (Operators.FloatColumn) column, convertSets(literals, 
Float.class));
+                        (Operators.FloatColumn) column, convertSets(literals, 
Float.class, type));
             } else if (column instanceof Operators.BinaryColumn) {
                 return FilterApi.notIn(
-                        (Operators.BinaryColumn) column, convertSets(literals, 
Binary.class));
+                        (Operators.BinaryColumn) column, convertSets(literals, 
Binary.class, type));
             }
 
             throw new UnsupportedOperationException();
@@ -213,10 +223,11 @@ public class ParquetFilters {
         }
     }
 
-    private static <T> Set<T> convertSets(List<Object> values, Class<T> 
kclass) {
+    private static <T> Set<T> convertSets(
+            List<Object> values, Class<T> kclass, 
org.apache.paimon.types.DataType type) {
         Set<T> converted = new HashSet<>();
         for (Object value : values) {
-            Comparable<?> cmp = toParquetObject(value);
+            Comparable<?> cmp = toParquetObject(value, type);
             if (kclass.isInstance(cmp)) {
                 converted.add((T) cmp);
             } else {
@@ -230,7 +241,8 @@ public class ParquetFilters {
         return fieldRef.type().accept(new 
ConvertToColumnTypeVisitor(fieldRef.name()));
     }
 
-    private static Comparable<?> toParquetObject(Object value) {
+    private static Comparable<?> toParquetObject(
+            Object value, org.apache.paimon.types.DataType type) {
         if (value == null) {
             return null;
         }
@@ -248,9 +260,19 @@ public class ParquetFilters {
             return Binary.fromString(value.toString());
         } else if (value instanceof byte[]) {
             return Binary.fromReusedByteArray((byte[]) value);
+        } else if (value instanceof Decimal) {
+            Decimal decimal = (Decimal) value;
+            int precision = decimal.precision();
+            if (ParquetSchemaConverter.is32BitDecimal(precision)) {
+                return (int) decimal.toUnscaledLong();
+            } else if (ParquetSchemaConverter.is64BitDecimal(precision)) {
+                return decimal.toUnscaledLong();
+            } else {
+                return Binary.fromConstantByteArray(decimal.toUnscaledBytes());
+            }
         }
 
-        // TODO Support Decimal and Timestamp
+        // TODO Support Timestamp
         throw new UnsupportedOperationException();
     }
 
@@ -328,13 +350,20 @@ public class ParquetFilters {
             return FilterApi.intColumn(name);
         }
 
-        // TODO we can support decimal and timestamp
-
         @Override
         public Operators.Column<?> visit(DecimalType decimalType) {
-            throw new UnsupportedOperationException();
+            int precision = decimalType.getPrecision();
+            if (ParquetSchemaConverter.is32BitDecimal(precision)) {
+                return FilterApi.intColumn(name);
+            } else if (ParquetSchemaConverter.is64BitDecimal(precision)) {
+                return FilterApi.longColumn(name);
+            } else {
+                return FilterApi.binaryColumn(name);
+            }
         }
 
+        // TODO we can support timestamp
+
         @Override
         public Operators.Column<?> visit(TimestampType timestampType) {
             throw new UnsupportedOperationException();
diff --git 
a/paimon-format/src/test/java/org/apache/paimon/format/parquet/ParquetFiltersTest.java
 
b/paimon-format/src/test/java/org/apache/paimon/format/parquet/ParquetFiltersTest.java
index 402fc782d4..3cd98e1434 100644
--- 
a/paimon-format/src/test/java/org/apache/paimon/format/parquet/ParquetFiltersTest.java
+++ 
b/paimon-format/src/test/java/org/apache/paimon/format/parquet/ParquetFiltersTest.java
@@ -18,10 +18,12 @@
 
 package org.apache.paimon.format.parquet;
 
+import org.apache.paimon.data.Decimal;
 import org.apache.paimon.predicate.Predicate;
 import org.apache.paimon.predicate.PredicateBuilder;
 import org.apache.paimon.types.BigIntType;
 import org.apache.paimon.types.DataField;
+import org.apache.paimon.types.DecimalType;
 import org.apache.paimon.types.DoubleType;
 import org.apache.paimon.types.FloatType;
 import org.apache.paimon.types.RowType;
@@ -35,6 +37,7 @@ import org.apache.parquet.filter2.predicate.ParquetFilters;
 import org.apache.parquet.io.api.Binary;
 import org.junit.jupiter.api.Test;
 
+import java.math.BigDecimal;
 import java.util.Arrays;
 import java.util.Collections;
 import java.util.stream.Collectors;
@@ -231,6 +234,182 @@ class ParquetFiltersTest {
                 true);
     }
 
+    @Test
+    public void testDecimal32Bit() {
+        // precision <= 9 uses INT32
+        int precision = 9;
+        int scale = 2;
+        PredicateBuilder builder =
+                new PredicateBuilder(
+                        new RowType(
+                                Collections.singletonList(
+                                        new DataField(
+                                                0,
+                                                "decimal1",
+                                                new DecimalType(precision, 
scale)))));
+
+        Decimal value = Decimal.fromBigDecimal(new BigDecimal("123.45"), 
precision, scale);
+        int expectedIntVal = (int) value.toUnscaledLong(); // 12345
+
+        test(builder.isNull(0), "eq(decimal1, null)", true);
+        test(builder.isNotNull(0), "noteq(decimal1, null)", true);
+        test(builder.equal(0, value), "eq(decimal1, " + expectedIntVal + ")", 
true);
+        test(builder.notEqual(0, value), "noteq(decimal1, " + expectedIntVal + 
")", true);
+        test(builder.lessThan(0, value), "lt(decimal1, " + expectedIntVal + 
")", true);
+        test(builder.lessOrEqual(0, value), "lteq(decimal1, " + expectedIntVal 
+ ")", true);
+        test(builder.greaterThan(0, value), "gt(decimal1, " + expectedIntVal + 
")", true);
+        test(builder.greaterOrEqual(0, value), "gteq(decimal1, " + 
expectedIntVal + ")", true);
+    }
+
+    @Test
+    public void testDecimal64Bit() {
+        // 9 < precision <= 18 uses INT64
+        int precision = 18;
+        int scale = 4;
+        PredicateBuilder builder =
+                new PredicateBuilder(
+                        new RowType(
+                                Collections.singletonList(
+                                        new DataField(
+                                                0,
+                                                "decimal1",
+                                                new DecimalType(precision, 
scale)))));
+
+        Decimal value =
+                Decimal.fromBigDecimal(new BigDecimal("12345678901234.5678"), 
precision, scale);
+        long expectedLongVal = value.toUnscaledLong();
+
+        test(builder.isNull(0), "eq(decimal1, null)", true);
+        test(builder.isNotNull(0), "noteq(decimal1, null)", true);
+        test(builder.equal(0, value), "eq(decimal1, " + expectedLongVal + ")", 
true);
+        test(builder.notEqual(0, value), "noteq(decimal1, " + expectedLongVal 
+ ")", true);
+        test(builder.lessThan(0, value), "lt(decimal1, " + expectedLongVal + 
")", true);
+        test(builder.lessOrEqual(0, value), "lteq(decimal1, " + 
expectedLongVal + ")", true);
+        test(builder.greaterThan(0, value), "gt(decimal1, " + expectedLongVal 
+ ")", true);
+        test(builder.greaterOrEqual(0, value), "gteq(decimal1, " + 
expectedLongVal + ")", true);
+    }
+
+    @Test
+    public void testDecimalBinary() {
+        // precision > 18 uses Binary
+        int precision = 38;
+        int scale = 10;
+        PredicateBuilder builder =
+                new PredicateBuilder(
+                        new RowType(
+                                Collections.singletonList(
+                                        new DataField(
+                                                0,
+                                                "decimal1",
+                                                new DecimalType(precision, 
scale)))));
+
+        Decimal value =
+                Decimal.fromBigDecimal(
+                        new BigDecimal("12345678901234567890.1234567890"), 
precision, scale);
+        Binary expectedBinary = 
Binary.fromConstantByteArray(value.toUnscaledBytes());
+
+        test(builder.isNull(0), "eq(decimal1, null)", true);
+        test(builder.isNotNull(0), "noteq(decimal1, null)", true);
+        test(
+                builder.equal(0, value),
+                FilterApi.eq(FilterApi.binaryColumn("decimal1"), 
expectedBinary),
+                true);
+        test(
+                builder.notEqual(0, value),
+                FilterApi.notEq(FilterApi.binaryColumn("decimal1"), 
expectedBinary),
+                true);
+        test(
+                builder.lessThan(0, value),
+                FilterApi.lt(FilterApi.binaryColumn("decimal1"), 
expectedBinary),
+                true);
+        test(
+                builder.greaterThan(0, value),
+                FilterApi.gt(FilterApi.binaryColumn("decimal1"), 
expectedBinary),
+                true);
+    }
+
+    @Test
+    public void testInFilterDecimal32Bit() {
+        int precision = 9;
+        int scale = 2;
+        PredicateBuilder builder =
+                new PredicateBuilder(
+                        new RowType(
+                                Collections.singletonList(
+                                        new DataField(
+                                                0,
+                                                "decimal1",
+                                                new DecimalType(precision, 
scale)))));
+
+        Decimal v1 = Decimal.fromBigDecimal(new BigDecimal("100.00"), 
precision, scale);
+        Decimal v2 = Decimal.fromBigDecimal(new BigDecimal("200.00"), 
precision, scale);
+        Decimal v3 = Decimal.fromBigDecimal(new BigDecimal("300.00"), 
precision, scale);
+
+        // For less than 21 elements, it expands to or(eq, eq, eq)
+        test(
+                builder.in(0, Arrays.asList(v1, v2, v3)),
+                "or(or(eq(decimal1, "
+                        + (int) v1.toUnscaledLong()
+                        + "), eq(decimal1, "
+                        + (int) v2.toUnscaledLong()
+                        + ")), eq(decimal1, "
+                        + (int) v3.toUnscaledLong()
+                        + "))",
+                true);
+
+        test(
+                builder.notIn(0, Arrays.asList(v1, v2, v3)),
+                "and(and(noteq(decimal1, "
+                        + (int) v1.toUnscaledLong()
+                        + "), noteq(decimal1, "
+                        + (int) v2.toUnscaledLong()
+                        + ")), noteq(decimal1, "
+                        + (int) v3.toUnscaledLong()
+                        + "))",
+                true);
+    }
+
+    @Test
+    public void testInFilterDecimal64Bit() {
+        int precision = 18;
+        int scale = 4;
+        PredicateBuilder builder =
+                new PredicateBuilder(
+                        new RowType(
+                                Collections.singletonList(
+                                        new DataField(
+                                                0,
+                                                "decimal1",
+                                                new DecimalType(precision, 
scale)))));
+
+        Decimal v1 = Decimal.fromBigDecimal(new 
BigDecimal("10000000000.0000"), precision, scale);
+        Decimal v2 = Decimal.fromBigDecimal(new 
BigDecimal("20000000000.0000"), precision, scale);
+        Decimal v3 = Decimal.fromBigDecimal(new 
BigDecimal("30000000000.0000"), precision, scale);
+
+        // For less than 21 elements, it expands to or(eq, eq, eq)
+        test(
+                builder.in(0, Arrays.asList(v1, v2, v3)),
+                "or(or(eq(decimal1, "
+                        + v1.toUnscaledLong()
+                        + "), eq(decimal1, "
+                        + v2.toUnscaledLong()
+                        + ")), eq(decimal1, "
+                        + v3.toUnscaledLong()
+                        + "))",
+                true);
+
+        test(
+                builder.notIn(0, Arrays.asList(v1, v2, v3)),
+                "and(and(noteq(decimal1, "
+                        + v1.toUnscaledLong()
+                        + "), noteq(decimal1, "
+                        + v2.toUnscaledLong()
+                        + ")), noteq(decimal1, "
+                        + v3.toUnscaledLong()
+                        + "))",
+                true);
+    }
+
     private void test(Predicate predicate, FilterPredicate parquetPredicate, 
boolean canPushDown) {
         FilterCompat.Filter filter = 
ParquetFilters.convert(PredicateBuilder.splitAnd(predicate));
         if (canPushDown) {
diff --git 
a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTestBase.scala
 
b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTestBase.scala
index 1731231a48..747b140cb3 100644
--- 
a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTestBase.scala
+++ 
b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTestBase.scala
@@ -766,6 +766,47 @@ abstract class PaimonPushDownTestBase extends 
PaimonSparkTestBase with AdaptiveS
     }
   }
 
+  test(s"Paimon pushdown: parquet decimal filter") {
+    withTable("T") {
+      spark.sql(s"""
+                   |CREATE TABLE T (a DECIMAL(18,4), b STRING) using paimon 
TBLPROPERTIES
+                   |(
+                   |'file.format' = 'parquet',
+                   |'parquet.block.size' = '100',
+                   |'target-file-size' = '10g'
+                   |)
+                   |""".stripMargin)
+
+      spark.sql("""INSERT INTO T VALUES
+                  |(CAST(100.1234 AS DECIMAL(18,4)), 'a'),
+                  |(CAST(200.5678 AS DECIMAL(18,4)), 'b'),
+                  |(CAST(300.9999 AS DECIMAL(18,4)), 'c'),
+                  |(CAST(150.0000 AS DECIMAL(18,4)), 'd')
+                  |""".stripMargin)
+
+      // Test equals filter
+      checkAnswer(
+        spark.sql("SELECT * FROM T WHERE a = CAST(100.1234 AS DECIMAL(18,4))"),
+        Row(new java.math.BigDecimal("100.1234"), "a") :: Nil
+      )
+
+      // Test comparison filter
+      checkAnswer(
+        spark.sql("SELECT * FROM T WHERE a < CAST(200.0000 AS DECIMAL(18,4)) 
ORDER BY a"),
+        Row(new java.math.BigDecimal("100.1234"), "a") ::
+          Row(new java.math.BigDecimal("150.0000"), "d") :: Nil
+      )
+
+      // Test in filter
+      checkAnswer(
+        spark.sql(
+          "SELECT * FROM T WHERE a IN (CAST(100.1234 AS DECIMAL(18,4)), 
CAST(300.9999 AS DECIMAL(18,4))) ORDER BY a"),
+        Row(new java.math.BigDecimal("100.1234"), "a") ::
+          Row(new java.math.BigDecimal("300.9999"), "c") :: Nil
+      )
+    }
+  }
+
   private def getScanBuilder(tableName: String = "T"): ScanBuilder = {
     
SparkTable(loadTable(tableName)).newScanBuilder(CaseInsensitiveStringMap.empty())
   }

Reply via email to