This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 883596a4bab [SPARK-38697][SQL] Extend SparkSessionExtensions to inject rules into AQE Optimizer 883596a4bab is described below commit 883596a4bab36ddf0e1a5af0ba98325ca8582550 Author: ulysses-you <ulyssesyo...@gmail.com> AuthorDate: Fri Apr 15 16:02:00 2022 +0800 [SPARK-38697][SQL] Extend SparkSessionExtensions to inject rules into AQE Optimizer ### What changes were proposed in this pull request? Add `injectRuntimeOptimizerRule` public method in `SparkSessionExtensions` ### Why are the changes needed? Provide a entrance for user to play their logical plan with runtime optimizer in adaptive query execution framework. We should follow the existed Spark session extension to allow user inject the rule. So developers can improve the logical plan leverage accurate statistics from shuffle. ### Does this PR introduce _any_ user-facing change? yes, a new entrance for Spark session extension ### How was this patch tested? Add test Closes #36011 from ulysses-you/aqe-optimizer. Authored-by: ulysses-you <ulyssesyo...@gmail.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../apache/spark/sql/SparkSessionExtensions.scala | 22 ++++++++++++++ .../sql/execution/adaptive/AQEOptimizer.scala | 10 ++++--- .../execution/adaptive/AdaptiveRulesHolder.scala | 30 +++++++++++++++++++ .../execution/adaptive/AdaptiveSparkPlanExec.scala | 5 ++-- .../sql/internal/BaseSessionStateBuilder.scala | 11 ++++--- .../apache/spark/sql/internal/SessionState.scala | 4 +-- .../spark/sql/SparkSessionExtensionSuite.scala | 35 ++++++++++++++++++++-- 7 files changed, 102 insertions(+), 15 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala index a4ec48142cf..a8ccc39ac47 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala @@ -47,6 +47,7 @@ import org.apache.spark.sql.execution.{ColumnarRule, SparkPlan} * <li>(External) Catalog listeners.</li> * <li>Columnar Rules.</li> * <li>Adaptive Query Stage Preparation Rules.</li> + * <li>Adaptive Query Execution Runtime Optimizer Rules.</li> * </ul> * * The extensions can be used by calling `withExtensions` on the [[SparkSession.Builder]], for @@ -113,6 +114,7 @@ class SparkSessionExtensions { private[this] val columnarRuleBuilders = mutable.Buffer.empty[ColumnarRuleBuilder] private[this] val queryStagePrepRuleBuilders = mutable.Buffer.empty[QueryStagePrepRuleBuilder] + private[this] val runtimeOptimizerRules = mutable.Buffer.empty[RuleBuilder] /** * Build the override rules for columnar execution. @@ -128,6 +130,13 @@ class SparkSessionExtensions { queryStagePrepRuleBuilders.map(_.apply(session)).toSeq } + /** + * Build the override rules for the optimizer of adaptive query execution. + */ + private[sql] def buildRuntimeOptimizerRules(session: SparkSession): Seq[Rule[LogicalPlan]] = { + runtimeOptimizerRules.map(_.apply(session)).toSeq + } + /** * Inject a rule that can override the columnar execution of an executor. */ @@ -143,6 +152,19 @@ class SparkSessionExtensions { queryStagePrepRuleBuilders += builder } + /** + * Inject a runtime `Rule` builder into the [[SparkSession]]. + * The injected rules will be executed after built-in + * [[org.apache.spark.sql.execution.adaptive.AQEOptimizer]] rules are applied. + * A runtime optimizer rule is used to improve the quality of a logical plan during execution + * which can leverage accurate statistics from shuffle. + * + * Note that, it does not work if adaptive query execution is disabled. + */ + def injectRuntimeOptimizerRule(builder: RuleBuilder): Unit = { + runtimeOptimizerRules += builder + } + private[this] val resolutionRuleBuilders = mutable.Buffer.empty[RuleBuilder] /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEOptimizer.scala index 5533bb1cd79..93fde72993e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEOptimizer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEOptimizer.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.adaptive import org.apache.spark.sql.catalyst.analysis.UpdateAttributeNullability import org.apache.spark.sql.catalyst.optimizer.{ConvertToLocalRelation, EliminateLimits, OptimizeOneRowPlan} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, LogicalPlanIntegrity, PlanHelper} -import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.DataType import org.apache.spark.util.Utils @@ -28,7 +28,9 @@ import org.apache.spark.util.Utils /** * The optimizer for re-optimizing the logical plan used by AdaptiveSparkPlanExec. */ -class AQEOptimizer(conf: SQLConf) extends RuleExecutor[LogicalPlan] { +class AQEOptimizer(conf: SQLConf, extendedRuntimeOptimizerRules: Seq[Rule[LogicalPlan]]) + extends RuleExecutor[LogicalPlan] { + private def fixedPoint = FixedPoint( conf.optimizerMaxIterations, @@ -41,8 +43,8 @@ class AQEOptimizer(conf: SQLConf) extends RuleExecutor[LogicalPlan] { UpdateAttributeNullability), Batch("Dynamic Join Selection", Once, DynamicJoinSelection), Batch("Eliminate Limits", fixedPoint, EliminateLimits), - Batch("Optimize One Row Plan", fixedPoint, OptimizeOneRowPlan) - ) + Batch("Optimize One Row Plan", fixedPoint, OptimizeOneRowPlan)) :+ + Batch("User Provided Runtime Optimizers", fixedPoint, extendedRuntimeOptimizerRules: _*) final override protected def batches: Seq[Batch] = { val excludedRules = conf.getConf(SQLConf.ADAPTIVE_OPTIMIZER_EXCLUDED_RULES) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveRulesHolder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveRulesHolder.scala new file mode 100644 index 00000000000..11cae824568 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveRulesHolder.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.adaptive + +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.SparkPlan + +/** + * A holder to warp the SQL extension rules of adaptive query execution + */ +class AdaptiveRulesHolder( + val queryStagePrepRules: Seq[Rule[SparkPlan]], + val runtimeOptimizerRules: Seq[Rule[LogicalPlan]]) { +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala index 25380bc1d89..808959363ac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala @@ -84,7 +84,8 @@ case class AdaptiveSparkPlanExec( @transient private val planChangeLogger = new PlanChangeLogger[SparkPlan]() // The logical plan optimizer for re-optimizing the current logical plan. - @transient private val optimizer = new AQEOptimizer(conf) + @transient private val optimizer = new AQEOptimizer(conf, + session.sessionState.adaptiveRulesHolder.runtimeOptimizerRules) // `EnsureRequirements` may remove user-specified repartition and assume the query plan won't // change its output partitioning. This assumption is not true in AQE. Here we check the @@ -121,7 +122,7 @@ case class AdaptiveSparkPlanExec( RemoveRedundantSorts, DisableUnnecessaryBucketedScan, OptimizeSkewedJoin(ensureRequirements) - ) ++ context.session.sessionState.queryStagePrepRules + ) ++ context.session.sessionState.adaptiveRulesHolder.queryStagePrepRules } // A list of physical optimizer rules to be applied to a new stage before its execution. These diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index 0655b946cc8..f3cbb789a94 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -27,7 +27,8 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.connector.catalog.CatalogManager import org.apache.spark.sql.errors.QueryCompilationErrors -import org.apache.spark.sql.execution.{ColumnarRule, CommandExecutionMode, QueryExecution, SparkOptimizer, SparkPlan, SparkPlanner, SparkSqlParser} +import org.apache.spark.sql.execution.{ColumnarRule, CommandExecutionMode, QueryExecution, SparkOptimizer, SparkPlanner, SparkSqlParser} +import org.apache.spark.sql.execution.adaptive.AdaptiveRulesHolder import org.apache.spark.sql.execution.aggregate.{ResolveEncodersInScalaAgg, ScalaUDAF} import org.apache.spark.sql.execution.analysis.DetectAmbiguousSelfJoin import org.apache.spark.sql.execution.command.CommandCheck @@ -308,8 +309,10 @@ abstract class BaseSessionStateBuilder( extensions.buildColumnarRules(session) } - protected def queryStagePrepRules: Seq[Rule[SparkPlan]] = { - extensions.buildQueryStagePrepRules(session) + protected def adaptiveRulesHolder: AdaptiveRulesHolder = { + new AdaptiveRulesHolder( + extensions.buildQueryStagePrepRules(session), + extensions.buildRuntimeOptimizerRules(session)) } /** @@ -366,7 +369,7 @@ abstract class BaseSessionStateBuilder( createQueryExecution, createClone, columnarRules, - queryStagePrepRules) + adaptiveRulesHolder) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala index cdf764a7317..1d5e61aab26 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala @@ -30,9 +30,9 @@ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.optimizer.Optimizer import org.apache.spark.sql.catalyst.parser.ParserInterface import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.connector.catalog.CatalogManager import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.adaptive.AdaptiveRulesHolder import org.apache.spark.sql.streaming.StreamingQueryManager import org.apache.spark.sql.util.ExecutionListenerManager import org.apache.spark.util.{DependencyUtils, Utils} @@ -79,7 +79,7 @@ private[sql] class SessionState( createQueryExecution: (LogicalPlan, CommandExecutionMode.Value) => QueryExecution, createClone: (SparkSession, SessionState) => SessionState, val columnarRules: Seq[ColumnarRule], - val queryStagePrepRules: Seq[Rule[SparkPlan]]) { + val adaptiveRulesHolder: AdaptiveRulesHolder) { // The following fields are lazy to avoid creating the Hive client when creating SessionState. lazy val catalog: SessionCatalog = catalogBuilder() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala index 17124cc2e4c..1aef458a352 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala @@ -27,7 +27,8 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow, TableIdentifier} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParserInterface} -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Statistics, UnresolvedHint} +import org.apache.spark.sql.catalyst.plans.SQLHelper +import org.apache.spark.sql.catalyst.plans.logical.{Limit, LocalRelation, LogicalPlan, Statistics, UnresolvedHint} import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.TreeNodeTag @@ -45,7 +46,7 @@ import org.apache.spark.unsafe.types.UTF8String /** * Test cases for the [[SparkSessionExtensions]]. */ -class SparkSessionExtensionSuite extends SparkFunSuite { +class SparkSessionExtensionSuite extends SparkFunSuite with SQLHelper { private def create( builder: SparkSessionExtensionsProvider): Seq[SparkSessionExtensionsProvider] = Seq(builder) @@ -171,7 +172,8 @@ class SparkSessionExtensionSuite extends SparkFunSuite { } withSession(extensions) { session => session.conf.set(SQLConf.ADAPTIVE_EXECUTION_ENABLED, true) - assert(session.sessionState.queryStagePrepRules.contains(MyQueryStagePrepRule())) + assert(session.sessionState.adaptiveRulesHolder.queryStagePrepRules + .contains(MyQueryStagePrepRule())) assert(session.sessionState.columnarRules.contains( MyColumnarRule(MyNewQueryStageRule(), MyNewQueryStageRule()))) import session.sqlContext.implicits._ @@ -406,6 +408,26 @@ class SparkSessionExtensionSuite extends SparkFunSuite { session.sql("SELECT * FROM v") } } + + test("SPARK-38697: Extend SparkSessionExtensions to inject rules into AQE Optimizer") { + def executedPlan(df: Dataset[java.lang.Long]): SparkPlan = { + assert(df.queryExecution.executedPlan.isInstanceOf[AdaptiveSparkPlanExec]) + df.queryExecution.executedPlan.asInstanceOf[AdaptiveSparkPlanExec].executedPlan + } + val extensions = create { extensions => + extensions.injectRuntimeOptimizerRule(_ => AddLimit) + } + withSession(extensions) { session => + assert(session.sessionState.adaptiveRulesHolder.runtimeOptimizerRules.contains(AddLimit)) + + withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { + val df = session.range(2).repartition() + assert(!executedPlan(df).isInstanceOf[CollectLimitExec]) + df.collect() + assert(executedPlan(df).isInstanceOf[CollectLimitExec]) + } + } + } } case class MyRule(spark: SparkSession) extends Rule[LogicalPlan] { @@ -1024,3 +1046,10 @@ class YourExtensions extends SparkSessionExtensionsProvider { v1.injectFunction(getAppName) } } + +object AddLimit extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan match { + case Limit(_, _) => plan + case _ => Limit(Literal(1), plan) + } +} --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org