This is an automated email from the ASF dual-hosted git repository. dongjoon pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 360a3f9023d0 [SPARK-45733][PYTHON][TESTS][FOLLOWUP] Skip `pyspark.sql.tests.connect.client.test_client` if not should_test_connect 360a3f9023d0 is described below commit 360a3f9023d08812e3f3c44af9cdac644c5d67b2 Author: Dongjoon Hyun <dh...@apple.com> AuthorDate: Tue Apr 2 22:30:08 2024 -0700 [SPARK-45733][PYTHON][TESTS][FOLLOWUP] Skip `pyspark.sql.tests.connect.client.test_client` if not should_test_connect ### What changes were proposed in this pull request? This is a follow-up of the following. - https://github.com/apache/spark/pull/43591 ### Why are the changes needed? This test requires `pandas` which is an optional dependency in Apache Spark. ``` $ python/run-tests --modules=pyspark-connect --parallelism=1 --python-executables=python3.10 --testnames 'pyspark.sql.tests.connect.client.test_client' Running PySpark tests. Output is in /Users/dongjoon/APACHE/spark-merge/python/unit-tests.log Will test against the following Python executables: ['python3.10'] Will test the following Python tests: ['pyspark.sql.tests.connect.client.test_client'] python3.10 python_implementation is CPython python3.10 version is: Python 3.10.13 Starting test(python3.10): pyspark.sql.tests.connect.client.test_client (temp output: /Users/dongjoon/APACHE/spark-merge/python/target/216a8716-3a1f-4cf9-9c7c-63087f29f892/python3.10__pyspark.sql.tests.connect.client.test_client__tydue4ck.log) Traceback (most recent call last): File "/Users/dongjoon/.pyenv/versions/3.10.13/lib/python3.10/runpy.py", line 196, in _run_module_as_main return _run_code(code, main_globals, None, File "/Users/dongjoon/.pyenv/versions/3.10.13/lib/python3.10/runpy.py", line 86, in _run_code exec(code, run_globals) File "/Users/dongjoon/APACHE/spark-merge/python/pyspark/sql/tests/connect/client/test_client.py", line 137, in <module> class TestPolicy(DefaultPolicy): NameError: name 'DefaultPolicy' is not defined ``` ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? Pass the CIs and manually test without `pandas`. ``` $ pip3 uninstall pandas $ python/run-tests --modules=pyspark-connect --parallelism=1 --python-executables=python3.10 --testnames 'pyspark.sql.tests.connect.client.test_client' Running PySpark tests. Output is in /Users/dongjoon/APACHE/spark-merge/python/unit-tests.log Will test against the following Python executables: ['python3.10'] Will test the following Python tests: ['pyspark.sql.tests.connect.client.test_client'] python3.10 python_implementation is CPython python3.10 version is: Python 3.10.13 Starting test(python3.10): pyspark.sql.tests.connect.client.test_client (temp output: /Users/dongjoon/APACHE/spark-merge/python/target/acf07ed5-938a-4272-87e1-47e3bf8b988e/python3.10__pyspark.sql.tests.connect.client.test_client__sfdosnek.log) Finished test(python3.10): pyspark.sql.tests.connect.client.test_client (0s) ... 13 tests were skipped Tests passed in 0 seconds Skipped tests in pyspark.sql.tests.connect.client.test_client with python3.10: test_basic_flow (pyspark.sql.tests.connect.client.test_client.SparkConnectClientReattachTestCase) ... skip (0.002s) test_fail_and_retry_during_execute (pyspark.sql.tests.connect.client.test_client.SparkConnectClientReattachTestCase) ... skip (0.000s) test_fail_and_retry_during_reattach (pyspark.sql.tests.connect.client.test_client.SparkConnectClientReattachTestCase) ... skip (0.000s) test_fail_during_execute (pyspark.sql.tests.connect.client.test_client.SparkConnectClientReattachTestCase) ... skip (0.000s) test_channel_builder (pyspark.sql.tests.connect.client.test_client.SparkConnectClientTestCase) ... skip (0.000s) test_channel_builder_with_session (pyspark.sql.tests.connect.client.test_client.SparkConnectClientTestCase) ... skip (0.000s) test_interrupt_all (pyspark.sql.tests.connect.client.test_client.SparkConnectClientTestCase) ... skip (0.000s) test_is_closed (pyspark.sql.tests.connect.client.test_client.SparkConnectClientTestCase) ... skip (0.000s) test_properties (pyspark.sql.tests.connect.client.test_client.SparkConnectClientTestCase) ... skip (0.000s) test_retry (pyspark.sql.tests.connect.client.test_client.SparkConnectClientTestCase) ... skip (0.000s) test_retry_client_unit (pyspark.sql.tests.connect.client.test_client.SparkConnectClientTestCase) ... skip (0.000s) test_user_agent_default (pyspark.sql.tests.connect.client.test_client.SparkConnectClientTestCase) ... skip (0.000s) test_user_agent_passthrough (pyspark.sql.tests.connect.client.test_client.SparkConnectClientTestCase) ... skip (0.000s) ``` ### Was this patch authored or co-authored using generative AI tooling? No. Closes #45830 from dongjoon-hyun/SPARK-45733. Authored-by: Dongjoon Hyun <dh...@apple.com> Signed-off-by: Dongjoon Hyun <dh...@apple.com> --- .../sql/tests/connect/client/test_client.py | 225 ++++++++++----------- 1 file changed, 110 insertions(+), 115 deletions(-) diff --git a/python/pyspark/sql/tests/connect/client/test_client.py b/python/pyspark/sql/tests/connect/client/test_client.py index 23e89ce8a1f9..b96fc44d50a7 100644 --- a/python/pyspark/sql/tests/connect/client/test_client.py +++ b/python/pyspark/sql/tests/connect/client/test_client.py @@ -36,6 +36,116 @@ if should_test_connect: from pyspark.errors import RetriesExceeded import pyspark.sql.connect.proto as proto + class TestPolicy(DefaultPolicy): + def __init__(self): + super().__init__( + max_retries=3, + backoff_multiplier=4.0, + initial_backoff=10, + max_backoff=10, + jitter=10, + min_jitter_threshold=10, + ) + + class TestException(grpc.RpcError, grpc.Call): + """Exception mock to test retryable exceptions.""" + + def __init__(self, msg, code=grpc.StatusCode.INTERNAL): + self.msg = msg + self._code = code + + def code(self): + return self._code + + def __str__(self): + return self.msg + + def trailing_metadata(self): + return () + + class ResponseGenerator(Generator): + """This class is used to generate values that are returned by the streaming + iterator of the GRPC stub.""" + + def __init__(self, funs): + self._funs = funs + self._iterator = iter(self._funs) + + def send(self, value: Any) -> proto.ExecutePlanResponse: + val = next(self._iterator) + if callable(val): + return val() + else: + return val + + def throw(self, type: Any = None, value: Any = None, traceback: Any = None) -> Any: + super().throw(type, value, traceback) + + def close(self) -> None: + return super().close() + + class MockSparkConnectStub: + """Simple mock class for the GRPC stub used by the re-attachable execution.""" + + def __init__(self, execute_ops=None, attach_ops=None): + self._execute_ops = execute_ops + self._attach_ops = attach_ops + # Call counters + self.execute_calls = 0 + self.release_calls = 0 + self.release_until_calls = 0 + self.attach_calls = 0 + + def ExecutePlan(self, *args, **kwargs): + self.execute_calls += 1 + return self._execute_ops + + def ReattachExecute(self, *args, **kwargs): + self.attach_calls += 1 + return self._attach_ops + + def ReleaseExecute(self, req: proto.ReleaseExecuteRequest, *args, **kwargs): + if req.HasField("release_all"): + self.release_calls += 1 + elif req.HasField("release_until"): + print("increment") + self.release_until_calls += 1 + + class MockService: + # Simplest mock of the SparkConnectService. + # If this needs more complex logic, it needs to be replaced with Python mocking. + + req: Optional[proto.ExecutePlanRequest] + + def __init__(self, session_id: str): + self._session_id = session_id + self.req = None + + def ExecutePlan(self, req: proto.ExecutePlanRequest, metadata): + self.req = req + resp = proto.ExecutePlanResponse() + resp.session_id = self._session_id + + pdf = pd.DataFrame(data={"col1": [1, 2]}) + schema = pa.Schema.from_pandas(pdf) + table = pa.Table.from_pandas(pdf) + sink = pa.BufferOutputStream() + + writer = pa.ipc.new_stream(sink, schema=schema) + writer.write(table) + writer.close() + + buf = sink.getvalue() + resp.arrow_batch.data = buf.to_pybytes() + resp.arrow_batch.row_count = 2 + return [resp] + + def Interrupt(self, req: proto.InterruptRequest, metadata): + self.req = req + resp = proto.InterruptResponse() + resp.session_id = self._session_id + return resp + @unittest.skipIf(not should_test_connect, connect_requirement_message) class SparkConnectClientTestCase(unittest.TestCase): @@ -134,18 +244,6 @@ class SparkConnectClientTestCase(unittest.TestCase): self.assertEqual(client._session_id, chan.session_id) -class TestPolicy(DefaultPolicy): - def __init__(self): - super().__init__( - max_retries=3, - backoff_multiplier=4.0, - initial_backoff=10, - max_backoff=10, - jitter=10, - min_jitter_threshold=10, - ) - - @unittest.skipIf(not should_test_connect, connect_requirement_message) class SparkConnectClientReattachTestCase(unittest.TestCase): def setUp(self) -> None: @@ -243,109 +341,6 @@ class SparkConnectClientReattachTestCase(unittest.TestCase): eventually(timeout=1, catch_assertions=True)(check)() -class TestException(grpc.RpcError, grpc.Call): - """Exception mock to test retryable exceptions.""" - - def __init__(self, msg, code=grpc.StatusCode.INTERNAL): - self.msg = msg - self._code = code - - def code(self): - return self._code - - def __str__(self): - return self.msg - - def trailing_metadata(self): - return () - - -class ResponseGenerator(Generator): - """This class is used to generate values that are returned by the streaming - iterator of the GRPC stub.""" - - def __init__(self, funs): - self._funs = funs - self._iterator = iter(self._funs) - - def send(self, value: Any) -> proto.ExecutePlanResponse: - val = next(self._iterator) - if callable(val): - return val() - else: - return val - - def throw(self, type: Any = None, value: Any = None, traceback: Any = None) -> Any: - super().throw(type, value, traceback) - - def close(self) -> None: - return super().close() - - -class MockSparkConnectStub: - """Simple mock class for the GRPC stub used by the re-attachable execution.""" - - def __init__(self, execute_ops=None, attach_ops=None): - self._execute_ops = execute_ops - self._attach_ops = attach_ops - # Call counters - self.execute_calls = 0 - self.release_calls = 0 - self.release_until_calls = 0 - self.attach_calls = 0 - - def ExecutePlan(self, *args, **kwargs): - self.execute_calls += 1 - return self._execute_ops - - def ReattachExecute(self, *args, **kwargs): - self.attach_calls += 1 - return self._attach_ops - - def ReleaseExecute(self, req: proto.ReleaseExecuteRequest, *args, **kwargs): - if req.HasField("release_all"): - self.release_calls += 1 - elif req.HasField("release_until"): - print("increment") - self.release_until_calls += 1 - - -class MockService: - # Simplest mock of the SparkConnectService. - # If this needs more complex logic, it needs to be replaced with Python mocking. - - req: Optional[proto.ExecutePlanRequest] - - def __init__(self, session_id: str): - self._session_id = session_id - self.req = None - - def ExecutePlan(self, req: proto.ExecutePlanRequest, metadata): - self.req = req - resp = proto.ExecutePlanResponse() - resp.session_id = self._session_id - - pdf = pd.DataFrame(data={"col1": [1, 2]}) - schema = pa.Schema.from_pandas(pdf) - table = pa.Table.from_pandas(pdf) - sink = pa.BufferOutputStream() - - writer = pa.ipc.new_stream(sink, schema=schema) - writer.write(table) - writer.close() - - buf = sink.getvalue() - resp.arrow_batch.data = buf.to_pybytes() - resp.arrow_batch.row_count = 2 - return [resp] - - def Interrupt(self, req: proto.InterruptRequest, metadata): - self.req = req - resp = proto.InterruptResponse() - resp.session_id = self._session_id - return resp - - if __name__ == "__main__": from pyspark.sql.tests.connect.client.test_client import * # noqa: F401 --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org