Repository: spark Updated Branches: refs/heads/master 7d425b190 -> c3eaee776
[SPARK-25003][PYSPARK] Use SessionExtensions in Pyspark Master ## What changes were proposed in this pull request? Previously Pyspark used the private constructor for SparkSession when building that object. This resulted in a SparkSession without checking the sql.extensions parameter for additional session extensions. To fix this we instead use the Session.builder() path as SparkR uses, this loads the extensions and allows their use in PySpark. ## How was this patch tested? An integration test was added which mimics the Scala test for the same feature. Please review http://spark.apache.org/contributing.html before opening a pull request. Closes #21990 from RussellSpitzer/SPARK-25003-master. Authored-by: Russell Spitzer <russell.spit...@gmail.com> Signed-off-by: hyukjinkwon <gurwls...@apache.org> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/c3eaee77 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/c3eaee77 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/c3eaee77 Branch: refs/heads/master Commit: c3eaee776509b0a23d0ba7a575575516bab4aa4e Parents: 7d425b1 Author: Russell Spitzer <russell.spit...@gmail.com> Authored: Thu Oct 18 12:29:09 2018 +0800 Committer: hyukjinkwon <gurwls...@apache.org> Committed: Thu Oct 18 12:29:09 2018 +0800 ---------------------------------------------------------------------- python/pyspark/sql/tests.py | 42 +++++++++++++++ .../org/apache/spark/sql/SparkSession.scala | 56 +++++++++++++------- 2 files changed, 80 insertions(+), 18 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/c3eaee77/python/pyspark/sql/tests.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 85712df..8065d82 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3837,6 +3837,48 @@ class QueryExecutionListenerTests(unittest.TestCase, SQLTestUtils): "The callback from the query execution listener should be called after 'toPandas'") +class SparkExtensionsTest(unittest.TestCase): + # These tests are separate because it uses 'spark.sql.extensions' which is + # static and immutable. This can't be set or unset, for example, via `spark.conf`. + + @classmethod + def setUpClass(cls): + import glob + from pyspark.find_spark_home import _find_spark_home + + SPARK_HOME = _find_spark_home() + filename_pattern = ( + "sql/core/target/scala-*/test-classes/org/apache/spark/sql/" + "SparkSessionExtensionSuite.class") + if not glob.glob(os.path.join(SPARK_HOME, filename_pattern)): + raise unittest.SkipTest( + "'org.apache.spark.sql.SparkSessionExtensionSuite' is not " + "available. Will skip the related tests.") + + # Note that 'spark.sql.extensions' is a static immutable configuration. + cls.spark = SparkSession.builder \ + .master("local[4]") \ + .appName(cls.__name__) \ + .config( + "spark.sql.extensions", + "org.apache.spark.sql.MyExtensions") \ + .getOrCreate() + + @classmethod + def tearDownClass(cls): + cls.spark.stop() + + def test_use_custom_class_for_extensions(self): + self.assertTrue( + self.spark._jsparkSession.sessionState().planner().strategies().contains( + self.spark._jvm.org.apache.spark.sql.MySparkStrategy(self.spark._jsparkSession)), + "MySparkStrategy not found in active planner strategies") + self.assertTrue( + self.spark._jsparkSession.sessionState().analyzer().extendedResolutionRules().contains( + self.spark._jvm.org.apache.spark.sql.MyRule(self.spark._jsparkSession)), + "MyRule not found in extended resolution rules") + + class SparkSessionTests(PySparkTestCase): # This test is separate because it's closely related with session's start and stop. http://git-wip-us.apache.org/repos/asf/spark/blob/c3eaee77/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 2b847fb..71f967a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -84,8 +84,17 @@ class SparkSession private( // The call site where this SparkSession was constructed. private val creationSite: CallSite = Utils.getCallSite() + /** + * Constructor used in Pyspark. Contains explicit application of Spark Session Extensions + * which otherwise only occurs during getOrCreate. We cannot add this to the default constructor + * since that would cause every new session to reinvoke Spark Session Extensions on the currently + * running extensions. + */ private[sql] def this(sc: SparkContext) { - this(sc, None, None, new SparkSessionExtensions) + this(sc, None, None, + SparkSession.applyExtensions( + sc.getConf.get(StaticSQLConf.SPARK_SESSION_EXTENSIONS), + new SparkSessionExtensions)) } sparkContext.assertNotStopped() @@ -936,23 +945,9 @@ object SparkSession extends Logging { // Do not update `SparkConf` for existing `SparkContext`, as it's shared by all sessions. } - // Initialize extensions if the user has defined a configurator class. - val extensionConfOption = sparkContext.conf.get(StaticSQLConf.SPARK_SESSION_EXTENSIONS) - if (extensionConfOption.isDefined) { - val extensionConfClassName = extensionConfOption.get - try { - val extensionConfClass = Utils.classForName(extensionConfClassName) - val extensionConf = extensionConfClass.newInstance() - .asInstanceOf[SparkSessionExtensions => Unit] - extensionConf(extensions) - } catch { - // Ignore the error if we cannot find the class or when the class has the wrong type. - case e @ (_: ClassCastException | - _: ClassNotFoundException | - _: NoClassDefFoundError) => - logWarning(s"Cannot use $extensionConfClassName to configure session extensions.", e) - } - } + applyExtensions( + sparkContext.getConf.get(StaticSQLConf.SPARK_SESSION_EXTENSIONS), + extensions) session = new SparkSession(sparkContext, None, None, extensions) options.foreach { case (k, v) => session.initialSessionOptions.put(k, v) } @@ -1137,4 +1132,29 @@ object SparkSession extends Logging { SparkSession.clearDefaultSession() } } + + /** + * Initialize extensions for given extension classname. This class will be applied to the + * extensions passed into this function. + */ + private def applyExtensions( + extensionOption: Option[String], + extensions: SparkSessionExtensions): SparkSessionExtensions = { + if (extensionOption.isDefined) { + val extensionConfClassName = extensionOption.get + try { + val extensionConfClass = Utils.classForName(extensionConfClassName) + val extensionConf = extensionConfClass.newInstance() + .asInstanceOf[SparkSessionExtensions => Unit] + extensionConf(extensions) + } catch { + // Ignore the error if we cannot find the class or when the class has the wrong type. + case e@(_: ClassCastException | + _: ClassNotFoundException | + _: NoClassDefFoundError) => + logWarning(s"Cannot use $extensionConfClassName to configure session extensions.", e) + } + } + extensions + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org