This is an automated email from the ASF dual-hosted git repository.

lixiao pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.0 by this push:
     new 597bbbb  [SPARK-30991] Refactor AQE readers and RDDs
597bbbb is described below

commit 597bbbbb5507448980e4fadbad85ffb104808081
Author: maryannxue <maryann...@apache.org>
AuthorDate: Mon Mar 2 16:04:00 2020 -0800

    [SPARK-30991] Refactor AQE readers and RDDs
    
    ### What changes were proposed in this pull request?
    This PR combines `CustomShuffledRowRDD` and `LocalShuffledRowRDD` into 
`ShuffledRowRDD`, and creates `CustomShuffleReaderExec` to unify and replace 
all existing AQE readers: `CoalescedShuffleReaderExec`, 
`LocalShuffleReaderExec` and `SkewJoinShuffleReaderExec`.
    
    ### Why are the changes needed?
    To reduce code redundancy.
    
    ### Does this PR introduce any user-facing change?
    No.
    
    ### How was this patch tested?
    Passed existing UTs.
    
    Closes #27742 from maryannxue/aqe-readers.
    
    Authored-by: maryannxue <maryann...@apache.org>
    Signed-off-by: gatorsmile <gatorsm...@gmail.com>
    (cherry picked from commit 473a28c1d032993c7fa515b39f2cb1e3105d65d3)
    Signed-off-by: gatorsmile <gatorsm...@gmail.com>
---
 .../spark/sql/execution/ShuffledRowRDD.scala       | 142 ++++++++++++---------
 .../apache/spark/sql/execution/SparkPlanInfo.scala |   2 +-
 .../adaptive/CustomShuffleReaderExec.scala         |  81 ++++++++++++
 .../execution/adaptive/CustomShuffledRowRDD.scala  | 113 ----------------
 .../execution/adaptive/LocalShuffledRowRDD.scala   | 112 ----------------
 .../adaptive/OptimizeLocalShuffleReader.scala      |  88 +++----------
 .../execution/adaptive/OptimizeSkewedJoin.scala    |  72 ++---------
 .../adaptive/ReduceNumShufflePartitions.scala      |  49 ++-----
 .../adaptive/ShufflePartitionsCoalescer.scala      |  23 ++--
 .../execution/exchange/ShuffleExchangeExec.scala   |  12 +-
 .../ReduceNumShufflePartitionsSuite.scala          |  28 ++--
 .../ShufflePartitionsCoalescerSuite.scala          | 101 ++++++++++-----
 .../adaptive/AdaptiveQueryExecSuite.scala          |  23 ++--
 13 files changed, 317 insertions(+), 529 deletions(-)

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala
index 4c19f95..eb02259 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala
@@ -26,17 +26,28 @@ import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.execution.metric.{SQLMetric, 
SQLShuffleReadMetricsReporter}
 import org.apache.spark.sql.internal.SQLConf
 
+sealed trait ShufflePartitionSpec
+
+// A partition that reads data of one or more reducers, from 
`startReducerIndex` (inclusive) to
+// `endReducerIndex` (exclusive).
+case class CoalescedPartitionSpec(
+  startReducerIndex: Int, endReducerIndex: Int) extends ShufflePartitionSpec
+
+// A partition that reads partial data of one reducer, from `startMapIndex` 
(inclusive) to
+// `endMapIndex` (exclusive).
+case class PartialReducerPartitionSpec(
+  reducerIndex: Int, startMapIndex: Int, endMapIndex: Int) extends 
ShufflePartitionSpec
+
+// A partition that reads partial data of one mapper, from `startReducerIndex` 
(inclusive) to
+// `endReducerIndex` (exclusive).
+case class PartialMapperPartitionSpec(
+  mapIndex: Int, startReducerIndex: Int, endReducerIndex: Int) extends 
ShufflePartitionSpec
+
 /**
- * The [[Partition]] used by [[ShuffledRowRDD]]. A post-shuffle partition
- * (identified by `postShufflePartitionIndex`) contains a range of pre-shuffle 
partitions
- * (`startPreShufflePartitionIndex` to `endPreShufflePartitionIndex - 1`, 
inclusive).
+ * The [[Partition]] used by [[ShuffledRowRDD]].
  */
-private final class ShuffledRowRDDPartition(
-    val postShufflePartitionIndex: Int,
-    val startPreShufflePartitionIndex: Int,
-    val endPreShufflePartitionIndex: Int) extends Partition {
-  override val index: Int = postShufflePartitionIndex
-}
+private final case class ShuffledRowRDDPartition(
+  index: Int, spec: ShufflePartitionSpec) extends Partition
 
 /**
  * A dummy partitioner for use with records whose partition ids have been 
pre-computed (i.e. for
@@ -94,8 +105,7 @@ class CoalescedPartitioner(val parent: Partitioner, val 
partitionStartIndices: A
  * interfaces / internals.
  *
  * This RDD takes a [[ShuffleDependency]] (`dependency`),
- * and an optional array of partition start indices as input arguments
- * (`specifiedPartitionStartIndices`).
+ * and an array of [[ShufflePartitionSpec]] as input arguments.
  *
  * The `dependency` has the parent RDD of this RDD, which represents the 
dataset before shuffle
  * (i.e. map output). Elements of this RDD are (partitionId, Row) pairs.
@@ -103,79 +113,97 @@ class CoalescedPartitioner(val parent: Partitioner, val 
partitionStartIndices: A
  * `dependency.partitioner` is the original partitioner used to partition
  * map output, and `dependency.partitioner.numPartitions` is the number of 
pre-shuffle partitions
  * (i.e. the number of partitions of the map output).
- *
- * When `specifiedPartitionStartIndices` is defined, 
`specifiedPartitionStartIndices.length`
- * will be the number of post-shuffle partitions. For this case, the `i`th 
post-shuffle
- * partition includes `specifiedPartitionStartIndices[i]` to
- * `specifiedPartitionStartIndices[i+1] - 1` (inclusive).
- *
- * When `specifiedPartitionStartIndices` is not defined, there will be
- * `dependency.partitioner.numPartitions` post-shuffle partitions. For this 
case,
- * a post-shuffle partition is created for every pre-shuffle partition.
  */
 class ShuffledRowRDD(
     var dependency: ShuffleDependency[Int, InternalRow, InternalRow],
     metrics: Map[String, SQLMetric],
-    specifiedPartitionStartIndices: Option[Array[Int]] = None)
+    partitionSpecs: Array[ShufflePartitionSpec])
   extends RDD[InternalRow](dependency.rdd.context, Nil) {
 
+  def this(
+      dependency: ShuffleDependency[Int, InternalRow, InternalRow],
+      metrics: Map[String, SQLMetric]) = {
+    this(dependency, metrics,
+      Array.tabulate(dependency.partitioner.numPartitions)(i => 
CoalescedPartitionSpec(i, i + 1)))
+  }
+
   if (SQLConf.get.fetchShuffleBlocksInBatchEnabled) {
     dependency.rdd.context.setLocalProperty(
       SortShuffleManager.FETCH_SHUFFLE_BLOCKS_IN_BATCH_ENABLED_KEY, "true")
   }
 
-  private[this] val numPreShufflePartitions = 
dependency.partitioner.numPartitions
-
-  private[this] val partitionStartIndices: Array[Int] = 
specifiedPartitionStartIndices match {
-    case Some(indices) => indices
-    case None =>
-      // When specifiedPartitionStartIndices is not defined, every 
post-shuffle partition
-      // corresponds to a pre-shuffle partition.
-      (0 until numPreShufflePartitions).toArray
-  }
-
-  private[this] val part: Partitioner =
-    new CoalescedPartitioner(dependency.partitioner, partitionStartIndices)
-
   override def getDependencies: Seq[Dependency[_]] = List(dependency)
 
-  override val partitioner: Option[Partitioner] = Some(part)
+  override val partitioner: Option[Partitioner] =
+    if (partitionSpecs.forall(_.isInstanceOf[CoalescedPartitionSpec])) {
+      val indices = 
partitionSpecs.map(_.asInstanceOf[CoalescedPartitionSpec].startReducerIndex)
+      // TODO this check is based on assumptions of callers' behavior but is 
sufficient for now.
+      if (indices.toSet.size == partitionSpecs.length) {
+        Some(new CoalescedPartitioner(dependency.partitioner, indices))
+      } else {
+        None
+      }
+    } else {
+      None
+    }
 
   override def getPartitions: Array[Partition] = {
-    assert(partitionStartIndices.length == part.numPartitions)
-    Array.tabulate[Partition](partitionStartIndices.length) { i =>
-      val startIndex = partitionStartIndices(i)
-      val endIndex =
-        if (i < partitionStartIndices.length - 1) {
-          partitionStartIndices(i + 1)
-        } else {
-          numPreShufflePartitions
-        }
-      new ShuffledRowRDDPartition(i, startIndex, endIndex)
+    Array.tabulate[Partition](partitionSpecs.length) { i =>
+      ShuffledRowRDDPartition(i, partitionSpecs(i))
     }
   }
 
   override def getPreferredLocations(partition: Partition): Seq[String] = {
     val tracker = 
SparkEnv.get.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]
-    val dep = dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]]
-    tracker.getPreferredLocationsForShuffle(dep, partition.index)
+    partition.asInstanceOf[ShuffledRowRDDPartition].spec match {
+      case CoalescedPartitionSpec(startReducerIndex, endReducerIndex) =>
+        // TODO order by partition size.
+        startReducerIndex.until(endReducerIndex).flatMap { reducerIndex =>
+          tracker.getPreferredLocationsForShuffle(dependency, reducerIndex)
+        }
+
+      case PartialReducerPartitionSpec(_, startMapIndex, endMapIndex) =>
+        tracker.getMapLocation(dependency, startMapIndex, endMapIndex)
+
+      case PartialMapperPartitionSpec(mapIndex, _, _) =>
+        tracker.getMapLocation(dependency, mapIndex, mapIndex + 1)
+    }
   }
 
   override def compute(split: Partition, context: TaskContext): 
