This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch branch-3.5 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.5 by this push: new 60073f31831 [SPARK-45167][CONNECT][PYTHON][3.5] Python client must call `release_all` 60073f31831 is described below commit 60073f318313ab2329ea1504ef7538641433852e Author: Martin Grund <martin.gr...@databricks.com> AuthorDate: Tue Sep 19 08:32:21 2023 +0900 [SPARK-45167][CONNECT][PYTHON][3.5] Python client must call `release_all` ### What changes were proposed in this pull request? Cherry-pick of https://github.com/apache/spark/pull/42929 Previously the Python client would not call `release_all` after fetching all results and leaving the query dangling. The query would then be removed after the five minute timeout. This patch adds proper testing for calling release all and release until. In addition it fixes a test race condition where we would close the SparkSession which would in turn close the GRPC channel which might have dangling async release calls hanging. ### Why are the changes needed? Stability ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? new UT ### Was this patch authored or co-authored using generative AI tooling? No Closes #42973 from grundprinzip/SPARK-45167-3.5. Authored-by: Martin Grund <martin.gr...@databricks.com> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- python/pyspark/sql/connect/client/core.py | 1 + python/pyspark/sql/connect/client/reattach.py | 37 +++- .../sql/tests/connect/client/test_client.py | 195 ++++++++++++++++++++- 3 files changed, 226 insertions(+), 7 deletions(-) diff --git a/python/pyspark/sql/connect/client/core.py b/python/pyspark/sql/connect/client/core.py index 7b3299d123b..7b1aafbefeb 100644 --- a/python/pyspark/sql/connect/client/core.py +++ b/python/pyspark/sql/connect/client/core.py @@ -1005,6 +1005,7 @@ class SparkConnectClient(object): """ Close the channel. """ + ExecutePlanResponseReattachableIterator.shutdown() self._channel.close() self._closed = True diff --git a/python/pyspark/sql/connect/client/reattach.py b/python/pyspark/sql/connect/client/reattach.py index 7e1e722d5fd..e58864b965b 100644 --- a/python/pyspark/sql/connect/client/reattach.py +++ b/python/pyspark/sql/connect/client/reattach.py @@ -21,7 +21,9 @@ check_dependencies(__name__) import warnings import uuid from collections.abc import Generator -from typing import Optional, Dict, Any, Iterator, Iterable, Tuple, Callable, cast +from typing import Optional, Dict, Any, Iterator, Iterable, Tuple, Callable, cast, Type, ClassVar +from multiprocessing import RLock +from multiprocessing.synchronize import RLock as RLockBase from multiprocessing.pool import ThreadPool import os @@ -53,7 +55,30 @@ class ExecutePlanResponseReattachableIterator(Generator): ReleaseExecute RPCs that instruct the server to release responses that it already processed. """ - _release_thread_pool = ThreadPool(os.cpu_count() if os.cpu_count() else 8) + # Lock to manage the pool + _lock: ClassVar[RLockBase] = RLock() + _release_thread_pool: Optional[ThreadPool] = ThreadPool(os.cpu_count() if os.cpu_count() else 8) + + @classmethod + def shutdown(cls: Type["ExecutePlanResponseReattachableIterator"]) -> None: + """ + When the channel is closed, this method will be called before, to make sure all + outstanding calls are closed. + """ + with cls._lock: + if cls._release_thread_pool is not None: + cls._release_thread_pool.close() + cls._release_thread_pool.join() + cls._release_thread_pool = None + + @classmethod + def _initialize_pool_if_necessary(cls: Type["ExecutePlanResponseReattachableIterator"]) -> None: + """ + If the processing pool for the release calls is None, initialize the pool exactly once. + """ + with cls._lock: + if cls._release_thread_pool is None: + cls._release_thread_pool = ThreadPool(os.cpu_count() if os.cpu_count() else 8) def __init__( self, @@ -62,6 +87,7 @@ class ExecutePlanResponseReattachableIterator(Generator): retry_policy: Dict[str, Any], metadata: Iterable[Tuple[str, str]], ): + ExecutePlanResponseReattachableIterator._initialize_pool_if_necessary() self._request = request self._retry_policy = retry_policy if request.operation_id: @@ -111,7 +137,6 @@ class ExecutePlanResponseReattachableIterator(Generator): self._last_returned_response_id = ret.response_id if ret.HasField("result_complete"): - self._result_complete = True self._release_all() else: self._release_until(self._last_returned_response_id) @@ -190,7 +215,8 @@ class ExecutePlanResponseReattachableIterator(Generator): except Exception as e: warnings.warn(f"ReleaseExecute failed with exception: {e}.") - ExecutePlanResponseReattachableIterator._release_thread_pool.apply_async(target) + if ExecutePlanResponseReattachableIterator._release_thread_pool is not None: + ExecutePlanResponseReattachableIterator._release_thread_pool.apply_async(target) def _release_all(self) -> None: """ @@ -218,7 +244,8 @@ class ExecutePlanResponseReattachableIterator(Generator): except Exception as e: warnings.warn(f"ReleaseExecute failed with exception: {e}.") - ExecutePlanResponseReattachableIterator._release_thread_pool.apply_async(target) + if ExecutePlanResponseReattachableIterator._release_thread_pool is not None: + ExecutePlanResponseReattachableIterator._release_thread_pool.apply_async(target) self._result_complete = True def _call_iter(self, iter_fun: Callable) -> Any: diff --git a/python/pyspark/sql/tests/connect/client/test_client.py b/python/pyspark/sql/tests/connect/client/test_client.py index 98f68767b8b..cf43fb16df7 100644 --- a/python/pyspark/sql/tests/connect/client/test_client.py +++ b/python/pyspark/sql/tests/connect/client/test_client.py @@ -17,14 +17,20 @@ import unittest import uuid -from typing import Optional +from collections.abc import Generator +from typing import Optional, Any + +import grpc from pyspark.sql.connect.client import SparkConnectClient, ChannelBuilder import pyspark.sql.connect.proto as proto from pyspark.testing.connectutils import should_test_connect, connect_requirement_message from pyspark.sql.connect.client.core import Retrying -from pyspark.sql.connect.client.reattach import RetryException +from pyspark.sql.connect.client.reattach import ( + RetryException, + ExecutePlanResponseReattachableIterator, +) if should_test_connect: import pandas as pd @@ -120,6 +126,191 @@ class SparkConnectClientTestCase(unittest.TestCase): self.assertEqual(client._session_id, chan.session_id) +@unittest.skipIf(not should_test_connect, connect_requirement_message) +class SparkConnectClientReattachTestCase(unittest.TestCase): + def setUp(self) -> None: + self.request = proto.ExecutePlanRequest() + self.policy = { + "max_retries": 3, + "backoff_multiplier": 4.0, + "initial_backoff": 10, + "max_backoff": 10, + "jitter": 10, + "min_jitter_threshold": 10, + } + self.response = proto.ExecutePlanResponse( + response_id="1", + ) + self.finished = proto.ExecutePlanResponse( + result_complete=proto.ExecutePlanResponse.ResultComplete(), + response_id="2", + ) + + def _stub_with(self, execute=None, attach=None): + return MockSparkConnectStub( + execute_ops=ResponseGenerator(execute) if execute is not None else None, + attach_ops=ResponseGenerator(attach) if attach is not None else None, + ) + + def assertEventually(self, callable, timeout_ms=1000): + """Helper method that will continuously evaluate the callable to not raise an + exception.""" + import time + + limit = time.monotonic_ns() + timeout_ms * 1000 * 1000 + while time.monotonic_ns() < limit: + try: + callable() + break + except Exception: + time.sleep(0.1) + callable() + + def test_basic_flow(self): + stub = self._stub_with([self.response, self.finished]) + ite = ExecutePlanResponseReattachableIterator(self.request, stub, self.policy, []) + for b in ite: + pass + + def check_all(): + self.assertEqual(0, stub.attach_calls) + self.assertEqual(1, stub.release_until_calls) + self.assertEqual(1, stub.release_calls) + self.assertEqual(1, stub.execute_calls) + + self.assertEventually(check_all, timeout_ms=1000) + + def test_fail_during_execute(self): + def fatal(): + raise TestException("Fatal") + + stub = self._stub_with([self.response, fatal]) + with self.assertRaises(TestException): + ite = ExecutePlanResponseReattachableIterator(self.request, stub, self.policy, []) + for b in ite: + pass + + def check(): + self.assertEqual(0, stub.attach_calls) + self.assertEqual(1, stub.release_calls) + self.assertEqual(1, stub.release_until_calls) + self.assertEqual(1, stub.execute_calls) + + self.assertEventually(check, timeout_ms=1000) + + def test_fail_and_retry_during_execute(self): + def non_fatal(): + raise TestException("Non Fatal", grpc.StatusCode.UNAVAILABLE) + + stub = self._stub_with( + [self.response, non_fatal], [self.response, self.response, self.finished] + ) + ite = ExecutePlanResponseReattachableIterator(self.request, stub, self.policy, []) + for b in ite: + pass + + def check(): + self.assertEqual(1, stub.attach_calls) + self.assertEqual(1, stub.release_calls) + self.assertEqual(3, stub.release_until_calls) + self.assertEqual(1, stub.execute_calls) + + self.assertEventually(check, timeout_ms=1000) + + def test_fail_and_retry_during_reattach(self): + count = 0 + + def non_fatal(): + nonlocal count + if count < 2: + count += 1 + raise TestException("Non Fatal", grpc.StatusCode.UNAVAILABLE) + else: + return proto.ExecutePlanResponse() + + stub = self._stub_with( + [self.response, non_fatal], [self.response, non_fatal, self.response, self.finished] + ) + ite = ExecutePlanResponseReattachableIterator(self.request, stub, self.policy, []) + for b in ite: + pass + + def check(): + self.assertEqual(2, stub.attach_calls) + self.assertEqual(3, stub.release_until_calls) + self.assertEqual(1, stub.release_calls) + self.assertEqual(1, stub.execute_calls) + + self.assertEventually(check, timeout_ms=1000) + + +class TestException(grpc.RpcError, grpc.Call): + """Exception mock to test retryable exceptions.""" + + def __init__(self, msg, code=grpc.StatusCode.INTERNAL): + self.msg = msg + self._code = code + + def code(self): + return self._code + + def __str__(self): + return self.msg + + def trailing_metadata(self): + return () + + +class ResponseGenerator(Generator): + """This class is used to generate values that are returned by the streaming + iterator of the GRPC stub.""" + + def __init__(self, funs): + self._funs = funs + self._iterator = iter(self._funs) + + def send(self, value: Any) -> proto.ExecutePlanResponse: + val = next(self._iterator) + if callable(val): + return val() + else: + return val + + def throw(self, type: Any = None, value: Any = None, traceback: Any = None) -> Any: + super().throw(type, value, traceback) + + def close(self) -> None: + return super().close() + + +class MockSparkConnectStub: + """Simple mock class for the GRPC stub used by the re-attachable execution.""" + + def __init__(self, execute_ops=None, attach_ops=None): + self._execute_ops = execute_ops + self._attach_ops = attach_ops + # Call counters + self.execute_calls = 0 + self.release_calls = 0 + self.release_until_calls = 0 + self.attach_calls = 0 + + def ExecutePlan(self, *args, **kwargs): + self.execute_calls += 1 + return self._execute_ops + + def ReattachExecute(self, *args, **kwargs): + self.attach_calls += 1 + return self._attach_ops + + def ReleaseExecute(self, req: proto.ReleaseExecuteRequest, *args, **kwargs): + if req.HasField("release_all"): + self.release_calls += 1 + elif req.HasField("release_until"): + print("increment") + self.release_until_calls += 1 + + class MockService: # Simplest mock of the SparkConnectService. # If this needs more complex logic, it needs to be replaced with Python mocking. --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org