This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new eb8e99721714 [SPARK-47657][SQL] Implement collation filter push down support per file source eb8e99721714 is described below commit eb8e99721714eeac14978f0cb6a2dc35251a5d23 Author: Stefan Kandic <stefan.kan...@databricks.com> AuthorDate: Mon Apr 8 12:17:38 2024 +0800 [SPARK-47657][SQL] Implement collation filter push down support per file source ### What changes were proposed in this pull request? Previously in #45262 we completely disabled filter pushdown for any expression referencing non utf8 binary collated columns. However, we should make this more fine-grained so that individual data sources can decide to support pushing down these filters if they can. ### Why are the changes needed? To enable collation filter push down for an individual data source. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? With previously added unit test. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #45782 from stefankandic/newPushdownLogic. Authored-by: Stefan Kandic <stefan.kan...@databricks.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../execution/datasources/DataSourceUtils.scala | 9 ++- .../sql/execution/datasources/FileFormat.scala | 6 ++ .../execution/datasources/FileSourceStrategy.scala | 3 +- .../datasources/PruneFileSourcePartitions.scala | 4 +- .../execution/datasources/v2/FileScanBuilder.scala | 9 ++- .../spark/sql/FileBasedDataSourceSuite.scala | 85 ++++++++++++---------- 6 files changed, 70 insertions(+), 46 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala index 38567c16fd1f..0db5de724340 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala @@ -284,12 +284,15 @@ object DataSourceUtils extends PredicateHelper { * Determines whether a filter should be pushed down to the data source or not. * * @param expression The filter expression to be evaluated. + * @param isCollationPushDownSupported Whether the data source supports collation push down. * @return A boolean indicating whether the filter should be pushed down or not. */ - def shouldPushFilter(expression: Expression): Boolean = { - expression.deterministic && !expression.exists { + def shouldPushFilter(expression: Expression, isCollationPushDownSupported: Boolean): Boolean = { + if (!expression.deterministic) return false + + isCollationPushDownSupported || !expression.exists { case childExpression @ (_: Attribute | _: GetStructField) => - // don't push down filters for types with non-default collation + // don't push down filters for types with non-binary sortable collation // as it could lead to incorrect results SchemaUtils.hasNonBinarySortableCollatedString(childExpression.dataType) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala index 36c59950fe20..0785b0cbe9e2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala @@ -223,6 +223,12 @@ trait FileFormat { */ def fileConstantMetadataExtractors: Map[String, PartitionedFile => Any] = FileFormat.BASE_METADATA_EXTRACTORS + + /** + * Returns whether the file format supports filter push down + * for non utf8 binary collated columns. + */ + def supportsCollationPushDown: Boolean = false } object FileFormat { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala index e4b66d72eaf8..f2dcbe26104f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala @@ -160,7 +160,8 @@ object FileSourceStrategy extends Strategy with PredicateHelper with Logging { // - filters that need to be evaluated again after the scan val filterSet = ExpressionSet(filters) - val filtersToPush = filters.filter(f => DataSourceUtils.shouldPushFilter(f)) + val filtersToPush = filters.filter(f => + DataSourceUtils.shouldPushFilter(f, fsRelation.fileFormat.supportsCollationPushDown)) val normalizedFilters = DataSourceStrategy.normalizeExprs( filtersToPush, l.output) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala index 408da5dad768..b0431d1df398 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala @@ -63,8 +63,8 @@ private[sql] object PruneFileSourcePartitions extends Rule[LogicalPlan] { _)) if filters.nonEmpty && fsRelation.partitionSchema.nonEmpty => val normalizedFilters = DataSourceStrategy.normalizeExprs( - filters.filter(f => - !SubqueryExpression.hasSubquery(f) && DataSourceUtils.shouldPushFilter(f)), + filters.filter(f => !SubqueryExpression.hasSubquery(f) && + DataSourceUtils.shouldPushFilter(f, fsRelation.fileFormat.supportsCollationPushDown)), logicalRelation.output) val (partitionKeyFilters, _) = DataSourceUtils .getPartitionFiltersAndDataFilters(partitionSchema, normalizedFilters) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScanBuilder.scala index 346bff980a96..7cd2779f86f9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScanBuilder.scala @@ -70,7 +70,8 @@ abstract class FileScanBuilder( } override def pushFilters(filters: Seq[Expression]): Seq[Expression] = { - val (filtersToPush, filtersToRemain) = filters.partition(DataSourceUtils.shouldPushFilter) + val (filtersToPush, filtersToRemain) = filters.partition( + f => DataSourceUtils.shouldPushFilter(f, supportsCollationPushDown)) val (partitionFilters, dataFilters) = DataSourceUtils.getPartitionFiltersAndDataFilters(partitionSchema, filtersToPush) this.partitionFilters = partitionFilters @@ -95,6 +96,12 @@ abstract class FileScanBuilder( */ protected def pushDataFilters(dataFilters: Array[Filter]): Array[Filter] = Array.empty[Filter] + /** + * Returns whether the file scan builder supports filter pushdown + * for non utf8 binary collated columns. + */ + protected def supportsCollationPushDown: Boolean = false + private def createRequiredNameSet(): Set[String] = requiredSchema.fields.map(PartitioningUtils.getColName(_, isCaseSensitive)).toSet diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala index d991bc4094a8..8a092ab69cf1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -1245,45 +1245,52 @@ class FileBasedDataSourceSuite extends QueryTest test("disable filter pushdown for collated strings") { Seq("parquet").foreach { format => - withTempPath { path => - val collation = "'UTF8_BINARY_LCASE'" - val df = sql( - s"""SELECT - | COLLATE(c, $collation) as c1, - | struct(COLLATE(c, $collation)) as str, - | named_struct('f1', named_struct('f2', COLLATE(c, $collation), 'f3', 1)) as namedstr, - | array(COLLATE(c, $collation)) as arr, - | map(COLLATE(c, $collation), 1) as map1, - | map(1, COLLATE(c, $collation)) as map2 - |FROM VALUES ('aaa'), ('AAA'), ('bbb') - |as data(c) - |""".stripMargin) - - df.write.format(format).save(path.getAbsolutePath) - - // filter and expected result - val filters = Seq( - ("==", Seq(Row("aaa"), Row("AAA"))), - ("!=", Seq(Row("bbb"))), - ("<", Seq()), - ("<=", Seq(Row("aaa"), Row("AAA"))), - (">", Seq(Row("bbb"))), - (">=", Seq(Row("aaa"), Row("AAA"), Row("bbb")))) - - filters.foreach { filter => - val readback = spark.read - .parquet(path.getAbsolutePath) - .where(s"c1 ${filter._1} collate('aaa', $collation)") - .where(s"str ${filter._1} struct(collate('aaa', $collation))") - .where(s"namedstr.f1.f2 ${filter._1} collate('aaa', $collation)") - .where(s"arr ${filter._1} array(collate('aaa', $collation))") - .where(s"map_keys(map1) ${filter._1} array(collate('aaa', $collation))") - .where(s"map_values(map2) ${filter._1} array(collate('aaa', $collation))") - .select("c1") - - val explain = readback.queryExecution.explainString(ExplainMode.fromString("extended")) - assert(explain.contains("PushedFilters: []")) - checkAnswer(readback, filter._2) + Seq(format, "").foreach { conf => + withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> conf) { + withTempPath { path => + val collation = "'UTF8_BINARY_LCASE'" + val df = sql( + s"""SELECT + | COLLATE(c, $collation) as c1, + | struct(COLLATE(c, $collation)) as str, + | named_struct('f1', named_struct('f2', + | COLLATE(c, $collation), 'f3', 1)) as namedstr, + | array(COLLATE(c, $collation)) as arr, + | map(COLLATE(c, $collation), 1) as map1, + | map(1, COLLATE(c, $collation)) as map2 + |FROM VALUES ('aaa'), ('AAA'), ('bbb') + |as data(c) + |""".stripMargin) + + df.write.format(format).save(path.getAbsolutePath) + + // filter and expected result + val filters = Seq( + ("==", Seq(Row("aaa"), Row("AAA"))), + ("!=", Seq(Row("bbb"))), + ("<", Seq()), + ("<=", Seq(Row("aaa"), Row("AAA"))), + (">", Seq(Row("bbb"))), + (">=", Seq(Row("aaa"), Row("AAA"), Row("bbb")))) + + filters.foreach { filter => + val readback = spark.read + .format(format) + .load(path.getAbsolutePath) + .where(s"c1 ${filter._1} collate('aaa', $collation)") + .where(s"str ${filter._1} struct(collate('aaa', $collation))") + .where(s"namedstr.f1.f2 ${filter._1} collate('aaa', $collation)") + .where(s"arr ${filter._1} array(collate('aaa', $collation))") + .where(s"map_keys(map1) ${filter._1} array(collate('aaa', $collation))") + .where(s"map_values(map2) ${filter._1} array(collate('aaa', $collation))") + .select("c1") + + val explain = readback.queryExecution.explainString( + ExplainMode.fromString("extended")) + assert(explain.contains("PushedFilters: []")) + checkAnswer(readback, filter._2) + } + } } } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org