Repository: spark
Updated Branches:
  refs/heads/master f1891ff1e -> d367bdcf5


[SPARK-25255][PYTHON] Add getActiveSession to SparkSession in PySpark

## What changes were proposed in this pull request?

add getActiveSession  in session.py

## How was this patch tested?

add doctest

Closes #22295 from huaxingao/spark25255.

Authored-by: Huaxin Gao <huax...@us.ibm.com>
Signed-off-by: Holden Karau <hol...@pigscanfly.ca>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/d367bdcf
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/d367bdcf
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/d367bdcf

Branch: refs/heads/master
Commit: d367bdcf521f564d2d7066257200be26b27ea926
Parents: f1891ff
Author: Huaxin Gao <huax...@us.ibm.com>
Authored: Fri Oct 26 09:40:13 2018 -0700
Committer: Holden Karau <hol...@pigscanfly.ca>
Committed: Fri Oct 26 09:40:13 2018 -0700

----------------------------------------------------------------------
 python/pyspark/sql/session.py |  30 ++++++++
 python/pyspark/sql/tests.py   | 151 +++++++++++++++++++++++++++++++++++++
 2 files changed, 181 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/d367bdcf/python/pyspark/sql/session.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py
index 079af8c..6f4b327 100644
--- a/python/pyspark/sql/session.py
+++ b/python/pyspark/sql/session.py
@@ -192,6 +192,7 @@ class SparkSession(object):
     """A class attribute having a :class:`Builder` to construct 
:class:`SparkSession` instances"""
 
     _instantiatedSession = None
+    _activeSession = None
 
     @ignore_unicode_prefix
     def __init__(self, sparkContext, jsparkSession=None):
@@ -233,7 +234,9 @@ class SparkSession(object):
         if SparkSession._instantiatedSession is None \
                 or SparkSession._instantiatedSession._sc._jsc is None:
             SparkSession._instantiatedSession = self
+            SparkSession._activeSession = self
             self._jvm.SparkSession.setDefaultSession(self._jsparkSession)
+            self._jvm.SparkSession.setActiveSession(self._jsparkSession)
 
     def _repr_html_(self):
         return """
@@ -255,6 +258,29 @@ class SparkSession(object):
         """
         return self.__class__(self._sc, self._jsparkSession.newSession())
 
+    @classmethod
+    @since(3.0)
+    def getActiveSession(cls):
+        """
+        Returns the active SparkSession for the current thread, returned by 
the builder.
+        >>> s = SparkSession.getActiveSession()
+        >>> l = [('Alice', 1)]
+        >>> rdd = s.sparkContext.parallelize(l)
+        >>> df = s.createDataFrame(rdd, ['name', 'age'])
+        >>> df.select("age").collect()
+        [Row(age=1)]
+        """
+        from pyspark import SparkContext
+        sc = SparkContext._active_spark_context
+        if sc is None:
+            return None
+        else:
+            if sc._jvm.SparkSession.getActiveSession().isDefined():
+                SparkSession(sc, sc._jvm.SparkSession.getActiveSession().get())
+                return SparkSession._activeSession
+            else:
+                return None
+
     @property
     @since(2.0)
     def sparkContext(self):
