This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 68d8e65c782 [SPARK-44424][CONNECT][PYTHON] Python client for 
reattaching to existing execute in Spark Connect
68d8e65c782 is described below

commit 68d8e65c7829a4a41f8c159c9b30c34cd623da47
Author: Hyukjin Kwon <gurwls...@apache.org>
AuthorDate: Thu Aug 3 08:11:02 2023 +0900

    [SPARK-44424][CONNECT][PYTHON] Python client for reattaching to existing 
execute in Spark Connect
    
    ### What changes were proposed in this pull request?
    
    This PR proposes to implement the Python client side for 
https://github.com/apache/spark/pull/42228.
    
    Basically this PR applies the same changes of 
`ExecutePlanResponseReattachableIterator`, and `SparkConnectClient` to PySpark 
as  the symmetry.
    
    ### Why are the changes needed?
    
    To enable the same feature in https://github.com/apache/spark/pull/42228
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, see https://github.com/apache/spark/pull/42228.
    
    ### How was this patch tested?
    
    Existing unittests because it enables the feature by default. Also, manual 
E2E tests.
    
    Closes #42235 from HyukjinKwon/SPARK-44599.
    
    Lead-authored-by: Hyukjin Kwon <gurwls...@apache.org>
    Co-authored-by: Hyukjin Kwon <gurwls...@gmail.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 python/pyspark/sql/connect/client/core.py          | 207 +++++++++++-------
 python/pyspark/sql/connect/client/reattach.py      | 237 +++++++++++++++++++++
 python/pyspark/sql/connect/session.py              |   2 +-
 python/pyspark/sql/session.py                      |   2 +
 .../sql/tests/connect/client/test_client.py        |  16 +-
 python/pyspark/testing/connectutils.py             |   4 +
 6 files changed, 386 insertions(+), 82 deletions(-)

diff --git a/python/pyspark/sql/connect/client/core.py 
b/python/pyspark/sql/connect/client/core.py
index 0288bbc6508..d9def40ebe8 100644
--- a/python/pyspark/sql/connect/client/core.py
+++ b/python/pyspark/sql/connect/client/core.py
@@ -20,6 +20,7 @@ __all__ = [
     "getLogLevel",
 ]
 
+from pyspark.sql.connect.client.reattach import 
ExecutePlanResponseReattachableIterator
 from pyspark.sql.connect.utils import check_dependencies
 
 check_dependencies(__name__)
@@ -50,6 +51,7 @@ from typing import (
     Generator,
     Type,
     TYPE_CHECKING,
+    Sequence,
 )
 
 import pandas as pd
@@ -558,8 +560,6 @@ class ConfigResult:
 class SparkConnectClient(object):
     """
     Conceptually the remote spark session that communicates with the server
-
-    .. versionadded:: 3.4.0
     """
 
     @classmethod
@@ -572,24 +572,40 @@ class SparkConnectClient(object):
     def __init__(
         self,
         connection: Union[str, ChannelBuilder],
-        userId: Optional[str] = None,
-        channelOptions: Optional[List[Tuple[str, Any]]] = None,
-        retryPolicy: Optional[Dict[str, Any]] = None,
+        user_id: Optional[str] = None,
+        channel_options: Optional[List[Tuple[str, Any]]] = None,
+        retry_policy: Optional[Dict[str, Any]] = None,
+        use_reattachable_execute: bool = True,
     ):
         """
         Creates a new SparkSession for the Spark Connect interface.
 
         Parameters
         ----------
-        connection: Union[str,ChannelBuilder]
+        connection : str or :class:`ChannelBuilder`
             Connection string that is used to extract the connection 
parameters and configure
             the GRPC connection. Or instance of ChannelBuilder that creates 
GRPC connection.
             Defaults to `sc://localhost`.
-        userId : Optional[str]
+        user_id : str, optional
             Optional unique user ID that is used to differentiate multiple 
users and
             isolate their Spark Sessions. If the `user_id` is not set, will 
default to
             the $USER environment. Defining the user ID as part of the 
connection string
             takes precedence.
+        channel_options: list of tuple, optional
+            Additional options that can be passed to the GRPC channel 
construction.
+        retry_policy: dict of str and any, optional
+            Additional configuration for retrying. There are four 
configurations as below
+                * ``max_retries``
+                    Maximum number of tries default 15
+                * ``backoff_multiplier``
+                    Backoff multiplier for the policy. Default: 4(ms)
+                * ``initial_backoff``
+                    Backoff to wait before the first retry. Default: 50(ms)
+                * ``max_backoff``
+                    Maximum backoff controls the maximum amount of time to 
wait before retrying
+                    a failed request. Default: 60000(ms).
+        use_reattachable_execute: bool
+            Enable reattachable execution.
         """
         self.thread_local = threading.local()
 
