This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new b6190a3db97 [SPARK-45056][PYTHON][SS][CONNECT] Termination tests for 
streamingQueryListener and foreachBatch
b6190a3db97 is described below

commit b6190a3db974c19b6a0c4fe7af75531d67755074
Author: Wei Liu <wei....@databricks.com>
AuthorDate: Thu Sep 14 11:23:44 2023 +0900

    [SPARK-45056][PYTHON][SS][CONNECT] Termination tests for 
streamingQueryListener and foreachBatch
    
    ### What changes were proposed in this pull request?
    
    Add termination tests for StreamingQueryListener and foreachBatch.
    
    The behavior is mimicked by creating the same query on server side that 
would have been created if running the same python query is ran on client side. 
For example, in foreachBatch, a python foreachBatch function is serialized 
using cloudPickleSerializer and passed to the server side, here we start 
another python process on the server and call the same cloudPickleSerializer 
and pass the bytes to the server, and construct `SimplePythonFunction` 
accordingly.
    
    Refactored the code a bit for testing purpose.
    
    ### Why are the changes needed?
    
    Necessary tests
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    Test only addition
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #42779 from WweiL/SPARK-44435-followup-termination-tests.
    
    Lead-authored-by: Wei Liu <wei....@databricks.com>
    Co-authored-by: Wei Liu <z920631...@gmail.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 .github/workflows/build_and_test.yml               |   6 +-
 .../sql/connect/planner/SparkConnectPlanner.scala  |   3 +-
 .../planner/StreamingQueryListenerHelper.scala     |   8 +-
 .../service/SparkConnectSessionHodlerSuite.scala   | 205 +++++++++++++++++++++
 .../spark/api/python/PythonWorkerFactory.scala     |   5 +
 .../spark/api/python/StreamingPythonRunner.scala   |  26 ++-
 .../connect/streaming/test_parity_foreach_batch.py |   1 -
 .../streaming/test_streaming_foreach_batch.py      |   2 -
 .../apache/spark/sql/IntegratedUDFTestUtils.scala  |   4 +-
 9 files changed, 242 insertions(+), 18 deletions(-)

diff --git a/.github/workflows/build_and_test.yml 
b/.github/workflows/build_and_test.yml
index 21809564497..25c95bd607d 100644
--- a/.github/workflows/build_and_test.yml
+++ b/.github/workflows/build_and_test.yml
@@ -256,14 +256,14 @@ jobs:
       # We should install one Python that is higher then 3+ for SQL and Yarn 
because:
       # - SQL component also has Python related tests, for example, 
IntegratedUDFTestUtils.
       # - Yarn has a Python specific test too, for example, YarnClusterSuite.
-      if: contains(matrix.modules, 'yarn') || (contains(matrix.modules, 'sql') 
&& !contains(matrix.modules, 'sql-'))
+      if: contains(matrix.modules, 'yarn') || (contains(matrix.modules, 'sql') 
&& !contains(matrix.modules, 'sql-')) || contains(matrix.modules, 'connect')
       with:
         python-version: 3.8
         architecture: x64
     - name: Install Python packages (Python 3.8)
-      if: (contains(matrix.modules, 'sql') && !contains(matrix.modules, 
'sql-'))
+      if: (contains(matrix.modules, 'sql') && !contains(matrix.modules, 
'sql-')) || contains(matrix.modules, 'connect')
       run: |