Iterator[InternalRow] = {
-    val shuffledRowPartition = split.asInstanceOf[ShuffledRowRDDPartition]
     val tempMetrics = context.taskMetrics().createTempShuffleReadMetrics()
     // `SQLShuffleReadMetricsReporter` will update its own metrics for SQL 
exchange operator,
     // as well as the `tempMetrics` for basic shuffle metrics.
     val sqlMetricsReporter = new SQLShuffleReadMetricsReporter(tempMetrics, 
metrics)
-    // The range of pre-shuffle partitions that we are fetching at here is
-    // [startPreShufflePartitionIndex, endPreShufflePartitionIndex - 1].
-    val reader =
-      SparkEnv.get.shuffleManager.getReader(
-        dependency.shuffleHandle,
-        shuffledRowPartition.startPreShufflePartitionIndex,
-        shuffledRowPartition.endPreShufflePartitionIndex,
-        context,
-        sqlMetricsReporter)
+    val reader = split.asInstanceOf[ShuffledRowRDDPartition].spec match {
+      case CoalescedPartitionSpec(startReducerIndex, endReducerIndex) =>
+        SparkEnv.get.shuffleManager.getReader(
+          dependency.shuffleHandle,
+          startReducerIndex,
+          endReducerIndex,
+          context,
+          sqlMetricsReporter)
+
+      case PartialReducerPartitionSpec(reducerIndex, startMapIndex, 
endMapIndex) =>
+        SparkEnv.get.shuffleManager.getReaderForRange(
+          dependency.shuffleHandle,
+          startMapIndex,
+          endMapIndex,
+          reducerIndex,
+          reducerIndex + 1,
+          context,
+          sqlMetricsReporter)
+
+      case PartialMapperPartitionSpec(mapIndex, startReducerIndex, 
endReducerIndex) =>
+        SparkEnv.get.shuffleManager.getReaderForRange(
+          dependency.shuffleHandle,
+          mapIndex,
+          mapIndex + 1,
+          startReducerIndex,
+          endReducerIndex,
+          context,
+          sqlMetricsReporter)
+    }
     reader.read().asInstanceOf[Iterator[Product2[Int, InternalRow]]].map(_._2)
   }
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala
index 5b72ec0..357820a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala
@@ -18,7 +18,7 @@
 package org.apache.spark.sql.execution
 
 import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, 