@@ -597,7 +613,7 @@ class SparkConnectClient(object):
         self._builder = (
             connection
             if isinstance(connection, ChannelBuilder)
-            else ChannelBuilder(connection, channelOptions)
+            else ChannelBuilder(connection, channel_options)
         )
         self._user_id = None
         self._retry_policy = {
@@ -606,8 +622,8 @@ class SparkConnectClient(object):
             "initial_backoff": 50,
             "max_backoff": 60000,
         }
-        if retryPolicy:
-            self._retry_policy.update(retryPolicy)
+        if retry_policy:
+            self._retry_policy.update(retry_policy)
 
         # Generate a unique session ID for this client. This UUID must be 
unique to allow
         # concurrent Spark sessions of the same user. If the channel is 
closed, creating
@@ -615,8 +631,8 @@ class SparkConnectClient(object):
         self._session_id = str(uuid.uuid4())
         if self._builder.userId is not None:
             self._user_id = self._builder.userId
-        elif userId is not None:
-            self._user_id = userId
+        elif user_id is not None:
+            self._user_id = user_id
         else:
             self._user_id = os.getenv("USER", None)
 
@@ -624,8 +640,17 @@ class SparkConnectClient(object):
         self._closed = False
         self._stub = grpc_lib.SparkConnectServiceStub(self._channel)
         self._artifact_manager = ArtifactManager(self._user_id, 
self._session_id, self._channel)
+        self._use_reattachable_execute = use_reattachable_execute
         # Configure logging for the SparkConnect client.
 
