This is an automated email from the ASF dual-hosted git repository.

sunchao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new c63ba6c4f43 [SPARK-42454][SQL] SPJ: encapsulate all SPJ related 
parameters in BatchScanExec
c63ba6c4f43 is described below

commit c63ba6c4f43f5872ee6804361ee443c29a739d9d
Author: Szehon Ho <szehon.apa...@gmail.com>
AuthorDate: Fri Jul 14 21:45:37 2023 -0700

    [SPARK-42454][SQL] SPJ: encapsulate all SPJ related parameters in 
BatchScanExec
    
    ### What changes were proposed in this pull request?
    Pull out the SPJ-related attribute of BatchScanExec into a case class
    
    ### Why are the changes needed?
    We plan to have further evolution of SPJ parameters to support more SPJ 
features.  So we want to stabilize the definition of BatchScanExec to not have 
to touch the many places in the code that it is pattern-matched/unapplied, etc..
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    Existing unit test to verify no behavior change.
    
    Closes #41990 from szehon-ho/spj_refactor.
    
    Authored-by: Szehon Ho <szehon.apa...@gmail.com>
    Signed-off-by: Chao Sun <sunc...@apple.com>
---
 .../apache/spark/sql/avro/AvroRowReaderSuite.scala |  2 +-
 .../org/apache/spark/sql/avro/AvroSuite.scala      |  6 +--
 .../execution/datasources/v2/BatchScanExec.scala   | 63 +++++++++++++++-------
 .../datasources/v2/DataSourceV2Strategy.scala      |  3 +-
 .../execution/exchange/EnsureRequirements.scala    |  8 +--
 .../spark/sql/FileBasedDataSourceSuite.scala       |  4 +-
 .../PruneFileSourcePartitionsSuite.scala           |  2 +-
 .../datasources/PrunePartitionSuiteBase.scala      |  2 +-
 .../datasources/orc/OrcV2SchemaPruningSuite.scala  |  2 +-
 9 files changed, 59 insertions(+), 33 deletions(-)

diff --git 
a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroRowReaderSuite.scala
 
b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroRowReaderSuite.scala
index 046ff4ef088..cc0e178c617 100644
--- 
a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroRowReaderSuite.scala
+++ 
b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroRowReaderSuite.scala
@@ -59,7 +59,7 @@ class AvroRowReaderSuite
 
       val df = spark.read.format("avro").load(dir.getCanonicalPath)
       val fileScan = df.queryExecution.executedPlan collectFirst {
-        case BatchScanExec(_, f: AvroScan, _, _, _, _, _, _, _) => f
+        case BatchScanExec(_, f: AvroScan, _, _, _, _) => f
       }
       val filePath = fileScan.get.fileIndex.inputFiles(0)
       val fileSize = new File(new URI(filePath)).length
diff --git 
a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala 
b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala
index c5e52292caf..35e9f43289c 100644
--- a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala
+++ b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala
@@ -2778,7 +2778,7 @@ class AvroV2Suite extends AvroSuite with 
ExplainSuiteHelper {
       })
 
       val fileScan = df.queryExecution.executedPlan collectFirst {
-        case BatchScanExec(_, f: AvroScan, _, _, _, _, _, _, _) => f
+        case BatchScanExec(_, f: AvroScan, _, _, _, _) => f
       }
       assert(fileScan.nonEmpty)
       assert(fileScan.get.partitionFilters.nonEmpty)
