This is an automated email from the ASF dual-hosted git repository.
lzljs3620320 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/paimon.git
The following commit(s) were added to refs/heads/master by this push:
new 65a22b39af [spark] Introduce SparkCatalystPartitionPredicate (#6813)
65a22b39af is described below
commit 65a22b39af07fa6940e31c81127b8c7f0a6fb190
Author: Zouxxyy <[email protected]>
AuthorDate: Tue Dec 16 08:49:56 2025 +0800
[spark] Introduce SparkCatalystPartitionPredicate (#6813)
---
.../filter/SparkCatalystPartitionPredicate.scala | 102 ++++++++++
.../SparkCatalystPartitionPredicateTest.scala | 220 +++++++++++++++++++++
2 files changed, 322 insertions(+)
diff --git
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/catalyst/filter/SparkCatalystPartitionPredicate.scala
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/catalyst/filter/SparkCatalystPartitionPredicate.scala
new file mode 100644
index 0000000000..b3128e0bc1
--- /dev/null
+++
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/catalyst/filter/SparkCatalystPartitionPredicate.scala
@@ -0,0 +1,102 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.filter
+
+import org.apache.paimon.data.{BinaryRow, InternalArray, InternalRow =>
PaimonInternalRow}
+import org.apache.paimon.partition.PartitionPredicate
+import org.apache.paimon.spark.SparkTypeUtils
+import org.apache.paimon.spark.data.SparkInternalRow
+import org.apache.paimon.types.RowType
+
+import org.apache.spark.sql.catalyst.expressions.{And, Attribute,
BasePredicate, BoundReference, Expression, Predicate, PythonUDF,
SubqueryExpression}
+import org.apache.spark.sql.catalyst.util.CharVarcharUtils
+import org.apache.spark.sql.execution.datasources.DataSourceUtils
+import org.apache.spark.sql.types.StructType
+
+/**
+ * A [[PartitionPredicate]] that applies Spark Catalyst partition filters to a
partition
+ * [[BinaryRow]].
+ */
+case class SparkCatalystPartitionPredicate(
+ partitionFilter: Expression,
+ partitionRowType: RowType
+) extends PartitionPredicate {
+
+ private val partitionSchema: StructType =
CharVarcharUtils.replaceCharVarcharWithStringInSchema(
+ SparkTypeUtils.fromPaimonRowType(partitionRowType))
+ @transient private val predicate: BasePredicate =
+ new StructExpressionFilters(partitionFilter, partitionSchema).toPredicate
+ @transient private val sparkPartitionRow: SparkInternalRow =
+ SparkInternalRow.create(partitionRowType)
+
+ override def test(partition: BinaryRow): Boolean = {
+ predicate.eval(sparkPartitionRow.replace(partition))
+ }
+
+ override def test(
+ rowCount: Long,
+ minValues: PaimonInternalRow,
+ maxValues: PaimonInternalRow,
+ nullCounts: InternalArray): Boolean = true
+
+ override def toString: String = partitionFilter.toString()
+
+ private class StructExpressionFilters(filter: Expression, schema:
StructType) {
+ def toPredicate: BasePredicate = {
+ Predicate.create(filter.transform { case a: Attribute =>
toRef(a.name).get })
+ }
+
+ // Finds a filter attribute in the schema and converts it to a
`BoundReference`
+ private def toRef(attr: String): Option[BoundReference] = {
+ // The names have been normalized and case sensitivity is not a concern
here.
+ schema.getFieldIndex(attr).map {
+ index =>
+ val field = schema(index)
+ BoundReference(index, field.dataType, field.nullable)
+ }
+ }
+ }
+}
+
+object SparkCatalystPartitionPredicate {
+
+ def apply(
+ partitionFilters: Seq[Expression],
+ partitionRowType: RowType): SparkCatalystPartitionPredicate = {
+ assert(partitionFilters.nonEmpty, "partitionFilters is empty")
+ new SparkCatalystPartitionPredicate(
+ partitionFilters.sortBy(_.references.size).reduce(And),
+ partitionRowType)
+ }
+
+ /** Extracts supported partition filters from the given filters. */
+ def extractSupportedPartitionFilters(
+ filters: Seq[Expression],
+ partitionRowType: RowType): Seq[Expression] = {
+ val partitionSchema: StructType =
SparkTypeUtils.fromPaimonRowType(partitionRowType)
+ val (deterministicFilters, _) = filters.partition(_.deterministic)
+ val (partitionFilters, _) =
+ DataSourceUtils.getPartitionFiltersAndDataFilters(partitionSchema,
deterministicFilters)
+ partitionFilters.filter {
+ f =>
+ // Python UDFs might exist because this rule is applied before
``ExtractPythonUDFs``.
+ !SubqueryExpression.hasSubquery(f) &&
!f.exists(_.isInstanceOf[PythonUDF])
+ }
+ }
+}
diff --git
a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/predicate/SparkCatalystPartitionPredicateTest.scala
b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/predicate/SparkCatalystPartitionPredicateTest.scala
new file mode 100644
index 0000000000..18639f47ad
--- /dev/null
+++
b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/predicate/SparkCatalystPartitionPredicateTest.scala
@@ -0,0 +1,220 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.paimon.spark.predicate
+
+import org.apache.paimon.data.DataFormatTestUtil.internalRowToString
+import org.apache.paimon.partition.PartitionPredicate
+import org.apache.paimon.spark.PaimonSparkTestBase
+import org.apache.paimon.table.DataTable
+import org.apache.paimon.types.RowType
+
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.filter.SparkCatalystPartitionPredicate
+import
org.apache.spark.sql.catalyst.filter.SparkCatalystPartitionPredicate.extractSupportedPartitionFilters
+import org.apache.spark.sql.catalyst.plans.logical.Filter
+import org.assertj.core.api.Assertions.assertThat
+
+import java.util.{List => JList}
+
+import scala.collection.JavaConverters._
+
+class SparkCatalystPartitionPredicateTest extends PaimonSparkTestBase {
+
+ test("SparkCatalystPartitionPredicate: basic test") {
+ withTable("t") {
+ sql("""
+ |CREATE TABLE t (id INT, value INT, year STRING, month STRING, day
STRING, hour STRING)
+ |PARTITIONED BY (year, month, day, hour)
+ |""".stripMargin)
+
+ sql("""
+ |INSERT INTO t values
+ |(1, 100, '2024', '07', '15', '21'),
+ |(3, 300, '2025', '07', '16', '22'),
+ |(4, 400, '2025', '07', '16', '23'),
+ |(5, 440, '2025', '07', '16', '25'),
+ |(6, 500, '2025', '07', '17', '00'),
+ |(7, 600, '2025', '07', '17', '02')
+ |""".stripMargin)
+
+ val q =
+ """
+ |SELECT * FROM t
+ |WHERE CONCAT_WS('-', year, month, day, hour)
+ |BETWEEN '2025-07-16-21' AND '2025-07-16-24'
+ |AND value > 400
+ |ORDER BY id
+ |""".stripMargin
+
+ val table = loadTable("t")
+ val partitionRowType = table.rowType().project(table.partitionKeys())
+
+ val filters = extractCatalystFilters(q)
+ assert(filters.size == 4)
+ val partitionFilters = extractSupportedPartitionFilters(filters,
partitionRowType)
+ assert(partitionFilters.size == 2)
+
+ val partitionPredicate =
SparkCatalystPartitionPredicate(partitionFilters, partitionRowType)
+ assertThat[String](getFilteredPartitions(table, partitionRowType,
partitionPredicate))
+ .containsExactlyInAnyOrder("+I[2025, 07, 16, 22]", "+I[2025, 07, 16,
23]")
+ }
+ }
+
+ test("SparkCatalystPartitionPredicate: varchar partition column") {
+ withTable("t_varchar") {
+ sql("""
+ |CREATE TABLE t_varchar (id INT, value INT, region VARCHAR(20),
city VARCHAR(20))
+ |PARTITIONED BY (region, city)
+ |""".stripMargin)
+
+ sql("""
+ |INSERT INTO t_varchar values
+ |(1, 100, 'north', 'beijing'),
+ |(2, 200, 'south', 'shanghai'),
+ |(3, 300, 'east', 'hangzhou'),
+ |(4, 400, 'west', 'chengdu'),
+ |(5, 500, 'north', 'tianjin')
+ |""".stripMargin)
+
+ val q =
+ """
+ |SELECT * FROM t_varchar
+ |WHERE CONCAT_WS('-', region, city)
+ |BETWEEN 'north-beijing' AND 'south-shanghai'
+ |ORDER BY id
+ |""".stripMargin
+
+ val table = loadTable("t_varchar")
+ val partitionRowType = table.rowType().project(table.partitionKeys())
+ val partitionFilters =
+ extractSupportedPartitionFilters(extractCatalystFilters(q),
partitionRowType)
+
+ val partitionPredicate =
SparkCatalystPartitionPredicate(partitionFilters, partitionRowType)
+ assertThat[String](getFilteredPartitions(table, partitionRowType,
partitionPredicate))
+ .containsExactlyInAnyOrder(
+ "+I[north, beijing]",
+ "+I[south, shanghai]",
+ "+I[north, tianjin]")
+
+ // swap cols
+ val q2 =
+ """
+ |SELECT * FROM t_varchar
+ |WHERE CONCAT_WS('-', city, region)
+ |BETWEEN 'beijing-north' AND 'chengdu-west'
+ |ORDER BY id
+ |""".stripMargin
+
+ val partitionFilters2 =
+ extractSupportedPartitionFilters(extractCatalystFilters(q2),
partitionRowType)
+ val partitionPredicate2 =
SparkCatalystPartitionPredicate(partitionFilters2, partitionRowType)
+ assertThat[String](getFilteredPartitions(table, partitionRowType,
partitionPredicate2))
+ .containsExactlyInAnyOrder("+I[north, beijing]", "+I[west, chengdu]")
+ }
+ }
+
+ test("SparkCatalystPartitionPredicate: cast") {
+ withTable("t") {
+ sql("""
+ |CREATE TABLE t (id int, value int, dt STRING)
+ |using paimon
+ |PARTITIONED BY (dt)
+ |""".stripMargin)
+
+ sql("""
+ |INSERT INTO t values
+ |(1, 100, '1'), (2, 111, '2')
+ |""".stripMargin)
+
+ val q = "SELECT * FROM t WHERE dt = 1"
+ val table = loadTable("t")
+ val partitionRowType = table.rowType().project(table.partitionKeys())
+ val partitionFilters =
+ extractSupportedPartitionFilters(extractCatalystFilters(q),
partitionRowType)
+
+ val partitionPredicate =
SparkCatalystPartitionPredicate(partitionFilters, partitionRowType)
+ assertThat[String](getFilteredPartitions(table, partitionRowType,
partitionPredicate))
+ .containsExactlyInAnyOrder("+I[1]")
+ }
+ }
+
+ test("SparkCatalystPartitionPredicate: null partition") {
+ withTable("t") {
+ sql("""
+ |CREATE TABLE t (id INT, value INT, region STRING, city INT)
+ |PARTITIONED BY (region, city)
+ |""".stripMargin)
+
+ sql("INSERT INTO t values (1, 100, 'north', null)")
+
+ val table = loadTable("t")
+ val partitionRowType = table.rowType().project(table.partitionKeys())
+
+ val q =
+ """
+ |SELECT * FROM t
+ |WHERE CONCAT_WS('-', region, city) = 'north'
+ |""".stripMargin
+ checkAnswer(sql(q), Seq(Row(1, 100, "north", null)))
+
+ val partitionFilters =
+ extractSupportedPartitionFilters(extractCatalystFilters(q),
partitionRowType)
+ val partitionPredicate =
SparkCatalystPartitionPredicate(partitionFilters, partitionRowType)
+ assertThat[String](getFilteredPartitions(table, partitionRowType,
partitionPredicate))
+ .containsExactlyInAnyOrder("+I[north, NULL]")
+
+ val q2 =
+ """
+ |SELECT * FROM t
+ |WHERE CONCAT_WS('-', region, city) != 'north'
+ |""".stripMargin
+ checkAnswer(sql(q2), Seq())
+
+ val partitionFilters2 =
+ extractSupportedPartitionFilters(extractCatalystFilters(q2),
partitionRowType)
+ val partitionPredicate2 =
SparkCatalystPartitionPredicate(partitionFilters2, partitionRowType)
+ assert(getFilteredPartitions(table, partitionRowType,
partitionPredicate2).isEmpty)
+ }
+ }
+
+ def extractCatalystFilters(sqlStr: String): Seq[Expression] = {
+ var filters: Seq[Expression] = Seq.empty
+ // Set ansi false to make sure some filters like `Cast` not push down
+ withSparkSQLConf("spark.sql.ansi.enabled" -> "false") {
+ filters = sql(sqlStr).queryExecution.optimizedPlan
+ .collect { case Filter(condition, _) => condition }
+ .flatMap(splitConjunctivePredicates)
+ }
+ filters
+ }
+
+ def getFilteredPartitions(
+ table: DataTable,
+ partitionRowType: RowType,
+ partitionPredicate: PartitionPredicate): JList[String] = {
+ table
+ .newScan()
+ .withPartitionFilter(partitionPredicate)
+ .listPartitions()
+ .asScala
+ .map(r => internalRowToString(r, partitionRowType))
+ .asJava
+ }
+}