+    def disable_reattachable_execute(self) -> "SparkConnectClient":
+        self._use_reattachable_execute = False
+        return self
+
+    def enable_reattachable_execute(self) -> "SparkConnectClient":
+        self._use_reattachable_execute = True
+        return self
+
     def register_udf(
         self,
         function: Any,
@@ -741,7 +766,7 @@ class SparkConnectClient(object):
         return resources
 
     def _build_observed_metrics(
-        self, metrics: List["pb2.ExecutePlanResponse.ObservedMetrics"]
+        self, metrics: Sequence["pb2.ExecutePlanResponse.ObservedMetrics"]
     ) -> Iterator[PlanObservedMetrics]:
         return (PlanObservedMetrics(x.name, [v for v in x.values]) for x in 
metrics)
 
@@ -1065,17 +1090,29 @@ class SparkConnectClient(object):
 
         """
         logger.info("Execute")
+
+        def handle_response(b: pb2.ExecutePlanResponse) -> None:
+            if b.session_id != self._session_id:
+                raise SparkConnectException(
+                    "Received incorrect session identifier for request: "
+                    f"{b.session_id} != {self._session_id}"
+                )
+
         try:
-            for attempt in Retrying(
-                can_retry=SparkConnectClient.retry_exception, 
**self._retry_policy
-            ):
-                with attempt:
-                    for b in self._stub.ExecutePlan(req, 
metadata=self._builder.metadata()):
-                        if b.session_id != self._session_id:
-                            raise SparkConnectException(
-                                "Received incorrect session identifier for 
request: "
-                                f"{b.session_id} != {self._session_id}"
-                            )
+            if self._use_reattachable_execute:
+                # Don't use retryHandler - own retry handling is inside.
+                generator = ExecutePlanResponseReattachableIterator(
+                    req, self._stub, self._retry_policy, 
self._builder.metadata()
+                )
+                for b in generator:
+                    handle_response(b)
+            else:
+                for attempt in Retrying(
+                    can_retry=SparkConnectClient.retry_exception, 
**self._retry_policy
+                ):
+                    with attempt:
+                        for b in self._stub.ExecutePlan(req, 
metadata=self._builder.metadata()):
+                            handle_response(b)
         except Exception as error:
             self._handle_error(error)
 
@@ -1092,58 +1129,77 @@ class SparkConnectClient(object):
     ]:
         logger.info("ExecuteAndFetchAsIterator")
 
+        def handle_response(
+            b: pb2.ExecutePlanResponse,
+        ) -> Iterator[
+            Union[
+                "pa.RecordBatch",
+                StructType,
+                PlanMetrics,
+                PlanObservedMetrics,
+                Dict[str, Any],
+            ]
+        ]:
+            if b.session_id != self._session_id:
+                raise SparkConnectException(
+                    "Received incorrect session identifier for request: "
+                    f"{b.session_id} != {self._session_id}"
+                )
+            if b.HasField("metrics"):
+                logger.debug("Received metric batch.")
+                yield from self._build_metrics(b.metrics)
+            if b.observed_metrics:
+                logger.debug("Received observed metric batch.")
+                yield from self._build_observed_metrics(b.observed_metrics)
+            if b.HasField("schema"):
+                logger.debug("Received the schema.")
+                dt = types.proto_schema_to_pyspark_data_type(b.schema)
+                assert isinstance(dt, StructType)
+                yield dt
+            if b.HasField("sql_command_result"):
+                logger.debug("Received the SQL command result.")
+                yield {"sql_command_result": b.sql_command_result.relation}
+            if b.HasField("write_stream_operation_start_result"):
+                field = "write_stream_operation_start_result"
+                yield {field: b.write_stream_operation_start_result}
+            if b.HasField("streaming_query_command_result"):
+                yield {"streaming_query_command_result": 
b.streaming_query_command_result}
+            if b.HasField("streaming_query_manager_command_result"):
+                cmd_result = b.streaming_query_manager_command_result
+                yield {"streaming_query_manager_command_result": cmd_result}
+            if b.HasField("get_resources_command_result"):
+                resources = {}
+                for key, resource in 
b.get_resources_command_result.resources.items():
+                    name = resource.name
+                    addresses = [address for address in resource.addresses]
+                    resources[key] = ResourceInformation(name, addresses)
+                yield {"get_resources_command_result": resources}
+            if b.HasField("arrow_batch"):
+                logger.debug(
+                    f"Received arrow batch rows={b.arrow_batch.row_count} "
+                    f"size={len(b.arrow_batch.data)}"
+                )
+
+                with pa.ipc.open_stream(b.arrow_batch.data) as reader:
+                    for batch in reader:
+                        assert isinstance(batch, pa.RecordBatch)
+                        yield batch
+
         try:
-            for attempt in Retrying(
-                can_retry=SparkConnectClient.retry_exception, 
**self._retry_policy
-            ):
-                with attempt:
-                    for b in self._stub.ExecutePlan(req, 
metadata=self._builder.metadata()):
-                        if b.session_id != self._session_id:
-                            raise SparkConnectException(
-                                "Received incorrect session identifier for 
request: "
-                                f"{b.session_id} != {self._session_id}"
-                            )
-                        if b.HasField("metrics"):
-                            logger.debug("Received metric batch.")
-                            yield from self._build_metrics(b.metrics)
-                        if b.observed_metrics:
-                            logger.debug("Received observed metric batch.")
-                            yield from 
self._build_observed_metrics(b.observed_metrics)
-                        if b.HasField("schema"):
-                            logger.debug("Received the schema.")
-                            dt = 
types.proto_schema_to_pyspark_data_type(b.schema)
-                            assert isinstance(dt, StructType)
-                            yield dt
-                        if b.HasField("sql_command_result"):
-                            logger.debug("Received the SQL command result.")
-                            yield {"sql_command_result": 
b.sql_command_result.relation}
-                        if b.HasField("write_stream_operation_start_result"):
-                            field = "write_stream_operation_start_result"
-                            yield {field: 
b.write_stream_operation_start_result}
-                        if b.HasField("streaming_query_command_result"):
-                            yield {
-                                "streaming_query_command_result": 
b.streaming_query_command_result
-                            }
-                        if 
b.HasField("streaming_query_manager_command_result"):
-                            cmd_result = 
b.streaming_query_manager_command_result
-                            yield {"streaming_query_manager_command_result": 
cmd_result}
-                        if b.HasField("get_resources_command_result"):
-                            resources = {}
-                            for key, resource in 
b.get_resources_command_result.resources.items():
-                                name = resource.name
-                                addresses = [address for address in 
resource.addresses]
-                                resources[key] = ResourceInformation(name, 
addresses)
-                            yield {"get_resources_command_result": resources}
-                        if b.HasField("arrow_batch"):
-                            logger.debug(
-                                f"Received arrow batch 
rows={b.arrow_batch.row_count} "
-                                f"size={len(b.arrow_batch.data)}"
-                            )
-
-                            with pa.ipc.open_stream(b.arrow_batch.data) as 
reader:
-                                for batch in reader:
-                                    assert isinstance(batch, pa.RecordBatch)
-                                    yield batch
+            if self._use_reattachable_execute:
+                # Don't use retryHandler - own retry handling is inside.
+                generator = ExecutePlanResponseReattachableIterator(
+                    req, self._stub, self._retry_policy, 
self._builder.metadata()
+                )
+                for b in generator:
+                    yield from handle_response(b)
+            else:
+                for attempt in Retrying(
+                    can_retry=SparkConnectClient.retry_exception, 
**self._retry_policy
+                ):
+                    with attempt:
+                        for b in self._stub.ExecutePlan(req, 
metadata=self._builder.metadata()):
+                            yield from handle_response(b)
         except Exception as error:
             self._handle_error(error)
 
@@ -1502,6 +1558,9 @@ class AttemptManager:
             self._retry_state.set_done()
             return None
 
+    def is_first_try(self) -> bool:
+        return self._retry_state._count == 0
+
 
 class Retrying:
     """
diff --git a/python/pyspark/sql/connect/client/reattach.py 
b/python/pyspark/sql/connect/client/reattach.py
new file mode 100644
index 00000000000..4d4cce0ca44
--- /dev/null
+++ b/python/pyspark/sql/connect/client/reattach.py
@@ -0,0 +1,237 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+from pyspark.sql.connect.utils import check_dependencies
+
+check_dependencies(__name__)
+
+import warnings
+import uuid
+from collections.abc import Generator
+from typing import Optional, Dict, Any, Iterator, Iterable, Tuple
+from multiprocessing.pool import ThreadPool
+import os
+
+import pyspark.sql.connect.proto as pb2
+import pyspark.sql.connect.proto.base_pb2_grpc as grpc_lib
+
+
+class ExecutePlanResponseReattachableIterator(Generator):
+    """
+    Retryable iterator of ExecutePlanResponses to an ExecutePlan call.
+
+    It can handle situations when:
+      - the ExecutePlanResponse stream was broken by retryable network error 
(governed by
+        retryPolicy)
+      - the ExecutePlanResponse was gracefully ended by the server without a 
ResultComplete
+        message; this tells the client that there is more, and it should 
reattach to continue.
+
+    Initial iterator is the result of an ExecutePlan on the request, but it 
can be reattached with
+    ReattachExecute request. ReattachExecute request is provided the 
responseId of last returned
+    ExecutePlanResponse on the iterator to return a new iterator from server 
that continues after
+    that.
+
+    In reattachable execute the server does buffer some responses in case the 
client needs to
+    backtrack. To let server release this buffer sooner, this iterator 
asynchronously sends
+    ReleaseExecute RPCs that instruct the server to release responses that it 
already processed.
+
+    Note: If the initial ExecutePlan did not even reach the server and 
execution didn't start,
+    the ReattachExecute can still fail with 
INVALID_HANDLE.OPERATION_NOT_FOUND, failing the whole
+    operation.
+    """
+
+    _release_thread_pool = ThreadPool(os.cpu_count() if os.cpu_count() else 8)
+
+    def __init__(
+        self,
+        request: pb2.ExecutePlanRequest,
+        stub: grpc_lib.SparkConnectServiceStub,
+        retry_policy: Dict[str, Any],
+        metadata: Iterable[Tuple[str, str]],
+    ):
+        self._request = request
+        self._retry_policy = retry_policy
+        if request.operation_id:
+            self._operation_id = request.operation_id
+        else:
+            # Add operation id, if not present.
+            # with operationId set by the client, the client can use it to try 
to reattach on error
+            # even before getting the first response. If the operation in fact 
didn't even reach the
+            # server, that will end with INVALID_HANDLE.OPERATION_NOT_FOUND 
error.
+            self._operation_id = str(uuid.uuid4())
+
+        self._stub = stub
+        request.request_options.append(
+            pb2.ExecutePlanRequest.RequestOption(
+                reattach_options=pb2.ReattachOptions(reattachable=True)
+            )
+        )
+        request.operation_id = self._operation_id
+        self._initial_request = request
+
+        # ResponseId of the last response returned by next()
+        self._last_returned_response_id: Optional[str] = None
+
+        # True after ResponseComplete message was seen in the stream.
+        # Server will always send this message at the end of the stream, if 
the underlying iterator
+        # finishes without producing one, another iterator needs to be 
reattached.
+        self._result_complete = False
+
+        # Initial iterator comes from ExecutePlan request.
+        # Note: This is not retried, because no error would ever be thrown 
here, and GRPC will only
+        # throw error on first self._has_next().
+        self._iterator: Iterator[pb2.ExecutePlanResponse] = iter(
+            self._stub.ExecutePlan(self._initial_request, metadata=metadata)
+        )
+
+        # Current item from this iterator.
+        self._current: Optional[pb2.ExecutePlanResponse] = None
+
+    def send(self, value: Any) -> pb2.ExecutePlanResponse:
+        # will trigger reattach in case the stream completed without 
result_complete
+        if not self._has_next():
+            raise StopIteration()
+
+        ret = self._current
+        assert ret is not None
+
+        self._last_returned_response_id = ret.response_id
+        if ret.HasField("result_complete"):
+            self._result_complete = True
+            self._release_execute(None)  # release all
+        else:
+            self._release_execute(self._last_returned_response_id)
+        self._current = None
+        return ret
+
+    def _has_next(self) -> bool:
+        from pyspark.sql.connect.client.core import SparkConnectClient
+        from pyspark.sql.connect.client.core import Retrying
+
+        if self._result_complete:
+            # After response complete response
+            return False
+        else:
+            for attempt in Retrying(
+                can_retry=SparkConnectClient.retry_exception, 
**self._retry_policy
+            ):
+                with attempt:
+                    # on first try, we use the existing iterator.
+                    if not attempt.is_first_try():
+                        # on retry, the iterator is borked, so we need a new 
one
+                        self._iterator = iter(
+                            
self._stub.ReattachExecute(self._create_reattach_execute_request())
+                        )
+
+                    if self._current is None:
+                        try:
+                            self._current = next(self._iterator)
+                        except StopIteration:
+                            pass
+
+                    has_next = self._current is not None
+
+                    # Graceful reattach:
+                    # If iterator ended, but there was no ResponseComplete, it 
means that
+                    # there is more, and we need to reattach. While 
ResponseComplete didn't
+                    # arrive, we keep reattaching.
+                    if not self._result_complete and not has_next:
+                        while not has_next:
+                            self._iterator = iter(
+                                
self._stub.ReattachExecute(self._create_reattach_execute_request())
+                            )
+                            # shouldn't change
+                            assert not self._result_complete
+                            try:
+                                self._current = next(self._iterator)
+                            except StopIteration:
+                                pass
+                            has_next = self._current is not None
+                    return has_next
+            return False
+
+    def _release_execute(self, until_response_id: Optional[str]) -> None:
+        """
+        Inform the server to release the execution.
+
+        This will send an asynchronous RPC which will not block this iterator, 
the iterator can
+        continue to be consumed.
+
+        Release with untilResponseId informs the server that the iterator has 
been consumed until
+        and including response with that responseId, and these responses can 
be freed.
+
+        Release with None means that the responses have been completely 
consumed and informs the
+        server that the completed execution can be completely freed.
+        """
+        from pyspark.sql.connect.client.core import SparkConnectClient
+        from pyspark.sql.connect.client.core import Retrying
+
+        request = self._create_release_execute_request(until_response_id)
+
+        def target() -> None:
+            try:
+                for attempt in Retrying(
+                    can_retry=SparkConnectClient.retry_exception, 
**self._retry_policy
+                ):
+                    with attempt:
+                        self._stub.ReleaseExecute(request)
+            except Exception as e:
+                warnings.warn(f"ReleaseExecute failed with exception: {e}.")
+
+        
ExecutePlanResponseReattachableIterator._release_thread_pool.apply_async(target)
+
+    def _create_reattach_execute_request(self) -> pb2.ReattachExecuteRequest:
+        reattach = pb2.ReattachExecuteRequest(
+            session_id=self._initial_request.session_id,
+            user_context=self._initial_request.user_context,
+            operation_id=self._initial_request.operation_id,
+        )
+
+        if self._initial_request.client_type:
+            reattach.client_type = self._initial_request.client_type
+
+        if self._last_returned_response_id:
+            reattach.last_response_id = self._last_returned_response_id
+
+        return reattach
+
+    def _create_release_execute_request(
+        self, until_response_id: Optional[str]
+    ) -> pb2.ReleaseExecuteRequest:
+        release = pb2.ReleaseExecuteRequest(
+            session_id=self._initial_request.session_id,
+            user_context=self._initial_request.user_context,
+            operation_id=self._initial_request.operation_id,
+        )
+
+        if self._initial_request.client_type:
+            release.client_type = self._initial_request.client_type
+
+        if not until_response_id:
+            
release.release_all.CopyFrom(pb2.ReleaseExecuteRequest.ReleaseAll())
+        else:
+            release.release_until.response_id = until_response_id
+
+        return release
+
+    def throw(self, type: Any = None, value: Any = None, traceback: Any = 
None) -> Any:
+        super().throw(type, value, traceback)
+
+    def close(self) -> None:
+        return super().close()
+
+    def __del__(self) -> None:
+        return self.close()
diff --git a/python/pyspark/sql/connect/session.py 
b/python/pyspark/sql/connect/session.py
index 8cd39ba7a79..9bba0db05e4 100644
--- a/python/pyspark/sql/connect/session.py
+++ b/python/pyspark/sql/connect/session.py
@@ -233,7 +233,7 @@ class SparkSession:
             the $USER environment. Defining the user ID as part of the 
connection string
             takes precedence.
         """
-        self._client = SparkConnectClient(connection=connection, userId=userId)
+        self._client = SparkConnectClient(connection=connection, 
user_id=userId)
         self._session_id = self._client._session_id
 
     def table(self, tableName: str) -> DataFrame:
diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py
index 834b0307238..ede6318782e 100644
--- a/python/pyspark/sql/session.py
+++ b/python/pyspark/sql/session.py
@@ -1788,6 +1788,8 @@ class SparkSession(SparkConversionMixin):
 
         Notes
         -----
+        This API is unstable, and a developer API. It returns non-API instance
+        :class:`SparkConnectClient`.
         This is an API dedicated to Spark Connect client only. With regular 
Spark Session, it throws
         an exception.
         """
diff --git a/python/pyspark/sql/tests/connect/client/test_client.py 
b/python/pyspark/sql/tests/connect/client/test_client.py
index 5c39d4502f5..9276b88e153 100644
--- a/python/pyspark/sql/tests/connect/client/test_client.py
+++ b/python/pyspark/sql/tests/connect/client/test_client.py
@@ -30,7 +30,7 @@ if should_test_connect:
 @unittest.skipIf(not should_test_connect, connect_requirement_message)
 class SparkConnectClientTestCase(unittest.TestCase):
     def test_user_agent_passthrough(self):
-        client = SparkConnectClient("sc://foo/;user_agent=bar")
+        client = SparkConnectClient("sc://foo/;user_agent=bar", 
use_reattachable_execute=False)
         mock = MockService(client._session_id)
         client._stub = mock
 
@@ -41,7 +41,7 @@ class SparkConnectClientTestCase(unittest.TestCase):
         self.assertRegex(mock.req.client_type, r"^bar spark/[^ ]+ os/[^ ]+ 
python/[^ ]+$")
 
     def test_user_agent_default(self):
-        client = SparkConnectClient("sc://foo/")
+        client = SparkConnectClient("sc://foo/", 
use_reattachable_execute=False)
         mock = MockService(client._session_id)
         client._stub = mock
 
@@ -54,11 +54,11 @@ class SparkConnectClientTestCase(unittest.TestCase):
         )
 
     def test_properties(self):
-        client = SparkConnectClient("sc://foo/;token=bar")
+        client = SparkConnectClient("sc://foo/;token=bar", 
use_reattachable_execute=False)
         self.assertEqual(client.token, "bar")
         self.assertEqual(client.host, "foo")
 
-        client = SparkConnectClient("sc://foo/")
+        client = SparkConnectClient("sc://foo/", 
use_reattachable_execute=False)
         self.assertIsNone(client.token)
 
     def test_channel_builder(self):
@@ -67,12 +67,14 @@ class SparkConnectClientTestCase(unittest.TestCase):
             def userId(self) -> Optional[str]:
                 return "abc"
 
-        client = SparkConnectClient(CustomChannelBuilder("sc://foo/"))
+        client = SparkConnectClient(
+            CustomChannelBuilder("sc://foo/"), use_reattachable_execute=False
+        )
 
         self.assertEqual(client._user_id, "abc")
 
     def test_interrupt_all(self):
-        client = SparkConnectClient("sc://foo/;token=bar")
+        client = SparkConnectClient("sc://foo/;token=bar", 
use_reattachable_execute=False)
         mock = MockService(client._session_id)
         client._stub = mock
 
@@ -80,7 +82,7 @@ class SparkConnectClientTestCase(unittest.TestCase):
         self.assertIsNotNone(mock.req, "Interrupt API was not called when 
expected")
 
     def test_is_closed(self):
-        client = SparkConnectClient("sc://foo/;token=bar")
+        client = SparkConnectClient("sc://foo/;token=bar", 
use_reattachable_execute=False)
 
         self.assertFalse(client.is_closed)
         client.close()
diff --git a/python/pyspark/testing/connectutils.py 
b/python/pyspark/testing/connectutils.py
index b6145d0a006..ba81c783672 100644
--- a/python/pyspark/testing/connectutils.py
+++ b/python/pyspark/testing/connectutils.py
@@ -170,6 +170,10 @@ class ReusedConnectTestCase(unittest.TestCase, 
SQLTestUtils, PySparkErrorTestUti
         # Disable JVM stack trace in Spark Connect tests to prevent the
         # HTTP header size from exceeding the maximum allowed size.
         conf.set("spark.sql.pyspark.jvmStacktrace.enabled", "false")
+        # Make the server terminate reattachable streams every 1 second and 
123 bytes,
+        # to make the tests exercise reattach.
+        conf.set("spark.connect.execute.reattachable.senderMaxStreamDuration", 
"1s")
+        conf.set("spark.connect.execute.reattachable.senderMaxStreamSize", 
"123")
         return conf
 
     @classmethod


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to