Repository: spark
Updated Branches:
  refs/heads/master 7d425b190 -> c3eaee776


[SPARK-25003][PYSPARK] Use SessionExtensions in Pyspark

Master

## What changes were proposed in this pull request?

Previously Pyspark used the private constructor for SparkSession when
building that object. This resulted in a SparkSession without checking
the sql.extensions parameter for additional session extensions. To fix
this we instead use the Session.builder() path as SparkR uses, this
loads the extensions and allows their use in PySpark.

## How was this patch tested?

An integration test was added which mimics the Scala test for the same feature.

Please review http://spark.apache.org/contributing.html before opening a pull 
request.

Closes #21990 from RussellSpitzer/SPARK-25003-master.

Authored-by: Russell Spitzer <russell.spit...@gmail.com>
Signed-off-by: hyukjinkwon <gurwls...@apache.org>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/c3eaee77
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/c3eaee77
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/c3eaee77

Branch: refs/heads/master
Commit: c3eaee776509b0a23d0ba7a575575516bab4aa4e
Parents: 7d425b1
Author: Russell Spitzer <russell.spit...@gmail.com>
Authored: Thu Oct 18 12:29:09 2018 +0800
Committer: hyukjinkwon <gurwls...@apache.org>
Committed: Thu Oct 18 12:29:09 2018 +0800

----------------------------------------------------------------------
 python/pyspark/sql/tests.py                     | 42 +++++++++++++++
 .../org/apache/spark/sql/SparkSession.scala     | 56 +++++++++++++-------
 2 files changed, 80 insertions(+), 18 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/c3eaee77/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 85712df..8065d82 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -3837,6 +3837,48 @@ class QueryExecutionListenerTests(unittest.TestCase, 
SQLTestUtils):
                 "The callback from the query execution listener should be 
called after 'toPandas'")
 
 
+class SparkExtensionsTest(unittest.TestCase):
+    # These tests are separate because it uses 'spark.sql.extensions' which is
+    # static and immutable. This can't be set or unset, for example, via 
`spark.conf`.
+
+    @classmethod
+    def setUpClass(cls):
+        import glob
+        from pyspark.find_spark_home import _find_spark_home
+
+        SPARK_HOME = _find_spark_home()
+        filename_pattern = (
+            "sql/core/target/scala-*/test-classes/org/apache/spark/sql/"
+            "SparkSessionExtensionSuite.class")
+        if not glob.glob(os.path.join(SPARK_HOME, filename_pattern)):
+            raise unittest.SkipTest(
+                "'org.apache.spark.sql.SparkSessionExtensionSuite' is not "
+                "available. Will skip the related tests.")
+
+        # Note that 'spark.sql.extensions' is a static immutable configuration.
+        cls.spark = SparkSession.builder \
+            .master("local[4]") \
+            .appName(cls.__name__) \
+            .config(
+                "spark.sql.extensions",
+                "org.apache.spark.sql.MyExtensions") \
+            .getOrCreate()
+
+    @classmethod
+    def tearDownClass(cls):
+        cls.spark.stop()
+
+    def test_use_custom_class_for_extensions(self):
+        self.assertTrue(
+            
self.spark._jsparkSession.sessionState().planner().strategies().contains(
+                
self.spark._jvm.org.apache.spark.sql.MySparkStrategy(self.spark._jsparkSession)),
+            "MySparkStrategy not found in active planner strategies")
+        self.assertTrue(
+            
self.spark._jsparkSession.sessionState().analyzer().extendedResolutionRules().contains(
+                
self.spark._jvm.org.apache.spark.sql.MyRule(self.spark._jsparkSession)),
+            "MyRule not found in extended resolution rules")
+
+
 class SparkSessionTests(PySparkTestCase):
 
     # This test is separate because it's closely related with session's start 
and stop.

http://git-wip-us.apache.org/repos/asf/spark/blob/c3eaee77/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
index 2b847fb..71f967a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -84,8 +84,17 @@ class SparkSession private(
   // The call site where this SparkSession was constructed.
   private val creationSite: CallSite = Utils.getCallSite()
 
+  /**
+   * Constructor used in Pyspark. Contains explicit application of Spark 
Session Extensions
+   * which otherwise only occurs during getOrCreate. We cannot add this to the 
default constructor
+   * since that would cause every new session to reinvoke Spark Session 
Extensions on the currently
+   * running extensions.
+   */
   private[sql] def this(sc: SparkContext) {
-    this(sc, None, None, new SparkSessionExtensions)
+    this(sc, None, None,
+      SparkSession.applyExtensions(
+        sc.getConf.get(StaticSQLConf.SPARK_SESSION_EXTENSIONS),
+        new SparkSessionExtensions))
   }
 
   sparkContext.assertNotStopped()
@@ -936,23 +945,9 @@ object SparkSession extends Logging {
           // Do not update `SparkConf` for existing `SparkContext`, as it's 
shared by all sessions.
         }
 
-        // Initialize extensions if the user has defined a configurator class.
-        val extensionConfOption = 
sparkContext.conf.get(StaticSQLConf.SPARK_SESSION_EXTENSIONS)
-        if (extensionConfOption.isDefined) {
-          val extensionConfClassName = extensionConfOption.get
-          try {
-            val extensionConfClass = Utils.classForName(extensionConfClassName)
-            val extensionConf = extensionConfClass.newInstance()
-              .asInstanceOf[SparkSessionExtensions => Unit]
-            extensionConf(extensions)
-          } catch {
-            // Ignore the error if we cannot find the class or when the class 
has the wrong type.
-            case e @ (_: ClassCastException |
-                      _: ClassNotFoundException |
-                      _: NoClassDefFoundError) =>
-              logWarning(s"Cannot use $extensionConfClassName to configure 
session extensions.", e)
-          }
-        }
+        applyExtensions(
+          sparkContext.getConf.get(StaticSQLConf.SPARK_SESSION_EXTENSIONS),
+          extensions)
 
         session = new SparkSession(sparkContext, None, None, extensions)
         options.foreach { case (k, v) => session.initialSessionOptions.put(k, 
v) }
@@ -1137,4 +1132,29 @@ object SparkSession extends Logging {
       SparkSession.clearDefaultSession()
     }
   }
+
+  /**
+   * Initialize extensions for given extension classname. This class will be 
applied to the
+   * extensions passed into this function.
+   */
+  private def applyExtensions(
+      extensionOption: Option[String],
+      extensions: SparkSessionExtensions): SparkSessionExtensions = {
+    if (extensionOption.isDefined) {
+      val extensionConfClassName = extensionOption.get
+      try {
+        val extensionConfClass = Utils.classForName(extensionConfClassName)
+        val extensionConf = extensionConfClass.newInstance()
+          .asInstanceOf[SparkSessionExtensions => Unit]
+        extensionConf(extensions)
+      } catch {
+        // Ignore the error if we cannot find the class or when the class has 
the wrong type.
+        case e@(_: ClassCastException |
+                _: ClassNotFoundException |
+                _: NoClassDefFoundError) =>
+          logWarning(s"Cannot use $extensionConfClassName to configure session 
extensions.", e)
+      }
+    }
+    extensions
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to