LocalShuffleReaderExec, QueryStageExec}
+import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, 
QueryStageExec}
 import org.apache.spark.sql.execution.exchange.ReusedExchangeExec
 import org.apache.spark.sql.execution.metric.SQLMetricInfo
 import org.apache.spark.sql.internal.SQLConf
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CustomShuffleReaderExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CustomShuffleReaderExec.scala
new file mode 100644
index 0000000..be372bb
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CustomShuffleReaderExec.scala
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.adaptive
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
+import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, 
UnknownPartitioning}
+import org.apache.spark.sql.execution._
+import org.apache.spark.sql.execution.exchange.{ReusedExchangeExec, 
ShuffleExchangeExec}
+
+
+/**
+ * A wrapper of shuffle query stage, which follows the given partition 
arrangement.
+ *
+ * @param child           It is usually `ShuffleQueryStageExec`, but can be 
the shuffle exchange
+ *                        node during canonicalization.
+ * @param partitionSpecs  The partition specs that defines the arrangement.
+ * @param description     The string description of this shuffle reader.
+ */
+case class CustomShuffleReaderExec private(
+    child: SparkPlan,
+    partitionSpecs: Array[ShufflePartitionSpec],
+    description: String) extends UnaryExecNode {
+
+  override def output: Seq[Attribute] = child.output
+  override lazy val outputPartitioning: Partitioning = {
+    // If it is a local shuffle reader with one mapper per task, then the 
output partitioning is
+    // the same as the plan before shuffle.
+    // TODO this check is based on assumptions of callers' behavior but is 
sufficient for now.
+    if (partitionSpecs.forall(_.isInstanceOf[PartialMapperPartitionSpec]) &&
+        
partitionSpecs.map(_.asInstanceOf[PartialMapperPartitionSpec].mapIndex).toSet.size
 ==
+          partitionSpecs.length) {
+      child match {
+        case ShuffleQueryStageExec(_, s: ShuffleExchangeExec) =>
+          s.child.outputPartitioning
+        case ShuffleQueryStageExec(_, r @ ReusedExchangeExec(_, s: 
ShuffleExchangeExec)) =>
+          s.child.outputPartitioning match {
+            case e: Expression => r.updateAttr(e).asInstanceOf[Partitioning]
+            case other => other
+          }
+        case _ =>
+          throw new IllegalStateException("operating on canonicalization plan")
+      }
+    } else {
+      UnknownPartitioning(partitionSpecs.length)
+    }
+  }
+
+  override def stringArgs: Iterator[Any] = Iterator(description)
+
+  private var cachedShuffleRDD: RDD[InternalRow] = null
+
+  override protected def doExecute(): RDD[InternalRow] = {
+    if (cachedShuffleRDD == null) {
+      cachedShuffleRDD = child match {
+        case stage: ShuffleQueryStageExec =>
+          new ShuffledRowRDD(
+            stage.shuffle.shuffleDependency, stage.shuffle.readMetrics, 
partitionSpecs)
+        case _ =>
+          throw new IllegalStateException("operating on canonicalization plan")
+      }
+    }
+    cachedShuffleRDD
+  }
+}
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CustomShuffledRowRDD.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CustomShuffledRowRDD.scala
deleted file mode 100644
index 5aba574..0000000
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CustomShuffledRowRDD.scala
+++ /dev/null
@@ -1,113 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.adaptive
-
-import org.apache.spark.{Dependency, MapOutputTrackerMaster, Partition, 
ShuffleDependency, SparkEnv, TaskContext}
-import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.execution.metric.{SQLMetric, 
SQLShuffleReadMetricsReporter}
-
-sealed trait ShufflePartitionSpec
-
-// A partition that reads data of one reducer.
-case class SinglePartitionSpec(reducerIndex: Int) extends ShufflePartitionSpec
-
-// A partition that reads data of multiple reducers, from `startReducerIndex` 
(inclusive) to
-// `endReducerIndex` (exclusive).
-case class CoalescedPartitionSpec(
-    startReducerIndex: Int, endReducerIndex: Int) extends ShufflePartitionSpec
-
-// A partition that reads partial data of one reducer, from `startMapIndex` 
(inclusive) to
-// `endMapIndex` (exclusive).
-case class PartialPartitionSpec(
-    reducerIndex: Int, startMapIndex: Int, endMapIndex: Int) extends 
ShufflePartitionSpec
-
-private final case class CustomShufflePartition(
-    index: Int, spec: ShufflePartitionSpec) extends Partition
-
-// TODO: merge this with `ShuffledRowRDD`, and replace `LocalShuffledRowRDD` 
with this RDD.
-class CustomShuffledRowRDD(
-    var dependency: ShuffleDependency[Int, InternalRow, InternalRow],
-    metrics: Map[String, SQLMetric],
-    partitionSpecs: Array[ShufflePartitionSpec])
-  extends RDD[InternalRow](dependency.rdd.context, Nil) {
-
-  override def getDependencies: Seq[Dependency[_]] = List(dependency)
-
-  override def clearDependencies() {
-    super.clearDependencies()
-    dependency = null
-  }
-
-  override def getPartitions: Array[Partition] = {
-    Array.tabulate[Partition](partitionSpecs.length) { i =>
-      CustomShufflePartition(i, partitionSpecs(i))
-    }
-  }
-
-  override def getPreferredLocations(partition: Partition): Seq[String] = {
-    val tracker = 
SparkEnv.get.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]
-    partition.asInstanceOf[CustomShufflePartition].spec match {
-      case SinglePartitionSpec(reducerIndex) =>
-        tracker.getPreferredLocationsForShuffle(dependency, reducerIndex)
-
-      case CoalescedPartitionSpec(startReducerIndex, endReducerIndex) =>
-        startReducerIndex.until(endReducerIndex).flatMap { reducerIndex =>
-          tracker.getPreferredLocationsForShuffle(dependency, reducerIndex)
-        }
-
-      case PartialPartitionSpec(_, startMapIndex, endMapIndex) =>
-        tracker.getMapLocation(dependency, startMapIndex, endMapIndex)
-    }
-  }
-
-  override def compute(split: Partition, context: TaskContext): 
Iterator[InternalRow] = {
-    val tempMetrics = context.taskMetrics().createTempShuffleReadMetrics()
-    // `SQLShuffleReadMetricsReporter` will update its own metrics for SQL 
exchange operator,
-    // as well as the `tempMetrics` for basic shuffle metrics.
-    val sqlMetricsReporter = new SQLShuffleReadMetricsReporter(tempMetrics, 
metrics)
-    val reader = split.asInstanceOf[CustomShufflePartition].spec match {
-      case SinglePartitionSpec(reducerIndex) =>
-        SparkEnv.get.shuffleManager.getReader(
-          dependency.shuffleHandle,
-          reducerIndex,
-          reducerIndex + 1,
-          context,
-          sqlMetricsReporter)
-
-      case CoalescedPartitionSpec(startReducerIndex, endReducerIndex) =>
-        SparkEnv.get.shuffleManager.getReader(
-          dependency.shuffleHandle,
-          startReducerIndex,
-          endReducerIndex,
-          context,
-          sqlMetricsReporter)
-
-      case PartialPartitionSpec(reducerIndex, startMapIndex, endMapIndex) =>
-        SparkEnv.get.shuffleManager.getReaderForRange(
-          dependency.shuffleHandle,
-          startMapIndex,
-          endMapIndex,
-          reducerIndex,
-          reducerIndex + 1,
-          context,
-          sqlMetricsReporter)
-    }
-    reader.read().asInstanceOf[Iterator[Product2[Int, InternalRow]]].map(_._2)
-  }
-}
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/LocalShuffledRowRDD.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/LocalShuffledRowRDD.scala
deleted file mode 100644
index 19b78f5..0000000
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/LocalShuffledRowRDD.scala
+++ /dev/null
@@ -1,112 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.adaptive
-
-import scala.collection.mutable.ArrayBuffer
-
-import org.apache.spark._
-import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.execution.metric.{SQLMetric, 
SQLShuffleReadMetricsReporter}
-
-/**
- * The [[Partition]] used by [[LocalShuffledRowRDD]].
- * @param mapIndex the index of mapper.
- * @param startPartition the start partition ID in mapIndex mapper.
- * @param endPartition the end partition ID in mapIndex mapper.
- */
-private final class LocalShuffledRowRDDPartition(
-    override val index: Int,
-    val mapIndex: Int,
-    val startPartition: Int,
-    val endPartition: Int) extends Partition {
-}
-
-/**
- * This is a specialized version of 
[[org.apache.spark.sql.execution.ShuffledRowRDD]]. This is used
- * in Spark SQL adaptive execution when a shuffle join is converted to 
broadcast join at runtime
- * because the map output of one input table is small enough for broadcast. 
This RDD represents the
- * data of another input table of the join that reads from shuffle. Each 
partition of the RDD reads
- * the whole data from just one mapper output locally. So actually there is no 
data transferred
- * from the network.
- *
- * This RDD takes a [[ShuffleDependency]] (`dependency`).
- *
- * The `dependency` has the parent RDD of this RDD, which represents the 
dataset before shuffle
- * (i.e. map output). Elements of this RDD are (partitionId, Row) pairs.
- * Partition ids should be in the range [0, numPartitions - 1].
- * `dependency.partitioner.numPartitions` is the number of pre-shuffle 
partitions. (i.e. the number
- * of partitions of the map output). The post-shuffle partition number is the 
same to the parent
- * RDD's partition number.
- *
- * `partitionStartIndicesPerMapper` specifies how to split the shuffle blocks 
of each mapper into
- * one or more partitions. For a mapper `i`, the `j`th partition includes 
shuffle blocks from
- * `partitionStartIndicesPerMapper[i][j]` to 
`partitionStartIndicesPerMapper[i][j+1]` (exclusive).
- */
-class LocalShuffledRowRDD(
-     var dependency: ShuffleDependency[Int, InternalRow, InternalRow],
-     metrics: Map[String, SQLMetric],
-     partitionStartIndicesPerMapper: Array[Array[Int]])
-  extends RDD[InternalRow](dependency.rdd.context, Nil) {
-
-  private[this] val numReducers = dependency.partitioner.numPartitions
-  private[this] val numMappers = dependency.rdd.partitions.length
-
-  override def getDependencies: Seq[Dependency[_]] = List(dependency)
-
-  override def getPartitions: Array[Partition] = {
-    val partitions = ArrayBuffer[LocalShuffledRowRDDPartition]()
-    for (mapIndex <- 0 until numMappers) {
-      (partitionStartIndicesPerMapper(mapIndex) :+ numReducers).sliding(2, 
1).foreach {
-        case Array(start, end) =>
-          partitions += new LocalShuffledRowRDDPartition(partitions.length, 
mapIndex, start, end)
-      }
-    }
-    partitions.toArray
-  }
-
-  override def getPreferredLocations(partition: Partition): Seq[String] = {
-    val tracker = 
SparkEnv.get.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]
-    tracker.getMapLocation(dependency, partition.index, partition.index + 1)
-  }
-
-  override def compute(split: Partition, context: TaskContext): 
Iterator[InternalRow] = {
-    val localRowPartition = split.asInstanceOf[LocalShuffledRowRDDPartition]
-    val mapIndex = localRowPartition.mapIndex
-    val tempMetrics = context.taskMetrics().createTempShuffleReadMetrics()
-    // `SQLShuffleReadMetricsReporter` will update its own metrics for SQL 
exchange operator,
-    // as well as the `tempMetrics` for basic shuffle metrics.
-    val sqlMetricsReporter = new SQLShuffleReadMetricsReporter(tempMetrics, 
metrics)
-
-    val reader = SparkEnv.get.shuffleManager.getReaderForRange(
-      dependency.shuffleHandle,
-      mapIndex,
-      mapIndex + 1,
-      localRowPartition.startPartition,
-      localRowPartition.endPartition,
-      context,
-      sqlMetricsReporter)
-    reader.read().asInstanceOf[Iterator[Product2[Int, InternalRow]]].map(_._2)
-  }
-
-  override def clearDependencies() {
-    super.clearDependencies()
-    dependency = null
-  }
-}
-
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeLocalShuffleReader.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeLocalShuffleReader.scala
index e95441e..e441763 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeLocalShuffleReader.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeLocalShuffleReader.scala
@@ -17,13 +17,9 @@
 
 package org.apache.spark.sql.execution.adaptive
 
-import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
-import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, 
UnknownPartitioning}
 import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
-import org.apache.spark.sql.execution.exchange.{EnsureRequirements, 
ReusedExchangeExec, ShuffleExchangeExec}
+import org.apache.spark.sql.execution._
+import org.apache.spark.sql.execution.exchange.{EnsureRequirements, 
ShuffleExchangeExec}
 import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BuildLeft, 
BuildRight, BuildSide}
 import org.apache.spark.sql.internal.SQLConf
 
@@ -67,28 +63,35 @@ case class OptimizeLocalShuffleReader(conf: SQLConf) 
extends Rule[SparkPlan] {
     }
   }
 
