This is an automated email from the ASF dual-hosted git repository. ulyssesyou pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new a29c9653f3d4 [SPARK-46090][SQL] Support plan fragment level SQL configs in AQE a29c9653f3d4 is described below commit a29c9653f3d48d97875ae446d82896bdf0de61ca Author: ulysses-you <ulyssesyo...@gmail.com> AuthorDate: Fri May 24 14:31:52 2024 +0800 [SPARK-46090][SQL] Support plan fragment level SQL configs in AQE ### What changes were proposed in this pull request? This pr introduces `case class AdaptiveRuleContext(isSubquery: Boolean, isFinalStage: Boolean)` which can be used inside adaptive sql extension rules through thread local, so that developers can modify the next plan fragment configs using `AdaptiveRuleContext.get()`. The plan fragment configs can be propagated through multi-phases, e.g., if set a config in `queryPostPlannerStrategyRules` then the config can be gotten in `queryStagePrepRules`, `queryStageOptimizerRules` and `columnarRules`. The configs will be cleanup before going to execute, so in next round the configs will be empty. ### Why are the changes needed? To support modify the plan fragment level SQL configs through AQE rules. ### Does this PR introduce _any_ user-facing change? no, only affect developers. ### How was this patch tested? add new tests ### Was this patch authored or co-authored using generative AI tooling? no Closes #44013 from ulysses-you/rule-context. Lead-authored-by: ulysses-you <ulyssesyo...@gmail.com> Co-authored-by: Kent Yao <y...@apache.org> Signed-off-by: youxiduo <youxi...@corp.netease.com> --- .../execution/adaptive/AdaptiveRuleContext.scala | 89 +++++++++++ .../execution/adaptive/AdaptiveSparkPlanExec.scala | 42 ++++- .../adaptive/AdaptiveRuleContextSuite.scala | 176 +++++++++++++++++++++ 3 files changed, 299 insertions(+), 8 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveRuleContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveRuleContext.scala new file mode 100644 index 000000000000..fce20b79e113 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveRuleContext.scala @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.adaptive + +import scala.collection.mutable + +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.catalyst.SQLConfHelper + +/** + * Provide the functionality to modify the next plan fragment configs in AQE rules. + * The configs will be cleanup before going to execute next plan fragment. + * To get instance, use: {{{ AdaptiveRuleContext.get() }}} + * + * @param isSubquery if the input query plan is subquery + * @param isFinalStage if the next stage is final stage + */ +@Experimental +@DeveloperApi +case class AdaptiveRuleContext(isSubquery: Boolean, isFinalStage: Boolean) { + + /** + * Set SQL configs for next plan fragment. The configs will affect all of rules in AQE, + * i.e., the runtime optimizer, planner, queryStagePreparationRules, queryStageOptimizerRules, + * columnarRules. + * This configs will be cleared before going to get the next plan fragment. + */ + private val nextPlanFragmentConf = new mutable.HashMap[String, String]() + + private[sql] def withFinalStage(isFinalStage: Boolean): AdaptiveRuleContext = { + if (this.isFinalStage == isFinalStage) { + this + } else { + val newRuleContext = copy(isFinalStage = isFinalStage) + newRuleContext.setConfigs(this.configs()) + newRuleContext + } + } + + def setConfig(key: String, value: String): Unit = { + nextPlanFragmentConf.put(key, value) + } + + def setConfigs(kvs: Map[String, String]): Unit = { + kvs.foreach(kv => nextPlanFragmentConf.put(kv._1, kv._2)) + } + + private[sql] def configs(): Map[String, String] = nextPlanFragmentConf.toMap + + private[sql] def clearConfigs(): Unit = nextPlanFragmentConf.clear() +} + +object AdaptiveRuleContext extends SQLConfHelper { + private val ruleContextThreadLocal = new ThreadLocal[AdaptiveRuleContext] + + /** + * If a rule is applied inside AQE then the returned value is always defined, else return None. + */ + def get(): Option[AdaptiveRuleContext] = Option(ruleContextThreadLocal.get()) + + private[sql] def withRuleContext[T](ruleContext: AdaptiveRuleContext)(block: => T): T = { + assert(ruleContext != null) + val origin = ruleContextThreadLocal.get() + ruleContextThreadLocal.set(ruleContext) + try { + val conf = ruleContext.configs() + withSQLConf(conf.toSeq: _*) { + block + } + } finally { + ruleContextThreadLocal.set(origin) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala index f30ffaf51566..f21960aeedd6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala @@ -85,6 +85,25 @@ case class AdaptiveSparkPlanExec( case _ => logDebug(_) } + @transient private var ruleContext = new AdaptiveRuleContext( + isSubquery = isSubquery, + isFinalStage = false) + + private def withRuleContext[T](f: => T): T = + AdaptiveRuleContext.withRuleContext(ruleContext) { f } + + private def applyPhysicalRulesWithRuleContext( + plan: => SparkPlan, + rules: Seq[Rule[SparkPlan]], + loggerAndBatchName: Option[(PlanChangeLogger[SparkPlan], String)] = None): SparkPlan = { + // Apply the last rules if exists before going to apply the next batch of rules, + // so that we can propagate the configs. + val newPlan = plan + withRuleContext { + applyPhysicalRules(newPlan, rules, loggerAndBatchName) + } + } + @transient private val planChangeLogger = new PlanChangeLogger[SparkPlan]() // The logical plan optimizer for re-optimizing the current logical plan. @@ -161,7 +180,9 @@ case class AdaptiveSparkPlanExec( collapseCodegenStagesRule ) - private def optimizeQueryStage(plan: SparkPlan, isFinalStage: Boolean): SparkPlan = { + private def optimizeQueryStage( + plan: SparkPlan, + isFinalStage: Boolean): SparkPlan = withRuleContext { val rules = if (isFinalStage && !conf.getConf(SQLConf.ADAPTIVE_EXECUTION_APPLY_FINAL_STAGE_SHUFFLE_OPTIMIZATIONS)) { queryStageOptimizerRules.filterNot(_.isInstanceOf[AQEShuffleReadRule]) @@ -197,7 +218,7 @@ case class AdaptiveSparkPlanExec( } private def applyQueryPostPlannerStrategyRules(plan: SparkPlan): SparkPlan = { - applyPhysicalRules( + applyPhysicalRulesWithRuleContext( plan, context.session.sessionState.adaptiveRulesHolder.queryPostPlannerStrategyRules, Some((planChangeLogger, "AQE Query Post Planner Strategy Rules")) @@ -205,7 +226,7 @@ case class AdaptiveSparkPlanExec( } @transient val initialPlan = context.session.withActive { - applyPhysicalRules( + applyPhysicalRulesWithRuleContext( applyQueryPostPlannerStrategyRules(inputPlan), queryStagePreparationRules, Some((planChangeLogger, "AQE Preparations"))) @@ -282,6 +303,7 @@ case class AdaptiveSparkPlanExec( val errors = new mutable.ArrayBuffer[Throwable]() var stagesToReplace = Seq.empty[QueryStageExec] while (!result.allChildStagesMaterialized) { + ruleContext.clearConfigs() currentPhysicalPlan = result.newPlan if (result.newStages.nonEmpty) { stagesToReplace = result.newStages ++ stagesToReplace @@ -373,11 +395,13 @@ case class AdaptiveSparkPlanExec( result = createQueryStages(currentPhysicalPlan) } + ruleContext = ruleContext.withFinalStage(isFinalStage = true) // Run the final plan when there's no more unfinished stages. - currentPhysicalPlan = applyPhysicalRules( + currentPhysicalPlan = applyPhysicalRulesWithRuleContext( optimizeQueryStage(result.newPlan, isFinalStage = true), postStageCreationRules(supportsColumnar), Some((planChangeLogger, "AQE Post Stage Creation"))) + ruleContext.clearConfigs() _isFinalPlan = true executionId.foreach(onUpdatePlan(_, Seq(currentPhysicalPlan))) currentPhysicalPlan @@ -595,7 +619,7 @@ case class AdaptiveSparkPlanExec( val queryStage = plan match { case e: Exchange => val optimized = e.withNewChildren(Seq(optimizeQueryStage(e.child, isFinalStage = false))) - val newPlan = applyPhysicalRules( + val newPlan = applyPhysicalRulesWithRuleContext( optimized, postStageCreationRules(outputsColumnar = plan.supportsColumnar), Some((planChangeLogger, "AQE Post Stage Creation"))) @@ -722,9 +746,11 @@ case class AdaptiveSparkPlanExec( private def reOptimize(logicalPlan: LogicalPlan): Option[(SparkPlan, LogicalPlan)] = { try { logicalPlan.invalidateStatsCache() - val optimized = optimizer.execute(logicalPlan) - val sparkPlan = context.session.sessionState.planner.plan(ReturnAnswer(optimized)).next() - val newPlan = applyPhysicalRules( + val optimized = withRuleContext { optimizer.execute(logicalPlan) } + val sparkPlan = withRuleContext { + context.session.sessionState.planner.plan(ReturnAnswer(optimized)).next() + } + val newPlan = applyPhysicalRulesWithRuleContext( applyQueryPostPlannerStrategyRules(sparkPlan), preprocessingRules ++ queryStagePreparationRules, Some((planChangeLogger, "AQE Replanning"))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveRuleContextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveRuleContextSuite.scala new file mode 100644 index 000000000000..04c9e6c946b4 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveRuleContextSuite.scala @@ -0,0 +1,176 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.adaptive + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.{SparkSession, SparkSessionExtensionsProvider} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.{ColumnarRule, RangeExec, SparkPlan, SparkStrategy} +import org.apache.spark.sql.execution.aggregate.HashAggregateExec +import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec + +class AdaptiveRuleContextSuite extends SparkFunSuite with AdaptiveSparkPlanHelper { + + private def stop(spark: SparkSession): Unit = { + spark.stop() + SparkSession.clearActiveSession() + SparkSession.clearDefaultSession() + } + + private def withSession( + builders: Seq[SparkSessionExtensionsProvider])(f: SparkSession => Unit): Unit = { + val builder = SparkSession.builder().master("local[1]") + builders.foreach(builder.withExtensions) + val spark = builder.getOrCreate() + try f(spark) finally { + stop(spark) + } + } + + test("test adaptive rule context") { + withSession( + Seq(_.injectRuntimeOptimizerRule(_ => MyRuleContextForRuntimeOptimization), + _.injectPlannerStrategy(_ => MyRuleContextForPlannerStrategy), + _.injectQueryPostPlannerStrategyRule(_ => MyRuleContextForPostPlannerStrategyRule), + _.injectQueryStagePrepRule(_ => MyRuleContextForPreQueryStageRule), + _.injectQueryStageOptimizerRule(_ => MyRuleContextForQueryStageRule), + _.injectColumnar(_ => MyRuleContextForColumnarRule))) { spark => + val df = spark.range(1, 10, 1, 3).selectExpr("id % 3 as c").groupBy("c").count() + df.collect() + assert(collectFirst(df.queryExecution.executedPlan) { + case s: ShuffleExchangeExec if s.numPartitions == 2 => s + }.isDefined) + } + } + + test("test adaptive rule context with subquery") { + withSession( + Seq(_.injectQueryStagePrepRule(_ => MyRuleContextForQueryStageWithSubquery))) { spark => + spark.sql("select (select count(*) from range(10)), id from range(10)").collect() + } + } +} + +object MyRuleContext { + def checkAndGetRuleContext(): AdaptiveRuleContext = { + val ruleContextOpt = AdaptiveRuleContext.get() + assert(ruleContextOpt.isDefined) + ruleContextOpt.get + } + + def checkRuleContextForQueryStage(plan: SparkPlan): SparkPlan = { + val ruleContext = checkAndGetRuleContext() + assert(!ruleContext.isSubquery) + val stage = plan.find(_.isInstanceOf[ShuffleQueryStageExec]) + if (stage.isDefined && stage.get.asInstanceOf[ShuffleQueryStageExec].isMaterialized) { + assert(ruleContext.isFinalStage) + assert(!ruleContext.configs().get("spark.sql.shuffle.partitions").contains("2")) + } else { + assert(!ruleContext.isFinalStage) + assert(ruleContext.configs().get("spark.sql.shuffle.partitions").contains("2")) + } + plan + } +} + +object MyRuleContextForRuntimeOptimization extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = { + MyRuleContext.checkAndGetRuleContext() + plan + } +} + +object MyRuleContextForPlannerStrategy extends SparkStrategy { + override def apply(plan: LogicalPlan): Seq[SparkPlan] = { + plan match { + case _: LogicalQueryStage => + val ruleContext = MyRuleContext.checkAndGetRuleContext() + assert(!ruleContext.configs().get("spark.sql.shuffle.partitions").contains("2")) + Nil + case _ => Nil + } + } +} + +object MyRuleContextForPostPlannerStrategyRule extends Rule[SparkPlan] { + override def apply(plan: SparkPlan): SparkPlan = { + val ruleContext = MyRuleContext.checkAndGetRuleContext() + if (plan.find(_.isInstanceOf[RangeExec]).isDefined) { + ruleContext.setConfig("spark.sql.shuffle.partitions", "2") + } + plan + } +} + +object MyRuleContextForPreQueryStageRule extends Rule[SparkPlan] { + override def apply(plan: SparkPlan): SparkPlan = { + val ruleContext = MyRuleContext.checkAndGetRuleContext() + assert(!ruleContext.isFinalStage) + plan + } +} + +object MyRuleContextForQueryStageRule extends Rule[SparkPlan] { + override def apply(plan: SparkPlan): SparkPlan = { + MyRuleContext.checkRuleContextForQueryStage(plan) + } +} + +object MyRuleContextForColumnarRule extends ColumnarRule { + override def preColumnarTransitions: Rule[SparkPlan] = { + plan: SparkPlan => { + if (plan.isInstanceOf[AdaptiveSparkPlanExec]) { + // skip if we are not inside AQE + assert(AdaptiveRuleContext.get().isEmpty) + plan + } else { + MyRuleContext.checkRuleContextForQueryStage(plan) + } + } + } + + override def postColumnarTransitions: Rule[SparkPlan] = { + plan: SparkPlan => { + if (plan.isInstanceOf[AdaptiveSparkPlanExec]) { + // skip if we are not inside AQE + assert(AdaptiveRuleContext.get().isEmpty) + plan + } else { + MyRuleContext.checkRuleContextForQueryStage(plan) + } + } + } +} + +object MyRuleContextForQueryStageWithSubquery extends Rule[SparkPlan] { + override def apply(plan: SparkPlan): SparkPlan = { + val ruleContext = MyRuleContext.checkAndGetRuleContext() + if (plan.exists(_.isInstanceOf[HashAggregateExec])) { + assert(ruleContext.isSubquery) + if (plan.exists(_.isInstanceOf[RangeExec])) { + assert(!ruleContext.isFinalStage) + } else { + assert(ruleContext.isFinalStage) + } + } else { + assert(!ruleContext.isSubquery) + } + plan + } +} --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org