This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch branch-3.5
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.5 by this push:
     new 808fd88bf4f [SPARK-42941][SS][CONNECT][3.5] Python 
StreamingQueryListener
808fd88bf4f is described below

commit 808fd88bf4fb96d3277641897f2a8fffdeb77f73
Author: bogao007 <bo....@databricks.com>
AuthorDate: Wed Aug 2 09:02:25 2023 +0900

    [SPARK-42941][SS][CONNECT][3.5] Python StreamingQueryListener
    
    ### What changes were proposed in this pull request?
    
    Implement the python streaming query listener and the addListener method 
and removeListener method, follow up filed in: 
[SPARK-44516](https://issues.apache.org/jira/browse/SPARK-44516) to actually 
terminate the query listener process when removeListener is called. 
[SPARK-44516](https://issues.apache.org/jira/browse/SPARK-44516) depends on 
[SPARK-44433](https://issues.apache.org/jira/browse/SPARK-44433).
    
    ### Why are the changes needed?
    
    SS Connect development
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes now they can use connect listener
    
    ### How was this patch tested?
    
    Manual test and added unit test
    
    **addListener:**
    ```
    # Client side:
    >>> from pyspark.sql.streaming.listener import StreamingQueryListener;from 
pyspark.sql.streaming.listener import (QueryStartedEvent, QueryProgressEvent, 
QueryTerminatedEvent, QueryIdleEvent)
    >>> class MyListener(StreamingQueryListener):
    ...     def onQueryStarted(self, event: QueryStartedEvent) -> None: 
print("hi, event query id is: " +  str(event.id)); 
df=self.spark.createDataFrame(["10","11","13"], "string").toDF("age"); 
df.write.saveAsTable("tbllistener1")
    ...     def onQueryProgress(self, event: QueryProgressEvent) -> None: pass
    ...     def onQueryIdle(self, event: QueryIdleEvent) -> None: pass
    ...     def onQueryTerminated(self, event: QueryTerminatedEvent) -> None: 
pass
    ...
    >>> spark.streams.addListener(MyListener())
    >>> q = 
spark.readStream.format("rate").load().writeStream.format("console").start()
    >>> q.stop()
    >>> spark.read.table("tbllistener1").collect()
    [Row(age='13'), Row(age='10'), Row(age='11’)]
    
    # Server side:
    ##### event_type received from python process is 0
    hi, event query id is: dd7ba1c4-6c8f-4369-9c3c-5dede22b8a2f
    ```
    
    **removeListener:**
    
    ```
    # Client side:
    >>> listener = MyListener(); spark.streams.addListener(listener)
    >>> spark.streams.removeListener(listener)
    
    # Server side:
    # nothing to print actually, the listener is removed from server side 
StreamingQueryManager and cache in sessionHolder, but the process still hangs 
there. Follow up SPARK-44516 filed to stop this process
    ```
    
    Closes #42250 from bogao007/3.5-branch-sync.
    
    Lead-authored-by: bogao007 <bo....@databricks.com>
    Co-authored-by: Wei Liu <wei....@databricks.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 .../sql/streaming/StreamingQueryManager.scala      |   6 +-
 .../sql/streaming/ClientStreamingQuerySuite.scala  |   2 +-
 .../src/main/protobuf/spark/connect/commands.proto |   2 +
 .../sql/connect/planner/SparkConnectPlanner.scala  |  30 +--
 .../planner/StreamingForeachBatchHelper.scala      |   9 +-
 .../planner/StreamingQueryListenerHelper.scala     |  69 +++++++
 .../spark/api/python/PythonWorkerFactory.scala     |   6 +-
 .../spark/api/python/StreamingPythonRunner.scala   |   5 +-
 dev/sparktestsupport/modules.py                    |   1 +
 python/pyspark/sql/connect/proto/commands_pb2.py   |  44 ++--
 python/pyspark/sql/connect/proto/commands_pb2.pyi  |  37 +++-
 python/pyspark/sql/connect/streaming/query.py      |  31 +--
 .../sql/connect/streaming/worker/__init__.py       |  18 ++
 .../streaming/worker/foreachBatch_worker.py}       |  18 +-
 .../connect/streaming/worker/listener_worker.py}   |  53 +++--
 python/pyspark/sql/streaming/listener.py           |  29 ++-
 python/pyspark/sql/streaming/query.py              |  12 ++
 .../connect/streaming/test_parity_listener.py      |  90 ++++++++
 .../sql/tests/streaming/test_streaming_listener.py | 228 +++++++++++----------
 python/setup.py                                    |   1 +
 20 files changed, 484 insertions(+), 207 deletions(-)

diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala
index 91744460440..d16638e5945 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala
@@ -156,7 +156,8 @@ class StreamingQueryManager private[sql] (sparkSession: 
SparkSession) extends Lo
     executeManagerCmd(
       _.getAddListenerBuilder
         .setListenerPayload(ByteString.copyFrom(SparkSerDeUtils
-          .serialize(StreamingListenerPacket(id, listener)))))
+          .serialize(StreamingListenerPacket(id, listener))))
+        .setId(id))
   }
 
   /**
@@ -168,8 +169,7 @@ class StreamingQueryManager private[sql] (sparkSession: 
SparkSession) extends Lo
     val id = getIdByListener(listener)
     executeManagerCmd(
       _.getRemoveListenerBuilder
-        .setListenerPayload(ByteString.copyFrom(SparkSerDeUtils
-          .serialize(StreamingListenerPacket(id, listener)))))
+        .setId(id))
     removeCachedListener(id)
   }
 
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/ClientStreamingQuerySuite.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/ClientStreamingQuerySuite.scala
index bc778f02480..f9e6e686495 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/ClientStreamingQuerySuite.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/ClientStreamingQuerySuite.scala
@@ -294,7 +294,7 @@ class ClientStreamingQuerySuite extends QueryTest with 
SQLHelper with Logging {
       spark.sql("DROP TABLE IF EXISTS my_listener_table")
     }
 
-    // List listeners after adding a new listener, length should be 2.
+    // List listeners after adding a new listener, length should be 1.
     val listeners = spark.streams.listListeners()
     assert(listeners.length == 1)
 
diff --git 
a/connector/connect/common/src/main/protobuf/spark/connect/commands.proto 
b/connector/connect/common/src/main/protobuf/spark/connect/commands.proto
index 4c4233124d8..49b25f099bf 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/commands.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/commands.proto
@@ -365,6 +365,8 @@ message StreamingQueryManagerCommand {
 
   message StreamingQueryListenerCommand {
     bytes listener_payload = 1;
+    optional PythonUDF python_listener_payload = 2;
+    string id = 3;
   }
 }
 
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 61e5d9de914..f9a1e44516e 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -2831,7 +2831,7 @@ class SparkConnectPlanner(val sessionHolder: 
SessionHolder) extends Logging {
           StreamingForeachBatchHelper.scalaForeachBatchWrapper(scalaFn, 
sessionHolder)
 
         case StreamingForeachFunction.FunctionCase.FUNCTION_NOT_SET =>
-          throw InvalidPlanInput("Unexpected") // Unreachable
+          throw InvalidPlanInput("Unexpected foreachBatch function") // 
Unreachable
       }
 
       writer.foreachBatch(foreachBatchFn)
@@ -3074,23 +3074,27 @@ class SparkConnectPlanner(val sessionHolder: 
SessionHolder) extends Logging {
         respBuilder.setResetTerminated(true)
 
       case StreamingQueryManagerCommand.CommandCase.ADD_LISTENER =>
-        val listenerPacket = Utils
-          .deserialize[StreamingListenerPacket](
-            command.getAddListener.getListenerPayload.toByteArray,
-            Utils.getContextOrSparkClassLoader)
-        val listener: StreamingQueryListener = listenerPacket.listener
-          .asInstanceOf[StreamingQueryListener]
-        val id: String = listenerPacket.id
+        val listener = if (command.getAddListener.hasPythonListenerPayload) {
+          new PythonStreamingQueryListener(
+            
transformPythonFunction(command.getAddListener.getPythonListenerPayload),
+            sessionHolder,
+            pythonExec)
+        } else {
+          val listenerPacket = Utils
+            .deserialize[StreamingListenerPacket](
+              command.getAddListener.getListenerPayload.toByteArray,
+              Utils.getContextOrSparkClassLoader)
+
+          listenerPacket.listener.asInstanceOf[StreamingQueryListener]
+        }
+
+        val id = command.getAddListener.getId
         sessionHolder.cacheListenerById(id, listener)
         session.streams.addListener(listener)
         respBuilder.setAddListener(true)
 
       case StreamingQueryManagerCommand.CommandCase.REMOVE_LISTENER =>
-        val listenerId = Utils
-          .deserialize[StreamingListenerPacket](
-            command.getRemoveListener.getListenerPayload.toByteArray,
-            Utils.getContextOrSparkClassLoader)
-          .id
+        val listenerId = command.getRemoveListener.getId
         val listener: StreamingQueryListener = 
sessionHolder.getListenerOrThrow(listenerId)
         session.streams.removeListener(listener)
         sessionHolder.removeCachedListener(listenerId)
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala
index 9770ac4cee5..3b9ae483cf1 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala
@@ -18,9 +18,7 @@ package org.apache.spark.sql.connect.planner
 
 import java.util.UUID
 
-import org.apache.spark.api.python.PythonRDD
-import org.apache.spark.api.python.SimplePythonFunction
-import org.apache.spark.api.python.StreamingPythonRunner
+import org.apache.spark.api.python.{PythonRDD, SimplePythonFunction, 
StreamingPythonRunner}
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.DataFrame
 import org.apache.spark.sql.connect.service.SessionHolder
@@ -90,7 +88,10 @@ object StreamingForeachBatchHelper extends Logging {
     val port = SparkConnectService.localPort
     val connectUrl = s"sc://localhost:$port/;user_id=${sessionHolder.userId}"
     val runner = StreamingPythonRunner(pythonFn, connectUrl)
-    val (dataOut, dataIn) = runner.init(sessionHolder.sessionId)
+    val (dataOut, dataIn) =
+      runner.init(
+        sessionHolder.sessionId,
+        "pyspark.sql.connect.streaming.worker.foreachBatch_worker")
 
     val foreachBatchRunnerFn: FnArgsWithId => Unit = (args: FnArgsWithId) => {
 
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingQueryListenerHelper.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingQueryListenerHelper.scala
new file mode 100644
index 00000000000..d915bc93496
--- /dev/null
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingQueryListenerHelper.scala
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connect.planner
+
+import org.apache.spark.api.python.{PythonRDD, SimplePythonFunction, 
StreamingPythonRunner}
+import org.apache.spark.sql.connect.service.{SessionHolder, 
SparkConnectService}
+import org.apache.spark.sql.streaming.StreamingQueryListener
+
+/**
+ * A helper class for handling StreamingQueryListener related functionality in 
Spark Connect. Each
+ * instance of this class starts a python process, inside which has the python 
handling logic.
+ * When new a event is received, it is serialized to json, and passed to the 
python process.
+ */
+class PythonStreamingQueryListener(
+    listener: SimplePythonFunction,
+    sessionHolder: SessionHolder,
+    pythonExec: String)
+    extends StreamingQueryListener {
+
+  val port = SparkConnectService.localPort
+  val connectUrl = s"sc://localhost:$port/;user_id=${sessionHolder.userId}"
+  val runner = StreamingPythonRunner(listener, connectUrl)
+
+  val (dataOut, _) =
+    runner.init(sessionHolder.sessionId, 
"pyspark.sql.connect.streaming.worker.listener_worker")
+
+  override def onQueryStarted(event: 
StreamingQueryListener.QueryStartedEvent): Unit = {
+    PythonRDD.writeUTF(event.json, dataOut)
+    dataOut.writeInt(0)
+    dataOut.flush()
+  }
+
+  override def onQueryProgress(event: 
StreamingQueryListener.QueryProgressEvent): Unit = {
+    PythonRDD.writeUTF(event.json, dataOut)
+    dataOut.writeInt(1)
+    dataOut.flush()
+  }
+
+  override def onQueryIdle(event: StreamingQueryListener.QueryIdleEvent): Unit 
= {
+    PythonRDD.writeUTF(event.json, dataOut)
+    dataOut.writeInt(2)
+    dataOut.flush()
+  }
+
+  override def onQueryTerminated(event: 
StreamingQueryListener.QueryTerminatedEvent): Unit = {
+    PythonRDD.writeUTF(event.json, dataOut)
+    dataOut.writeInt(3)
+    dataOut.flush()
+  }
+
+  // TODO(SPARK-44433)(SPARK-44516): Improve termination of Processes.
+  // Similar to foreachBatch when we need to exit the process when the query 
ends.
+  // In listener semantics, we need to exit the process when removeListener is 
called.
+}
diff --git 
a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala 
b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
index 6039f8d232b..d5d97b74d11 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
@@ -110,9 +110,9 @@ private[spark] class PythonWorkerFactory(pythonExec: 
String, envVars: Map[String
     }
   }
 
