This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 2f31d05d75e1 [SPARK-48056][CONNECT][PYTHON] Re-execute plan if a 
SESSION_NOT_FOUND error is raised and no partial response was received
2f31d05d75e1 is described below

commit 2f31d05d75e12029aeb39225a7c43ede66f5fb00
Author: Niranjan Jayakar <n...@databricks.com>
AuthorDate: Thu May 2 21:20:34 2024 +0900

    [SPARK-48056][CONNECT][PYTHON] Re-execute plan if a SESSION_NOT_FOUND error 
is raised and no partial response was received
    
    ### What changes were proposed in this pull request?
    
    Similar to OPERATION_NOT_FOUND, re-attempt to execute
    the original spark connect plan when a SESSION_NOT_FOUND is
    received from the spark connect service and no partial responses
    were previously received.
    
    ### Why are the changes needed?
    
    This error has been noticed to occur during a cluster cold start
    and when a request arrives when the connect service is not fully
    initialized.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Prevoiusly, connect-based pyspark APIs would fail with the error code
    "INVALID_HANDLE.SESSION_NOT_FOUND" in the very first request to
    the service.
    With this change, the client will now automatically retry.
    
    ### How was this patch tested?
    
    Attached unit tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #46297 from nija-at/session-not-found.
    
    Authored-by: Niranjan Jayakar <n...@databricks.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 python/pyspark/sql/connect/client/reattach.py      |  5 +-
 .../sql/tests/connect/client/test_client.py        | 85 +++++++++++++++++++++-
 2 files changed, 85 insertions(+), 5 deletions(-)

diff --git a/python/pyspark/sql/connect/client/reattach.py 
b/python/pyspark/sql/connect/client/reattach.py
index 4468582ca80e..cc50e5892631 100644
--- a/python/pyspark/sql/connect/client/reattach.py
+++ b/python/pyspark/sql/connect/client/reattach.py
@@ -254,7 +254,10 @@ class ExecutePlanResponseReattachableIterator(Generator):
             return iter_fun()
         except grpc.RpcError as e:
             status = rpc_status.from_call(cast(grpc.Call, e))
-            if status is not None and "INVALID_HANDLE.OPERATION_NOT_FOUND" in 
status.message:
+            if status is not None and (
+                "INVALID_HANDLE.OPERATION_NOT_FOUND" in status.message
+                or "INVALID_HANDLE.SESSION_NOT_FOUND" in status.message
+            ):
                 if self._last_returned_response_id is not None:
                     raise PySparkRuntimeError(
                         error_class="RESPONSE_ALREADY_RECEIVED",
diff --git a/python/pyspark/sql/tests/connect/client/test_client.py 
b/python/pyspark/sql/tests/connect/client/test_client.py
index b96fc44d50a7..4f54a0a67d8a 100644
--- a/python/pyspark/sql/tests/connect/client/test_client.py
+++ b/python/pyspark/sql/tests/connect/client/test_client.py
@@ -18,13 +18,14 @@
 import unittest
 import uuid
 from collections.abc import Generator
-from typing import Optional, Any
+from typing import Optional, Any, Union
 
 from pyspark.testing.connectutils import should_test_connect, 
connect_requirement_message
 from pyspark.testing.utils import eventually
 
 if should_test_connect:
     import grpc
+    from google.rpc import status_pb2
     import pandas as pd
     import pyarrow as pa
     from pyspark.sql.connect.client import SparkConnectClient, 
DefaultChannelBuilder
@@ -33,7 +34,7 @@ if should_test_connect:
         DefaultPolicy,
     )
     from pyspark.sql.connect.client.reattach import 
ExecutePlanResponseReattachableIterator
-    from pyspark.errors import RetriesExceeded
+    from pyspark.errors import PySparkRuntimeError, RetriesExceeded
     import pyspark.sql.connect.proto as proto
 
     class TestPolicy(DefaultPolicy):
