This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new e7fc4003b246 [SPARK-47812][CONNECT] Support Serialization of SparkSession for ForEachBatch worker e7fc4003b246 is described below commit e7fc4003b246bab743ab82d9e7bb77c0e2e5946e Author: Martin Grund <martin.gr...@databricks.com> AuthorDate: Sat Apr 13 10:30:23 2024 +0900 [SPARK-47812][CONNECT] Support Serialization of SparkSession for ForEachBatch worker ### What changes were proposed in this pull request? This patch adds support to register custom dispatch handlers when serializing objects using the provided Cloudpickle library. This is necessary to provide compatibility when executing ForEachBatch functions in structured streaming. A typical example for this behavior is the following test case: ```python def curried_function(df): def inner(batch_df, batch_id): df.createOrReplaceTempView("updates") batch_df.createOrReplaceTempView("batch_updates") return inner df = spark.readStream.format("text").load("python/test_support/sql/streaming") other_df = self.spark.range(100) df.writeStream.foreachBatch(curried_function(other_df)).start() ``` Here we curry a DataFrame into the function called during ForEachBatch and effectively passing state. Until now, serializing DataFrames and SparkSessions in Spark Connect was not possible since the SparkSession carries the open GPRC connection and the DataFrame itself overrides certain magic methods that make pickling fail. To make serializing Spark Sessions possible, we register a custom session constructor, that simply returns the current active session, during the serialization of the ForEachBatch function. Now, when the ForEachBatch worker starts the execution it already creates and registers an active SparkSession. To serialize and reconstruct the DataFrame we simply have to pass in the session and the plan, the remaining attributes do not carry a permanent state. To avoid modifying any global behavior, the serialization handlers are not registered for all cases but only when the ForEachBatch and ForEach handlers are called. This is to make sure that we don't unexpectedly change behavior. ### Why are the changes needed? Compatibility and Ease of Use ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Added and updated tests ### Was this patch authored or co-authored using generative AI tooling? No Closes #46002 from grundprinzip/SPARK-47812. Lead-authored-by: Martin Grund <martin.gr...@databricks.com> Co-authored-by: Martin Grund <grundprin...@gmail.com> Co-authored-by: Hyukjin Kwon <gurwls...@gmail.com> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- python/pyspark/sql/connect/dataframe.py | 22 +++++++ python/pyspark/sql/connect/session.py | 37 ++++++++++++ .../streaming/worker/foreach_batch_worker.py | 15 ++++- .../connect/streaming/worker/listener_worker.py | 15 ++++- .../connect/streaming/test_parity_foreach_batch.py | 70 +++++++++++++++++----- .../connect/streaming/test_parity_listener.py | 23 ++----- .../pyspark/sql/tests/connect/test_parity_udtf.py | 18 +++++- 7 files changed, 163 insertions(+), 37 deletions(-) diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index 1dddcc078810..f0dc412760a4 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -122,6 +122,28 @@ class DataFrame: self._support_repr_html = False self._cached_schema: Optional[StructType] = None + def __reduce__(self) -> Tuple: + """ + Custom method for serializing the DataFrame object using Pickle. Since the DataFrame + overrides "__getattr__" method, the default serialization method does not work. + + Returns + ------- + The tuple containing the information needed to reconstruct the object. + + """ + return ( + DataFrame, + ( + self._plan, + self._session, + ), + { + "_support_repr_html": self._support_repr_html, + "_cached_schema": self._cached_schema, + }, + ) + def __repr__(self) -> str: if not self._support_repr_html: ( diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py index 07fe8a62f082..3be6c83cf13b 100644 --- a/python/pyspark/sql/connect/session.py +++ b/python/pyspark/sql/connect/session.py @@ -96,6 +96,7 @@ from pyspark.errors import ( PySparkRuntimeError, PySparkValueError, PySparkTypeError, + PySparkAssertionError, ) if TYPE_CHECKING: @@ -288,6 +289,26 @@ class SparkSession: def getActiveSession(cls) -> Optional["SparkSession"]: return getattr(cls._active_session, "session", None) + @classmethod + def _getActiveSessionIfMatches(cls, session_id: str) -> "SparkSession": + """ + Internal use only. This method is called from the custom handler + generated by __reduce__. To avoid serializing a WeakRef, we create a + custom classmethod to instantiate the SparkSession. + """ + session = SparkSession.getActiveSession() + if session is None: + raise PySparkRuntimeError( + error_class="NO_ACTIVE_SESSION", + message_parameters={}, + ) + if session._session_id != session_id: + raise PySparkAssertionError( + "Expected session ID does not match active session ID: " + f"{session_id} != {session._session_id}" + ) + return session + getActiveSession.__doc__ = PySparkSession.getActiveSession.__doc__ @classmethod @@ -1034,6 +1055,22 @@ class SparkSession: profile.__doc__ = PySparkSession.profile.__doc__ + def __reduce__(self) -> Tuple: + """ + This method is called when the object is pickled. It returns a tuple of the object's + constructor function, arguments to it and the local state of the object. + This function is supposed to only be used when the active spark session that is pickled + is the same active spark session that is unpickled. + """ + + def creator(old_session_id: str) -> "SparkSession": + # We cannot perform the checks for session matching here because accessing the + # session ID property causes the serialization of a WeakRef and in turn breaks + # the serialization. + return SparkSession._getActiveSessionIfMatches(old_session_id) + + return creator, (self._session_id,) + SparkSession.__doc__ = PySparkSession.__doc__ diff --git a/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py b/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py index c4cf52b9996d..92ed7a4aaff5 100644 --- a/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py +++ b/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py @@ -29,7 +29,7 @@ from pyspark.serializers import ( CPickleSerializer, ) from pyspark import worker -from pyspark.sql import SparkSession +from pyspark.sql.connect.session import SparkSession from pyspark.util import handle_worker_exception from typing import IO from pyspark.worker_util import check_python_version @@ -38,9 +38,16 @@ pickle_ser = CPickleSerializer() utf8_deserializer = UTF8Deserializer() +spark = None + + def main(infile: IO, outfile: IO) -> None: + global spark check_python_version(infile) + # Enable Spark Connect Mode + os.environ["SPARK_CONNECT_MODE_ENABLED"] = "1" + connect_url = os.environ["SPARK_CONNECT_LOCAL_URL"] session_id = utf8_deserializer.loads(infile) @@ -49,8 +56,11 @@ def main(infile: IO, outfile: IO) -> None: f"url {connect_url} and sessionId {session_id}." ) + # To attach to the existing SparkSession, we're setting the session_id in the URL. + connect_url = connect_url + ";session_id=" + session_id spark_connect_session = SparkSession.builder.remote(connect_url).getOrCreate() - spark_connect_session._client._session_id = session_id # type: ignore[attr-defined] + assert spark_connect_session.session_id == session_id + spark = spark_connect_session # TODO(SPARK-44461): Enable Process Isolation @@ -62,6 +72,7 @@ def main(infile: IO, outfile: IO) -> None: log_name = "Streaming ForeachBatch worker" def process(df_id, batch_id): # type: ignore[no-untyped-def] + global spark print(f"{log_name} Started batch {batch_id} with DF id {df_id}") batch_df = spark_connect_session._create_remote_dataframe(df_id) func(batch_df, batch_id) diff --git a/python/pyspark/sql/connect/streaming/worker/listener_worker.py b/python/pyspark/sql/connect/streaming/worker/listener_worker.py index 69e0d8a46248..d3efb5894fc0 100644 --- a/python/pyspark/sql/connect/streaming/worker/listener_worker.py +++ b/python/pyspark/sql/connect/streaming/worker/listener_worker.py @@ -30,7 +30,7 @@ from pyspark.serializers import ( CPickleSerializer, ) from pyspark import worker -from pyspark.sql import SparkSession +from pyspark.sql.connect.session import SparkSession from pyspark.util import handle_worker_exception from typing import IO @@ -46,9 +46,16 @@ pickle_ser = CPickleSerializer() utf8_deserializer = UTF8Deserializer() +spark = None + + def main(infile: IO, outfile: IO) -> None: + global spark check_python_version(infile) + # Enable Spark Connect Mode + os.environ["SPARK_CONNECT_MODE_ENABLED"] = "1" + connect_url = os.environ["SPARK_CONNECT_LOCAL_URL"] session_id = utf8_deserializer.loads(infile) @@ -57,8 +64,11 @@ def main(infile: IO, outfile: IO) -> None: f"url {connect_url} and sessionId {session_id}." ) + # To attach to the existing SparkSession, we're setting the session_id in the URL. + connect_url = connect_url + ";session_id=" + session_id spark_connect_session = SparkSession.builder.remote(connect_url).getOrCreate() - spark_connect_session._client._session_id = session_id # type: ignore[attr-defined] + assert spark_connect_session.session_id == session_id + spark = spark_connect_session # TODO(SPARK-44461): Enable Process Isolation @@ -71,6 +81,7 @@ def main(infile: IO, outfile: IO) -> None: assert listener.spark == spark_connect_session def process(listener_event_str, listener_event_type): # type: ignore[no-untyped-def] + global spark listener_event = json.loads(listener_event_str) if listener_event_type == 0: listener.onQueryStarted(QueryStartedEvent.fromJson(listener_event)) diff --git a/python/pyspark/sql/tests/connect/streaming/test_parity_foreach_batch.py b/python/pyspark/sql/tests/connect/streaming/test_parity_foreach_batch.py index 30f7bb8c2df9..4598cbbdca4e 100644 --- a/python/pyspark/sql/tests/connect/streaming/test_parity_foreach_batch.py +++ b/python/pyspark/sql/tests/connect/streaming/test_parity_foreach_batch.py @@ -30,33 +30,73 @@ class StreamingForeachBatchParityTests(StreamingTestsForeachBatchMixin, ReusedCo def test_streaming_foreach_batch_graceful_stop(self): super().test_streaming_foreach_batch_graceful_stop() + def test_nested_dataframes(self): + def curried_function(df): + def inner(batch_df, batch_id): + df.createOrReplaceTempView("updates") + batch_df.createOrReplaceTempView("batch_updates") + + return inner + + try: + df = self.spark.readStream.format("text").load("python/test_support/sql/streaming") + other_df = self.spark.range(100) + q = df.writeStream.foreachBatch(curried_function(other_df)).start() + q.processAllAvailable() + collected = self.spark.sql("select * from batch_updates").collect() + self.assertTrue(len(collected), 2) + self.assertEqual(100, self.spark.sql("select * from updates").count()) + finally: + if q: + q.stop() + + def test_pickling_error(self): + class NoPickle: + def __reduce__(self): + raise ValueError("No pickle") + + no_pickle = NoPickle() + + def func(df, _): + print(no_pickle) + df.count() + + with self.assertRaises(PySparkPicklingError): + df = self.spark.readStream.format("text").load("python/test_support/sql/streaming") + q = df.writeStream.foreachBatch(func).start() + q.processAllAvailable() + def test_accessing_spark_session(self): spark = self.spark def func(df, _): - spark.createDataFrame([("do", "not"), ("serialize", "spark")]).collect() + spark.createDataFrame([("you", "can"), ("serialize", "spark")]).createOrReplaceTempView( + "test_accessing_spark_session" + ) - error_thrown = False try: - self.spark.readStream.format("rate").load().writeStream.foreachBatch(func).start() - except PySparkPicklingError as e: - self.assertEqual(e.getErrorClass(), "STREAMING_CONNECT_SERIALIZATION_ERROR") - error_thrown = True - self.assertTrue(error_thrown) + df = self.spark.readStream.format("text").load("python/test_support/sql/streaming") + q = df.writeStream.foreachBatch(func).start() + q.processAllAvailable() + self.assertEqual(2, spark.table("test_accessing_spark_session").count()) + finally: + if q: + q.stop() def test_accessing_spark_session_through_df(self): - dataframe = self.spark.createDataFrame([("do", "not"), ("serialize", "dataframe")]) + dataframe = self.spark.createDataFrame([("you", "can"), ("serialize", "dataframe")]) def func(df, _): - dataframe.collect() + dataframe.createOrReplaceTempView("test_accessing_spark_session_through_df") - error_thrown = False try: - self.spark.readStream.format("rate").load().writeStream.foreachBatch(func).start() - except PySparkPicklingError as e: - self.assertEqual(e.getErrorClass(), "STREAMING_CONNECT_SERIALIZATION_ERROR") - error_thrown = True - self.assertTrue(error_thrown) + df = self.spark.readStream.format("text").load("python/test_support/sql/streaming") + q = df.writeStream.foreachBatch(func).start() + q.processAllAvailable() + self.assertEqual(2, self.spark.table("test_accessing_spark_session_through_df").count()) + finally: + if q: + q.stop() if __name__ == "__main__": diff --git a/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py b/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py index f5ffa0154df1..a15e4547f67a 100644 --- a/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py +++ b/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py @@ -19,7 +19,6 @@ import unittest import time import pyspark.cloudpickle -from pyspark.errors import PySparkPicklingError from pyspark.sql.tests.streaming.test_streaming_listener import StreamingListenerTestsMixin from pyspark.sql.streaming.listener import StreamingQueryListener from pyspark.sql.functions import count, lit @@ -138,7 +137,9 @@ class StreamingListenerParityTests(StreamingListenerTestsMixin, ReusedConnectTes class TestListener(StreamingQueryListener): def onQueryStarted(self, event): - spark.createDataFrame([("do", "not"), ("serialize", "spark")]).collect() + spark.createDataFrame( + [("you", "can"), ("serialize", "spark")] + ).createOrReplaceTempView("test_accessing_spark_session") def onQueryProgress(self, event): pass @@ -149,16 +150,10 @@ class StreamingListenerParityTests(StreamingListenerTestsMixin, ReusedConnectTes def onQueryTerminated(self, event): pass - error_thrown = False - try: - self.spark.streams.addListener(TestListener()) - except PySparkPicklingError as e: - self.assertEqual(e.getErrorClass(), "STREAMING_CONNECT_SERIALIZATION_ERROR") - error_thrown = True - self.assertTrue(error_thrown) + self.spark.streams.addListener(TestListener()) def test_accessing_spark_session_through_df(self): - dataframe = self.spark.createDataFrame([("do", "not"), ("serialize", "dataframe")]) + dataframe = self.spark.createDataFrame([("you", "can"), ("serialize", "dataframe")]) class TestListener(StreamingQueryListener): def onQueryStarted(self, event): @@ -173,13 +168,7 @@ class StreamingListenerParityTests(StreamingListenerTestsMixin, ReusedConnectTes def onQueryTerminated(self, event): pass - error_thrown = False - try: - self.spark.streams.addListener(TestListener()) - except PySparkPicklingError as e: - self.assertEqual(e.getErrorClass(), "STREAMING_CONNECT_SERIALIZATION_ERROR") - error_thrown = True - self.assertTrue(error_thrown) + self.spark.streams.addListener(TestListener()) if __name__ == "__main__": diff --git a/python/pyspark/sql/tests/connect/test_parity_udtf.py b/python/pyspark/sql/tests/connect/test_parity_udtf.py index 02570ac9efa7..5071b69060a1 100644 --- a/python/pyspark/sql/tests/connect/test_parity_udtf.py +++ b/python/pyspark/sql/tests/connect/test_parity_udtf.py @@ -28,7 +28,7 @@ if should_test_connect: from pyspark.util import is_remote_only from pyspark.sql.tests.test_udtf import BaseUDTFTestsMixin, UDTFArrowTestsMixin from pyspark.testing.connectutils import ReusedConnectTestCase -from pyspark.errors.exceptions.connect import SparkConnectGrpcException +from pyspark.errors.exceptions.connect import SparkConnectGrpcException, PythonException class UDTFParityTests(BaseUDTFTestsMixin, ReusedConnectTestCase): @@ -76,6 +76,10 @@ class UDTFParityTests(BaseUDTFTestsMixin, ReusedConnectTestCase): def test_udtf_with_analyze_using_file(self): super().test_udtf_with_analyze_using_file() + @unittest.skip("pyspark-connect can serialize SparkSession, but fails on executor") + def test_udtf_access_spark_session(self): + super().test_udtf_access_spark_session() + def _add_pyfile(self, path): self.spark.addArtifacts(path, pyfile=True) @@ -99,6 +103,18 @@ class ArrowUDTFParityTests(UDTFArrowTestsMixin, UDTFParityTests): finally: super(ArrowUDTFParityTests, cls).tearDownClass() + def test_udtf_access_spark_session_connect(self): + df = self.spark.range(10) + + @udtf(returnType="x: int") + class TestUDTF: + def eval(self): + df.collect() + yield 1, + + with self.assertRaisesRegex(PythonException, "NO_ACTIVE_SESSION"): + TestUDTF().collect() + if __name__ == "__main__": import unittest --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org