-  /** Creates a Python worker with `pyspark.streaming_worker` module. */
-  def createStreamingWorker(): (Socket, Option[Int]) = {
-    createSimpleWorker("pyspark.streaming_worker")
+  /** Creates a Python worker with streaming worker module. */
+  def createStreamingWorker(streamingWorkerModule: String): (Socket, 
Option[Int]) = {
+    createSimpleWorker(streamingWorkerModule)
   }
 
   /**
diff --git 
a/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala 
b/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala
index faf462a1990..c02871ee145 100644
--- 
a/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala
+++ 
b/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala
@@ -48,9 +48,8 @@ private[spark] class StreamingPythonRunner(func: 
PythonFunction, connectUrl: Str
    * Initializes the Python worker for streaming functions. Sets up Spark 
Connect session
    * to be used with the functions.
    */
-  def init(sessionId: String): (DataOutputStream, DataInputStream) = {
+  def init(sessionId: String, workerModule: String): (DataOutputStream, 
DataInputStream) = {
     logInfo(s"Initializing Python runner (session: $sessionId ,pythonExec: 
$pythonExec")
-
     val env = SparkEnv.get
 
     val localdir = env.blockManager.diskBlockManager.localDirs.map(f => 
f.getPath()).mkString(",")
@@ -62,7 +61,7 @@ private[spark] class StreamingPythonRunner(func: 
PythonFunction, connectUrl: Str
     envVars.put("SPARK_CONNECT_LOCAL_URL", connectUrl)
 
     val pythonWorkerFactory = new PythonWorkerFactory(pythonExec, 
envVars.asScala.toMap)
-    val (worker: Socket, _) = pythonWorkerFactory.createStreamingWorker()
+    val (worker: Socket, _) = 
pythonWorkerFactory.createStreamingWorker(workerModule)
 
     val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize)
     val dataOut = new DataOutputStream(stream)
diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index 79c3f8f26b1..4005c317e62 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -863,6 +863,7 @@ pyspark_connect = Module(
         "pyspark.sql.tests.connect.client.test_artifact",
         "pyspark.sql.tests.connect.client.test_client",
         "pyspark.sql.tests.connect.streaming.test_parity_streaming",
+        "pyspark.sql.tests.connect.streaming.test_parity_listener",
         "pyspark.sql.tests.connect.streaming.test_parity_foreach",
         "pyspark.sql.tests.connect.streaming.test_parity_foreachBatch",
         "pyspark.sql.tests.connect.test_parity_pandas_grouped_map_with_state",
diff --git a/python/pyspark/sql/connect/proto/commands_pb2.py 
b/python/pyspark/sql/connect/proto/commands_pb2.py
index 6f03d80e669..90911e382bf 100644
--- a/python/pyspark/sql/connect/proto/commands_pb2.py
+++ b/python/pyspark/sql/connect/proto/commands_pb2.py
@@ -35,7 +35,7 @@ from pyspark.sql.connect.proto import relations_pb2 as 
spark_dot_connect_dot_rel
 
 
 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-    
b'\n\x1cspark/connect/commands.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1aspark/connect/common.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto"\xf5\x07\n\x07\x43ommand\x12]\n\x11register_function\x18\x01
 
\x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionH\x00R\x10registerFunction\x12H\n\x0fwrite_operation\x18\x02
 
\x01(\x0b\x32\x1d.spark.connect.WriteOperationH\x00R\x0ewriteOperation\x12_\n\x15\x63reate_dataframe_view\x
 [...]
+    
b'\n\x1cspark/connect/commands.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1aspark/connect/common.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto"\xf5\x07\n\x07\x43ommand\x12]\n\x11register_function\x18\x01
 
\x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionH\x00R\x10registerFunction\x12H\n\x0fwrite_operation\x18\x02
 
\x01(\x0b\x32\x1d.spark.connect.WriteOperationH\x00R\x0ewriteOperation\x12_\n\x15\x63reate_dataframe_view\x
 [...]
 )
 
 _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
@@ -115,27 +115,27 @@ if _descriptor._USE_C_DESCRIPTORS == False:
     _STREAMINGQUERYCOMMANDRESULT_AWAITTERMINATIONRESULT._serialized_start = 
6330
     _STREAMINGQUERYCOMMANDRESULT_AWAITTERMINATIONRESULT._serialized_end = 6386
     _STREAMINGQUERYMANAGERCOMMAND._serialized_start = 6404
-    _STREAMINGQUERYMANAGERCOMMAND._serialized_end = 7101
+    _STREAMINGQUERYMANAGERCOMMAND._serialized_end = 7233
     _STREAMINGQUERYMANAGERCOMMAND_AWAITANYTERMINATIONCOMMAND._serialized_start 
= 6935
     _STREAMINGQUERYMANAGERCOMMAND_AWAITANYTERMINATIONCOMMAND._serialized_end = 
7014
-    
_STREAMINGQUERYMANAGERCOMMAND_STREAMINGQUERYLISTENERCOMMAND._serialized_start = 
7016
-    
_STREAMINGQUERYMANAGERCOMMAND_STREAMINGQUERYLISTENERCOMMAND._serialized_end = 
7090
-    _STREAMINGQUERYMANAGERCOMMANDRESULT._serialized_start = 7104
-    _STREAMINGQUERYMANAGERCOMMANDRESULT._serialized_end = 8180
-    _STREAMINGQUERYMANAGERCOMMANDRESULT_ACTIVERESULT._serialized_start = 7712
-    _STREAMINGQUERYMANAGERCOMMANDRESULT_ACTIVERESULT._serialized_end = 7839
-    
_STREAMINGQUERYMANAGERCOMMANDRESULT_STREAMINGQUERYINSTANCE._serialized_start = 
7841
-    _STREAMINGQUERYMANAGERCOMMANDRESULT_STREAMINGQUERYINSTANCE._serialized_end 
= 7956
-    
_STREAMINGQUERYMANAGERCOMMANDRESULT_AWAITANYTERMINATIONRESULT._serialized_start 
= 7958
-    
_STREAMINGQUERYMANAGERCOMMANDRESULT_AWAITANYTERMINATIONRESULT._serialized_end = 
8017
-    
_STREAMINGQUERYMANAGERCOMMANDRESULT_STREAMINGQUERYLISTENERINSTANCE._serialized_start
 = 8019
-    
_STREAMINGQUERYMANAGERCOMMANDRESULT_STREAMINGQUERYLISTENERINSTANCE._serialized_end
 = 8094
-    
_STREAMINGQUERYMANAGERCOMMANDRESULT_LISTSTREAMINGQUERYLISTENERRESULT._serialized_start
 = 8096
-    
_STREAMINGQUERYMANAGERCOMMANDRESULT_LISTSTREAMINGQUERYLISTENERRESULT._serialized_end
 = 8165
-    _GETRESOURCESCOMMAND._serialized_start = 8182
-    _GETRESOURCESCOMMAND._serialized_end = 8203
-    _GETRESOURCESCOMMANDRESULT._serialized_start = 8206
-    _GETRESOURCESCOMMANDRESULT._serialized_end = 8418
-    _GETRESOURCESCOMMANDRESULT_RESOURCESENTRY._serialized_start = 8322
-    _GETRESOURCESCOMMANDRESULT_RESOURCESENTRY._serialized_end = 8418
+    
_STREAMINGQUERYMANAGERCOMMAND_STREAMINGQUERYLISTENERCOMMAND._serialized_start = 
7017
+    
_STREAMINGQUERYMANAGERCOMMAND_STREAMINGQUERYLISTENERCOMMAND._serialized_end = 
7222
+    _STREAMINGQUERYMANAGERCOMMANDRESULT._serialized_start = 7236
+    _STREAMINGQUERYMANAGERCOMMANDRESULT._serialized_end = 8312
+    _STREAMINGQUERYMANAGERCOMMANDRESULT_ACTIVERESULT._serialized_start = 7844
+    _STREAMINGQUERYMANAGERCOMMANDRESULT_ACTIVERESULT._serialized_end = 7971
+    
_STREAMINGQUERYMANAGERCOMMANDRESULT_STREAMINGQUERYINSTANCE._serialized_start = 
7973
+    _STREAMINGQUERYMANAGERCOMMANDRESULT_STREAMINGQUERYINSTANCE._serialized_end 
= 8088
+    
_STREAMINGQUERYMANAGERCOMMANDRESULT_AWAITANYTERMINATIONRESULT._serialized_start 
= 8090
+    
_STREAMINGQUERYMANAGERCOMMANDRESULT_AWAITANYTERMINATIONRESULT._serialized_end = 
8149
+    
_STREAMINGQUERYMANAGERCOMMANDRESULT_STREAMINGQUERYLISTENERINSTANCE._serialized_start
 = 8151
+    
_STREAMINGQUERYMANAGERCOMMANDRESULT_STREAMINGQUERYLISTENERINSTANCE._serialized_end
 = 8226
+    
_STREAMINGQUERYMANAGERCOMMANDRESULT_LISTSTREAMINGQUERYLISTENERRESULT._serialized_start
 = 8228
+    
_STREAMINGQUERYMANAGERCOMMANDRESULT_LISTSTREAMINGQUERYLISTENERRESULT._serialized_end
 = 8297
+    _GETRESOURCESCOMMAND._serialized_start = 8314
+    _GETRESOURCESCOMMAND._serialized_end = 8335
+    _GETRESOURCESCOMMANDRESULT._serialized_start = 8338
+    _GETRESOURCESCOMMANDRESULT._serialized_end = 8550
+    _GETRESOURCESCOMMANDRESULT_RESOURCESENTRY._serialized_start = 8454
+    _GETRESOURCESCOMMANDRESULT_RESOURCESENTRY._serialized_end = 8550
 # @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/commands_pb2.pyi 
b/python/pyspark/sql/connect/proto/commands_pb2.pyi
index be423ea036e..f3dca7ab4bb 100644
--- a/python/pyspark/sql/connect/proto/commands_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/commands_pb2.pyi
@@ -1372,15 +1372,50 @@ class 
StreamingQueryManagerCommand(google.protobuf.message.Message):
         DESCRIPTOR: google.protobuf.descriptor.Descriptor
 
         LISTENER_PAYLOAD_FIELD_NUMBER: builtins.int
+        PYTHON_LISTENER_PAYLOAD_FIELD_NUMBER: builtins.int
+        ID_FIELD_NUMBER: builtins.int
         listener_payload: builtins.bytes
+        @property
+        def python_listener_payload(
+            self,
+        ) -> pyspark.sql.connect.proto.expressions_pb2.PythonUDF: ...
+        id: builtins.str
         def __init__(
             self,
             *,
             listener_payload: builtins.bytes = ...,
+            python_listener_payload: 
pyspark.sql.connect.proto.expressions_pb2.PythonUDF
+            | None = ...,
+            id: builtins.str = ...,
         ) -> None: ...
+        def HasField(
+            self,
+            field_name: typing_extensions.Literal[
+                "_python_listener_payload",
+                b"_python_listener_payload",
+                "python_listener_payload",
+                b"python_listener_payload",
+            ],
+        ) -> builtins.bool: ...
         def ClearField(
-            self, field_name: typing_extensions.Literal["listener_payload", 
b"listener_payload"]
+            self,
+            field_name: typing_extensions.Literal[
+                "_python_listener_payload",
+                b"_python_listener_payload",
+                "id",
+                b"id",
+                "listener_payload",
+                b"listener_payload",
+                "python_listener_payload",
+                b"python_listener_payload",
+            ],
         ) -> None: ...
+        def WhichOneof(
+            self,
+            oneof_group: typing_extensions.Literal[
+                "_python_listener_payload", b"_python_listener_payload"
+            ],
+        ) -> typing_extensions.Literal["python_listener_payload"] | None: ...
 
     ACTIVE_FIELD_NUMBER: builtins.int
     GET_QUERY_FIELD_NUMBER: builtins.int
diff --git a/python/pyspark/sql/connect/streaming/query.py 
b/python/pyspark/sql/connect/streaming/query.py
index e5aa881c990..59e98e7bc30 100644
--- a/python/pyspark/sql/connect/streaming/query.py
+++ b/python/pyspark/sql/connect/streaming/query.py
@@ -21,6 +21,9 @@ from typing import TYPE_CHECKING, Any, cast, Dict, List, 
Optional
 
 from pyspark.errors import StreamingQueryException, PySparkValueError
 import pyspark.sql.connect.proto as pb2
+from pyspark.serializers import CloudPickleSerializer
+from pyspark.sql.connect import proto
+from pyspark.sql.streaming import StreamingQueryListener
 from pyspark.sql.streaming.query import (
     StreamingQuery as PySparkStreamingQuery,
     StreamingQueryManager as PySparkStreamingQueryManager,
@@ -226,25 +229,27 @@ class StreamingQueryManager:
         cmd = pb2.StreamingQueryManagerCommand()
         cmd.reset_terminated = True
         self._execute_streaming_query_manager_cmd(cmd)
-        return None
 
     resetTerminated.__doc__ = 
PySparkStreamingQueryManager.resetTerminated.__doc__
 
-    def addListener(self, listener: Any) -> None:
-        # TODO(SPARK-42941): Change listener type to Connect 
StreamingQueryListener
-        # and implement below
-        raise NotImplementedError("addListener() is not implemented.")
+    def addListener(self, listener: StreamingQueryListener) -> None:
+        listener._init_listener_id()
+        cmd = pb2.StreamingQueryManagerCommand()
+        expr = proto.PythonUDF()
+        expr.command = CloudPickleSerializer().dumps(listener)
+        expr.python_ver = "%d.%d" % sys.version_info[:2]
+        cmd.add_listener.python_listener_payload.CopyFrom(expr)
+        cmd.add_listener.id = listener._id
+        self._execute_streaming_query_manager_cmd(cmd)
 
-    # TODO(SPARK-42941): uncomment below
-    # addListener.__doc__ = PySparkStreamingQueryManager.addListener.__doc__
+    addListener.__doc__ = PySparkStreamingQueryManager.addListener.__doc__
 
-    def removeListener(self, listener: Any) -> None:
-        # TODO(SPARK-42941): Change listener type to Connect 
StreamingQueryListener
-        # and implement below
-        raise NotImplementedError("removeListener() is not implemented.")
+    def removeListener(self, listener: StreamingQueryListener) -> None:
+        cmd = pb2.StreamingQueryManagerCommand()
+        cmd.remove_listener.id = listener._id
+        self._execute_streaming_query_manager_cmd(cmd)
 
-    # TODO(SPARK-42941): uncomment below
-    # removeListener.__doc__ = 
PySparkStreamingQueryManager.removeListener.__doc__
+    removeListener.__doc__ = 
PySparkStreamingQueryManager.removeListener.__doc__
 
     def _execute_streaming_query_manager_cmd(
         self, cmd: pb2.StreamingQueryManagerCommand
diff --git a/python/pyspark/sql/connect/streaming/worker/__init__.py 
b/python/pyspark/sql/connect/streaming/worker/__init__.py
new file mode 100644
index 00000000000..a5c98019891
--- /dev/null
+++ b/python/pyspark/sql/connect/streaming/worker/__init__.py
@@ -0,0 +1,18 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Spark Connect Streaming Server-side Worker"""
diff --git a/python/pyspark/streaming_worker.py 
b/python/pyspark/sql/connect/streaming/worker/foreachBatch_worker.py
similarity index 86%
copy from python/pyspark/streaming_worker.py
copy to python/pyspark/sql/connect/streaming/worker/foreachBatch_worker.py
index a818880a984..054788539f2 100644
--- a/python/pyspark/streaming_worker.py
+++ b/python/pyspark/sql/connect/streaming/worker/foreachBatch_worker.py
@@ -16,7 +16,8 @@
 #
 
 """
-A worker for streaming foreachBatch and query listener in Spark Connect.
+A worker for streaming foreachBatch in Spark Connect.
+Usually this is ran on the driver side of the Spark Connect Server.
 """
 import os
 
@@ -29,20 +30,23 @@ from pyspark.serializers import (
 )
 from pyspark import worker
 from pyspark.sql import SparkSession
+from typing import IO
 
 pickle_ser = CPickleSerializer()
 utf8_deserializer = UTF8Deserializer()
 
 
-def main(infile, outfile):  # type: ignore[no-untyped-def]
-    log_name = "Streaming ForeachBatch worker"
+def main(infile: IO, outfile: IO) -> None:
     connect_url = os.environ["SPARK_CONNECT_LOCAL_URL"]
     session_id = utf8_deserializer.loads(infile)
 
-    print(f"{log_name} is starting with url {connect_url} and sessionId 
{session_id}.")
+    print(
+        "Streaming foreachBatch worker is starting with "
+        f"url {connect_url} and sessionId {session_id}."
+    )
 
     spark_connect_session = 
SparkSession.builder.remote(connect_url).getOrCreate()
-    spark_connect_session._client._session_id = session_id
+    spark_connect_session._client._session_id = session_id  # type: 
ignore[attr-defined]
 
     # TODO(SPARK-44460): Pass credentials.
     # TODO(SPARK-44461): Enable Process Isolation
@@ -52,6 +56,8 @@ def main(infile, outfile):  # type: ignore[no-untyped-def]
 
     outfile.flush()
 
+    log_name = "Streaming ForeachBatch worker"
+
     def process(df_id, batch_id):  # type: ignore[no-untyped-def]
         print(f"{log_name} Started batch {batch_id} with DF id {df_id}")
         batch_df = spark_connect_session._create_remote_dataframe(df_id)
@@ -67,8 +73,6 @@ def main(infile, outfile):  # type: ignore[no-untyped-def]
 
 
 if __name__ == "__main__":
-    print("Starting streaming worker")
-
     # Read information about how to connect back to the JVM from the 
environment.
     java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"])
     auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"]
diff --git a/python/pyspark/streaming_worker.py 
b/python/pyspark/sql/connect/streaming/worker/listener_worker.py
similarity index 55%
rename from python/pyspark/streaming_worker.py
rename to python/pyspark/sql/connect/streaming/worker/listener_worker.py
index a818880a984..8eb310461b6 100644
--- a/python/pyspark/streaming_worker.py
+++ b/python/pyspark/sql/connect/streaming/worker/listener_worker.py
@@ -16,59 +16,76 @@
 #
 
 """
-A worker for streaming foreachBatch and query listener in Spark Connect.
+A worker for streaming query listener in Spark Connect.
+Usually this is ran on the driver side of the Spark Connect Server.
 """
 import os
+import json
 
 from pyspark.java_gateway import local_connect_and_auth
 from pyspark.serializers import (
+    read_int,
     write_int,
-    read_long,
     UTF8Deserializer,
     CPickleSerializer,
 )
 from pyspark import worker
 from pyspark.sql import SparkSession
+from typing import IO
+
+from pyspark.sql.streaming.listener import (
+    QueryStartedEvent,
+    QueryProgressEvent,
+    QueryTerminatedEvent,
+    QueryIdleEvent,
+)
 
 pickle_ser = CPickleSerializer()
 utf8_deserializer = UTF8Deserializer()
 
 
-def main(infile, outfile):  # type: ignore[no-untyped-def]
-    log_name = "Streaming ForeachBatch worker"
+def main(infile: IO, outfile: IO) -> None:
     connect_url = os.environ["SPARK_CONNECT_LOCAL_URL"]
     session_id = utf8_deserializer.loads(infile)
 
-    print(f"{log_name} is starting with url {connect_url} and sessionId 
{session_id}.")
+    print(
+        "Streaming query listener worker is starting with "
+        f"url {connect_url} and sessionId {session_id}."
+    )
 
     spark_connect_session = 
SparkSession.builder.remote(connect_url).getOrCreate()
-    spark_connect_session._client._session_id = session_id
+    spark_connect_session._client._session_id = session_id  # type: 
ignore[attr-defined]
 
     # TODO(SPARK-44460): Pass credentials.
     # TODO(SPARK-44461): Enable Process Isolation
 
-    func = worker.read_command(pickle_ser, infile)
+    listener = worker.read_command(pickle_ser, infile)
     write_int(0, outfile)  # Indicate successful initialization
 
     outfile.flush()
 
-    def process(df_id, batch_id):  # type: ignore[no-untyped-def]
-        print(f"{log_name} Started batch {batch_id} with DF id {df_id}")
-        batch_df = spark_connect_session._create_remote_dataframe(df_id)
-        func(batch_df, batch_id)
-        print(f"{log_name} Completed batch {batch_id} with DF id {df_id}")
+    listener._set_spark_session(spark_connect_session)
+    assert listener.spark == spark_connect_session
+
+    def process(listener_event_str, listener_event_type):  # type: 
ignore[no-untyped-def]
+        listener_event = json.loads(listener_event_str)
+        if listener_event_type == 0:
+            listener.onQueryStarted(QueryStartedEvent.fromJson(listener_event))
+        elif listener_event_type == 1:
+            
listener.onQueryProgress(QueryProgressEvent.fromJson(listener_event))
+        elif listener_event_type == 2:
+            listener.onQueryIdle(QueryIdleEvent.fromJson(listener_event))
+        elif listener_event_type == 3:
+            
listener.onQueryTerminated(QueryTerminatedEvent.fromJson(listener_event))
 
     while True:
-        df_ref_id = utf8_deserializer.loads(infile)
-        batch_id = read_long(infile)
-        process(df_ref_id, int(batch_id))  # TODO(SPARK-44463): Propagate 
error to the user.
-        write_int(0, outfile)
+        event = utf8_deserializer.loads(infile)
+        event_type = read_int(infile)
+        process(event, int(event_type))  # TODO(SPARK-44463): Propagate error 
to the user.
         outfile.flush()
 
 
 if __name__ == "__main__":
-    print("Starting streaming worker")
-
     # Read information about how to connect back to the JVM from the 
environment.
     java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"])
     auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"]
diff --git a/python/pyspark/sql/streaming/listener.py 
b/python/pyspark/sql/streaming/listener.py
index 198af0c9cbe..225ad6d45af 100644
--- a/python/pyspark/sql/streaming/listener.py
+++ b/python/pyspark/sql/streaming/listener.py
@@ -62,6 +62,21 @@ class StreamingQueryListener(ABC):
     >>> spark.streams.addListener(MyListener())
     """
 
+    def _set_spark_session(
+        self, spark: "SparkSession"  # type: ignore[name-defined] # noqa: F821
+    ) -> None:
+        self._sparkSession = spark
+
+    @property
+    def spark(self) -> Optional["SparkSession"]:  # type: ignore[name-defined] 
# noqa: F821
+        if hasattr(self, "_sparkSession"):
+            return self._sparkSession
+        else:
+            return None
+
+    def _init_listener_id(self) -> None:
+        self._id = str(uuid.uuid4())
+
     @abstractmethod
     def onQueryStarted(self, event: "QueryStartedEvent") -> None:
         """
@@ -463,8 +478,8 @@ class StreamingQueryProgress:
             timestamp=j["timestamp"],
             batchId=j["batchId"],
             batchDuration=j["batchDuration"],
-            durationMs=dict(j["durationMs"]),
-            eventTime=dict(j["eventTime"]),
+            durationMs=dict(j["durationMs"]) if "durationMs" in j else {},
+            eventTime=dict(j["eventTime"]) if "eventTime" in j else {},
             stateOperators=[StateOperatorProgress.fromJson(s) for s in 
j["stateOperators"]],
             sources=[SourceProgress.fromJson(s) for s in j["sources"]],
             sink=SinkProgress.fromJson(j["sink"]),
@@ -474,7 +489,9 @@ class StreamingQueryProgress:
             observedMetrics={
                 k: Row(*row_dict.keys())(*row_dict.values())  # Assume no 
nested rows
                 for k, row_dict in j["observedMetrics"].items()
-            },
+            }
+            if "observedMetrics" in j
+            else {},
         )
 
     @property
@@ -696,7 +713,7 @@ class StateOperatorProgress:
             numRowsDroppedByWatermark=j["numRowsDroppedByWatermark"],
             numShufflePartitions=j["numShufflePartitions"],
             numStateStoreInstances=j["numStateStoreInstances"],
-            customMetrics=dict(j["customMetrics"]),
+            customMetrics=dict(j["customMetrics"]) if "customMetrics" in j 
else {},
         )
 
     @property
@@ -831,7 +848,7 @@ class SourceProgress:
             numInputRows=j["numInputRows"],
             inputRowsPerSecond=j["inputRowsPerSecond"],
             processedRowsPerSecond=j["processedRowsPerSecond"],
-            metrics=dict(j["metrics"]),
+            metrics=dict(j["metrics"]) if "metrics" in j else {},
         )
 
     @property
@@ -951,7 +968,7 @@ class SinkProgress:
             jdict=j,
             description=j["description"],
             numOutputRows=j["numOutputRows"],
-            metrics=j["metrics"],
+            metrics=dict(j["metrics"]) if "metrics" in j else {},
         )
 
     @property
diff --git a/python/pyspark/sql/streaming/query.py 
b/python/pyspark/sql/streaming/query.py
index 443e7dbee39..db104e30755 100644
--- a/python/pyspark/sql/streaming/query.py
+++ b/python/pyspark/sql/streaming/query.py
@@ -618,12 +618,24 @@ class StreamingQueryManager:
 
         .. versionadded:: 3.4.0
 
+        .. versionchanged:: 3.5.0
+            Supports Spark Connect.
+
         Parameters
         ----------
         listener : :class:`StreamingQueryListener`
             A :class:`StreamingQueryListener` to receive up-calls for life 
cycle events of
             :class:`~pyspark.sql.streaming.StreamingQuery`.
 
+        Notes
+        -----
+        This function behaves differently in Spark Connect mode.
+        In Connect, the provided functions doesn't have access to variables 
defined outside of it.
+        Also in Connect, you need to use `self.spark` to access spark session.
+        Using `spark` would throw an exception.
+        In short, if you want to use spark session inside the listener,
+        please use `self.spark` in Connect mode, and use `spark` otherwise.
+
         Examples
         --------
         >>> from pyspark.sql.streaming import StreamingQueryListener
diff --git a/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py 
b/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py
new file mode 100644
index 00000000000..547462d4da6
--- /dev/null
+++ b/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py
@@ -0,0 +1,90 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+import time
+
+from pyspark.sql.tests.streaming.test_streaming_listener import 
StreamingListenerTestsMixin
+from pyspark.sql.streaming.listener import StreamingQueryListener, 
QueryStartedEvent
+from pyspark.sql.types import StructType, StructField, StringType
+from pyspark.testing.connectutils import ReusedConnectTestCase
+
+
+def get_start_event_schema():
+    return StructType(
+        [
+            StructField("id", StringType(), True),
+            StructField("runId", StringType(), True),
+            StructField("name", StringType(), True),
+            StructField("timestamp", StringType(), True),
+        ]
+    )
+
+
+class TestListener(StreamingQueryListener):
+    def onQueryStarted(self, event):
+        df = self.spark.createDataFrame(
+            data=[(str(event.id), str(event.runId), event.name, 
event.timestamp)],
+            schema=get_start_event_schema(),
+        )
+        df.write.saveAsTable("listener_start_events")
+
+    def onQueryProgress(self, event):
+        pass
+
+    def onQueryIdle(self, event):
+        pass
+
+    def onQueryTerminated(self, event):
+        pass
+
+
+class StreamingListenerParityTests(StreamingListenerTestsMixin, 
ReusedConnectTestCase):
+    def test_listener_events(self):
+        test_listener = TestListener()
+
+        try:
+            self.spark.streams.addListener(test_listener)
+
+            df = self.spark.readStream.format("rate").option("rowsPerSecond", 
10).load()
+            q = df.writeStream.format("noop").queryName("test").start()
+
+            self.assertTrue(q.isActive)
+            time.sleep(10)
+            q.stop()
+
+            start_event = QueryStartedEvent.fromJson(
+                
self.spark.read.table("listener_start_events").collect()[0].asDict()
+            )
+
+            self.check_start_event(start_event)
+
+        finally:
+            self.spark.streams.removeListener(test_listener)
+
+
+if __name__ == "__main__":
+    import unittest
+    from pyspark.sql.tests.connect.streaming.test_parity_listener import *  # 
noqa: F401
+
+    try:
+        import xmlrunner  # type: ignore[import]
+
+        testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", 
verbosity=2)
+    except ImportError:
+        testRunner = None
+    unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/sql/tests/streaming/test_streaming_listener.py 
b/python/pyspark/sql/tests/streaming/test_streaming_listener.py
index 2bd6d2c6668..cbbdc2955e5 100644
--- a/python/pyspark/sql/tests/streaming/test_streaming_listener.py
+++ b/python/pyspark/sql/tests/streaming/test_streaming_listener.py
@@ -33,119 +33,7 @@ from pyspark.sql.streaming.listener import (
 from pyspark.testing.sqlutils import ReusedSQLTestCase
 
 
-class StreamingListenerTests(ReusedSQLTestCase):
-    def test_number_of_public_methods(self):
-        msg = (
-            "New field or method was detected in JVM side. If you added a new 
public "
-            "field or method, implement that in the corresponding Python class 
too."
-            "Otherwise, fix the number on the assert here."
-        )
-
-        def get_number_of_public_methods(clz):
-            return len(
-                
self.spark.sparkContext._jvm.org.apache.spark.util.Utils.classForName(
-                    clz, True, False
-                ).getMethods()
-            )
-
-        self.assertEquals(
-            get_number_of_public_methods(
-                
"org.apache.spark.sql.streaming.StreamingQueryListener$QueryStartedEvent"
-            ),
-            15,
-            msg,
-        )
-        self.assertEquals(
-            get_number_of_public_methods(
-                
"org.apache.spark.sql.streaming.StreamingQueryListener$QueryProgressEvent"
-            ),
-            12,
-            msg,
-        )
-        self.assertEquals(
-            get_number_of_public_methods(
-                
"org.apache.spark.sql.streaming.StreamingQueryListener$QueryTerminatedEvent"
-            ),
-            15,
-            msg,
-        )
-        self.assertEquals(
-            
get_number_of_public_methods("org.apache.spark.sql.streaming.StreamingQueryProgress"),
-            38,
-            msg,
-        )
-        self.assertEquals(
-            
get_number_of_public_methods("org.apache.spark.sql.streaming.StateOperatorProgress"),
-            27,
-            msg,
-        )
-        self.assertEquals(
-            
get_number_of_public_methods("org.apache.spark.sql.streaming.SourceProgress"), 
21, msg
-        )
-        self.assertEquals(
-            
get_number_of_public_methods("org.apache.spark.sql.streaming.SinkProgress"), 
19, msg
-        )
-
-    def test_listener_events(self):
-        start_event = None
-        progress_event = None
-        terminated_event = None
-
-        class TestListener(StreamingQueryListener):
-            def onQueryStarted(self, event):
-                nonlocal start_event
-                start_event = event
-
-            def onQueryProgress(self, event):
-                nonlocal progress_event
-                progress_event = event
-
-            def onQueryIdle(self, event):
-                pass
-
-            def onQueryTerminated(self, event):
-                nonlocal terminated_event
-                terminated_event = event
-
-        test_listener = TestListener()
-
-        try:
-            self.spark.streams.addListener(test_listener)
-
-            df = self.spark.readStream.format("rate").option("rowsPerSecond", 
10).load()
-
-            # check successful stateful query
-            df_stateful = df.groupBy().count()  # make query stateful
-            q = (
-                df_stateful.writeStream.format("noop")
-                .queryName("test")
-                .outputMode("complete")
-                .start()
-            )
-            self.assertTrue(q.isActive)
-            time.sleep(10)
-            q.stop()
-
-            # Make sure all events are empty
-            self.spark.sparkContext._jsc.sc().listenerBus().waitUntilEmpty()
-
-            self.check_start_event(start_event)
-            self.check_progress_event(progress_event)
-            self.check_terminated_event(terminated_event)
-
-            # Check query terminated with exception
-            from pyspark.sql.functions import col, udf
-
-            bad_udf = udf(lambda x: 1 / 0)
-            q = 
df.select(bad_udf(col("value"))).writeStream.format("noop").start()
-            time.sleep(5)
-            q.stop()
-            self.spark.sparkContext._jsc.sc().listenerBus().waitUntilEmpty()
-            self.check_terminated_event(terminated_event, "ZeroDivisionError")
-
-        finally:
-            self.spark.streams.removeListener(test_listener)
-
+class StreamingListenerTestsMixin:
     def check_start_event(self, event):
         """Check QueryStartedEvent"""
         self.assertTrue(isinstance(event, QueryStartedEvent))
@@ -304,6 +192,120 @@ class StreamingListenerTests(ReusedSQLTestCase):
         self.assertTrue(isinstance(progress.numOutputRows, int))
         self.assertTrue(isinstance(progress.metrics, dict))
 
+
+class StreamingListenerTests(StreamingListenerTestsMixin, ReusedSQLTestCase):
+    def test_number_of_public_methods(self):
+        msg = (
+            "New field or method was detected in JVM side. If you added a new 
public "
+            "field or method, implement that in the corresponding Python class 
too."
+            "Otherwise, fix the number on the assert here."
+        )
+
+        def get_number_of_public_methods(clz):
+            return len(
+                
self.spark.sparkContext._jvm.org.apache.spark.util.Utils.classForName(
+                    clz, True, False
+                ).getMethods()
+            )
+
+        self.assertEquals(
+            get_number_of_public_methods(
+                
"org.apache.spark.sql.streaming.StreamingQueryListener$QueryStartedEvent"
+            ),
+            15,
+            msg,
+        )
+        self.assertEquals(
+            get_number_of_public_methods(
+                
"org.apache.spark.sql.streaming.StreamingQueryListener$QueryProgressEvent"
+            ),
+            12,
+            msg,
+        )
+        self.assertEquals(
+            get_number_of_public_methods(
+                
"org.apache.spark.sql.streaming.StreamingQueryListener$QueryTerminatedEvent"
+            ),
+            15,
+            msg,
+        )
+        self.assertEquals(
+            
get_number_of_public_methods("org.apache.spark.sql.streaming.StreamingQueryProgress"),
+            38,
+            msg,
+        )
+        self.assertEquals(
+            
get_number_of_public_methods("org.apache.spark.sql.streaming.StateOperatorProgress"),
+            27,
+            msg,
+        )
+        self.assertEquals(
+            
get_number_of_public_methods("org.apache.spark.sql.streaming.SourceProgress"), 
21, msg
+        )
+        self.assertEquals(
+            
get_number_of_public_methods("org.apache.spark.sql.streaming.SinkProgress"), 
19, msg
+        )
+
+    def test_listener_events(self):
+        start_event = None
+        progress_event = None
+        terminated_event = None
+
+        class TestListener(StreamingQueryListener):
+            def onQueryStarted(self, event):
+                nonlocal start_event
+                start_event = event
+
+            def onQueryProgress(self, event):
+                nonlocal progress_event
+                progress_event = event
+
+            def onQueryIdle(self, event):
+                pass
+
+            def onQueryTerminated(self, event):
+                nonlocal terminated_event
+                terminated_event = event
+
+        test_listener = TestListener()
+
+        try:
+            self.spark.streams.addListener(test_listener)
+
+            df = self.spark.readStream.format("rate").option("rowsPerSecond", 
10).load()
+
+            # check successful stateful query
+            df_stateful = df.groupBy().count()  # make query stateful
+            q = (
+                df_stateful.writeStream.format("noop")
+                .queryName("test")
+                .outputMode("complete")
+                .start()
+            )
+            self.assertTrue(q.isActive)
+            time.sleep(10)
+            q.stop()
+
+            # Make sure all events are empty
+            self.spark.sparkContext._jsc.sc().listenerBus().waitUntilEmpty()
+
+            self.check_start_event(start_event)
+            self.check_progress_event(progress_event)
+            self.check_terminated_event(terminated_event)
+
+            # Check query terminated with exception
+            from pyspark.sql.functions import col, udf
+
+            bad_udf = udf(lambda x: 1 / 0)
+            q = 
df.select(bad_udf(col("value"))).writeStream.format("noop").start()
+            time.sleep(5)
+            q.stop()
+            self.spark.sparkContext._jsc.sc().listenerBus().waitUntilEmpty()
+            self.check_terminated_event(terminated_event, "ZeroDivisionError")
+
+        finally:
+            self.spark.streams.removeListener(test_listener)
+
     def test_remove_listener(self):
         # SPARK-38804: Test StreamingQueryManager.removeListener
         class TestListener(StreamingQueryListener):
diff --git a/python/setup.py b/python/setup.py
index 4d14dfd3cb9..b8e4c9a40e0 100755
--- a/python/setup.py
+++ b/python/setup.py
@@ -254,6 +254,7 @@ try:
             "pyspark.sql.protobuf",
             "pyspark.sql.streaming",
             "pyspark.streaming",
+            "pyspark.sql.connect.streaming.worker",
             "pyspark.bin",
             "pyspark.sbin",
             "pyspark.jars",


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to