This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-3.1
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.1 by this push:
     new c1f7aa2  [SPARK-33482][SPARK-34756][SQL] Fix FileScan equality check
c1f7aa2 is described below

commit c1f7aa286a64f650f1dc9fc85bde33b683f9dd2e
Author: Peter Toth <peter.t...@gmail.com>
AuthorDate: Tue Mar 23 17:01:16 2021 +0800

    [SPARK-33482][SPARK-34756][SQL] Fix FileScan equality check
    
    ### What changes were proposed in this pull request?
    
    This bug was introduced by SPARK-30428 at Apache Spark 3.0.0.
    This PR fixes `FileScan.equals()`.
    
    ### Why are the changes needed?
    - Without this fix `FileScan.equals` doesn't take `fileIndex` and 
`readSchema` into account.
    - Partition filters and data filters added to `FileScan` (in #27112 and 
#27157) caused that canonicalized form of some `BatchScanExec` nodes don't 
match and this prevents some reuse possibilities.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes, before this fix incorrect reuse of `FileScan` and so `BatchScanExec` 
could have happed causing correctness issues.
    
    ### How was this patch tested?
    Added new UTs.
    
    Closes #31848 from peter-toth/SPARK-34756-fix-filescan-equality-check.
    
    Authored-by: Peter Toth <peter.t...@gmail.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
    (cherry picked from commit 93a5d34f84c362110ef7d8853e59ce597faddad9)
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../org/apache/spark/sql/avro/AvroScanSuite.scala  |  30 ++
 .../sql/execution/datasources/v2/FileScan.scala    |  22 +-
 .../scala/org/apache/spark/sql/FileScanSuite.scala | 374 +++++++++++++++++++++
 .../scala/org/apache/spark/sql/SQLQuerySuite.scala |  24 ++
 4 files changed, 446 insertions(+), 4 deletions(-)

diff --git 
a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroScanSuite.scala 
b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroScanSuite.scala
new file mode 100644
index 0000000..98a7190
--- /dev/null
+++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroScanSuite.scala
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.avro
+
+import org.apache.spark.sql.FileScanSuiteBase
+import org.apache.spark.sql.v2.avro.AvroScan
+
+class AvroScanSuite extends FileScanSuiteBase {
+  val scanBuilders = Seq[(String, ScanBuilder, Seq[String])](
+    ("AvroScan",
+      (s, fi, ds, rds, rps, f, o, pf, df) => AvroScan(s, fi, ds, rds, rps, o, 
f, pf, df),
+      Seq.empty))
+
+  run(scanBuilders)
+}
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala
index 363dd15..ac63725 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala
@@ -24,8 +24,9 @@ import org.apache.hadoop.fs.Path
 import org.apache.spark.internal.Logging
 import org.apache.spark.internal.config.IO_WARNING_LARGEFILETHRESHOLD
 import org.apache.spark.sql.{AnalysisException, SparkSession}
-import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionSet}
+import org.apache.spark.sql.catalyst.expressions.{AttributeSet, Expression, 
ExpressionSet}
 import 
org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
+import org.apache.spark.sql.catalyst.plans.QueryPlan
 import org.apache.spark.sql.connector.read.{Batch, InputPartition, Scan, 
Statistics, SupportsReportStatistics}
 import org.apache.spark.sql.execution.PartitionedFileUtil
 import org.apache.spark.sql.execution.datasources._
@@ -84,11 +85,24 @@ trait FileScan extends Scan
 
   protected def seqToString(seq: Seq[Any]): String = seq.mkString("[", ", ", 
"]")
 
+  private lazy val (normalizedPartitionFilters, normalizedDataFilters) = {
+    val output = readSchema().toAttributes
+    val partitionFilterAttributes = AttributeSet(partitionFilters).map(a => 
a.name -> a).toMap
+    val dataFiltersAttributes = AttributeSet(dataFilters).map(a => a.name -> 
a).toMap
+    val normalizedPartitionFilters = ExpressionSet(partitionFilters.map(
+      QueryPlan.normalizeExpressions(_,
+        output.map(a => partitionFilterAttributes.getOrElse(a.name, a)))))
+    val normalizedDataFilters = ExpressionSet(dataFilters.map(
+      QueryPlan.normalizeExpressions(_,
+        output.map(a => dataFiltersAttributes.getOrElse(a.name, a)))))
+    (normalizedPartitionFilters, normalizedDataFilters)
+  }
+
   override def equals(obj: Any): Boolean = obj match {
     case f: FileScan =>
-      fileIndex == f.fileIndex && readSchema == f.readSchema
-        ExpressionSet(partitionFilters) == ExpressionSet(f.partitionFilters) &&
-        ExpressionSet(dataFilters) == ExpressionSet(f.dataFilters)
+      fileIndex == f.fileIndex && readSchema == f.readSchema &&
+        normalizedPartitionFilters == f.normalizedPartitionFilters &&
+        normalizedDataFilters == f.normalizedDataFilters
 
     case _ => false
   }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileScanSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/FileScanSuite.scala
