This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new ce53fdfa32f [SPARK-44410][PYTHON][CONNECT] Set active session in 
create, not just getOrCreate
ce53fdfa32f is described below

commit ce53fdfa32f9fdabdf14df695aba674bb4a377a8
Author: Alice Sayutina <alice.sayut...@databricks.com>
AuthorDate: Sun Jul 16 16:31:13 2023 +0900

    [SPARK-44410][PYTHON][CONNECT] Set active session in create, not just 
getOrCreate
    
    ### What changes were proposed in this pull request?
    
    ML and other uses rely on _active_spark_session to find spark session.
    
    Sessions created using getOrCreate method set this variable, but sessions 
created with create don't.
    
    Update create method to set _active_spark_session.
    
    ### Why are the changes needed?
    This breaks spark connect customers, such as pyspark.ml and pandas from 
finding created session if it was created with create.
    
    ### Does this PR introduce _any_ user-facing change?
    Sessions created by create are set as current session. This is slightly 
different behavior then before, however this
    suits interest of almost all clients. The only case it might break is if 
someone uses mix of both `create` and `getOrCreate` relying on this exact 
semantic.
    
    We can hide it under configuration flag, e.g. 
`create(set_active_session=False)` if undesired. In this case clients who use 
`create` and want to use pyspark.ml/pandas will need to update to set it to 
True.
    
    ### How was this patch tested?
    UT
    
    Closes #41987 from cdkrot/spark_session_create_store_session.
    
    Authored-by: Alice Sayutina <alice.sayut...@databricks.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 python/pyspark/sql/connect/session.py                  | 11 +++++++++--
 python/pyspark/sql/tests/connect/test_connect_basic.py |  4 ++--
 python/pyspark/sql/tests/connect/test_session.py       |  7 +++++++
 3 files changed, 18 insertions(+), 4 deletions(-)

diff --git a/python/pyspark/sql/connect/session.py 
b/python/pyspark/sql/connect/session.py
index 3f9d46a22f4..52eab1bf5f9 100644
--- a/python/pyspark/sql/connect/session.py
+++ b/python/pyspark/sql/connect/session.py
@@ -93,6 +93,8 @@ if TYPE_CHECKING:
 
 # `_active_spark_session` stores the active spark connect session created by
 # `SparkSession.builder.getOrCreate`. It is used by ML code.
+#  If sessions are created with `SparkSession.builder.create`, it stores
+#  The last created session
 _active_spark_session = None
 
 
@@ -172,6 +174,8 @@ class SparkSession:
             )
 
         def create(self) -> "SparkSession":
+            global _active_spark_session
+
             has_channel_builder = self._channel_builder is not None
             has_spark_remote = "spark.remote" in self._options
 
@@ -188,11 +192,14 @@ class SparkSession:
 
             if has_channel_builder:
                 assert self._channel_builder is not None
-                return SparkSession(connection=self._channel_builder)
+                session = SparkSession(connection=self._channel_builder)
             else:
                 spark_remote = to_str(self._options.get("spark.remote"))
                 assert spark_remote is not None
-                return SparkSession(connection=spark_remote)
+                session = SparkSession(connection=spark_remote)
+
+            _active_spark_session = session
+            return session
 
         def getOrCreate(self) -> "SparkSession":
             global _active_spark_session
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py 
b/python/pyspark/sql/tests/connect/test_connect_basic.py
index c1235620990..5259ea6b5f5 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -3323,9 +3323,9 @@ class SparkConnectSessionTests(ReusedConnectTestCase):
         other = 
PySparkSession.builder.remote("sc://other.remote:114/").create()
         self.assertNotEquals(self.spark, other)
 
-        # Reuses an active session that was previously created.
+        # Gets currently active session.
         same = 
PySparkSession.builder.remote("sc://other.remote.host:114/").getOrCreate()
-        self.assertEquals(self.spark, same)
+        self.assertEquals(other, same)
         same.stop()
 
         # Make sure the environment is clean.
diff --git a/python/pyspark/sql/tests/connect/test_session.py 
b/python/pyspark/sql/tests/connect/test_session.py
index bde22d80303..2f14eeddc1e 100644
--- a/python/pyspark/sql/tests/connect/test_session.py
+++ b/python/pyspark/sql/tests/connect/test_session.py
@@ -63,3 +63,10 @@ class SparkSessionTestCase(unittest.TestCase):
         self.assertFalse(session.is_stopped)
         session.stop()
         self.assertTrue(session.is_stopped)
+
+    def test_session_create_sets_active_session(self):
+        session = RemoteSparkSession.builder.remote("sc://abc").create()
+        session2 = 
RemoteSparkSession.builder.remote("sc://other").getOrCreate()
+
+        self.assertIs(session, session2)
+        session.stop()


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to