Repository: spark
Updated Branches:
  refs/heads/branch-2.2 f971ce5dd -> f0de60079


[SPARK-18127] Add hooks and extension points to Spark

## What changes were proposed in this pull request?

This patch adds support for customizing the spark session by injecting 
user-defined custom extensions. This allows a user to add custom analyzer 
rules/checks, optimizer rules, planning strategies or even a customized parser.

## How was this patch tested?

Unit Tests in SparkSessionExtensionSuite

Author: Sameer Agarwal <samee...@cs.berkeley.edu>

Closes #17724 from sameeragarwal/session-extensions.

(cherry picked from commit caf392025ce21d701b503112060fa016d5eabe04)
Signed-off-by: Xiao Li <gatorsm...@gmail.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/f0de6007
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/f0de6007
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/f0de6007

Branch: refs/heads/branch-2.2
Commit: f0de600797ff4883927d0c70732675fd8629e239
Parents: f971ce5
Author: Sameer Agarwal <samee...@cs.berkeley.edu>
Authored: Tue Apr 25 17:05:20 2017 -0700
Committer: Xiao Li <gatorsm...@gmail.com>
Committed: Tue Apr 25 17:05:41 2017 -0700

----------------------------------------------------------------------
 .../spark/sql/catalyst/parser/ParseDriver.scala |   9 +-
 .../sql/catalyst/parser/ParserInterface.scala   |  35 +++-
 .../spark/sql/internal/StaticSQLConf.scala      |   6 +
 .../org/apache/spark/sql/SparkSession.scala     |  45 ++++-
 .../spark/sql/SparkSessionExtensions.scala      | 171 +++++++++++++++++++
 .../sql/internal/BaseSessionStateBuilder.scala  |  33 +++-
 .../spark/sql/SparkSessionExtensionSuite.scala  | 144 ++++++++++++++++
 7 files changed, 418 insertions(+), 25 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/f0de6007/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala
index 80ab75c..dcccbd0 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala
@@ -34,8 +34,7 @@ import org.apache.spark.sql.types.{DataType, StructType}
 abstract class AbstractSqlParser extends ParserInterface with Logging {
 
   /** Creates/Resolves DataType for a given SQL string. */
-  def parseDataType(sqlText: String): DataType = parse(sqlText) { parser =>
-    // TODO add this to the parser interface.
+  override def parseDataType(sqlText: String): DataType = parse(sqlText) { 
parser =>
     astBuilder.visitSingleDataType(parser.singleDataType())
   }
 
@@ -50,8 +49,10 @@ abstract class AbstractSqlParser extends ParserInterface 
with Logging {
   }
 
   /** Creates FunctionIdentifier for a given SQL string. */
-  def parseFunctionIdentifier(sqlText: String): FunctionIdentifier = 
parse(sqlText) { parser =>
-    astBuilder.visitSingleFunctionIdentifier(parser.singleFunctionIdentifier())
+  override def parseFunctionIdentifier(sqlText: String): FunctionIdentifier = {
+    parse(sqlText) { parser =>
+      
astBuilder.visitSingleFunctionIdentifier(parser.singleFunctionIdentifier())
+    }
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/f0de6007/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserInterface.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserInterface.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserInterface.scala
index db3598b..75240d2 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserInterface.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserInterface.scala
@@ -17,30 +17,51 @@
 
 package org.apache.spark.sql.catalyst.parser
 
+import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
 import org.apache.spark.sql.catalyst.expressions.Expression
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.types.{DataType, StructType}
 
 /**
  * Interface for a parser.
  */
+@DeveloperApi
 trait ParserInterface {
-  /** Creates LogicalPlan for a given SQL string. */
+  /**
+   * Parse a string to a [[LogicalPlan]].
+   */
+  @throws[ParseException]("Text cannot be parsed to a LogicalPlan")
   def parsePlan(sqlText: String): LogicalPlan
 
-  /** Creates Expression for a given SQL string. */
+  /**
+   * Parse a string to an [[Expression]].
+   */
+  @throws[ParseException]("Text cannot be parsed to an Expression")
   def parseExpression(sqlText: String): Expression
 
-  /** Creates TableIdentifier for a given SQL string. */
+  /**
+   * Parse a string to a [[TableIdentifier]].
+   */
+  @throws[ParseException]("Text cannot be parsed to a TableIdentifier")
   def parseTableIdentifier(sqlText: String): TableIdentifier
 
-  /** Creates FunctionIdentifier for a given SQL string. */
+  /**
+   * Parse a string to a [[FunctionIdentifier]].
+   */
+  @throws[ParseException]("Text cannot be parsed to a FunctionIdentifier")
   def parseFunctionIdentifier(sqlText: String): FunctionIdentifier
 
   /**
-   * Creates StructType for a given SQL string, which is a comma separated 
list of field
-   * definitions which will preserve the correct Hive metadata.
+   * Parse a string to a [[StructType]]. The passed SQL string should be a 
comma separated list
+   * of field definitions which will preserve the correct Hive metadata.
    */
+  @throws[ParseException]("Text cannot be parsed to a schema")
   def parseTableSchema(sqlText: String): StructType
+
+  /**
+   * Parse a string to a [[DataType]].
+   */
+  @throws[ParseException]("Text cannot be parsed to a DataType")
+  def parseDataType(sqlText: String): DataType
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/f0de6007/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala
index af1a9ce..c6c0a60 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala
@@ -81,4 +81,10 @@ object StaticSQLConf {
         "SQL configuration and the current database.")
       .booleanConf
       .createWithDefault(false)
+
+  val SPARK_SESSION_EXTENSIONS = buildStaticConf("spark.sql.extensions")
+    .doc("Name of the class used to configure Spark Session extensions. The 
class should " +
+      "implement Function1[SparkSessionExtension, Unit], and must have a 
no-args constructor.")
+    .stringConf
+    .createOptional
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/f0de6007/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
index 95f3463..a519492 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -38,7 +38,7 @@ import 
org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Range}
 import org.apache.spark.sql.execution._
 import org.apache.spark.sql.execution.datasources.LogicalRelation
 import org.apache.spark.sql.execution.ui.SQLListener
-import org.apache.spark.sql.internal.{BaseSessionStateBuilder, CatalogImpl, 
SessionState, SessionStateBuilder, SharedState}
+import org.apache.spark.sql.internal._
 import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION
 import org.apache.spark.sql.sources.BaseRelation
 import org.apache.spark.sql.streaming._
@@ -77,11 +77,12 @@ import org.apache.spark.util.Utils
 class SparkSession private(
     @transient val sparkContext: SparkContext,
     @transient private val existingSharedState: Option[SharedState],
-    @transient private val parentSessionState: Option[SessionState])
+    @transient private val parentSessionState: Option[SessionState],
+    @transient private[sql] val extensions: SparkSessionExtensions)
   extends Serializable with Closeable with Logging { self =>
 
   private[sql] def this(sc: SparkContext) {
-    this(sc, None, None)
+    this(sc, None, None, new SparkSessionExtensions)
   }
 
   sparkContext.assertNotStopped()
@@ -219,7 +220,7 @@ class SparkSession private(
    * @since 2.0.0
    */
   def newSession(): SparkSession = {
-    new SparkSession(sparkContext, Some(sharedState), parentSessionState = 
None)
+    new SparkSession(sparkContext, Some(sharedState), parentSessionState = 
None, extensions)
   }
 
   /**
@@ -235,7 +236,7 @@ class SparkSession private(
    * implementation is Hive, this will initialize the metastore, which may 
take some time.
    */
   private[sql] def cloneSession(): SparkSession = {
-    val result = new SparkSession(sparkContext, Some(sharedState), 
Some(sessionState))
+    val result = new SparkSession(sparkContext, Some(sharedState), 
Some(sessionState), extensions)
     result.sessionState // force copy of SessionState
     result
   }
@@ -754,6 +755,8 @@ object SparkSession {
 
     private[this] val options = new scala.collection.mutable.HashMap[String, 
String]
 
+    private[this] val extensions = new SparkSessionExtensions
+
     private[this] var userSuppliedContext: Option[SparkContext] = None
 
     private[spark] def sparkContext(sparkContext: SparkContext): Builder = 
synchronized {
@@ -848,6 +851,17 @@ object SparkSession {
     }
 
     /**
+     * Inject extensions into the [[SparkSession]]. This allows a user to add 
Analyzer rules,
+     * Optimizer rules, Planning Strategies or a customized parser.
+     *
+     * @since 2.2.0
+     */
+    def withExtensions(f: SparkSessionExtensions => Unit): Builder = {
+      f(extensions)
+      this
+    }
+
+    /**
      * Gets an existing [[SparkSession]] or, if there is no existing one, 
creates a new
      * one based on the options set in this builder.
      *
@@ -903,7 +917,26 @@ object SparkSession {
           }
           sc
         }
-        session = new SparkSession(sparkContext)
+
+        // Initialize extensions if the user has defined a configurator class.
+        val extensionConfOption = 
sparkContext.conf.get(StaticSQLConf.SPARK_SESSION_EXTENSIONS)
+        if (extensionConfOption.isDefined) {
+          val extensionConfClassName = extensionConfOption.get
+          try {
+            val extensionConfClass = Utils.classForName(extensionConfClassName)
+            val extensionConf = extensionConfClass.newInstance()
+              .asInstanceOf[SparkSessionExtensions => Unit]
+            extensionConf(extensions)
+          } catch {
+            // Ignore the error if we cannot find the class or when the class 
has the wrong type.
+            case e @ (_: ClassCastException |
+                      _: ClassNotFoundException |
+                      _: NoClassDefFoundError) =>
+              logWarning(s"Cannot use $extensionConfClassName to configure 
session extensions.", e)
+          }
+        }
+
+        session = new SparkSession(sparkContext, None, None, extensions)
         options.foreach { case (k, v) => 
session.sessionState.conf.setConfString(k, v) }
         defaultSession.set(session)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/f0de6007/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala
new file mode 100644
index 0000000..f99c108
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala
@@ -0,0 +1,171 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import scala.collection.mutable
+
+import org.apache.spark.annotation.{DeveloperApi, Experimental, 
InterfaceStability}
+import org.apache.spark.sql.catalyst.parser.ParserInterface
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.rules.Rule
+
+/**
+ * :: Experimental ::
+ * Holder for injection points to the [[SparkSession]]. We make NO guarantee 
about the stability
+ * regarding binary compatibility and source compatibility of methods here.
+ *
+ * This current provides the following extension points:
+ * - Analyzer Rules.
+ * - Check Analysis Rules
+ * - Optimizer Rules.
+ * - Planning Strategies.
+ * - Customized Parser.
+ * - (External) Catalog listeners.
+ *
+ * The extensions can be used by calling withExtension on the 
[[SparkSession.Builder]], for
+ * example:
+ * {{{
+ *   SparkSession.builder()
+ *     .master("...")
+ *     .conf("...", true)
+ *     .withExtensions { extensions =>
+ *       extensions.injectResolutionRule { session =>
+ *         ...
+ *       }
+ *       extensions.injectParser { (session, parser) =>
+ *         ...
+ *       }
+ *     }
+ *     .getOrCreate()
+ * }}}
+ *
+ * Note that none of the injected builders should assume that the 
[[SparkSession]] is fully
+ * initialized and should not touch the session's internals (e.g. the 
SessionState).
+ */
+@DeveloperApi
+@Experimental
+@InterfaceStability.Unstable
+class SparkSessionExtensions {
+  type RuleBuilder = SparkSession => Rule[LogicalPlan]
+  type CheckRuleBuilder = SparkSession => LogicalPlan => Unit
+  type StrategyBuilder = SparkSession => Strategy
+  type ParserBuilder = (SparkSession, ParserInterface) => ParserInterface
+
+  private[this] val resolutionRuleBuilders = mutable.Buffer.empty[RuleBuilder]
+
+  /**
+   * Build the analyzer resolution `Rule`s using the given [[SparkSession]].
+   */
+  private[sql] def buildResolutionRules(session: SparkSession): 
Seq[Rule[LogicalPlan]] = {
+    resolutionRuleBuilders.map(_.apply(session))
+  }
+
+  /**
+   * Inject an analyzer resolution `Rule` builder into the [[SparkSession]]. 
These analyzer
+   * rules will be executed as part of the resolution phase of analysis.
+   */
+  def injectResolutionRule(builder: RuleBuilder): Unit = {
+    resolutionRuleBuilders += builder
+  }
+
+  private[this] val postHocResolutionRuleBuilders = 
mutable.Buffer.empty[RuleBuilder]
+
+  /**
+   * Build the analyzer post-hoc resolution `Rule`s using the given 
[[SparkSession]].
+   */
+  private[sql] def buildPostHocResolutionRules(session: SparkSession): 
Seq[Rule[LogicalPlan]] = {
+    postHocResolutionRuleBuilders.map(_.apply(session))
+  }
+
+  /**
+   * Inject an analyzer `Rule` builder into the [[SparkSession]]. These 
analyzer
+   * rules will be executed after resolution.
+   */
+  def injectPostHocResolutionRule(builder: RuleBuilder): Unit = {
+    postHocResolutionRuleBuilders += builder
+  }
+
+  private[this] val checkRuleBuilders = mutable.Buffer.empty[CheckRuleBuilder]
+
+  /**
+   * Build the check analysis `Rule`s using the given [[SparkSession]].
+   */
+  private[sql] def buildCheckRules(session: SparkSession): Seq[LogicalPlan => 
Unit] = {
+    checkRuleBuilders.map(_.apply(session))
+  }
+
+  /**
+   * Inject an check analysis `Rule` builder into the [[SparkSession]]. The 
injected rules will
+   * be executed after the analysis phase. A check analysis rule is used to 
detect problems with a
+   * LogicalPlan and should throw an exception when a problem is found.
+   */
+  def injectCheckRule(builder: CheckRuleBuilder): Unit = {
+    checkRuleBuilders += builder
+  }
+
+  private[this] val optimizerRules = mutable.Buffer.empty[RuleBuilder]
+
+  private[sql] def buildOptimizerRules(session: SparkSession): 
Seq[Rule[LogicalPlan]] = {
+    optimizerRules.map(_.apply(session))
+  }
+
+  /**
+   * Inject an optimizer `Rule` builder into the [[SparkSession]]. The 
injected rules will be
+   * executed during the operator optimization batch. An optimizer rule is 
used to improve the
+   * quality of an analyzed logical plan; these rules should never modify the 
result of the
+   * LogicalPlan.
+   */
+  def injectOptimizerRule(builder: RuleBuilder): Unit = {
+    optimizerRules += builder
+  }
+
+  private[this] val plannerStrategyBuilders = 
mutable.Buffer.empty[StrategyBuilder]
+
+  private[sql] def buildPlannerStrategies(session: SparkSession): 
Seq[Strategy] = {
+    plannerStrategyBuilders.map(_.apply(session))
+  }
+
+  /**
+   * Inject a planner `Strategy` builder into the [[SparkSession]]. The 
injected strategy will
+   * be used to convert a `LogicalPlan` into a executable
+   * [[org.apache.spark.sql.execution.SparkPlan]].
+   */
+  def injectPlannerStrategy(builder: StrategyBuilder): Unit = {
+    plannerStrategyBuilders += builder
+  }
+
+  private[this] val parserBuilders = mutable.Buffer.empty[ParserBuilder]
+
+  private[sql] def buildParser(
+      session: SparkSession,
+      initial: ParserInterface): ParserInterface = {
+    parserBuilders.foldLeft(initial) { (parser, builder) =>
+      builder(session, parser)
+    }
+  }
+
+  /**
+   * Inject a custom parser into the [[SparkSession]]. Note that the builder 
is passed a session
+   * and an initial parser. The latter allows for a user to create a partial 
parser and to delegate
+   * to the underlying parser for completeness. If a user injects more 
parsers, then the parsers
+   * are stacked on top of each other.
+   */
+  def injectParser(builder: ParserBuilder): Unit = {
+    parserBuilders += builder
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/f0de6007/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
index df7c367..2a801d8 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
@@ -18,8 +18,8 @@ package org.apache.spark.sql.internal
 
 import org.apache.spark.SparkConf
 import org.apache.spark.annotation.{Experimental, InterfaceStability}
-import org.apache.spark.sql.{ExperimentalMethods, SparkSession, Strategy, 
UDFRegistration}
-import org.apache.spark.sql.catalyst.analysis.{Analyzer, FunctionRegistry, 
ResolveTimeZone}
+import org.apache.spark.sql.{ExperimentalMethods, SparkSession, 
UDFRegistration, _}
+import org.apache.spark.sql.catalyst.analysis.{Analyzer, FunctionRegistry}
 import org.apache.spark.sql.catalyst.catalog.SessionCatalog
 import org.apache.spark.sql.catalyst.optimizer.Optimizer
 import org.apache.spark.sql.catalyst.parser.ParserInterface
@@ -64,6 +64,11 @@ abstract class BaseSessionStateBuilder(
   protected def newBuilder: NewBuilder
 
   /**
+   * Session extensions defined in the [[SparkSession]].
+   */
+  protected def extensions: SparkSessionExtensions = session.extensions
+
+  /**
    * Extract entries from `SparkConf` and put them in the `SQLConf`
    */
   protected def mergeSparkConf(sqlConf: SQLConf, sparkConf: SparkConf): Unit = 
{
@@ -108,7 +113,9 @@ abstract class BaseSessionStateBuilder(
    *
    * Note: this depends on the `conf` field.
    */
-  protected lazy val sqlParser: ParserInterface = new SparkSqlParser(conf)
+  protected lazy val sqlParser: ParserInterface = {
+    extensions.buildParser(session, new SparkSqlParser(conf))
+  }
 
   /**
    * ResourceLoader that is used to load function resources and jars.
@@ -171,7 +178,9 @@ abstract class BaseSessionStateBuilder(
    *
    * Note that this may NOT depend on the `analyzer` function.
    */
-  protected def customResolutionRules: Seq[Rule[LogicalPlan]] = Nil
+  protected def customResolutionRules: Seq[Rule[LogicalPlan]] = {
+    extensions.buildResolutionRules(session)
+  }
 
   /**
    * Custom post resolution rules to add to the Analyzer. Prefer overriding 
this instead of
@@ -179,7 +188,9 @@ abstract class BaseSessionStateBuilder(
    *
    * Note that this may NOT depend on the `analyzer` function.
    */
-  protected def customPostHocResolutionRules: Seq[Rule[LogicalPlan]] = Nil
+  protected def customPostHocResolutionRules: Seq[Rule[LogicalPlan]] = {
+    extensions.buildPostHocResolutionRules(session)
+  }
 
   /**
    * Custom check rules to add to the Analyzer. Prefer overriding this instead 
of creating
@@ -187,7 +198,9 @@ abstract class BaseSessionStateBuilder(
    *
    * Note that this may NOT depend on the `analyzer` function.
    */
-  protected def customCheckRules: Seq[LogicalPlan => Unit] = Nil
+  protected def customCheckRules: Seq[LogicalPlan => Unit] = {
+    extensions.buildCheckRules(session)
+  }
 
   /**
    * Logical query plan optimizer.
@@ -207,7 +220,9 @@ abstract class BaseSessionStateBuilder(
    *
    * Note that this may NOT depend on the `optimizer` function.
    */
-  protected def customOperatorOptimizationRules: Seq[Rule[LogicalPlan]] = Nil
+  protected def customOperatorOptimizationRules: Seq[Rule[LogicalPlan]] = {
+    extensions.buildOptimizerRules(session)
+  }
 
   /**
    * Planner that converts optimized logical plans to physical plans.
@@ -227,7 +242,9 @@ abstract class BaseSessionStateBuilder(
    *
    * Note that this may NOT depend on the `planner` function.
    */
-  protected def customPlanningStrategies: Seq[Strategy] = Nil
+  protected def customPlanningStrategies: Seq[Strategy] = {
+    extensions.buildPlannerStrategies(session)
+  }
 
   /**
    * Create a query execution object.

http://git-wip-us.apache.org/repos/asf/spark/blob/f0de6007/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
new file mode 100644
index 0000000..43db796
--- /dev/null
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
@@ -0,0 +1,144 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, 
ParserInterface}
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.execution.{SparkPlan, SparkStrategy}
+import org.apache.spark.sql.types.{DataType, StructType}
+
+/**
+ * Test cases for the [[SparkSessionExtensions]].
+ */
+class SparkSessionExtensionSuite extends SparkFunSuite {
+  type ExtensionsBuilder = SparkSessionExtensions => Unit
+  private def create(builder: ExtensionsBuilder): ExtensionsBuilder = builder
+
+  private def stop(spark: SparkSession): Unit = {
+    spark.stop()
+    SparkSession.clearActiveSession()
+    SparkSession.clearDefaultSession()
+  }
+
+  private def withSession(builder: ExtensionsBuilder)(f: SparkSession => 
Unit): Unit = {
+    val spark = 
SparkSession.builder().master("local[1]").withExtensions(builder).getOrCreate()
+    try f(spark) finally {
+      stop(spark)
+    }
+  }
+
+  test("inject analyzer rule") {
+    withSession(_.injectResolutionRule(MyRule)) { session =>
+      
assert(session.sessionState.analyzer.extendedResolutionRules.contains(MyRule(session)))
+    }
+  }
+
+  test("inject check analysis rule") {
+    withSession(_.injectCheckRule(MyCheckRule)) { session =>
+      
assert(session.sessionState.analyzer.extendedCheckRules.contains(MyCheckRule(session)))
+    }
+  }
+
+  test("inject optimizer rule") {
+    withSession(_.injectOptimizerRule(MyRule)) { session =>
+      
assert(session.sessionState.optimizer.batches.flatMap(_.rules).contains(MyRule(session)))
+    }
+  }
+
+  test("inject spark planner strategy") {
+    withSession(_.injectPlannerStrategy(MySparkStrategy)) { session =>
+      
assert(session.sessionState.planner.strategies.contains(MySparkStrategy(session)))
+    }
+  }
+
+  test("inject parser") {
+    val extension = create { extensions =>
+      extensions.injectParser((_, _) => CatalystSqlParser)
+    }
+    withSession(extension) { session =>
+      assert(session.sessionState.sqlParser == CatalystSqlParser)
+    }
+  }
+
+  test("inject stacked parsers") {
+    val extension = create { extensions =>
+      extensions.injectParser((_, _) => CatalystSqlParser)
+      extensions.injectParser(MyParser)
+      extensions.injectParser(MyParser)
+    }
+    withSession(extension) { session =>
+      val parser = MyParser(session, MyParser(session, CatalystSqlParser))
+      assert(session.sessionState.sqlParser == parser)
+    }
+  }
+
+  test("use custom class for extensions") {
+    val session = SparkSession.builder()
+      .master("local[1]")
+      .config("spark.sql.extensions", classOf[MyExtensions].getCanonicalName)
+      .getOrCreate()
+    try {
+      
assert(session.sessionState.planner.strategies.contains(MySparkStrategy(session)))
+      
assert(session.sessionState.analyzer.extendedResolutionRules.contains(MyRule(session)))
+    } finally {
+      stop(session)
+    }
+  }
+}
+
+case class MyRule(spark: SparkSession) extends Rule[LogicalPlan] {
+  override def apply(plan: LogicalPlan): LogicalPlan = plan
+}
+
+case class MyCheckRule(spark: SparkSession) extends (LogicalPlan => Unit) {
+  override def apply(plan: LogicalPlan): Unit = { }
+}
+
+case class MySparkStrategy(spark: SparkSession) extends SparkStrategy {
+  override def apply(plan: LogicalPlan): Seq[SparkPlan] = Seq.empty
+}
+
+case class MyParser(spark: SparkSession, delegate: ParserInterface) extends 
ParserInterface {
+  override def parsePlan(sqlText: String): LogicalPlan =
+    delegate.parsePlan(sqlText)
+
+  override def parseExpression(sqlText: String): Expression =
+    delegate.parseExpression(sqlText)
+
+  override def parseTableIdentifier(sqlText: String): TableIdentifier =
+    delegate.parseTableIdentifier(sqlText)
+
+  override def parseFunctionIdentifier(sqlText: String): FunctionIdentifier =
+    delegate.parseFunctionIdentifier(sqlText)
+
+  override def parseTableSchema(sqlText: String): StructType =
+    delegate.parseTableSchema(sqlText)
+
+  override def parseDataType(sqlText: String): DataType =
+    delegate.parseDataType(sqlText)
+}
+
+class MyExtensions extends (SparkSessionExtensions => Unit) {
+  def apply(e: SparkSessionExtensions): Unit = {
+    e.injectPlannerStrategy(MySparkStrategy)
+    e.injectResolutionRule(MyRule)
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to