new file mode 100644
index 0000000..4e7fe84
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/FileScanSuite.scala
@@ -0,0 +1,374 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import scala.collection.mutable
+
+import com.google.common.collect.ImmutableMap
+import org.apache.hadoop.fs.{FileStatus, Path}
+
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.expressions.{And, Expression, IsNull, 
LessThan}
+import org.apache.spark.sql.execution.datasources.{PartitioningAwareFileIndex, 
PartitionSpec}
+import org.apache.spark.sql.execution.datasources.v2.FileScan
+import org.apache.spark.sql.execution.datasources.v2.csv.CSVScan
+import org.apache.spark.sql.execution.datasources.v2.json.JsonScan
+import org.apache.spark.sql.execution.datasources.v2.orc.OrcScan
+import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan
+import org.apache.spark.sql.execution.datasources.v2.text.TextScan
+import org.apache.spark.sql.sources.Filter
+import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
+
+trait FileScanSuiteBase extends SharedSparkSession {
+  private def newPartitioningAwareFileIndex() = {
+    new PartitioningAwareFileIndex(spark, Map.empty, None) {
+      override def partitionSpec(): PartitionSpec = {
+        PartitionSpec.emptySpec
+      }
+
+      override protected def leafFiles: mutable.LinkedHashMap[Path, 
FileStatus] = {
+        mutable.LinkedHashMap.empty
+      }
+
+      override protected def leafDirToChildrenFiles: Map[Path, 
Array[FileStatus]] = {
+        Map.empty
+      }
+
+      override def rootPaths: Seq[Path] = {
+        Seq.empty
+      }
+
+      override def refresh(): Unit = {}
+    }
+  }
+
+  type ScanBuilder = (
+    SparkSession,
+      PartitioningAwareFileIndex,
+      StructType,
+      StructType,
+      StructType,
+      Array[Filter],
+      CaseInsensitiveStringMap,
+      Seq[Expression],
+      Seq[Expression]) => FileScan
+
+  def run(scanBuilders: Seq[(String, ScanBuilder, Seq[String])]): Unit = {
+    val dataSchema = StructType.fromDDL("data INT, partition INT, other INT")
+    val dataSchemaNotEqual = StructType.fromDDL("data INT, partition INT, 
other INT, new INT")
+    val readDataSchema = StructType.fromDDL("data INT")
+    val readDataSchemaNotEqual = StructType.fromDDL("data INT, other INT")
+    val readPartitionSchema = StructType.fromDDL("partition INT")
+    val readPartitionSchemaNotEqual = StructType.fromDDL("partition INT, other 
INT")
+    val pushedFilters =
+      Array[Filter](sources.And(sources.IsNull("data"), 
sources.LessThan("data", 0)))
+    val pushedFiltersNotEqual =
+      Array[Filter](sources.And(sources.IsNull("data"), 
sources.LessThan("data", 1)))
+    val optionsMap = ImmutableMap.of("key", "value")
+    val options = new CaseInsensitiveStringMap(ImmutableMap.copyOf(optionsMap))
+    val optionsNotEqual =
+      new CaseInsensitiveStringMap(ImmutableMap.copyOf(ImmutableMap.of("key2", 
"value2")))
+    val partitionFilters = Seq(And(IsNull('data.int), LessThan('data.int, 0)))
+    val partitionFiltersNotEqual = Seq(And(IsNull('data.int), 
LessThan('data.int, 1)))
+    val dataFilters = Seq(And(IsNull('data.int), LessThan('data.int, 0)))
+    val dataFiltersNotEqual = Seq(And(IsNull('data.int), LessThan('data.int, 
1)))
+
+    scanBuilders.foreach { case (name, scanBuilder, exclusions) =>
+      test(s"SPARK-33482: Test $name equals") {
+        val partitioningAwareFileIndex = newPartitioningAwareFileIndex()
+
+        val scan = scanBuilder(
+          spark,
+          partitioningAwareFileIndex,
+          dataSchema,
+          readDataSchema,
+          readPartitionSchema,
+          pushedFilters,
+          options,
+          partitionFilters,
+          dataFilters)
+
+        val scanEquals = scanBuilder(
+          spark,
+          partitioningAwareFileIndex,
+          dataSchema.copy(),
+          readDataSchema.copy(),
+          readPartitionSchema.copy(),
+          pushedFilters.clone(),
+          new CaseInsensitiveStringMap(ImmutableMap.copyOf(optionsMap)),
+          Seq(partitionFilters: _*),
+          Seq(dataFilters: _*))
+
+        assert(scan === scanEquals)
+      }
+
+      test(s"SPARK-33482: Test $name fileIndex not equals") {
+        val partitioningAwareFileIndex = newPartitioningAwareFileIndex()
+
+        val scan = scanBuilder(
+          spark,
+          partitioningAwareFileIndex,
+          dataSchema,
+          readDataSchema,
+          readPartitionSchema,
+          pushedFilters,
+          options,
+          partitionFilters,
+          dataFilters)
+
+        val partitioningAwareFileIndexNotEqual = 
newPartitioningAwareFileIndex()
+
+        val scanNotEqual = scanBuilder(
+          spark,
+          partitioningAwareFileIndexNotEqual,
+          dataSchema,
+          readDataSchema,
+          readPartitionSchema,
+          pushedFilters,
+          options,
+          partitionFilters,
+          dataFilters)
+
+        assert(scan !== scanNotEqual)
+      }
+
+      if (!exclusions.contains("dataSchema")) {
+        test(s"SPARK-33482: Test $name dataSchema not equals") {
+          val partitioningAwareFileIndex = newPartitioningAwareFileIndex()
+
+          val scan = scanBuilder(
+            spark,
+            partitioningAwareFileIndex,
+            dataSchema,
+            readDataSchema,
+            readPartitionSchema,
+            pushedFilters,
+            options,
+            partitionFilters,
+            dataFilters)
+
+          val scanNotEqual = scanBuilder(
+            spark,
+            partitioningAwareFileIndex,
+            dataSchemaNotEqual,
+            readDataSchema,
+            readPartitionSchema,
+            pushedFilters,
+            options,
+            partitionFilters,
+            dataFilters)
+
+          assert(scan !== scanNotEqual)
+        }
+      }
+
+      test(s"SPARK-33482: Test $name readDataSchema not equals") {
+        val partitioningAwareFileIndex = newPartitioningAwareFileIndex()
+
+        val scan = scanBuilder(
+          spark,
+          partitioningAwareFileIndex,
+          dataSchema,
+          readDataSchema,
+          readPartitionSchema,
+          pushedFilters,
+          options,
+          partitionFilters,
+          dataFilters)
+
+        val scanNotEqual = scanBuilder(
+          spark,
+          partitioningAwareFileIndex,
+          dataSchema,
+          readDataSchemaNotEqual,
+          readPartitionSchema,
+          pushedFilters,
+          options,
+          partitionFilters,
+          dataFilters)
+
+        assert(scan !== scanNotEqual)
+      }
+
+      test(s"SPARK-33482: Test $name readPartitionSchema not equals") {
+        val partitioningAwareFileIndex = newPartitioningAwareFileIndex()
+
+        val scan = scanBuilder(
+          spark,
+          partitioningAwareFileIndex,
+          dataSchema,
+          readDataSchema,
+          readPartitionSchema,
+          pushedFilters,
+          options,
+          partitionFilters,
+          dataFilters)
+
+        val scanNotEqual = scanBuilder(
+          spark,
+          partitioningAwareFileIndex,
+          dataSchema,
+          readDataSchema,
+          readPartitionSchemaNotEqual,
+          pushedFilters,
+          options,
+          partitionFilters,
+          dataFilters)
+
+        assert(scan !== scanNotEqual)
+      }
+
+      if (!exclusions.contains("pushedFilters")) {
+        test(s"SPARK-33482: Test $name pushedFilters not equals") {
+          val partitioningAwareFileIndex = newPartitioningAwareFileIndex()
+
+          val scan = scanBuilder(
+            spark,
+            partitioningAwareFileIndex,
+            dataSchema,
+            readDataSchema,
+            readPartitionSchema,
+            pushedFilters,
+            options,
+            partitionFilters,
+            dataFilters)
+
+          val scanNotEqual = scanBuilder(
+            spark,
+            partitioningAwareFileIndex,
+            dataSchema,
+            readDataSchema,
+            readPartitionSchema,
+            pushedFiltersNotEqual,
+            options,
+            partitionFilters,
+            dataFilters)
+
+          assert(scan !== scanNotEqual)
+        }
+      }
+
+      test(s"SPARK-33482: Test $name options not equals") {
+        val partitioningAwareFileIndex = newPartitioningAwareFileIndex()
+
+        val scan = scanBuilder(
+          spark,
+          partitioningAwareFileIndex,
+          dataSchema,
+          readDataSchema,
+          readPartitionSchema,
+          pushedFilters,
+          options,
+          partitionFilters,
+          dataFilters)
+
+        val scanNotEqual = scanBuilder(
+          spark,
+          partitioningAwareFileIndex,
+          dataSchema,
+          readDataSchema,
+          readPartitionSchema,
+          pushedFilters,
+          optionsNotEqual,
+          partitionFilters,
+          dataFilters)
+
+        assert(scan !== scanNotEqual)
+      }
+
+      test(s"SPARK-33482: Test $name partitionFilters not equals") {
+        val partitioningAwareFileIndex = newPartitioningAwareFileIndex()
+
+        val scan = scanBuilder(
+          spark,
+          partitioningAwareFileIndex,
+          dataSchema,
+          readDataSchema,
+          readPartitionSchema,
+          pushedFilters,
+          options,
+          partitionFilters,
+          dataFilters)
+
+        val scanNotEqual = scanBuilder(
+          spark,
+          partitioningAwareFileIndex,
+          dataSchema,
+          readDataSchema,
+          readPartitionSchema,
+          pushedFilters,
+          options,
+          partitionFiltersNotEqual,
+          dataFilters)
+        assert(scan !== scanNotEqual)
+      }
+
+      test(s"SPARK-33482: Test $name dataFilters not equals") {
+        val partitioningAwareFileIndex = newPartitioningAwareFileIndex()
+
+        val scan = scanBuilder(
+          spark,
+          partitioningAwareFileIndex,
+          dataSchema,
+          readDataSchema,
+          readPartitionSchema,
+          pushedFilters,
+          options,
+          partitionFilters,
+          dataFilters)
+
+        val scanNotEqual = scanBuilder(
+          spark,
+          partitioningAwareFileIndex,
+          dataSchema,
+          readDataSchema,
+          readPartitionSchema,
+          pushedFilters,
+          options,
+          partitionFilters,
+          dataFiltersNotEqual)
+        assert(scan !== scanNotEqual)
+      }
+    }
+  }
+}
+
+class FileScanSuite extends FileScanSuiteBase {
+  val scanBuilders = Seq[(String, ScanBuilder, Seq[String])](
+    ("ParquetScan",
+      (s, fi, ds, rds, rps, f, o, pf, df) =>
+        ParquetScan(s, s.sessionState.newHadoopConf(), fi, ds, rds, rps, f, o, 
pf, df),
+      Seq.empty),
+    ("OrcScan",
+      (s, fi, ds, rds, rps, f, o, pf, df) =>
+        OrcScan(s, s.sessionState.newHadoopConf(), fi, ds, rds, rps, o, f, pf, 
df),
+      Seq.empty),
+    ("CSVScan",
+      (s, fi, ds, rds, rps, f, o, pf, df) => CSVScan(s, fi, ds, rds, rps, o, 
f, pf, df),
+      Seq.empty),
+    ("JsonScan",
+      (s, fi, ds, rds, rps, f, o, pf, df) => JsonScan(s, fi, ds, rds, rps, o, 
f, pf, df),
+      Seq.empty),
+    ("TextScan",
+      (s, fi, _, rds, rps, _, o, pf, df) => TextScan(s, fi, rds, rps, o, pf, 
df),
+      Seq("dataSchema", "pushedFilters")))
+
+  run(scanBuilders)
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index c29eac2..aa673dc 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -37,6 +37,7 @@ import 
org.apache.spark.sql.execution.datasources.SchemaColumnConvertNotSupporte
 import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
 import org.apache.spark.sql.execution.datasources.v2.orc.OrcScan
 import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan
+import org.apache.spark.sql.execution.exchange.ReusedExchangeExec
 import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, 
CartesianProductExec, SortMergeJoinExec}
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.internal.SQLConf
@@ -3945,6 +3946,29 @@ class SQLQuerySuite extends QueryTest with 
SharedSparkSession with AdaptiveSpark
       }
     }
   }
+
+  test("SPARK-33482: Fix FileScan canonicalization") {
+    withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> "") {
+      withTempPath { path =>
+        spark.range(5).toDF().write.mode("overwrite").parquet(path.toString)
+        withTempView("t") {
+          spark.read.parquet(path.toString).createOrReplaceTempView("t")
+          val df = sql(
+            """
+              |SELECT *
+              |FROM t AS t1
+              |JOIN t AS t2 ON t2.id = t1.id
+              |JOIN t AS t3 ON t3.id = t2.id
+              |""".stripMargin)
+          df.collect()
+          val reusedExchanges = collect(df.queryExecution.executedPlan) {
+            case r: ReusedExchangeExec => r
+          }
+          assert(reusedExchanges.size == 1)
+        }
+      }
+    }
+  }
 }
 
 case class Foo(bar: Option[String])

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to