Repository: spark Updated Branches: refs/heads/master c8f7691c6 -> 6e0fc8b0f
[SPARK-25560][SQL] Allow FunctionInjection in SparkExtensions This allows an implementer of Spark Session Extensions to utilize a method "injectFunction" which will add a new function to the default Spark Session Catalogue. ## What changes were proposed in this pull request? Adds a new function to SparkSessionExtensions def injectFunction(functionDescription: FunctionDescription) Where function description is a new type type FunctionDescription = (FunctionIdentifier, FunctionBuilder) The functions are loaded in BaseSessionBuilder when the function registry does not have a parent function registry to get loaded from. ## How was this patch tested? New unit tests are added for the extension in SparkSessionExtensionSuite Closes #22576 from RussellSpitzer/SPARK-25560. Authored-by: Russell Spitzer <russell.spit...@gmail.com> Signed-off-by: Herman van Hovell <hvanhov...@databricks.com> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/6e0fc8b0 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/6e0fc8b0 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/6e0fc8b0 Branch: refs/heads/master Commit: 6e0fc8b0fc2798b6372d1101f7996f57bae8fea4 Parents: c8f7691 Author: Russell Spitzer <russell.spit...@gmail.com> Authored: Fri Oct 19 10:40:56 2018 +0200 Committer: Herman van Hovell <hvanhov...@databricks.com> Committed: Fri Oct 19 10:40:56 2018 +0200 ---------------------------------------------------------------------- .../spark/sql/SparkSessionExtensions.scala | 22 ++++++++++++++++++ .../sql/internal/BaseSessionStateBuilder.scala | 3 ++- .../spark/sql/SparkSessionExtensionSuite.scala | 24 ++++++++++++++++++-- 3 files changed, 46 insertions(+), 3 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/6e0fc8b0/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala index 6b02ac2..a486434 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala @@ -20,6 +20,10 @@ package org.apache.spark.sql import scala.collection.mutable import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability} +import org.apache.spark.sql.catalyst.FunctionIdentifier +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder +import org.apache.spark.sql.catalyst.expressions.ExpressionInfo import org.apache.spark.sql.catalyst.parser.ParserInterface import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule @@ -68,6 +72,7 @@ class SparkSessionExtensions { type CheckRuleBuilder = SparkSession => LogicalPlan => Unit type StrategyBuilder = SparkSession => Strategy type ParserBuilder = (SparkSession, ParserInterface) => ParserInterface + type FunctionDescription = (FunctionIdentifier, ExpressionInfo, FunctionBuilder) private[this] val resolutionRuleBuilders = mutable.Buffer.empty[RuleBuilder] @@ -171,4 +176,21 @@ class SparkSessionExtensions { def injectParser(builder: ParserBuilder): Unit = { parserBuilders += builder } + + private[this] val injectedFunctions = mutable.Buffer.empty[FunctionDescription] + + private[sql] def registerFunctions(functionRegistry: FunctionRegistry) = { + for ((name, expressionInfo, function) <- injectedFunctions) { + functionRegistry.registerFunction(name, expressionInfo, function) + } + functionRegistry + } + + /** + * Injects a custom function into the [[org.apache.spark.sql.catalyst.analysis.FunctionRegistry]] + * at runtime for all sessions. + */ + def injectFunction(functionDescription: FunctionDescription): Unit = { + injectedFunctions += functionDescription + } } http://git-wip-us.apache.org/repos/asf/spark/blob/6e0fc8b0/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index 60bba5e..f67cc32 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -95,7 +95,8 @@ abstract class BaseSessionStateBuilder( * This either gets cloned from a pre-existing version or cloned from the built-in registry. */ protected lazy val functionRegistry: FunctionRegistry = { - parentState.map(_.functionRegistry).getOrElse(FunctionRegistry.builtin).clone() + parentState.map(_.functionRegistry.clone()) + .getOrElse(extensions.registerFunctions(FunctionRegistry.builtin.clone())) } /** http://git-wip-us.apache.org/repos/asf/spark/blob/6e0fc8b0/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala index 43db796..234711e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala @@ -18,12 +18,12 @@ package org.apache.spark.sql import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} -import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo, Literal} import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParserInterface} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.{SparkPlan, SparkStrategy} -import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.sql.types.{DataType, IntegerType, StructType} /** * Test cases for the [[SparkSessionExtensions]]. @@ -90,6 +90,16 @@ class SparkSessionExtensionSuite extends SparkFunSuite { } } + test("inject function") { + val extensions = create { extensions => + extensions.injectFunction(MyExtensions.myFunction) + } + withSession(extensions) { session => + assert(session.sessionState.functionRegistry + .lookupFunction(MyExtensions.myFunction._1).isDefined) + } + } + test("use custom class for extensions") { val session = SparkSession.builder() .master("local[1]") @@ -98,6 +108,8 @@ class SparkSessionExtensionSuite extends SparkFunSuite { try { assert(session.sessionState.planner.strategies.contains(MySparkStrategy(session))) assert(session.sessionState.analyzer.extendedResolutionRules.contains(MyRule(session))) + assert(session.sessionState.functionRegistry + .lookupFunction(MyExtensions.myFunction._1).isDefined) } finally { stop(session) } @@ -136,9 +148,17 @@ case class MyParser(spark: SparkSession, delegate: ParserInterface) extends Pars delegate.parseDataType(sqlText) } +object MyExtensions { + + val myFunction = (FunctionIdentifier("myFunction"), + new ExpressionInfo("noClass", "myDb", "myFunction", "usage", "extended usage" ), + (myArgs: Seq[Expression]) => Literal(5, IntegerType)) +} + class MyExtensions extends (SparkSessionExtensions => Unit) { def apply(e: SparkSessionExtensions): Unit = { e.injectPlannerStrategy(MySparkStrategy) e.injectResolutionRule(MyRule) + e.injectFunction(MyExtensions.myFunction) } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org