Repository: spark Updated Branches: refs/heads/master f1891ff1e -> d367bdcf5
[SPARK-25255][PYTHON] Add getActiveSession to SparkSession in PySpark ## What changes were proposed in this pull request? add getActiveSession in session.py ## How was this patch tested? add doctest Closes #22295 from huaxingao/spark25255. Authored-by: Huaxin Gao <huax...@us.ibm.com> Signed-off-by: Holden Karau <hol...@pigscanfly.ca> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/d367bdcf Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/d367bdcf Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/d367bdcf Branch: refs/heads/master Commit: d367bdcf521f564d2d7066257200be26b27ea926 Parents: f1891ff Author: Huaxin Gao <huax...@us.ibm.com> Authored: Fri Oct 26 09:40:13 2018 -0700 Committer: Holden Karau <hol...@pigscanfly.ca> Committed: Fri Oct 26 09:40:13 2018 -0700 ---------------------------------------------------------------------- python/pyspark/sql/session.py | 30 ++++++++ python/pyspark/sql/tests.py | 151 +++++++++++++++++++++++++++++++++++++ 2 files changed, 181 insertions(+) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/d367bdcf/python/pyspark/sql/session.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 079af8c..6f4b327 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -192,6 +192,7 @@ class SparkSession(object): """A class attribute having a :class:`Builder` to construct :class:`SparkSession` instances""" _instantiatedSession = None + _activeSession = None @ignore_unicode_prefix def __init__(self, sparkContext, jsparkSession=None): @@ -233,7 +234,9 @@ class SparkSession(object): if SparkSession._instantiatedSession is None \ or SparkSession._instantiatedSession._sc._jsc is None: SparkSession._instantiatedSession = self + SparkSession._activeSession = self self._jvm.SparkSession.setDefaultSession(self._jsparkSession) + self._jvm.SparkSession.setActiveSession(self._jsparkSession) def _repr_html_(self): return """ @@ -255,6 +258,29 @@ class SparkSession(object): """ return self.__class__(self._sc, self._jsparkSession.newSession()) + @classmethod + @since(3.0) + def getActiveSession(cls): + """ + Returns the active SparkSession for the current thread, returned by the builder. + >>> s = SparkSession.getActiveSession() + >>> l = [('Alice', 1)] + >>> rdd = s.sparkContext.parallelize(l) + >>> df = s.createDataFrame(rdd, ['name', 'age']) + >>> df.select("age").collect() + [Row(age=1)] + """ + from pyspark import SparkContext + sc = SparkContext._active_spark_context + if sc is None: + return None + else: + if sc._jvm.SparkSession.getActiveSession().isDefined(): + SparkSession(sc, sc._jvm.SparkSession.getActiveSession().get()) + return SparkSession._activeSession + else: + return None + @property @since(2.0) def sparkContext(self): @@ -671,6 +697,8 @@ class SparkSession(object): ... Py4JJavaError: ... """ + SparkSession._activeSession = self + self._jvm.SparkSession.setActiveSession(self._jsparkSession) if isinstance(data, DataFrame): raise TypeError("data is already a DataFrame") @@ -826,7 +854,9 @@ class SparkSession(object): self._sc.stop() # We should clean the default session up. See SPARK-23228. self._jvm.SparkSession.clearDefaultSession() + self._jvm.SparkSession.clearActiveSession() SparkSession._instantiatedSession = None + SparkSession._activeSession = None @since(2.0) def __enter__(self): http://git-wip-us.apache.org/repos/asf/spark/blob/d367bdcf/python/pyspark/sql/tests.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 82dc5a6..ad04270 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3985,6 +3985,157 @@ class SparkSessionTests(PySparkTestCase): spark.stop() +class SparkSessionTests2(unittest.TestCase): + + def test_active_session(self): + spark = SparkSession.builder \ + .master("local") \ + .getOrCreate() + try: + activeSession = SparkSession.getActiveSession() + df = activeSession.createDataFrame([(1, 'Alice')], ['age', 'name']) + self.assertEqual(df.collect(), [Row(age=1, name=u'Alice')]) + finally: + spark.stop() + + def test_get_active_session_when_no_active_session(self): + active = SparkSession.getActiveSession() + self.assertEqual(active, None) + spark = SparkSession.builder \ + .master("local") \ + .getOrCreate() + active = SparkSession.getActiveSession() + self.assertEqual(active, spark) + spark.stop() + active = SparkSession.getActiveSession() + self.assertEqual(active, None) + + def test_SparkSession(self): + spark = SparkSession.builder \ + .master("local") \ + .config("some-config", "v2") \ + .getOrCreate() + try: + self.assertEqual(spark.conf.get("some-config"), "v2") + self.assertEqual(spark.sparkContext._conf.get("some-config"), "v2") + self.assertEqual(spark.version, spark.sparkContext.version) + spark.sql("CREATE DATABASE test_db") + spark.catalog.setCurrentDatabase("test_db") + self.assertEqual(spark.catalog.currentDatabase(), "test_db") + spark.sql("CREATE TABLE table1 (name STRING, age INT) USING parquet") + self.assertEqual(spark.table("table1").columns, ['name', 'age']) + self.assertEqual(spark.range(3).count(), 3) + finally: + spark.stop() + + def test_global_default_session(self): + spark = SparkSession.builder \ + .master("local") \ + .getOrCreate() + try: + self.assertEqual(SparkSession.builder.getOrCreate(), spark) + finally: + spark.stop() + + def test_default_and_active_session(self): + spark = SparkSession.builder \ + .master("local") \ + .getOrCreate() + activeSession = spark._jvm.SparkSession.getActiveSession() + defaultSession = spark._jvm.SparkSession.getDefaultSession() + try: + self.assertEqual(activeSession, defaultSession) + finally: + spark.stop() + + def test_config_option_propagated_to_existing_session(self): + session1 = SparkSession.builder \ + .master("local") \ + .config("spark-config1", "a") \ + .getOrCreate() + self.assertEqual(session1.conf.get("spark-config1"), "a") + session2 = SparkSession.builder \ + .config("spark-config1", "b") \ + .getOrCreate() + try: + self.assertEqual(session1, session2) + self.assertEqual(session1.conf.get("spark-config1"), "b") + finally: + session1.stop() + + def test_new_session(self): + session = SparkSession.builder \ + .master("local") \ + .getOrCreate() + newSession = session.newSession() + try: + self.assertNotEqual(session, newSession) + finally: + session.stop() + newSession.stop() + + def test_create_new_session_if_old_session_stopped(self): + session = SparkSession.builder \ + .master("local") \ + .getOrCreate() + session.stop() + newSession = SparkSession.builder \ + .master("local") \ + .getOrCreate() + try: + self.assertNotEqual(session, newSession) + finally: + newSession.stop() + + def test_active_session_with_None_and_not_None_context(self): + from pyspark.context import SparkContext + from pyspark.conf import SparkConf + sc = None + session = None + try: + sc = SparkContext._active_spark_context + self.assertEqual(sc, None) + activeSession = SparkSession.getActiveSession() + self.assertEqual(activeSession, None) + sparkConf = SparkConf() + sc = SparkContext.getOrCreate(sparkConf) + activeSession = sc._jvm.SparkSession.getActiveSession() + self.assertFalse(activeSession.isDefined()) + session = SparkSession(sc) + activeSession = sc._jvm.SparkSession.getActiveSession() + self.assertTrue(activeSession.isDefined()) + activeSession2 = SparkSession.getActiveSession() + self.assertNotEqual(activeSession2, None) + finally: + if session is not None: + session.stop() + if sc is not None: + sc.stop() + + +class SparkSessionTests3(ReusedSQLTestCase): + + def test_get_active_session_after_create_dataframe(self): + session2 = None + try: + activeSession1 = SparkSession.getActiveSession() + session1 = self.spark + self.assertEqual(session1, activeSession1) + session2 = self.spark.newSession() + activeSession2 = SparkSession.getActiveSession() + self.assertEqual(session1, activeSession2) + self.assertNotEqual(session2, activeSession2) + session2.createDataFrame([(1, 'Alice')], ['age', 'name']) + activeSession3 = SparkSession.getActiveSession() + self.assertEqual(session2, activeSession3) + session1.createDataFrame([(1, 'Alice')], ['age', 'name']) + activeSession4 = SparkSession.getActiveSession() + self.assertEqual(session1, activeSession4) + finally: + if session2 is not None: + session2.stop() + + class UDFInitializationTests(unittest.TestCase): def tearDown(self): if SparkSession._instantiatedSession is not None: --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org