Repository: spark Updated Branches: refs/heads/master 161a3f2ae -> 3d0911bbe
[SPARK-23228][PYSPARK] Add Python Created jsparkSession to JVM's defaultSession ## What changes were proposed in this pull request? In the current PySpark code, Python created `jsparkSession` doesn't add to JVM's defaultSession, this `SparkSession` object cannot be fetched from Java side, so the below scala code will be failed when loaded in PySpark application. ```scala class TestSparkSession extends SparkListener with Logging { override def onOtherEvent(event: SparkListenerEvent): Unit = { event match { case CreateTableEvent(db, table) => val session = SparkSession.getActiveSession.orElse(SparkSession.getDefaultSession) assert(session.isDefined) val tableInfo = session.get.sharedState.externalCatalog.getTable(db, table) logInfo(s"Table info ${tableInfo}") case e => logInfo(s"event $e") } } } ``` So here propose to add fresh create `jsparkSession` to `defaultSession`. ## How was this patch tested? Manual verification. Author: jerryshao <ss...@hortonworks.com> Author: hyukjinkwon <gurwls...@gmail.com> Author: Saisai Shao <sai.sai.s...@gmail.com> Closes #20404 from jerryshao/SPARK-23228. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/3d0911bb Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/3d0911bb Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/3d0911bb Branch: refs/heads/master Commit: 3d0911bbe47f76c341c090edad3737e88a67e3d7 Parents: 161a3f2 Author: jerryshao <ss...@hortonworks.com> Authored: Wed Jan 31 20:04:51 2018 +0900 Committer: hyukjinkwon <gurwls...@gmail.com> Committed: Wed Jan 31 20:04:51 2018 +0900 ---------------------------------------------------------------------- python/pyspark/sql/session.py | 10 +++++++++- python/pyspark/sql/tests.py | 28 +++++++++++++++++++++++++++- 2 files changed, 36 insertions(+), 2 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/3d0911bb/python/pyspark/sql/session.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 6c84023..1ed0429 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -213,7 +213,12 @@ class SparkSession(object): self._jsc = self._sc._jsc self._jvm = self._sc._jvm if jsparkSession is None: - jsparkSession = self._jvm.SparkSession(self._jsc.sc()) + if self._jvm.SparkSession.getDefaultSession().isDefined() \ + and not self._jvm.SparkSession.getDefaultSession().get() \ + .sparkContext().isStopped(): + jsparkSession = self._jvm.SparkSession.getDefaultSession().get() + else: + jsparkSession = self._jvm.SparkSession(self._jsc.sc()) self._jsparkSession = jsparkSession self._jwrapped = self._jsparkSession.sqlContext() self._wrapped = SQLContext(self._sc, self, self._jwrapped) @@ -225,6 +230,7 @@ class SparkSession(object): if SparkSession._instantiatedSession is None \ or SparkSession._instantiatedSession._sc._jsc is None: SparkSession._instantiatedSession = self + self._jvm.SparkSession.setDefaultSession(self._jsparkSession) def _repr_html_(self): return """ @@ -759,6 +765,8 @@ class SparkSession(object): """Stop the underlying :class:`SparkContext`. """ self._sc.stop() + # We should clean the default session up. See SPARK-23228. + self._jvm.SparkSession.clearDefaultSession() SparkSession._instantiatedSession = None @since(2.0) http://git-wip-us.apache.org/repos/asf/spark/blob/3d0911bb/python/pyspark/sql/tests.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index dc80870..dc26b96 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -69,7 +69,7 @@ from pyspark.sql.types import UserDefinedType, _infer_type, _make_type_verifier from pyspark.sql.types import _array_signed_int_typecode_ctype_mappings, _array_type_mappings from pyspark.sql.types import _array_unsigned_int_typecode_ctype_mappings from pyspark.sql.types import _merge_type -from pyspark.tests import QuietTest, ReusedPySparkTestCase, SparkSubmitTests +from pyspark.tests import QuietTest, ReusedPySparkTestCase, PySparkTestCase, SparkSubmitTests from pyspark.sql.functions import UserDefinedFunction, sha2, lit from pyspark.sql.window import Window from pyspark.sql.utils import AnalysisException, ParseException, IllegalArgumentException @@ -2925,6 +2925,32 @@ class SQLTests2(ReusedSQLTestCase): sc.stop() +class SparkSessionTests(PySparkTestCase): + + # This test is separate because it's closely related with session's start and stop. + # See SPARK-23228. + def test_set_jvm_default_session(self): + spark = SparkSession.builder.getOrCreate() + try: + self.assertTrue(spark._jvm.SparkSession.getDefaultSession().isDefined()) + finally: + spark.stop() + self.assertTrue(spark._jvm.SparkSession.getDefaultSession().isEmpty()) + + def test_jvm_default_session_already_set(self): + # Here, we assume there is the default session already set in JVM. + jsession = self.sc._jvm.SparkSession(self.sc._jsc.sc()) + self.sc._jvm.SparkSession.setDefaultSession(jsession) + + spark = SparkSession.builder.getOrCreate() + try: + self.assertTrue(spark._jvm.SparkSession.getDefaultSession().isDefined()) + # The session should be the same with the exiting one. + self.assertTrue(jsession.equals(spark._jvm.SparkSession.getDefaultSession().get())) + finally: + spark.stop() + + class UDFInitializationTests(unittest.TestCase): def tearDown(self): if SparkSession._instantiatedSession is not None: --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org