This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 7d04d0f043d2 [SPARK-47986][CONNECT][PYTHON] Unable to create a new 
session when the default session is closed by the server
7d04d0f043d2 is described below

commit 7d04d0f043d2af6b518c6567443a6a5bed7ae541
Author: Niranjan Jayakar <n...@databricks.com>
AuthorDate: Fri Apr 26 15:24:02 2024 +0800

    [SPARK-47986][CONNECT][PYTHON] Unable to create a new session when the 
default session is closed by the server
    
    ### What changes were proposed in this pull request?
    
    When the server closes a session, usually after a cluster restart,
    the client is unaware of this until it receives an error.
    
    At this point, the client in unable to create a new session to the
    same connect endpoint, since the stale session is still recorded
    as the active and default session.
    
    With this change, when the server communicates that the session
    has changed via a GRPC error, the session and the respective client
    are marked as stale. A new default connection can be created
    via the session builder.
    
    ### Why are the changes needed?
    
    See section above.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Attached unit tests
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #46221 from nija-at/session-expires.
    
    Authored-by: Niranjan Jayakar <n...@databricks.com>
    Signed-off-by: Ruifeng Zheng <ruife...@apache.org>
---
 python/pyspark/sql/connect/client/core.py                |  3 +++
 python/pyspark/sql/connect/session.py                    |  4 ++--
 python/pyspark/sql/tests/connect/test_connect_session.py | 14 ++++++++++++++
 3 files changed, 19 insertions(+), 2 deletions(-)

diff --git a/python/pyspark/sql/connect/client/core.py 
b/python/pyspark/sql/connect/client/core.py
index 0bdfb4bb7910..badd9a33397e 100644
--- a/python/pyspark/sql/connect/client/core.py
+++ b/python/pyspark/sql/connect/client/core.py
@@ -1763,6 +1763,9 @@ class SparkConnectClient(object):
                     info = error_details_pb2.ErrorInfo()
                     d.Unpack(info)
 
+                    if info.metadata["errorClass"] == 
"INVALID_HANDLE.SESSION_CHANGED":
+                        self._closed = True
+
                     raise convert_exception(
                         info,
                         status.message,
diff --git a/python/pyspark/sql/connect/session.py 
b/python/pyspark/sql/connect/session.py
index 5e677efe6ca6..eb7a546ca18d 100644
--- a/python/pyspark/sql/connect/session.py
+++ b/python/pyspark/sql/connect/session.py
@@ -237,9 +237,9 @@ class SparkSession:
         def getOrCreate(self) -> "SparkSession":
             with SparkSession._lock:
                 session = SparkSession.getActiveSession()
-                if session is None:
+                if session is None or session.is_stopped:
                     session = SparkSession._default_session
-                    if session is None:
+                    if session is None or session.is_stopped:
                         session = self.create()
                 self._apply_options(session)
                 return session
diff --git a/python/pyspark/sql/tests/connect/test_connect_session.py 
b/python/pyspark/sql/tests/connect/test_connect_session.py
index 1caf3525cfbb..c5ce697a9561 100644
--- a/python/pyspark/sql/tests/connect/test_connect_session.py
+++ b/python/pyspark/sql/tests/connect/test_connect_session.py
@@ -242,6 +242,20 @@ class SparkConnectSessionTests(ReusedConnectTestCase):
         session = 
RemoteSparkSession.builder.channelBuilder(CustomChannelBuilder()).create()
         session.sql("select 1 + 1")
 
+    def test_reset_when_server_session_changes(self):
+        session = 
RemoteSparkSession.builder.remote("sc://localhost").getOrCreate()
+        # run a simple query so the session id is synchronized.
+        session.range(3).collect()
+
+        # trigger a mismatch between client session id and server session id.
+        session._client._session_id = str(uuid.uuid4())
+        with self.assertRaises(SparkConnectException):
+            session.range(3).collect()
+
+        # assert that getOrCreate() generates a new session
+        session = 
RemoteSparkSession.builder.remote("sc://localhost").getOrCreate()
+        session.range(3).collect()
+
 
 class SparkConnectSessionWithOptionsTest(unittest.TestCase):
     def setUp(self) -> None:


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to