This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new eb8e99721714 [SPARK-47657][SQL] Implement collation filter push down 
support per file source
eb8e99721714 is described below

commit eb8e99721714eeac14978f0cb6a2dc35251a5d23
Author: Stefan Kandic <stefan.kan...@databricks.com>
AuthorDate: Mon Apr 8 12:17:38 2024 +0800

    [SPARK-47657][SQL] Implement collation filter push down support per file 
source
    
    ### What changes were proposed in this pull request?
    
    Previously in #45262 we completely disabled filter pushdown for any 
expression referencing non utf8 binary collated columns. However, we should 
make this more fine-grained so that individual data sources can decide to 
support pushing down these filters if they can.
    
    ### Why are the changes needed?
    
    To enable collation filter push down for an individual data source.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    With previously added unit test.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #45782 from stefankandic/newPushdownLogic.
    
    Authored-by: Stefan Kandic <stefan.kan...@databricks.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../execution/datasources/DataSourceUtils.scala    |  9 ++-
 .../sql/execution/datasources/FileFormat.scala     |  6 ++
 .../execution/datasources/FileSourceStrategy.scala |  3 +-
 .../datasources/PruneFileSourcePartitions.scala    |  4 +-
 .../execution/datasources/v2/FileScanBuilder.scala |  9 ++-
 .../spark/sql/FileBasedDataSourceSuite.scala       | 85 ++++++++++++----------
 6 files changed, 70 insertions(+), 46 deletions(-)

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala
index 38567c16fd1f..0db5de724340 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala
@@ -284,12 +284,15 @@ object DataSourceUtils extends PredicateHelper {
    * Determines whether a filter should be pushed down to the data source or 
not.
    *
    * @param expression The filter expression to be evaluated.
+   * @param isCollationPushDownSupported Whether the data source supports 
collation push down.
    * @return A boolean indicating whether the filter should be pushed down or 
not.
    */
-  def shouldPushFilter(expression: Expression): Boolean = {
-    expression.deterministic && !expression.exists {
+  def shouldPushFilter(expression: Expression, isCollationPushDownSupported: 
Boolean): Boolean = {
+    if (!expression.deterministic) return false
+
+    isCollationPushDownSupported || !expression.exists {
       case childExpression @ (_: Attribute | _: GetStructField) =>
-        // don't push down filters for types with non-default collation
+        // don't push down filters for types with non-binary sortable collation
         // as it could lead to incorrect results
         
SchemaUtils.hasNonBinarySortableCollatedString(childExpression.dataType)
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala
index 36c59950fe20..0785b0cbe9e2 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala
@@ -223,6 +223,12 @@ trait FileFormat {
    */
   def fileConstantMetadataExtractors: Map[String, PartitionedFile => Any] =
     FileFormat.BASE_METADATA_EXTRACTORS
+
+  /**
+   * Returns whether the file format supports filter push down
+   * for non utf8 binary collated columns.
+   */
+  def supportsCollationPushDown: Boolean = false
 }
 
 object FileFormat {
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala
index e4b66d72eaf8..f2dcbe26104f 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala
@@ -160,7 +160,8 @@ object FileSourceStrategy extends Strategy with 
PredicateHelper with Logging {
       //  - filters that need to be evaluated again after the scan
       val filterSet = ExpressionSet(filters)
 
-      val filtersToPush = filters.filter(f => 
DataSourceUtils.shouldPushFilter(f))
+      val filtersToPush = filters.filter(f =>
+          DataSourceUtils.shouldPushFilter(f, 
fsRelation.fileFormat.supportsCollationPushDown))
 
       val normalizedFilters = DataSourceStrategy.normalizeExprs(
         filtersToPush, l.output)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala
index 408da5dad768..b0431d1df398 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala
@@ -63,8 +63,8 @@ private[sql] object PruneFileSourcePartitions extends 
Rule[LogicalPlan] {
             _))
         if filters.nonEmpty && fsRelation.partitionSchema.nonEmpty =>
       val normalizedFilters = DataSourceStrategy.normalizeExprs(
-        filters.filter(f =>
-          !SubqueryExpression.hasSubquery(f) && 
DataSourceUtils.shouldPushFilter(f)),
+        filters.filter(f => !SubqueryExpression.hasSubquery(f) &&
+          DataSourceUtils.shouldPushFilter(f, 
fsRelation.fileFormat.supportsCollationPushDown)),
         logicalRelation.output)
       val (partitionKeyFilters, _) = DataSourceUtils
         .getPartitionFiltersAndDataFilters(partitionSchema, normalizedFilters)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScanBuilder.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScanBuilder.scala
index 346bff980a96..7cd2779f86f9 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScanBuilder.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScanBuilder.scala
@@ -70,7 +70,8 @@ abstract class FileScanBuilder(
   }
 
   override def pushFilters(filters: Seq[Expression]): Seq[Expression] = {
-    val (filtersToPush, filtersToRemain) = 
filters.partition(DataSourceUtils.shouldPushFilter)
+    val (filtersToPush, filtersToRemain) = filters.partition(
+      f => DataSourceUtils.shouldPushFilter(f, supportsCollationPushDown))
     val (partitionFilters, dataFilters) =
       DataSourceUtils.getPartitionFiltersAndDataFilters(partitionSchema, 
filtersToPush)
     this.partitionFilters = partitionFilters
@@ -95,6 +96,12 @@ abstract class FileScanBuilder(
    */
   protected def pushDataFilters(dataFilters: Array[Filter]): Array[Filter] = 
Array.empty[Filter]
 
+  /**
+   * Returns whether the file scan builder supports filter pushdown
+   * for non utf8 binary collated columns.
+   */
+  protected def supportsCollationPushDown: Boolean = false
+
   private def createRequiredNameSet(): Set[String] =
     requiredSchema.fields.map(PartitioningUtils.getColName(_, 
isCaseSensitive)).toSet
 
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala
index d991bc4094a8..8a092ab69cf1 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala
@@ -1245,45 +1245,52 @@ class FileBasedDataSourceSuite extends QueryTest
 
   test("disable filter pushdown for collated strings") {
     Seq("parquet").foreach { format =>
-      withTempPath { path =>
-        val collation = "'UTF8_BINARY_LCASE'"
-        val df = sql(
-          s"""SELECT
-             |  COLLATE(c, $collation) as c1,
-             |  struct(COLLATE(c, $collation)) as str,
-             |  named_struct('f1', named_struct('f2', COLLATE(c, $collation), 
'f3', 1)) as namedstr,
-             |  array(COLLATE(c, $collation)) as arr,
-             |  map(COLLATE(c, $collation), 1) as map1,
-             |  map(1, COLLATE(c, $collation)) as map2
-             |FROM VALUES ('aaa'), ('AAA'), ('bbb')
-             |as data(c)
-             |""".stripMargin)
-
-        df.write.format(format).save(path.getAbsolutePath)
-
-        // filter and expected result
-        val filters = Seq(
-          ("==", Seq(Row("aaa"), Row("AAA"))),
-          ("!=", Seq(Row("bbb"))),
-          ("<", Seq()),
-          ("<=", Seq(Row("aaa"), Row("AAA"))),
-          (">", Seq(Row("bbb"))),
-          (">=", Seq(Row("aaa"), Row("AAA"), Row("bbb"))))
-
-        filters.foreach { filter =>
-          val readback = spark.read
-            .parquet(path.getAbsolutePath)
-            .where(s"c1 ${filter._1} collate('aaa', $collation)")
-            .where(s"str ${filter._1} struct(collate('aaa', $collation))")
-            .where(s"namedstr.f1.f2 ${filter._1} collate('aaa', $collation)")
-            .where(s"arr ${filter._1} array(collate('aaa', $collation))")
-            .where(s"map_keys(map1) ${filter._1} array(collate('aaa', 
$collation))")
-            .where(s"map_values(map2) ${filter._1} array(collate('aaa', 
$collation))")
-            .select("c1")
-
-          val explain = 
readback.queryExecution.explainString(ExplainMode.fromString("extended"))
-          assert(explain.contains("PushedFilters: []"))
-          checkAnswer(readback, filter._2)
+      Seq(format, "").foreach { conf =>
+        withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> conf) {
+          withTempPath { path =>
+            val collation = "'UTF8_BINARY_LCASE'"
+            val df = sql(
+              s"""SELECT
+                 |  COLLATE(c, $collation) as c1,
+                 |  struct(COLLATE(c, $collation)) as str,
+                 |  named_struct('f1', named_struct('f2',
+                 |    COLLATE(c, $collation), 'f3', 1)) as namedstr,
+                 |  array(COLLATE(c, $collation)) as arr,
+                 |  map(COLLATE(c, $collation), 1) as map1,
+                 |  map(1, COLLATE(c, $collation)) as map2
+                 |FROM VALUES ('aaa'), ('AAA'), ('bbb')
+                 |as data(c)
+                 |""".stripMargin)
+
+            df.write.format(format).save(path.getAbsolutePath)
+
+            // filter and expected result
+            val filters = Seq(
+              ("==", Seq(Row("aaa"), Row("AAA"))),
+              ("!=", Seq(Row("bbb"))),
+              ("<", Seq()),
+              ("<=", Seq(Row("aaa"), Row("AAA"))),
+              (">", Seq(Row("bbb"))),
+              (">=", Seq(Row("aaa"), Row("AAA"), Row("bbb"))))
+
+            filters.foreach { filter =>
+              val readback = spark.read
+                .format(format)
+                .load(path.getAbsolutePath)
+                .where(s"c1 ${filter._1} collate('aaa', $collation)")
+                .where(s"str ${filter._1} struct(collate('aaa', $collation))")
+                .where(s"namedstr.f1.f2 ${filter._1} collate('aaa', 
$collation)")
+                .where(s"arr ${filter._1} array(collate('aaa', $collation))")
+                .where(s"map_keys(map1) ${filter._1} array(collate('aaa', 
$collation))")
+                .where(s"map_values(map2) ${filter._1} array(collate('aaa', 
$collation))")
+                .select("c1")
+
+              val explain = readback.queryExecution.explainString(
+                ExplainMode.fromString("extended"))
+              assert(explain.contains("PushedFilters: []"))
+              checkAnswer(readback, filter._2)
+            }
+          }
         }
       }
     }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to