This is an automated email from the ASF dual-hosted git repository.

hvanhovell pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new c79f471  [SPARK-23128][SQL] A new approach to do adaptive execution in 
Spark SQL
c79f471 is described below

commit c79f471d0475fd98ddeb1e6281551e42684837d2
Author: maryannxue <maryann...@apache.org>
AuthorDate: Sat Jun 15 11:27:15 2019 +0200

    [SPARK-23128][SQL] A new approach to do adaptive execution in Spark SQL
    
    ## What changes were proposed in this pull request?
    
    Implemented a new SparkPlan that executes the query adaptively. It splits 
the query plan into independent stages and executes them in order according to 
their dependencies. The query stage materializes its output at the end. When 
one stage completes, the data statistics of the materialized output will be 
used to optimize the remainder of the query.
    
    The adaptive mode is off by default, when turned on, user can see 
"AdaptiveSparkPlan" as the top node of a query or sub-query. The inner plan of 
"AdaptiveSparkPlan" is subject to change during query execution but becomes 
final once the execution is complete. Whether the inner plan is final is 
included in the EXPLAIN string. Below is an example of the EXPLAIN plan before 
and after execution:
    
    Query:
    ```
    SELECT * FROM testData JOIN testData2 ON key = a WHERE value = '1'
    ```
    
    Before execution:
    ```
    == Physical Plan ==
    AdaptiveSparkPlan(isFinalPlan=false)
    +- SortMergeJoin [key#13], [a#23], Inner
       :- Sort [key#13 ASC NULLS FIRST], false, 0
       :  +- Exchange hashpartitioning(key#13, 5)
       :     +- Filter (isnotnull(value#14) AND (value#14 = 1))
       :        +- SerializeFromObject [knownnotnull(assertnotnull(input[0, 
org.apache.spark.sql.test.SQLTestData$TestData, true])).key AS key#13, 
staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, 
fromString, knownnotnull(assertnotnull(input[0, 
org.apache.spark.sql.test.SQLTestData$TestData, true])).value, true, false) AS 
value#14]
       :           +- Scan[obj#12]
       +- Sort [a#23 ASC NULLS FIRST], false, 0
          +- Exchange hashpartitioning(a#23, 5)
             +- SerializeFromObject [knownnotnull(assertnotnull(input[0, 
org.apache.spark.sql.test.SQLTestData$TestData2, true])).a AS a#23, 
knownnotnull(assertnotnull(input[0, 
org.apache.spark.sql.test.SQLTestData$TestData2, true])).b AS b#24]
                +- Scan[obj#22]
    ```
    
    After execution:
    ```
    == Physical Plan ==
    AdaptiveSparkPlan(isFinalPlan=true)
    +- *(1) BroadcastHashJoin [key#13], [a#23], Inner, BuildLeft
       :- BroadcastQueryStage 2
       :  +- BroadcastExchange HashedRelationBroadcastMode(List(cast(input[0, 
int, false] as bigint)))
       :     +- ShuffleQueryStage 0
       :        +- Exchange hashpartitioning(key#13, 5)
       :           +- *(1) Filter (isnotnull(value#14) AND (value#14 = 1))
       :              +- *(1) SerializeFromObject 
[knownnotnull(assertnotnull(input[0, 
org.apache.spark.sql.test.SQLTestData$TestData, true])).key AS key#13, 
staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, 
fromString, knownnotnull(assertnotnull(input[0, 
org.apache.spark.sql.test.SQLTestData$TestData, true])).value, true, false) AS 
value#14]
       :                 +- Scan[obj#12]
       +- ShuffleQueryStage 1
          +- Exchange hashpartitioning(a#23, 5)
             +- *(1) SerializeFromObject [knownnotnull(assertnotnull(input[0, 
org.apache.spark.sql.test.SQLTestData$TestData2, true])).a AS a#23, 
knownnotnull(assertnotnull(input[0, 
org.apache.spark.sql.test.SQLTestData$TestData2, true])).b AS b#24]
                +- Scan[obj#22]
    ```
    
    Credit also goes to carsonwang and cloud-fan
    
    ## How was this patch tested?
    
    Added new UT.
    
    Closes #24706 from maryannxue/aqe.
    
    Authored-by: maryannxue <maryann...@apache.org>
    Signed-off-by: herman <her...@databricks.com>
