This is an automated email from the ASF dual-hosted git repository.

lzljs3620320 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/paimon.git


The following commit(s) were added to refs/heads/master by this push:
     new 98f4de6a62 [spark] Make SparkCatalystPartitionPredicate thread safe 
(#7135)
98f4de6a62 is described below

commit 98f4de6a623b798f5a692ac3dd6c41927fe5c104
Author: Zouxxyy <[email protected]>
AuthorDate: Tue Jan 27 20:23:21 2026 +0800

    [spark] Make SparkCatalystPartitionPredicate thread safe (#7135)
---
 .../paimon/partition/PartitionPredicate.java       |  5 +-
 .../filter/SparkCatalystPartitionPredicate.scala   | 10 +--
 .../SparkCatalystPartitionPredicateTest.scala      | 84 ++++++++++++++++++++++
 3 files changed, 93 insertions(+), 6 deletions(-)

diff --git 
a/paimon-core/src/main/java/org/apache/paimon/partition/PartitionPredicate.java 
b/paimon-core/src/main/java/org/apache/paimon/partition/PartitionPredicate.java
index d997ad2db7..023a7d89ca 100644
--- 
a/paimon-core/src/main/java/org/apache/paimon/partition/PartitionPredicate.java
+++ 
b/paimon-core/src/main/java/org/apache/paimon/partition/PartitionPredicate.java
@@ -35,6 +35,7 @@ import org.apache.paimon.utils.Preconditions;
 import org.apache.paimon.utils.RowDataToObjectArrayConverter;
 
 import javax.annotation.Nullable;
+import javax.annotation.concurrent.ThreadSafe;
 
 import java.io.Serializable;
 import java.util.ArrayList;
@@ -56,10 +57,12 @@ import static 
org.apache.paimon.utils.Preconditions.checkArgument;
 import static org.apache.paimon.utils.Preconditions.checkNotNull;
 
 /**
- * A special predicate to filter partition only, just like {@link Predicate}.
+ * A special predicate to filter partition only, just like {@link Predicate}, 
this should be thread
+ * safe.
  *
  * @since 1.3.0
  */
+@ThreadSafe
 public interface PartitionPredicate extends Serializable {
 
     /**
diff --git 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/catalyst/filter/SparkCatalystPartitionPredicate.scala
 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/catalyst/filter/SparkCatalystPartitionPredicate.scala
index b3128e0bc1..f030a67d24 100644
--- 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/catalyst/filter/SparkCatalystPartitionPredicate.scala
+++ 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/catalyst/filter/SparkCatalystPartitionPredicate.scala
@@ -40,13 +40,13 @@ case class SparkCatalystPartitionPredicate(
 
   private val partitionSchema: StructType = 
CharVarcharUtils.replaceCharVarcharWithStringInSchema(
     SparkTypeUtils.fromPaimonRowType(partitionRowType))
-  @transient private val predicate: BasePredicate =
-    new StructExpressionFilters(partitionFilter, partitionSchema).toPredicate
-  @transient private val sparkPartitionRow: SparkInternalRow =
-    SparkInternalRow.create(partitionRowType)
+  @transient private lazy val predicateThreadLocal: ThreadLocal[BasePredicate] 
=
+    ThreadLocal.withInitial(
+      () => new StructExpressionFilters(partitionFilter, 
partitionSchema).toPredicate)
 
   override def test(partition: BinaryRow): Boolean = {
-    predicate.eval(sparkPartitionRow.replace(partition))
+    val predicate = predicateThreadLocal.get()
+    
predicate.eval(SparkInternalRow.create(partitionRowType).replace(partition))
   }
 
   override def test(
diff --git 
a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/predicate/SparkCatalystPartitionPredicateTest.scala
 
b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/predicate/SparkCatalystPartitionPredicateTest.scala
index 18639f47ad..9ccfacdaac 100644
--- 
a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/predicate/SparkCatalystPartitionPredicateTest.scala
+++ 
b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/predicate/SparkCatalystPartitionPredicateTest.scala
@@ -32,8 +32,10 @@ import org.apache.spark.sql.catalyst.plans.logical.Filter
 import org.assertj.core.api.Assertions.assertThat
 
 import java.util.{List => JList}
+import java.util.concurrent.{CountDownLatch, Executors, ExecutorService}
 
 import scala.collection.JavaConverters._
+import scala.collection.mutable.ArrayBuffer
 
 class SparkCatalystPartitionPredicateTest extends PaimonSparkTestBase {
 
@@ -194,6 +196,88 @@ class SparkCatalystPartitionPredicateTest extends 
PaimonSparkTestBase {
     }
   }
 
+  test("SparkCatalystPartitionPredicate: thread safety") {
+    withTable("t") {
+      sql("""
+            |CREATE TABLE t (id INT, value INT, year STRING, month STRING, day 
STRING)
+            |PARTITIONED BY (year, month, day)
+            |""".stripMargin)
+
+      sql("""
+            |INSERT INTO t values
+            |(1, 100, '2024', '01', '01'),
+            |(2, 200, '2024', '01', '02'),
+            |(3, 300, '2024', '01', '03'),
+            |(4, 400, '2024', '02', '01'),
+            |(5, 500, '2024', '02', '02'),
+            |(6, 600, '2024', '03', '01')
+            |""".stripMargin)
+
+      val table = loadTable("t")
+      val partitionRowType = table.rowType().project(table.partitionKeys())
+
+      val q =
+        """
+          |SELECT * FROM t
+          |WHERE CONCAT_WS('-', year, month)
+          |BETWEEN '2024-01' AND '2024-01'
+          |""".stripMargin
+
+      val partitionFilters =
+        extractSupportedPartitionFilters(extractCatalystFilters(q), 
partitionRowType)
+      val partitionPredicate = 
SparkCatalystPartitionPredicate(partitionFilters, partitionRowType)
+
+      val allPartitions = table.newScan().listPartitions().asScala.toSeq
+
+      val threadCount = 10
+      val iterationsPerThread = 100
+      val executor: ExecutorService = Executors.newFixedThreadPool(threadCount)
+      val latch = new CountDownLatch(threadCount)
+      val errors = new ArrayBuffer[Throwable]()
+
+      try {
+        for (_ <- 0 until threadCount) {
+          executor.submit(new Runnable {
+            override def run(): Unit = {
+              try {
+                for (_ <- 0 until iterationsPerThread) {
+                  // Directly test the predicate.test() method which is the 
core of thread safety
+                  val matchedPartitions = allPartitions.filter(p => 
partitionPredicate.test(p))
+                  val results =
+                    matchedPartitions.map(r => internalRowToString(r, 
partitionRowType)).toSet
+
+                  // Verify that all results match expected partitions
+                  val expected = Set("+I[2024, 01, 01]", "+I[2024, 01, 02]", 
"+I[2024, 01, 03]")
+                  if (results != expected) {
+                    throw new AssertionError(
+                      s"Expected $expected but got $results in thread 
${Thread.currentThread().getName}")
+                  }
+                }
+              } catch {
+                case e: Throwable =>
+                  errors.synchronized {
+                    errors += e
+                  }
+              } finally {
+                latch.countDown()
+              }
+            }
+          })
+        }
+
+        latch.await()
+
+        // Check if there were any errors
+        if (errors.nonEmpty) {
+          fail(s"Thread safety test failed with ${errors.size} errors: 
${errors.head.getMessage}")
+        }
+      } finally {
+        executor.shutdown()
+        executor.awaitTermination(10, java.util.concurrent.TimeUnit.SECONDS)
+      }
+    }
+  }
+
   def extractCatalystFilters(sqlStr: String): Seq[Expression] = {
     var filters: Seq[Expression] = Seq.empty
     // Set ansi false to make sure some filters like `Cast` not push down

Reply via email to