This is an automated email from the ASF dual-hosted git repository.
lzljs3620320 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/paimon.git
The following commit(s) were added to refs/heads/master by this push:
new 98f4de6a62 [spark] Make SparkCatalystPartitionPredicate thread safe
(#7135)
98f4de6a62 is described below
commit 98f4de6a623b798f5a692ac3dd6c41927fe5c104
Author: Zouxxyy <[email protected]>
AuthorDate: Tue Jan 27 20:23:21 2026 +0800
[spark] Make SparkCatalystPartitionPredicate thread safe (#7135)
---
.../paimon/partition/PartitionPredicate.java | 5 +-
.../filter/SparkCatalystPartitionPredicate.scala | 10 +--
.../SparkCatalystPartitionPredicateTest.scala | 84 ++++++++++++++++++++++
3 files changed, 93 insertions(+), 6 deletions(-)
diff --git
a/paimon-core/src/main/java/org/apache/paimon/partition/PartitionPredicate.java
b/paimon-core/src/main/java/org/apache/paimon/partition/PartitionPredicate.java
index d997ad2db7..023a7d89ca 100644
---
a/paimon-core/src/main/java/org/apache/paimon/partition/PartitionPredicate.java
+++
b/paimon-core/src/main/java/org/apache/paimon/partition/PartitionPredicate.java
@@ -35,6 +35,7 @@ import org.apache.paimon.utils.Preconditions;
import org.apache.paimon.utils.RowDataToObjectArrayConverter;
import javax.annotation.Nullable;
+import javax.annotation.concurrent.ThreadSafe;
import java.io.Serializable;
import java.util.ArrayList;
@@ -56,10 +57,12 @@ import static
org.apache.paimon.utils.Preconditions.checkArgument;
import static org.apache.paimon.utils.Preconditions.checkNotNull;
/**
- * A special predicate to filter partition only, just like {@link Predicate}.
+ * A special predicate to filter partition only, just like {@link Predicate},
this should be thread
+ * safe.
*
* @since 1.3.0
*/
+@ThreadSafe
public interface PartitionPredicate extends Serializable {
/**
diff --git
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/catalyst/filter/SparkCatalystPartitionPredicate.scala
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/catalyst/filter/SparkCatalystPartitionPredicate.scala
index b3128e0bc1..f030a67d24 100644
---
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/catalyst/filter/SparkCatalystPartitionPredicate.scala
+++
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/catalyst/filter/SparkCatalystPartitionPredicate.scala
@@ -40,13 +40,13 @@ case class SparkCatalystPartitionPredicate(
private val partitionSchema: StructType =
CharVarcharUtils.replaceCharVarcharWithStringInSchema(
SparkTypeUtils.fromPaimonRowType(partitionRowType))
- @transient private val predicate: BasePredicate =
- new StructExpressionFilters(partitionFilter, partitionSchema).toPredicate
- @transient private val sparkPartitionRow: SparkInternalRow =
- SparkInternalRow.create(partitionRowType)
+ @transient private lazy val predicateThreadLocal: ThreadLocal[BasePredicate]
=
+ ThreadLocal.withInitial(
+ () => new StructExpressionFilters(partitionFilter,
partitionSchema).toPredicate)
override def test(partition: BinaryRow): Boolean = {
- predicate.eval(sparkPartitionRow.replace(partition))
+ val predicate = predicateThreadLocal.get()
+
predicate.eval(SparkInternalRow.create(partitionRowType).replace(partition))
}
override def test(
diff --git
a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/predicate/SparkCatalystPartitionPredicateTest.scala
b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/predicate/SparkCatalystPartitionPredicateTest.scala
index 18639f47ad..9ccfacdaac 100644
---
a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/predicate/SparkCatalystPartitionPredicateTest.scala
+++
b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/predicate/SparkCatalystPartitionPredicateTest.scala
@@ -32,8 +32,10 @@ import org.apache.spark.sql.catalyst.plans.logical.Filter
import org.assertj.core.api.Assertions.assertThat
import java.util.{List => JList}
+import java.util.concurrent.{CountDownLatch, Executors, ExecutorService}
import scala.collection.JavaConverters._
+import scala.collection.mutable.ArrayBuffer
class SparkCatalystPartitionPredicateTest extends PaimonSparkTestBase {
@@ -194,6 +196,88 @@ class SparkCatalystPartitionPredicateTest extends
PaimonSparkTestBase {
}
}
+ test("SparkCatalystPartitionPredicate: thread safety") {
+ withTable("t") {
+ sql("""
+ |CREATE TABLE t (id INT, value INT, year STRING, month STRING, day
STRING)
+ |PARTITIONED BY (year, month, day)
+ |""".stripMargin)
+
+ sql("""
+ |INSERT INTO t values
+ |(1, 100, '2024', '01', '01'),
+ |(2, 200, '2024', '01', '02'),
+ |(3, 300, '2024', '01', '03'),
+ |(4, 400, '2024', '02', '01'),
+ |(5, 500, '2024', '02', '02'),
+ |(6, 600, '2024', '03', '01')
+ |""".stripMargin)
+
+ val table = loadTable("t")
+ val partitionRowType = table.rowType().project(table.partitionKeys())
+
+ val q =
+ """
+ |SELECT * FROM t
+ |WHERE CONCAT_WS('-', year, month)
+ |BETWEEN '2024-01' AND '2024-01'
+ |""".stripMargin
+
+ val partitionFilters =
+ extractSupportedPartitionFilters(extractCatalystFilters(q),
partitionRowType)
+ val partitionPredicate =
SparkCatalystPartitionPredicate(partitionFilters, partitionRowType)
+
+ val allPartitions = table.newScan().listPartitions().asScala.toSeq
+
+ val threadCount = 10
+ val iterationsPerThread = 100
+ val executor: ExecutorService = Executors.newFixedThreadPool(threadCount)
+ val latch = new CountDownLatch(threadCount)
+ val errors = new ArrayBuffer[Throwable]()
+
+ try {
+ for (_ <- 0 until threadCount) {
+ executor.submit(new Runnable {
+ override def run(): Unit = {
+ try {
+ for (_ <- 0 until iterationsPerThread) {
+ // Directly test the predicate.test() method which is the
core of thread safety
+ val matchedPartitions = allPartitions.filter(p =>
partitionPredicate.test(p))
+ val results =
+ matchedPartitions.map(r => internalRowToString(r,
partitionRowType)).toSet
+
+ // Verify that all results match expected partitions
+ val expected = Set("+I[2024, 01, 01]", "+I[2024, 01, 02]",
"+I[2024, 01, 03]")
+ if (results != expected) {
+ throw new AssertionError(
+ s"Expected $expected but got $results in thread
${Thread.currentThread().getName}")
+ }
+ }
+ } catch {
+ case e: Throwable =>
+ errors.synchronized {
+ errors += e
+ }
+ } finally {
+ latch.countDown()
+ }
+ }
+ })
+ }
+
+ latch.await()
+
+ // Check if there were any errors
+ if (errors.nonEmpty) {
+ fail(s"Thread safety test failed with ${errors.size} errors:
${errors.head.getMessage}")
+ }
+ } finally {
+ executor.shutdown()
+ executor.awaitTermination(10, java.util.concurrent.TimeUnit.SECONDS)
+ }
+ }
+ }
+
def extractCatalystFilters(sqlStr: String): Seq[Expression] = {
var filters: Seq[Expression] = Seq.empty
// Set ansi false to make sure some filters like `Cast` not push down