This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 5c36c580477 [SPARK-44433][PYTHON][CONNECT][SS][FOLLOWUP] Terminate 
listener process with `removeListener` and improvements
5c36c580477 is described below

commit 5c36c58047724885864cb781f17038a6b9c94513
Author: Wei Liu <wei....@databricks.com>
AuthorDate: Fri Aug 4 09:14:05 2023 +0900

    [SPARK-44433][PYTHON][CONNECT][SS][FOLLOWUP] Terminate listener process 
with `removeListener` and improvements
    
    ### What changes were proposed in this pull request?
    
    This is a followup to #42116. It addresses the following issues:
    
    1. When `removeListener` is called upon one listener, before the python 
process is left running, now it also get stopped.
    2. When multiple `removeListener` is called on the same listener, in 
non-connect mode, subsequent calls will be noop. But before this PR, in connect 
it actually throws an error, which doesn't align with existing behavior, this 
PR addresses it.
    3. Set the socket timeout to be None (\infty) for `foreachBatch_worker` and 
`listener_worker`, because there could be a long time between each microbatch. 
If not setting this, the socket will timeout and won't be able to process new 
data.
    
    ```
    scala> Streaming query listener worker is starting with url 
sc://localhost:15002/;user_id=wei.liu and sessionId 
886191f0-2b64-4c44-b067-de511f04b42d.
    Traceback (most recent call last):
      File "/usr/lib/python3.9/runpy.py", line 197, in _run_module_as_main
        return _run_code(code, main_globals, None,
      File "/usr/lib/python3.9/runpy.py", line 87, in _run_code
        exec(code, run_globals)
      File 
"/home/wei.liu/oss-spark/python/lib/pyspark.zip/pyspark/sql/connect/streaming/worker/listener_worker.py",
 line 95, in <module>
      File 
"/home/wei.liu/oss-spark/python/lib/pyspark.zip/pyspark/sql/connect/streaming/worker/listener_worker.py",
 line 82, in main
      File 
"/home/wei.liu/oss-spark/python/lib/pyspark.zip/pyspark/serializers.py", line 
557, in loads
      File 
"/home/wei.liu/oss-spark/python/lib/pyspark.zip/pyspark/serializers.py", line 
594, in read_int
      File "/usr/lib/python3.9/socket.py", line 704, in readinto
        return self._sock.recv_into(b)
    socket.timeout: timed out
    ```
    
    ### Why are the changes needed?
    
    Necessary improvements
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    Manual test + unit test
    
    Closes #42283 from WweiL/SPARK-44433-listener-process-termination.
    
    Authored-by: Wei Liu <wei....@databricks.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 .../sql/streaming/StreamingQueryListener.scala     | 28 -----------------
 .../sql/connect/planner/SparkConnectPlanner.scala  | 12 +++++---
 .../planner/StreamingForeachBatchHelper.scala      | 10 +++---
 .../planner/StreamingQueryListenerHelper.scala     | 21 +++++++------
 .../spark/sql/connect/service/SessionHolder.scala  | 19 +++++++-----
 .../spark/api/python/StreamingPythonRunner.scala   | 36 ++++++++++++++++------
 .../streaming/worker/foreachBatch_worker.py        |  4 ++-
 .../connect/streaming/worker/listener_worker.py    |  4 ++-
 .../connect/streaming/test_parity_listener.py      |  7 +++++
 9 files changed, 77 insertions(+), 64 deletions(-)

diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala
index e2f3be02ad3..404bd1b078b 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala
@@ -75,34 +75,6 @@ abstract class StreamingQueryListener extends Serializable {
   def onQueryTerminated(event: QueryTerminatedEvent): Unit
 }
 
