This is an automated email from the ASF dual-hosted git repository.

ulyssesyou pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new a29c9653f3d4 [SPARK-46090][SQL] Support plan fragment level SQL 
configs in AQE
a29c9653f3d4 is described below

commit a29c9653f3d48d97875ae446d82896bdf0de61ca
Author: ulysses-you <ulyssesyo...@gmail.com>
AuthorDate: Fri May 24 14:31:52 2024 +0800

    [SPARK-46090][SQL] Support plan fragment level SQL configs in AQE
    
    ### What changes were proposed in this pull request?
    
    This pr introduces `case class AdaptiveRuleContext(isSubquery: Boolean, 
isFinalStage: Boolean)` which can be used inside adaptive sql extension rules 
through thread local, so that developers can modify the next plan fragment 
configs using `AdaptiveRuleContext.get()`.
    
    The plan fragment configs can be propagated through multi-phases, e.g., if 
set a config in `queryPostPlannerStrategyRules` then the config can be gotten 
in `queryStagePrepRules`, `queryStageOptimizerRules` and `columnarRules`. The 
configs will be cleanup before going to execute, so in next round the configs 
will be empty.
    
    ### Why are the changes needed?
    
    To support modify the plan fragment level SQL configs through AQE rules.
    
    ### Does this PR introduce _any_ user-facing change?
    
    no, only affect developers.
    
    ### How was this patch tested?
    
    add new tests
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    no
    
    Closes #44013 from ulysses-you/rule-context.
    
    Lead-authored-by: ulysses-you <ulyssesyo...@gmail.com>
    Co-authored-by: Kent Yao <y...@apache.org>
    Signed-off-by: youxiduo <youxi...@corp.netease.com>
---
 .../execution/adaptive/AdaptiveRuleContext.scala   |  89 +++++++++++
 .../execution/adaptive/AdaptiveSparkPlanExec.scala |  42 ++++-
 .../adaptive/AdaptiveRuleContextSuite.scala        | 176 +++++++++++++++++++++
 3 files changed, 299 insertions(+), 8 deletions(-)

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveRuleContext.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveRuleContext.scala
new file mode 100644
index 000000000000..fce20b79e113
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveRuleContext.scala
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.adaptive
+
+import scala.collection.mutable
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.sql.catalyst.SQLConfHelper
+
+/**
+ * Provide the functionality to modify the next plan fragment configs in AQE 
rules.
+ * The configs will be cleanup before going to execute next plan fragment.
+ * To get instance, use: {{{ AdaptiveRuleContext.get() }}}
+ *
+ * @param isSubquery if the input query plan is subquery
+ * @param isFinalStage if the next stage is final stage
+ */
+@Experimental
+@DeveloperApi
+case class AdaptiveRuleContext(isSubquery: Boolean, isFinalStage: Boolean) {
+
+  /**
+   * Set SQL configs for next plan fragment. The configs will affect all of 
rules in AQE,
+   * i.e., the runtime optimizer, planner, queryStagePreparationRules, 
queryStageOptimizerRules,
+   * columnarRules.
+   * This configs will be cleared before going to get the next plan fragment.
+   */
+  private val nextPlanFragmentConf = new mutable.HashMap[String, String]()
+
+  private[sql] def withFinalStage(isFinalStage: Boolean): AdaptiveRuleContext 
= {
+    if (this.isFinalStage == isFinalStage) {
+      this
+    } else {
+      val newRuleContext = copy(isFinalStage = isFinalStage)
+      newRuleContext.setConfigs(this.configs())
+      newRuleContext
+    }
+  }
+
+  def setConfig(key: String, value: String): Unit = {
+    nextPlanFragmentConf.put(key, value)
+  }
+
+  def setConfigs(kvs: Map[String, String]): Unit = {
+    kvs.foreach(kv => nextPlanFragmentConf.put(kv._1, kv._2))
+  }
+
+  private[sql] def configs(): Map[String, String] = nextPlanFragmentConf.toMap
+
+  private[sql] def clearConfigs(): Unit = nextPlanFragmentConf.clear()
+}
+
+object AdaptiveRuleContext extends SQLConfHelper {
+  private val ruleContextThreadLocal = new ThreadLocal[AdaptiveRuleContext]
+
+  /**
+   * If a rule is applied inside AQE then the returned value is always 
defined, else return None.
+   */
+  def get(): Option[AdaptiveRuleContext] = Option(ruleContextThreadLocal.get())
+
+  private[sql] def withRuleContext[T](ruleContext: AdaptiveRuleContext)(block: 
=> T): T = {
+    assert(ruleContext != null)
+    val origin = ruleContextThreadLocal.get()
+    ruleContextThreadLocal.set(ruleContext)
+    try {
+      val conf = ruleContext.configs()
+      withSQLConf(conf.toSeq: _*) {
+        block
+      }
+    } finally {
+      ruleContextThreadLocal.set(origin)
+    }
+  }
+}
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
index f30ffaf51566..f21960aeedd6 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
@@ -85,6 +85,25 @@ case class AdaptiveSparkPlanExec(
       case _ => logDebug(_)
     }
 
