This is an automated email from the ASF dual-hosted git repository. hvanhovell pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new c79f471 [SPARK-23128][SQL] A new approach to do adaptive execution in Spark SQL c79f471 is described below commit c79f471d0475fd98ddeb1e6281551e42684837d2 Author: maryannxue <maryann...@apache.org> AuthorDate: Sat Jun 15 11:27:15 2019 +0200 [SPARK-23128][SQL] A new approach to do adaptive execution in Spark SQL ## What changes were proposed in this pull request? Implemented a new SparkPlan that executes the query adaptively. It splits the query plan into independent stages and executes them in order according to their dependencies. The query stage materializes its output at the end. When one stage completes, the data statistics of the materialized output will be used to optimize the remainder of the query. The adaptive mode is off by default, when turned on, user can see "AdaptiveSparkPlan" as the top node of a query or sub-query. The inner plan of "AdaptiveSparkPlan" is subject to change during query execution but becomes final once the execution is complete. Whether the inner plan is final is included in the EXPLAIN string. Below is an example of the EXPLAIN plan before and after execution: Query: ``` SELECT * FROM testData JOIN testData2 ON key = a WHERE value = '1' ``` Before execution: ``` == Physical Plan == AdaptiveSparkPlan(isFinalPlan=false) +- SortMergeJoin [key#13], [a#23], Inner :- Sort [key#13 ASC NULLS FIRST], false, 0 : +- Exchange hashpartitioning(key#13, 5) : +- Filter (isnotnull(value#14) AND (value#14 = 1)) : +- SerializeFromObject [knownnotnull(assertnotnull(input[0, org.apache.spark.sql.test.SQLTestData$TestData, true])).key AS key#13, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, knownnotnull(assertnotnull(input[0, org.apache.spark.sql.test.SQLTestData$TestData, true])).value, true, false) AS value#14] : +- Scan[obj#12] +- Sort [a#23 ASC NULLS FIRST], false, 0 +- Exchange hashpartitioning(a#23, 5) +- SerializeFromObject [knownnotnull(assertnotnull(input[0, org.apache.spark.sql.test.SQLTestData$TestData2, true])).a AS a#23, knownnotnull(assertnotnull(input[0, org.apache.spark.sql.test.SQLTestData$TestData2, true])).b AS b#24] +- Scan[obj#22] ``` After execution: ``` == Physical Plan == AdaptiveSparkPlan(isFinalPlan=true) +- *(1) BroadcastHashJoin [key#13], [a#23], Inner, BuildLeft :- BroadcastQueryStage 2 : +- BroadcastExchange HashedRelationBroadcastMode(List(cast(input[0, int, false] as bigint))) : +- ShuffleQueryStage 0 : +- Exchange hashpartitioning(key#13, 5) : +- *(1) Filter (isnotnull(value#14) AND (value#14 = 1)) : +- *(1) SerializeFromObject [knownnotnull(assertnotnull(input[0, org.apache.spark.sql.test.SQLTestData$TestData, true])).key AS key#13, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, knownnotnull(assertnotnull(input[0, org.apache.spark.sql.test.SQLTestData$TestData, true])).value, true, false) AS value#14] : +- Scan[obj#12] +- ShuffleQueryStage 1 +- Exchange hashpartitioning(a#23, 5) +- *(1) SerializeFromObject [knownnotnull(assertnotnull(input[0, org.apache.spark.sql.test.SQLTestData$TestData2, true])).a AS a#23, knownnotnull(assertnotnull(input[0, org.apache.spark.sql.test.SQLTestData$TestData2, true])).b AS b#24] +- Scan[obj#22] ``` Credit also goes to carsonwang and cloud-fan ## How was this patch tested? Added new UT. Closes #24706 from maryannxue/aqe. Authored-by: maryannxue <maryann...@apache.org> Signed-off-by: herman <her...@databricks.com> --- .../org/apache/spark/sql/internal/SQLConf.scala | 11 +- .../spark/sql/execution/QueryExecution.scala | 12 +- .../org/apache/spark/sql/execution/SparkPlan.scala | 35 +- .../apache/spark/sql/execution/SparkPlanInfo.scala | 3 + .../apache/spark/sql/execution/SparkPlanner.scala | 2 + .../spark/sql/execution/SparkStrategies.scala | 5 +- .../sql/execution/WholeStageCodegenExec.scala | 96 +++-- .../execution/adaptive/AdaptiveSparkPlanExec.scala | 425 +++++++++++++++++++++ .../adaptive/InsertAdaptiveSparkPlan.scala | 131 +++++++ .../sql/execution/adaptive/LogicalQueryStage.scala | 56 +++ .../adaptive/LogicalQueryStageStrategy.scala | 63 +++ .../adaptive/PlanAdaptiveSubqueries.scala | 36 ++ .../sql/execution/adaptive/QueryStageExec.scala | 210 ++++++++++ .../execution/exchange/BroadcastExchangeExec.scala | 28 +- .../execution/exchange/ExchangeCoordinator.scala | 2 +- .../execution/exchange/ShuffleExchangeExec.scala | 11 +- .../org/apache/spark/sql/execution/subquery.scala | 3 + .../sql/execution/ui/SQLAppStatusListener.scala | 54 ++- .../spark/sql/execution/ui/SQLListener.scala | 7 + .../spark/sql/execution/ui/SparkPlanGraph.scala | 9 + .../adaptive/AdaptiveQueryExecSuite.scala | 316 +++++++++++++++ 21 files changed, 1430 insertions(+), 85 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index fee7172..1eb39f8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -292,6 +292,12 @@ object SQLConf { .bytesConf(ByteUnit.BYTE) .createWithDefault(64 * 1024 * 1024) + val RUNTIME_REOPTIMIZATION_ENABLED = + buildConf("spark.sql.runtime.reoptimization.enabled") + .doc("When true, enable runtime query re-optimization.") + .booleanConf + .createWithDefault(false) + val ADAPTIVE_EXECUTION_ENABLED = buildConf("spark.sql.adaptive.enabled") .doc("When true, enable adaptive query execution.") .booleanConf @@ -1889,7 +1895,10 @@ class SQLConf extends Serializable with Logging { def targetPostShuffleInputSize: Long = getConf(SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE) - def adaptiveExecutionEnabled: Boolean = getConf(ADAPTIVE_EXECUTION_ENABLED) + def runtimeReoptimizationEnabled: Boolean = getConf(RUNTIME_REOPTIMIZATION_ENABLED) + + def adaptiveExecutionEnabled: Boolean = + getConf(ADAPTIVE_EXECUTION_ENABLED) && !getConf(RUNTIME_REOPTIMIZATION_ENABLED) def minNumPostShufflePartitions: Int = getConf(SHUFFLE_MIN_NUM_POSTSHUFFLE_PARTITIONS) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 5d2710bd..6f0b489 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.util.StringUtils.{PlanStringConcat, StringConcat} import org.apache.spark.sql.catalyst.util.truncatedString +import org.apache.spark.sql.execution.adaptive.InsertAdaptiveSparkPlan import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReuseExchange} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.Utils @@ -74,9 +75,15 @@ class QueryExecution( lazy val sparkPlan: SparkPlan = tracker.measurePhase(QueryPlanningTracker.PLANNING) { SparkSession.setActiveSession(sparkSession) + // Runtime re-optimization requires a unique instance of every node in the logical plan. + val logicalPlan = if (sparkSession.sessionState.conf.runtimeReoptimizationEnabled) { + optimizedPlan.clone() + } else { + optimizedPlan + } // TODO: We use next(), i.e. take the first plan returned by the planner, here for now, // but we will implement to choose the best plan. - planner.plan(ReturnAnswer(optimizedPlan)).next() + planner.plan(ReturnAnswer(logicalPlan)).next() } // executedPlan should not be used to initialize any SparkPlan. It should be @@ -107,6 +114,9 @@ class QueryExecution( /** A sequence of rules that will be applied in order to the physical plan before execution. */ protected def preparations: Seq[Rule[SparkPlan]] = Seq( + // `AdaptiveSparkPlanExec` is a leaf node. If inserted, all the following rules will be no-op + // as the original plan is hidden behind `AdaptiveSparkPlanExec`. + InsertAdaptiveSparkPlan(sparkSession), PlanSubqueries(sparkSession), EnsureRequirements(sparkSession.sessionState.conf), CollapseCodegenStages(sparkSession.sessionState.conf), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index ddcf61b..fbe8e50 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -40,9 +40,11 @@ import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.types.DataType object SparkPlan { - // a TreeNode tag in SparkPlan, to carry its original logical plan. The planner will add this tag - // when converting a logical plan to a physical plan. + /** The original [[LogicalPlan]] from which this [[SparkPlan]] is converted. */ val LOGICAL_PLAN_TAG = TreeNodeTag[LogicalPlan]("logical_plan") + + /** The [[LogicalPlan]] inherited from its ancestor. */ + val LOGICAL_PLAN_INHERITED_TAG = TreeNodeTag[LogicalPlan]("logical_plan_inherited") } /** @@ -80,6 +82,35 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ } /** + * @return The logical plan this plan is linked to. + */ + def logicalLink: Option[LogicalPlan] = + getTagValue(SparkPlan.LOGICAL_PLAN_TAG) + .orElse(getTagValue(SparkPlan.LOGICAL_PLAN_INHERITED_TAG)) + + /** + * Set logical plan link recursively if unset. + */ + def setLogicalLink(logicalPlan: LogicalPlan): Unit = { + setLogicalLink(logicalPlan, false) + } + + private def setLogicalLink(logicalPlan: LogicalPlan, inherited: Boolean = false): Unit = { + // Stop at a descendant which is the root of a sub-tree transformed from another logical node. + if (inherited && getTagValue(SparkPlan.LOGICAL_PLAN_TAG).isDefined) { + return + } + + val tag = if (inherited) { + SparkPlan.LOGICAL_PLAN_INHERITED_TAG + } else { + SparkPlan.LOGICAL_PLAN_TAG + } + setTagValue(tag, logicalPlan) + children.foreach(_.setLogicalLink(logicalPlan, true)) + } + + /** * @return All metrics containing metrics of this SparkPlan. */ def metrics: Map[String, SQLMetric] = Map.empty diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala index 3cd02b9..8c7752c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, QueryStageExec} import org.apache.spark.sql.execution.exchange.ReusedExchangeExec import org.apache.spark.sql.execution.metric.SQLMetricInfo import org.apache.spark.sql.internal.SQLConf @@ -53,6 +54,8 @@ private[execution] object SparkPlanInfo { val children = plan match { case ReusedExchangeExec(_, child) => child :: Nil case ReusedSubqueryExec(child) => child :: Nil + case a: AdaptiveSparkPlanExec => a.executedPlan :: Nil + case stage: QueryStageExec => stage.plan :: Nil case _ => plan.children ++ plan.subqueries } val metrics = plan.metrics.toSeq.map { case (key, metric) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala index 2a4a1c8..dc7fb77 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala @@ -21,6 +21,7 @@ import org.apache.spark.SparkContext import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.adaptive.LogicalQueryStageStrategy import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, FileSourceStrategy} import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Strategy import org.apache.spark.sql.internal.SQLConf @@ -36,6 +37,7 @@ class SparkPlanner( override def strategies: Seq[Strategy] = experimentalMethods.extraStrategies ++ extraPlanningStrategies ++ ( + LogicalQueryStageStrategy :: PythonEvals :: DataSourceV2Strategy :: FileSourceStrategy :: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index c403149..faf2fdd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes +import org.apache.spark.sql.execution.adaptive.LogicalQueryStage import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec} import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec @@ -58,6 +59,8 @@ case class PlanLater(plan: LogicalPlan) extends LeafExecNode { protected override def doExecute(): RDD[InternalRow] = { throw new UnsupportedOperationException() } + + override def setLogicalLink(logicalPlan: LogicalPlan): Unit = {} } abstract class SparkStrategies extends QueryPlanner[SparkPlan] { @@ -69,7 +72,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case ReturnAnswer(rootPlan) => rootPlan case _ => plan } - p.setTagValue(SparkPlan.LOGICAL_PLAN_TAG, logicalPlan) + p.setLogicalLink(logicalPlan) p } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 99dcca8..92e80dc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution import java.io.Writer import java.util.Locale +import java.util.concurrent.atomic.AtomicInteger import scala.collection.mutable import scala.util.control.NonFatal @@ -551,56 +552,6 @@ object WholeStageCodegenExec { } } -object WholeStageCodegenId { - // codegenStageId: ID for codegen stages within a query plan. - // It does not affect equality, nor does it participate in destructuring pattern matching - // of WholeStageCodegenExec. - // - // This ID is used to help differentiate between codegen stages. It is included as a part - // of the explain output for physical plans, e.g. - // - // == Physical Plan == - // *(5) SortMergeJoin [x#3L], [y#9L], Inner - // :- *(2) Sort [x#3L ASC NULLS FIRST], false, 0 - // : +- Exchange hashpartitioning(x#3L, 200) - // : +- *(1) Project [(id#0L % 2) AS x#3L] - // : +- *(1) Filter isnotnull((id#0L % 2)) - // : +- *(1) Range (0, 5, step=1, splits=8) - // +- *(4) Sort [y#9L ASC NULLS FIRST], false, 0 - // +- Exchange hashpartitioning(y#9L, 200) - // +- *(3) Project [(id#6L % 2) AS y#9L] - // +- *(3) Filter isnotnull((id#6L % 2)) - // +- *(3) Range (0, 5, step=1, splits=8) - // - // where the ID makes it obvious that not all adjacent codegen'd plan operators are of the - // same codegen stage. - // - // The codegen stage ID is also optionally included in the name of the generated classes as - // a suffix, so that it's easier to associate a generated class back to the physical operator. - // This is controlled by SQLConf: spark.sql.codegen.useIdInClassName - // - // The ID is also included in various log messages. - // - // Within a query, a codegen stage in a plan starts counting from 1, in "insertion order". - // WholeStageCodegenExec operators are inserted into a plan in depth-first post-order. - // See CollapseCodegenStages.insertWholeStageCodegen for the definition of insertion order. - // - // 0 is reserved as a special ID value to indicate a temporary WholeStageCodegenExec object - // is created, e.g. for special fallback handling when an existing WholeStageCodegenExec - // failed to generate/compile code. - - private val codegenStageCounter: ThreadLocal[Integer] = ThreadLocal.withInitial(() => 1) - - def resetPerQuery(): Unit = codegenStageCounter.set(1) - - def getNextStageId(): Int = { - val counter = codegenStageCounter - val id = counter.get() - counter.set(id + 1) - id - } -} - /** * WholeStageCodegen compiles a subtree of plans that support codegen together into single Java * function. @@ -824,8 +775,48 @@ case class WholeStageCodegenExec(child: SparkPlan)(val codegenStageId: Int) /** * Find the chained plans that support codegen, collapse them together as WholeStageCodegen. + * + * The `codegenStageCounter` generates ID for codegen stages within a query plan. + * It does not affect equality, nor does it participate in destructuring pattern matching + * of WholeStageCodegenExec. + * + * This ID is used to help differentiate between codegen stages. It is included as a part + * of the explain output for physical plans, e.g. + * + * == Physical Plan == + * *(5) SortMergeJoin [x#3L], [y#9L], Inner + * :- *(2) Sort [x#3L ASC NULLS FIRST], false, 0 + * : +- Exchange hashpartitioning(x#3L, 200) + * : +- *(1) Project [(id#0L % 2) AS x#3L] + * : +- *(1) Filter isnotnull((id#0L % 2)) + * : +- *(1) Range (0, 5, step=1, splits=8) + * +- *(4) Sort [y#9L ASC NULLS FIRST], false, 0 + * +- Exchange hashpartitioning(y#9L, 200) + * +- *(3) Project [(id#6L % 2) AS y#9L] + * +- *(3) Filter isnotnull((id#6L % 2)) + * +- *(3) Range (0, 5, step=1, splits=8) + * + * where the ID makes it obvious that not all adjacent codegen'd plan operators are of the + * same codegen stage. + * + * The codegen stage ID is also optionally included in the name of the generated classes as + * a suffix, so that it's easier to associate a generated class back to the physical operator. + * This is controlled by SQLConf: spark.sql.codegen.useIdInClassName + * + * The ID is also included in various log messages. + * + * Within a query, a codegen stage in a plan starts counting from 1, in "insertion order". + * WholeStageCodegenExec operators are inserted into a plan in depth-first post-order. + * See CollapseCodegenStages.insertWholeStageCodegen for the definition of insertion order. + * + * 0 is reserved as a special ID value to indicate a temporary WholeStageCodegenExec object + * is created, e.g. for special fallback handling when an existing WholeStageCodegenExec + * failed to generate/compile code. */ -case class CollapseCodegenStages(conf: SQLConf) extends Rule[SparkPlan] { +case class CollapseCodegenStages( + conf: SQLConf, + codegenStageCounter: AtomicInteger = new AtomicInteger(0)) + extends Rule[SparkPlan] { private def supportCodegen(e: Expression): Boolean = e match { case e: LeafExpression => true @@ -869,14 +860,13 @@ case class CollapseCodegenStages(conf: SQLConf) extends Rule[SparkPlan] { case plan if plan.output.length == 1 && plan.output.head.dataType.isInstanceOf[ObjectType] => plan.withNewChildren(plan.children.map(insertWholeStageCodegen)) case plan: CodegenSupport if supportCodegen(plan) => - WholeStageCodegenExec(insertInputAdapter(plan))(WholeStageCodegenId.getNextStageId()) + WholeStageCodegenExec(insertInputAdapter(plan))(codegenStageCounter.incrementAndGet()) case other => other.withNewChildren(other.children.map(insertWholeStageCodegen)) } def apply(plan: SparkPlan): SparkPlan = { if (conf.wholeStageEnabled) { - WholeStageCodegenId.resetPerQuery() insertWholeStageCodegen(plan) } else { plan diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala new file mode 100644 index 0000000..606fbd8 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala @@ -0,0 +1,425 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.adaptive + +import java.util +import java.util.concurrent.LinkedBlockingQueue + +import scala.collection.JavaConverters._ +import scala.collection.concurrent.TrieMap +import scala.collection.mutable +import scala.concurrent.ExecutionContext +import scala.util.control.NonFatal + +import org.apache.spark.SparkException +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer} +import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor} +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec._ +import org.apache.spark.sql.execution.exchange._ +import org.apache.spark.sql.execution.ui.SparkListenerSQLAdaptiveExecutionUpdate +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.util.ThreadUtils + +/** + * A root node to execute the query plan adaptively. It splits the query plan into independent + * stages and executes them in order according to their dependencies. The query stage + * materializes its output at the end. When one stage completes, the data statistics of the + * materialized output will be used to optimize the remainder of the query. + * + * To create query stages, we traverse the query tree bottom up. When we hit an exchange node, + * and if all the child query stages of this exchange node are materialized, we create a new + * query stage for this exchange node. The new stage is then materialized asynchronously once it + * is created. + * + * When one query stage finishes materialization, the rest query is re-optimized and planned based + * on the latest statistics provided by all materialized stages. Then we traverse the query plan + * again and create more stages if possible. After all stages have been materialized, we execute + * the rest of the plan. + */ +case class AdaptiveSparkPlanExec( + initialPlan: SparkPlan, + @transient session: SparkSession, + @transient subqueryMap: Map[Long, ExecSubqueryExpression], + @transient stageCache: TrieMap[SparkPlan, QueryStageExec]) + extends LeafExecNode { + + @transient private val lock = new Object() + + // The logical plan optimizer for re-optimizing the current logical plan. + @transient private val optimizer = new RuleExecutor[LogicalPlan] { + // TODO add more optimization rules + override protected def batches: Seq[Batch] = Seq() + } + + // A list of physical plan rules to be applied before creation of query stages. The physical + // plan should reach a final status of query stages (i.e., no more addition or removal of + // Exchange nodes) after running these rules. + @transient private val queryStagePreparationRules: Seq[Rule[SparkPlan]] = Seq( + PlanAdaptiveSubqueries(subqueryMap), + EnsureRequirements(conf) + ) + + // A list of physical optimizer rules to be applied to a new stage before its execution. These + // optimizations should be stage-independent. + @transient private val queryStageOptimizerRules: Seq[Rule[SparkPlan]] = Seq( + CollapseCodegenStages(conf) + ) + + @volatile private var currentPhysicalPlan = initialPlan + + private var isFinalPlan = false + + private var currentStageId = 0 + + /** + * Return type for `createQueryStages` + * @param newPlan the new plan with created query stages. + * @param allChildStagesMaterialized whether all child stages have been materialized. + * @param newStages the newly created query stages, including new reused query stages. + */ + private case class CreateStageResult( + newPlan: SparkPlan, + allChildStagesMaterialized: Boolean, + newStages: Seq[(Exchange, QueryStageExec)]) + + def executedPlan: SparkPlan = currentPhysicalPlan + + override def conf: SQLConf = session.sessionState.conf + + override def output: Seq[Attribute] = initialPlan.output + + override def doCanonicalize(): SparkPlan = initialPlan.canonicalized + + override def doExecute(): RDD[InternalRow] = lock.synchronized { + if (isFinalPlan) { + currentPhysicalPlan.execute() + } else { + val executionId = Option( + session.sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)).map(_.toLong) + var currentLogicalPlan = currentPhysicalPlan.logicalLink.get + var result = createQueryStages(currentPhysicalPlan) + val events = new LinkedBlockingQueue[StageMaterializationEvent]() + val errors = new mutable.ArrayBuffer[SparkException]() + while (!result.allChildStagesMaterialized) { + currentPhysicalPlan = result.newPlan + currentLogicalPlan = updateLogicalPlan(currentLogicalPlan, result.newStages) + currentPhysicalPlan.setTagValue(SparkPlan.LOGICAL_PLAN_TAG, currentLogicalPlan) + executionId.foreach(onUpdatePlan) + + // Start materialization of all new stages. + result.newStages.map(_._2).foreach { stage => + stage.materialize().onComplete { res => + if (res.isSuccess) { + events.offer(StageSuccess(stage, res.get)) + } else { + events.offer(StageFailure(stage, res.failed.get)) + } + }(AdaptiveSparkPlanExec.executionContext) + } + + // Wait on the next completed stage, which indicates new stats are available and probably + // new stages can be created. There might be other stages that finish at around the same + // time, so we process those stages too in order to reduce re-planning. + val nextMsg = events.take() + val rem = new util.ArrayList[StageMaterializationEvent]() + events.drainTo(rem) + (Seq(nextMsg) ++ rem.asScala).foreach { + case StageSuccess(stage, res) => + stage.resultOption = Some(res) + case StageFailure(stage, ex) => + errors.append( + new SparkException(s"Failed to materialize query stage: ${stage.treeString}", ex)) + } + + // In case of errors, we cancel all running stages and throw exception. + if (errors.nonEmpty) { + cleanUpAndThrowException(errors) + } + + // Do re-planning and try creating new stages on the new physical plan. + val (newPhysicalPlan, newLogicalPlan) = reOptimize(currentLogicalPlan) + currentPhysicalPlan = newPhysicalPlan + currentLogicalPlan = newLogicalPlan + result = createQueryStages(currentPhysicalPlan) + } + + // Run the final plan when there's no more unfinished stages. + currentPhysicalPlan = applyPhysicalRules(result.newPlan, queryStageOptimizerRules) + currentPhysicalPlan.setTagValue(SparkPlan.LOGICAL_PLAN_TAG, currentLogicalPlan) + isFinalPlan = true + logDebug(s"Final plan: $currentPhysicalPlan") + executionId.foreach(onUpdatePlan) + + currentPhysicalPlan.execute() + } + } + + override def verboseString(maxFields: Int): String = simpleString(maxFields) + + override def simpleString(maxFields: Int): String = + s"AdaptiveSparkPlan(isFinalPlan=$isFinalPlan)" + + override def generateTreeString( + depth: Int, + lastChildren: Seq[Boolean], + append: String => Unit, + verbose: Boolean, + prefix: String = "", + addSuffix: Boolean = false, + maxFields: Int): Unit = { + super.generateTreeString(depth, lastChildren, append, verbose, prefix, addSuffix, maxFields) + currentPhysicalPlan.generateTreeString( + depth + 1, lastChildren :+ true, append, verbose, "", addSuffix = false, maxFields) + } + + /** + * This method is called recursively to traverse the plan tree bottom-up and create a new query + * stage or try reusing an existing stage if the current node is an [[Exchange]] node and all of + * its child stages have been materialized. + * + * With each call, it returns: + * 1) The new plan replaced with [[QueryStageExec]] nodes where new stages are created. + * 2) Whether the child query stages (if any) of the current node have all been materialized. + * 3) A list of the new query stages that have been created. + */ + private def createQueryStages(plan: SparkPlan): CreateStageResult = plan match { + case e: Exchange => + // First have a quick check in the `stageCache` without having to traverse down the node. + stageCache.get(e.canonicalized) match { + case Some(existingStage) if conf.exchangeReuseEnabled => + val reusedStage = reuseQueryStage(existingStage, e.output) + // When reusing a stage, we treat it a new stage regardless of whether the existing stage + // has been materialized or not. Thus we won't skip re-optimization for a reused stage. + CreateStageResult(newPlan = reusedStage, + allChildStagesMaterialized = false, newStages = Seq((e, reusedStage))) + + case _ => + val result = createQueryStages(e.child) + val newPlan = e.withNewChildren(Seq(result.newPlan)).asInstanceOf[Exchange] + // Create a query stage only when all the child query stages are ready. + if (result.allChildStagesMaterialized) { + var newStage = newQueryStage(newPlan) + if (conf.exchangeReuseEnabled) { + // Check the `stageCache` again for reuse. If a match is found, ditch the new stage + // and reuse the existing stage found in the `stageCache`, otherwise update the + // `stageCache` with the new stage. + val queryStage = stageCache.getOrElseUpdate(e.canonicalized, newStage) + if (queryStage.ne(newStage)) { + newStage = reuseQueryStage(queryStage, e.output) + } + } + + // We've created a new stage, which is obviously not ready yet. + CreateStageResult(newPlan = newStage, + allChildStagesMaterialized = false, newStages = Seq((e, newStage))) + } else { + CreateStageResult(newPlan = newPlan, + allChildStagesMaterialized = false, newStages = result.newStages) + } + } + + case q: QueryStageExec => + CreateStageResult(newPlan = q, + allChildStagesMaterialized = q.resultOption.isDefined, newStages = Seq.empty) + + case _ => + if (plan.children.isEmpty) { + CreateStageResult(newPlan = plan, allChildStagesMaterialized = true, newStages = Seq.empty) + } else { + val results = plan.children.map(createQueryStages) + CreateStageResult( + newPlan = plan.withNewChildren(results.map(_.newPlan)), + allChildStagesMaterialized = results.forall(_.allChildStagesMaterialized), + newStages = results.flatMap(_.newStages)) + } + } + + private def newQueryStage(e: Exchange): QueryStageExec = { + val optimizedPlan = applyPhysicalRules(e.child, queryStageOptimizerRules) + val queryStage = e match { + case s: ShuffleExchangeExec => + ShuffleQueryStageExec(currentStageId, s.copy(child = optimizedPlan)) + case b: BroadcastExchangeExec => + BroadcastQueryStageExec(currentStageId, b.copy(child = optimizedPlan)) + } + currentStageId += 1 + queryStage + } + + private def reuseQueryStage(s: QueryStageExec, output: Seq[Attribute]): QueryStageExec = { + val queryStage = ReusedQueryStageExec(currentStageId, s, output) + currentStageId += 1 + queryStage + } + + /** + * Returns the updated logical plan after new query stages have been created and the physical + * plan has been updated with the newly created stages. + * 1. If the new query stage can be mapped to an integral logical sub-tree, replace the + * corresponding logical sub-tree with a leaf node [[LogicalQueryStage]] referencing the new + * query stage. For example: + * Join SMJ SMJ + * / \ / \ / \ + * r1 r2 => Xchg1 Xchg2 => Stage1 Stage2 + * | | + * r1 r2 + * The updated plan node will be: + * Join + * / \ + * LogicalQueryStage1(Stage1) LogicalQueryStage2(Stage2) + * + * 2. Otherwise (which means the new query stage can only be mapped to part of a logical + * sub-tree), replace the corresponding logical sub-tree with a leaf node + * [[LogicalQueryStage]] referencing to the top physical node into which this logical node is + * transformed during physical planning. For example: + * Agg HashAgg HashAgg + * | | | + * child => Xchg => Stage1 + * | + * HashAgg + * | + * child + * The updated plan node will be: + * LogicalQueryStage(HashAgg - Stage1) + */ + private def updateLogicalPlan( + logicalPlan: LogicalPlan, + newStages: Seq[(Exchange, QueryStageExec)]): LogicalPlan = { + var currentLogicalPlan = logicalPlan + newStages.foreach { + case (exchange, stage) => + // Get the corresponding logical node for `exchange`. If `exchange` has been transformed + // from a `Repartition`, it should have `logicalLink` available by itself; otherwise + // traverse down to find the first node that is not generated by `EnsureRequirements`. + val logicalNodeOpt = exchange.logicalLink.orElse(exchange.collectFirst { + case p if p.logicalLink.isDefined => p.logicalLink.get + }) + assert(logicalNodeOpt.isDefined) + val logicalNode = logicalNodeOpt.get + val physicalNode = currentPhysicalPlan.collectFirst { + case p if p.eq(stage) || p.logicalLink.exists(logicalNode.eq) => p + } + assert(physicalNode.isDefined) + // Replace the corresponding logical node with LogicalQueryStage + val newLogicalNode = LogicalQueryStage(logicalNode, physicalNode.get) + val newLogicalPlan = currentLogicalPlan.transformDown { + case p if p.eq(logicalNode) => newLogicalNode + } + assert(newLogicalPlan != currentLogicalPlan, + s"logicalNode: $logicalNode; " + + s"logicalPlan: $currentLogicalPlan " + + s"physicalPlan: $currentPhysicalPlan" + + s"stage: $stage") + currentLogicalPlan = newLogicalPlan + } + currentLogicalPlan + } + + /** + * Re-optimize and run physical planning on the current logical plan based on the latest stats. + */ + private def reOptimize(logicalPlan: LogicalPlan): (SparkPlan, LogicalPlan) = { + logicalPlan.invalidateStatsCache() + val optimized = optimizer.execute(logicalPlan) + SparkSession.setActiveSession(session) + val sparkPlan = session.sessionState.planner.plan(ReturnAnswer(optimized)).next() + val newPlan = applyPhysicalRules(sparkPlan, queryStagePreparationRules) + (newPlan, optimized) + } + + /** + * Notify the listeners of the physical plan change. + */ + private def onUpdatePlan(executionId: Long): Unit = { + session.sparkContext.listenerBus.post(SparkListenerSQLAdaptiveExecutionUpdate( + executionId, + SQLExecution.getQueryExecution(executionId).toString, + SparkPlanInfo.fromSparkPlan(this))) + } + + /** + * Cancel all running stages with best effort and throw an Exception containing all stage + * materialization errors and stage cancellation errors. + */ + private def cleanUpAndThrowException(errors: Seq[SparkException]): Unit = { + val runningStages = currentPhysicalPlan.collect { + case s: QueryStageExec => s + } + val cancelErrors = new mutable.ArrayBuffer[SparkException]() + try { + runningStages.foreach { s => + try { + s.cancel() + } catch { + case NonFatal(t) => + cancelErrors.append( + new SparkException(s"Failed to cancel query stage: ${s.treeString}", t)) + } + } + } finally { + val ex = new SparkException( + "Adaptive execution failed due to stage materialization failures.", errors.head) + errors.tail.foreach(ex.addSuppressed) + cancelErrors.foreach(ex.addSuppressed) + throw ex + } + } +} + +object AdaptiveSparkPlanExec { + private val executionContext = ExecutionContext.fromExecutorService( + ThreadUtils.newDaemonCachedThreadPool("QueryStageCreator", 16)) + + /** + * Creates the list of physical plan rules to be applied before creation of query stages. + */ + def createQueryStagePreparationRules( + conf: SQLConf, + subqueryMap: Map[Long, ExecSubqueryExpression]): Seq[Rule[SparkPlan]] = { + Seq( + PlanAdaptiveSubqueries(subqueryMap), + EnsureRequirements(conf)) + } + + /** + * Apply a list of physical operator rules on a [[SparkPlan]]. + */ + def applyPhysicalRules(plan: SparkPlan, rules: Seq[Rule[SparkPlan]]): SparkPlan = { + rules.foldLeft(plan) { case (sp, rule) => rule.apply(sp) } + } +} + +/** + * The event type for stage materialization. + */ +sealed trait StageMaterializationEvent + +/** + * The materialization of a query stage completed with success. + */ +case class StageSuccess(stage: QueryStageExec, result: Any) extends StageMaterializationEvent + +/** + * The materialization of a query stage hit an error and failed. + */ +case class StageFailure(stage: QueryStageExec, error: Throwable) extends StageMaterializationEvent diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala new file mode 100644 index 0000000..a1b0e29 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.adaptive + +import scala.collection.concurrent.TrieMap +import scala.collection.mutable + +import org.apache.spark.sql.{execution, SparkSession} +import org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.command.ExecutedCommandExec +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.StructType + +/** + * This rule wraps the query plan with an [[AdaptiveSparkPlanExec]], which executes the query plan + * and re-optimize the plan during execution based on runtime data statistics. + * + * Note that this rule is stateful and thus should not be reused across query executions. + */ +case class InsertAdaptiveSparkPlan(session: SparkSession) extends Rule[SparkPlan] { + + private val conf = session.sessionState.conf + + // Exchange-reuse is shared across the entire query, including sub-queries. + private val stageCache = new TrieMap[SparkPlan, QueryStageExec]() + + override def apply(plan: SparkPlan): SparkPlan = plan match { + case _: ExecutedCommandExec => plan + case _ if conf.runtimeReoptimizationEnabled && supportAdaptive(plan) => + try { + // Plan sub-queries recursively and pass in the shared stage cache for exchange reuse. Fall + // back to non-adaptive mode if adaptive execution is supported in any of the sub-queries. + val subqueryMap = buildSubqueryMap(plan) + // Run preparation rules. + val preparations = AdaptiveSparkPlanExec.createQueryStagePreparationRules( + session.sessionState.conf, subqueryMap) + val newPlan = AdaptiveSparkPlanExec.applyPhysicalRules(plan, preparations) + logDebug(s"Adaptive execution enabled for plan: $plan") + AdaptiveSparkPlanExec(newPlan, session, subqueryMap, stageCache) + } catch { + case SubqueryAdaptiveNotSupportedException(subquery) => + logWarning(s"${SQLConf.RUNTIME_REOPTIMIZATION_ENABLED.key} is enabled " + + s"but is not supported for sub-query: $subquery.") + plan + } + case _ => + if (conf.runtimeReoptimizationEnabled) { + logWarning(s"${SQLConf.RUNTIME_REOPTIMIZATION_ENABLED.key} is enabled " + + s"but is not supported for query: $plan.") + } + plan + } + + private def supportAdaptive(plan: SparkPlan): Boolean = { + sanityCheck(plan) && + !plan.logicalLink.exists(_.isStreaming) && + plan.children.forall(supportAdaptive) + } + + private def sanityCheck(plan: SparkPlan): Boolean = + plan.logicalLink.isDefined + + /** + * Returns an expression-id-to-execution-plan map for all the sub-queries. + * For each sub-query, generate the adaptive execution plan for each sub-query by applying this + * rule, or reuse the execution plan from another sub-query of the same semantics if possible. + */ + private def buildSubqueryMap(plan: SparkPlan): Map[Long, ExecSubqueryExpression] = { + val subqueryMapBuilder = mutable.HashMap.empty[Long, ExecSubqueryExpression] + plan.foreach(_.expressions.foreach(_.foreach { + case expressions.ScalarSubquery(p, _, exprId) + if !subqueryMapBuilder.contains(exprId.id) => + val executedPlan = getExecutedPlan(p) + val scalarSubquery = execution.ScalarSubquery( + SubqueryExec(s"subquery${exprId.id}", executedPlan), exprId) + subqueryMapBuilder.put(exprId.id, scalarSubquery) + case _ => + })) + + // Reuse subqueries + if (session.sessionState.conf.subqueryReuseEnabled) { + // Build a hash map using schema of subqueries to avoid O(N*N) sameResult calls. + val reuseMap = mutable.HashMap[StructType, mutable.ArrayBuffer[BaseSubqueryExec]]() + subqueryMapBuilder.keySet.foreach { exprId => + val sub = subqueryMapBuilder(exprId) + val sameSchema = + reuseMap.getOrElseUpdate(sub.plan.schema, mutable.ArrayBuffer.empty) + val sameResult = sameSchema.find(_.sameResult(sub.plan)) + if (sameResult.isDefined) { + val newExpr = sub.withNewPlan(ReusedSubqueryExec(sameResult.get)) + subqueryMapBuilder.update(exprId, newExpr) + } else { + sameSchema += sub.plan + } + } + } + + subqueryMapBuilder.toMap + } + + private def getExecutedPlan(plan: LogicalPlan): SparkPlan = { + val queryExec = new QueryExecution(session, plan) + // Apply the same instance of this rule to sub-queries so that sub-queries all share the + // same `stageCache` for Exchange reuse. + val adaptivePlan = this.apply(queryExec.sparkPlan) + if (!adaptivePlan.isInstanceOf[AdaptiveSparkPlanExec]) { + throw SubqueryAdaptiveNotSupportedException(plan) + } + adaptivePlan + } +} + +private case class SubqueryAdaptiveNotSupportedException(plan: LogicalPlan) extends Exception {} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/LogicalQueryStage.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/LogicalQueryStage.scala new file mode 100644 index 0000000..9914edd --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/LogicalQueryStage.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.adaptive + +import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder} +import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} +import org.apache.spark.sql.catalyst.plans.physical.Partitioning +import org.apache.spark.sql.execution.SparkPlan + +/** + * The LogicalPlan wrapper for a [[QueryStageExec]], or a snippet of physical plan containing + * a [[QueryStageExec]], in which all ancestor nodes of the [[QueryStageExec]] are linked to + * the same logical node. + * + * For example, a logical Aggregate can be transformed into FinalAgg - Shuffle - PartialAgg, in + * which the Shuffle will be wrapped into a [[QueryStageExec]], thus the [[LogicalQueryStage]] + * will have FinalAgg - QueryStageExec as its physical plan. + */ +// TODO we can potentially include only [[QueryStageExec]] in this class if we make the aggregation +// planning aware of partitioning. +case class LogicalQueryStage( + logicalPlan: LogicalPlan, + physicalPlan: SparkPlan) extends LeafNode { + + override def output: Seq[Attribute] = logicalPlan.output + override val isStreaming: Boolean = logicalPlan.isStreaming + override val outputOrdering: Seq[SortOrder] = physicalPlan.outputOrdering + + override def computeStats(): Statistics = { + // TODO this is not accurate when there is other physical nodes above QueryStageExec. + val physicalStats = physicalPlan.collectFirst { + case s: QueryStageExec => s + }.flatMap(_.computeStats()) + if (physicalStats.isDefined) { + logDebug(s"Physical stats available as ${physicalStats.get} for plan: $physicalPlan") + } else { + logDebug(s"Physical stats not available for plan: $physicalPlan") + } + physicalStats.getOrElse(logicalPlan.stats) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/LogicalQueryStageStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/LogicalQueryStageStrategy.scala new file mode 100644 index 0000000..a0d07a6 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/LogicalQueryStageStrategy.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.adaptive + +import org.apache.spark.sql.Strategy +import org.apache.spark.sql.catalyst.expressions.PredicateHelper +import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys +import org.apache.spark.sql.catalyst.plans.logical.{Join, LogicalPlan} +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, BuildLeft, BuildRight} + +/** + * Strategy for plans containing [[LogicalQueryStage]] nodes: + * 1. Transforms [[LogicalQueryStage]] to its corresponding physical plan that is either being + * executed or has already completed execution. + * 2. Transforms [[Join]] which has one child relation already planned and executed as a + * [[BroadcastQueryStageExec]]. This is to prevent reversing a broadcast stage into a shuffle + * stage in case of the larger join child relation finishes before the smaller relation. Note + * that this rule needs to applied before regular join strategies. + */ +object LogicalQueryStageStrategy extends Strategy with PredicateHelper { + + private def isBroadcastStage(plan: LogicalPlan): Boolean = plan match { + case LogicalQueryStage(_, physicalPlan) + if BroadcastQueryStageExec.isBroadcastQueryStageExec(physicalPlan) => + true + case _ => false + } + + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right, hint) + if isBroadcastStage(left) || isBroadcastStage(right) => + val buildSide = if (isBroadcastStage(left)) BuildLeft else BuildRight + Seq(BroadcastHashJoinExec( + leftKeys, rightKeys, joinType, buildSide, condition, planLater(left), planLater(right))) + + case j @ Join(left, right, joinType, condition, _) + if isBroadcastStage(left) || isBroadcastStage(right) => + val buildSide = if (isBroadcastStage(left)) BuildLeft else BuildRight + BroadcastNestedLoopJoinExec( + planLater(left), planLater(right), buildSide, joinType, condition) :: Nil + + case q: LogicalQueryStage => + q.physicalPlan :: Nil + + case _ => Nil + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala new file mode 100644 index 0000000..4af7432 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.adaptive + +import org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst.expressions.ListQuery +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.{ExecSubqueryExpression, SparkPlan} + +case class PlanAdaptiveSubqueries( + subqueryMap: Map[Long, ExecSubqueryExpression]) extends Rule[SparkPlan] { + + def apply(plan: SparkPlan): SparkPlan = { + plan.transformAllExpressions { + case expressions.ScalarSubquery(_, _, exprId) => + subqueryMap(exprId.id) + case expressions.InSubquery(_, ListQuery(_, _, exprId, _)) => + subqueryMap(exprId.id) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala new file mode 100644 index 0000000..98cb7d0 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala @@ -0,0 +1,210 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.adaptive + +import scala.concurrent.Future + +import org.apache.spark.{FutureAction, MapOutputStatistics} +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical.Statistics +import org.apache.spark.sql.catalyst.plans.physical.Partitioning +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.exchange._ + + +/** + * A query stage is an independent subgraph of the query plan. Query stage materializes its output + * before proceeding with further operators of the query plan. The data statistics of the + * materialized output can be used to optimize subsequent query stages. + * + * There are 2 kinds of query stages: + * 1. Shuffle query stage. This stage materializes its output to shuffle files, and Spark launches + * another job to execute the further operators. + * 2. Broadcast query stage. This stage materializes its output to an array in driver JVM. Spark + * broadcasts the array before executing the further operators. + */ +abstract class QueryStageExec extends LeafExecNode { + + /** + * An id of this query stage which is unique in the entire query plan. + */ + val id: Int + + /** + * The sub-tree of the query plan that belongs to this query stage. + */ + val plan: SparkPlan + + /** + * Materialize this query stage, to prepare for the execution, like submitting map stages, + * broadcasting data, etc. The caller side can use the returned [[Future]] to wait until this + * stage is ready. + */ + def doMaterialize(): Future[Any] + + /** + * Cancel the stage materialization if in progress; otherwise do nothing. + */ + def cancel(): Unit + + /** + * Materialize this query stage, to prepare for the execution, like submitting map stages, + * broadcasting data, etc. The caller side can use the returned [[Future]] to wait until this + * stage is ready. + */ + final def materialize(): Future[Any] = executeQuery { + doMaterialize() + } + + /** + * Compute the statistics of the query stage if executed, otherwise None. + */ + def computeStats(): Option[Statistics] = resultOption.map { _ => + // Metrics `dataSize` are available in both `ShuffleExchangeExec` and `BroadcastExchangeExec`. + Statistics(sizeInBytes = plan.metrics("dataSize").value) + } + + @transient + @volatile + private[adaptive] var resultOption: Option[Any] = None + + override def output: Seq[Attribute] = plan.output + override def outputPartitioning: Partitioning = plan.outputPartitioning + override def outputOrdering: Seq[SortOrder] = plan.outputOrdering + override def executeCollect(): Array[InternalRow] = plan.executeCollect() + override def executeTake(n: Int): Array[InternalRow] = plan.executeTake(n) + override def executeToIterator(): Iterator[InternalRow] = plan.executeToIterator() + + override def doPrepare(): Unit = plan.prepare() + override def doExecute(): RDD[InternalRow] = plan.execute() + override def doExecuteBroadcast[T](): Broadcast[T] = plan.executeBroadcast() + override def doCanonicalize(): SparkPlan = plan.canonicalized + + protected override def stringArgs: Iterator[Any] = Iterator.single(id) + + override def generateTreeString( + depth: Int, + lastChildren: Seq[Boolean], + append: String => Unit, + verbose: Boolean, + prefix: String = "", + addSuffix: Boolean = false, + maxFields: Int): Unit = { + super.generateTreeString(depth, lastChildren, append, verbose, prefix, addSuffix, maxFields) + plan.generateTreeString( + depth + 1, lastChildren :+ true, append, verbose, "", false, maxFields) + } +} + +/** + * A shuffle query stage whose child is a [[ShuffleExchangeExec]]. + */ +case class ShuffleQueryStageExec( + override val id: Int, + override val plan: ShuffleExchangeExec) extends QueryStageExec { + + @transient lazy val mapOutputStatisticsFuture: Future[MapOutputStatistics] = { + if (plan.inputRDD.getNumPartitions == 0) { + Future.successful(null) + } else { + sparkContext.submitMapStage(plan.shuffleDependency) + } + } + + override def doMaterialize(): Future[Any] = { + mapOutputStatisticsFuture + } + + override def cancel(): Unit = { + mapOutputStatisticsFuture match { + case action: FutureAction[MapOutputStatistics] if !mapOutputStatisticsFuture.isCompleted => + action.cancel() + case _ => + } + } +} + +/** + * A broadcast query stage whose child is a [[BroadcastExchangeExec]]. + */ +case class BroadcastQueryStageExec( + override val id: Int, + override val plan: BroadcastExchangeExec) extends QueryStageExec { + + override def doMaterialize(): Future[Any] = { + plan.completionFuture + } + + override def cancel(): Unit = { + if (!plan.relationFuture.isDone) { + sparkContext.cancelJobGroup(plan.runId.toString) + plan.relationFuture.cancel(true) + } + } +} + +object BroadcastQueryStageExec { + /** + * Returns if the plan is a [[BroadcastQueryStageExec]] or a reused [[BroadcastQueryStageExec]]. + */ + def isBroadcastQueryStageExec(plan: SparkPlan): Boolean = plan match { + case r: ReusedQueryStageExec => isBroadcastQueryStageExec(r.plan) + case _: BroadcastQueryStageExec => true + case _ => false + } +} + +/** + * A wrapper for reused query stage to have different output. + */ +case class ReusedQueryStageExec( + override val id: Int, + override val plan: QueryStageExec, + override val output: Seq[Attribute]) extends QueryStageExec { + + override def doMaterialize(): Future[Any] = { + plan.materialize() + } + + override def cancel(): Unit = { + plan.cancel() + } + + // `ReusedQueryStageExec` can have distinct set of output attribute ids from its child, we need + // to update the attribute ids in `outputPartitioning` and `outputOrdering`. + private lazy val updateAttr: Expression => Expression = { + val originalAttrToNewAttr = AttributeMap(plan.output.zip(output)) + e => e.transform { + case attr: Attribute => originalAttrToNewAttr.getOrElse(attr, attr) + } + } + + override def outputPartitioning: Partitioning = plan.outputPartitioning match { + case e: Expression => updateAttr(e).asInstanceOf[Partitioning] + case other => other + } + + override def outputOrdering: Seq[SortOrder] = { + plan.outputOrdering.map(updateAttr(_).asInstanceOf[SortOrder]) + } + + override def computeStats(): Option[Statistics] = plan.computeStats() +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala index 8017188..36f0d17 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.exchange import java.util.UUID import java.util.concurrent._ -import scala.concurrent.ExecutionContext +import scala.concurrent.{ExecutionContext, Promise} import scala.concurrent.duration.NANOSECONDS import scala.util.control.NonFatal @@ -44,7 +44,7 @@ case class BroadcastExchangeExec( mode: BroadcastMode, child: SparkPlan) extends Exchange { - private val runId: UUID = UUID.randomUUID + private[sql] val runId: UUID = UUID.randomUUID override lazy val metrics = Map( "dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"), @@ -59,10 +59,20 @@ case class BroadcastExchangeExec( } @transient + private lazy val promise = Promise[broadcast.Broadcast[Any]]() + + /** + * For registering callbacks on `relationFuture`. + * Note that calling this field will not start the execution of broadcast job. + */ + @transient + lazy val completionFuture: scala.concurrent.Future[broadcast.Broadcast[Any]] = promise.future + + @transient private val timeout: Long = SQLConf.get.broadcastTimeout @transient - private lazy val relationFuture: Future[broadcast.Broadcast[Any]] = { + private[sql] lazy val relationFuture: Future[broadcast.Broadcast[Any]] = { // relationFuture is used in "doExecute". Therefore we can get the execution id correctly here. val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) val task = new Callable[broadcast.Broadcast[Any]]() { @@ -113,20 +123,28 @@ case class BroadcastExchangeExec( System.nanoTime() - beforeBroadcast) SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, metrics.values.toSeq) + promise.success(broadcasted) broadcasted } catch { // SPARK-24294: To bypass scala bug: https://github.com/scala/bug/issues/9554, we throw // SparkFatalException, which is a subclass of Exception. ThreadUtils.awaitResult // will catch this exception and re-throw the wrapped fatal throwable. case oe: OutOfMemoryError => - throw new SparkFatalException( + val ex = new SparkFatalException( new OutOfMemoryError("Not enough memory to build and broadcast the table to all " + "worker nodes. As a workaround, you can either disable broadcast by setting " + s"${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key} to -1 or increase the spark " + s"driver memory by setting ${SparkLauncher.DRIVER_MEMORY} to a higher value.") .initCause(oe.getCause)) + promise.failure(ex) + throw ex case e if !NonFatal(e) => - throw new SparkFatalException(e) + val ex = new SparkFatalException(e) + promise.failure(ex) + throw ex + case e: Throwable => + promise.failure(e) + throw e } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala index e4ec76f..c99bf45 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala @@ -209,7 +209,7 @@ class ExchangeCoordinator( var i = 0 while (i < numExchanges) { val exchange = exchanges(i) - val shuffleDependency = exchange.prepareShuffleDependency() + val shuffleDependency = exchange.shuffleDependency shuffleDependencies += shuffleDependency if (shuffleDependency.rdd.partitions.length != 0) { // submitMapStage does not accept RDD with 0 partition. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala index 16398e3..31f75e3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala @@ -87,15 +87,17 @@ case class ShuffleExchangeExec( } } + @transient lazy val inputRDD: RDD[InternalRow] = child.execute() + /** - * Returns a [[ShuffleDependency]] that will partition rows of its child based on + * A [[ShuffleDependency]] that will partition rows of its child based on * the partitioning scheme defined in `newPartitioning`. Those partitions of * the returned ShuffleDependency will be the input of shuffle. */ - private[exchange] def prepareShuffleDependency() - : ShuffleDependency[Int, InternalRow, InternalRow] = { + @transient + lazy val shuffleDependency : ShuffleDependency[Int, InternalRow, InternalRow] = { ShuffleExchangeExec.prepareShuffleDependency( - child.execute(), + inputRDD, child.output, newPartitioning, serializer, @@ -135,7 +137,6 @@ case class ShuffleExchangeExec( assert(shuffleRDD.partitions.length == newPartitioning.numPartitions) shuffleRDD case _ => - val shuffleDependency = prepareShuffleDependency() preparePostShuffleRDD(shuffleDependency) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala index 960d47b..064e0a0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala @@ -37,6 +37,9 @@ abstract class ExecSubqueryExpression extends PlanExpression[BaseSubqueryExec] { */ def updateResult(): Unit + /** Updates the expression with a new plan. */ + override def withNewPlan(plan: BaseSubqueryExec): ExecSubqueryExpression + override def canonicalize(attrs: AttributeSeq): ExecSubqueryExpression = { withNewPlan(plan.canonicalized.asInstanceOf[BaseSubqueryExec]) .asInstanceOf[ExecSubqueryExpression] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala index e496de1..2c4a7ea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala @@ -248,26 +248,26 @@ class SQLAppStatusListener( } } + private def toStoredNodes(nodes: Seq[SparkPlanGraphNode]): Seq[SparkPlanGraphNodeWrapper] = { + nodes.map { + case cluster: SparkPlanGraphCluster => + val storedCluster = new SparkPlanGraphClusterWrapper( + cluster.id, + cluster.name, + cluster.desc, + toStoredNodes(cluster.nodes), + cluster.metrics) + new SparkPlanGraphNodeWrapper(null, storedCluster) + + case node => + new SparkPlanGraphNodeWrapper(node, null) + } + } + private def onExecutionStart(event: SparkListenerSQLExecutionStart): Unit = { val SparkListenerSQLExecutionStart(executionId, description, details, physicalPlanDescription, sparkPlanInfo, time) = event - def toStoredNodes(nodes: Seq[SparkPlanGraphNode]): Seq[SparkPlanGraphNodeWrapper] = { - nodes.map { - case cluster: SparkPlanGraphCluster => - val storedCluster = new SparkPlanGraphClusterWrapper( - cluster.id, - cluster.name, - cluster.desc, - toStoredNodes(cluster.nodes), - cluster.metrics) - new SparkPlanGraphNodeWrapper(null, storedCluster) - - case node => - new SparkPlanGraphNodeWrapper(node, null) - } - } - val planGraph = SparkPlanGraph(sparkPlanInfo) val sqlPlanMetrics = planGraph.allNodes.flatMap { node => node.metrics.map { metric => (metric.accumulatorId, metric) } @@ -288,6 +288,27 @@ class SQLAppStatusListener( update(exec) } + private def onAdaptiveExecutionUpdate(event: SparkListenerSQLAdaptiveExecutionUpdate): Unit = { + val SparkListenerSQLAdaptiveExecutionUpdate( + executionId, physicalPlanDescription, sparkPlanInfo) = event + + val planGraph = SparkPlanGraph(sparkPlanInfo) + val sqlPlanMetrics = planGraph.allNodes.flatMap { node => + node.metrics.map { metric => (metric.accumulatorId, metric) } + }.toMap.values.toList + + val graphToStore = new SparkPlanGraphWrapper( + executionId, + toStoredNodes(planGraph.nodes), + planGraph.edges) + kvstore.write(graphToStore) + + val exec = getOrCreateExecution(executionId) + exec.physicalPlanDescription = physicalPlanDescription + exec.metrics = sqlPlanMetrics + update(exec) + } + private def onExecutionEnd(event: SparkListenerSQLExecutionEnd): Unit = { val SparkListenerSQLExecutionEnd(executionId, time) = event Option(liveExecutions.get(executionId)).foreach { exec => @@ -320,6 +341,7 @@ class SQLAppStatusListener( override def onOtherEvent(event: SparkListenerEvent): Unit = event match { case e: SparkListenerSQLExecutionStart => onExecutionStart(e) + case e: SparkListenerSQLAdaptiveExecutionUpdate => onAdaptiveExecutionUpdate(e) case e: SparkListenerSQLExecutionEnd => onExecutionEnd(e) case e: SparkListenerDriverAccumUpdates => onDriverAccumUpdates(e) case _ => // Ignore diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala index 03d75c4..67d1f27 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala @@ -28,6 +28,13 @@ import org.apache.spark.scheduler._ import org.apache.spark.sql.execution.{QueryExecution, SparkPlanInfo} @DeveloperApi +case class SparkListenerSQLAdaptiveExecutionUpdate( + executionId: Long, + physicalPlanDescription: String, + sparkPlanInfo: SparkPlanInfo) + extends SparkListenerEvent + +@DeveloperApi case class SparkListenerSQLExecutionStart( executionId: Long, description: String, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala index b864ad1..2b7597e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala @@ -96,6 +96,15 @@ object SparkPlanGraph { case "InputAdapter" => buildSparkPlanGraphNode( planInfo.children.head, nodeIdGenerator, nodes, edges, parent, null, exchanges) + case "BroadcastQueryStage" | "ShuffleQueryStage" => + if (exchanges.contains(planInfo.children.head)) { + // Point to the re-used exchange + val node = exchanges(planInfo.children.head) + edges += SparkPlanGraphEdge(node.id, parent.id) + } else { + buildSparkPlanGraphNode( + planInfo.children.head, nodeIdGenerator, nodes, edges, parent, null, exchanges) + } case "Subquery" if subgraph != null => // Subquery should not be included in WholeStageCodegen buildSparkPlanGraphNode(planInfo, nodeIdGenerator, nodes, edges, parent, null, exchanges) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala new file mode 100644 index 0000000..2cddf7c --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala @@ -0,0 +1,316 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.adaptive + +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.execution.{ReusedSubqueryExec, SparkPlan} +import org.apache.spark.sql.execution.exchange.Exchange +import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSQLContext + +class AdaptiveQueryExecSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + setupTestData() + + private def runAdaptiveAndVerifyResult(query: String): (SparkPlan, SparkPlan) = { + val dfAdaptive = sql(query) + val planBefore = dfAdaptive.queryExecution.executedPlan + assert(planBefore.toString.startsWith("AdaptiveSparkPlan(isFinalPlan=false)")) + val result = dfAdaptive.collect() + withSQLConf(SQLConf.RUNTIME_REOPTIMIZATION_ENABLED.key -> "false") { + val df = sql(query) + QueryTest.sameRows(result.toSeq, df.collect().toSeq) + } + val planAfter = dfAdaptive.queryExecution.executedPlan + assert(planAfter.toString.startsWith("AdaptiveSparkPlan(isFinalPlan=true)")) + val adaptivePlan = planAfter.asInstanceOf[AdaptiveSparkPlanExec].executedPlan + val exchanges = adaptivePlan.collect { + case e: Exchange => e + } + assert(exchanges.isEmpty, "The final plan should not contain any Exchange node.") + (dfAdaptive.queryExecution.sparkPlan, adaptivePlan) + } + + private def findTopLevelBroadcastHashJoin(plan: SparkPlan): Seq[BroadcastHashJoinExec] = { + plan.collect { + case j: BroadcastHashJoinExec => Seq(j) + case s: QueryStageExec => findTopLevelBroadcastHashJoin(s.plan) + }.flatten + } + + private def findTopLevelSortMergeJoin(plan: SparkPlan): Seq[SortMergeJoinExec] = { + plan.collect { + case j: SortMergeJoinExec => Seq(j) + case s: QueryStageExec => findTopLevelSortMergeJoin(s.plan) + }.flatten + } + + private def findReusedExchange(plan: SparkPlan): Seq[ReusedQueryStageExec] = { + plan.collect { + case e: ReusedQueryStageExec => Seq(e) + case a: AdaptiveSparkPlanExec => findReusedExchange(a.executedPlan) + case s: QueryStageExec => findReusedExchange(s.plan) + case p: SparkPlan => p.subqueries.flatMap(findReusedExchange) + }.flatten + } + + private def findReusedSubquery(plan: SparkPlan): Seq[ReusedSubqueryExec] = { + plan.collect { + case e: ReusedSubqueryExec => Seq(e) + case s: QueryStageExec => findReusedSubquery(s.plan) + case p: SparkPlan => p.subqueries.flatMap(findReusedSubquery) + }.flatten + } + + test("Change merge join to broadcast join") { + withSQLConf( + SQLConf.RUNTIME_REOPTIMIZATION_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { + val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( + "SELECT * FROM testData join testData2 ON key = a where value = '1'") + val smj = findTopLevelSortMergeJoin(plan) + assert(smj.size == 1) + val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) + assert(bhj.size == 1) + } + } + + test("Scalar subquery") { + withSQLConf( + SQLConf.RUNTIME_REOPTIMIZATION_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { + val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( + "SELECT * FROM testData join testData2 ON key = a " + + "where value = (SELECT max(a) from testData3)") + val smj = findTopLevelSortMergeJoin(plan) + assert(smj.size == 1) + val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) + assert(bhj.size == 1) + } + } + + test("Scalar subquery in later stages") { + withSQLConf( + SQLConf.RUNTIME_REOPTIMIZATION_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { + val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( + "SELECT * FROM testData join testData2 ON key = a " + + "where (value + a) = (SELECT max(a) from testData3)") + val smj = findTopLevelSortMergeJoin(plan) + assert(smj.size == 1) + val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) + assert(bhj.size == 1) + } + } + + test("multiple joins") { + withSQLConf( + SQLConf.RUNTIME_REOPTIMIZATION_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { + val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( + """ + |WITH t4 AS ( + | SELECT * FROM lowercaseData t2 JOIN testData3 t3 ON t2.n = t3.a + |) + |SELECT * FROM testData + |JOIN testData2 t2 ON key = t2.a + |JOIN t4 ON key = t4.a + |WHERE value = 1 + """.stripMargin) + val smj = findTopLevelSortMergeJoin(plan) + assert(smj.size == 3) + val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) + assert(bhj.size == 2) + } + } + + test("multiple joins with aggregate") { + withSQLConf( + SQLConf.RUNTIME_REOPTIMIZATION_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { + val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( + """ + |WITH t4 AS ( + | SELECT * FROM lowercaseData t2 JOIN ( + | select a, sum(b) from testData3 group by a + | ) t3 ON t2.n = t3.a + |) + |SELECT * FROM testData + |JOIN testData2 t2 ON key = t2.a + |JOIN t4 ON key = t4.a + |WHERE value = 1 + """.stripMargin) + val smj = findTopLevelSortMergeJoin(plan) + assert(smj.size == 3) + val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) + assert(bhj.size == 2) + } + } + + test("multiple joins with aggregate 2") { + withSQLConf( + SQLConf.RUNTIME_REOPTIMIZATION_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "500") { + val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( + """ + |WITH t4 AS ( + | SELECT * FROM lowercaseData t2 JOIN ( + | select a, max(b) b from testData2 group by a + | ) t3 ON t2.n = t3.b + |) + |SELECT * FROM testData + |JOIN testData2 t2 ON key = t2.a + |JOIN t4 ON value = t4.a + |WHERE value = 1 + """.stripMargin) + val smj = findTopLevelSortMergeJoin(plan) + assert(smj.size == 3) + val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) + assert(bhj.size == 3) + } + } + + test("Exchange reuse") { + withSQLConf( + SQLConf.RUNTIME_REOPTIMIZATION_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { + val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( + "SELECT value FROM testData join testData2 ON key = a " + + "join (SELECT value v from testData join testData3 ON key = a) on value = v") + val smj = findTopLevelSortMergeJoin(plan) + assert(smj.size == 3) + val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) + assert(bhj.size == 2) + val ex = findReusedExchange(adaptivePlan) + assert(ex.size == 1) + } + } + + test("Exchange reuse with subqueries") { + withSQLConf( + SQLConf.RUNTIME_REOPTIMIZATION_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { + val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( + "SELECT a FROM testData join testData2 ON key = a " + + "where value = (SELECT max(a) from testData join testData2 ON key = a)") + val smj = findTopLevelSortMergeJoin(plan) + assert(smj.size == 1) + val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) + assert(bhj.size == 1) + val ex = findReusedExchange(adaptivePlan) + assert(ex.size == 1) + } + } + + test("Exchange reuse across subqueries") { + withSQLConf( + SQLConf.RUNTIME_REOPTIMIZATION_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80", + SQLConf.SUBQUERY_REUSE_ENABLED.key -> "false") { + val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( + "SELECT a FROM testData join testData2 ON key = a " + + "where value >= (SELECT max(a) from testData join testData2 ON key = a) " + + "and a <= (SELECT max(a) from testData join testData2 ON key = a)") + val smj = findTopLevelSortMergeJoin(plan) + assert(smj.size == 1) + val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) + assert(bhj.size == 1) + val ex = findReusedExchange(adaptivePlan) + assert(ex.nonEmpty) + val sub = findReusedSubquery(adaptivePlan) + assert(sub.isEmpty) + } + } + + test("Subquery reuse") { + withSQLConf( + SQLConf.RUNTIME_REOPTIMIZATION_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { + val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( + "SELECT a FROM testData join testData2 ON key = a " + + "where value >= (SELECT max(a) from testData join testData2 ON key = a) " + + "and a <= (SELECT max(a) from testData join testData2 ON key = a)") + val smj = findTopLevelSortMergeJoin(plan) + assert(smj.size == 1) + val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) + assert(bhj.size == 1) + val ex = findReusedExchange(adaptivePlan) + assert(ex.isEmpty) + val sub = findReusedSubquery(adaptivePlan) + assert(sub.nonEmpty) + } + } + + test("Broadcast exchange reuse across subqueries") { + withSQLConf( + SQLConf.RUNTIME_REOPTIMIZATION_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "20000000", + SQLConf.SUBQUERY_REUSE_ENABLED.key -> "false") { + val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( + "SELECT a FROM testData join testData2 ON key = a " + + "where value >= (" + + "SELECT /*+ broadcast(testData2) */ max(key) from testData join testData2 ON key = a) " + + "and a <= (" + + "SELECT /*+ broadcast(testData2) */ max(value) from testData join testData2 ON key = a)") + val smj = findTopLevelSortMergeJoin(plan) + assert(smj.size == 1) + val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) + assert(bhj.size == 1) + val ex = findReusedExchange(adaptivePlan) + assert(ex.nonEmpty) + assert(ex.head.plan.isInstanceOf[BroadcastQueryStageExec]) + val sub = findReusedSubquery(adaptivePlan) + assert(sub.isEmpty) + } + } + + test("Union/Except/Intersect queries") { + withSQLConf(SQLConf.RUNTIME_REOPTIMIZATION_ENABLED.key -> "true") { + runAdaptiveAndVerifyResult( + """ + |SELECT * FROM testData + |EXCEPT + |SELECT * FROM testData2 + |UNION ALL + |SELECT * FROM testData + |INTERSECT ALL + |SELECT * FROM testData2 + """.stripMargin) + } + } + + test("Subquery de-correlation in Union queries") { + withSQLConf(SQLConf.RUNTIME_REOPTIMIZATION_ENABLED.key -> "true") { + withTempView("a", "b") { + Seq("a" -> 2, "b" -> 1).toDF("id", "num").createTempView("a") + Seq("a" -> 2, "b" -> 1).toDF("id", "num").createTempView("b") + + runAdaptiveAndVerifyResult( + """ + |SELECT id,num,source FROM ( + | SELECT id, num, 'a' as source FROM a + | UNION ALL + | SELECT id, num, 'b' as source FROM b + |) AS c WHERE c.id IN (SELECT id FROM b WHERE num = 2) + """.stripMargin) + } + } + } +} --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org