@@ -50,9 +51,17 @@ if should_test_connect:
     class TestException(grpc.RpcError, grpc.Call):
         """Exception mock to test retryable exceptions."""
 
-        def __init__(self, msg, code=grpc.StatusCode.INTERNAL):
+        def __init__(
+            self,
+            msg,
+            code=grpc.StatusCode.INTERNAL,
+            trailing_status: Union[status_pb2.Status, None] = None,
+        ):
             self.msg = msg
             self._code = code
+            self._trailer: dict[str, Any] = {}
+            if trailing_status is not None:
+                self._trailer["grpc-status-details-bin"] = 
trailing_status.SerializeToString()
 
         def code(self):
             return self._code
@@ -60,8 +69,11 @@ if should_test_connect:
         def __str__(self):
             return self.msg
 
+        def details(self):
+            return self.msg
+
         def trailing_metadata(self):
-            return ()
+            return None if not self._trailer else self._trailer.items()
 
     class ResponseGenerator(Generator):
         """This class is used to generate values that are returned by the 
streaming
@@ -340,6 +352,71 @@ class 
SparkConnectClientReattachTestCase(unittest.TestCase):
 
         eventually(timeout=1, catch_assertions=True)(check)()
 
+    def test_not_found_recovers(self):
+        """SPARK-48056: Assert that the client recovers from session or 
operation not
+        found error if no partial responses were previously received.
+        """
+
+        def not_found_recovers(error_code: str):
+            def not_found():
+                raise TestException(
+                    error_code,
+                    grpc.StatusCode.UNAVAILABLE,
+                    trailing_status=status_pb2.Status(code=14, 
message=error_code, details=""),
+                )
+
+            stub = self._stub_with([not_found, self.finished])
+            ite = ExecutePlanResponseReattachableIterator(self.request, stub, 
self.retrying, [])
+
+            for _ in ite:
+                pass
+
+            def checks():
+                self.assertEquals(2, stub.execute_calls)
+                self.assertEquals(0, stub.attach_calls)
+                self.assertEquals(0, stub.release_calls)
+                self.assertEquals(0, stub.release_until_calls)
+
+            eventually(timeout=1, catch_assertions=True)(checks)()
+
+        parameters = ["INVALID_HANDLE.SESSION_NOT_FOUND", 
"INVALID_HANDLE.OPERATION_NOT_FOUND"]
+        for b in parameters:
+            not_found_recovers(b)
+
+    def test_not_found_fails(self):
+        """SPARK-48056: Assert that the client fails from session or operation 
not found error
+        if a partial response was previously received.
+        """
+
+        def not_found_fails(error_code: str):
+            def not_found():
+                raise TestException(
+                    error_code,
+                    grpc.StatusCode.UNAVAILABLE,
+                    trailing_status=status_pb2.Status(code=14, 
message=error_code, details=""),
+                )
+
+            stub = self._stub_with([self.response], [not_found])
+
+            with self.assertRaises(PySparkRuntimeError) as e:
+                ite = ExecutePlanResponseReattachableIterator(self.request, 
stub, self.retrying, [])
+                for _ in ite:
+                    pass
+
+            self.assertTrue("RESPONSE_ALREADY_RECEIVED" in 
e.exception.getMessage())
+
+            def checks():
+                self.assertEquals(1, stub.execute_calls)
+                self.assertEquals(1, stub.attach_calls)
+                self.assertEquals(0, stub.release_calls)
+                self.assertEquals(0, stub.release_until_calls)
+
+            eventually(timeout=1, catch_assertions=True)(checks)()
+
+        parameters = ["INVALID_HANDLE.SESSION_NOT_FOUND", 
"INVALID_HANDLE.OPERATION_NOT_FOUND"]
+        for b in parameters:
+            not_found_fails(b)
+
 
 if __name__ == "__main__":
     from pyspark.sql.tests.connect.client.test_client import *  # noqa: F401


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to