+  @transient private var ruleContext = new AdaptiveRuleContext(
+    isSubquery = isSubquery,
+    isFinalStage = false)
+
+  private def withRuleContext[T](f: => T): T =
+    AdaptiveRuleContext.withRuleContext(ruleContext) { f }
+
+  private def applyPhysicalRulesWithRuleContext(
+      plan: => SparkPlan,
+      rules: Seq[Rule[SparkPlan]],
+      loggerAndBatchName: Option[(PlanChangeLogger[SparkPlan], String)] = 
None): SparkPlan = {
+    // Apply the last rules if exists before going to apply the next batch of 
rules,
+    // so that we can propagate the configs.
+    val newPlan = plan
+    withRuleContext {
+      applyPhysicalRules(newPlan, rules, loggerAndBatchName)
+    }
+  }
+
   @transient private val planChangeLogger = new PlanChangeLogger[SparkPlan]()
 
   // The logical plan optimizer for re-optimizing the current logical plan.
@@ -161,7 +180,9 @@ case class AdaptiveSparkPlanExec(
     collapseCodegenStagesRule
   )
 
-  private def optimizeQueryStage(plan: SparkPlan, isFinalStage: Boolean): 
SparkPlan = {
+  private def optimizeQueryStage(
+      plan: SparkPlan,
+      isFinalStage: Boolean): SparkPlan = withRuleContext {
     val rules = if (isFinalStage &&
         
!conf.getConf(SQLConf.ADAPTIVE_EXECUTION_APPLY_FINAL_STAGE_SHUFFLE_OPTIMIZATIONS))
 {
       queryStageOptimizerRules.filterNot(_.isInstanceOf[AQEShuffleReadRule])
@@ -197,7 +218,7 @@ case class AdaptiveSparkPlanExec(
   }
 
   private def applyQueryPostPlannerStrategyRules(plan: SparkPlan): SparkPlan = 
{
-    applyPhysicalRules(
+    applyPhysicalRulesWithRuleContext(
       plan,
       
context.session.sessionState.adaptiveRulesHolder.queryPostPlannerStrategyRules,
       Some((planChangeLogger, "AQE Query Post Planner Strategy Rules"))
@@ -205,7 +226,7 @@ case class AdaptiveSparkPlanExec(
   }
 
   @transient val initialPlan = context.session.withActive {
-    applyPhysicalRules(
+    applyPhysicalRulesWithRuleContext(
       applyQueryPostPlannerStrategyRules(inputPlan),
       queryStagePreparationRules,
       Some((planChangeLogger, "AQE Preparations")))
@@ -282,6 +303,7 @@ case class AdaptiveSparkPlanExec(
       val errors = new mutable.ArrayBuffer[Throwable]()
       var stagesToReplace = Seq.empty[QueryStageExec]
       while (!result.allChildStagesMaterialized) {
+        ruleContext.clearConfigs()
         currentPhysicalPlan = result.newPlan
         if (result.newStages.nonEmpty) {
           stagesToReplace = result.newStages ++ stagesToReplace
@@ -373,11 +395,13 @@ case class AdaptiveSparkPlanExec(
         result = createQueryStages(currentPhysicalPlan)
       }
 
+      ruleContext = ruleContext.withFinalStage(isFinalStage = true)
       // Run the final plan when there's no more unfinished stages.
-      currentPhysicalPlan = applyPhysicalRules(
+      currentPhysicalPlan = applyPhysicalRulesWithRuleContext(
         optimizeQueryStage(result.newPlan, isFinalStage = true),
         postStageCreationRules(supportsColumnar),
         Some((planChangeLogger, "AQE Post Stage Creation")))
+      ruleContext.clearConfigs()
       _isFinalPlan = true
       executionId.foreach(onUpdatePlan(_, Seq(currentPhysicalPlan)))
       currentPhysicalPlan
@@ -595,7 +619,7 @@ case class AdaptiveSparkPlanExec(
     val queryStage = plan match {
       case e: Exchange =>
         val optimized = e.withNewChildren(Seq(optimizeQueryStage(e.child, 
isFinalStage = false)))
-        val newPlan = applyPhysicalRules(
+        val newPlan = applyPhysicalRulesWithRuleContext(
           optimized,
           postStageCreationRules(outputsColumnar = plan.supportsColumnar),
           Some((planChangeLogger, "AQE Post Stage Creation")))
@@ -722,9 +746,11 @@ case class AdaptiveSparkPlanExec(
   private def reOptimize(logicalPlan: LogicalPlan): Option[(SparkPlan, 
LogicalPlan)] = {
     try {
       logicalPlan.invalidateStatsCache()
-      val optimized = optimizer.execute(logicalPlan)
-      val sparkPlan = 
context.session.sessionState.planner.plan(ReturnAnswer(optimized)).next()
-      val newPlan = applyPhysicalRules(
+      val optimized = withRuleContext { optimizer.execute(logicalPlan) }
+      val sparkPlan = withRuleContext {
+        
context.session.sessionState.planner.plan(ReturnAnswer(optimized)).next()
+      }
+      val newPlan = applyPhysicalRulesWithRuleContext(
         applyQueryPostPlannerStrategyRules(sparkPlan),
         preprocessingRules ++ queryStagePreparationRules,
         Some((planChangeLogger, "AQE Replanning")))
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveRuleContextSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveRuleContextSuite.scala
new file mode 100644
index 000000000000..04c9e6c946b4
--- /dev/null
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveRuleContextSuite.scala
@@ -0,0 +1,176 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.adaptive
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.{SparkSession, SparkSessionExtensionsProvider}
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.execution.{ColumnarRule, RangeExec, SparkPlan, 
SparkStrategy}
+import org.apache.spark.sql.execution.aggregate.HashAggregateExec
+import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
+
+class AdaptiveRuleContextSuite extends SparkFunSuite with 
AdaptiveSparkPlanHelper {
+
+  private def stop(spark: SparkSession): Unit = {
+    spark.stop()
+    SparkSession.clearActiveSession()
+    SparkSession.clearDefaultSession()
+  }
+
+  private def withSession(
+      builders: Seq[SparkSessionExtensionsProvider])(f: SparkSession => Unit): 
Unit = {
+    val builder = SparkSession.builder().master("local[1]")
+    builders.foreach(builder.withExtensions)
+    val spark = builder.getOrCreate()
+    try f(spark) finally {
+      stop(spark)
+    }
+  }
+
+  test("test adaptive rule context") {
+    withSession(
+      Seq(_.injectRuntimeOptimizerRule(_ => 
MyRuleContextForRuntimeOptimization),
+        _.injectPlannerStrategy(_ => MyRuleContextForPlannerStrategy),
+        _.injectQueryPostPlannerStrategyRule(_ => 
MyRuleContextForPostPlannerStrategyRule),
+        _.injectQueryStagePrepRule(_ => MyRuleContextForPreQueryStageRule),
+        _.injectQueryStageOptimizerRule(_ => MyRuleContextForQueryStageRule),
+        _.injectColumnar(_ => MyRuleContextForColumnarRule))) { spark =>
+      val df = spark.range(1, 10, 1, 3).selectExpr("id % 3 as 
c").groupBy("c").count()
+      df.collect()
+      assert(collectFirst(df.queryExecution.executedPlan) {
+        case s: ShuffleExchangeExec if s.numPartitions == 2 => s
+      }.isDefined)
+    }
+  }
+
+  test("test adaptive rule context with subquery") {
+    withSession(
+      Seq(_.injectQueryStagePrepRule(_ => 
MyRuleContextForQueryStageWithSubquery))) { spark =>
+      spark.sql("select (select count(*) from range(10)), id from 
range(10)").collect()
+    }
+  }
+}
+
+object MyRuleContext {
+  def checkAndGetRuleContext(): AdaptiveRuleContext = {
+    val ruleContextOpt = AdaptiveRuleContext.get()
+    assert(ruleContextOpt.isDefined)
+    ruleContextOpt.get
+  }
+
+  def checkRuleContextForQueryStage(plan: SparkPlan): SparkPlan = {
+    val ruleContext = checkAndGetRuleContext()
+    assert(!ruleContext.isSubquery)
+    val stage = plan.find(_.isInstanceOf[ShuffleQueryStageExec])
+    if (stage.isDefined && 
stage.get.asInstanceOf[ShuffleQueryStageExec].isMaterialized) {
+      assert(ruleContext.isFinalStage)
+      
assert(!ruleContext.configs().get("spark.sql.shuffle.partitions").contains("2"))
+    } else {
+      assert(!ruleContext.isFinalStage)
+      
assert(ruleContext.configs().get("spark.sql.shuffle.partitions").contains("2"))
+    }
+    plan
+  }
+}
+
+object MyRuleContextForRuntimeOptimization extends Rule[LogicalPlan] {
+  override def apply(plan: LogicalPlan): LogicalPlan = {
+    MyRuleContext.checkAndGetRuleContext()
+    plan
+  }
+}
+
+object MyRuleContextForPlannerStrategy extends SparkStrategy {
+  override def apply(plan: LogicalPlan): Seq[SparkPlan] = {
+    plan match {
+      case _: LogicalQueryStage =>
+        val ruleContext = MyRuleContext.checkAndGetRuleContext()
+        
assert(!ruleContext.configs().get("spark.sql.shuffle.partitions").contains("2"))
+        Nil
+      case _ => Nil
+    }
+  }
+}
+
+object MyRuleContextForPostPlannerStrategyRule extends Rule[SparkPlan] {
+  override def apply(plan: SparkPlan): SparkPlan = {
+    val ruleContext = MyRuleContext.checkAndGetRuleContext()
+    if (plan.find(_.isInstanceOf[RangeExec]).isDefined) {
+      ruleContext.setConfig("spark.sql.shuffle.partitions", "2")
+    }
+    plan
+  }
+}
+
+object MyRuleContextForPreQueryStageRule extends Rule[SparkPlan] {
+  override def apply(plan: SparkPlan): SparkPlan = {
+    val ruleContext = MyRuleContext.checkAndGetRuleContext()
+    assert(!ruleContext.isFinalStage)
+    plan
+  }
+}
+
+object MyRuleContextForQueryStageRule extends Rule[SparkPlan] {
+  override def apply(plan: SparkPlan): SparkPlan = {
+    MyRuleContext.checkRuleContextForQueryStage(plan)
+  }
+}
+
+object MyRuleContextForColumnarRule extends ColumnarRule {
+  override def preColumnarTransitions: Rule[SparkPlan] = {
+    plan: SparkPlan => {
+      if (plan.isInstanceOf[AdaptiveSparkPlanExec]) {
+        // skip if we are not inside AQE
+        assert(AdaptiveRuleContext.get().isEmpty)
+        plan
+      } else {
+        MyRuleContext.checkRuleContextForQueryStage(plan)
+      }
+    }
+  }
+
+  override def postColumnarTransitions: Rule[SparkPlan] = {
+    plan: SparkPlan => {
+      if (plan.isInstanceOf[AdaptiveSparkPlanExec]) {
+        // skip if we are not inside AQE
+        assert(AdaptiveRuleContext.get().isEmpty)
+        plan
+      } else {
+        MyRuleContext.checkRuleContextForQueryStage(plan)
+      }
+    }
+  }
+}
+
+object MyRuleContextForQueryStageWithSubquery extends Rule[SparkPlan] {
+  override def apply(plan: SparkPlan): SparkPlan = {
+    val ruleContext = MyRuleContext.checkAndGetRuleContext()
+    if (plan.exists(_.isInstanceOf[HashAggregateExec])) {
+      assert(ruleContext.isSubquery)
+      if (plan.exists(_.isInstanceOf[RangeExec])) {
+        assert(!ruleContext.isFinalStage)
+      } else {
+        assert(ruleContext.isFinalStage)
+      }
+    } else {
+      assert(!ruleContext.isSubquery)
+    }
+    plan
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to