Repository: spark
Updated Branches:
  refs/heads/master c8f7691c6 -> 6e0fc8b0f


[SPARK-25560][SQL] Allow FunctionInjection in SparkExtensions

This allows an implementer of Spark Session Extensions to utilize a
method "injectFunction" which will add a new function to the default
Spark Session Catalogue.

## What changes were proposed in this pull request?

Adds a new function to SparkSessionExtensions

    def injectFunction(functionDescription: FunctionDescription)

Where function description is a new type

  type FunctionDescription = (FunctionIdentifier, FunctionBuilder)

The functions are loaded in BaseSessionBuilder when the function registry does 
not have a parent
function registry to get loaded from.

## How was this patch tested?

New unit tests are added for the extension in SparkSessionExtensionSuite

Closes #22576 from RussellSpitzer/SPARK-25560.

Authored-by: Russell Spitzer <russell.spit...@gmail.com>
Signed-off-by: Herman van Hovell <hvanhov...@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/6e0fc8b0
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/6e0fc8b0
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/6e0fc8b0

Branch: refs/heads/master
Commit: 6e0fc8b0fc2798b6372d1101f7996f57bae8fea4
Parents: c8f7691
Author: Russell Spitzer <russell.spit...@gmail.com>
Authored: Fri Oct 19 10:40:56 2018 +0200
Committer: Herman van Hovell <hvanhov...@databricks.com>
Committed: Fri Oct 19 10:40:56 2018 +0200

----------------------------------------------------------------------
 .../spark/sql/SparkSessionExtensions.scala      | 22 ++++++++++++++++++
 .../sql/internal/BaseSessionStateBuilder.scala  |  3 ++-
 .../spark/sql/SparkSessionExtensionSuite.scala  | 24 ++++++++++++++++++--
 3 files changed, 46 insertions(+), 3 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/6e0fc8b0/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala
index 6b02ac2..a486434 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala
@@ -20,6 +20,10 @@ package org.apache.spark.sql
 import scala.collection.mutable
 
 import org.apache.spark.annotation.{DeveloperApi, Experimental, 
InterfaceStability}
+import org.apache.spark.sql.catalyst.FunctionIdentifier
+import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
+import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
+import org.apache.spark.sql.catalyst.expressions.ExpressionInfo
 import org.apache.spark.sql.catalyst.parser.ParserInterface
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.catalyst.rules.Rule
@@ -68,6 +72,7 @@ class SparkSessionExtensions {
   type CheckRuleBuilder = SparkSession => LogicalPlan => Unit
   type StrategyBuilder = SparkSession => Strategy
   type ParserBuilder = (SparkSession, ParserInterface) => ParserInterface
+  type FunctionDescription = (FunctionIdentifier, ExpressionInfo, 
FunctionBuilder)
 
   private[this] val resolutionRuleBuilders = mutable.Buffer.empty[RuleBuilder]
 
@@ -171,4 +176,21 @@ class SparkSessionExtensions {
   def injectParser(builder: ParserBuilder): Unit = {
     parserBuilders += builder
   }
+
+  private[this] val injectedFunctions = 
mutable.Buffer.empty[FunctionDescription]
+
+  private[sql] def registerFunctions(functionRegistry: FunctionRegistry) = {
+    for ((name, expressionInfo, function) <- injectedFunctions) {
+      functionRegistry.registerFunction(name, expressionInfo, function)
+    }
+    functionRegistry
+  }
+
+  /**
+  * Injects a custom function into the 
[[org.apache.spark.sql.catalyst.analysis.FunctionRegistry]]
+  * at runtime for all sessions.
+  */
+  def injectFunction(functionDescription: FunctionDescription): Unit = {
+    injectedFunctions += functionDescription
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/6e0fc8b0/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
index 60bba5e..f67cc32 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
@@ -95,7 +95,8 @@ abstract class BaseSessionStateBuilder(
    * This either gets cloned from a pre-existing version or cloned from the 
built-in registry.
    */
   protected lazy val functionRegistry: FunctionRegistry = {
-    
parentState.map(_.functionRegistry).getOrElse(FunctionRegistry.builtin).clone()
+    parentState.map(_.functionRegistry.clone())
+      
.getOrElse(extensions.registerFunctions(FunctionRegistry.builtin.clone()))
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/6e0fc8b0/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
index 43db796..234711e 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
@@ -18,12 +18,12 @@ package org.apache.spark.sql
 
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
-import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo, 
Literal}
 import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, 
ParserInterface}
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.execution.{SparkPlan, SparkStrategy}
-import org.apache.spark.sql.types.{DataType, StructType}
+import org.apache.spark.sql.types.{DataType, IntegerType, StructType}
 
 /**
  * Test cases for the [[SparkSessionExtensions]].
@@ -90,6 +90,16 @@ class SparkSessionExtensionSuite extends SparkFunSuite {
     }
   }
 
+  test("inject function") {
+    val extensions = create { extensions =>
+      extensions.injectFunction(MyExtensions.myFunction)
+    }
+    withSession(extensions) { session =>
+      assert(session.sessionState.functionRegistry
+        .lookupFunction(MyExtensions.myFunction._1).isDefined)
+    }
+  }
+
   test("use custom class for extensions") {
     val session = SparkSession.builder()
       .master("local[1]")
@@ -98,6 +108,8 @@ class SparkSessionExtensionSuite extends SparkFunSuite {
     try {
       
assert(session.sessionState.planner.strategies.contains(MySparkStrategy(session)))
       
assert(session.sessionState.analyzer.extendedResolutionRules.contains(MyRule(session)))
+      assert(session.sessionState.functionRegistry
+        .lookupFunction(MyExtensions.myFunction._1).isDefined)
     } finally {
       stop(session)
     }
@@ -136,9 +148,17 @@ case class MyParser(spark: SparkSession, delegate: 
ParserInterface) extends Pars
     delegate.parseDataType(sqlText)
 }
 
+object MyExtensions {
+
+  val myFunction = (FunctionIdentifier("myFunction"),
+    new ExpressionInfo("noClass", "myDb", "myFunction", "usage", "extended 
usage" ),
+    (myArgs: Seq[Expression]) => Literal(5, IntegerType))
+}
+
 class MyExtensions extends (SparkSessionExtensions => Unit) {
   def apply(e: SparkSessionExtensions): Unit = {
     e.injectPlannerStrategy(MySparkStrategy)
     e.injectResolutionRule(MyRule)
+    e.injectFunction(MyExtensions.myFunction)
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to