This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 883596a4bab [SPARK-38697][SQL] Extend SparkSessionExtensions to inject 
rules into AQE Optimizer
883596a4bab is described below

commit 883596a4bab36ddf0e1a5af0ba98325ca8582550
Author: ulysses-you <ulyssesyo...@gmail.com>
AuthorDate: Fri Apr 15 16:02:00 2022 +0800

    [SPARK-38697][SQL] Extend SparkSessionExtensions to inject rules into AQE 
Optimizer
    
    ### What changes were proposed in this pull request?
    
    Add `injectRuntimeOptimizerRule` public method in `SparkSessionExtensions`
    
    ### Why are the changes needed?
    
    Provide a entrance for user to play their logical plan with runtime 
optimizer in adaptive query execution framework.
    
    We should follow the existed Spark session extension to allow user inject 
the rule.
    
    So developers can improve the logical plan leverage accurate statistics 
from shuffle.
    
    ### Does this PR introduce _any_ user-facing change?
    
    yes, a new entrance for Spark session extension
    
    ### How was this patch tested?
    
    Add test
    
    Closes #36011 from ulysses-you/aqe-optimizer.
    
    Authored-by: ulysses-you <ulyssesyo...@gmail.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../apache/spark/sql/SparkSessionExtensions.scala  | 22 ++++++++++++++
 .../sql/execution/adaptive/AQEOptimizer.scala      | 10 ++++---
 .../execution/adaptive/AdaptiveRulesHolder.scala   | 30 +++++++++++++++++++
 .../execution/adaptive/AdaptiveSparkPlanExec.scala |  5 ++--
 .../sql/internal/BaseSessionStateBuilder.scala     | 11 ++++---
 .../apache/spark/sql/internal/SessionState.scala   |  4 +--
 .../spark/sql/SparkSessionExtensionSuite.scala     | 35 ++++++++++++++++++++--
 7 files changed, 102 insertions(+), 15 deletions(-)

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala
index a4ec48142cf..a8ccc39ac47 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala
@@ -47,6 +47,7 @@ import org.apache.spark.sql.execution.{ColumnarRule, 
SparkPlan}
  * <li>(External) Catalog listeners.</li>
  * <li>Columnar Rules.</li>
  * <li>Adaptive Query Stage Preparation Rules.</li>
+ * <li>Adaptive Query Execution Runtime Optimizer Rules.</li>
  * </ul>
  *
  * The extensions can be used by calling `withExtensions` on the 
