This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch branch-3.5
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.5 by this push:
     new 60073f31831 [SPARK-45167][CONNECT][PYTHON][3.5] Python client must 
call `release_all`
60073f31831 is described below

commit 60073f318313ab2329ea1504ef7538641433852e
Author: Martin Grund <martin.gr...@databricks.com>
AuthorDate: Tue Sep 19 08:32:21 2023 +0900

    [SPARK-45167][CONNECT][PYTHON][3.5] Python client must call `release_all`
    
    ### What changes were proposed in this pull request?
    
    Cherry-pick of https://github.com/apache/spark/pull/42929
    
    Previously the Python client would not call `release_all` after fetching 
all results and leaving the query dangling. The query would then be removed 
after the five minute timeout.
    
    This patch adds proper testing for calling release all and release until.
    
    In addition it fixes a test race condition where we would close the 
SparkSession which would in turn close the GRPC channel which might have 
dangling async release calls hanging.
    
    ### Why are the changes needed?
    Stability
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    new UT
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #42973 from grundprinzip/SPARK-45167-3.5.
    
    Authored-by: Martin Grund <martin.gr...@databricks.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 python/pyspark/sql/connect/client/core.py          |   1 +
 python/pyspark/sql/connect/client/reattach.py      |  37 +++-
 .../sql/tests/connect/client/test_client.py        | 195 ++++++++++++++++++++-
 3 files changed, 226 insertions(+), 7 deletions(-)

diff --git a/python/pyspark/sql/connect/client/core.py 
b/python/pyspark/sql/connect/client/core.py
index 7b3299d123b..7b1aafbefeb 100644
--- a/python/pyspark/sql/connect/client/core.py
+++ b/python/pyspark/sql/connect/client/core.py
@@ -1005,6 +1005,7 @@ class SparkConnectClient(object):
         """
         Close the channel.
         """
+        ExecutePlanResponseReattachableIterator.shutdown()
         self._channel.close()
         self._closed = True
 
diff --git a/python/pyspark/sql/connect/client/reattach.py 
b/python/pyspark/sql/connect/client/reattach.py
index 7e1e722d5fd..e58864b965b 100644
--- a/python/pyspark/sql/connect/client/reattach.py
+++ b/python/pyspark/sql/connect/client/reattach.py
@@ -21,7 +21,9 @@ check_dependencies(__name__)
 import warnings
 import uuid
 from collections.abc import Generator
-from typing import Optional, Dict, Any, Iterator, Iterable, Tuple, Callable, 
cast
+from typing import Optional, Dict, Any, Iterator, Iterable, Tuple, Callable, 
cast, Type, ClassVar
+from multiprocessing import RLock
+from multiprocessing.synchronize import RLock as RLockBase
 from multiprocessing.pool import ThreadPool
 import os
 
@@ -53,7 +55,30 @@ class ExecutePlanResponseReattachableIterator(Generator):
     ReleaseExecute RPCs that instruct the server to release responses that it 
already processed.
     """
 
-    _release_thread_pool = ThreadPool(os.cpu_count() if os.cpu_count() else 8)
+    # Lock to manage the pool
+    _lock: ClassVar[RLockBase] = RLock()
+    _release_thread_pool: Optional[ThreadPool] = ThreadPool(os.cpu_count() if 
os.cpu_count() else 8)
+
+    @classmethod
+    def shutdown(cls: Type["ExecutePlanResponseReattachableIterator"]) -> None:
+        """
+        When the channel is closed, this method will be called before, to make 
sure all
+        outstanding calls are closed.
+        """
+        with cls._lock:
+            if cls._release_thread_pool is not None:
+                cls._release_thread_pool.close()
+                cls._release_thread_pool.join()
+                cls._release_thread_pool = None
+
+    @classmethod
+    def _initialize_pool_if_necessary(cls: 
Type["ExecutePlanResponseReattachableIterator"]) -> None:
+        """
+        If the processing pool for the release calls is None, initialize the 
pool exactly once.
+        """
+        with cls._lock:
+            if cls._release_thread_pool is None:
+                cls._release_thread_pool = ThreadPool(os.cpu_count() if 
os.cpu_count() else 8)
 
     def __init__(
         self,
@@ -62,6 +87,7 @@ class ExecutePlanResponseReattachableIterator(Generator):
         retry_policy: Dict[str, Any],
         metadata: Iterable[Tuple[str, str]],
     ):
+        ExecutePlanResponseReattachableIterator._initialize_pool_if_necessary()
         self._request = request
         self._retry_policy = retry_policy
         if request.operation_id:
@@ -111,7 +137,6 @@ class ExecutePlanResponseReattachableIterator(Generator):
 
         self._last_returned_response_id = ret.response_id
         if ret.HasField("result_complete"):
-            self._result_complete = True
             self._release_all()
         else:
             self._release_until(self._last_returned_response_id)
@@ -190,7 +215,8 @@ class ExecutePlanResponseReattachableIterator(Generator):
             except Exception as e:
                 warnings.warn(f"ReleaseExecute failed with exception: {e}.")
 
-        
ExecutePlanResponseReattachableIterator._release_thread_pool.apply_async(target)
+        if ExecutePlanResponseReattachableIterator._release_thread_pool is not 
None:
+            
ExecutePlanResponseReattachableIterator._release_thread_pool.apply_async(target)
 
     def _release_all(self) -> None:
         """
