This is an automated email from the ASF dual-hosted git repository.

lzljs3620320 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/paimon.git


The following commit(s) were added to refs/heads/master by this push:
     new 65a22b39af [spark] Introduce SparkCatalystPartitionPredicate (#6813)
65a22b39af is described below

commit 65a22b39af07fa6940e31c81127b8c7f0a6fb190
Author: Zouxxyy <[email protected]>
AuthorDate: Tue Dec 16 08:49:56 2025 +0800

    [spark] Introduce SparkCatalystPartitionPredicate (#6813)
---
 .../filter/SparkCatalystPartitionPredicate.scala   | 102 ++++++++++
 .../SparkCatalystPartitionPredicateTest.scala      | 220 +++++++++++++++++++++
 2 files changed, 322 insertions(+)

diff --git 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/catalyst/filter/SparkCatalystPartitionPredicate.scala
 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/catalyst/filter/SparkCatalystPartitionPredicate.scala
new file mode 100644
index 0000000000..b3128e0bc1
--- /dev/null
+++ 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/catalyst/filter/SparkCatalystPartitionPredicate.scala
@@ -0,0 +1,102 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.filter
+
+import org.apache.paimon.data.{BinaryRow, InternalArray, InternalRow => 
PaimonInternalRow}
+import org.apache.paimon.partition.PartitionPredicate
+import org.apache.paimon.spark.SparkTypeUtils
+import org.apache.paimon.spark.data.SparkInternalRow
+import org.apache.paimon.types.RowType
+
+import org.apache.spark.sql.catalyst.expressions.{And, Attribute, 
BasePredicate, BoundReference, Expression, Predicate, PythonUDF, 
SubqueryExpression}
+import org.apache.spark.sql.catalyst.util.CharVarcharUtils
+import org.apache.spark.sql.execution.datasources.DataSourceUtils
+import org.apache.spark.sql.types.StructType
+
+/**
+ * A [[PartitionPredicate]] that applies Spark Catalyst partition filters to a 
partition
+ * [[BinaryRow]].
+ */
+case class SparkCatalystPartitionPredicate(
+    partitionFilter: Expression,
+    partitionRowType: RowType
+) extends PartitionPredicate {
+
+  private val partitionSchema: StructType = 
CharVarcharUtils.replaceCharVarcharWithStringInSchema(
+    SparkTypeUtils.fromPaimonRowType(partitionRowType))
+  @transient private val predicate: BasePredicate =
+    new StructExpressionFilters(partitionFilter, partitionSchema).toPredicate
+  @transient private val sparkPartitionRow: SparkInternalRow =
+    SparkInternalRow.create(partitionRowType)
+
+  override def test(partition: BinaryRow): Boolean = {
+    predicate.eval(sparkPartitionRow.replace(partition))
+  }
+
+  override def test(
+      rowCount: Long,
+      minValues: PaimonInternalRow,
+      maxValues: PaimonInternalRow,
+      nullCounts: InternalArray): Boolean = true
+
+  override def toString: String = partitionFilter.toString()
+
+  private class StructExpressionFilters(filter: Expression, schema: 
StructType) {
+    def toPredicate: BasePredicate = {
+      Predicate.create(filter.transform { case a: Attribute => 
toRef(a.name).get })
+    }
+
+    // Finds a filter attribute in the schema and converts it to a 
`BoundReference`
+    private def toRef(attr: String): Option[BoundReference] = {
+      // The names have been normalized and case sensitivity is not a concern 
here.
+      schema.getFieldIndex(attr).map {
+        index =>
+          val field = schema(index)
+          BoundReference(index, field.dataType, field.nullable)
+      }
+    }
+  }
+}
+
+object SparkCatalystPartitionPredicate {
+
+  def apply(
+      partitionFilters: Seq[Expression],
+      partitionRowType: RowType): SparkCatalystPartitionPredicate = {
+    assert(partitionFilters.nonEmpty, "partitionFilters is empty")
+    new SparkCatalystPartitionPredicate(
+      partitionFilters.sortBy(_.references.size).reduce(And),
+      partitionRowType)
+  }
+
+  /** Extracts supported partition filters from the given filters. */
+  def extractSupportedPartitionFilters(
+      filters: Seq[Expression],
+      partitionRowType: RowType): Seq[Expression] = {
+    val partitionSchema: StructType = 
SparkTypeUtils.fromPaimonRowType(partitionRowType)
+    val (deterministicFilters, _) = filters.partition(_.deterministic)
+    val (partitionFilters, _) =
+      DataSourceUtils.getPartitionFiltersAndDataFilters(partitionSchema, 
deterministicFilters)
+    partitionFilters.filter {
+      f =>
+        // Python UDFs might exist because this rule is applied before 
``ExtractPythonUDFs``.
+        !SubqueryExpression.hasSubquery(f) && 
!f.exists(_.isInstanceOf[PythonUDF])
+    }
+  }
+}
diff --git 
a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/predicate/SparkCatalystPartitionPredicateTest.scala
 
b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/predicate/SparkCatalystPartitionPredicateTest.scala
new file mode 100644
index 0000000000..18639f47ad
--- /dev/null
+++ 
b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/predicate/SparkCatalystPartitionPredicateTest.scala
@@ -0,0 +1,220 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.paimon.spark.predicate
+
+import org.apache.paimon.data.DataFormatTestUtil.internalRowToString
+import org.apache.paimon.partition.PartitionPredicate
+import org.apache.paimon.spark.PaimonSparkTestBase
+import org.apache.paimon.table.DataTable
+import org.apache.paimon.types.RowType
+
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.filter.SparkCatalystPartitionPredicate
+import 
org.apache.spark.sql.catalyst.filter.SparkCatalystPartitionPredicate.extractSupportedPartitionFilters
+import org.apache.spark.sql.catalyst.plans.logical.Filter
+import org.assertj.core.api.Assertions.assertThat
+
+import java.util.{List => JList}
+
+import scala.collection.JavaConverters._
+
+class SparkCatalystPartitionPredicateTest extends PaimonSparkTestBase {
+
+  test("SparkCatalystPartitionPredicate: basic test") {
+    withTable("t") {
+      sql("""
+            |CREATE TABLE t (id INT, value INT, year STRING, month STRING, day 
STRING, hour STRING)
+            |PARTITIONED BY (year, month, day, hour)
+            |""".stripMargin)
+
+      sql("""
+            |INSERT INTO t values
+            |(1, 100, '2024', '07', '15', '21'),
+            |(3, 300, '2025', '07', '16', '22'),
+            |(4, 400, '2025', '07', '16', '23'),
+            |(5, 440, '2025', '07', '16', '25'),
+            |(6, 500, '2025', '07', '17', '00'),
+            |(7, 600, '2025', '07', '17', '02')
+            |""".stripMargin)
+
+      val q =
+        """
+          |SELECT * FROM t
+          |WHERE CONCAT_WS('-', year, month, day, hour)
+          |BETWEEN '2025-07-16-21' AND '2025-07-16-24'
+          |AND value > 400
+          |ORDER BY id
+          |""".stripMargin
+
+      val table = loadTable("t")
+      val partitionRowType = table.rowType().project(table.partitionKeys())
+
+      val filters = extractCatalystFilters(q)
+      assert(filters.size == 4)
+      val partitionFilters = extractSupportedPartitionFilters(filters, 
partitionRowType)
+      assert(partitionFilters.size == 2)
+
+      val partitionPredicate = 
SparkCatalystPartitionPredicate(partitionFilters, partitionRowType)
+      assertThat[String](getFilteredPartitions(table, partitionRowType, 
partitionPredicate))
+        .containsExactlyInAnyOrder("+I[2025, 07, 16, 22]", "+I[2025, 07, 16, 
23]")
+    }
+  }
+
+  test("SparkCatalystPartitionPredicate: varchar partition column") {
+    withTable("t_varchar") {
+      sql("""
+            |CREATE TABLE t_varchar (id INT, value INT, region VARCHAR(20), 
city VARCHAR(20))
+            |PARTITIONED BY (region, city)
+            |""".stripMargin)
+
+      sql("""
+            |INSERT INTO t_varchar values
+            |(1, 100, 'north', 'beijing'),
+            |(2, 200, 'south', 'shanghai'),
+            |(3, 300, 'east', 'hangzhou'),
+            |(4, 400, 'west', 'chengdu'),
+            |(5, 500, 'north', 'tianjin')
+            |""".stripMargin)
+
+      val q =
+        """
+          |SELECT * FROM t_varchar
+          |WHERE CONCAT_WS('-', region, city)
+          |BETWEEN 'north-beijing' AND 'south-shanghai'
+          |ORDER BY id
+          |""".stripMargin
+
+      val table = loadTable("t_varchar")
+      val partitionRowType = table.rowType().project(table.partitionKeys())
+      val partitionFilters =
+        extractSupportedPartitionFilters(extractCatalystFilters(q), 
partitionRowType)
+
+      val partitionPredicate = 
SparkCatalystPartitionPredicate(partitionFilters, partitionRowType)
+      assertThat[String](getFilteredPartitions(table, partitionRowType, 
partitionPredicate))
+        .containsExactlyInAnyOrder(
+          "+I[north, beijing]",
+          "+I[south, shanghai]",
+          "+I[north, tianjin]")
+
+      // swap cols
+      val q2 =
+        """
+          |SELECT * FROM t_varchar
+          |WHERE CONCAT_WS('-', city, region)
+          |BETWEEN 'beijing-north' AND 'chengdu-west'
+          |ORDER BY id
+          |""".stripMargin
+
+      val partitionFilters2 =
+        extractSupportedPartitionFilters(extractCatalystFilters(q2), 
partitionRowType)
+      val partitionPredicate2 = 
SparkCatalystPartitionPredicate(partitionFilters2, partitionRowType)
+      assertThat[String](getFilteredPartitions(table, partitionRowType, 
partitionPredicate2))
+        .containsExactlyInAnyOrder("+I[north, beijing]", "+I[west, chengdu]")
+    }
+  }
+
+  test("SparkCatalystPartitionPredicate: cast") {
+    withTable("t") {
+      sql("""
+            |CREATE TABLE t (id int, value int, dt STRING)
+            |using paimon
+            |PARTITIONED BY (dt)
+            |""".stripMargin)
+
+      sql("""
+            |INSERT INTO t values
+            |(1, 100, '1'), (2, 111, '2')
+            |""".stripMargin)
+
+      val q = "SELECT * FROM t WHERE dt = 1"
+      val table = loadTable("t")
+      val partitionRowType = table.rowType().project(table.partitionKeys())
+      val partitionFilters =
+        extractSupportedPartitionFilters(extractCatalystFilters(q), 
partitionRowType)
+
+      val partitionPredicate = 
SparkCatalystPartitionPredicate(partitionFilters, partitionRowType)
+      assertThat[String](getFilteredPartitions(table, partitionRowType, 
partitionPredicate))
+        .containsExactlyInAnyOrder("+I[1]")
+    }
+  }
+
+  test("SparkCatalystPartitionPredicate: null partition") {
+    withTable("t") {
+      sql("""
+            |CREATE TABLE t (id INT, value INT, region STRING, city INT)
+            |PARTITIONED BY (region, city)
+            |""".stripMargin)
+
+      sql("INSERT INTO t values (1, 100, 'north', null)")
+
+      val table = loadTable("t")
+      val partitionRowType = table.rowType().project(table.partitionKeys())
+
+      val q =
+        """
+          |SELECT * FROM t
+          |WHERE CONCAT_WS('-', region, city) = 'north'
+          |""".stripMargin
+      checkAnswer(sql(q), Seq(Row(1, 100, "north", null)))
+
+      val partitionFilters =
+        extractSupportedPartitionFilters(extractCatalystFilters(q), 
partitionRowType)
+      val partitionPredicate = 
SparkCatalystPartitionPredicate(partitionFilters, partitionRowType)
+      assertThat[String](getFilteredPartitions(table, partitionRowType, 
partitionPredicate))
+        .containsExactlyInAnyOrder("+I[north, NULL]")
+
+      val q2 =
+        """
+          |SELECT * FROM t
+          |WHERE CONCAT_WS('-', region, city) != 'north'
+          |""".stripMargin
+      checkAnswer(sql(q2), Seq())
+
+      val partitionFilters2 =
+        extractSupportedPartitionFilters(extractCatalystFilters(q2), 
partitionRowType)
+      val partitionPredicate2 = 
SparkCatalystPartitionPredicate(partitionFilters2, partitionRowType)
+      assert(getFilteredPartitions(table, partitionRowType, 
partitionPredicate2).isEmpty)
+    }
+  }
+
+  def extractCatalystFilters(sqlStr: String): Seq[Expression] = {
+    var filters: Seq[Expression] = Seq.empty
+    // Set ansi false to make sure some filters like `Cast` not push down
+    withSparkSQLConf("spark.sql.ansi.enabled" -> "false") {
+      filters = sql(sqlStr).queryExecution.optimizedPlan
+        .collect { case Filter(condition, _) => condition }
+        .flatMap(splitConjunctivePredicates)
+    }
+    filters
+  }
+
+  def getFilteredPartitions(
+      table: DataTable,
+      partitionRowType: RowType,
+      partitionPredicate: PartitionPredicate): JList[String] = {
+    table
+      .newScan()
+      .withPartitionFilter(partitionPredicate)
+      .listPartitions()
+      .asScala
+      .map(r => internalRowToString(r, partitionRowType))
+      .asJava
+  }
+}

Reply via email to