@@ -671,6 +697,8 @@ class SparkSession(object):
             ...
         Py4JJavaError: ...
         """
+        SparkSession._activeSession = self
+        self._jvm.SparkSession.setActiveSession(self._jsparkSession)
         if isinstance(data, DataFrame):
             raise TypeError("data is already a DataFrame")
 
@@ -826,7 +854,9 @@ class SparkSession(object):
         self._sc.stop()
         # We should clean the default session up. See SPARK-23228.
         self._jvm.SparkSession.clearDefaultSession()
+        self._jvm.SparkSession.clearActiveSession()
         SparkSession._instantiatedSession = None
+        SparkSession._activeSession = None
 
     @since(2.0)
     def __enter__(self):

http://git-wip-us.apache.org/repos/asf/spark/blob/d367bdcf/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 82dc5a6..ad04270 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -3985,6 +3985,157 @@ class SparkSessionTests(PySparkTestCase):
             spark.stop()
 
 
+class SparkSessionTests2(unittest.TestCase):
+
+    def test_active_session(self):
+        spark = SparkSession.builder \
+            .master("local") \
+            .getOrCreate()
+        try:
+            activeSession = SparkSession.getActiveSession()
+            df = activeSession.createDataFrame([(1, 'Alice')], ['age', 'name'])
+            self.assertEqual(df.collect(), [Row(age=1, name=u'Alice')])
+        finally:
+            spark.stop()
+
+    def test_get_active_session_when_no_active_session(self):
+        active = SparkSession.getActiveSession()
+        self.assertEqual(active, None)
+        spark = SparkSession.builder \
+            .master("local") \
+            .getOrCreate()
+        active = SparkSession.getActiveSession()
+        self.assertEqual(active, spark)
+        spark.stop()
+        active = SparkSession.getActiveSession()
+        self.assertEqual(active, None)
+
+    def test_SparkSession(self):
+        spark = SparkSession.builder \
+            .master("local") \
+            .config("some-config", "v2") \
+            .getOrCreate()
+        try:
+            self.assertEqual(spark.conf.get("some-config"), "v2")
+            self.assertEqual(spark.sparkContext._conf.get("some-config"), "v2")
+            self.assertEqual(spark.version, spark.sparkContext.version)
+            spark.sql("CREATE DATABASE test_db")
+            spark.catalog.setCurrentDatabase("test_db")
+            self.assertEqual(spark.catalog.currentDatabase(), "test_db")
+            spark.sql("CREATE TABLE table1 (name STRING, age INT) USING 
parquet")
+            self.assertEqual(spark.table("table1").columns, ['name', 'age'])
+            self.assertEqual(spark.range(3).count(), 3)
+        finally:
+            spark.stop()
+
+    def test_global_default_session(self):
+        spark = SparkSession.builder \
+            .master("local") \
+            .getOrCreate()
+        try:
+            self.assertEqual(SparkSession.builder.getOrCreate(), spark)
+        finally:
+            spark.stop()
+
+    def test_default_and_active_session(self):
+        spark = SparkSession.builder \
+            .master("local") \
+            .getOrCreate()
+        activeSession = spark._jvm.SparkSession.getActiveSession()
+        defaultSession = spark._jvm.SparkSession.getDefaultSession()
+        try:
+            self.assertEqual(activeSession, defaultSession)
+        finally:
+            spark.stop()
+
+    def test_config_option_propagated_to_existing_session(self):
+        session1 = SparkSession.builder \
+            .master("local") \
+            .config("spark-config1", "a") \
+            .getOrCreate()
+        self.assertEqual(session1.conf.get("spark-config1"), "a")
+        session2 = SparkSession.builder \
+            .config("spark-config1", "b") \
+            .getOrCreate()
+        try:
+            self.assertEqual(session1, session2)
+            self.assertEqual(session1.conf.get("spark-config1"), "b")
+        finally:
+            session1.stop()
+
+    def test_new_session(self):
+        session = SparkSession.builder \
+            .master("local") \
+            .getOrCreate()
+        newSession = session.newSession()
+        try:
+            self.assertNotEqual(session, newSession)
+        finally:
+            session.stop()
+            newSession.stop()
+
+    def test_create_new_session_if_old_session_stopped(self):
+        session = SparkSession.builder \
+            .master("local") \
+            .getOrCreate()
+        session.stop()
+        newSession = SparkSession.builder \
+            .master("local") \
+            .getOrCreate()
+        try:
+            self.assertNotEqual(session, newSession)
+        finally:
+            newSession.stop()
+
+    def test_active_session_with_None_and_not_None_context(self):
+        from pyspark.context import SparkContext
+        from pyspark.conf import SparkConf
+        sc = None
+        session = None
+        try:
+            sc = SparkContext._active_spark_context
+            self.assertEqual(sc, None)
+            activeSession = SparkSession.getActiveSession()
+            self.assertEqual(activeSession, None)
+            sparkConf = SparkConf()
+            sc = SparkContext.getOrCreate(sparkConf)
+            activeSession = sc._jvm.SparkSession.getActiveSession()
+            self.assertFalse(activeSession.isDefined())
+            session = SparkSession(sc)
+            activeSession = sc._jvm.SparkSession.getActiveSession()
+            self.assertTrue(activeSession.isDefined())
+            activeSession2 = SparkSession.getActiveSession()
+            self.assertNotEqual(activeSession2, None)
+        finally:
+            if session is not None:
+                session.stop()
+            if sc is not None:
+                sc.stop()
+
+
+class SparkSessionTests3(ReusedSQLTestCase):
+
+    def test_get_active_session_after_create_dataframe(self):
+        session2 = None
+        try:
+            activeSession1 = SparkSession.getActiveSession()
+            session1 = self.spark
+            self.assertEqual(session1, activeSession1)
+            session2 = self.spark.newSession()
+            activeSession2 = SparkSession.getActiveSession()
+            self.assertEqual(session1, activeSession2)
+            self.assertNotEqual(session2, activeSession2)
+            session2.createDataFrame([(1, 'Alice')], ['age', 'name'])
+            activeSession3 = SparkSession.getActiveSession()
+            self.assertEqual(session2, activeSession3)
+            session1.createDataFrame([(1, 'Alice')], ['age', 'name'])
+            activeSession4 = SparkSession.getActiveSession()
+            self.assertEqual(session1, activeSession4)
+        finally:
+            if session2 is not None:
+                session2.stop()
+
+
 class UDFInitializationTests(unittest.TestCase):
     def tearDown(self):
         if SparkSession._instantiatedSession is not None:


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to