-  private def createLocalReader(plan: SparkPlan): LocalShuffleReaderExec = {
+  private def createLocalReader(plan: SparkPlan): CustomShuffleReaderExec = {
     plan match {
-      case c @ CoalescedShuffleReaderExec(s: ShuffleQueryStageExec, _) =>
-        LocalShuffleReaderExec(
-          s, getPartitionStartIndices(s, Some(c.partitionStartIndices.length)))
+      case c @ CustomShuffleReaderExec(s: ShuffleQueryStageExec, _, _) =>
+        CustomShuffleReaderExec(
+          s, getPartitionSpecs(s, Some(c.partitionSpecs.length)), 
LOCAL_SHUFFLE_READER_DESCRIPTION)
       case s: ShuffleQueryStageExec =>
-        LocalShuffleReaderExec(s, getPartitionStartIndices(s, None))
+        CustomShuffleReaderExec(s, getPartitionSpecs(s, None), 
LOCAL_SHUFFLE_READER_DESCRIPTION)
     }
   }
 
   // TODO: this method assumes all shuffle blocks are the same data size. We 
should calculate the
   //       partition start indices based on block size to avoid data skew.
-  private def getPartitionStartIndices(
+  private def getPartitionSpecs(
       shuffleStage: ShuffleQueryStageExec,
-      advisoryParallelism: Option[Int]): Array[Array[Int]] = {
+      advisoryParallelism: Option[Int]): Array[ShufflePartitionSpec] = {
     val shuffleDep = shuffleStage.shuffle.shuffleDependency
     val numReducers = shuffleDep.partitioner.numPartitions
     val expectedParallelism = advisoryParallelism.getOrElse(numReducers)
     val numMappers = shuffleDep.rdd.getNumPartitions
-    Array.fill(numMappers) {
+    val splitPoints = if (numMappers == 0) {
+      Array.empty
+    } else {
       equallyDivide(numReducers, math.max(1, expectedParallelism / 
numMappers)).toArray
     }
+    (0 until numMappers).flatMap { mapIndex =>
+      (splitPoints :+ numReducers).sliding(2).map {
+        case Array(start, end) => PartialMapperPartitionSpec(mapIndex, start, 
end)
+      }
+    }.toArray
   }
 
   /**
@@ -120,6 +123,8 @@ case class OptimizeLocalShuffleReader(conf: SQLConf) 
extends Rule[SparkPlan] {
 
 object OptimizeLocalShuffleReader {
 
+  val LOCAL_SHUFFLE_READER_DESCRIPTION: String = "local"
+
   object BroadcastJoinWithShuffleLeft {
     def unapply(plan: SparkPlan): Option[(SparkPlan, BuildSide)] = plan match {
       case join: BroadcastHashJoinExec if canUseLocalShuffleReader(join.left) 
=>
@@ -138,60 +143,7 @@ object OptimizeLocalShuffleReader {
 
   def canUseLocalShuffleReader(plan: SparkPlan): Boolean = plan match {
     case s: ShuffleQueryStageExec => s.shuffle.canChangeNumPartitions
-    case CoalescedShuffleReaderExec(s: ShuffleQueryStageExec, _) => 
s.shuffle.canChangeNumPartitions
+    case CustomShuffleReaderExec(s: ShuffleQueryStageExec, _, _) => 
s.shuffle.canChangeNumPartitions
     case _ => false
   }
 }
-
-/**
- * A wrapper of shuffle query stage, which submits one or more reduce tasks 
per mapper to read the
- * shuffle files written by one mapper. By doing this, it's very likely to 
read the shuffle files
- * locally, as the shuffle files that a reduce task needs to read are in one 
node.
- *
- * @param child It's usually `ShuffleQueryStageExec`, but can be the shuffle 
exchange node during
- *              canonicalization.
- * @param partitionStartIndicesPerMapper A mapper usually writes many shuffle 
blocks, and it's
- *                                       better to launch multiple tasks to 
read shuffle blocks of
- *                                       one mapper. This array contains the 
partition start
- *                                       indices for each mapper.
- */
-case class LocalShuffleReaderExec(
-    child: SparkPlan,
-    partitionStartIndicesPerMapper: Array[Array[Int]]) extends UnaryExecNode {
-
-  override def output: Seq[Attribute] = child.output
-
-  override lazy val outputPartitioning: Partitioning = {
-    // when we read one mapper per task, then the output partitioning is the 
same as the plan
-    // before shuffle.
-    if (partitionStartIndicesPerMapper.forall(_.length == 1)) {
-      child match {
-        case ShuffleQueryStageExec(_, s: ShuffleExchangeExec) =>
-          s.child.outputPartitioning
-        case ShuffleQueryStageExec(_, r @ ReusedExchangeExec(_, s: 
ShuffleExchangeExec)) =>
-          s.child.outputPartitioning match {
-            case e: Expression => r.updateAttr(e).asInstanceOf[Partitioning]
-            case other => other
-          }
-        case _ =>
-          throw new IllegalStateException("operating on canonicalization plan")
-      }
-    } else {
-      UnknownPartitioning(partitionStartIndicesPerMapper.map(_.length).sum)
-    }
-  }
-
-  private var cachedShuffleRDD: RDD[InternalRow] = null
-
-  override protected def doExecute(): RDD[InternalRow] = {
-    if (cachedShuffleRDD == null) {
-      cachedShuffleRDD = child match {
-        case stage: ShuffleQueryStageExec =>
-          stage.shuffle.createLocalShuffleRDD(partitionStartIndicesPerMapper)
-        case _ =>
-          throw new IllegalStateException("operating on canonicalization plan")
-      }
-    }
-    cachedShuffleRDD
-  }
-}
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala
index dd238fe..979fee1 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala
@@ -23,11 +23,7 @@ import scala.collection.mutable.ArrayBuffer
 import org.apache.commons.io.FileUtils
 
 import org.apache.spark.{MapOutputStatistics, MapOutputTrackerMaster, SparkEnv}
-import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.Attribute
 import org.apache.spark.sql.catalyst.plans._
-import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, 
UnknownPartitioning}
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.execution._
 import org.apache.spark.sql.execution.exchange.{EnsureRequirements, 
ShuffleExchangeExec}
@@ -221,7 +217,7 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends 
Rule[SparkPlan] {
               getMapStartIndices(left, partitionIndex, leftTargetSize),
               getNumMappers(left))
           } else {
-            Seq(SinglePartitionSpec(partitionIndex))
+            Seq(CoalescedPartitionSpec(partitionIndex, partitionIndex + 1))
           }
 
           val rightParts = if (isRightSkew) {
@@ -231,7 +227,7 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends 
Rule[SparkPlan] {
               getMapStartIndices(right, partitionIndex, rightTargetSize),
               getNumMappers(right))
           } else {
-            Seq(SinglePartitionSpec(partitionIndex))
+            Seq(CoalescedPartitionSpec(partitionIndex, partitionIndex + 1))
           }
 
           for {
@@ -259,9 +255,9 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends 
Rule[SparkPlan] {
       logDebug("number of skewed partitions: " +
         s"left ${leftSkewDesc.numPartitions}, right 
${rightSkewDesc.numPartitions}")
       if (leftSkewDesc.numPartitions > 0 || rightSkewDesc.numPartitions > 0) {
-        val newLeft = SkewJoinShuffleReaderExec(
+        val newLeft = CustomShuffleReaderExec(
           left, leftSidePartitions.toArray, leftSkewDesc.toString)
-        val newRight = SkewJoinShuffleReaderExec(
+        val newRight = CustomShuffleReaderExec(
           right, rightSidePartitions.toArray, rightSkewDesc.toString)
         smj.copy(
           left = s1.copy(child = newLeft), right = s2.copy(child = newRight), 
isSkewJoin = true)
@@ -277,36 +273,21 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends 
Rule[SparkPlan] {
     assert(nonSkewPartitionIndices.nonEmpty)
     val shouldCoalesce = 
conf.getConf(SQLConf.REDUCE_POST_SHUFFLE_PARTITIONS_ENABLED)
     if (!shouldCoalesce || nonSkewPartitionIndices.length == 1) {
-      Seq(SinglePartitionSpec(nonSkewPartitionIndices.head))
+      nonSkewPartitionIndices.map(i => CoalescedPartitionSpec(i, i + 1))
     } else {
-      val startIndices = ShufflePartitionsCoalescer.coalescePartitions(
+      ShufflePartitionsCoalescer.coalescePartitions(
         Array(leftStats, rightStats),
         firstPartitionIndex = nonSkewPartitionIndices.head,
         // `lastPartitionIndex` is exclusive.
         lastPartitionIndex = nonSkewPartitionIndices.last + 1,
         advisoryTargetSize = conf.targetPostShuffleInputSize)
-      startIndices.indices.map { i =>
-        val startIndex = startIndices(i)
-        val endIndex = if (i == startIndices.length - 1) {
-          // `endIndex` is exclusive.
-          nonSkewPartitionIndices.last + 1
-        } else {
-          startIndices(i + 1)
-        }
-        // Do not create `CoalescedPartitionSpec` if only need to read a singe 
partition.
-        if (startIndex + 1 == endIndex) {
-          SinglePartitionSpec(startIndex)
-        } else {
-          CoalescedPartitionSpec(startIndex, endIndex)
-        }
-      }
     }
   }
 
   private def createSkewPartitions(
       reducerIndex: Int,
       mapStartIndices: Array[Int],
-      numMappers: Int): Seq[PartialPartitionSpec] = {
+      numMappers: Int): Seq[PartialReducerPartitionSpec] = {
     mapStartIndices.indices.map { i =>
       val startMapIndex = mapStartIndices(i)
       val endMapIndex = if (i == mapStartIndices.length - 1) {
@@ -314,7 +295,7 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends 
Rule[SparkPlan] {
       } else {
         mapStartIndices(i + 1)
       }
-      PartialPartitionSpec(reducerIndex, startMapIndex, endMapIndex)
+      PartialReducerPartitionSpec(reducerIndex, startMapIndex, endMapIndex)
     }
   }
 
@@ -387,40 +368,3 @@ private class SkewDesc {
     }
   }
 }
-
-/**
- * A wrapper of shuffle query stage, which follows the given partition 
arrangement.
- *
- * @param child It's usually `ShuffleQueryStageExec`, but can be the shuffle 
exchange node during
- *              canonicalization.
- * @param partitionSpecs The partition specs that defines the arrangement.
- * @param skewDesc The description of the skewed partitions.
- */
-case class SkewJoinShuffleReaderExec(
-    child: SparkPlan,
-    partitionSpecs: Array[ShufflePartitionSpec],
-    skewDesc: String) extends UnaryExecNode {
-
-  override def output: Seq[Attribute] = child.output
-
-  override def outputPartitioning: Partitioning = {
-    UnknownPartitioning(partitionSpecs.length)
-  }
-
-  override def stringArgs: Iterator[Any] = Iterator(skewDesc)
-
-  private var cachedShuffleRDD: RDD[InternalRow] = null
-
-  override protected def doExecute(): RDD[InternalRow] = {
-    if (cachedShuffleRDD == null) {
-      cachedShuffleRDD = child match {
-        case stage: ShuffleQueryStageExec =>
-          new CustomShuffledRowRDD(
-            stage.shuffle.shuffleDependency, stage.shuffle.readMetrics, 
partitionSpecs)
-        case _ =>
-          throw new IllegalStateException("operating on canonicalization plan")
-      }
-    }
-    cachedShuffleRDD
-  }
-}
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ReduceNumShufflePartitions.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ReduceNumShufflePartitions.scala
index 5bbcb14..767a4b2 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ReduceNumShufflePartitions.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ReduceNumShufflePartitions.scala
@@ -18,12 +18,8 @@
 package org.apache.spark.sql.execution.adaptive
 
 import org.apache.spark.MapOutputStatistics
-import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.Attribute
-import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, 
UnknownPartitioning}
 import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.execution.{ShuffledRowRDD, SparkPlan, 
UnaryExecNode}
+import org.apache.spark.sql.execution.SparkPlan
 import org.apache.spark.sql.internal.SQLConf
 
 /**
@@ -31,12 +27,14 @@ import org.apache.spark.sql.internal.SQLConf
  * avoid many small reduce tasks that hurt performance.
  */
 case class ReduceNumShufflePartitions(conf: SQLConf) extends Rule[SparkPlan] {
+  import ReduceNumShufflePartitions._
 
   override def apply(plan: SparkPlan): SparkPlan = {
     if (!conf.reducePostShufflePartitionsEnabled) {
       return plan
     }
-    if (!plan.collectLeaves().forall(_.isInstanceOf[QueryStageExec])) {
+    if (!plan.collectLeaves().forall(_.isInstanceOf[QueryStageExec])
+        || plan.find(_.isInstanceOf[CustomShuffleReaderExec]).isDefined) {
       // If not all leaf nodes are query stages, it's not safe to reduce the 
number of
       // shuffle partitions, because we may break the assumption that all 
children of a spark plan
       // have same number of output partitions.
@@ -44,8 +42,6 @@ case class ReduceNumShufflePartitions(conf: SQLConf) extends 
Rule[SparkPlan] {
     }
 
     def collectShuffleStages(plan: SparkPlan): Seq[ShuffleQueryStageExec] = 
plan match {
-      case _: LocalShuffleReaderExec => Nil
-      case _: SkewJoinShuffleReaderExec => Nil
       case stage: ShuffleQueryStageExec => Seq(stage)
       case _ => plan.children.flatMap(collectShuffleStages)
     }
@@ -70,7 +66,7 @@ case class ReduceNumShufflePartitions(conf: SQLConf) extends 
Rule[SparkPlan] {
       val distinctNumPreShufflePartitions =
         validMetrics.map(stats => stats.bytesByPartitionId.length).distinct
       if (validMetrics.nonEmpty && distinctNumPreShufflePartitions.length == 
1) {
-        val partitionStartIndices = 
ShufflePartitionsCoalescer.coalescePartitions(
+        val partitionSpecs = ShufflePartitionsCoalescer.coalescePartitions(
           validMetrics.toArray,
           firstPartitionIndex = 0,
           lastPartitionIndex = distinctNumPreShufflePartitions.head,
@@ -83,7 +79,7 @@ case class ReduceNumShufflePartitions(conf: SQLConf) extends 
Rule[SparkPlan] {
           // `partitionStartIndices`, so that all the leaf shuffles in a stage 
have the same
           // number of output partitions.
           case stage: ShuffleQueryStageExec if stageIds.contains(stage.id) =>
-            CoalescedShuffleReaderExec(stage, partitionStartIndices)
+            CustomShuffleReaderExec(stage, partitionSpecs, 
COALESCED_SHUFFLE_READER_DESCRIPTION)
         }
       } else {
         plan
@@ -92,35 +88,6 @@ case class ReduceNumShufflePartitions(conf: SQLConf) extends 
Rule[SparkPlan] {
   }
 }
 
-/**
- * A wrapper of shuffle query stage, which submits fewer reduce task as one 
reduce task may read
- * multiple shuffle partitions. This can avoid many small reduce tasks that 
hurt performance.
- *
- * @param child It's usually `ShuffleQueryStageExec`, but can be the shuffle 
exchange node during
- *              canonicalization.
- * @param partitionStartIndices The start partition indices for the coalesced 
partitions.
- */
-case class CoalescedShuffleReaderExec(
-    child: SparkPlan,
-    partitionStartIndices: Array[Int]) extends UnaryExecNode {
-
-  override def output: Seq[Attribute] = child.output
-
-  override def outputPartitioning: Partitioning = {
-    UnknownPartitioning(partitionStartIndices.length)
-  }
-
-  private var cachedShuffleRDD: ShuffledRowRDD = null
-
-  override protected def doExecute(): RDD[InternalRow] = {
-    if (cachedShuffleRDD == null) {
-      cachedShuffleRDD = child match {
-        case stage: ShuffleQueryStageExec =>
-          stage.shuffle.createShuffledRDD(Some(partitionStartIndices))
-        case _ =>
-          throw new IllegalStateException("operating on canonicalization plan")
-      }
-    }
-    cachedShuffleRDD
-  }
+object ReduceNumShufflePartitions {
+  val COALESCED_SHUFFLE_READER_DESCRIPTION = "coalesced"
 }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsCoalescer.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsCoalescer.scala
index 18f0585..c3b8bf6 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsCoalescer.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsCoalescer.scala
@@ -21,11 +21,12 @@ import scala.collection.mutable.ArrayBuffer
 
 import org.apache.spark.MapOutputStatistics
 import org.apache.spark.internal.Logging
+import org.apache.spark.sql.execution.{CoalescedPartitionSpec, 
ShufflePartitionSpec}
 
 object ShufflePartitionsCoalescer extends Logging {
 
   /**
-   * Coalesce the same range of partitions (`firstPartitionIndex`` to 
`lastPartitionIndex`, the
+   * Coalesce the same range of partitions (`firstPartitionIndex` to 
`lastPartitionIndex`, the
    * start is inclusive and the end is exclusive) from multiple shuffles. This 
method assumes that
    * all the shuffles have the same number of partitions, and the partitions 
of same index will be
    * read together by one task.
@@ -46,15 +47,17 @@ object ShufflePartitionsCoalescer extends Logging {
    *  - coalesced partition 2: shuffle partition 2 (size 170 MiB)
    *  - coalesced partition 3: shuffle partition 3 and 4 (size 50 MiB)
    *
-   *  @return An array of partition indices which represents the coalesced 
partitions. For example,
-   *          [0, 2, 3] means 3 coalesced partitions: [0, 2), [2, 3), [3, 
lastPartitionIndex]
+   *  @return An array of [[CoalescedPartitionSpec]]s. For example, if 
partitions [0, 1, 2, 3, 4]
+   *          split at indices [0, 2, 3], the returned partition specs will be:
+   *          CoalescedPartitionSpec(0, 2), CoalescedPartitionSpec(2, 3) and
+   *          CoalescedPartitionSpec(3, 5).
    */
   def coalescePartitions(
       mapOutputStatistics: Array[MapOutputStatistics],
       firstPartitionIndex: Int,
       lastPartitionIndex: Int,
       advisoryTargetSize: Long,
-      minNumPartitions: Int = 1): Array[Int] = {
+      minNumPartitions: Int = 1): Array[ShufflePartitionSpec] = {
     // If `minNumPartitions` is very large, it is possible that we need to use 
a value less than
     // `advisoryTargetSize` as the target size of a coalesced task.
     val totalPostShuffleInputSize = 
mapOutputStatistics.map(_.bytesByPartitionId.sum).sum
@@ -82,8 +85,8 @@ object ShufflePartitionsCoalescer extends Logging {
       "There should be only one distinct value of the number of shuffle 
partitions " +
         "among registered Exchange operators.")
 
-    val splitPoints = ArrayBuffer[Int]()
-    splitPoints += firstPartitionIndex
+    val partitionSpecs = ArrayBuffer[CoalescedPartitionSpec]()
+    var latestSplitPoint = firstPartitionIndex
     var coalescedSize = 0L
     var i = firstPartitionIndex
     while (i < lastPartitionIndex) {
@@ -97,8 +100,9 @@ object ShufflePartitionsCoalescer extends Logging {
 
       // If including the `totalSizeOfCurrentPartition` would exceed the 
target size, then start a
       // new coalesced partition.
-      if (i > firstPartitionIndex && coalescedSize + 
totalSizeOfCurrentPartition > targetSize) {
-        splitPoints += i
+      if (i > latestSplitPoint && coalescedSize + totalSizeOfCurrentPartition 
> targetSize) {
+        partitionSpecs += CoalescedPartitionSpec(latestSplitPoint, i)
+        latestSplitPoint = i
         // reset postShuffleInputSize.
         coalescedSize = totalSizeOfCurrentPartition
       } else {
@@ -106,7 +110,8 @@ object ShufflePartitionsCoalescer extends Logging {
       }
       i += 1
     }
+    partitionSpecs += CoalescedPartitionSpec(latestSplitPoint, 
lastPartitionIndex)
 
-    splitPoints.toArray
+    partitionSpecs.toArray
   }
 }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
index 4b08da0..b06742e 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
@@ -34,7 +34,6 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, 
BoundReference, Div
 import 
org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering
 import org.apache.spark.sql.catalyst.plans.physical._
 import org.apache.spark.sql.execution._
-import org.apache.spark.sql.execution.adaptive.LocalShuffledRowRDD
 import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics, 
SQLShuffleReadMetricsReporter, SQLShuffleWriteMetricsReporter}
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types.StructType
@@ -88,15 +87,6 @@ case class ShuffleExchangeExec(
       writeMetrics)
   }
 
-  def createShuffledRDD(partitionStartIndices: Option[Array[Int]]): 
ShuffledRowRDD = {
-    new ShuffledRowRDD(shuffleDependency, readMetrics, partitionStartIndices)
-  }
-
-  def createLocalShuffleRDD(
-      partitionStartIndicesPerMapper: Array[Array[Int]]): LocalShuffledRowRDD 
= {
-    new LocalShuffledRowRDD(shuffleDependency, readMetrics, 
partitionStartIndicesPerMapper)
-  }
-
   /**
    * Caches the created ShuffleRowRDD so we can reuse that.
    */
@@ -105,7 +95,7 @@ case class ShuffleExchangeExec(
   protected override def doExecute(): RDD[InternalRow] = attachTree(this, 
"execute") {
     // Returns the same ShuffleRowRDD if this plan is used by multiple plans.
     if (cachedShuffleRDD == null) {
-      cachedShuffleRDD = createShuffledRDD(None)
+      cachedShuffleRDD = new ShuffledRowRDD(shuffleDependency, readMetrics)
     }
     cachedShuffleRDD
   }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/ReduceNumShufflePartitionsSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/ReduceNumShufflePartitionsSuite.scala
index 5565a0d..a32b684 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/ReduceNumShufflePartitionsSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/ReduceNumShufflePartitionsSuite.scala
@@ -23,7 +23,8 @@ import org.apache.spark.{SparkConf, SparkFunSuite}
 import org.apache.spark.internal.config.UI.UI_ENABLED
 import org.apache.spark.sql._
 import org.apache.spark.sql.execution.adaptive._
-import org.apache.spark.sql.execution.adaptive.CoalescedShuffleReaderExec
+import org.apache.spark.sql.execution.adaptive.CustomShuffleReaderExec
+import 
org.apache.spark.sql.execution.adaptive.ReduceNumShufflePartitions.COALESCED_SHUFFLE_READER_DESCRIPTION
 import org.apache.spark.sql.execution.exchange.ReusedExchangeExec
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.internal.SQLConf
@@ -107,7 +108,7 @@ class ReduceNumShufflePartitionsSuite extends SparkFunSuite 
with BeforeAndAfterA
         val finalPlan = agg.queryExecution.executedPlan
           .asInstanceOf[AdaptiveSparkPlanExec].executedPlan
         val shuffleReaders = finalPlan.collect {
-          case reader: CoalescedShuffleReaderExec => reader
+          case r @ CustomShuffleReaderExec(_, _, 
COALESCED_SHUFFLE_READER_DESCRIPTION) => r
         }
         assert(shuffleReaders.length === 1)
         minNumPostShufflePartitions match {
@@ -154,7 +155,7 @@ class ReduceNumShufflePartitionsSuite extends SparkFunSuite 
with BeforeAndAfterA
         val finalPlan = join.queryExecution.executedPlan
           .asInstanceOf[AdaptiveSparkPlanExec].executedPlan
         val shuffleReaders = finalPlan.collect {
-          case reader: CoalescedShuffleReaderExec => reader
+          case r @ CustomShuffleReaderExec(_, _, 
COALESCED_SHUFFLE_READER_DESCRIPTION) => r
         }
         assert(shuffleReaders.length === 2)
         minNumPostShufflePartitions match {
@@ -206,7 +207,7 @@ class ReduceNumShufflePartitionsSuite extends SparkFunSuite 
with BeforeAndAfterA
         val finalPlan = join.queryExecution.executedPlan
           .asInstanceOf[AdaptiveSparkPlanExec].executedPlan
         val shuffleReaders = finalPlan.collect {
-          case reader: CoalescedShuffleReaderExec => reader
+          case r @ CustomShuffleReaderExec(_, _, 
COALESCED_SHUFFLE_READER_DESCRIPTION) => r
         }
         assert(shuffleReaders.length === 2)
         minNumPostShufflePartitions match {
@@ -258,7 +259,7 @@ class ReduceNumShufflePartitionsSuite extends SparkFunSuite 
with BeforeAndAfterA
         val finalPlan = join.queryExecution.executedPlan
           .asInstanceOf[AdaptiveSparkPlanExec].executedPlan
         val shuffleReaders = finalPlan.collect {
-          case reader: CoalescedShuffleReaderExec => reader
+          case r @ CustomShuffleReaderExec(_, _, 
COALESCED_SHUFFLE_READER_DESCRIPTION) => r
         }
         assert(shuffleReaders.length === 2)
         minNumPostShufflePartitions match {
@@ -301,7 +302,7 @@ class ReduceNumShufflePartitionsSuite extends SparkFunSuite 
with BeforeAndAfterA
           val finalPlan = join.queryExecution.executedPlan
             .asInstanceOf[AdaptiveSparkPlanExec].executedPlan
           val shuffleReaders = finalPlan.collect {
-            case reader: CoalescedShuffleReaderExec => reader
+            case r @ CustomShuffleReaderExec(_, _, 
COALESCED_SHUFFLE_READER_DESCRIPTION) => r
           }
           assert(shuffleReaders.length === 0)
         } finally {
@@ -329,7 +330,10 @@ class ReduceNumShufflePartitionsSuite extends 
SparkFunSuite with BeforeAndAfterA
       assert(finalPlan.collect {
         case ShuffleQueryStageExec(_, r: ReusedExchangeExec) => r
       }.length == 2)
-      assert(finalPlan.collect { case p: CoalescedShuffleReaderExec => p 
}.length == 3)
+      assert(
+        finalPlan.collect {
+          case p @ CustomShuffleReaderExec(_, _, 
COALESCED_SHUFFLE_READER_DESCRIPTION) => p
+        }.length == 3)
 
 
       // test case 2: a query stage has 2 parent stages.
@@ -377,7 +381,10 @@ class ReduceNumShufflePartitionsSuite extends 
SparkFunSuite with BeforeAndAfterA
         Seq(0, 1, 2).map(i => Row(i)))
       val finalPlan = resultDf.queryExecution.executedPlan
         .asInstanceOf[AdaptiveSparkPlanExec].executedPlan
-      assert(finalPlan.collect { case p: CoalescedShuffleReaderExec => p 
}.length == 0)
+      assert(
+        finalPlan.collect {
+          case p @ CustomShuffleReaderExec(_, _, 
COALESCED_SHUFFLE_READER_DESCRIPTION) => p
+        }.isEmpty)
     }
     withSparkSession(test, 200, None)
   }
@@ -395,7 +402,10 @@ class ReduceNumShufflePartitionsSuite extends 
SparkFunSuite with BeforeAndAfterA
         .asInstanceOf[AdaptiveSparkPlanExec].executedPlan
       // As the pre-shuffle partition number are different, we will skip 
reducing
       // the shuffle partition numbers.
-      assert(finalPlan.collect { case p: CoalescedShuffleReaderExec => p 
}.length == 0)
+      assert(
+        finalPlan.collect {
+          case p @ CustomShuffleReaderExec(_, _, 
COALESCED_SHUFFLE_READER_DESCRIPTION) => p
+        }.isEmpty)
     }
     withSparkSession(test, 100, None)
   }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/ShufflePartitionsCoalescerSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/ShufflePartitionsCoalescerSuite.scala
index fcfde83..0befa06 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/ShufflePartitionsCoalescerSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/ShufflePartitionsCoalescerSuite.scala
@@ -24,7 +24,7 @@ class ShufflePartitionsCoalescerSuite extends SparkFunSuite {
 
   private def checkEstimation(
       bytesByPartitionIdArray: Array[Array[Long]],
-      expectedPartitionStartIndices: Array[Int],
+      expectedPartitionStartIndices: Array[CoalescedPartitionSpec],
       targetSize: Long,
       minNumPartitions: Int = 1): Unit = {
     val mapOutputStatistics = bytesByPartitionIdArray.zipWithIndex.map {
@@ -46,44 +46,54 @@ class ShufflePartitionsCoalescerSuite extends SparkFunSuite 
{
     {
       // All bytes per partition are 0.
       val bytesByPartitionId = Array[Long](0, 0, 0, 0, 0)
-      val expectedPartitionStartIndices = Array[Int](0)
-      checkEstimation(Array(bytesByPartitionId), 
expectedPartitionStartIndices, targetSize)
+      val expectedPartitionSpecs = Array(CoalescedPartitionSpec(0, 5))
+      checkEstimation(Array(bytesByPartitionId), expectedPartitionSpecs, 
targetSize)
     }
 
     {
       // Some bytes per partition are 0 and total size is less than the target 
size.
       // 1 coalesced partition is expected.
       val bytesByPartitionId = Array[Long](10, 0, 20, 0, 0)
-      val expectedPartitionStartIndices = Array[Int](0)
-      checkEstimation(Array(bytesByPartitionId), 
expectedPartitionStartIndices, targetSize)
+      val expectedPartitionSpecs = Array(CoalescedPartitionSpec(0, 5))
+      checkEstimation(Array(bytesByPartitionId), expectedPartitionSpecs, 
targetSize)
     }
 
     {
       // 2 coalesced partitions are expected.
       val bytesByPartitionId = Array[Long](10, 0, 90, 20, 0)
-      val expectedPartitionStartIndices = Array[Int](0, 3)
-      checkEstimation(Array(bytesByPartitionId), 
expectedPartitionStartIndices, targetSize)
+      val expectedPartitionSpecs = Array(CoalescedPartitionSpec(0, 3), 
CoalescedPartitionSpec(3, 5))
+      checkEstimation(Array(bytesByPartitionId), expectedPartitionSpecs, 
targetSize)
     }
 
     {
       // There are a few large shuffle partitions.
       val bytesByPartitionId = Array[Long](110, 10, 100, 110, 0)
-      val expectedPartitionStartIndices = Array[Int](0, 1, 2, 3, 4)
-      checkEstimation(Array(bytesByPartitionId), 
expectedPartitionStartIndices, targetSize)
+      val expectedPartitionSpecs = Array(
+        CoalescedPartitionSpec(0, 1),
+        CoalescedPartitionSpec(1, 2),
+        CoalescedPartitionSpec(2, 3),
+        CoalescedPartitionSpec(3, 4),
+        CoalescedPartitionSpec(4, 5))
+      checkEstimation(Array(bytesByPartitionId), expectedPartitionSpecs, 
targetSize)
     }
 
     {
       // All shuffle partitions are larger than the targeted size.
       val bytesByPartitionId = Array[Long](100, 110, 100, 110, 110)
-      val expectedPartitionStartIndices = Array[Int](0, 1, 2, 3, 4)
-      checkEstimation(Array(bytesByPartitionId), 
expectedPartitionStartIndices, targetSize)
+      val expectedPartitionSpecs = Array(
+        CoalescedPartitionSpec(0, 1),
+        CoalescedPartitionSpec(1, 2),
+        CoalescedPartitionSpec(2, 3),
+        CoalescedPartitionSpec(3, 4),
+        CoalescedPartitionSpec(4, 5))
+      checkEstimation(Array(bytesByPartitionId), expectedPartitionSpecs, 
targetSize)
     }
 
     {
       // The last shuffle partition is in a single coalesced partition.
       val bytesByPartitionId = Array[Long](30, 30, 0, 40, 110)
-      val expectedPartitionStartIndices = Array[Int](0, 4)
-      checkEstimation(Array(bytesByPartitionId), 
expectedPartitionStartIndices, targetSize)
+      val expectedPartitionSpecs = Array(CoalescedPartitionSpec(0, 4), 
CoalescedPartitionSpec(4, 5))
+      checkEstimation(Array(bytesByPartitionId), expectedPartitionSpecs, 
targetSize)
     }
   }
 
@@ -104,10 +114,10 @@ class ShufflePartitionsCoalescerSuite extends 
SparkFunSuite {
       // All bytes per partition are 0.
       val bytesByPartitionId1 = Array[Long](0, 0, 0, 0, 0)
       val bytesByPartitionId2 = Array[Long](0, 0, 0, 0, 0)
-      val expectedPartitionStartIndices = Array[Int](0)
+      val expectedPartitionSpecs = Array(CoalescedPartitionSpec(0, 5))
       checkEstimation(
         Array(bytesByPartitionId1, bytesByPartitionId2),
-        expectedPartitionStartIndices,
+        expectedPartitionSpecs,
         targetSize)
     }
 
@@ -116,10 +126,10 @@ class ShufflePartitionsCoalescerSuite extends 
SparkFunSuite {
       // 1 coalesced partition is expected.
       val bytesByPartitionId1 = Array[Long](0, 10, 0, 20, 0)
       val bytesByPartitionId2 = Array[Long](30, 0, 20, 0, 20)
-      val expectedPartitionStartIndices = Array[Int](0)
+      val expectedPartitionSpecs = Array(CoalescedPartitionSpec(0, 5))
       checkEstimation(
         Array(bytesByPartitionId1, bytesByPartitionId2),
-        expectedPartitionStartIndices,
+        expectedPartitionSpecs,
         targetSize)
     }
 
@@ -127,10 +137,13 @@ class ShufflePartitionsCoalescerSuite extends 
SparkFunSuite {
       // 2 coalesced partition are expected.
       val bytesByPartitionId1 = Array[Long](0, 10, 0, 20, 0)
       val bytesByPartitionId2 = Array[Long](30, 0, 70, 0, 30)
-      val expectedPartitionStartIndices = Array[Int](0, 2, 4)
+      val expectedPartitionSpecs = Array(
+        CoalescedPartitionSpec(0, 2),
+        CoalescedPartitionSpec(2, 4),
+        CoalescedPartitionSpec(4, 5))
       checkEstimation(
         Array(bytesByPartitionId1, bytesByPartitionId2),
-        expectedPartitionStartIndices,
+        expectedPartitionSpecs,
         targetSize)
     }
 
@@ -138,10 +151,14 @@ class ShufflePartitionsCoalescerSuite extends 
SparkFunSuite {
       // 4 coalesced partition are expected.
       val bytesByPartitionId1 = Array[Long](0, 99, 0, 20, 0)
       val bytesByPartitionId2 = Array[Long](30, 0, 70, 0, 30)
-      val expectedPartitionStartIndices = Array[Int](0, 1, 2, 4)
+      val expectedPartitionSpecs = Array(
+        CoalescedPartitionSpec(0, 1),
+        CoalescedPartitionSpec(1, 2),
+        CoalescedPartitionSpec(2, 4),
+        CoalescedPartitionSpec(4, 5))
       checkEstimation(
         Array(bytesByPartitionId1, bytesByPartitionId2),
-        expectedPartitionStartIndices,
+        expectedPartitionSpecs,
         targetSize)
     }
 
@@ -149,10 +166,14 @@ class ShufflePartitionsCoalescerSuite extends 
SparkFunSuite {
       // 2 coalesced partition are needed.
       val bytesByPartitionId1 = Array[Long](0, 100, 0, 30, 0)
       val bytesByPartitionId2 = Array[Long](30, 0, 70, 0, 30)
-      val expectedPartitionStartIndices = Array[Int](0, 1, 2, 4)
+      val expectedPartitionSpecs = Array(
+        CoalescedPartitionSpec(0, 1),
+        CoalescedPartitionSpec(1, 2),
+        CoalescedPartitionSpec(2, 4),
+        CoalescedPartitionSpec(4, 5))
       checkEstimation(
         Array(bytesByPartitionId1, bytesByPartitionId2),
-        expectedPartitionStartIndices,
+        expectedPartitionSpecs,
         targetSize)
     }
 
@@ -160,10 +181,15 @@ class ShufflePartitionsCoalescerSuite extends 
SparkFunSuite {
       // There are a few large shuffle partitions.
       val bytesByPartitionId1 = Array[Long](0, 100, 40, 30, 0)
       val bytesByPartitionId2 = Array[Long](30, 0, 60, 0, 110)
-      val expectedPartitionStartIndices = Array[Int](0, 1, 2, 3, 4)
+      val expectedPartitionSpecs = Array(
+        CoalescedPartitionSpec(0, 1),
+        CoalescedPartitionSpec(1, 2),
+        CoalescedPartitionSpec(2, 3),
+        CoalescedPartitionSpec(3, 4),
+        CoalescedPartitionSpec(4, 5))
       checkEstimation(
         Array(bytesByPartitionId1, bytesByPartitionId2),
-        expectedPartitionStartIndices,
+        expectedPartitionSpecs,
         targetSize)
     }
 
@@ -171,10 +197,15 @@ class ShufflePartitionsCoalescerSuite extends 
SparkFunSuite {
       // All pairs of shuffle partitions are larger than the targeted size.
       val bytesByPartitionId1 = Array[Long](100, 100, 40, 30, 0)
       val bytesByPartitionId2 = Array[Long](30, 0, 60, 70, 110)
-      val expectedPartitionStartIndices = Array[Int](0, 1, 2, 3, 4)
+      val expectedPartitionSpecs = Array(
+        CoalescedPartitionSpec(0, 1),
+        CoalescedPartitionSpec(1, 2),
+        CoalescedPartitionSpec(2, 3),
+        CoalescedPartitionSpec(3, 4),
+        CoalescedPartitionSpec(4, 5))
       checkEstimation(
         Array(bytesByPartitionId1, bytesByPartitionId2),
-        expectedPartitionStartIndices,
+        expectedPartitionSpecs,
         targetSize)
     }
   }
@@ -188,10 +219,10 @@ class ShufflePartitionsCoalescerSuite extends 
SparkFunSuite {
       // the size of data is 0.
       val bytesByPartitionId1 = Array[Long](0, 0, 0, 0, 0)
       val bytesByPartitionId2 = Array[Long](0, 0, 0, 0, 0)
-      val expectedPartitionStartIndices = Array[Int](0)
+      val expectedPartitionSpecs = Array(CoalescedPartitionSpec(0, 5))
       checkEstimation(
         Array(bytesByPartitionId1, bytesByPartitionId2),
-        expectedPartitionStartIndices,
+        expectedPartitionSpecs,
         targetSize, minNumPartitions)
     }
 
@@ -199,10 +230,10 @@ class ShufflePartitionsCoalescerSuite extends 
SparkFunSuite {
       // The minimal number of coalesced partitions is enforced.
       val bytesByPartitionId1 = Array[Long](10, 5, 5, 0, 20)
       val bytesByPartitionId2 = Array[Long](5, 10, 0, 10, 5)
-      val expectedPartitionStartIndices = Array[Int](0, 3)
+      val expectedPartitionSpecs = Array(CoalescedPartitionSpec(0, 3), 
CoalescedPartitionSpec(3, 5))
       checkEstimation(
         Array(bytesByPartitionId1, bytesByPartitionId2),
-        expectedPartitionStartIndices,
+        expectedPartitionSpecs,
         targetSize, minNumPartitions)
     }
 
@@ -210,10 +241,14 @@ class ShufflePartitionsCoalescerSuite extends 
SparkFunSuite {
       // The number of coalesced partitions is determined by the algorithm.
       val bytesByPartitionId1 = Array[Long](10, 50, 20, 80, 20)
       val bytesByPartitionId2 = Array[Long](40, 10, 0, 10, 30)
-      val expectedPartitionStartIndices = Array[Int](0, 1, 3, 4)
+      val expectedPartitionSpecs = Array(
+        CoalescedPartitionSpec(0, 1),
+        CoalescedPartitionSpec(1, 3),
+        CoalescedPartitionSpec(3, 4),
+        CoalescedPartitionSpec(4, 5))
       checkEstimation(
         Array(bytesByPartitionId1, bytesByPartitionId2),
-        expectedPartitionStartIndices,
+        expectedPartitionSpecs,
         targetSize, minNumPartitions)
     }
   }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
index 64566af..17f6b29 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
@@ -22,8 +22,9 @@ import java.net.URI
 
 import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent, 
SparkListenerJobStart}
 import org.apache.spark.sql.QueryTest
-import org.apache.spark.sql.execution.{ReusedSubqueryExec, SparkPlan}
-import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, 
Exchange, ReusedExchangeExec, ShuffleExchangeExec}
+import org.apache.spark.sql.execution.{ReusedSubqueryExec, ShuffledRowRDD, 
SparkPlan}
+import 
org.apache.spark.sql.execution.adaptive.OptimizeLocalShuffleReader.LOCAL_SHUFFLE_READER_DESCRIPTION
+import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, 
Exchange, ReusedExchangeExec}
 import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, 
BuildRight, SortMergeJoinExec}
 import 
org.apache.spark.sql.execution.ui.SparkListenerSQLAdaptiveExecutionUpdate
 import org.apache.spark.sql.internal.SQLConf
@@ -110,7 +111,7 @@ class AdaptiveQueryExecSuite
     }.length
 
     val numLocalReaders = collect(plan) {
-      case reader: LocalShuffleReaderExec => reader
+      case reader @ CustomShuffleReaderExec(_, _, 
LOCAL_SHUFFLE_READER_DESCRIPTION) => reader
     }.length
 
     assert(numShuffles === (numLocalReaders + numShufflesWithoutLocalReader))
@@ -142,11 +143,11 @@ class AdaptiveQueryExecSuite
       val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
       assert(bhj.size == 1)
       val localReaders = collect(adaptivePlan) {
-        case reader: LocalShuffleReaderExec => reader
+        case reader @ CustomShuffleReaderExec(_, _, 
LOCAL_SHUFFLE_READER_DESCRIPTION) => reader
       }
       assert(localReaders.length == 2)
-      val localShuffleRDD0 = 
localReaders(0).execute().asInstanceOf[LocalShuffledRowRDD]
-      val localShuffleRDD1 = 
localReaders(1).execute().asInstanceOf[LocalShuffledRowRDD]
+      val localShuffleRDD0 = 
localReaders(0).execute().asInstanceOf[ShuffledRowRDD]
+      val localShuffleRDD1 = 
localReaders(1).execute().asInstanceOf[ShuffledRowRDD]
       // The pre-shuffle partition size is [0, 0, 0, 72, 0]
       // And the partitionStartIndices is [0, 3, 4], so advisoryParallelism = 
3.
       // the final parallelism is
@@ -174,11 +175,11 @@ class AdaptiveQueryExecSuite
       val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
       assert(bhj.size == 1)
       val localReaders = collect(adaptivePlan) {
-        case reader: LocalShuffleReaderExec => reader
+        case reader @ CustomShuffleReaderExec(_, _, 
LOCAL_SHUFFLE_READER_DESCRIPTION) => reader
       }
       assert(localReaders.length == 2)
-      val localShuffleRDD0 = 
localReaders(0).execute().asInstanceOf[LocalShuffledRowRDD]
-      val localShuffleRDD1 = 
localReaders(1).execute().asInstanceOf[LocalShuffledRowRDD]
+      val localShuffleRDD0 = 
localReaders(0).execute().asInstanceOf[ShuffledRowRDD]
+      val localShuffleRDD1 = 
localReaders(1).execute().asInstanceOf[ShuffledRowRDD]
       // the final parallelism is math.max(1, numReduces / numMappers): 
math.max(1, 5/2) = 2
       // and the partitions length is 2 * numMappers = 4
       assert(localShuffleRDD0.getPartitions.length == 4)
@@ -622,10 +623,10 @@ class AdaptiveQueryExecSuite
         def checkSkewJoin(joins: Seq[SortMergeJoinExec], 
expectedNumPartitions: Int): Unit = {
           assert(joins.size == 1 && joins.head.isSkewJoin)
           assert(joins.head.left.collect {
-            case r: SkewJoinShuffleReaderExec => r
+            case r: CustomShuffleReaderExec => r
           }.head.partitionSpecs.length == expectedNumPartitions)
           assert(joins.head.right.collect {
-            case r: SkewJoinShuffleReaderExec => r
+            case r: CustomShuffleReaderExec => r
           }.head.partitionSpecs.length == expectedNumPartitions)
         }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to