This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new b6190a3db97 [SPARK-45056][PYTHON][SS][CONNECT] Termination tests for streamingQueryListener and foreachBatch b6190a3db97 is described below commit b6190a3db974c19b6a0c4fe7af75531d67755074 Author: Wei Liu <wei....@databricks.com> AuthorDate: Thu Sep 14 11:23:44 2023 +0900 [SPARK-45056][PYTHON][SS][CONNECT] Termination tests for streamingQueryListener and foreachBatch ### What changes were proposed in this pull request? Add termination tests for StreamingQueryListener and foreachBatch. The behavior is mimicked by creating the same query on server side that would have been created if running the same python query is ran on client side. For example, in foreachBatch, a python foreachBatch function is serialized using cloudPickleSerializer and passed to the server side, here we start another python process on the server and call the same cloudPickleSerializer and pass the bytes to the server, and construct `SimplePythonFunction` accordingly. Refactored the code a bit for testing purpose. ### Why are the changes needed? Necessary tests ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Test only addition ### Was this patch authored or co-authored using generative AI tooling? No Closes #42779 from WweiL/SPARK-44435-followup-termination-tests. Lead-authored-by: Wei Liu <wei....@databricks.com> Co-authored-by: Wei Liu <z920631...@gmail.com> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- .github/workflows/build_and_test.yml | 6 +- .../sql/connect/planner/SparkConnectPlanner.scala | 3 +- .../planner/StreamingQueryListenerHelper.scala | 8 +- .../service/SparkConnectSessionHodlerSuite.scala | 205 +++++++++++++++++++++ .../spark/api/python/PythonWorkerFactory.scala | 5 + .../spark/api/python/StreamingPythonRunner.scala | 26 ++- .../connect/streaming/test_parity_foreach_batch.py | 1 - .../streaming/test_streaming_foreach_batch.py | 2 - .../apache/spark/sql/IntegratedUDFTestUtils.scala | 4 +- 9 files changed, 242 insertions(+), 18 deletions(-) diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index 21809564497..25c95bd607d 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -256,14 +256,14 @@ jobs: # We should install one Python that is higher then 3+ for SQL and Yarn because: # - SQL component also has Python related tests, for example, IntegratedUDFTestUtils. # - Yarn has a Python specific test too, for example, YarnClusterSuite. - if: contains(matrix.modules, 'yarn') || (contains(matrix.modules, 'sql') && !contains(matrix.modules, 'sql-')) + if: contains(matrix.modules, 'yarn') || (contains(matrix.modules, 'sql') && !contains(matrix.modules, 'sql-')) || contains(matrix.modules, 'connect') with: python-version: 3.8 architecture: x64 - name: Install Python packages (Python 3.8) - if: (contains(matrix.modules, 'sql') && !contains(matrix.modules, 'sql-')) + if: (contains(matrix.modules, 'sql') && !contains(matrix.modules, 'sql-')) || contains(matrix.modules, 'connect') run: | - python3.8 -m pip install 'numpy>=1.20.0' 'pyarrow==12.0.1' pandas scipy unittest-xml-reporting 'grpcio==1.56.0' 'protobuf==3.20.3' + python3.8 -m pip install 'numpy>=1.20.0' 'pyarrow==12.0.1' pandas scipy unittest-xml-reporting 'grpcio>=1.48,<1.57' 'grpcio-status>=1.48,<1.57' 'protobuf==3.20.3' python3.8 -m pip list # Run the tests. - name: Run tests diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index b8ab5539b30..24dee006f0b 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -3131,8 +3131,7 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { val listener = if (command.getAddListener.hasPythonListenerPayload) { new PythonStreamingQueryListener( transformPythonFunction(command.getAddListener.getPythonListenerPayload), - sessionHolder, - pythonExec) + sessionHolder) } else { val listenerPacket = Utils .deserialize[StreamingListenerPacket]( diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingQueryListenerHelper.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingQueryListenerHelper.scala index 9b2a931ec4a..01339a8a1b4 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingQueryListenerHelper.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingQueryListenerHelper.scala @@ -26,15 +26,13 @@ import org.apache.spark.sql.streaming.StreamingQueryListener * instance of this class starts a python process, inside which has the python handling logic. * When a new event is received, it is serialized to json, and passed to the python process. */ -class PythonStreamingQueryListener( - listener: SimplePythonFunction, - sessionHolder: SessionHolder, - pythonExec: String) +class PythonStreamingQueryListener(listener: SimplePythonFunction, sessionHolder: SessionHolder) extends StreamingQueryListener { private val port = SparkConnectService.localPort private val connectUrl = s"sc://localhost:$port/;user_id=${sessionHolder.userId}" - private val runner = StreamingPythonRunner( + // Scoped for testing + private[connect] val runner = StreamingPythonRunner( listener, connectUrl, sessionHolder.sessionId, diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHodlerSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHodlerSuite.scala index 51b78886819..c5874e10ead 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHodlerSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHodlerSuite.scala @@ -17,7 +17,20 @@ package org.apache.spark.sql.connect.service +import java.nio.charset.StandardCharsets +import java.nio.file.Files + +import scala.collection.JavaConverters._ +import scala.collection.mutable +import scala.sys.process.Process + +import com.google.common.collect.Lists +import org.scalatest.time.SpanSugar._ + +import org.apache.spark.api.python.SimplePythonFunction +import org.apache.spark.sql.IntegratedUDFTestUtils import org.apache.spark.sql.connect.common.InvalidPlanInput +import org.apache.spark.sql.connect.planner.{PythonStreamingQueryListener, StreamingForeachBatchHelper} import org.apache.spark.sql.test.SharedSparkSession class SparkConnectSessionHolderSuite extends SharedSparkSession { @@ -79,4 +92,196 @@ class SparkConnectSessionHolderSuite extends SharedSparkSession { sessionHolder.getDataFrameOrThrow(key1) } } + + private def streamingForeachBatchFunction(pysparkPythonPath: String): Array[Byte] = { + var binaryFunc: Array[Byte] = null + withTempPath { path => + Process( + Seq( + IntegratedUDFTestUtils.pythonExec, + "-c", + "from pyspark.serializers import CloudPickleSerializer; " + + s"f = open('$path', 'wb');" + + "f.write(CloudPickleSerializer().dumps((" + + "lambda df, batchId: batchId)))"), + None, + "PYTHONPATH" -> pysparkPythonPath).!! + binaryFunc = Files.readAllBytes(path.toPath) + } + assert(binaryFunc != null) + binaryFunc + } + + private def streamingQueryListenerFunction(pysparkPythonPath: String): Array[Byte] = { + var binaryFunc: Array[Byte] = null + val pythonScript = + """ + |from pyspark.sql.streaming.listener import StreamingQueryListener + | + |class MyListener(StreamingQueryListener): + | def onQueryStarted(e): + | pass + | + | def onQueryIdle(e): + | pass + | + | def onQueryProgress(e): + | pass + | + | def onQueryTerminated(e): + | pass + | + |listener = MyListener() + """.stripMargin + withTempPath { codePath => + Files.write(codePath.toPath, pythonScript.getBytes(StandardCharsets.UTF_8)) + withTempPath { path => + Process( + Seq( + IntegratedUDFTestUtils.pythonExec, + "-c", + "from pyspark.serializers import CloudPickleSerializer; " + + s"f = open('$path', 'wb');" + + s"exec(open('$codePath', 'r').read());" + + "f.write(CloudPickleSerializer().dumps(listener))"), + None, + "PYTHONPATH" -> pysparkPythonPath).!! + binaryFunc = Files.readAllBytes(path.toPath) + } + } + assert(binaryFunc != null) + binaryFunc + } + + private def dummyPythonFunction(sessionHolder: SessionHolder)( + fcn: String => Array[Byte]): SimplePythonFunction = { + val sparkPythonPath = + s"${IntegratedUDFTestUtils.pysparkPythonPath}:${IntegratedUDFTestUtils.pythonPath}" + + SimplePythonFunction( + command = fcn(sparkPythonPath), + envVars = mutable.Map("PYTHONPATH" -> sparkPythonPath).asJava, + pythonIncludes = sessionHolder.artifactManager.getSparkConnectPythonIncludes.asJava, + pythonExec = IntegratedUDFTestUtils.pythonExec, + pythonVer = IntegratedUDFTestUtils.pythonVer, + broadcastVars = Lists.newArrayList(), + accumulator = null) + } + + test("python foreachBatch process: process terminates after query is stopped") { + // scalastyle:off assume + assume(IntegratedUDFTestUtils.shouldTestPythonUDFs) + // scalastyle:on assume + + val sessionHolder = SessionHolder.forTesting(spark) + try { + SparkConnectService.start(spark.sparkContext) + + val pythonFn = dummyPythonFunction(sessionHolder)(streamingForeachBatchFunction) + val (fn1, cleaner1) = + StreamingForeachBatchHelper.pythonForeachBatchWrapper(pythonFn, sessionHolder) + val (fn2, cleaner2) = + StreamingForeachBatchHelper.pythonForeachBatchWrapper(pythonFn, sessionHolder) + + val query1 = spark.readStream + .format("rate") + .load() + .writeStream + .format("memory") + .queryName("foreachBatch_termination_test_q1") + .foreachBatch(fn1) + .start() + + val query2 = spark.readStream + .format("rate") + .load() + .writeStream + .format("memory") + .queryName("foreachBatch_termination_test_q2") + .foreachBatch(fn2) + .start() + + sessionHolder.streamingForeachBatchRunnerCleanerCache + .registerCleanerForQuery(query1, cleaner1) + sessionHolder.streamingForeachBatchRunnerCleanerCache + .registerCleanerForQuery(query2, cleaner2) + + val (runner1, runner2) = (cleaner1.runner, cleaner2.runner) + + // assert both python processes are running + assert(!runner1.isWorkerStopped().get) + assert(!runner2.isWorkerStopped().get) + // stop query1 + query1.stop() + // assert query1's python process is not running + eventually(timeout(30.seconds)) { + assert(runner1.isWorkerStopped().get) + assert(!runner2.isWorkerStopped().get) + } + + // stop query2 + query2.stop() + eventually(timeout(30.seconds)) { + // assert query2's python process is not running + assert(runner2.isWorkerStopped().get) + } + + assert(spark.streams.active.isEmpty) // no running query + assert(spark.streams.listListeners().length == 1) // only process termination listener + } finally { + SparkConnectService.stop() + // remove process termination listener + spark.streams.removeListener(spark.streams.listListeners()(0)) + } + } + + test("python listener process: process terminates after listener is removed") { + // scalastyle:off assume + assume(IntegratedUDFTestUtils.shouldTestPythonUDFs) + // scalastyle:on assume + + val sessionHolder = SessionHolder.forTesting(spark) + try { + SparkConnectService.start(spark.sparkContext) + + val pythonFn = dummyPythonFunction(sessionHolder)(streamingQueryListenerFunction) + + val id1 = "listener_removeListener_test_1" + val id2 = "listener_removeListener_test_2" + val listener1 = new PythonStreamingQueryListener(pythonFn, sessionHolder) + val listener2 = new PythonStreamingQueryListener(pythonFn, sessionHolder) + + sessionHolder.cacheListenerById(id1, listener1) + spark.streams.addListener(listener1) + sessionHolder.cacheListenerById(id2, listener2) + spark.streams.addListener(listener2) + + val (runner1, runner2) = (listener1.runner, listener2.runner) + + // assert both python processes are running + assert(!runner1.isWorkerStopped().get) + assert(!runner2.isWorkerStopped().get) + + // remove listener1 + spark.streams.removeListener(listener1) + sessionHolder.removeCachedListener(id1) + // assert listener1's python process is not running + eventually(timeout(30.seconds)) { + assert(runner1.isWorkerStopped().get) + assert(!runner2.isWorkerStopped().get) + } + + // remove listener2 + spark.streams.removeListener(listener2) + sessionHolder.removeCachedListener(id2) + eventually(timeout(30.seconds)) { + // assert listener2's python process is not running + assert(runner2.isWorkerStopped().get) + // all listeners are removed + assert(spark.streams.listListeners().isEmpty) + } + } finally { + SparkConnectService.stop() + } + } } diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala index 8888c97041e..d0776eb2cc7 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala @@ -413,6 +413,11 @@ private[spark] class PythonWorkerFactory( } } } + + def isWorkerStopped(worker: PythonWorker): Boolean = { + assert(!useDaemon, "isWorkerStopped() is not supported for daemon mode") + simpleWorkers.get(worker).exists(!_.isAlive) + } } private[spark] object PythonWorkerFactory { diff --git a/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala index 1dae3aa19ab..bd2a8a01cac 100644 --- a/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala @@ -108,10 +108,30 @@ private[spark] class StreamingPythonRunner( * Stops the Python worker. */ def stop(): Unit = { - pythonWorker.foreach { worker => + logInfo(s"Stopping streaming runner for sessionId: $sessionId, module: $workerModule.") + + try { pythonWorkerFactory.foreach { factory => - factory.stopWorker(worker) - factory.stop() + pythonWorker.foreach { worker => + factory.stopWorker(worker) + factory.stop() + } + } + } catch { + case e: Exception => + logError("Exception when trying to kill worker", e) + } + } + + /** + * Returns whether the Python worker has been stopped. + * @return Some(true) if the Python worker has been stopped. + * None if either the Python worker or the Python worker factory is not initialized. + */ + def isWorkerStopped(): Option[Boolean] = { + pythonWorkerFactory.flatMap { factory => + pythonWorker.map { worker => + factory.isWorkerStopped(worker) } } } diff --git a/python/pyspark/sql/tests/connect/streaming/test_parity_foreach_batch.py b/python/pyspark/sql/tests/connect/streaming/test_parity_foreach_batch.py index e4577173687..c174bd53f8e 100644 --- a/python/pyspark/sql/tests/connect/streaming/test_parity_foreach_batch.py +++ b/python/pyspark/sql/tests/connect/streaming/test_parity_foreach_batch.py @@ -31,7 +31,6 @@ class StreamingForeachBatchParityTests(StreamingTestsForeachBatchMixin, ReusedCo def test_streaming_foreach_batch_graceful_stop(self): super().test_streaming_foreach_batch_graceful_stop() - # class StreamingForeachBatchParityTests(ReusedConnectTestCase): def test_accessing_spark_session(self): spark = self.spark diff --git a/python/pyspark/sql/tests/streaming/test_streaming_foreach_batch.py b/python/pyspark/sql/tests/streaming/test_streaming_foreach_batch.py index af2831ef193..84cd42b342f 100644 --- a/python/pyspark/sql/tests/streaming/test_streaming_foreach_batch.py +++ b/python/pyspark/sql/tests/streaming/test_streaming_foreach_batch.py @@ -135,8 +135,6 @@ class StreamingTestsForeachBatchMixin: df = df.union(df) self.assertEqual(sorted(df.collect()), sorted(actual.collect())) - # write to delta table? - @staticmethod def my_test_function_2(): return 2 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala index 05f71500a0f..bc7a68732a1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala @@ -97,14 +97,14 @@ import org.apache.spark.sql.types.{DataType, IntegerType, NullType, StringType, object IntegratedUDFTestUtils extends SQLHelper { import scala.sys.process._ - private lazy val pythonPath = sys.env.getOrElse("PYTHONPATH", "") + private[spark] lazy val pythonPath = sys.env.getOrElse("PYTHONPATH", "") // Note that we will directly refer pyspark's source, not the zip from a regular build. // It is possible the test is being ran without the build. private lazy val sourcePath = Paths.get(sparkHome, "python").toAbsolutePath private lazy val py4jPath = Paths.get( sparkHome, "python", "lib", PythonUtils.PY4J_ZIP_NAME).toAbsolutePath - private lazy val pysparkPythonPath = s"$py4jPath:$sourcePath" + private[spark] lazy val pysparkPythonPath = s"$py4jPath:$sourcePath" private lazy val isPythonAvailable: Boolean = TestUtils.testCommandAvailable(pythonExec) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org