Github user wangyum commented on a diff in the pull request: https://github.com/apache/spark/pull/20303#discussion_r164070918 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStage.scala --- @@ -0,0 +1,222 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.adaptive + +import scala.concurrent.{ExecutionContext, Future} +import scala.concurrent.duration.Duration + +import org.apache.spark.MapOutputStatistics +import org.apache.spark.broadcast +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical.Partitioning +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.exchange._ +import org.apache.spark.sql.execution.ui.SparkListenerSQLAdaptiveExecutionUpdate +import org.apache.spark.util.ThreadUtils + +/** + * In adaptive execution mode, an execution plan is divided into multiple QueryStages. Each + * QueryStage is a sub-tree that runs in a single stage. + */ +abstract class QueryStage extends UnaryExecNode { + + var child: SparkPlan + + // Ignore this wrapper for canonicalizing. + override def doCanonicalize(): SparkPlan = child.canonicalized + + override def output: Seq[Attribute] = child.output + + override def outputPartitioning: Partitioning = child.outputPartitioning + + override def outputOrdering: Seq[SortOrder] = child.outputOrdering + + /** + * Execute childStages and wait until all stages are completed. Use a thread pool to avoid + * blocking on one child stage. + */ + def executeChildStages(): Unit = { + // Handle broadcast stages + val broadcastQueryStages: Seq[BroadcastQueryStage] = child.collect { + case bqs: BroadcastQueryStageInput => bqs.childStage + } + val broadcastFutures = broadcastQueryStages.map { queryStage => + Future { queryStage.prepareBroadcast() }(QueryStage.executionContext) + } + + // Submit shuffle stages + val executionId = sqlContext.sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) + val shuffleQueryStages: Seq[ShuffleQueryStage] = child.collect { + case sqs: ShuffleQueryStageInput => sqs.childStage + } + val shuffleStageFutures = shuffleQueryStages.map { queryStage => + Future { + SQLExecution.withExecutionId(sqlContext.sparkContext, executionId) { + queryStage.execute() + } + }(QueryStage.executionContext) + } + + ThreadUtils.awaitResult( + Future.sequence(broadcastFutures)(implicitly, QueryStage.executionContext), Duration.Inf) + ThreadUtils.awaitResult( + Future.sequence(shuffleStageFutures)(implicitly, QueryStage.executionContext), Duration.Inf) + } + + /** + * Before executing the plan in this query stage, we execute all child stages, optimize the plan + * in this stage and determine the reducer number based on the child stages' statistics. Finally + * we do a codegen for this query stage and update the UI with the new plan. + */ + def prepareExecuteStage(): Unit = { + // 1. Execute childStages + executeChildStages() + // It is possible to optimize this stage's plan here based on the child stages' statistics. + + // 2. Determine reducer number + val queryStageInputs: Seq[ShuffleQueryStageInput] = child.collect { + case input: ShuffleQueryStageInput => input + } + val childMapOutputStatistics = queryStageInputs.map(_.childStage.mapOutputStatistics) + .filter(_ != null).toArray + if (childMapOutputStatistics.length > 0) { + val exchangeCoordinator = new ExchangeCoordinator( + conf.targetPostShuffleInputSize, + conf.minNumPostShufflePartitions) + + val partitionStartIndices = + exchangeCoordinator.estimatePartitionStartIndices(childMapOutputStatistics) + child = child.transform { + case ShuffleQueryStageInput(childStage, output, _) => + ShuffleQueryStageInput(childStage, output, Some(partitionStartIndices)) + } + } + + // 3. Codegen and update the UI + child = CollapseCodegenStages(sqlContext.conf).apply(child) --- End diff -- Change this line to: ```scala child = child match { case s: WholeStageCodegenExec => s case other => CollapseCodegenStages(sqlContext.conf).apply(other) } ``` ?
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org