---
 .../org/apache/spark/sql/internal/SQLConf.scala    |  11 +-
 .../spark/sql/execution/QueryExecution.scala       |  12 +-
 .../org/apache/spark/sql/execution/SparkPlan.scala |  35 +-
 .../apache/spark/sql/execution/SparkPlanInfo.scala |   3 +
 .../apache/spark/sql/execution/SparkPlanner.scala  |   2 +
 .../spark/sql/execution/SparkStrategies.scala      |   5 +-
 .../sql/execution/WholeStageCodegenExec.scala      |  96 +++--
 .../execution/adaptive/AdaptiveSparkPlanExec.scala | 425 +++++++++++++++++++++
 .../adaptive/InsertAdaptiveSparkPlan.scala         | 131 +++++++
 .../sql/execution/adaptive/LogicalQueryStage.scala |  56 +++
 .../adaptive/LogicalQueryStageStrategy.scala       |  63 +++
 .../adaptive/PlanAdaptiveSubqueries.scala          |  36 ++
 .../sql/execution/adaptive/QueryStageExec.scala    | 210 ++++++++++
 .../execution/exchange/BroadcastExchangeExec.scala |  28 +-
 .../execution/exchange/ExchangeCoordinator.scala   |   2 +-
 .../execution/exchange/ShuffleExchangeExec.scala   |  11 +-
 .../org/apache/spark/sql/execution/subquery.scala  |   3 +
 .../sql/execution/ui/SQLAppStatusListener.scala    |  54 ++-
 .../spark/sql/execution/ui/SQLListener.scala       |   7 +
 .../spark/sql/execution/ui/SparkPlanGraph.scala    |   9 +
 .../adaptive/AdaptiveQueryExecSuite.scala          | 316 +++++++++++++++
 21 files changed, 1430 insertions(+), 85 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index fee7172..1eb39f8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -292,6 +292,12 @@ object SQLConf {
       .bytesConf(ByteUnit.BYTE)
       .createWithDefault(64 * 1024 * 1024)
 
+  val RUNTIME_REOPTIMIZATION_ENABLED =
+    buildConf("spark.sql.runtime.reoptimization.enabled")
+      .doc("When true, enable runtime query re-optimization.")
+      .booleanConf
+      .createWithDefault(false)
+
   val ADAPTIVE_EXECUTION_ENABLED = buildConf("spark.sql.adaptive.enabled")
     .doc("When true, enable adaptive query execution.")
     .booleanConf
@@ -1889,7 +1895,10 @@ class SQLConf extends Serializable with Logging {
   def targetPostShuffleInputSize: Long =
     getConf(SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE)
 
-  def adaptiveExecutionEnabled: Boolean = getConf(ADAPTIVE_EXECUTION_ENABLED)
+  def runtimeReoptimizationEnabled: Boolean = 
getConf(RUNTIME_REOPTIMIZATION_ENABLED)
+
+  def adaptiveExecutionEnabled: Boolean =
+    getConf(ADAPTIVE_EXECUTION_ENABLED) && 
!getConf(RUNTIME_REOPTIMIZATION_ENABLED)
 
   def minNumPostShufflePartitions: Int =
     getConf(SHUFFLE_MIN_NUM_POSTSHUFFLE_PARTITIONS)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
index 5d2710bd..6f0b489 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
@@ -30,6 +30,7 @@ import 
org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer}
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.catalyst.util.StringUtils.{PlanStringConcat, 
StringConcat}
 import org.apache.spark.sql.catalyst.util.truncatedString
+import org.apache.spark.sql.execution.adaptive.InsertAdaptiveSparkPlan
 import org.apache.spark.sql.execution.exchange.{EnsureRequirements, 
ReuseExchange}
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.util.Utils
@@ -74,9 +75,15 @@ class QueryExecution(
 
   lazy val sparkPlan: SparkPlan = 
tracker.measurePhase(QueryPlanningTracker.PLANNING) {
     SparkSession.setActiveSession(sparkSession)
+    // Runtime re-optimization requires a unique instance of every node in the 
logical plan.
+    val logicalPlan = if 
(sparkSession.sessionState.conf.runtimeReoptimizationEnabled) {
+      optimizedPlan.clone()
+    } else {
+      optimizedPlan
+    }
     // TODO: We use next(), i.e. take the first plan returned by the planner, 
here for now,
     //       but we will implement to choose the best plan.
-    planner.plan(ReturnAnswer(optimizedPlan)).next()
+    planner.plan(ReturnAnswer(logicalPlan)).next()
   }
 
   // executedPlan should not be used to initialize any SparkPlan. It should be
@@ -107,6 +114,9 @@ class QueryExecution(
 
   /** A sequence of rules that will be applied in order to the physical plan 
before execution. */
   protected def preparations: Seq[Rule[SparkPlan]] = Seq(
+    // `AdaptiveSparkPlanExec` is a leaf node. If inserted, all the following 
rules will be no-op
+    // as the original plan is hidden behind `AdaptiveSparkPlanExec`.
+    InsertAdaptiveSparkPlan(sparkSession),
     PlanSubqueries(sparkSession),
     EnsureRequirements(sparkSession.sessionState.conf),
     CollapseCodegenStages(sparkSession.sessionState.conf),
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
index ddcf61b..fbe8e50 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
@@ -40,9 +40,11 @@ import org.apache.spark.sql.execution.metric.SQLMetric
 import org.apache.spark.sql.types.DataType
 
 object SparkPlan {
-  // a TreeNode tag in SparkPlan, to carry its original logical plan. The 
planner will add this tag
-  // when converting a logical plan to a physical plan.
+  /** The original [[LogicalPlan]] from which this [[SparkPlan]] is converted. 
*/
   val LOGICAL_PLAN_TAG = TreeNodeTag[LogicalPlan]("logical_plan")
+
+  /** The [[LogicalPlan]] inherited from its ancestor. */
+  val LOGICAL_PLAN_INHERITED_TAG = 
TreeNodeTag[LogicalPlan]("logical_plan_inherited")
 }
 
 /**
@@ -80,6 +82,35 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with 
Logging with Serializ
   }
 
   /**
+   * @return The logical plan this plan is linked to.
+   */
+  def logicalLink: Option[LogicalPlan] =
+    getTagValue(SparkPlan.LOGICAL_PLAN_TAG)
+      .orElse(getTagValue(SparkPlan.LOGICAL_PLAN_INHERITED_TAG))
+
+  /**
+   * Set logical plan link recursively if unset.
+   */
+  def setLogicalLink(logicalPlan: LogicalPlan): Unit = {
+    setLogicalLink(logicalPlan, false)
+  }
+
+  private def setLogicalLink(logicalPlan: LogicalPlan, inherited: Boolean = 
false): Unit = {
+    // Stop at a descendant which is the root of a sub-tree transformed from 
another logical node.
+    if (inherited && getTagValue(SparkPlan.LOGICAL_PLAN_TAG).isDefined) {
+      return
+    }
+
+    val tag = if (inherited) {
+      SparkPlan.LOGICAL_PLAN_INHERITED_TAG
+    } else {
+      SparkPlan.LOGICAL_PLAN_TAG
+    }
+    setTagValue(tag, logicalPlan)
+    children.foreach(_.setLogicalLink(logicalPlan, true))
+  }
+
+  /**
    * @return All metrics containing metrics of this SparkPlan.
    */
   def metrics: Map[String, SQLMetric] = Map.empty
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala
index 3cd02b9..8c7752c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.sql.execution
 
 import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, 
QueryStageExec}
 import org.apache.spark.sql.execution.exchange.ReusedExchangeExec
 import org.apache.spark.sql.execution.metric.SQLMetricInfo
 import org.apache.spark.sql.internal.SQLConf
@@ -53,6 +54,8 @@ private[execution] object SparkPlanInfo {
     val children = plan match {
       case ReusedExchangeExec(_, child) => child :: Nil
       case ReusedSubqueryExec(child) => child :: Nil
+      case a: AdaptiveSparkPlanExec => a.executedPlan :: Nil
+      case stage: QueryStageExec => stage.plan :: Nil
       case _ => plan.children ++ plan.subqueries
     }
     val metrics = plan.metrics.toSeq.map { case (key, metric) =>
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala
index 2a4a1c8..dc7fb77 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala
@@ -21,6 +21,7 @@ import org.apache.spark.SparkContext
 import org.apache.spark.sql._
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.execution.adaptive.LogicalQueryStageStrategy
 import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, 
FileSourceStrategy}
 import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Strategy
 import org.apache.spark.sql.internal.SQLConf
@@ -36,6 +37,7 @@ class SparkPlanner(
   override def strategies: Seq[Strategy] =
     experimentalMethods.extraStrategies ++
       extraPlanningStrategies ++ (
+      LogicalQueryStageStrategy ::
       PythonEvals ::
       DataSourceV2Strategy ::
       FileSourceStrategy ::
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index c403149..faf2fdd 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.plans.physical._
 import org.apache.spark.sql.catalyst.streaming.InternalOutputModes
+import org.apache.spark.sql.execution.adaptive.LogicalQueryStage
 import org.apache.spark.sql.execution.columnar.{InMemoryRelation, 
InMemoryTableScanExec}
 import org.apache.spark.sql.execution.command._
 import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
@@ -58,6 +59,8 @@ case class PlanLater(plan: LogicalPlan) extends LeafExecNode {
   protected override def doExecute(): RDD[InternalRow] = {
     throw new UnsupportedOperationException()
   }
+
+  override def setLogicalLink(logicalPlan: LogicalPlan): Unit = {}
 }
 
 abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
@@ -69,7 +72,7 @@ abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
         case ReturnAnswer(rootPlan) => rootPlan
         case _ => plan
       }
-      p.setTagValue(SparkPlan.LOGICAL_PLAN_TAG, logicalPlan)
+      p.setLogicalLink(logicalPlan)
       p
     }
   }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
index 99dcca8..92e80dc 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution
 
 import java.io.Writer
 import java.util.Locale
+import java.util.concurrent.atomic.AtomicInteger
 
 import scala.collection.mutable
 import scala.util.control.NonFatal
@@ -551,56 +552,6 @@ object WholeStageCodegenExec {
   }
 }
 
-object WholeStageCodegenId {
-  // codegenStageId: ID for codegen stages within a query plan.
-  // It does not affect equality, nor does it participate in destructuring 
pattern matching
-  // of WholeStageCodegenExec.
-  //
-  // This ID is used to help differentiate between codegen stages. It is 
included as a part
-  // of the explain output for physical plans, e.g.
-  //
-  // == Physical Plan ==
-  // *(5) SortMergeJoin [x#3L], [y#9L], Inner
-  // :- *(2) Sort [x#3L ASC NULLS FIRST], false, 0
-  // :  +- Exchange hashpartitioning(x#3L, 200)
-  // :     +- *(1) Project [(id#0L % 2) AS x#3L]
-  // :        +- *(1) Filter isnotnull((id#0L % 2))
-  // :           +- *(1) Range (0, 5, step=1, splits=8)
-  // +- *(4) Sort [y#9L ASC NULLS FIRST], false, 0
-  //    +- Exchange hashpartitioning(y#9L, 200)
-  //       +- *(3) Project [(id#6L % 2) AS y#9L]
-  //          +- *(3) Filter isnotnull((id#6L % 2))
-  //             +- *(3) Range (0, 5, step=1, splits=8)
-  //
-  // where the ID makes it obvious that not all adjacent codegen'd plan 
operators are of the
-  // same codegen stage.
-  //
-  // The codegen stage ID is also optionally included in the name of the 
generated classes as
-  // a suffix, so that it's easier to associate a generated class back to the 
physical operator.
-  // This is controlled by SQLConf: spark.sql.codegen.useIdInClassName
-  //
-  // The ID is also included in various log messages.
-  //
-  // Within a query, a codegen stage in a plan starts counting from 1, in 
"insertion order".
-  // WholeStageCodegenExec operators are inserted into a plan in depth-first 
post-order.
-  // See CollapseCodegenStages.insertWholeStageCodegen for the definition of 
insertion order.
-  //
-  // 0 is reserved as a special ID value to indicate a temporary 
WholeStageCodegenExec object
-  // is created, e.g. for special fallback handling when an existing 
WholeStageCodegenExec
-  // failed to generate/compile code.
-
-  private val codegenStageCounter: ThreadLocal[Integer] = 
ThreadLocal.withInitial(() => 1)
-
-  def resetPerQuery(): Unit = codegenStageCounter.set(1)
-
-  def getNextStageId(): Int = {
-    val counter = codegenStageCounter
-    val id = counter.get()
-    counter.set(id + 1)
-    id
-  }
-}
-
 /**
  * WholeStageCodegen compiles a subtree of plans that support codegen together 
into single Java
  * function.
@@ -824,8 +775,48 @@ case class WholeStageCodegenExec(child: SparkPlan)(val 
codegenStageId: Int)
 
 /**
  * Find the chained plans that support codegen, collapse them together as 
WholeStageCodegen.
+ *
+ * The `codegenStageCounter` generates ID for codegen stages within a query 
plan.
+ * It does not affect equality, nor does it participate in destructuring 
pattern matching
+ * of WholeStageCodegenExec.
+ *
+ * This ID is used to help differentiate between codegen stages. It is 
included as a part
+ * of the explain output for physical plans, e.g.
+ *
+ * == Physical Plan ==
+ * *(5) SortMergeJoin [x#3L], [y#9L], Inner
+ * :- *(2) Sort [x#3L ASC NULLS FIRST], false, 0
+ * :  +- Exchange hashpartitioning(x#3L, 200)
+ * :     +- *(1) Project [(id#0L % 2) AS x#3L]
+ * :        +- *(1) Filter isnotnull((id#0L % 2))
+ * :           +- *(1) Range (0, 5, step=1, splits=8)
+ * +- *(4) Sort [y#9L ASC NULLS FIRST], false, 0
+ *    +- Exchange hashpartitioning(y#9L, 200)
+ *       +- *(3) Project [(id#6L % 2) AS y#9L]
+ *          +- *(3) Filter isnotnull((id#6L % 2))
+ *             +- *(3) Range (0, 5, step=1, splits=8)
+ *
+ * where the ID makes it obvious that not all adjacent codegen'd plan 
operators are of the
+ * same codegen stage.
+ *
+ * The codegen stage ID is also optionally included in the name of the 
generated classes as
+ * a suffix, so that it's easier to associate a generated class back to the 
physical operator.
+ * This is controlled by SQLConf: spark.sql.codegen.useIdInClassName
+ *
+ * The ID is also included in various log messages.
+ *
+ * Within a query, a codegen stage in a plan starts counting from 1, in 
"insertion order".
+ * WholeStageCodegenExec operators are inserted into a plan in depth-first 
post-order.
+ * See CollapseCodegenStages.insertWholeStageCodegen for the definition of 
insertion order.
+ *
+ * 0 is reserved as a special ID value to indicate a temporary 
WholeStageCodegenExec object
+ * is created, e.g. for special fallback handling when an existing 
WholeStageCodegenExec
+ * failed to generate/compile code.
  */
-case class CollapseCodegenStages(conf: SQLConf) extends Rule[SparkPlan] {
+case class CollapseCodegenStages(
+    conf: SQLConf,
+    codegenStageCounter: AtomicInteger = new AtomicInteger(0))
+  extends Rule[SparkPlan] {
 
   private def supportCodegen(e: Expression): Boolean = e match {
     case e: LeafExpression => true
@@ -869,14 +860,13 @@ case class CollapseCodegenStages(conf: SQLConf) extends 
Rule[SparkPlan] {
     case plan if plan.output.length == 1 && 
plan.output.head.dataType.isInstanceOf[ObjectType] =>
       plan.withNewChildren(plan.children.map(insertWholeStageCodegen))
     case plan: CodegenSupport if supportCodegen(plan) =>
-      
WholeStageCodegenExec(insertInputAdapter(plan))(WholeStageCodegenId.getNextStageId())
+      
WholeStageCodegenExec(insertInputAdapter(plan))(codegenStageCounter.incrementAndGet())
     case other =>
       other.withNewChildren(other.children.map(insertWholeStageCodegen))
   }
 
   def apply(plan: SparkPlan): SparkPlan = {
     if (conf.wholeStageEnabled) {
-      WholeStageCodegenId.resetPerQuery()
       insertWholeStageCodegen(plan)
     } else {
       plan
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
new file mode 100644
index 0000000..606fbd8
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
@@ -0,0 +1,425 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.adaptive
+
+import java.util
+import java.util.concurrent.LinkedBlockingQueue
+
+import scala.collection.JavaConverters._
+import scala.collection.concurrent.TrieMap
+import scala.collection.mutable
+import scala.concurrent.ExecutionContext
+import scala.util.control.NonFatal
+
+import org.apache.spark.SparkException
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer}
+import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor}
+import org.apache.spark.sql.execution._
+import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec._
+import org.apache.spark.sql.execution.exchange._
+import 
org.apache.spark.sql.execution.ui.SparkListenerSQLAdaptiveExecutionUpdate
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.util.ThreadUtils
+
+/**
+ * A root node to execute the query plan adaptively. It splits the query plan 
into independent
+ * stages and executes them in order according to their dependencies. The 
query stage
+ * materializes its output at the end. When one stage completes, the data 
statistics of the
+ * materialized output will be used to optimize the remainder of the query.
+ *
+ * To create query stages, we traverse the query tree bottom up. When we hit 
an exchange node,
+ * and if all the child query stages of this exchange node are materialized, 
we create a new
+ * query stage for this exchange node. The new stage is then materialized 
asynchronously once it
+ * is created.
+ *
+ * When one query stage finishes materialization, the rest query is 
re-optimized and planned based
+ * on the latest statistics provided by all materialized stages. Then we 
traverse the query plan
+ * again and create more stages if possible. After all stages have been 
materialized, we execute
+ * the rest of the plan.
+ */
+case class AdaptiveSparkPlanExec(
+    initialPlan: SparkPlan,
+    @transient session: SparkSession,
+    @transient subqueryMap: Map[Long, ExecSubqueryExpression],
+    @transient stageCache: TrieMap[SparkPlan, QueryStageExec])
+  extends LeafExecNode {
+
+  @transient private val lock = new Object()
+
+  // The logical plan optimizer for re-optimizing the current logical plan.
+  @transient private val optimizer = new RuleExecutor[LogicalPlan] {
+    // TODO add more optimization rules
+    override protected def batches: Seq[Batch] = Seq()
+  }
+
+  // A list of physical plan rules to be applied before creation of query 
stages. The physical
+  // plan should reach a final status of query stages (i.e., no more addition 
or removal of
+  // Exchange nodes) after running these rules.
+  @transient private val queryStagePreparationRules: Seq[Rule[SparkPlan]] = 
Seq(
+    PlanAdaptiveSubqueries(subqueryMap),
+    EnsureRequirements(conf)
+  )
+
+  // A list of physical optimizer rules to be applied to a new stage before 
its execution. These
+  // optimizations should be stage-independent.
+  @transient private val queryStageOptimizerRules: Seq[Rule[SparkPlan]] = Seq(
+    CollapseCodegenStages(conf)
+  )
+
+  @volatile private var currentPhysicalPlan = initialPlan
+
+  private var isFinalPlan = false
+
+  private var currentStageId = 0
+
+  /**
+   * Return type for `createQueryStages`
+   * @param newPlan the new plan with created query stages.
+   * @param allChildStagesMaterialized whether all child stages have been 
materialized.
+   * @param newStages the newly created query stages, including new reused 
query stages.
+   */
+  private case class CreateStageResult(
+    newPlan: SparkPlan,
+    allChildStagesMaterialized: Boolean,
+    newStages: Seq[(Exchange, QueryStageExec)])
+
+  def executedPlan: SparkPlan = currentPhysicalPlan
+
+  override def conf: SQLConf = session.sessionState.conf
+
+  override def output: Seq[Attribute] = initialPlan.output
+
+  override def doCanonicalize(): SparkPlan = initialPlan.canonicalized
+
+  override def doExecute(): RDD[InternalRow] = lock.synchronized {
+    if (isFinalPlan) {
+      currentPhysicalPlan.execute()
+    } else {
+      val executionId = Option(
+        
session.sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)).map(_.toLong)
+      var currentLogicalPlan = currentPhysicalPlan.logicalLink.get
+      var result = createQueryStages(currentPhysicalPlan)
+      val events = new LinkedBlockingQueue[StageMaterializationEvent]()
+      val errors = new mutable.ArrayBuffer[SparkException]()
+      while (!result.allChildStagesMaterialized) {
+        currentPhysicalPlan = result.newPlan
+        currentLogicalPlan = updateLogicalPlan(currentLogicalPlan, 
result.newStages)
+        currentPhysicalPlan.setTagValue(SparkPlan.LOGICAL_PLAN_TAG, 
currentLogicalPlan)
+        executionId.foreach(onUpdatePlan)
+
+        // Start materialization of all new stages.
+        result.newStages.map(_._2).foreach { stage =>
+          stage.materialize().onComplete { res =>
+            if (res.isSuccess) {
+              events.offer(StageSuccess(stage, res.get))
+            } else {
+              events.offer(StageFailure(stage, res.failed.get))
+            }
+          }(AdaptiveSparkPlanExec.executionContext)
+        }
+
+        // Wait on the next completed stage, which indicates new stats are 
available and probably
+        // new stages can be created. There might be other stages that finish 
at around the same
+        // time, so we process those stages too in order to reduce re-planning.
+        val nextMsg = events.take()
+        val rem = new util.ArrayList[StageMaterializationEvent]()
+        events.drainTo(rem)
+        (Seq(nextMsg) ++ rem.asScala).foreach {
+          case StageSuccess(stage, res) =>
+            stage.resultOption = Some(res)
+          case StageFailure(stage, ex) =>
+            errors.append(
+              new SparkException(s"Failed to materialize query stage: 
${stage.treeString}", ex))
+        }
+
+        // In case of errors, we cancel all running stages and throw exception.
+        if (errors.nonEmpty) {
+          cleanUpAndThrowException(errors)
+        }
+
+        // Do re-planning and try creating new stages on the new physical plan.
+        val (newPhysicalPlan, newLogicalPlan) = reOptimize(currentLogicalPlan)
+        currentPhysicalPlan = newPhysicalPlan
+        currentLogicalPlan = newLogicalPlan
+        result = createQueryStages(currentPhysicalPlan)
+      }
+
+      // Run the final plan when there's no more unfinished stages.
+      currentPhysicalPlan = applyPhysicalRules(result.newPlan, 
queryStageOptimizerRules)
+      currentPhysicalPlan.setTagValue(SparkPlan.LOGICAL_PLAN_TAG, 
currentLogicalPlan)
+      isFinalPlan = true
+      logDebug(s"Final plan: $currentPhysicalPlan")
+      executionId.foreach(onUpdatePlan)
+
+      currentPhysicalPlan.execute()
+    }
+  }
+
+  override def verboseString(maxFields: Int): String = simpleString(maxFields)
+
+  override def simpleString(maxFields: Int): String =
+    s"AdaptiveSparkPlan(isFinalPlan=$isFinalPlan)"
+
+  override def generateTreeString(
+      depth: Int,
+      lastChildren: Seq[Boolean],
+      append: String => Unit,
+      verbose: Boolean,
+      prefix: String = "",
+      addSuffix: Boolean = false,
+      maxFields: Int): Unit = {
+    super.generateTreeString(depth, lastChildren, append, verbose, prefix, 
addSuffix, maxFields)
+    currentPhysicalPlan.generateTreeString(
+      depth + 1, lastChildren :+ true, append, verbose, "", addSuffix = false, 
maxFields)
+  }
+
+  /**
+   * This method is called recursively to traverse the plan tree bottom-up and 
create a new query
+   * stage or try reusing an existing stage if the current node is an 
[[Exchange]] node and all of
+   * its child stages have been materialized.
+   *
+   * With each call, it returns:
+   * 1) The new plan replaced with [[QueryStageExec]] nodes where new stages 
are created.
+   * 2) Whether the child query stages (if any) of the current node have all 
been materialized.
+   * 3) A list of the new query stages that have been created.
+   */
+  private def createQueryStages(plan: SparkPlan): CreateStageResult = plan 
match {
+    case e: Exchange =>
+      // First have a quick check in the `stageCache` without having to 
traverse down the node.
+      stageCache.get(e.canonicalized) match {
+        case Some(existingStage) if conf.exchangeReuseEnabled =>
+          val reusedStage = reuseQueryStage(existingStage, e.output)
+          // When reusing a stage, we treat it a new stage regardless of 
whether the existing stage
+          // has been materialized or not. Thus we won't skip re-optimization 
for a reused stage.
+          CreateStageResult(newPlan = reusedStage,
+            allChildStagesMaterialized = false, newStages = Seq((e, 
reusedStage)))
+
+        case _ =>
+          val result = createQueryStages(e.child)
+          val newPlan = 
e.withNewChildren(Seq(result.newPlan)).asInstanceOf[Exchange]
+          // Create a query stage only when all the child query stages are 
ready.
+          if (result.allChildStagesMaterialized) {
+            var newStage = newQueryStage(newPlan)
+            if (conf.exchangeReuseEnabled) {
+              // Check the `stageCache` again for reuse. If a match is found, 
ditch the new stage
+              // and reuse the existing stage found in the `stageCache`, 
otherwise update the
+              // `stageCache` with the new stage.
+              val queryStage = stageCache.getOrElseUpdate(e.canonicalized, 
newStage)
+              if (queryStage.ne(newStage)) {
+                newStage = reuseQueryStage(queryStage, e.output)
+              }
+            }
+
+            // We've created a new stage, which is obviously not ready yet.
+            CreateStageResult(newPlan = newStage,
+              allChildStagesMaterialized = false, newStages = Seq((e, 
newStage)))
+          } else {
+            CreateStageResult(newPlan = newPlan,
+              allChildStagesMaterialized = false, newStages = result.newStages)
+          }
+      }
+
+    case q: QueryStageExec =>
+      CreateStageResult(newPlan = q,
+        allChildStagesMaterialized = q.resultOption.isDefined, newStages = 
Seq.empty)
+
+    case _ =>
+      if (plan.children.isEmpty) {
+        CreateStageResult(newPlan = plan, allChildStagesMaterialized = true, 
newStages = Seq.empty)
+      } else {
+        val results = plan.children.map(createQueryStages)
+        CreateStageResult(
+          newPlan = plan.withNewChildren(results.map(_.newPlan)),
+          allChildStagesMaterialized = 
results.forall(_.allChildStagesMaterialized),
+          newStages = results.flatMap(_.newStages))
+      }
+  }
+
+  private def newQueryStage(e: Exchange): QueryStageExec = {
+    val optimizedPlan = applyPhysicalRules(e.child, queryStageOptimizerRules)
+    val queryStage = e match {
+      case s: ShuffleExchangeExec =>
+        ShuffleQueryStageExec(currentStageId, s.copy(child = optimizedPlan))
+      case b: BroadcastExchangeExec =>
+        BroadcastQueryStageExec(currentStageId, b.copy(child = optimizedPlan))
+    }
+    currentStageId += 1
+    queryStage
+  }
+
+  private def reuseQueryStage(s: QueryStageExec, output: Seq[Attribute]): 
QueryStageExec = {
+    val queryStage = ReusedQueryStageExec(currentStageId, s, output)
+    currentStageId += 1
+    queryStage
+  }
+
+  /**
+   * Returns the updated logical plan after new query stages have been created 
and the physical
+   * plan has been updated with the newly created stages.
+   * 1. If the new query stage can be mapped to an integral logical sub-tree, 
replace the
+   *    corresponding logical sub-tree with a leaf node [[LogicalQueryStage]] 
referencing the new
+   *    query stage. For example:
+   *        Join                   SMJ                      SMJ
+   *      /     \                /    \                   /    \
+   *    r1      r2    =>    Xchg1     Xchg2    =>    Stage1     Stage2
+   *                          |        |
+   *                          r1       r2
+   *    The updated plan node will be:
+   *                               Join
+   *                             /     \
+   *    LogicalQueryStage1(Stage1)     LogicalQueryStage2(Stage2)
+   *
+   * 2. Otherwise (which means the new query stage can only be mapped to part 
of a logical
+   *    sub-tree), replace the corresponding logical sub-tree with a leaf node
+   *    [[LogicalQueryStage]] referencing to the top physical node into which 
this logical node is
+   *    transformed during physical planning. For example:
+   *     Agg           HashAgg          HashAgg
+   *      |               |                |
+   *    child    =>     Xchg      =>     Stage1
+   *                      |
+   *                   HashAgg
+   *                      |
+   *                    child
+   *    The updated plan node will be:
+   *    LogicalQueryStage(HashAgg - Stage1)
+   */
+  private def updateLogicalPlan(
+      logicalPlan: LogicalPlan,
+      newStages: Seq[(Exchange, QueryStageExec)]): LogicalPlan = {
+    var currentLogicalPlan = logicalPlan
+    newStages.foreach {
+      case (exchange, stage) =>
+        // Get the corresponding logical node for `exchange`. If `exchange` 
has been transformed
+        // from a `Repartition`, it should have `logicalLink` available by 
itself; otherwise
+        // traverse down to find the first node that is not generated by 
`EnsureRequirements`.
+        val logicalNodeOpt = exchange.logicalLink.orElse(exchange.collectFirst 
{
+          case p if p.logicalLink.isDefined => p.logicalLink.get
+        })
+        assert(logicalNodeOpt.isDefined)
+        val logicalNode = logicalNodeOpt.get
+        val physicalNode = currentPhysicalPlan.collectFirst {
+          case p if p.eq(stage) || p.logicalLink.exists(logicalNode.eq) => p
+        }
+        assert(physicalNode.isDefined)
+        // Replace the corresponding logical node with LogicalQueryStage
+        val newLogicalNode = LogicalQueryStage(logicalNode, physicalNode.get)
+        val newLogicalPlan = currentLogicalPlan.transformDown {
+          case p if p.eq(logicalNode) => newLogicalNode
+        }
+        assert(newLogicalPlan != currentLogicalPlan,
+          s"logicalNode: $logicalNode; " +
+            s"logicalPlan: $currentLogicalPlan " +
+            s"physicalPlan: $currentPhysicalPlan" +
+            s"stage: $stage")
+        currentLogicalPlan = newLogicalPlan
+    }
+    currentLogicalPlan
+  }
+
+  /**
+   * Re-optimize and run physical planning on the current logical plan based 
on the latest stats.
+   */
+  private def reOptimize(logicalPlan: LogicalPlan): (SparkPlan, LogicalPlan) = 
{
+    logicalPlan.invalidateStatsCache()
+    val optimized = optimizer.execute(logicalPlan)
+    SparkSession.setActiveSession(session)
+    val sparkPlan = 
session.sessionState.planner.plan(ReturnAnswer(optimized)).next()
+    val newPlan = applyPhysicalRules(sparkPlan, queryStagePreparationRules)
+    (newPlan, optimized)
+  }
+
+  /**
+   * Notify the listeners of the physical plan change.
+   */
+  private def onUpdatePlan(executionId: Long): Unit = {
+    
session.sparkContext.listenerBus.post(SparkListenerSQLAdaptiveExecutionUpdate(
+      executionId,
+      SQLExecution.getQueryExecution(executionId).toString,
+      SparkPlanInfo.fromSparkPlan(this)))
+  }
+
+  /**
+   * Cancel all running stages with best effort and throw an Exception 
containing all stage
+   * materialization errors and stage cancellation errors.
+   */
+  private def cleanUpAndThrowException(errors: Seq[SparkException]): Unit = {
+    val runningStages = currentPhysicalPlan.collect {
+      case s: QueryStageExec => s
+    }
+    val cancelErrors = new mutable.ArrayBuffer[SparkException]()
+    try {
+      runningStages.foreach { s =>
+        try {
+          s.cancel()
+        } catch {
+          case NonFatal(t) =>
+            cancelErrors.append(
+              new SparkException(s"Failed to cancel query stage: 
${s.treeString}", t))
+        }
+      }
+    } finally {
+      val ex = new SparkException(
+        "Adaptive execution failed due to stage materialization failures.", 
errors.head)
+      errors.tail.foreach(ex.addSuppressed)
+      cancelErrors.foreach(ex.addSuppressed)
+      throw ex
+    }
+  }
+}
+
+object AdaptiveSparkPlanExec {
+  private val executionContext = ExecutionContext.fromExecutorService(
+    ThreadUtils.newDaemonCachedThreadPool("QueryStageCreator", 16))
+
+  /**
+   * Creates the list of physical plan rules to be applied before creation of 
query stages.
+   */
+  def createQueryStagePreparationRules(
+      conf: SQLConf,
+      subqueryMap: Map[Long, ExecSubqueryExpression]): Seq[Rule[SparkPlan]] = {
+    Seq(
+      PlanAdaptiveSubqueries(subqueryMap),
+      EnsureRequirements(conf))
+  }
+
+  /**
+   * Apply a list of physical operator rules on a [[SparkPlan]].
+   */
+  def applyPhysicalRules(plan: SparkPlan, rules: Seq[Rule[SparkPlan]]): 
SparkPlan = {
+    rules.foldLeft(plan) { case (sp, rule) => rule.apply(sp) }
+  }
+}
+
+/**
+ * The event type for stage materialization.
+ */
+sealed trait StageMaterializationEvent
+
+/**
+ * The materialization of a query stage completed with success.
+ */
+case class StageSuccess(stage: QueryStageExec, result: Any) extends 
StageMaterializationEvent
+
+/**
+ * The materialization of a query stage hit an error and failed.
+ */
+case class StageFailure(stage: QueryStageExec, error: Throwable) extends 
StageMaterializationEvent
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala
new file mode 100644
index 0000000..a1b0e29
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala
@@ -0,0 +1,131 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.adaptive
+
+import scala.collection.concurrent.TrieMap
+import scala.collection.mutable
+
+import org.apache.spark.sql.{execution, SparkSession}
+import org.apache.spark.sql.catalyst.expressions
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.execution._
+import org.apache.spark.sql.execution.command.ExecutedCommandExec
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types.StructType
+
+/**
+ * This rule wraps the query plan with an [[AdaptiveSparkPlanExec]], which 
executes the query plan
+ * and re-optimize the plan during execution based on runtime data statistics.
+ *
+ * Note that this rule is stateful and thus should not be reused across query 
executions.
+ */
+case class InsertAdaptiveSparkPlan(session: SparkSession) extends 
Rule[SparkPlan] {
+
+  private val conf = session.sessionState.conf
+
+  // Exchange-reuse is shared across the entire query, including sub-queries.
+  private val stageCache = new TrieMap[SparkPlan, QueryStageExec]()
+
+  override def apply(plan: SparkPlan): SparkPlan = plan match {
+    case _: ExecutedCommandExec => plan
+    case _ if conf.runtimeReoptimizationEnabled && supportAdaptive(plan) =>
+      try {
+        // Plan sub-queries recursively and pass in the shared stage cache for 
exchange reuse. Fall
+        // back to non-adaptive mode if adaptive execution is supported in any 
of the sub-queries.
+        val subqueryMap = buildSubqueryMap(plan)
+        // Run preparation rules.
+        val preparations = 
AdaptiveSparkPlanExec.createQueryStagePreparationRules(
+          session.sessionState.conf, subqueryMap)
+        val newPlan = AdaptiveSparkPlanExec.applyPhysicalRules(plan, 
preparations)
+        logDebug(s"Adaptive execution enabled for plan: $plan")
+        AdaptiveSparkPlanExec(newPlan, session, subqueryMap, stageCache)
+      } catch {
+        case SubqueryAdaptiveNotSupportedException(subquery) =>
+          logWarning(s"${SQLConf.RUNTIME_REOPTIMIZATION_ENABLED.key} is 
enabled " +
+            s"but is not supported for sub-query: $subquery.")
+          plan
+      }
+    case _ =>
+      if (conf.runtimeReoptimizationEnabled) {
+        logWarning(s"${SQLConf.RUNTIME_REOPTIMIZATION_ENABLED.key} is enabled 
" +
+          s"but is not supported for query: $plan.")
+      }
+      plan
+  }
+
+  private def supportAdaptive(plan: SparkPlan): Boolean = {
+    sanityCheck(plan) &&
+      !plan.logicalLink.exists(_.isStreaming) &&
+      plan.children.forall(supportAdaptive)
+  }
+
+  private def sanityCheck(plan: SparkPlan): Boolean =
+    plan.logicalLink.isDefined
+
+  /**
+   * Returns an expression-id-to-execution-plan map for all the sub-queries.
+   * For each sub-query, generate the adaptive execution plan for each 
sub-query by applying this
+   * rule, or reuse the execution plan from another sub-query of the same 
semantics if possible.
+   */
+  private def buildSubqueryMap(plan: SparkPlan): Map[Long, 
ExecSubqueryExpression] = {
+    val subqueryMapBuilder = mutable.HashMap.empty[Long, 
ExecSubqueryExpression]
+    plan.foreach(_.expressions.foreach(_.foreach {
+      case expressions.ScalarSubquery(p, _, exprId)
+          if !subqueryMapBuilder.contains(exprId.id) =>
+        val executedPlan = getExecutedPlan(p)
+        val scalarSubquery = execution.ScalarSubquery(
+          SubqueryExec(s"subquery${exprId.id}", executedPlan), exprId)
+        subqueryMapBuilder.put(exprId.id, scalarSubquery)
+      case _ =>
+    }))
+
+    // Reuse subqueries
+    if (session.sessionState.conf.subqueryReuseEnabled) {
+      // Build a hash map using schema of subqueries to avoid O(N*N) 
sameResult calls.
+      val reuseMap = mutable.HashMap[StructType, 
mutable.ArrayBuffer[BaseSubqueryExec]]()
+      subqueryMapBuilder.keySet.foreach { exprId =>
+        val sub = subqueryMapBuilder(exprId)
+        val sameSchema =
+          reuseMap.getOrElseUpdate(sub.plan.schema, mutable.ArrayBuffer.empty)
+        val sameResult = sameSchema.find(_.sameResult(sub.plan))
+        if (sameResult.isDefined) {
+          val newExpr = sub.withNewPlan(ReusedSubqueryExec(sameResult.get))
+          subqueryMapBuilder.update(exprId, newExpr)
+        } else {
+          sameSchema += sub.plan
+        }
+      }
+    }
+
+    subqueryMapBuilder.toMap
+  }
+
+  private def getExecutedPlan(plan: LogicalPlan): SparkPlan = {
+    val queryExec = new QueryExecution(session, plan)
+    // Apply the same instance of this rule to sub-queries so that sub-queries 
all share the
+    // same `stageCache` for Exchange reuse.
+    val adaptivePlan = this.apply(queryExec.sparkPlan)
+    if (!adaptivePlan.isInstanceOf[AdaptiveSparkPlanExec]) {
+      throw SubqueryAdaptiveNotSupportedException(plan)
+    }
+    adaptivePlan
+  }
+}
+
+private case class SubqueryAdaptiveNotSupportedException(plan: LogicalPlan) 
extends Exception {}
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/LogicalQueryStage.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/LogicalQueryStage.scala
new file mode 100644
index 0000000..9914edd
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/LogicalQueryStage.scala
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.adaptive
+
+import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder}
+import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, 
Statistics}
+import org.apache.spark.sql.catalyst.plans.physical.Partitioning
+import org.apache.spark.sql.execution.SparkPlan
+
+/**
+ * The LogicalPlan wrapper for a [[QueryStageExec]], or a snippet of physical 
plan containing
+ * a [[QueryStageExec]], in which all ancestor nodes of the [[QueryStageExec]] 
are linked to
+ * the same logical node.
+ *
+ * For example, a logical Aggregate can be transformed into FinalAgg - Shuffle 
- PartialAgg, in
+ * which the Shuffle will be wrapped into a [[QueryStageExec]], thus the 
[[LogicalQueryStage]]
+ * will have FinalAgg - QueryStageExec as its physical plan.
+ */
+// TODO we can potentially include only [[QueryStageExec]] in this class if we 
make the aggregation
+// planning aware of partitioning.
+case class LogicalQueryStage(
+    logicalPlan: LogicalPlan,
+    physicalPlan: SparkPlan) extends LeafNode {
+
+  override def output: Seq[Attribute] = logicalPlan.output
+  override val isStreaming: Boolean = logicalPlan.isStreaming
+  override val outputOrdering: Seq[SortOrder] = physicalPlan.outputOrdering
+
+  override def computeStats(): Statistics = {
+    // TODO this is not accurate when there is other physical nodes above 
QueryStageExec.
+    val physicalStats = physicalPlan.collectFirst {
+      case s: QueryStageExec => s
+    }.flatMap(_.computeStats())
+    if (physicalStats.isDefined) {
+      logDebug(s"Physical stats available as ${physicalStats.get} for plan: 
$physicalPlan")
+    } else {
+      logDebug(s"Physical stats not available for plan: $physicalPlan")
+    }
+    physicalStats.getOrElse(logicalPlan.stats)
+  }
+}
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/LogicalQueryStageStrategy.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/LogicalQueryStageStrategy.scala
new file mode 100644
index 0000000..a0d07a6
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/LogicalQueryStageStrategy.scala
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.adaptive
+
+import org.apache.spark.sql.Strategy
+import org.apache.spark.sql.catalyst.expressions.PredicateHelper
+import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
+import org.apache.spark.sql.catalyst.plans.logical.{Join, LogicalPlan}
+import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, 
BroadcastNestedLoopJoinExec, BuildLeft, BuildRight}
+
+/**
+ * Strategy for plans containing [[LogicalQueryStage]] nodes:
+ * 1. Transforms [[LogicalQueryStage]] to its corresponding physical plan that 
is either being
+ *    executed or has already completed execution.
+ * 2. Transforms [[Join]] which has one child relation already planned and 
executed as a
+ *    [[BroadcastQueryStageExec]]. This is to prevent reversing a broadcast 
stage into a shuffle
+ *    stage in case of the larger join child relation finishes before the 
smaller relation. Note
+ *    that this rule needs to applied before regular join strategies.
+ */
+object LogicalQueryStageStrategy extends Strategy with PredicateHelper {
+
+  private def isBroadcastStage(plan: LogicalPlan): Boolean = plan match {
+    case LogicalQueryStage(_, physicalPlan)
+        if BroadcastQueryStageExec.isBroadcastQueryStageExec(physicalPlan) =>
+      true
+    case _ => false
+  }
+
+  def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
+    case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, 
right, hint)
+        if isBroadcastStage(left) || isBroadcastStage(right) =>
+      val buildSide = if (isBroadcastStage(left)) BuildLeft else BuildRight
+      Seq(BroadcastHashJoinExec(
+        leftKeys, rightKeys, joinType, buildSide, condition, planLater(left), 
planLater(right)))
+
+    case j @ Join(left, right, joinType, condition, _)
+        if isBroadcastStage(left) || isBroadcastStage(right) =>
+      val buildSide = if (isBroadcastStage(left)) BuildLeft else BuildRight
+      BroadcastNestedLoopJoinExec(
+        planLater(left), planLater(right), buildSide, joinType, condition) :: 
Nil
+
+    case q: LogicalQueryStage =>
+      q.physicalPlan :: Nil
+
+    case _ => Nil
+  }
+}
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala
new file mode 100644
index 0000000..4af7432
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.adaptive
+
+import org.apache.spark.sql.catalyst.expressions
+import org.apache.spark.sql.catalyst.expressions.ListQuery
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.execution.{ExecSubqueryExpression, SparkPlan}
+
+case class PlanAdaptiveSubqueries(
+    subqueryMap: Map[Long, ExecSubqueryExpression]) extends Rule[SparkPlan] {
+
+  def apply(plan: SparkPlan): SparkPlan = {
+    plan.transformAllExpressions {
+      case expressions.ScalarSubquery(_, _, exprId) =>
+        subqueryMap(exprId.id)
+      case expressions.InSubquery(_, ListQuery(_, _, exprId, _)) =>
+        subqueryMap(exprId.id)
+    }
+  }
+}
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala
new file mode 100644
index 0000000..98cb7d0
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala
@@ -0,0 +1,210 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.adaptive
+
+import scala.concurrent.Future
+
+import org.apache.spark.{FutureAction, MapOutputStatistics}
+import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.logical.Statistics
+import org.apache.spark.sql.catalyst.plans.physical.Partitioning
+import org.apache.spark.sql.execution._
+import org.apache.spark.sql.execution.exchange._
+
+
+/**
+ * A query stage is an independent subgraph of the query plan. Query stage 
materializes its output
+ * before proceeding with further operators of the query plan. The data 
statistics of the
+ * materialized output can be used to optimize subsequent query stages.
+ *
+ * There are 2 kinds of query stages:
+ *   1. Shuffle query stage. This stage materializes its output to shuffle 
files, and Spark launches
+ *      another job to execute the further operators.
+ *   2. Broadcast query stage. This stage materializes its output to an array 
in driver JVM. Spark
+ *      broadcasts the array before executing the further operators.
+ */
+abstract class QueryStageExec extends LeafExecNode {
+
+  /**
+   * An id of this query stage which is unique in the entire query plan.
+   */
+  val id: Int
+
+  /**
+   * The sub-tree of the query plan that belongs to this query stage.
+   */
+  val plan: SparkPlan
+
+  /**
+   * Materialize this query stage, to prepare for the execution, like 
submitting map stages,
+   * broadcasting data, etc. The caller side can use the returned [[Future]] 
to wait until this
+   * stage is ready.
+   */
+  def doMaterialize(): Future[Any]
+
+  /**
+   * Cancel the stage materialization if in progress; otherwise do nothing.
+   */
+  def cancel(): Unit
+
+  /**
+   * Materialize this query stage, to prepare for the execution, like 
submitting map stages,
+   * broadcasting data, etc. The caller side can use the returned [[Future]] 
to wait until this
+   * stage is ready.
+   */
+  final def materialize(): Future[Any] = executeQuery {
+    doMaterialize()
+  }
+
+  /**
+   * Compute the statistics of the query stage if executed, otherwise None.
+   */
+  def computeStats(): Option[Statistics] = resultOption.map { _ =>
+    // Metrics `dataSize` are available in both `ShuffleExchangeExec` and 
`BroadcastExchangeExec`.
+    Statistics(sizeInBytes = plan.metrics("dataSize").value)
+  }
+
+  @transient
+  @volatile
+  private[adaptive] var resultOption: Option[Any] = None
+
+  override def output: Seq[Attribute] = plan.output
+  override def outputPartitioning: Partitioning = plan.outputPartitioning
+  override def outputOrdering: Seq[SortOrder] = plan.outputOrdering
+  override def executeCollect(): Array[InternalRow] = plan.executeCollect()
+  override def executeTake(n: Int): Array[InternalRow] = plan.executeTake(n)
+  override def executeToIterator(): Iterator[InternalRow] = 
plan.executeToIterator()
+
+  override def doPrepare(): Unit = plan.prepare()
+  override def doExecute(): RDD[InternalRow] = plan.execute()
+  override def doExecuteBroadcast[T](): Broadcast[T] = plan.executeBroadcast()
+  override def doCanonicalize(): SparkPlan = plan.canonicalized
+
+  protected override def stringArgs: Iterator[Any] = Iterator.single(id)
+
+  override def generateTreeString(
+      depth: Int,
+      lastChildren: Seq[Boolean],
+      append: String => Unit,
+      verbose: Boolean,
+      prefix: String = "",
+      addSuffix: Boolean = false,
+      maxFields: Int): Unit = {
+    super.generateTreeString(depth, lastChildren, append, verbose, prefix, 
addSuffix, maxFields)
+    plan.generateTreeString(
+      depth + 1, lastChildren :+ true, append, verbose, "", false, maxFields)
+  }
+}
+
+/**
+ * A shuffle query stage whose child is a [[ShuffleExchangeExec]].
+ */
+case class ShuffleQueryStageExec(
+    override val id: Int,
+    override val plan: ShuffleExchangeExec) extends QueryStageExec {
+
+  @transient lazy val mapOutputStatisticsFuture: Future[MapOutputStatistics] = 
{
+    if (plan.inputRDD.getNumPartitions == 0) {
+      Future.successful(null)
+    } else {
+      sparkContext.submitMapStage(plan.shuffleDependency)
+    }
+  }
+
+  override def doMaterialize(): Future[Any] = {
+    mapOutputStatisticsFuture
+  }
+
+  override def cancel(): Unit = {
+    mapOutputStatisticsFuture match {
+      case action: FutureAction[MapOutputStatistics] if 
!mapOutputStatisticsFuture.isCompleted =>
+        action.cancel()
+      case _ =>
+    }
+  }
+}
+
+/**
+ * A broadcast query stage whose child is a [[BroadcastExchangeExec]].
+ */
+case class BroadcastQueryStageExec(
+    override val id: Int,
+    override val plan: BroadcastExchangeExec) extends QueryStageExec {
+
+  override def doMaterialize(): Future[Any] = {
+    plan.completionFuture
+  }
+
+  override def cancel(): Unit = {
+    if (!plan.relationFuture.isDone) {
+      sparkContext.cancelJobGroup(plan.runId.toString)
+      plan.relationFuture.cancel(true)
+    }
+  }
+}
+
+object BroadcastQueryStageExec {
+  /**
+   * Returns if the plan is a [[BroadcastQueryStageExec]] or a reused 
[[BroadcastQueryStageExec]].
+   */
+  def isBroadcastQueryStageExec(plan: SparkPlan): Boolean = plan match {
+    case r: ReusedQueryStageExec => isBroadcastQueryStageExec(r.plan)
+    case _: BroadcastQueryStageExec => true
+    case _ => false
+  }
+}
+
+/**
+ * A wrapper for reused query stage to have different output.
+ */
+case class ReusedQueryStageExec(
+    override val id: Int,
+    override val plan: QueryStageExec,
+    override val output: Seq[Attribute]) extends QueryStageExec {
+
+  override def doMaterialize(): Future[Any] = {
+    plan.materialize()
+  }
+
+  override def cancel(): Unit = {
+    plan.cancel()
+  }
+
+  // `ReusedQueryStageExec` can have distinct set of output attribute ids from 
its child, we need
+  // to update the attribute ids in `outputPartitioning` and `outputOrdering`.
+  private lazy val updateAttr: Expression => Expression = {
+    val originalAttrToNewAttr = AttributeMap(plan.output.zip(output))
+    e => e.transform {
+      case attr: Attribute => originalAttrToNewAttr.getOrElse(attr, attr)
+    }
+  }
+
+  override def outputPartitioning: Partitioning = plan.outputPartitioning 
match {
+    case e: Expression => updateAttr(e).asInstanceOf[Partitioning]
+    case other => other
+  }
+
+  override def outputOrdering: Seq[SortOrder] = {
+    plan.outputOrdering.map(updateAttr(_).asInstanceOf[SortOrder])
+  }
+
+  override def computeStats(): Option[Statistics] = plan.computeStats()
+}
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala
index 8017188..36f0d17 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.exchange
 import java.util.UUID
 import java.util.concurrent._
 
-import scala.concurrent.ExecutionContext
+import scala.concurrent.{ExecutionContext, Promise}
 import scala.concurrent.duration.NANOSECONDS
 import scala.util.control.NonFatal
 
@@ -44,7 +44,7 @@ case class BroadcastExchangeExec(
     mode: BroadcastMode,
     child: SparkPlan) extends Exchange {
 
-  private val runId: UUID = UUID.randomUUID
+  private[sql] val runId: UUID = UUID.randomUUID
 
   override lazy val metrics = Map(
     "dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"),
@@ -59,10 +59,20 @@ case class BroadcastExchangeExec(
   }
 
   @transient
+  private lazy val promise = Promise[broadcast.Broadcast[Any]]()
+
+  /**
+   * For registering callbacks on `relationFuture`.
+   * Note that calling this field will not start the execution of broadcast 
job.
+   */
+  @transient
+  lazy val completionFuture: scala.concurrent.Future[broadcast.Broadcast[Any]] 
= promise.future
+
+  @transient
   private val timeout: Long = SQLConf.get.broadcastTimeout
 
   @transient
-  private lazy val relationFuture: Future[broadcast.Broadcast[Any]] = {
+  private[sql] lazy val relationFuture: Future[broadcast.Broadcast[Any]] = {
     // relationFuture is used in "doExecute". Therefore we can get the 
execution id correctly here.
     val executionId = 
sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
     val task = new Callable[broadcast.Broadcast[Any]]() {
@@ -113,20 +123,28 @@ case class BroadcastExchangeExec(
               System.nanoTime() - beforeBroadcast)
 
             SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, 
metrics.values.toSeq)
+            promise.success(broadcasted)
             broadcasted
           } catch {
             // SPARK-24294: To bypass scala bug: 
https://github.com/scala/bug/issues/9554, we throw
             // SparkFatalException, which is a subclass of Exception. 
ThreadUtils.awaitResult
             // will catch this exception and re-throw the wrapped fatal 
throwable.
             case oe: OutOfMemoryError =>
-              throw new SparkFatalException(
+              val ex = new SparkFatalException(
                 new OutOfMemoryError("Not enough memory to build and broadcast 
the table to all " +
                   "worker nodes. As a workaround, you can either disable 
broadcast by setting " +
                   s"${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key} to -1 or 
increase the spark " +
                   s"driver memory by setting ${SparkLauncher.DRIVER_MEMORY} to 
a higher value.")
                   .initCause(oe.getCause))
+              promise.failure(ex)
+              throw ex
             case e if !NonFatal(e) =>
-              throw new SparkFatalException(e)
+              val ex = new SparkFatalException(e)
+              promise.failure(ex)
+              throw ex
+            case e: Throwable =>
+              promise.failure(e)
+              throw e
           }
         }
       }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala
index e4ec76f..c99bf45 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala
@@ -209,7 +209,7 @@ class ExchangeCoordinator(
       var i = 0
       while (i < numExchanges) {
         val exchange = exchanges(i)
-        val shuffleDependency = exchange.prepareShuffleDependency()
+        val shuffleDependency = exchange.shuffleDependency
         shuffleDependencies += shuffleDependency
         if (shuffleDependency.rdd.partitions.length != 0) {
           // submitMapStage does not accept RDD with 0 partition.
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
index 16398e3..31f75e3 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
@@ -87,15 +87,17 @@ case class ShuffleExchangeExec(
     }
   }
 
+  @transient lazy val inputRDD: RDD[InternalRow] = child.execute()
+
   /**
-   * Returns a [[ShuffleDependency]] that will partition rows of its child 
based on
+   * A [[ShuffleDependency]] that will partition rows of its child based on
    * the partitioning scheme defined in `newPartitioning`. Those partitions of
    * the returned ShuffleDependency will be the input of shuffle.
    */
-  private[exchange] def prepareShuffleDependency()
-    : ShuffleDependency[Int, InternalRow, InternalRow] = {
+  @transient
+  lazy val shuffleDependency : ShuffleDependency[Int, InternalRow, 
InternalRow] = {
     ShuffleExchangeExec.prepareShuffleDependency(
-      child.execute(),
+      inputRDD,
       child.output,
       newPartitioning,
       serializer,
@@ -135,7 +137,6 @@ case class ShuffleExchangeExec(
           assert(shuffleRDD.partitions.length == newPartitioning.numPartitions)
           shuffleRDD
         case _ =>
-          val shuffleDependency = prepareShuffleDependency()
           preparePostShuffleRDD(shuffleDependency)
       }
     }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala
index 960d47b..064e0a0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala
@@ -37,6 +37,9 @@ abstract class ExecSubqueryExpression extends 
PlanExpression[BaseSubqueryExec] {
    */
   def updateResult(): Unit
 
+  /** Updates the expression with a new plan. */
+  override def withNewPlan(plan: BaseSubqueryExec): ExecSubqueryExpression
+
   override def canonicalize(attrs: AttributeSeq): ExecSubqueryExpression = {
     withNewPlan(plan.canonicalized.asInstanceOf[BaseSubqueryExec])
       .asInstanceOf[ExecSubqueryExpression]
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala
index e496de1..2c4a7ea 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala
@@ -248,26 +248,26 @@ class SQLAppStatusListener(
     }
   }
 
+  private def toStoredNodes(nodes: Seq[SparkPlanGraphNode]): 
Seq[SparkPlanGraphNodeWrapper] = {
+    nodes.map {
+      case cluster: SparkPlanGraphCluster =>
+        val storedCluster = new SparkPlanGraphClusterWrapper(
+          cluster.id,
+          cluster.name,
+          cluster.desc,
+          toStoredNodes(cluster.nodes),
+          cluster.metrics)
+        new SparkPlanGraphNodeWrapper(null, storedCluster)
+
+      case node =>
+        new SparkPlanGraphNodeWrapper(node, null)
+    }
+  }
+
   private def onExecutionStart(event: SparkListenerSQLExecutionStart): Unit = {
     val SparkListenerSQLExecutionStart(executionId, description, details,
       physicalPlanDescription, sparkPlanInfo, time) = event
 
-    def toStoredNodes(nodes: Seq[SparkPlanGraphNode]): 
Seq[SparkPlanGraphNodeWrapper] = {
-      nodes.map {
-        case cluster: SparkPlanGraphCluster =>
-          val storedCluster = new SparkPlanGraphClusterWrapper(
-            cluster.id,
-            cluster.name,
-            cluster.desc,
-            toStoredNodes(cluster.nodes),
-            cluster.metrics)
-          new SparkPlanGraphNodeWrapper(null, storedCluster)
-
-        case node =>
-          new SparkPlanGraphNodeWrapper(node, null)
-      }
-    }
-
     val planGraph = SparkPlanGraph(sparkPlanInfo)
     val sqlPlanMetrics = planGraph.allNodes.flatMap { node =>
       node.metrics.map { metric => (metric.accumulatorId, metric) }
@@ -288,6 +288,27 @@ class SQLAppStatusListener(
     update(exec)
   }
 
+  private def onAdaptiveExecutionUpdate(event: 
SparkListenerSQLAdaptiveExecutionUpdate): Unit = {
+    val SparkListenerSQLAdaptiveExecutionUpdate(
+      executionId, physicalPlanDescription, sparkPlanInfo) = event
+
+    val planGraph = SparkPlanGraph(sparkPlanInfo)
+    val sqlPlanMetrics = planGraph.allNodes.flatMap { node =>
+      node.metrics.map { metric => (metric.accumulatorId, metric) }
+    }.toMap.values.toList
+
+    val graphToStore = new SparkPlanGraphWrapper(
+      executionId,
+      toStoredNodes(planGraph.nodes),
+      planGraph.edges)
+    kvstore.write(graphToStore)
+
+    val exec = getOrCreateExecution(executionId)
+    exec.physicalPlanDescription = physicalPlanDescription
+    exec.metrics = sqlPlanMetrics
+    update(exec)
+  }
+
   private def onExecutionEnd(event: SparkListenerSQLExecutionEnd): Unit = {
     val SparkListenerSQLExecutionEnd(executionId, time) = event
     Option(liveExecutions.get(executionId)).foreach { exec =>
@@ -320,6 +341,7 @@ class SQLAppStatusListener(
 
   override def onOtherEvent(event: SparkListenerEvent): Unit = event match {
     case e: SparkListenerSQLExecutionStart => onExecutionStart(e)
+    case e: SparkListenerSQLAdaptiveExecutionUpdate => 
onAdaptiveExecutionUpdate(e)
     case e: SparkListenerSQLExecutionEnd => onExecutionEnd(e)
     case e: SparkListenerDriverAccumUpdates => onDriverAccumUpdates(e)
     case _ => // Ignore
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala
index 03d75c4..67d1f27 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala
@@ -28,6 +28,13 @@ import org.apache.spark.scheduler._
 import org.apache.spark.sql.execution.{QueryExecution, SparkPlanInfo}
 
 @DeveloperApi
+case class SparkListenerSQLAdaptiveExecutionUpdate(
+  executionId: Long,
+  physicalPlanDescription: String,
+  sparkPlanInfo: SparkPlanInfo)
+  extends SparkListenerEvent
+
+@DeveloperApi
 case class SparkListenerSQLExecutionStart(
     executionId: Long,
     description: String,
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala
index b864ad1..2b7597e 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala
@@ -96,6 +96,15 @@ object SparkPlanGraph {
       case "InputAdapter" =>
         buildSparkPlanGraphNode(
           planInfo.children.head, nodeIdGenerator, nodes, edges, parent, null, 
exchanges)
+      case "BroadcastQueryStage" | "ShuffleQueryStage" =>
+        if (exchanges.contains(planInfo.children.head)) {
+          // Point to the re-used exchange
+          val node = exchanges(planInfo.children.head)
+          edges += SparkPlanGraphEdge(node.id, parent.id)
+        } else {
+          buildSparkPlanGraphNode(
+            planInfo.children.head, nodeIdGenerator, nodes, edges, parent, 
null, exchanges)
+        }
       case "Subquery" if subgraph != null =>
         // Subquery should not be included in WholeStageCodegen
         buildSparkPlanGraphNode(planInfo, nodeIdGenerator, nodes, edges, 
parent, null, exchanges)
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
new file mode 100644
index 0000000..2cddf7c
--- /dev/null
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
@@ -0,0 +1,316 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.adaptive
+
+import org.apache.spark.sql.QueryTest
+import org.apache.spark.sql.execution.{ReusedSubqueryExec, SparkPlan}
+import org.apache.spark.sql.execution.exchange.Exchange
+import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, 
SortMergeJoinExec}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.test.SharedSQLContext
+
+class AdaptiveQueryExecSuite extends QueryTest with SharedSQLContext {
+  import testImplicits._
+
+  setupTestData()
+
+  private def runAdaptiveAndVerifyResult(query: String): (SparkPlan, 
SparkPlan) = {
+    val dfAdaptive = sql(query)
+    val planBefore = dfAdaptive.queryExecution.executedPlan
+    
assert(planBefore.toString.startsWith("AdaptiveSparkPlan(isFinalPlan=false)"))
+    val result = dfAdaptive.collect()
+    withSQLConf(SQLConf.RUNTIME_REOPTIMIZATION_ENABLED.key -> "false") {
+      val df = sql(query)
+      QueryTest.sameRows(result.toSeq, df.collect().toSeq)
+    }
+    val planAfter = dfAdaptive.queryExecution.executedPlan
+    
assert(planAfter.toString.startsWith("AdaptiveSparkPlan(isFinalPlan=true)"))
+    val adaptivePlan = 
planAfter.asInstanceOf[AdaptiveSparkPlanExec].executedPlan
+    val exchanges = adaptivePlan.collect {
+      case e: Exchange => e
+    }
+    assert(exchanges.isEmpty, "The final plan should not contain any Exchange 
node.")
+    (dfAdaptive.queryExecution.sparkPlan, adaptivePlan)
+  }
+
+  private def findTopLevelBroadcastHashJoin(plan: SparkPlan): 
Seq[BroadcastHashJoinExec] = {
+    plan.collect {
+      case j: BroadcastHashJoinExec => Seq(j)
+      case s: QueryStageExec => findTopLevelBroadcastHashJoin(s.plan)
+    }.flatten
+  }
+
+  private def findTopLevelSortMergeJoin(plan: SparkPlan): 
Seq[SortMergeJoinExec] = {
+    plan.collect {
+      case j: SortMergeJoinExec => Seq(j)
+      case s: QueryStageExec => findTopLevelSortMergeJoin(s.plan)
+    }.flatten
+  }
+
+  private def findReusedExchange(plan: SparkPlan): Seq[ReusedQueryStageExec] = 
{
+    plan.collect {
+      case e: ReusedQueryStageExec => Seq(e)
+      case a: AdaptiveSparkPlanExec => findReusedExchange(a.executedPlan)
+      case s: QueryStageExec => findReusedExchange(s.plan)
+      case p: SparkPlan => p.subqueries.flatMap(findReusedExchange)
+    }.flatten
+  }
+
+  private def findReusedSubquery(plan: SparkPlan): Seq[ReusedSubqueryExec] = {
+    plan.collect {
+      case e: ReusedSubqueryExec => Seq(e)
+      case s: QueryStageExec => findReusedSubquery(s.plan)
+      case p: SparkPlan => p.subqueries.flatMap(findReusedSubquery)
+    }.flatten
+  }
+
+  test("Change merge join to broadcast join") {
+    withSQLConf(
+        SQLConf.RUNTIME_REOPTIMIZATION_ENABLED.key -> "true",
+        SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") {
+      val (plan, adaptivePlan) = runAdaptiveAndVerifyResult(
+        "SELECT * FROM testData join testData2 ON key = a where value = '1'")
+      val smj = findTopLevelSortMergeJoin(plan)
+      assert(smj.size == 1)
+      val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
+      assert(bhj.size == 1)
+    }
+  }
+
+  test("Scalar subquery") {
+    withSQLConf(
+        SQLConf.RUNTIME_REOPTIMIZATION_ENABLED.key -> "true",
+        SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") {
+      val (plan, adaptivePlan) = runAdaptiveAndVerifyResult(
+        "SELECT * FROM testData join testData2 ON key = a " +
+        "where value = (SELECT max(a) from testData3)")
+      val smj = findTopLevelSortMergeJoin(plan)
+      assert(smj.size == 1)
+      val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
+      assert(bhj.size == 1)
+    }
+  }
+
+  test("Scalar subquery in later stages") {
+    withSQLConf(
+        SQLConf.RUNTIME_REOPTIMIZATION_ENABLED.key -> "true",
+        SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") {
+      val (plan, adaptivePlan) = runAdaptiveAndVerifyResult(
+        "SELECT * FROM testData join testData2 ON key = a " +
+        "where (value + a) = (SELECT max(a) from testData3)")
+      val smj = findTopLevelSortMergeJoin(plan)
+      assert(smj.size == 1)
+      val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
+      assert(bhj.size == 1)
+    }
+  }
+
+  test("multiple joins") {
+    withSQLConf(
+        SQLConf.RUNTIME_REOPTIMIZATION_ENABLED.key -> "true",
+        SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") {
+      val (plan, adaptivePlan) = runAdaptiveAndVerifyResult(
+        """
+          |WITH t4 AS (
+          |  SELECT * FROM lowercaseData t2 JOIN testData3 t3 ON t2.n = t3.a
+          |)
+          |SELECT * FROM testData
+          |JOIN testData2 t2 ON key = t2.a
+          |JOIN t4 ON key = t4.a
+          |WHERE value = 1
+        """.stripMargin)
+      val smj = findTopLevelSortMergeJoin(plan)
+      assert(smj.size == 3)
+      val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
+      assert(bhj.size == 2)
+    }
+  }
+
+  test("multiple joins with aggregate") {
+    withSQLConf(
+        SQLConf.RUNTIME_REOPTIMIZATION_ENABLED.key -> "true",
+        SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") {
+      val (plan, adaptivePlan) = runAdaptiveAndVerifyResult(
+        """
+          |WITH t4 AS (
+          |  SELECT * FROM lowercaseData t2 JOIN (
+          |    select a, sum(b) from testData3 group by a
+          |  ) t3 ON t2.n = t3.a
+          |)
+          |SELECT * FROM testData
+          |JOIN testData2 t2 ON key = t2.a
+          |JOIN t4 ON key = t4.a
+          |WHERE value = 1
+        """.stripMargin)
+      val smj = findTopLevelSortMergeJoin(plan)
+      assert(smj.size == 3)
+      val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
+      assert(bhj.size == 2)
+    }
+  }
+
+  test("multiple joins with aggregate 2") {
+    withSQLConf(
+        SQLConf.RUNTIME_REOPTIMIZATION_ENABLED.key -> "true",
+        SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "500") {
+      val (plan, adaptivePlan) = runAdaptiveAndVerifyResult(
+        """
+          |WITH t4 AS (
+          |  SELECT * FROM lowercaseData t2 JOIN (
+          |    select a, max(b) b from testData2 group by a
+          |  ) t3 ON t2.n = t3.b
+          |)
+          |SELECT * FROM testData
+          |JOIN testData2 t2 ON key = t2.a
+          |JOIN t4 ON value = t4.a
+          |WHERE value = 1
+        """.stripMargin)
+      val smj = findTopLevelSortMergeJoin(plan)
+      assert(smj.size == 3)
+      val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
+      assert(bhj.size == 3)
+    }
+  }
+
+  test("Exchange reuse") {
+    withSQLConf(
+        SQLConf.RUNTIME_REOPTIMIZATION_ENABLED.key -> "true",
+        SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") {
+      val (plan, adaptivePlan) = runAdaptiveAndVerifyResult(
+        "SELECT value FROM testData join testData2 ON key = a " +
+        "join (SELECT value v from testData join testData3 ON key = a) on 
value = v")
+      val smj = findTopLevelSortMergeJoin(plan)
+      assert(smj.size == 3)
+      val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
+      assert(bhj.size == 2)
+      val ex = findReusedExchange(adaptivePlan)
+      assert(ex.size == 1)
+    }
+  }
+
+  test("Exchange reuse with subqueries") {
+    withSQLConf(
+        SQLConf.RUNTIME_REOPTIMIZATION_ENABLED.key -> "true",
+        SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") {
+      val (plan, adaptivePlan) = runAdaptiveAndVerifyResult(
+        "SELECT a FROM testData join testData2 ON key = a " +
+        "where value = (SELECT max(a) from testData join testData2 ON key = 
a)")
+      val smj = findTopLevelSortMergeJoin(plan)
+      assert(smj.size == 1)
+      val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
+      assert(bhj.size == 1)
+      val ex = findReusedExchange(adaptivePlan)
+      assert(ex.size == 1)
+    }
+  }
+
+  test("Exchange reuse across subqueries") {
+    withSQLConf(
+        SQLConf.RUNTIME_REOPTIMIZATION_ENABLED.key -> "true",
+        SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80",
+        SQLConf.SUBQUERY_REUSE_ENABLED.key -> "false") {
+      val (plan, adaptivePlan) = runAdaptiveAndVerifyResult(
+        "SELECT a FROM testData join testData2 ON key = a " +
+        "where value >= (SELECT max(a) from testData join testData2 ON key = 
a) " +
+        "and a <= (SELECT max(a) from testData join testData2 ON key = a)")
+      val smj = findTopLevelSortMergeJoin(plan)
+      assert(smj.size == 1)
+      val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
+      assert(bhj.size == 1)
+      val ex = findReusedExchange(adaptivePlan)
+      assert(ex.nonEmpty)
+      val sub = findReusedSubquery(adaptivePlan)
+      assert(sub.isEmpty)
+    }
+  }
+
+  test("Subquery reuse") {
+    withSQLConf(
+        SQLConf.RUNTIME_REOPTIMIZATION_ENABLED.key -> "true",
+        SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") {
+      val (plan, adaptivePlan) = runAdaptiveAndVerifyResult(
+        "SELECT a FROM testData join testData2 ON key = a " +
+        "where value >= (SELECT max(a) from testData join testData2 ON key = 
a) " +
+        "and a <= (SELECT max(a) from testData join testData2 ON key = a)")
+      val smj = findTopLevelSortMergeJoin(plan)
+      assert(smj.size == 1)
+      val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
+      assert(bhj.size == 1)
+      val ex = findReusedExchange(adaptivePlan)
+      assert(ex.isEmpty)
+      val sub = findReusedSubquery(adaptivePlan)
+      assert(sub.nonEmpty)
+    }
+  }
+
+  test("Broadcast exchange reuse across subqueries") {
+    withSQLConf(
+        SQLConf.RUNTIME_REOPTIMIZATION_ENABLED.key -> "true",
+        SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "20000000",
+        SQLConf.SUBQUERY_REUSE_ENABLED.key -> "false") {
+      val (plan, adaptivePlan) = runAdaptiveAndVerifyResult(
+        "SELECT a FROM testData join testData2 ON key = a " +
+        "where value >= (" +
+        "SELECT /*+ broadcast(testData2) */ max(key) from testData join 
testData2 ON key = a) " +
+        "and a <= (" +
+        "SELECT /*+ broadcast(testData2) */ max(value) from testData join 
testData2 ON key = a)")
+      val smj = findTopLevelSortMergeJoin(plan)
+      assert(smj.size == 1)
+      val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
+      assert(bhj.size == 1)
+      val ex = findReusedExchange(adaptivePlan)
+      assert(ex.nonEmpty)
+      assert(ex.head.plan.isInstanceOf[BroadcastQueryStageExec])
+      val sub = findReusedSubquery(adaptivePlan)
+      assert(sub.isEmpty)
+    }
+  }
+
+  test("Union/Except/Intersect queries") {
+    withSQLConf(SQLConf.RUNTIME_REOPTIMIZATION_ENABLED.key -> "true") {
+      runAdaptiveAndVerifyResult(
+        """
+          |SELECT * FROM testData
+          |EXCEPT
+          |SELECT * FROM testData2
+          |UNION ALL
+          |SELECT * FROM testData
+          |INTERSECT ALL
+          |SELECT * FROM testData2
+        """.stripMargin)
+    }
+  }
+
+  test("Subquery de-correlation in Union queries") {
+    withSQLConf(SQLConf.RUNTIME_REOPTIMIZATION_ENABLED.key -> "true") {
+      withTempView("a", "b") {
+        Seq("a" -> 2, "b" -> 1).toDF("id", "num").createTempView("a")
+        Seq("a" -> 2, "b" -> 1).toDF("id", "num").createTempView("b")
+
+        runAdaptiveAndVerifyResult(
+          """
+            |SELECT id,num,source FROM (
+            |  SELECT id, num, 'a' as source FROM a
+            |  UNION ALL
+            |  SELECT id, num, 'b' as source FROM b
+            |) AS c WHERE c.id IN (SELECT id FROM b WHERE num = 2)
+          """.stripMargin)
+      }
+    }
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to