-        python3.8 -m pip install 'numpy>=1.20.0' 'pyarrow==12.0.1' pandas 
scipy unittest-xml-reporting 'grpcio==1.56.0' 'protobuf==3.20.3'
+        python3.8 -m pip install 'numpy>=1.20.0' 'pyarrow==12.0.1' pandas 
scipy unittest-xml-reporting 'grpcio>=1.48,<1.57' 'grpcio-status>=1.48,<1.57' 
'protobuf==3.20.3'
         python3.8 -m pip list
     # Run the tests.
     - name: Run tests
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index b8ab5539b30..24dee006f0b 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -3131,8 +3131,7 @@ class SparkConnectPlanner(val sessionHolder: 
SessionHolder) extends Logging {
         val listener = if (command.getAddListener.hasPythonListenerPayload) {
           new PythonStreamingQueryListener(
             
transformPythonFunction(command.getAddListener.getPythonListenerPayload),
-            sessionHolder,
-            pythonExec)
+            sessionHolder)
         } else {
           val listenerPacket = Utils
             .deserialize[StreamingListenerPacket](
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingQueryListenerHelper.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingQueryListenerHelper.scala
index 9b2a931ec4a..01339a8a1b4 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingQueryListenerHelper.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingQueryListenerHelper.scala
@@ -26,15 +26,13 @@ import org.apache.spark.sql.streaming.StreamingQueryListener
  * instance of this class starts a python process, inside which has the python 
handling logic.
  * When a new event is received, it is serialized to json, and passed to the 
python process.
  */
-class PythonStreamingQueryListener(
-    listener: SimplePythonFunction,
-    sessionHolder: SessionHolder,
-    pythonExec: String)
+class PythonStreamingQueryListener(listener: SimplePythonFunction, 
sessionHolder: SessionHolder)
     extends StreamingQueryListener {
 
   private val port = SparkConnectService.localPort
   private val connectUrl = 
s"sc://localhost:$port/;user_id=${sessionHolder.userId}"
-  private val runner = StreamingPythonRunner(
+  // Scoped for testing
+  private[connect] val runner = StreamingPythonRunner(
     listener,
     connectUrl,
     sessionHolder.sessionId,
diff --git 
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHodlerSuite.scala
 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHodlerSuite.scala
index 51b78886819..c5874e10ead 100644
--- 
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHodlerSuite.scala
+++ 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHodlerSuite.scala
@@ -17,7 +17,20 @@
 
 package org.apache.spark.sql.connect.service
 
+import java.nio.charset.StandardCharsets
+import java.nio.file.Files
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable
+import scala.sys.process.Process
+
+import com.google.common.collect.Lists
+import org.scalatest.time.SpanSugar._
+
+import org.apache.spark.api.python.SimplePythonFunction
+import org.apache.spark.sql.IntegratedUDFTestUtils
 import org.apache.spark.sql.connect.common.InvalidPlanInput
+import org.apache.spark.sql.connect.planner.{PythonStreamingQueryListener, 
StreamingForeachBatchHelper}
 import org.apache.spark.sql.test.SharedSparkSession
 
 class SparkConnectSessionHolderSuite extends SharedSparkSession {
@@ -79,4 +92,196 @@ class SparkConnectSessionHolderSuite extends 
SharedSparkSession {
       sessionHolder.getDataFrameOrThrow(key1)
     }
   }
+
+  private def streamingForeachBatchFunction(pysparkPythonPath: String): 
Array[Byte] = {
+    var binaryFunc: Array[Byte] = null
+    withTempPath { path =>
+      Process(
+        Seq(
+          IntegratedUDFTestUtils.pythonExec,
+          "-c",
+          "from pyspark.serializers import CloudPickleSerializer; " +
+            s"f = open('$path', 'wb');" +
+            "f.write(CloudPickleSerializer().dumps((" +
+            "lambda df, batchId: batchId)))"),
+        None,
+        "PYTHONPATH" -> pysparkPythonPath).!!
+      binaryFunc = Files.readAllBytes(path.toPath)
+    }
+    assert(binaryFunc != null)
+    binaryFunc
+  }
+
+  private def streamingQueryListenerFunction(pysparkPythonPath: String): 
Array[Byte] = {
+    var binaryFunc: Array[Byte] = null
+    val pythonScript =
+      """
+        |from pyspark.sql.streaming.listener import StreamingQueryListener
+        |
+        |class MyListener(StreamingQueryListener):
+        |    def onQueryStarted(e):
+        |        pass
+        |
+        |    def onQueryIdle(e):
+        |        pass
+        |
+        |    def onQueryProgress(e):
+        |        pass
+        |
+        |    def onQueryTerminated(e):
+        |        pass
+        |
+        |listener = MyListener()
+      """.stripMargin
+    withTempPath { codePath =>
+      Files.write(codePath.toPath, 
pythonScript.getBytes(StandardCharsets.UTF_8))
+      withTempPath { path =>
+        Process(
+          Seq(
+            IntegratedUDFTestUtils.pythonExec,
+            "-c",
+            "from pyspark.serializers import CloudPickleSerializer; " +
+              s"f = open('$path', 'wb');" +
+              s"exec(open('$codePath', 'r').read());" +
+              "f.write(CloudPickleSerializer().dumps(listener))"),
+          None,
+          "PYTHONPATH" -> pysparkPythonPath).!!
+        binaryFunc = Files.readAllBytes(path.toPath)
+      }
+    }
+    assert(binaryFunc != null)
+    binaryFunc
+  }
+
+  private def dummyPythonFunction(sessionHolder: SessionHolder)(
+      fcn: String => Array[Byte]): SimplePythonFunction = {
+    val sparkPythonPath =
+      
s"${IntegratedUDFTestUtils.pysparkPythonPath}:${IntegratedUDFTestUtils.pythonPath}"
+
+    SimplePythonFunction(
+      command = fcn(sparkPythonPath),
+      envVars = mutable.Map("PYTHONPATH" -> sparkPythonPath).asJava,
+      pythonIncludes = 
sessionHolder.artifactManager.getSparkConnectPythonIncludes.asJava,
+      pythonExec = IntegratedUDFTestUtils.pythonExec,
+      pythonVer = IntegratedUDFTestUtils.pythonVer,
+      broadcastVars = Lists.newArrayList(),
+      accumulator = null)
+  }
+
+  test("python foreachBatch process: process terminates after query is 
stopped") {
+    // scalastyle:off assume
+    assume(IntegratedUDFTestUtils.shouldTestPythonUDFs)
+    // scalastyle:on assume
+
+    val sessionHolder = SessionHolder.forTesting(spark)
+    try {
+      SparkConnectService.start(spark.sparkContext)
+
+      val pythonFn = 
dummyPythonFunction(sessionHolder)(streamingForeachBatchFunction)
+      val (fn1, cleaner1) =
+        StreamingForeachBatchHelper.pythonForeachBatchWrapper(pythonFn, 
sessionHolder)
+      val (fn2, cleaner2) =
+        StreamingForeachBatchHelper.pythonForeachBatchWrapper(pythonFn, 
sessionHolder)
+
+      val query1 = spark.readStream
+        .format("rate")
+        .load()
+        .writeStream
+        .format("memory")
+        .queryName("foreachBatch_termination_test_q1")
+        .foreachBatch(fn1)
+        .start()
+
+      val query2 = spark.readStream
+        .format("rate")
+        .load()
+        .writeStream
+        .format("memory")
+        .queryName("foreachBatch_termination_test_q2")
+        .foreachBatch(fn2)
+        .start()
+
+      sessionHolder.streamingForeachBatchRunnerCleanerCache
+        .registerCleanerForQuery(query1, cleaner1)
+      sessionHolder.streamingForeachBatchRunnerCleanerCache
+        .registerCleanerForQuery(query2, cleaner2)
+
+      val (runner1, runner2) = (cleaner1.runner, cleaner2.runner)
+
+      // assert both python processes are running
+      assert(!runner1.isWorkerStopped().get)
+      assert(!runner2.isWorkerStopped().get)
+      // stop query1
+      query1.stop()
+      // assert query1's python process is not running
+      eventually(timeout(30.seconds)) {
+        assert(runner1.isWorkerStopped().get)
+        assert(!runner2.isWorkerStopped().get)
+      }
+
+      // stop query2
+      query2.stop()
+      eventually(timeout(30.seconds)) {
+        // assert query2's python process is not running
+        assert(runner2.isWorkerStopped().get)
+      }
+
+      assert(spark.streams.active.isEmpty) // no running query
+      assert(spark.streams.listListeners().length == 1) // only process 
termination listener
+    } finally {
+      SparkConnectService.stop()
+      // remove process termination listener
+      spark.streams.removeListener(spark.streams.listListeners()(0))
+    }
+  }
+
+  test("python listener process: process terminates after listener is 
removed") {
+    // scalastyle:off assume
+    assume(IntegratedUDFTestUtils.shouldTestPythonUDFs)
+    // scalastyle:on assume
+
+    val sessionHolder = SessionHolder.forTesting(spark)
+    try {
+      SparkConnectService.start(spark.sparkContext)
+
+      val pythonFn = 
dummyPythonFunction(sessionHolder)(streamingQueryListenerFunction)
+
+      val id1 = "listener_removeListener_test_1"
+      val id2 = "listener_removeListener_test_2"
+      val listener1 = new PythonStreamingQueryListener(pythonFn, sessionHolder)
+      val listener2 = new PythonStreamingQueryListener(pythonFn, sessionHolder)
+
+      sessionHolder.cacheListenerById(id1, listener1)
+      spark.streams.addListener(listener1)
+      sessionHolder.cacheListenerById(id2, listener2)
+      spark.streams.addListener(listener2)
+
+      val (runner1, runner2) = (listener1.runner, listener2.runner)
+
+      // assert both python processes are running
+      assert(!runner1.isWorkerStopped().get)
+      assert(!runner2.isWorkerStopped().get)
+
+      // remove listener1
+      spark.streams.removeListener(listener1)
+      sessionHolder.removeCachedListener(id1)
+      // assert listener1's python process is not running
+      eventually(timeout(30.seconds)) {
+        assert(runner1.isWorkerStopped().get)
+        assert(!runner2.isWorkerStopped().get)
+      }
+
+      // remove listener2
+      spark.streams.removeListener(listener2)
+      sessionHolder.removeCachedListener(id2)
+      eventually(timeout(30.seconds)) {
+        // assert listener2's python process is not running
+        assert(runner2.isWorkerStopped().get)
+        // all listeners are removed
+        assert(spark.streams.listListeners().isEmpty)
+      }
+    } finally {
+      SparkConnectService.stop()
+    }
+  }
 }
diff --git 
a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala 
b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
index 8888c97041e..d0776eb2cc7 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
@@ -413,6 +413,11 @@ private[spark] class PythonWorkerFactory(
       }
     }
   }
+
+  def isWorkerStopped(worker: PythonWorker): Boolean = {
+    assert(!useDaemon, "isWorkerStopped() is not supported for daemon mode")
+    simpleWorkers.get(worker).exists(!_.isAlive)
+  }
 }
 
 private[spark] object PythonWorkerFactory {
diff --git 
a/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala 
b/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala
index 1dae3aa19ab..bd2a8a01cac 100644
--- 
a/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala
+++ 
b/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala
@@ -108,10 +108,30 @@ private[spark] class StreamingPythonRunner(
    * Stops the Python worker.
    */
   def stop(): Unit = {
-    pythonWorker.foreach { worker =>
+    logInfo(s"Stopping streaming runner for sessionId: $sessionId, module: 
$workerModule.")
+
+    try {
       pythonWorkerFactory.foreach { factory =>
-        factory.stopWorker(worker)
-        factory.stop()
+        pythonWorker.foreach { worker =>
+          factory.stopWorker(worker)
+          factory.stop()
+        }
+      }
+    } catch {
+      case e: Exception =>
+        logError("Exception when trying to kill worker", e)
+    }
+  }
+
+  /**
+   * Returns whether the Python worker has been stopped.
+   * @return Some(true) if the Python worker has been stopped.
+   *         None if either the Python worker or the Python worker factory is 
not initialized.
+   */
+  def isWorkerStopped(): Option[Boolean] = {
+    pythonWorkerFactory.flatMap { factory =>
+      pythonWorker.map { worker =>
+        factory.isWorkerStopped(worker)
       }
     }
   }
diff --git 
a/python/pyspark/sql/tests/connect/streaming/test_parity_foreach_batch.py 
b/python/pyspark/sql/tests/connect/streaming/test_parity_foreach_batch.py
index e4577173687..c174bd53f8e 100644
--- a/python/pyspark/sql/tests/connect/streaming/test_parity_foreach_batch.py
+++ b/python/pyspark/sql/tests/connect/streaming/test_parity_foreach_batch.py
@@ -31,7 +31,6 @@ class 
StreamingForeachBatchParityTests(StreamingTestsForeachBatchMixin, ReusedCo
     def test_streaming_foreach_batch_graceful_stop(self):
         super().test_streaming_foreach_batch_graceful_stop()
 
-    # class StreamingForeachBatchParityTests(ReusedConnectTestCase):
     def test_accessing_spark_session(self):
         spark = self.spark
 
diff --git a/python/pyspark/sql/tests/streaming/test_streaming_foreach_batch.py 
b/python/pyspark/sql/tests/streaming/test_streaming_foreach_batch.py
index af2831ef193..84cd42b342f 100644
--- a/python/pyspark/sql/tests/streaming/test_streaming_foreach_batch.py
+++ b/python/pyspark/sql/tests/streaming/test_streaming_foreach_batch.py
@@ -135,8 +135,6 @@ class StreamingTestsForeachBatchMixin:
         df = df.union(df)
         self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
 
-        # write to delta table?
-
     @staticmethod
     def my_test_function_2():
         return 2
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala
index 05f71500a0f..bc7a68732a1 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala
@@ -97,14 +97,14 @@ import org.apache.spark.sql.types.{DataType, IntegerType, 
NullType, StringType,
 object IntegratedUDFTestUtils extends SQLHelper {
   import scala.sys.process._
 
-  private lazy val pythonPath = sys.env.getOrElse("PYTHONPATH", "")
+  private[spark] lazy val pythonPath = sys.env.getOrElse("PYTHONPATH", "")
 
   // Note that we will directly refer pyspark's source, not the zip from a 
regular build.
   // It is possible the test is being ran without the build.
   private lazy val sourcePath = Paths.get(sparkHome, "python").toAbsolutePath
   private lazy val py4jPath = Paths.get(
     sparkHome, "python", "lib", PythonUtils.PY4J_ZIP_NAME).toAbsolutePath
-  private lazy val pysparkPythonPath = s"$py4jPath:$sourcePath"
+  private[spark] lazy val pysparkPythonPath = s"$py4jPath:$sourcePath"
 
   private lazy val isPythonAvailable: Boolean = 
TestUtils.testCommandAvailable(pythonExec)
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to