-/**
- * Py4J allows a pure interface so this proxy is required.
- */
-private[spark] trait PythonStreamingQueryListener {
-  import StreamingQueryListener._
-
-  def onQueryStarted(event: QueryStartedEvent): Unit
-
-  def onQueryProgress(event: QueryProgressEvent): Unit
-
-  def onQueryIdle(event: QueryIdleEvent): Unit
-
-  def onQueryTerminated(event: QueryTerminatedEvent): Unit
-}
-
-private[spark] class PythonStreamingQueryListenerWrapper(listener: 
PythonStreamingQueryListener)
-    extends StreamingQueryListener {
-  import StreamingQueryListener._
-
-  def onQueryStarted(event: QueryStartedEvent): Unit = 
listener.onQueryStarted(event)
-
-  def onQueryProgress(event: QueryProgressEvent): Unit = 
listener.onQueryProgress(event)
-
-  override def onQueryIdle(event: QueryIdleEvent): Unit = 
listener.onQueryIdle(event)
-
-  def onQueryTerminated(event: QueryTerminatedEvent): Unit = 
listener.onQueryTerminated(event)
-}
-
 /**
  * Companion object of [[StreamingQueryListener]] that defines the listener 
events.
  * @since 3.5.0
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index f4b33ae961a..7136476b515 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -3097,10 +3097,14 @@ class SparkConnectPlanner(val sessionHolder: 
SessionHolder) extends Logging {
 
       case StreamingQueryManagerCommand.CommandCase.REMOVE_LISTENER =>
         val listenerId = command.getRemoveListener.getId
-        val listener: StreamingQueryListener = 
sessionHolder.getListenerOrThrow(listenerId)
-        session.streams.removeListener(listener)
-        sessionHolder.removeCachedListener(listenerId)
-        respBuilder.setRemoveListener(true)
+        sessionHolder.getListener(listenerId) match {
+          case Some(listener) =>
+            session.streams.removeListener(listener)
+            sessionHolder.removeCachedListener(listenerId)
+            respBuilder.setRemoveListener(true)
+          case None =>
+            respBuilder.setRemoveListener(false)
+        }
 
       case StreamingQueryManagerCommand.CommandCase.LIST_LISTENERS =>
         respBuilder.getListListenersBuilder
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala
index 998faf327d0..4f1037b86c9 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala
@@ -87,11 +87,13 @@ object StreamingForeachBatchHelper extends Logging {
 
     val port = SparkConnectService.localPort
     val connectUrl = s"sc://localhost:$port/;user_id=${sessionHolder.userId}"
-    val runner = StreamingPythonRunner(pythonFn, connectUrl)
+    val runner = StreamingPythonRunner(
+      pythonFn,
+      connectUrl,
+      sessionHolder.sessionId,
+      "pyspark.sql.connect.streaming.worker.foreachBatch_worker")
     val (dataOut, dataIn) =
-      runner.init(
-        sessionHolder.sessionId,
-        "pyspark.sql.connect.streaming.worker.foreachBatch_worker")
+      runner.init()
 
     val foreachBatchRunnerFn: FnArgsWithId => Unit = (args: FnArgsWithId) => {
 
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingQueryListenerHelper.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingQueryListenerHelper.scala
index d915bc93496..9b2a931ec4a 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingQueryListenerHelper.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingQueryListenerHelper.scala
@@ -24,7 +24,7 @@ import org.apache.spark.sql.streaming.StreamingQueryListener
 /**
  * A helper class for handling StreamingQueryListener related functionality in 
Spark Connect. Each
  * instance of this class starts a python process, inside which has the python 
handling logic.
- * When new a event is received, it is serialized to json, and passed to the 
python process.
+ * When a new event is received, it is serialized to json, and passed to the 
python process.
  */
 class PythonStreamingQueryListener(
     listener: SimplePythonFunction,
@@ -32,12 +32,15 @@ class PythonStreamingQueryListener(
     pythonExec: String)
     extends StreamingQueryListener {
 
-  val port = SparkConnectService.localPort
-  val connectUrl = s"sc://localhost:$port/;user_id=${sessionHolder.userId}"
-  val runner = StreamingPythonRunner(listener, connectUrl)
+  private val port = SparkConnectService.localPort
+  private val connectUrl = 
s"sc://localhost:$port/;user_id=${sessionHolder.userId}"
+  private val runner = StreamingPythonRunner(
+    listener,
+    connectUrl,
+    sessionHolder.sessionId,
+    "pyspark.sql.connect.streaming.worker.listener_worker")
 
-  val (dataOut, _) =
-    runner.init(sessionHolder.sessionId, 
"pyspark.sql.connect.streaming.worker.listener_worker")
+  val (dataOut, _) = runner.init()
 
   override def onQueryStarted(event: 
StreamingQueryListener.QueryStartedEvent): Unit = {
     PythonRDD.writeUTF(event.json, dataOut)
@@ -63,7 +66,7 @@ class PythonStreamingQueryListener(
     dataOut.flush()
   }
 
-  // TODO(SPARK-44433)(SPARK-44516): Improve termination of Processes.
-  // Similar to foreachBatch when we need to exit the process when the query 
ends.
-  // In listener semantics, we need to exit the process when removeListener is 
called.
+  private[spark] def stopListenerProcess(): Unit = {
+    runner.stop()
+  }
 }
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala
index 310bb9208c2..29134f0dc0d 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala
@@ -31,6 +31,7 @@ import org.apache.spark.sql.DataFrame
 import org.apache.spark.sql.SparkSession
 import org.apache.spark.sql.connect.artifact.SparkConnectArtifactManager
 import org.apache.spark.sql.connect.common.InvalidPlanInput
+import org.apache.spark.sql.connect.planner.PythonStreamingQueryListener
 import org.apache.spark.sql.streaming.StreamingQueryListener
 import org.apache.spark.util.{SystemClock}
 import org.apache.spark.util.Utils
@@ -220,20 +221,22 @@ case class SessionHolder(userId: String, sessionId: 
String, session: SparkSessio
   }
 
   /**
-   * Returns [[StreamingQueryListener]] cached for Listener ID `id`. If it is 
not found, throw
-   * [[InvalidPlanInput]].
+   * Returns [[StreamingQueryListener]] cached for Listener ID `id`. If it is 
not found, return
+   * None.
    */
-  private[connect] def getListenerOrThrow(id: String): StreamingQueryListener 
= {
+  private[connect] def getListener(id: String): Option[StreamingQueryListener] 
= {
     Option(listenerCache.get(id))
-      .getOrElse {
-        throw InvalidPlanInput(s"No listener with id $id is found in the 
session $sessionId")
-      }
   }
 
   /**
-   * Removes corresponding StreamingQueryListener by ID.
+   * Removes corresponding StreamingQueryListener by ID. Terminates the python 
process if it's a
+   * Spark Connect PythonStreamingQueryListener.
    */
-  private[connect] def removeCachedListener(id: String): 
StreamingQueryListener = {
+  private[connect] def removeCachedListener(id: String): Unit = {
+    listenerCache.get(id) match {
+      case pyListener: PythonStreamingQueryListener => 
pyListener.stopListenerProcess()
+      case _ => // do nothing
+    }
     listenerCache.remove(id)
   }
 
diff --git 
a/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala 
b/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala
index d4fd9485675..f14289f984a 100644
--- 
a/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala
+++ 
b/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala
@@ -29,27 +29,36 @@ import 
org.apache.spark.internal.config.Python.{PYTHON_AUTH_SOCKET_TIMEOUT, PYTH
 
 
 private[spark] object StreamingPythonRunner {
-  def apply(func: PythonFunction, connectUrl: String): StreamingPythonRunner = 
{
-    new StreamingPythonRunner(func, connectUrl)
+  def apply(
+      func: PythonFunction,
+      connectUrl: String,
+      sessionId: String,
+      workerModule: String
+  ): StreamingPythonRunner = {
+    new StreamingPythonRunner(func, connectUrl, sessionId, workerModule)
   }
 }
 
-private[spark] class StreamingPythonRunner(func: PythonFunction, connectUrl: 
String)
-  extends Logging {
+private[spark] class StreamingPythonRunner(
+    func: PythonFunction,
+    connectUrl: String,
+    sessionId: String,
+    workerModule: String) extends Logging {
   private val conf = SparkEnv.get.conf
   protected val bufferSize: Int = conf.get(BUFFER_SIZE)
   protected val authSocketTimeout = conf.get(PYTHON_AUTH_SOCKET_TIMEOUT)
 
   private val envVars: java.util.Map[String, String] = func.envVars
   private val pythonExec: String = func.pythonExec
+  private var pythonWorker: Option[Socket] = None
   protected val pythonVer: String = func.pythonVer
 
   /**
    * Initializes the Python worker for streaming functions. Sets up Spark 
Connect session
    * to be used with the functions.
    */
-  def init(sessionId: String, workerModule: String): (DataOutputStream, 
DataInputStream) = {
-    logInfo(s"Initializing Python runner (session: $sessionId ,pythonExec: 
$pythonExec")
+  def init(): (DataOutputStream, DataInputStream) = {
+    logInfo(s"Initializing Python runner (session: $sessionId, pythonExec: 
$pythonExec")
     val env = SparkEnv.get
 
     val localdir = env.blockManager.diskBlockManager.localDirs.map(f => 
f.getPath()).mkString(",")
@@ -60,9 +69,9 @@ private[spark] class StreamingPythonRunner(func: 
PythonFunction, connectUrl: Str
     conf.set(PYTHON_USE_DAEMON, false)
     envVars.put("SPARK_CONNECT_LOCAL_URL", connectUrl)
 
-    val pythonWorkerFactory =
-      new PythonWorkerFactory(pythonExec, workerModule, envVars.asScala.toMap)
-    val (worker: Socket, _) = pythonWorkerFactory.createSimpleWorker()
+    val (worker, _) = env.createPythonWorker(
+      pythonExec, workerModule, envVars.asScala.toMap)
+    pythonWorker = Some(worker)
 
     val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize)
     val dataOut = new DataOutputStream(stream)
@@ -85,4 +94,13 @@ private[spark] class StreamingPythonRunner(func: 
PythonFunction, connectUrl: Str
 
     (dataOut, dataIn)
   }
+
+  /**
+   * Stops the Python worker.
+   */
+  def stop(): Unit = {
+    pythonWorker.foreach { worker =>
+      SparkEnv.get.destroyPythonWorker(pythonExec, workerModule, 
envVars.asScala.toMap, worker)
+    }
+  }
 }
diff --git a/python/pyspark/sql/connect/streaming/worker/foreachBatch_worker.py 
b/python/pyspark/sql/connect/streaming/worker/foreachBatch_worker.py
index 054788539f2..48a9848de40 100644
--- a/python/pyspark/sql/connect/streaming/worker/foreachBatch_worker.py
+++ b/python/pyspark/sql/connect/streaming/worker/foreachBatch_worker.py
@@ -76,7 +76,9 @@ if __name__ == "__main__":
     # Read information about how to connect back to the JVM from the 
environment.
     java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"])
     auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"]
-    (sock_file, _) = local_connect_and_auth(java_port, auth_secret)
+    (sock_file, sock) = local_connect_and_auth(java_port, auth_secret)
+    # There could be a long time between each micro batch.
+    sock.settimeout(None)
     write_int(os.getpid(), sock_file)
     sock_file.flush()
     main(sock_file, sock_file)
diff --git a/python/pyspark/sql/connect/streaming/worker/listener_worker.py 
b/python/pyspark/sql/connect/streaming/worker/listener_worker.py
index 8eb310461b6..7aef911426d 100644
--- a/python/pyspark/sql/connect/streaming/worker/listener_worker.py
+++ b/python/pyspark/sql/connect/streaming/worker/listener_worker.py
@@ -89,7 +89,9 @@ if __name__ == "__main__":
     # Read information about how to connect back to the JVM from the 
environment.
     java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"])
     auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"]
-    (sock_file, _) = local_connect_and_auth(java_port, auth_secret)
+    (sock_file, sock) = local_connect_and_auth(java_port, auth_secret)
+    # There could be a long time between each listener event.
+    sock.settimeout(None)
     write_int(os.getpid(), sock_file)
     sock_file.flush()
     main(sock_file, sock_file)
diff --git a/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py 
b/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py
index 547462d4da6..4bf58bf7807 100644
--- a/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py
+++ b/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py
@@ -60,6 +60,10 @@ class 
StreamingListenerParityTests(StreamingListenerTestsMixin, ReusedConnectTes
         try:
             self.spark.streams.addListener(test_listener)
 
+            # This ensures the read socket on the server won't crash (i.e. 
because of timeout)
+            # when there hasn't been a new event for a long time
+            time.sleep(30)
+
             df = self.spark.readStream.format("rate").option("rowsPerSecond", 
10).load()
             q = df.writeStream.format("noop").queryName("test").start()
 
@@ -76,6 +80,9 @@ class 
StreamingListenerParityTests(StreamingListenerTestsMixin, ReusedConnectTes
         finally:
             self.spark.streams.removeListener(test_listener)
 
+            # Remove again to verify this won't throw any error
+            self.spark.streams.removeListener(test_listener)
+
 
 if __name__ == "__main__":
     import unittest


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to