@@ -218,7 +244,8 @@ class ExecutePlanResponseReattachableIterator(Generator):
             except Exception as e:
                 warnings.warn(f"ReleaseExecute failed with exception: {e}.")
 
-        
ExecutePlanResponseReattachableIterator._release_thread_pool.apply_async(target)
+        if ExecutePlanResponseReattachableIterator._release_thread_pool is not 
None:
+            
ExecutePlanResponseReattachableIterator._release_thread_pool.apply_async(target)
         self._result_complete = True
 
     def _call_iter(self, iter_fun: Callable) -> Any:
diff --git a/python/pyspark/sql/tests/connect/client/test_client.py 
b/python/pyspark/sql/tests/connect/client/test_client.py
index 98f68767b8b..cf43fb16df7 100644
--- a/python/pyspark/sql/tests/connect/client/test_client.py
+++ b/python/pyspark/sql/tests/connect/client/test_client.py
@@ -17,14 +17,20 @@
 
 import unittest
 import uuid
-from typing import Optional
+from collections.abc import Generator
+from typing import Optional, Any
+
+import grpc
 
 from pyspark.sql.connect.client import SparkConnectClient, ChannelBuilder
 import pyspark.sql.connect.proto as proto
 from pyspark.testing.connectutils import should_test_connect, 
connect_requirement_message
 
 from pyspark.sql.connect.client.core import Retrying
-from pyspark.sql.connect.client.reattach import RetryException
+from pyspark.sql.connect.client.reattach import (
+    RetryException,
+    ExecutePlanResponseReattachableIterator,
+)
 
 if should_test_connect:
     import pandas as pd
@@ -120,6 +126,191 @@ class SparkConnectClientTestCase(unittest.TestCase):
         self.assertEqual(client._session_id, chan.session_id)
 
 
