This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch branch-3.5 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.5 by this push: new 808fd88bf4f [SPARK-42941][SS][CONNECT][3.5] Python StreamingQueryListener 808fd88bf4f is described below commit 808fd88bf4fb96d3277641897f2a8fffdeb77f73 Author: bogao007 <bo....@databricks.com> AuthorDate: Wed Aug 2 09:02:25 2023 +0900 [SPARK-42941][SS][CONNECT][3.5] Python StreamingQueryListener ### What changes were proposed in this pull request? Implement the python streaming query listener and the addListener method and removeListener method, follow up filed in: [SPARK-44516](https://issues.apache.org/jira/browse/SPARK-44516) to actually terminate the query listener process when removeListener is called. [SPARK-44516](https://issues.apache.org/jira/browse/SPARK-44516) depends on [SPARK-44433](https://issues.apache.org/jira/browse/SPARK-44433). ### Why are the changes needed? SS Connect development ### Does this PR introduce _any_ user-facing change? Yes now they can use connect listener ### How was this patch tested? Manual test and added unit test **addListener:** ``` # Client side: >>> from pyspark.sql.streaming.listener import StreamingQueryListener;from pyspark.sql.streaming.listener import (QueryStartedEvent, QueryProgressEvent, QueryTerminatedEvent, QueryIdleEvent) >>> class MyListener(StreamingQueryListener): ... def onQueryStarted(self, event: QueryStartedEvent) -> None: print("hi, event query id is: " + str(event.id)); df=self.spark.createDataFrame(["10","11","13"], "string").toDF("age"); df.write.saveAsTable("tbllistener1") ... def onQueryProgress(self, event: QueryProgressEvent) -> None: pass ... def onQueryIdle(self, event: QueryIdleEvent) -> None: pass ... def onQueryTerminated(self, event: QueryTerminatedEvent) -> None: pass ... >>> spark.streams.addListener(MyListener()) >>> q = spark.readStream.format("rate").load().writeStream.format("console").start() >>> q.stop() >>> spark.read.table("tbllistener1").collect() [Row(age='13'), Row(age='10'), Row(age='11’)] # Server side: ##### event_type received from python process is 0 hi, event query id is: dd7ba1c4-6c8f-4369-9c3c-5dede22b8a2f ``` **removeListener:** ``` # Client side: >>> listener = MyListener(); spark.streams.addListener(listener) >>> spark.streams.removeListener(listener) # Server side: # nothing to print actually, the listener is removed from server side StreamingQueryManager and cache in sessionHolder, but the process still hangs there. Follow up SPARK-44516 filed to stop this process ``` Closes #42250 from bogao007/3.5-branch-sync. Lead-authored-by: bogao007 <bo....@databricks.com> Co-authored-by: Wei Liu <wei....@databricks.com> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- .../sql/streaming/StreamingQueryManager.scala | 6 +- .../sql/streaming/ClientStreamingQuerySuite.scala | 2 +- .../src/main/protobuf/spark/connect/commands.proto | 2 + .../sql/connect/planner/SparkConnectPlanner.scala | 30 +-- .../planner/StreamingForeachBatchHelper.scala | 9 +- .../planner/StreamingQueryListenerHelper.scala | 69 +++++++ .../spark/api/python/PythonWorkerFactory.scala | 6 +- .../spark/api/python/StreamingPythonRunner.scala | 5 +- dev/sparktestsupport/modules.py | 1 + python/pyspark/sql/connect/proto/commands_pb2.py | 44 ++-- python/pyspark/sql/connect/proto/commands_pb2.pyi | 37 +++- python/pyspark/sql/connect/streaming/query.py | 31 +-- .../sql/connect/streaming/worker/__init__.py | 18 ++ .../streaming/worker/foreachBatch_worker.py} | 18 +- .../connect/streaming/worker/listener_worker.py} | 53 +++-- python/pyspark/sql/streaming/listener.py | 29 ++- python/pyspark/sql/streaming/query.py | 12 ++ .../connect/streaming/test_parity_listener.py | 90 ++++++++ .../sql/tests/streaming/test_streaming_listener.py | 228 +++++++++++---------- python/setup.py | 1 + 20 files changed, 484 insertions(+), 207 deletions(-) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala index 91744460440..d16638e5945 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala @@ -156,7 +156,8 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo executeManagerCmd( _.getAddListenerBuilder .setListenerPayload(ByteString.copyFrom(SparkSerDeUtils - .serialize(StreamingListenerPacket(id, listener))))) + .serialize(StreamingListenerPacket(id, listener)))) + .setId(id)) } /** @@ -168,8 +169,7 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo val id = getIdByListener(listener) executeManagerCmd( _.getRemoveListenerBuilder - .setListenerPayload(ByteString.copyFrom(SparkSerDeUtils - .serialize(StreamingListenerPacket(id, listener))))) + .setId(id)) removeCachedListener(id) } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/ClientStreamingQuerySuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/ClientStreamingQuerySuite.scala index bc778f02480..f9e6e686495 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/ClientStreamingQuerySuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/ClientStreamingQuerySuite.scala @@ -294,7 +294,7 @@ class ClientStreamingQuerySuite extends QueryTest with SQLHelper with Logging { spark.sql("DROP TABLE IF EXISTS my_listener_table") } - // List listeners after adding a new listener, length should be 2. + // List listeners after adding a new listener, length should be 1. val listeners = spark.streams.listListeners() assert(listeners.length == 1) diff --git a/connector/connect/common/src/main/protobuf/spark/connect/commands.proto b/connector/connect/common/src/main/protobuf/spark/connect/commands.proto index 4c4233124d8..49b25f099bf 100644 --- a/connector/connect/common/src/main/protobuf/spark/connect/commands.proto +++ b/connector/connect/common/src/main/protobuf/spark/connect/commands.proto @@ -365,6 +365,8 @@ message StreamingQueryManagerCommand { message StreamingQueryListenerCommand { bytes listener_payload = 1; + optional PythonUDF python_listener_payload = 2; + string id = 3; } } diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 61e5d9de914..f9a1e44516e 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -2831,7 +2831,7 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { StreamingForeachBatchHelper.scalaForeachBatchWrapper(scalaFn, sessionHolder) case StreamingForeachFunction.FunctionCase.FUNCTION_NOT_SET => - throw InvalidPlanInput("Unexpected") // Unreachable + throw InvalidPlanInput("Unexpected foreachBatch function") // Unreachable } writer.foreachBatch(foreachBatchFn) @@ -3074,23 +3074,27 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { respBuilder.setResetTerminated(true) case StreamingQueryManagerCommand.CommandCase.ADD_LISTENER => - val listenerPacket = Utils - .deserialize[StreamingListenerPacket]( - command.getAddListener.getListenerPayload.toByteArray, - Utils.getContextOrSparkClassLoader) - val listener: StreamingQueryListener = listenerPacket.listener - .asInstanceOf[StreamingQueryListener] - val id: String = listenerPacket.id + val listener = if (command.getAddListener.hasPythonListenerPayload) { + new PythonStreamingQueryListener( + transformPythonFunction(command.getAddListener.getPythonListenerPayload), + sessionHolder, + pythonExec) + } else { + val listenerPacket = Utils + .deserialize[StreamingListenerPacket]( + command.getAddListener.getListenerPayload.toByteArray, + Utils.getContextOrSparkClassLoader) + + listenerPacket.listener.asInstanceOf[StreamingQueryListener] + } + + val id = command.getAddListener.getId sessionHolder.cacheListenerById(id, listener) session.streams.addListener(listener) respBuilder.setAddListener(true) case StreamingQueryManagerCommand.CommandCase.REMOVE_LISTENER => - val listenerId = Utils - .deserialize[StreamingListenerPacket]( - command.getRemoveListener.getListenerPayload.toByteArray, - Utils.getContextOrSparkClassLoader) - .id + val listenerId = command.getRemoveListener.getId val listener: StreamingQueryListener = sessionHolder.getListenerOrThrow(listenerId) session.streams.removeListener(listener) sessionHolder.removeCachedListener(listenerId) diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala index 9770ac4cee5..3b9ae483cf1 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala @@ -18,9 +18,7 @@ package org.apache.spark.sql.connect.planner import java.util.UUID -import org.apache.spark.api.python.PythonRDD -import org.apache.spark.api.python.SimplePythonFunction -import org.apache.spark.api.python.StreamingPythonRunner +import org.apache.spark.api.python.{PythonRDD, SimplePythonFunction, StreamingPythonRunner} import org.apache.spark.internal.Logging import org.apache.spark.sql.DataFrame import org.apache.spark.sql.connect.service.SessionHolder @@ -90,7 +88,10 @@ object StreamingForeachBatchHelper extends Logging { val port = SparkConnectService.localPort val connectUrl = s"sc://localhost:$port/;user_id=${sessionHolder.userId}" val runner = StreamingPythonRunner(pythonFn, connectUrl) - val (dataOut, dataIn) = runner.init(sessionHolder.sessionId) + val (dataOut, dataIn) = + runner.init( + sessionHolder.sessionId, + "pyspark.sql.connect.streaming.worker.foreachBatch_worker") val foreachBatchRunnerFn: FnArgsWithId => Unit = (args: FnArgsWithId) => { diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingQueryListenerHelper.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingQueryListenerHelper.scala new file mode 100644 index 00000000000..d915bc93496 --- /dev/null +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingQueryListenerHelper.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect.planner + +import org.apache.spark.api.python.{PythonRDD, SimplePythonFunction, StreamingPythonRunner} +import org.apache.spark.sql.connect.service.{SessionHolder, SparkConnectService} +import org.apache.spark.sql.streaming.StreamingQueryListener + +/** + * A helper class for handling StreamingQueryListener related functionality in Spark Connect. Each + * instance of this class starts a python process, inside which has the python handling logic. + * When new a event is received, it is serialized to json, and passed to the python process. + */ +class PythonStreamingQueryListener( + listener: SimplePythonFunction, + sessionHolder: SessionHolder, + pythonExec: String) + extends StreamingQueryListener { + + val port = SparkConnectService.localPort + val connectUrl = s"sc://localhost:$port/;user_id=${sessionHolder.userId}" + val runner = StreamingPythonRunner(listener, connectUrl) + + val (dataOut, _) = + runner.init(sessionHolder.sessionId, "pyspark.sql.connect.streaming.worker.listener_worker") + + override def onQueryStarted(event: StreamingQueryListener.QueryStartedEvent): Unit = { + PythonRDD.writeUTF(event.json, dataOut) + dataOut.writeInt(0) + dataOut.flush() + } + + override def onQueryProgress(event: StreamingQueryListener.QueryProgressEvent): Unit = { + PythonRDD.writeUTF(event.json, dataOut) + dataOut.writeInt(1) + dataOut.flush() + } + + override def onQueryIdle(event: StreamingQueryListener.QueryIdleEvent): Unit = { + PythonRDD.writeUTF(event.json, dataOut) + dataOut.writeInt(2) + dataOut.flush() + } + + override def onQueryTerminated(event: StreamingQueryListener.QueryTerminatedEvent): Unit = { + PythonRDD.writeUTF(event.json, dataOut) + dataOut.writeInt(3) + dataOut.flush() + } + + // TODO(SPARK-44433)(SPARK-44516): Improve termination of Processes. + // Similar to foreachBatch when we need to exit the process when the query ends. + // In listener semantics, we need to exit the process when removeListener is called. +} diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala index 6039f8d232b..d5d97b74d11 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala @@ -110,9 +110,9 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String } } - /** Creates a Python worker with `pyspark.streaming_worker` module. */ - def createStreamingWorker(): (Socket, Option[Int]) = { - createSimpleWorker("pyspark.streaming_worker") + /** Creates a Python worker with streaming worker module. */ + def createStreamingWorker(streamingWorkerModule: String): (Socket, Option[Int]) = { + createSimpleWorker(streamingWorkerModule) } /** diff --git a/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala index faf462a1990..c02871ee145 100644 --- a/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala @@ -48,9 +48,8 @@ private[spark] class StreamingPythonRunner(func: PythonFunction, connectUrl: Str * Initializes the Python worker for streaming functions. Sets up Spark Connect session * to be used with the functions. */ - def init(sessionId: String): (DataOutputStream, DataInputStream) = { + def init(sessionId: String, workerModule: String): (DataOutputStream, DataInputStream) = { logInfo(s"Initializing Python runner (session: $sessionId ,pythonExec: $pythonExec") - val env = SparkEnv.get val localdir = env.blockManager.diskBlockManager.localDirs.map(f => f.getPath()).mkString(",") @@ -62,7 +61,7 @@ private[spark] class StreamingPythonRunner(func: PythonFunction, connectUrl: Str envVars.put("SPARK_CONNECT_LOCAL_URL", connectUrl) val pythonWorkerFactory = new PythonWorkerFactory(pythonExec, envVars.asScala.toMap) - val (worker: Socket, _) = pythonWorkerFactory.createStreamingWorker() + val (worker: Socket, _) = pythonWorkerFactory.createStreamingWorker(workerModule) val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize) val dataOut = new DataOutputStream(stream) diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 79c3f8f26b1..4005c317e62 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -863,6 +863,7 @@ pyspark_connect = Module( "pyspark.sql.tests.connect.client.test_artifact", "pyspark.sql.tests.connect.client.test_client", "pyspark.sql.tests.connect.streaming.test_parity_streaming", + "pyspark.sql.tests.connect.streaming.test_parity_listener", "pyspark.sql.tests.connect.streaming.test_parity_foreach", "pyspark.sql.tests.connect.streaming.test_parity_foreachBatch", "pyspark.sql.tests.connect.test_parity_pandas_grouped_map_with_state", diff --git a/python/pyspark/sql/connect/proto/commands_pb2.py b/python/pyspark/sql/connect/proto/commands_pb2.py index 6f03d80e669..90911e382bf 100644 --- a/python/pyspark/sql/connect/proto/commands_pb2.py +++ b/python/pyspark/sql/connect/proto/commands_pb2.py @@ -35,7 +35,7 @@ from pyspark.sql.connect.proto import relations_pb2 as spark_dot_connect_dot_rel DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x1cspark/connect/commands.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1aspark/connect/common.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto"\xf5\x07\n\x07\x43ommand\x12]\n\x11register_function\x18\x01 \x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionH\x00R\x10registerFunction\x12H\n\x0fwrite_operation\x18\x02 \x01(\x0b\x32\x1d.spark.connect.WriteOperationH\x00R\x0ewriteOperation\x12_\n\x15\x63reate_dataframe_view\x [...] + b'\n\x1cspark/connect/commands.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1aspark/connect/common.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto"\xf5\x07\n\x07\x43ommand\x12]\n\x11register_function\x18\x01 \x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionH\x00R\x10registerFunction\x12H\n\x0fwrite_operation\x18\x02 \x01(\x0b\x32\x1d.spark.connect.WriteOperationH\x00R\x0ewriteOperation\x12_\n\x15\x63reate_dataframe_view\x [...] ) _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) @@ -115,27 +115,27 @@ if _descriptor._USE_C_DESCRIPTORS == False: _STREAMINGQUERYCOMMANDRESULT_AWAITTERMINATIONRESULT._serialized_start = 6330 _STREAMINGQUERYCOMMANDRESULT_AWAITTERMINATIONRESULT._serialized_end = 6386 _STREAMINGQUERYMANAGERCOMMAND._serialized_start = 6404 - _STREAMINGQUERYMANAGERCOMMAND._serialized_end = 7101 + _STREAMINGQUERYMANAGERCOMMAND._serialized_end = 7233 _STREAMINGQUERYMANAGERCOMMAND_AWAITANYTERMINATIONCOMMAND._serialized_start = 6935 _STREAMINGQUERYMANAGERCOMMAND_AWAITANYTERMINATIONCOMMAND._serialized_end = 7014 - _STREAMINGQUERYMANAGERCOMMAND_STREAMINGQUERYLISTENERCOMMAND._serialized_start = 7016 - _STREAMINGQUERYMANAGERCOMMAND_STREAMINGQUERYLISTENERCOMMAND._serialized_end = 7090 - _STREAMINGQUERYMANAGERCOMMANDRESULT._serialized_start = 7104 - _STREAMINGQUERYMANAGERCOMMANDRESULT._serialized_end = 8180 - _STREAMINGQUERYMANAGERCOMMANDRESULT_ACTIVERESULT._serialized_start = 7712 - _STREAMINGQUERYMANAGERCOMMANDRESULT_ACTIVERESULT._serialized_end = 7839 - _STREAMINGQUERYMANAGERCOMMANDRESULT_STREAMINGQUERYINSTANCE._serialized_start = 7841 - _STREAMINGQUERYMANAGERCOMMANDRESULT_STREAMINGQUERYINSTANCE._serialized_end = 7956 - _STREAMINGQUERYMANAGERCOMMANDRESULT_AWAITANYTERMINATIONRESULT._serialized_start = 7958 - _STREAMINGQUERYMANAGERCOMMANDRESULT_AWAITANYTERMINATIONRESULT._serialized_end = 8017 - _STREAMINGQUERYMANAGERCOMMANDRESULT_STREAMINGQUERYLISTENERINSTANCE._serialized_start = 8019 - _STREAMINGQUERYMANAGERCOMMANDRESULT_STREAMINGQUERYLISTENERINSTANCE._serialized_end = 8094 - _STREAMINGQUERYMANAGERCOMMANDRESULT_LISTSTREAMINGQUERYLISTENERRESULT._serialized_start = 8096 - _STREAMINGQUERYMANAGERCOMMANDRESULT_LISTSTREAMINGQUERYLISTENERRESULT._serialized_end = 8165 - _GETRESOURCESCOMMAND._serialized_start = 8182 - _GETRESOURCESCOMMAND._serialized_end = 8203 - _GETRESOURCESCOMMANDRESULT._serialized_start = 8206 - _GETRESOURCESCOMMANDRESULT._serialized_end = 8418 - _GETRESOURCESCOMMANDRESULT_RESOURCESENTRY._serialized_start = 8322 - _GETRESOURCESCOMMANDRESULT_RESOURCESENTRY._serialized_end = 8418 + _STREAMINGQUERYMANAGERCOMMAND_STREAMINGQUERYLISTENERCOMMAND._serialized_start = 7017 + _STREAMINGQUERYMANAGERCOMMAND_STREAMINGQUERYLISTENERCOMMAND._serialized_end = 7222 + _STREAMINGQUERYMANAGERCOMMANDRESULT._serialized_start = 7236 + _STREAMINGQUERYMANAGERCOMMANDRESULT._serialized_end = 8312 + _STREAMINGQUERYMANAGERCOMMANDRESULT_ACTIVERESULT._serialized_start = 7844 + _STREAMINGQUERYMANAGERCOMMANDRESULT_ACTIVERESULT._serialized_end = 7971 + _STREAMINGQUERYMANAGERCOMMANDRESULT_STREAMINGQUERYINSTANCE._serialized_start = 7973 + _STREAMINGQUERYMANAGERCOMMANDRESULT_STREAMINGQUERYINSTANCE._serialized_end = 8088 + _STREAMINGQUERYMANAGERCOMMANDRESULT_AWAITANYTERMINATIONRESULT._serialized_start = 8090 + _STREAMINGQUERYMANAGERCOMMANDRESULT_AWAITANYTERMINATIONRESULT._serialized_end = 8149 + _STREAMINGQUERYMANAGERCOMMANDRESULT_STREAMINGQUERYLISTENERINSTANCE._serialized_start = 8151 + _STREAMINGQUERYMANAGERCOMMANDRESULT_STREAMINGQUERYLISTENERINSTANCE._serialized_end = 8226 + _STREAMINGQUERYMANAGERCOMMANDRESULT_LISTSTREAMINGQUERYLISTENERRESULT._serialized_start = 8228 + _STREAMINGQUERYMANAGERCOMMANDRESULT_LISTSTREAMINGQUERYLISTENERRESULT._serialized_end = 8297 + _GETRESOURCESCOMMAND._serialized_start = 8314 + _GETRESOURCESCOMMAND._serialized_end = 8335 + _GETRESOURCESCOMMANDRESULT._serialized_start = 8338 + _GETRESOURCESCOMMANDRESULT._serialized_end = 8550 + _GETRESOURCESCOMMANDRESULT_RESOURCESENTRY._serialized_start = 8454 + _GETRESOURCESCOMMANDRESULT_RESOURCESENTRY._serialized_end = 8550 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/commands_pb2.pyi b/python/pyspark/sql/connect/proto/commands_pb2.pyi index be423ea036e..f3dca7ab4bb 100644 --- a/python/pyspark/sql/connect/proto/commands_pb2.pyi +++ b/python/pyspark/sql/connect/proto/commands_pb2.pyi @@ -1372,15 +1372,50 @@ class StreamingQueryManagerCommand(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor LISTENER_PAYLOAD_FIELD_NUMBER: builtins.int + PYTHON_LISTENER_PAYLOAD_FIELD_NUMBER: builtins.int + ID_FIELD_NUMBER: builtins.int listener_payload: builtins.bytes + @property + def python_listener_payload( + self, + ) -> pyspark.sql.connect.proto.expressions_pb2.PythonUDF: ... + id: builtins.str def __init__( self, *, listener_payload: builtins.bytes = ..., + python_listener_payload: pyspark.sql.connect.proto.expressions_pb2.PythonUDF + | None = ..., + id: builtins.str = ..., ) -> None: ... + def HasField( + self, + field_name: typing_extensions.Literal[ + "_python_listener_payload", + b"_python_listener_payload", + "python_listener_payload", + b"python_listener_payload", + ], + ) -> builtins.bool: ... def ClearField( - self, field_name: typing_extensions.Literal["listener_payload", b"listener_payload"] + self, + field_name: typing_extensions.Literal[ + "_python_listener_payload", + b"_python_listener_payload", + "id", + b"id", + "listener_payload", + b"listener_payload", + "python_listener_payload", + b"python_listener_payload", + ], ) -> None: ... + def WhichOneof( + self, + oneof_group: typing_extensions.Literal[ + "_python_listener_payload", b"_python_listener_payload" + ], + ) -> typing_extensions.Literal["python_listener_payload"] | None: ... ACTIVE_FIELD_NUMBER: builtins.int GET_QUERY_FIELD_NUMBER: builtins.int diff --git a/python/pyspark/sql/connect/streaming/query.py b/python/pyspark/sql/connect/streaming/query.py index e5aa881c990..59e98e7bc30 100644 --- a/python/pyspark/sql/connect/streaming/query.py +++ b/python/pyspark/sql/connect/streaming/query.py @@ -21,6 +21,9 @@ from typing import TYPE_CHECKING, Any, cast, Dict, List, Optional from pyspark.errors import StreamingQueryException, PySparkValueError import pyspark.sql.connect.proto as pb2 +from pyspark.serializers import CloudPickleSerializer +from pyspark.sql.connect import proto +from pyspark.sql.streaming import StreamingQueryListener from pyspark.sql.streaming.query import ( StreamingQuery as PySparkStreamingQuery, StreamingQueryManager as PySparkStreamingQueryManager, @@ -226,25 +229,27 @@ class StreamingQueryManager: cmd = pb2.StreamingQueryManagerCommand() cmd.reset_terminated = True self._execute_streaming_query_manager_cmd(cmd) - return None resetTerminated.__doc__ = PySparkStreamingQueryManager.resetTerminated.__doc__ - def addListener(self, listener: Any) -> None: - # TODO(SPARK-42941): Change listener type to Connect StreamingQueryListener - # and implement below - raise NotImplementedError("addListener() is not implemented.") + def addListener(self, listener: StreamingQueryListener) -> None: + listener._init_listener_id() + cmd = pb2.StreamingQueryManagerCommand() + expr = proto.PythonUDF() + expr.command = CloudPickleSerializer().dumps(listener) + expr.python_ver = "%d.%d" % sys.version_info[:2] + cmd.add_listener.python_listener_payload.CopyFrom(expr) + cmd.add_listener.id = listener._id + self._execute_streaming_query_manager_cmd(cmd) - # TODO(SPARK-42941): uncomment below - # addListener.__doc__ = PySparkStreamingQueryManager.addListener.__doc__ + addListener.__doc__ = PySparkStreamingQueryManager.addListener.__doc__ - def removeListener(self, listener: Any) -> None: - # TODO(SPARK-42941): Change listener type to Connect StreamingQueryListener - # and implement below - raise NotImplementedError("removeListener() is not implemented.") + def removeListener(self, listener: StreamingQueryListener) -> None: + cmd = pb2.StreamingQueryManagerCommand() + cmd.remove_listener.id = listener._id + self._execute_streaming_query_manager_cmd(cmd) - # TODO(SPARK-42941): uncomment below - # removeListener.__doc__ = PySparkStreamingQueryManager.removeListener.__doc__ + removeListener.__doc__ = PySparkStreamingQueryManager.removeListener.__doc__ def _execute_streaming_query_manager_cmd( self, cmd: pb2.StreamingQueryManagerCommand diff --git a/python/pyspark/sql/connect/streaming/worker/__init__.py b/python/pyspark/sql/connect/streaming/worker/__init__.py new file mode 100644 index 00000000000..a5c98019891 --- /dev/null +++ b/python/pyspark/sql/connect/streaming/worker/__init__.py @@ -0,0 +1,18 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Spark Connect Streaming Server-side Worker""" diff --git a/python/pyspark/streaming_worker.py b/python/pyspark/sql/connect/streaming/worker/foreachBatch_worker.py similarity index 86% copy from python/pyspark/streaming_worker.py copy to python/pyspark/sql/connect/streaming/worker/foreachBatch_worker.py index a818880a984..054788539f2 100644 --- a/python/pyspark/streaming_worker.py +++ b/python/pyspark/sql/connect/streaming/worker/foreachBatch_worker.py @@ -16,7 +16,8 @@ # """ -A worker for streaming foreachBatch and query listener in Spark Connect. +A worker for streaming foreachBatch in Spark Connect. +Usually this is ran on the driver side of the Spark Connect Server. """ import os @@ -29,20 +30,23 @@ from pyspark.serializers import ( ) from pyspark import worker from pyspark.sql import SparkSession +from typing import IO pickle_ser = CPickleSerializer() utf8_deserializer = UTF8Deserializer() -def main(infile, outfile): # type: ignore[no-untyped-def] - log_name = "Streaming ForeachBatch worker" +def main(infile: IO, outfile: IO) -> None: connect_url = os.environ["SPARK_CONNECT_LOCAL_URL"] session_id = utf8_deserializer.loads(infile) - print(f"{log_name} is starting with url {connect_url} and sessionId {session_id}.") + print( + "Streaming foreachBatch worker is starting with " + f"url {connect_url} and sessionId {session_id}." + ) spark_connect_session = SparkSession.builder.remote(connect_url).getOrCreate() - spark_connect_session._client._session_id = session_id + spark_connect_session._client._session_id = session_id # type: ignore[attr-defined] # TODO(SPARK-44460): Pass credentials. # TODO(SPARK-44461): Enable Process Isolation @@ -52,6 +56,8 @@ def main(infile, outfile): # type: ignore[no-untyped-def] outfile.flush() + log_name = "Streaming ForeachBatch worker" + def process(df_id, batch_id): # type: ignore[no-untyped-def] print(f"{log_name} Started batch {batch_id} with DF id {df_id}") batch_df = spark_connect_session._create_remote_dataframe(df_id) @@ -67,8 +73,6 @@ def main(infile, outfile): # type: ignore[no-untyped-def] if __name__ == "__main__": - print("Starting streaming worker") - # Read information about how to connect back to the JVM from the environment. java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"]) auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"] diff --git a/python/pyspark/streaming_worker.py b/python/pyspark/sql/connect/streaming/worker/listener_worker.py similarity index 55% rename from python/pyspark/streaming_worker.py rename to python/pyspark/sql/connect/streaming/worker/listener_worker.py index a818880a984..8eb310461b6 100644 --- a/python/pyspark/streaming_worker.py +++ b/python/pyspark/sql/connect/streaming/worker/listener_worker.py @@ -16,59 +16,76 @@ # """ -A worker for streaming foreachBatch and query listener in Spark Connect. +A worker for streaming query listener in Spark Connect. +Usually this is ran on the driver side of the Spark Connect Server. """ import os +import json from pyspark.java_gateway import local_connect_and_auth from pyspark.serializers import ( + read_int, write_int, - read_long, UTF8Deserializer, CPickleSerializer, ) from pyspark import worker from pyspark.sql import SparkSession +from typing import IO + +from pyspark.sql.streaming.listener import ( + QueryStartedEvent, + QueryProgressEvent, + QueryTerminatedEvent, + QueryIdleEvent, +) pickle_ser = CPickleSerializer() utf8_deserializer = UTF8Deserializer() -def main(infile, outfile): # type: ignore[no-untyped-def] - log_name = "Streaming ForeachBatch worker" +def main(infile: IO, outfile: IO) -> None: connect_url = os.environ["SPARK_CONNECT_LOCAL_URL"] session_id = utf8_deserializer.loads(infile) - print(f"{log_name} is starting with url {connect_url} and sessionId {session_id}.") + print( + "Streaming query listener worker is starting with " + f"url {connect_url} and sessionId {session_id}." + ) spark_connect_session = SparkSession.builder.remote(connect_url).getOrCreate() - spark_connect_session._client._session_id = session_id + spark_connect_session._client._session_id = session_id # type: ignore[attr-defined] # TODO(SPARK-44460): Pass credentials. # TODO(SPARK-44461): Enable Process Isolation - func = worker.read_command(pickle_ser, infile) + listener = worker.read_command(pickle_ser, infile) write_int(0, outfile) # Indicate successful initialization outfile.flush() - def process(df_id, batch_id): # type: ignore[no-untyped-def] - print(f"{log_name} Started batch {batch_id} with DF id {df_id}") - batch_df = spark_connect_session._create_remote_dataframe(df_id) - func(batch_df, batch_id) - print(f"{log_name} Completed batch {batch_id} with DF id {df_id}") + listener._set_spark_session(spark_connect_session) + assert listener.spark == spark_connect_session + + def process(listener_event_str, listener_event_type): # type: ignore[no-untyped-def] + listener_event = json.loads(listener_event_str) + if listener_event_type == 0: + listener.onQueryStarted(QueryStartedEvent.fromJson(listener_event)) + elif listener_event_type == 1: + listener.onQueryProgress(QueryProgressEvent.fromJson(listener_event)) + elif listener_event_type == 2: + listener.onQueryIdle(QueryIdleEvent.fromJson(listener_event)) + elif listener_event_type == 3: + listener.onQueryTerminated(QueryTerminatedEvent.fromJson(listener_event)) while True: - df_ref_id = utf8_deserializer.loads(infile) - batch_id = read_long(infile) - process(df_ref_id, int(batch_id)) # TODO(SPARK-44463): Propagate error to the user. - write_int(0, outfile) + event = utf8_deserializer.loads(infile) + event_type = read_int(infile) + process(event, int(event_type)) # TODO(SPARK-44463): Propagate error to the user. outfile.flush() if __name__ == "__main__": - print("Starting streaming worker") - # Read information about how to connect back to the JVM from the environment. java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"]) auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"] diff --git a/python/pyspark/sql/streaming/listener.py b/python/pyspark/sql/streaming/listener.py index 198af0c9cbe..225ad6d45af 100644 --- a/python/pyspark/sql/streaming/listener.py +++ b/python/pyspark/sql/streaming/listener.py @@ -62,6 +62,21 @@ class StreamingQueryListener(ABC): >>> spark.streams.addListener(MyListener()) """ + def _set_spark_session( + self, spark: "SparkSession" # type: ignore[name-defined] # noqa: F821 + ) -> None: + self._sparkSession = spark + + @property + def spark(self) -> Optional["SparkSession"]: # type: ignore[name-defined] # noqa: F821 + if hasattr(self, "_sparkSession"): + return self._sparkSession + else: + return None + + def _init_listener_id(self) -> None: + self._id = str(uuid.uuid4()) + @abstractmethod def onQueryStarted(self, event: "QueryStartedEvent") -> None: """ @@ -463,8 +478,8 @@ class StreamingQueryProgress: timestamp=j["timestamp"], batchId=j["batchId"], batchDuration=j["batchDuration"], - durationMs=dict(j["durationMs"]), - eventTime=dict(j["eventTime"]), + durationMs=dict(j["durationMs"]) if "durationMs" in j else {}, + eventTime=dict(j["eventTime"]) if "eventTime" in j else {}, stateOperators=[StateOperatorProgress.fromJson(s) for s in j["stateOperators"]], sources=[SourceProgress.fromJson(s) for s in j["sources"]], sink=SinkProgress.fromJson(j["sink"]), @@ -474,7 +489,9 @@ class StreamingQueryProgress: observedMetrics={ k: Row(*row_dict.keys())(*row_dict.values()) # Assume no nested rows for k, row_dict in j["observedMetrics"].items() - }, + } + if "observedMetrics" in j + else {}, ) @property @@ -696,7 +713,7 @@ class StateOperatorProgress: numRowsDroppedByWatermark=j["numRowsDroppedByWatermark"], numShufflePartitions=j["numShufflePartitions"], numStateStoreInstances=j["numStateStoreInstances"], - customMetrics=dict(j["customMetrics"]), + customMetrics=dict(j["customMetrics"]) if "customMetrics" in j else {}, ) @property @@ -831,7 +848,7 @@ class SourceProgress: numInputRows=j["numInputRows"], inputRowsPerSecond=j["inputRowsPerSecond"], processedRowsPerSecond=j["processedRowsPerSecond"], - metrics=dict(j["metrics"]), + metrics=dict(j["metrics"]) if "metrics" in j else {}, ) @property @@ -951,7 +968,7 @@ class SinkProgress: jdict=j, description=j["description"], numOutputRows=j["numOutputRows"], - metrics=j["metrics"], + metrics=dict(j["metrics"]) if "metrics" in j else {}, ) @property diff --git a/python/pyspark/sql/streaming/query.py b/python/pyspark/sql/streaming/query.py index 443e7dbee39..db104e30755 100644 --- a/python/pyspark/sql/streaming/query.py +++ b/python/pyspark/sql/streaming/query.py @@ -618,12 +618,24 @@ class StreamingQueryManager: .. versionadded:: 3.4.0 + .. versionchanged:: 3.5.0 + Supports Spark Connect. + Parameters ---------- listener : :class:`StreamingQueryListener` A :class:`StreamingQueryListener` to receive up-calls for life cycle events of :class:`~pyspark.sql.streaming.StreamingQuery`. + Notes + ----- + This function behaves differently in Spark Connect mode. + In Connect, the provided functions doesn't have access to variables defined outside of it. + Also in Connect, you need to use `self.spark` to access spark session. + Using `spark` would throw an exception. + In short, if you want to use spark session inside the listener, + please use `self.spark` in Connect mode, and use `spark` otherwise. + Examples -------- >>> from pyspark.sql.streaming import StreamingQueryListener diff --git a/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py b/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py new file mode 100644 index 00000000000..547462d4da6 --- /dev/null +++ b/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py @@ -0,0 +1,90 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest +import time + +from pyspark.sql.tests.streaming.test_streaming_listener import StreamingListenerTestsMixin +from pyspark.sql.streaming.listener import StreamingQueryListener, QueryStartedEvent +from pyspark.sql.types import StructType, StructField, StringType +from pyspark.testing.connectutils import ReusedConnectTestCase + + +def get_start_event_schema(): + return StructType( + [ + StructField("id", StringType(), True), + StructField("runId", StringType(), True), + StructField("name", StringType(), True), + StructField("timestamp", StringType(), True), + ] + ) + + +class TestListener(StreamingQueryListener): + def onQueryStarted(self, event): + df = self.spark.createDataFrame( + data=[(str(event.id), str(event.runId), event.name, event.timestamp)], + schema=get_start_event_schema(), + ) + df.write.saveAsTable("listener_start_events") + + def onQueryProgress(self, event): + pass + + def onQueryIdle(self, event): + pass + + def onQueryTerminated(self, event): + pass + + +class StreamingListenerParityTests(StreamingListenerTestsMixin, ReusedConnectTestCase): + def test_listener_events(self): + test_listener = TestListener() + + try: + self.spark.streams.addListener(test_listener) + + df = self.spark.readStream.format("rate").option("rowsPerSecond", 10).load() + q = df.writeStream.format("noop").queryName("test").start() + + self.assertTrue(q.isActive) + time.sleep(10) + q.stop() + + start_event = QueryStartedEvent.fromJson( + self.spark.read.table("listener_start_events").collect()[0].asDict() + ) + + self.check_start_event(start_event) + + finally: + self.spark.streams.removeListener(test_listener) + + +if __name__ == "__main__": + import unittest + from pyspark.sql.tests.connect.streaming.test_parity_listener import * # noqa: F401 + + try: + import xmlrunner # type: ignore[import] + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/streaming/test_streaming_listener.py b/python/pyspark/sql/tests/streaming/test_streaming_listener.py index 2bd6d2c6668..cbbdc2955e5 100644 --- a/python/pyspark/sql/tests/streaming/test_streaming_listener.py +++ b/python/pyspark/sql/tests/streaming/test_streaming_listener.py @@ -33,119 +33,7 @@ from pyspark.sql.streaming.listener import ( from pyspark.testing.sqlutils import ReusedSQLTestCase -class StreamingListenerTests(ReusedSQLTestCase): - def test_number_of_public_methods(self): - msg = ( - "New field or method was detected in JVM side. If you added a new public " - "field or method, implement that in the corresponding Python class too." - "Otherwise, fix the number on the assert here." - ) - - def get_number_of_public_methods(clz): - return len( - self.spark.sparkContext._jvm.org.apache.spark.util.Utils.classForName( - clz, True, False - ).getMethods() - ) - - self.assertEquals( - get_number_of_public_methods( - "org.apache.spark.sql.streaming.StreamingQueryListener$QueryStartedEvent" - ), - 15, - msg, - ) - self.assertEquals( - get_number_of_public_methods( - "org.apache.spark.sql.streaming.StreamingQueryListener$QueryProgressEvent" - ), - 12, - msg, - ) - self.assertEquals( - get_number_of_public_methods( - "org.apache.spark.sql.streaming.StreamingQueryListener$QueryTerminatedEvent" - ), - 15, - msg, - ) - self.assertEquals( - get_number_of_public_methods("org.apache.spark.sql.streaming.StreamingQueryProgress"), - 38, - msg, - ) - self.assertEquals( - get_number_of_public_methods("org.apache.spark.sql.streaming.StateOperatorProgress"), - 27, - msg, - ) - self.assertEquals( - get_number_of_public_methods("org.apache.spark.sql.streaming.SourceProgress"), 21, msg - ) - self.assertEquals( - get_number_of_public_methods("org.apache.spark.sql.streaming.SinkProgress"), 19, msg - ) - - def test_listener_events(self): - start_event = None - progress_event = None - terminated_event = None - - class TestListener(StreamingQueryListener): - def onQueryStarted(self, event): - nonlocal start_event - start_event = event - - def onQueryProgress(self, event): - nonlocal progress_event - progress_event = event - - def onQueryIdle(self, event): - pass - - def onQueryTerminated(self, event): - nonlocal terminated_event - terminated_event = event - - test_listener = TestListener() - - try: - self.spark.streams.addListener(test_listener) - - df = self.spark.readStream.format("rate").option("rowsPerSecond", 10).load() - - # check successful stateful query - df_stateful = df.groupBy().count() # make query stateful - q = ( - df_stateful.writeStream.format("noop") - .queryName("test") - .outputMode("complete") - .start() - ) - self.assertTrue(q.isActive) - time.sleep(10) - q.stop() - - # Make sure all events are empty - self.spark.sparkContext._jsc.sc().listenerBus().waitUntilEmpty() - - self.check_start_event(start_event) - self.check_progress_event(progress_event) - self.check_terminated_event(terminated_event) - - # Check query terminated with exception - from pyspark.sql.functions import col, udf - - bad_udf = udf(lambda x: 1 / 0) - q = df.select(bad_udf(col("value"))).writeStream.format("noop").start() - time.sleep(5) - q.stop() - self.spark.sparkContext._jsc.sc().listenerBus().waitUntilEmpty() - self.check_terminated_event(terminated_event, "ZeroDivisionError") - - finally: - self.spark.streams.removeListener(test_listener) - +class StreamingListenerTestsMixin: def check_start_event(self, event): """Check QueryStartedEvent""" self.assertTrue(isinstance(event, QueryStartedEvent)) @@ -304,6 +192,120 @@ class StreamingListenerTests(ReusedSQLTestCase): self.assertTrue(isinstance(progress.numOutputRows, int)) self.assertTrue(isinstance(progress.metrics, dict)) + +class StreamingListenerTests(StreamingListenerTestsMixin, ReusedSQLTestCase): + def test_number_of_public_methods(self): + msg = ( + "New field or method was detected in JVM side. If you added a new public " + "field or method, implement that in the corresponding Python class too." + "Otherwise, fix the number on the assert here." + ) + + def get_number_of_public_methods(clz): + return len( + self.spark.sparkContext._jvm.org.apache.spark.util.Utils.classForName( + clz, True, False + ).getMethods() + ) + + self.assertEquals( + get_number_of_public_methods( + "org.apache.spark.sql.streaming.StreamingQueryListener$QueryStartedEvent" + ), + 15, + msg, + ) + self.assertEquals( + get_number_of_public_methods( + "org.apache.spark.sql.streaming.StreamingQueryListener$QueryProgressEvent" + ), + 12, + msg, + ) + self.assertEquals( + get_number_of_public_methods( + "org.apache.spark.sql.streaming.StreamingQueryListener$QueryTerminatedEvent" + ), + 15, + msg, + ) + self.assertEquals( + get_number_of_public_methods("org.apache.spark.sql.streaming.StreamingQueryProgress"), + 38, + msg, + ) + self.assertEquals( + get_number_of_public_methods("org.apache.spark.sql.streaming.StateOperatorProgress"), + 27, + msg, + ) + self.assertEquals( + get_number_of_public_methods("org.apache.spark.sql.streaming.SourceProgress"), 21, msg + ) + self.assertEquals( + get_number_of_public_methods("org.apache.spark.sql.streaming.SinkProgress"), 19, msg + ) + + def test_listener_events(self): + start_event = None + progress_event = None + terminated_event = None + + class TestListener(StreamingQueryListener): + def onQueryStarted(self, event): + nonlocal start_event + start_event = event + + def onQueryProgress(self, event): + nonlocal progress_event + progress_event = event + + def onQueryIdle(self, event): + pass + + def onQueryTerminated(self, event): + nonlocal terminated_event + terminated_event = event + + test_listener = TestListener() + + try: + self.spark.streams.addListener(test_listener) + + df = self.spark.readStream.format("rate").option("rowsPerSecond", 10).load() + + # check successful stateful query + df_stateful = df.groupBy().count() # make query stateful + q = ( + df_stateful.writeStream.format("noop") + .queryName("test") + .outputMode("complete") + .start() + ) + self.assertTrue(q.isActive) + time.sleep(10) + q.stop() + + # Make sure all events are empty + self.spark.sparkContext._jsc.sc().listenerBus().waitUntilEmpty() + + self.check_start_event(start_event) + self.check_progress_event(progress_event) + self.check_terminated_event(terminated_event) + + # Check query terminated with exception + from pyspark.sql.functions import col, udf + + bad_udf = udf(lambda x: 1 / 0) + q = df.select(bad_udf(col("value"))).writeStream.format("noop").start() + time.sleep(5) + q.stop() + self.spark.sparkContext._jsc.sc().listenerBus().waitUntilEmpty() + self.check_terminated_event(terminated_event, "ZeroDivisionError") + + finally: + self.spark.streams.removeListener(test_listener) + def test_remove_listener(self): # SPARK-38804: Test StreamingQueryManager.removeListener class TestListener(StreamingQueryListener): diff --git a/python/setup.py b/python/setup.py index 4d14dfd3cb9..b8e4c9a40e0 100755 --- a/python/setup.py +++ b/python/setup.py @@ -254,6 +254,7 @@ try: "pyspark.sql.protobuf", "pyspark.sql.streaming", "pyspark.streaming", + "pyspark.sql.connect.streaming.worker", "pyspark.bin", "pyspark.sbin", "pyspark.jars", --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org