@@ -2812,7 +2812,7 @@ class AvroV2Suite extends AvroSuite with 
ExplainSuiteHelper {
       assert(filterCondition.isDefined)
 
       val fileScan = df.queryExecution.executedPlan collectFirst {
-        case BatchScanExec(_, f: AvroScan, _, _, _, _, _, _, _) => f
+        case BatchScanExec(_, f: AvroScan, _, _, _, _) => f
       }
       assert(fileScan.nonEmpty)
       assert(fileScan.get.partitionFilters.isEmpty)
@@ -2893,7 +2893,7 @@ class AvroV2Suite extends AvroSuite with 
ExplainSuiteHelper {
             .where("value = 'a'")
 
           val fileScan = df.queryExecution.executedPlan collectFirst {
-            case BatchScanExec(_, f: AvroScan, _, _, _, _, _, _, _) => f
+            case BatchScanExec(_, f: AvroScan, _, _, _, _) => f
           }
           assert(fileScan.nonEmpty)
           if (filtersPushdown) {
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala
index d43331d57c4..4b538197392 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala
@@ -37,23 +37,19 @@ case class BatchScanExec(
     output: Seq[AttributeReference],
     @transient scan: Scan,
     runtimeFilters: Seq[Expression],
-    keyGroupedPartitioning: Option[Seq[Expression]] = None,
     ordering: Option[Seq[SortOrder]] = None,
     @transient table: Table,
-    commonPartitionValues: Option[Seq[(InternalRow, Int)]] = None,
-    applyPartialClustering: Boolean = false,
-    replicatePartitions: Boolean = false) extends DataSourceV2ScanExecBase {
+    spjParams: StoragePartitionJoinParams = StoragePartitionJoinParams()
+  ) extends DataSourceV2ScanExecBase {
 
-  @transient lazy val batch = if (scan == null) null else scan.toBatch
+  @transient lazy val batch: Batch = if (scan == null) null else scan.toBatch
 
   // TODO: unify the equal/hashCode implementation for all data source v2 
query plans.
   override def equals(other: Any): Boolean = other match {
     case other: BatchScanExec =>
       this.batch != null && this.batch == other.batch &&
           this.runtimeFilters == other.runtimeFilters &&
-          this.commonPartitionValues == other.commonPartitionValues &&
-          this.replicatePartitions == other.replicatePartitions &&
-          this.applyPartialClustering == other.applyPartialClustering
+          this.spjParams == other.spjParams
     case _ =>
       false
   }
@@ -119,11 +115,11 @@ case class BatchScanExec(
 
   override def outputPartitioning: Partitioning = {
     super.outputPartitioning match {
-      case k: KeyGroupedPartitioning if commonPartitionValues.isDefined =>
+      case k: KeyGroupedPartitioning if 
spjParams.commonPartitionValues.isDefined =>
         // We allow duplicated partition values if
         // 
`spark.sql.sources.v2.bucketing.partiallyClusteredDistribution.enabled` is true
-        val newPartValues = commonPartitionValues.get.flatMap { case 
(partValue, numSplits) =>
-          Seq.fill(numSplits)(partValue)
+        val newPartValues = spjParams.commonPartitionValues.get.flatMap {
+          case (partValue, numSplits) => Seq.fill(numSplits)(partValue)
         }
         k.copy(numPartitions = newPartValues.length, partitionValues = 
newPartValues)
       case p => p
@@ -148,15 +144,17 @@ case class BatchScanExec(
                   
s"${SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key} " +
                   "is enabled")
 
-            val groupedPartitions = 
groupPartitions(finalPartitions.map(_.head), true).get
+            val groupedPartitions = 
groupPartitions(finalPartitions.map(_.head),
+              groupSplits = true).get
 
             // This means the input partitions are not grouped by partition 
values. We'll need to
             // check `groupByPartitionValues` and decide whether to group and 
replicate splits
             // within a partition.
-            if (commonPartitionValues.isDefined && applyPartialClustering) {
+            if (spjParams.commonPartitionValues.isDefined &&
+              spjParams.applyPartialClustering) {
               // A mapping from the common partition values to how many splits 
the partition
               // should contain. Note this no longer maintain the partition 
key ordering.
-              val commonPartValuesMap = commonPartitionValues
+              val commonPartValuesMap = spjParams.commonPartitionValues
                 .get
                 .map(t => (InternalRowComparableWrapper(t._1, p.expressions), 
t._2))
                 .toMap
@@ -168,7 +166,7 @@ case class BatchScanExec(
                   assert(numSplits.isDefined, s"Partition value $partValue 
does not exist in " +
                       "common partition values from Spark plan")
 
-                  val newSplits = if (replicatePartitions) {
+                  val newSplits = if (spjParams.replicatePartitions) {
                     // We need to also replicate partitions according to the 
other side of join
                     Seq.fill(numSplits.get)(splits)
                   } else {
@@ -184,11 +182,12 @@ case class BatchScanExec(
 
               // Now fill missing partition keys with empty partitions
               val partitionMapping = nestGroupedPartitions.toMap
-              finalPartitions = commonPartitionValues.get.flatMap { case 
(partValue, numSplits) =>
-                // Use empty partition for those partition values that are not 
present.
-                partitionMapping.getOrElse(
-                  InternalRowComparableWrapper(partValue, p.expressions),
-                  Seq.fill(numSplits)(Seq.empty))
+              finalPartitions = spjParams.commonPartitionValues.get.flatMap {
+                case (partValue, numSplits) =>
+                  // Use empty partition for those partition values that are 
not present.
+                  partitionMapping.getOrElse(
+                    InternalRowComparableWrapper(partValue, p.expressions),
+                    Seq.fill(numSplits)(Seq.empty))
               }
             } else {
               val partitionMapping = groupedPartitions.map { case (row, parts) 
=>
@@ -222,6 +221,9 @@ case class BatchScanExec(
     rdd
   }
 
+  override def keyGroupedPartitioning: Option[Seq[Expression]] =
+    spjParams.keyGroupedPartitioning
+
   override def doCanonicalize(): BatchScanExec = {
     this.copy(
       output = output.map(QueryPlan.normalizeExpressions(_, output)),
@@ -241,3 +243,24 @@ case class BatchScanExec(
     s"BatchScan ${table.name()}".trim
   }
 }
+
+case class StoragePartitionJoinParams(
+    keyGroupedPartitioning: Option[Seq[Expression]] = None,
+    commonPartitionValues: Option[Seq[(InternalRow, Int)]] = None,
+    applyPartialClustering: Boolean = false,
+    replicatePartitions: Boolean = false) {
+  override def equals(other: Any): Boolean = other match {
+    case other: StoragePartitionJoinParams =>
+      this.commonPartitionValues == other.commonPartitionValues &&
+      this.replicatePartitions == other.replicatePartitions &&
+      this.applyPartialClustering == other.applyPartialClustering
+    case _ =>
+      false
+  }
+
+  override def hashCode(): Int = Objects.hashCode(
+    commonPartitionValues: Option[Seq[(InternalRow, Int)]],
+    applyPartialClustering: java.lang.Boolean,
+    replicatePartitions: java.lang.Boolean)
+}
+
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala
index 542ac2e6748..abd70f322c8 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala
@@ -142,7 +142,8 @@ class DataSourceV2Strategy(session: SparkSession) extends 
Strategy with Predicat
         case _ => false
       }
       val batchExec = BatchScanExec(relation.output, relation.scan, 
runtimeFilters,
-        relation.keyGroupedPartitioning, relation.ordering, 
relation.relation.table)
+        relation.ordering, relation.relation.table,
+        StoragePartitionJoinParams(relation.keyGroupedPartitioning))
       withProjectAndFilter(project, postScanFilters, batchExec, 
!batchExec.supportsColumnar) :: Nil
 
     case PhysicalOperation(p, f, r: StreamingDataSourceV2Relation)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
index 457a9e0a868..42c880e7c62 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
@@ -531,9 +531,11 @@ case class EnsureRequirements(
       replicatePartitions: Boolean): SparkPlan = plan match {
     case scan: BatchScanExec =>
       scan.copy(
-        commonPartitionValues = Some(values),
-        applyPartialClustering = applyPartialClustering,
-        replicatePartitions = replicatePartitions
+        spjParams = scan.spjParams.copy(
+          commonPartitionValues = Some(values),
+          applyPartialClustering = applyPartialClustering,
+          replicatePartitions = replicatePartitions
+        )
       )
     case node =>
       node.mapChildren(child => populatePartitionValues(
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala
index d69a68f5726..93275487f29 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala
@@ -1014,7 +1014,7 @@ class FileBasedDataSourceSuite extends QueryTest
           })
 
           val fileScan = df.queryExecution.executedPlan collectFirst {
-            case BatchScanExec(_, f: FileScan, _, _, _, _, _, _, _) => f
+            case BatchScanExec(_, f: FileScan, _, _, _, _) => f
           }
           assert(fileScan.nonEmpty)
           assert(fileScan.get.partitionFilters.nonEmpty)
@@ -1055,7 +1055,7 @@ class FileBasedDataSourceSuite extends QueryTest
           assert(filterCondition.isDefined)
 
           val fileScan = df.queryExecution.executedPlan collectFirst {
-            case BatchScanExec(_, f: FileScan, _, _, _, _, _, _, _) => f
+            case BatchScanExec(_, f: FileScan, _, _, _, _) => f
           }
           assert(fileScan.nonEmpty)
           assert(fileScan.get.partitionFilters.isEmpty)
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitionsSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitionsSuite.scala
index 1b30205a418..3a70bfc7f4a 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitionsSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitionsSuite.scala
@@ -170,7 +170,7 @@ class PruneFileSourcePartitionsSuite extends 
PrunePartitionSuiteBase with Shared
   override def getScanExecPartitionSize(plan: SparkPlan): Long = {
     plan.collectFirst {
       case p: FileSourceScanExec => p.selectedPartitions.length
-      case BatchScanExec(_, scan: FileScan, _, _, _, _, _, _, _) =>
+      case BatchScanExec(_, scan: FileScan, _, _, _, _) =>
         scan.fileIndex.listFiles(scan.partitionFilters, 
scan.dataFilters).length
     }.get
   }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/PrunePartitionSuiteBase.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/PrunePartitionSuiteBase.scala
index 9a61e6517f7..430e9f848e4 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/PrunePartitionSuiteBase.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/PrunePartitionSuiteBase.scala
@@ -95,7 +95,7 @@ abstract class PrunePartitionSuiteBase extends 
StatisticsCollectionTestBase {
     assert(getScanExecPartitionSize(plan) == expectedPartitionCount)
 
     val collectFn: PartialFunction[SparkPlan, Seq[Expression]] = 
collectPartitionFiltersFn orElse {
-      case BatchScanExec(_, scan: FileScan, _, _, _, _, _, _, _) => 
scan.partitionFilters
+      case BatchScanExec(_, scan: FileScan, _, _, _, _) => 
scan.partitionFilters
     }
     val pushedDownPartitionFilters = plan.collectFirst(collectFn)
       .map(exps => exps.filterNot(e => e.isInstanceOf[IsNotNull]))
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcV2SchemaPruningSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcV2SchemaPruningSuite.scala
index 1fba772f5a8..8d503d64e30 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcV2SchemaPruningSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcV2SchemaPruningSuite.scala
@@ -42,7 +42,7 @@ class OrcV2SchemaPruningSuite extends SchemaPruningSuite with 
AdaptiveSparkPlanH
   override def checkScanSchemata(df: DataFrame, expectedSchemaCatalogStrings: 
String*): Unit = {
     val fileSourceScanSchemata =
       collect(df.queryExecution.executedPlan) {
-        case BatchScanExec(_, scan: OrcScan, _, _, _, _, _, _, _) => 
scan.readDataSchema
+        case BatchScanExec(_, scan: OrcScan, _, _, _, _) => scan.readDataSchema
       }
     assert(fileSourceScanSchemata.size === expectedSchemaCatalogStrings.size,
       s"Found ${fileSourceScanSchemata.size} file sources in dataframe, " +


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to