+@unittest.skipIf(not should_test_connect, connect_requirement_message)
+class SparkConnectClientReattachTestCase(unittest.TestCase):
+    def setUp(self) -> None:
+        self.request = proto.ExecutePlanRequest()
+        self.policy = {
+            "max_retries": 3,
+            "backoff_multiplier": 4.0,
+            "initial_backoff": 10,
+            "max_backoff": 10,
+            "jitter": 10,
+            "min_jitter_threshold": 10,
+        }
+        self.response = proto.ExecutePlanResponse(
+            response_id="1",
+        )
+        self.finished = proto.ExecutePlanResponse(
+            result_complete=proto.ExecutePlanResponse.ResultComplete(),
+            response_id="2",
+        )
+
+    def _stub_with(self, execute=None, attach=None):
+        return MockSparkConnectStub(
+            execute_ops=ResponseGenerator(execute) if execute is not None else 
None,
+            attach_ops=ResponseGenerator(attach) if attach is not None else 
None,
+        )
+
+    def assertEventually(self, callable, timeout_ms=1000):
+        """Helper method that will continuously evaluate the callable to not 
raise an
+        exception."""
+        import time
+
+        limit = time.monotonic_ns() + timeout_ms * 1000 * 1000
+        while time.monotonic_ns() < limit:
+            try:
+                callable()
+                break
+            except Exception:
+                time.sleep(0.1)
+        callable()
+
+    def test_basic_flow(self):
+        stub = self._stub_with([self.response, self.finished])
+        ite = ExecutePlanResponseReattachableIterator(self.request, stub, 
self.policy, [])
+        for b in ite:
+            pass
+
+        def check_all():
+            self.assertEqual(0, stub.attach_calls)
+            self.assertEqual(1, stub.release_until_calls)
+            self.assertEqual(1, stub.release_calls)
+            self.assertEqual(1, stub.execute_calls)
+
+        self.assertEventually(check_all, timeout_ms=1000)
+
+    def test_fail_during_execute(self):
+        def fatal():
+            raise TestException("Fatal")
+
+        stub = self._stub_with([self.response, fatal])
+        with self.assertRaises(TestException):
+            ite = ExecutePlanResponseReattachableIterator(self.request, stub, 
self.policy, [])
+            for b in ite:
+                pass
+
+        def check():
+            self.assertEqual(0, stub.attach_calls)
+            self.assertEqual(1, stub.release_calls)
+            self.assertEqual(1, stub.release_until_calls)
+            self.assertEqual(1, stub.execute_calls)
+
+        self.assertEventually(check, timeout_ms=1000)
+
+    def test_fail_and_retry_during_execute(self):
+        def non_fatal():
+            raise TestException("Non Fatal", grpc.StatusCode.UNAVAILABLE)
+
+        stub = self._stub_with(
+            [self.response, non_fatal], [self.response, self.response, 
self.finished]
+        )
+        ite = ExecutePlanResponseReattachableIterator(self.request, stub, 
self.policy, [])
+        for b in ite:
+            pass
+
+        def check():
+            self.assertEqual(1, stub.attach_calls)
+            self.assertEqual(1, stub.release_calls)
+            self.assertEqual(3, stub.release_until_calls)
+            self.assertEqual(1, stub.execute_calls)
+
+        self.assertEventually(check, timeout_ms=1000)
+
+    def test_fail_and_retry_during_reattach(self):
+        count = 0
+
+        def non_fatal():
+            nonlocal count
+            if count < 2:
+                count += 1
+                raise TestException("Non Fatal", grpc.StatusCode.UNAVAILABLE)
+            else:
+                return proto.ExecutePlanResponse()
+
+        stub = self._stub_with(
+            [self.response, non_fatal], [self.response, non_fatal, 
self.response, self.finished]
+        )
+        ite = ExecutePlanResponseReattachableIterator(self.request, stub, 
self.policy, [])
+        for b in ite:
+            pass
+
+        def check():
+            self.assertEqual(2, stub.attach_calls)
+            self.assertEqual(3, stub.release_until_calls)
+            self.assertEqual(1, stub.release_calls)
+            self.assertEqual(1, stub.execute_calls)
+
+        self.assertEventually(check, timeout_ms=1000)
+
+
+class TestException(grpc.RpcError, grpc.Call):
+    """Exception mock to test retryable exceptions."""
+
+    def __init__(self, msg, code=grpc.StatusCode.INTERNAL):
+        self.msg = msg
+        self._code = code
+
+    def code(self):
+        return self._code
+
+    def __str__(self):
+        return self.msg
+
+    def trailing_metadata(self):
+        return ()
+
+
+class ResponseGenerator(Generator):
+    """This class is used to generate values that are returned by the streaming
+    iterator of the GRPC stub."""
+
+    def __init__(self, funs):
+        self._funs = funs
+        self._iterator = iter(self._funs)
+
+    def send(self, value: Any) -> proto.ExecutePlanResponse:
+        val = next(self._iterator)
+        if callable(val):
+            return val()
+        else:
+            return val
+
+    def throw(self, type: Any = None, value: Any = None, traceback: Any = 
None) -> Any:
+        super().throw(type, value, traceback)
+
+    def close(self) -> None:
+        return super().close()
+
+
+class MockSparkConnectStub:
+    """Simple mock class for the GRPC stub used by the re-attachable 
execution."""
+
+    def __init__(self, execute_ops=None, attach_ops=None):
+        self._execute_ops = execute_ops
+        self._attach_ops = attach_ops
+        # Call counters
+        self.execute_calls = 0
+        self.release_calls = 0
+        self.release_until_calls = 0
+        self.attach_calls = 0
+
+    def ExecutePlan(self, *args, **kwargs):
+        self.execute_calls += 1
+        return self._execute_ops
+
+    def ReattachExecute(self, *args, **kwargs):
+        self.attach_calls += 1
+        return self._attach_ops
+
+    def ReleaseExecute(self, req: proto.ReleaseExecuteRequest, *args, 
**kwargs):
+        if req.HasField("release_all"):
+            self.release_calls += 1
+        elif req.HasField("release_until"):
+            print("increment")
+            self.release_until_calls += 1
+
+
 class MockService:
     # Simplest mock of the SparkConnectService.
     # If this needs more complex logic, it needs to be replaced with Python 
mocking.


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to