This is an automated email from the ASF dual-hosted git repository.

lzljs3620320 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/paimon.git


The following commit(s) were added to refs/heads/master by this push:
     new ba21d6bf40 [core][flink][spark] Fix column pruning with row filter 
(#7150)
ba21d6bf40 is described below

commit ba21d6bf4066508f91def1b3d64d39878cd0d038
Author: Jiajia Li <[email protected]>
AuthorDate: Thu Jan 29 16:30:37 2026 +0800

    [core][flink][spark] Fix column pruning with row filter (#7150)
---
 .../table/format/predicate/PredicateUtils.java     |   6 +-
 .../paimon/flink/source/FlinkSourceBuilder.java    |  54 ++++++-
 .../org/apache/paimon/flink/RESTCatalogITCase.java | 168 +++++++++++++++++++--
 .../paimon/spark/PaimonBaseScanBuilder.scala       |  26 +++-
 .../paimon/spark/SparkCatalogWithRestTest.java     | 133 +++++++++++++++-
 5 files changed, 366 insertions(+), 21 deletions(-)

diff --git 
a/paimon-core/src/main/java/org/apache/paimon/table/format/predicate/PredicateUtils.java
 
b/paimon-core/src/main/java/org/apache/paimon/table/format/predicate/PredicateUtils.java
index 93ad6b9010..aac792f196 100644
--- 
a/paimon-core/src/main/java/org/apache/paimon/table/format/predicate/PredicateUtils.java
+++ 
b/paimon-core/src/main/java/org/apache/paimon/table/format/predicate/PredicateUtils.java
@@ -50,7 +50,7 @@ public class PredicateUtils {
 
         for (Predicate sub : predicates) {
             // Collect all field names referenced by this predicate
-            Set<String> referencedFields = sub.visit(new FieldNameCollector());
+            Set<String> referencedFields = collectFieldNames(sub);
             Optional<Predicate> transformed = transformFieldMapping(sub, 
fieldMap);
             if (transformed.isPresent() && referencedFields.size() == 1) {
                 Predicate child = transformed.get();
@@ -70,6 +70,10 @@ public class PredicateUtils {
         return result;
     }
 
+    public static Set<String> collectFieldNames(Predicate predicate) {
+        return predicate.visit(new FieldNameCollector());
+    }
+
     /** A visitor that collects all field names referenced by a predicate. */
     private static class FieldNameCollector implements 
PredicateVisitor<Set<String>> {
 
diff --git 
a/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/source/FlinkSourceBuilder.java
 
b/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/source/FlinkSourceBuilder.java
index 3e96dec1ea..6497f28184 100644
--- 
a/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/source/FlinkSourceBuilder.java
+++ 
b/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/source/FlinkSourceBuilder.java
@@ -20,6 +20,7 @@ package org.apache.paimon.flink.source;
 
 import org.apache.paimon.CoreOptions;
 import org.apache.paimon.annotation.VisibleForTesting;
+import org.apache.paimon.catalog.TableQueryAuthResult;
 import org.apache.paimon.flink.FlinkConnectorOptions;
 import org.apache.paimon.flink.NestedProjectedRowData;
 import org.apache.paimon.flink.Projection;
@@ -33,7 +34,10 @@ import org.apache.paimon.predicate.Predicate;
 import org.apache.paimon.table.BucketMode;
 import org.apache.paimon.table.FileStoreTable;
 import org.apache.paimon.table.Table;
+import org.apache.paimon.table.format.predicate.PredicateUtils;
 import org.apache.paimon.table.source.ReadBuilder;
+import org.apache.paimon.table.source.TableQueryAuth;
+import org.apache.paimon.types.DataField;
 import org.apache.paimon.utils.StringUtils;
 
 import org.apache.flink.api.common.eventtime.WatermarkStrategy;
@@ -57,8 +61,11 @@ import org.apache.flink.types.Row;
 
 import javax.annotation.Nullable;
 
+import java.util.ArrayList;
+import java.util.HashSet;
 import java.util.List;
 import java.util.Optional;
+import java.util.Set;
 
 import static 
org.apache.flink.table.types.utils.TypeConversions.fromLogicalToDataType;
 import static 
org.apache.paimon.flink.FlinkConnectorOptions.SOURCE_OPERATOR_UID_SUFFIX;
@@ -189,7 +196,7 @@ public class FlinkSourceBuilder {
     private ReadBuilder createReadBuilder(@Nullable 
org.apache.paimon.types.RowType readType) {
         ReadBuilder readBuilder = table.newReadBuilder();
         if (readType != null) {
-            readBuilder.withReadType(readType);
+            readBuilder.withReadType(readTypeWithAuth(readType));
         }
         if (predicate != null) {
             readBuilder.withFilter(predicate);
@@ -390,4 +397,49 @@ public class FlinkSourceBuilder {
                 "The align mode of paimon source currently only supports 
EXACTLY_ONCE checkpoint mode. Please set "
                         + "execution.checkpointing.mode to exactly-once");
     }
+
+    private org.apache.paimon.types.RowType readTypeWithAuth(
+            org.apache.paimon.types.RowType readType) {
+        if (!(table instanceof FileStoreTable)) {
+            return readType;
+        }
+
+        FileStoreTable fileStoreTable = (FileStoreTable) table;
+        TableQueryAuth auth =
+                fileStoreTable.catalogEnvironment().tableQueryAuth(new 
CoreOptions(conf.toMap()));
+        if (auth == null) {
+            return readType;
+        }
+
+        List<String> requiredFieldNames = readType.getFieldNames();
+        TableQueryAuthResult result = auth.auth(requiredFieldNames);
+        if (result == null) {
+            return readType;
+        }
+
+        Predicate authPredicate = result.extractPredicate();
+        if (authPredicate == null) {
+            return readType;
+        }
+
+        Set<String> authFieldNames = 
PredicateUtils.collectFieldNames(authPredicate);
+        Set<String> requiredFieldNameSet = new HashSet<>(requiredFieldNames);
+
+        List<DataField> additionalFields = new ArrayList<>();
+        for (DataField field : table.rowType().getFields()) {
+            if (authFieldNames.contains(field.name())
+                    && !requiredFieldNameSet.contains(field.name())) {
+                additionalFields.add(field);
+            }
+        }
+
+        if (additionalFields.isEmpty()) {
+            return readType;
+        }
+
+        // Create new read type with additional fields
+        List<DataField> newFields = new ArrayList<>(readType.getFields());
+        newFields.addAll(additionalFields);
+        return new org.apache.paimon.types.RowType(newFields);
+    }
 }
diff --git 
a/paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/RESTCatalogITCase.java
 
b/paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/RESTCatalogITCase.java
index 6a0bed06a9..de93679e1a 100644
--- 
a/paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/RESTCatalogITCase.java
+++ 
b/paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/RESTCatalogITCase.java
@@ -500,6 +500,90 @@ class RESTCatalogITCase extends RESTCatalogITCaseBase {
         assertThat(joinResult.get(0)).isEqualTo(Row.of(2, "Bob", 30, 60000.0));
         assertThat(joinResult.get(1)).isEqualTo(Row.of(3, "Charlie", 35, 
70000.0));
 
+        // Test column pruning with row filter
+        Predicate ageGe28Predicate =
+                LeafPredicate.of(
+                        new FieldTransform(new FieldRef(2, "age", 
DataTypes.INT())),
+                        GreaterOrEqual.INSTANCE,
+                        Collections.singletonList(28));
+        restCatalogServer.setRowFilterAuth(
+                Identifier.create(DATABASE_NAME, filterTable),
+                Collections.singletonList(ageGe28Predicate));
+
+        // Query only id and name, but age column should be automatically 
included for filtering
+        List<Row> pruneResult =
+                batchSql(
+                        String.format(
+                                "SELECT id, name FROM %s.%s ORDER BY id",
+                                DATABASE_NAME, filterTable));
+        assertThat(pruneResult.size()).isEqualTo(3);
+        assertThat(pruneResult.get(0).getField(0)).isEqualTo(2);
+        assertThat(pruneResult.get(0).getField(1)).isEqualTo("Bob");
+        assertThat(pruneResult.get(1).getField(0)).isEqualTo(3);
+        assertThat(pruneResult.get(1).getField(1)).isEqualTo("Charlie");
+        assertThat(pruneResult.get(2).getField(0)).isEqualTo(4);
+        assertThat(pruneResult.get(2).getField(1)).isEqualTo("David");
+
+        // Test with complex AND predicate - query only id
+        restCatalogServer.setRowFilterAuth(
+                Identifier.create(DATABASE_NAME, filterTable),
+                Collections.singletonList(combinedPredicate));
+
+        pruneResult =
+                batchSql(
+                        String.format(
+                                "SELECT id FROM %s.%s ORDER BY id", 
DATABASE_NAME, filterTable));
+        assertThat(pruneResult.size()).isEqualTo(1);
+        assertThat(pruneResult.get(0).getField(0)).isEqualTo(3);
+
+        // Test aggregate functions with row filter
+        restCatalogServer.setRowFilterAuth(
+                Identifier.create(DATABASE_NAME, filterTable),
+                Collections.singletonList(ageGe30Predicate));
+
+        // Test COUNT(*) with row filter
+        assertThat(
+                        batchSql(
+                                String.format(
+                                        "SELECT COUNT(*) FROM %s.%s", 
DATABASE_NAME, filterTable)))
+                .containsExactlyInAnyOrder(Row.of(2L));
+
+        // Test COUNT(column) with row filter
+        assertThat(
+                        batchSql(
+                                String.format(
+                                        "SELECT COUNT(name) FROM %s.%s",
+                                        DATABASE_NAME, filterTable)))
+                .containsExactlyInAnyOrder(Row.of(2L));
+
+        // Test GROUP BY with row filter
+        List<Row> groupByResult =
+                batchSql(
+                        String.format(
+                                "SELECT department, COUNT(*) FROM %s.%s GROUP 
BY department ORDER BY department",
+                                DATABASE_NAME, filterTable));
+        assertThat(groupByResult.size()).isEqualTo(2);
+        assertThat(groupByResult.get(0).getField(0)).isEqualTo("HR");
+        assertThat(groupByResult.get(0).getField(1)).isEqualTo(1L);
+        assertThat(groupByResult.get(1).getField(0)).isEqualTo("IT");
+        assertThat(groupByResult.get(1).getField(1)).isEqualTo(1L);
+
+        // Test HAVING clause with row filter
+        List<Row> havingResult =
+                batchSql(
+                        String.format(
+                                "SELECT department, COUNT(*) as cnt FROM %s.%s 
GROUP BY department HAVING COUNT(*) >= 1 ORDER BY department",
+                                DATABASE_NAME, filterTable));
+        assertThat(havingResult.size()).isEqualTo(2);
+
+        // Test COUNT DISTINCT with row filter
+        assertThat(
+                        batchSql(
+                                String.format(
+                                        "SELECT COUNT(DISTINCT department) 
FROM %s.%s",
+                                        DATABASE_NAME, filterTable)))
+                .containsExactlyInAnyOrder(Row.of(2L));
+
         // Clear row filter and verify original data
         restCatalogServer.setRowFilterAuth(Identifier.create(DATABASE_NAME, 
filterTable), null);
 
@@ -551,16 +635,6 @@ class RESTCatalogITCase extends RESTCatalogITCaseBase {
         assertThat(combinedResult.get(0).getField(3)).isEqualTo(25); // age 
not masked
         assertThat(combinedResult.get(0).getField(4)).isEqualTo("IT"); // 
department not masked
 
-        // Test must read with row filter columns
-        assertThatThrownBy(
-                        () ->
-                                batchSql(
-                                        String.format(
-                                                "SELECT id, name FROM %s.%s 
WHERE age > 30 ORDER BY id",
-                                                DATABASE_NAME, combinedTable)))
-                .rootCause()
-                .hasMessageContaining("Unable to read data without column 
department");
-
         // Test WHERE clause with both features
         assertThat(
                         batchSql(
@@ -569,6 +643,80 @@ class RESTCatalogITCase extends RESTCatalogITCaseBase {
                                         DATABASE_NAME, combinedTable)))
                 .containsExactlyInAnyOrder(Row.of(3, "***", "IT"));
 
+        // Test column pruning with both column masking and row filter
+        Predicate ageGe30Predicate =
+                LeafPredicate.of(
+                        new FieldTransform(new FieldRef(3, "age", 
DataTypes.INT())),
+                        GreaterOrEqual.INSTANCE,
+                        Collections.singletonList(30));
+        columnMasking.clear();
+        columnMasking.put("salary", salaryMaskTransform);
+        restCatalogServer.setColumnMaskingAuth(
+                Identifier.create(DATABASE_NAME, combinedTable), 
columnMasking);
+        restCatalogServer.setRowFilterAuth(
+                Identifier.create(DATABASE_NAME, combinedTable),
+                Collections.singletonList(ageGe30Predicate));
+
+        // Query only id, name and salary (masked)
+        List<Row> pruneResult =
+                batchSql(
+                        String.format(
+                                "SELECT id, name, salary FROM %s.%s ORDER BY 
id",
+                                DATABASE_NAME, combinedTable));
+        assertThat(pruneResult.size()).isEqualTo(2);
+        assertThat(pruneResult.get(0).getField(0)).isEqualTo(2);
+        assertThat(pruneResult.get(0).getField(1)).isEqualTo("Bob");
+        assertThat(pruneResult.get(0).getField(2)).isEqualTo("***"); // salary 
is masked
+        assertThat(pruneResult.get(1).getField(0)).isEqualTo(3);
+        assertThat(pruneResult.get(1).getField(1)).isEqualTo("Charlie");
+        assertThat(pruneResult.get(1).getField(2)).isEqualTo("***"); // salary 
is masked
+
+        // Test aggregate functions with column masking and row filter
+        assertThat(
+                        batchSql(
+                                String.format(
+                                        "SELECT COUNT(*) FROM %s.%s",
+                                        DATABASE_NAME, combinedTable)))
+                .containsExactlyInAnyOrder(Row.of(2L));
+        assertThat(
+                        batchSql(
+                                String.format(
+                                        "SELECT COUNT(name) FROM %s.%s",
+                                        DATABASE_NAME, combinedTable)))
+                .containsExactlyInAnyOrder(Row.of(2L));
+
+        // Test aggregation on non-masked columns with row filter
+        List<Row> deptAggResult =
+                batchSql(
+                        String.format(
+                                "SELECT department, COUNT(*) FROM %s.%s GROUP 
BY department ORDER BY department",
+                                DATABASE_NAME, combinedTable));
+        assertThat(deptAggResult.size()).isEqualTo(2);
+        assertThat(deptAggResult.get(0).getField(0)).isEqualTo("HR");
+        assertThat(deptAggResult.get(0).getField(1)).isEqualTo(1L);
+        assertThat(deptAggResult.get(1).getField(0)).isEqualTo("IT");
+        assertThat(deptAggResult.get(1).getField(1)).isEqualTo(1L);
+
+        // Test with non-existent column as row filter
+        Predicate nonExistentPredicate =
+                LeafPredicate.of(
+                        new FieldTransform(
+                                new FieldRef(10, "non_existent_column", 
DataTypes.STRING())),
+                        Equal.INSTANCE,
+                        
Collections.singletonList(BinaryString.fromString("value")));
+        restCatalogServer.setRowFilterAuth(
+                Identifier.create(DATABASE_NAME, combinedTable),
+                Collections.singletonList(nonExistentPredicate));
+
+        assertThatThrownBy(
+                        () ->
+                                batchSql(
+                                        String.format(
+                                                "SELECT id, name FROM %s.%s 
WHERE age > 30 ORDER BY id",
+                                                DATABASE_NAME, combinedTable)))
+                .rootCause()
+                .hasMessageContaining("Unable to read data without column 
non_existent_column");
+
         // Clear both column masking and row filter
         restCatalogServer.setColumnMaskingAuth(
                 Identifier.create(DATABASE_NAME, combinedTable), new 
HashMap<>());
diff --git 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonBaseScanBuilder.scala
 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonBaseScanBuilder.scala
index 47723171e4..4bca5bad11 100644
--- 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonBaseScanBuilder.scala
+++ 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonBaseScanBuilder.scala
@@ -22,7 +22,8 @@ import org.apache.paimon.CoreOptions
 import org.apache.paimon.partition.PartitionPredicate
 import 
org.apache.paimon.partition.PartitionPredicate.splitPartitionPredicatesAndDataPredicates
 import org.apache.paimon.predicate.{PartitionPredicateVisitor, Predicate, 
TopN, VectorSearch}
-import org.apache.paimon.table.{SpecialFields, Table}
+import org.apache.paimon.table.{FileStoreTable, SpecialFields, Table}
+import org.apache.paimon.table.format.predicate.PredicateUtils
 import org.apache.paimon.types.RowType
 
 import org.apache.spark.sql.connector.expressions.filter.{Predicate => 
SparkPredicate}
@@ -58,6 +59,29 @@ abstract class PaimonBaseScanBuilder
 
   override def pruneColumns(requiredSchema: StructType): Unit = {
     this.requiredSchema = requiredSchema
+    pruneColumnsWithAuth(requiredSchema)
+  }
+
+  private def pruneColumnsWithAuth(requiredSchema: StructType): Unit = {
+    val auth = table match {
+      case fileStoreTable: FileStoreTable =>
+        fileStoreTable.catalogEnvironment().tableQueryAuth(coreOptions)
+      case _ =>
+        return
+    }
+
+    val result = auth.auth(requiredSchema.fieldNames.toList.asJava)
+    if (result != null) {
+      val predicate = result.extractPredicate()
+      if (predicate != null) {
+        val names = PredicateUtils.collectFieldNames(predicate)
+        val fullType = SparkTypeUtils.fromPaimonRowType(table.rowType())
+        val requiredFieldNames = requiredSchema.fieldNames.toSet
+        val addFields = fullType.fields.filter(
+          field => names.contains(field.name) && 
!requiredFieldNames.contains(field.name))
+        this.requiredSchema = StructType(requiredSchema.fields ++ addFields)
+      }
+    }
   }
 
   override def pushPredicates(predicates: Array[SparkPredicate]): 
Array[SparkPredicate] = {
diff --git 
a/paimon-spark/paimon-spark-ut/src/test/java/org/apache/paimon/spark/SparkCatalogWithRestTest.java
 
b/paimon-spark/paimon-spark-ut/src/test/java/org/apache/paimon/spark/SparkCatalogWithRestTest.java
index cdf6384da3..4a8f43e61b 100644
--- 
a/paimon-spark/paimon-spark-ut/src/test/java/org/apache/paimon/spark/SparkCatalogWithRestTest.java
+++ 
b/paimon-spark/paimon-spark-ut/src/test/java/org/apache/paimon/spark/SparkCatalogWithRestTest.java
@@ -465,6 +465,75 @@ public class SparkCatalogWithRestTest {
         
assertThat(joinResult.get(0).toString()).isEqualTo("[2,Bob,30,60000.0]");
         
assertThat(joinResult.get(1).toString()).isEqualTo("[3,Charlie,35,70000.0]");
 
+        // Test column pruning with row filter
+        Predicate ageGe28Predicate =
+                LeafPredicate.of(
+                        new FieldTransform(new FieldRef(2, "age", 
DataTypes.INT())),
+                        GreaterOrEqual.INSTANCE,
+                        Collections.singletonList(28));
+        restCatalogServer.setRowFilterAuth(
+                Identifier.create("db2", "t_row_filter"),
+                Collections.singletonList(ageGe28Predicate));
+
+        // Query only id and name, but age column should be automatically 
included for filtering
+        List<Row> pruneResult =
+                spark.sql("SELECT id, name FROM t_row_filter ORDER BY 
id").collectAsList();
+        assertThat(pruneResult.size()).isEqualTo(3);
+        assertThat(pruneResult.get(0).getInt(0)).isEqualTo(2);
+        assertThat(pruneResult.get(0).getString(1)).isEqualTo("Bob");
+        assertThat(pruneResult.get(1).getInt(0)).isEqualTo(3);
+        assertThat(pruneResult.get(1).getString(1)).isEqualTo("Charlie");
+        assertThat(pruneResult.get(2).getInt(0)).isEqualTo(4);
+        assertThat(pruneResult.get(2).getString(1)).isEqualTo("David");
+
+        // Test with complex AND predicate
+        restCatalogServer.setRowFilterAuth(
+                Identifier.create("db2", "t_row_filter"),
+                Collections.singletonList(combinedPredicate));
+
+        // Query only id
+        pruneResult = spark.sql("SELECT id FROM t_row_filter ORDER BY 
id").collectAsList();
+        assertThat(pruneResult.size()).isEqualTo(1);
+        assertThat(pruneResult.get(0).getInt(0)).isEqualTo(3);
+
+        // Test aggregate functions with row filter
+        restCatalogServer.setRowFilterAuth(
+                Identifier.create("db2", "t_row_filter"),
+                Collections.singletonList(ageGe30Predicate));
+
+        // Test COUNT(*) with row filter
+        assertThat(spark.sql("SELECT COUNT(*) FROM 
t_row_filter").collectAsList().toString())
+                .isEqualTo("[[2]]");
+
+        // Test COUNT(column) with row filter
+        assertThat(spark.sql("SELECT COUNT(name) FROM 
t_row_filter").collectAsList().toString())
+                .isEqualTo("[[2]]");
+
+        // Test GROUP BY with row filter
+        List<Row> groupByResult =
+                spark.sql(
+                                "SELECT department, COUNT(*) FROM t_row_filter 
GROUP BY department ORDER BY department")
+                        .collectAsList();
+        assertThat(groupByResult.size()).isEqualTo(2);
+        assertThat(groupByResult.get(0).getString(0)).isEqualTo("HR");
+        assertThat(groupByResult.get(0).getLong(1)).isEqualTo(1);
+        assertThat(groupByResult.get(1).getString(0)).isEqualTo("IT");
+        assertThat(groupByResult.get(1).getLong(1)).isEqualTo(1);
+
+        // Test HAVING clause with row filter
+        List<Row> havingResult =
+                spark.sql(
+                                "SELECT department, COUNT(*) as cnt FROM 
t_row_filter GROUP BY department HAVING cnt >= 1 ORDER BY department")
+                        .collectAsList();
+        assertThat(havingResult.size()).isEqualTo(2);
+
+        // Test COUNT DISTINCT with row filter
+        assertThat(
+                        spark.sql("SELECT COUNT(DISTINCT department) FROM 
t_row_filter")
+                                .collectAsList()
+                                .toString())
+                .isEqualTo("[[2]]");
+
         // Clear row filter and verify original data
         restCatalogServer.setRowFilterAuth(Identifier.create("db2", 
"t_row_filter"), null);
 
@@ -512,14 +581,6 @@ public class SparkCatalogWithRestTest {
         assertThat(combinedResult.get(0).getInt(3)).isEqualTo(25); // age not 
masked
         assertThat(combinedResult.get(0).getString(4)).isEqualTo("IT"); // 
department not masked
 
-        // Test must read with row filter columns
-        assertThatThrownBy(
-                        () ->
-                                spark.sql(
-                                                "SELECT id, name FROM 
t_combined WHERE age > 30 ORDER BY id")
-                                        .collectAsList())
-                .hasMessageContaining("Unable to read data without column 
department");
-
         // Test WHERE clause with both features
         assertThat(
                         spark.sql(
@@ -528,6 +589,62 @@ public class SparkCatalogWithRestTest {
                                 .toString())
                 .isEqualTo("[[3,***,IT]]");
 
+        // Test column pruning with both column masking and row filter
+        columnMasking.clear();
+        columnMasking.put("salary", salaryMaskTransform);
+        restCatalogServer.setColumnMaskingAuth(
+                Identifier.create("db2", "t_combined"), columnMasking);
+        restCatalogServer.setRowFilterAuth(
+                Identifier.create("db2", "t_combined"),
+                Collections.singletonList(ageGe30Predicate));
+
+        // Query only id, name and salary (masked)
+        List<Row> pruneResult =
+                spark.sql("SELECT id, name, salary FROM t_combined ORDER BY 
id").collectAsList();
+        assertThat(pruneResult.size()).isEqualTo(2);
+        assertThat(pruneResult.get(0).getInt(0)).isEqualTo(2);
+        assertThat(pruneResult.get(0).getString(1)).isEqualTo("Bob");
+        assertThat(pruneResult.get(0).getString(2)).isEqualTo("***"); // 
salary is masked
+        assertThat(pruneResult.get(1).getInt(0)).isEqualTo(3);
+        assertThat(pruneResult.get(1).getString(1)).isEqualTo("Charlie");
+        assertThat(pruneResult.get(1).getString(2)).isEqualTo("***"); // 
salary is masked
+
+        // Test aggregate functions with column masking and row filter
+        assertThat(spark.sql("SELECT COUNT(*) FROM 
t_combined").collectAsList().toString())
+                .isEqualTo("[[2]]");
+        assertThat(spark.sql("SELECT COUNT(name) FROM 
t_combined").collectAsList().toString())
+                .isEqualTo("[[2]]");
+
+        // Test aggregation on non-masked columns with row filter
+        List<Row> deptAggResult =
+                spark.sql(
+                                "SELECT department, COUNT(*) FROM t_combined 
GROUP BY department ORDER BY department")
+                        .collectAsList();
+        assertThat(deptAggResult.size()).isEqualTo(2);
+        assertThat(deptAggResult.get(0).getString(0)).isEqualTo("HR");
+        assertThat(deptAggResult.get(0).getLong(1)).isEqualTo(1);
+        assertThat(deptAggResult.get(1).getString(0)).isEqualTo("IT");
+        assertThat(deptAggResult.get(1).getLong(1)).isEqualTo(1);
+
+        // Test with non-existent column as row filter
+        Predicate nonExistentPredicate =
+                LeafPredicate.of(
+                        new FieldTransform(
+                                new FieldRef(10, "non_existent_column", 
DataTypes.STRING())),
+                        Equal.INSTANCE,
+                        
Collections.singletonList(BinaryString.fromString("value")));
+        restCatalogServer.setRowFilterAuth(
+                Identifier.create("db2", "t_combined"),
+                Collections.singletonList(nonExistentPredicate));
+
+        // Test must read with row filter columns
+        assertThatThrownBy(
+                        () ->
+                                spark.sql(
+                                                "SELECT id, name FROM 
t_combined WHERE age > 30 ORDER BY id")
+                                        .collectAsList())
+                .hasMessageContaining("Unable to read data without column 
non_existent_column");
+
         // Clear both column masking and row filter
         restCatalogServer.setColumnMaskingAuth(
                 Identifier.create("db2", "t_combined"), new HashMap<>());

Reply via email to