[[SparkSession.Builder]], for
@@ -113,6 +114,7 @@ class SparkSessionExtensions {
 
   private[this] val columnarRuleBuilders = 
mutable.Buffer.empty[ColumnarRuleBuilder]
   private[this] val queryStagePrepRuleBuilders = 
mutable.Buffer.empty[QueryStagePrepRuleBuilder]
+  private[this] val runtimeOptimizerRules = mutable.Buffer.empty[RuleBuilder]
 
   /**
    * Build the override rules for columnar execution.
@@ -128,6 +130,13 @@ class SparkSessionExtensions {
     queryStagePrepRuleBuilders.map(_.apply(session)).toSeq
   }
 
+  /**
+   * Build the override rules for the optimizer of adaptive query execution.
+   */
+  private[sql] def buildRuntimeOptimizerRules(session: SparkSession): 
Seq[Rule[LogicalPlan]] = {
+    runtimeOptimizerRules.map(_.apply(session)).toSeq
+  }
+
   /**
    * Inject a rule that can override the columnar execution of an executor.
    */
@@ -143,6 +152,19 @@ class SparkSessionExtensions {
     queryStagePrepRuleBuilders += builder
   }
 
+  /**
+   * Inject a runtime `Rule` builder into the [[SparkSession]].
+   * The injected rules will be executed after built-in
+   * [[org.apache.spark.sql.execution.adaptive.AQEOptimizer]] rules are 
applied.
+   * A runtime optimizer rule is used to improve the quality of a logical plan 
during execution
+   * which can leverage accurate statistics from shuffle.
+   *
+   * Note that, it does not work if adaptive query execution is disabled.
+   */
+  def injectRuntimeOptimizerRule(builder: RuleBuilder): Unit = {
+    runtimeOptimizerRules += builder
+  }
+
   private[this] val resolutionRuleBuilders = mutable.Buffer.empty[RuleBuilder]
 
   /**
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEOptimizer.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEOptimizer.scala
index 5533bb1cd79..93fde72993e 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEOptimizer.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEOptimizer.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.adaptive
 import org.apache.spark.sql.catalyst.analysis.UpdateAttributeNullability
 import org.apache.spark.sql.catalyst.optimizer.{ConvertToLocalRelation, 
EliminateLimits, OptimizeOneRowPlan}
 import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, 
LogicalPlanIntegrity, PlanHelper}
-import org.apache.spark.sql.catalyst.rules.RuleExecutor
+import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor}
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types.DataType
 import org.apache.spark.util.Utils
@@ -28,7 +28,9 @@ import org.apache.spark.util.Utils
 /**
  * The optimizer for re-optimizing the logical plan used by 
AdaptiveSparkPlanExec.
  */
-class AQEOptimizer(conf: SQLConf) extends RuleExecutor[LogicalPlan] {
+class AQEOptimizer(conf: SQLConf, extendedRuntimeOptimizerRules: 
Seq[Rule[LogicalPlan]])
+  extends RuleExecutor[LogicalPlan] {
+
   private def fixedPoint =
     FixedPoint(
       conf.optimizerMaxIterations,
@@ -41,8 +43,8 @@ class AQEOptimizer(conf: SQLConf) extends 
RuleExecutor[LogicalPlan] {
       UpdateAttributeNullability),
     Batch("Dynamic Join Selection", Once, DynamicJoinSelection),
     Batch("Eliminate Limits", fixedPoint, EliminateLimits),
-    Batch("Optimize One Row Plan", fixedPoint, OptimizeOneRowPlan)
-  )
+    Batch("Optimize One Row Plan", fixedPoint, OptimizeOneRowPlan)) :+
+    Batch("User Provided Runtime Optimizers", fixedPoint, 
extendedRuntimeOptimizerRules: _*)
 
   final override protected def batches: Seq[Batch] = {
     val excludedRules = conf.getConf(SQLConf.ADAPTIVE_OPTIMIZER_EXCLUDED_RULES)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveRulesHolder.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveRulesHolder.scala
new file mode 100644
index 00000000000..11cae824568
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveRulesHolder.scala
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.adaptive
+
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.execution.SparkPlan
+
+/**
+ * A holder to warp the SQL extension rules of adaptive query execution
+ */
+class AdaptiveRulesHolder(
+    val queryStagePrepRules: Seq[Rule[SparkPlan]],
+    val runtimeOptimizerRules: Seq[Rule[LogicalPlan]]) {
+}
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
index 25380bc1d89..808959363ac 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
@@ -84,7 +84,8 @@ case class AdaptiveSparkPlanExec(
   @transient private val planChangeLogger = new PlanChangeLogger[SparkPlan]()
 
   // The logical plan optimizer for re-optimizing the current logical plan.
-  @transient private val optimizer = new AQEOptimizer(conf)
+  @transient private val optimizer = new AQEOptimizer(conf,
+    session.sessionState.adaptiveRulesHolder.runtimeOptimizerRules)
 
   // `EnsureRequirements` may remove user-specified repartition and assume the 
query plan won't
   // change its output partitioning. This assumption is not true in AQE. Here 
we check the
@@ -121,7 +122,7 @@ case class AdaptiveSparkPlanExec(
       RemoveRedundantSorts,
       DisableUnnecessaryBucketedScan,
       OptimizeSkewedJoin(ensureRequirements)
-    ) ++ context.session.sessionState.queryStagePrepRules
+    ) ++ context.session.sessionState.adaptiveRulesHolder.queryStagePrepRules
   }
 
   // A list of physical optimizer rules to be applied to a new stage before 
its execution. These
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
index 0655b946cc8..f3cbb789a94 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
@@ -27,7 +27,8 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.connector.catalog.CatalogManager
 import org.apache.spark.sql.errors.QueryCompilationErrors
-import org.apache.spark.sql.execution.{ColumnarRule, CommandExecutionMode, 
QueryExecution, SparkOptimizer, SparkPlan, SparkPlanner, SparkSqlParser}
+import org.apache.spark.sql.execution.{ColumnarRule, CommandExecutionMode, 
QueryExecution, SparkOptimizer, SparkPlanner, SparkSqlParser}
+import org.apache.spark.sql.execution.adaptive.AdaptiveRulesHolder
 import org.apache.spark.sql.execution.aggregate.{ResolveEncodersInScalaAgg, 
ScalaUDAF}
 import org.apache.spark.sql.execution.analysis.DetectAmbiguousSelfJoin
 import org.apache.spark.sql.execution.command.CommandCheck
@@ -308,8 +309,10 @@ abstract class BaseSessionStateBuilder(
     extensions.buildColumnarRules(session)
   }
 
-  protected def queryStagePrepRules: Seq[Rule[SparkPlan]] = {
-    extensions.buildQueryStagePrepRules(session)
+  protected def adaptiveRulesHolder: AdaptiveRulesHolder = {
+    new AdaptiveRulesHolder(
+      extensions.buildQueryStagePrepRules(session),
+      extensions.buildRuntimeOptimizerRules(session))
   }
 
   /**
@@ -366,7 +369,7 @@ abstract class BaseSessionStateBuilder(
       createQueryExecution,
       createClone,
       columnarRules,
-      queryStagePrepRules)
+      adaptiveRulesHolder)
   }
 }
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
index cdf764a7317..1d5e61aab26 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
@@ -30,9 +30,9 @@ import org.apache.spark.sql.catalyst.catalog._
 import org.apache.spark.sql.catalyst.optimizer.Optimizer
 import org.apache.spark.sql.catalyst.parser.ParserInterface
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
-import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.connector.catalog.CatalogManager
 import org.apache.spark.sql.execution._
+import org.apache.spark.sql.execution.adaptive.AdaptiveRulesHolder
 import org.apache.spark.sql.streaming.StreamingQueryManager
 import org.apache.spark.sql.util.ExecutionListenerManager
 import org.apache.spark.util.{DependencyUtils, Utils}
@@ -79,7 +79,7 @@ private[sql] class SessionState(
     createQueryExecution: (LogicalPlan, CommandExecutionMode.Value) => 
QueryExecution,
     createClone: (SparkSession, SessionState) => SessionState,
     val columnarRules: Seq[ColumnarRule],
-    val queryStagePrepRules: Seq[Rule[SparkPlan]]) {
+    val adaptiveRulesHolder: AdaptiveRulesHolder) {
 
   // The following fields are lazy to avoid creating the Hive client when 
creating SessionState.
   lazy val catalog: SessionCatalog = catalogBuilder()
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
index 17124cc2e4c..1aef458a352 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
@@ -27,7 +27,8 @@ import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow, 
TableIdentifier}
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, 
ParserInterface}
-import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, 
LogicalPlan, Statistics, UnresolvedHint}
+import org.apache.spark.sql.catalyst.plans.SQLHelper
+import org.apache.spark.sql.catalyst.plans.logical.{Limit, LocalRelation, 
LogicalPlan, Statistics, UnresolvedHint}
 import org.apache.spark.sql.catalyst.plans.physical.Partitioning
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.catalyst.trees.TreeNodeTag
@@ -45,7 +46,7 @@ import org.apache.spark.unsafe.types.UTF8String
 /**
  * Test cases for the [[SparkSessionExtensions]].
  */
-class SparkSessionExtensionSuite extends SparkFunSuite {
+class SparkSessionExtensionSuite extends SparkFunSuite with SQLHelper {
   private def create(
       builder: SparkSessionExtensionsProvider): 
Seq[SparkSessionExtensionsProvider] = Seq(builder)
 
@@ -171,7 +172,8 @@ class SparkSessionExtensionSuite extends SparkFunSuite {
     }
     withSession(extensions) { session =>
       session.conf.set(SQLConf.ADAPTIVE_EXECUTION_ENABLED, true)
-      
assert(session.sessionState.queryStagePrepRules.contains(MyQueryStagePrepRule()))
+      assert(session.sessionState.adaptiveRulesHolder.queryStagePrepRules
+        .contains(MyQueryStagePrepRule()))
       assert(session.sessionState.columnarRules.contains(
         MyColumnarRule(MyNewQueryStageRule(), MyNewQueryStageRule())))
       import session.sqlContext.implicits._
@@ -406,6 +408,26 @@ class SparkSessionExtensionSuite extends SparkFunSuite {
       session.sql("SELECT * FROM v")
     }
   }
+
+  test("SPARK-38697: Extend SparkSessionExtensions to inject rules into AQE 
Optimizer") {
+    def executedPlan(df: Dataset[java.lang.Long]): SparkPlan = {
+      
assert(df.queryExecution.executedPlan.isInstanceOf[AdaptiveSparkPlanExec])
+      
df.queryExecution.executedPlan.asInstanceOf[AdaptiveSparkPlanExec].executedPlan
+    }
+    val extensions = create { extensions =>
+      extensions.injectRuntimeOptimizerRule(_ => AddLimit)
+    }
+    withSession(extensions) { session =>
+      
assert(session.sessionState.adaptiveRulesHolder.runtimeOptimizerRules.contains(AddLimit))
+
+      withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") {
+        val df = session.range(2).repartition()
+        assert(!executedPlan(df).isInstanceOf[CollectLimitExec])
+        df.collect()
+        assert(executedPlan(df).isInstanceOf[CollectLimitExec])
+      }
+    }
+  }
 }
 
 case class MyRule(spark: SparkSession) extends Rule[LogicalPlan] {
@@ -1024,3 +1046,10 @@ class YourExtensions extends 
SparkSessionExtensionsProvider {
     v1.injectFunction(getAppName)
   }
 }
+
+object AddLimit extends Rule[LogicalPlan] {
+  override def apply(plan: LogicalPlan): LogicalPlan = plan match {
+    case Limit(_, _) => plan
+    case _ => Limit(Literal(1), plan)
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to