This is an automated email from the ASF dual-hosted git repository.

kabhwan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new a58362ecbbf7 [SPARK-46722][CONNECT] Add a test regarding to backward 
compatibility check for StreamingQueryListener in Spark Connect (Scala/PySpark)
a58362ecbbf7 is described below

commit a58362ecbbf7c4e5d5f848411834cf2a9ef298b3
Author: Jungtaek Lim <kabhwan.opensou...@gmail.com>
AuthorDate: Tue Jan 16 12:23:02 2024 +0900

    [SPARK-46722][CONNECT] Add a test regarding to backward compatibility check 
for StreamingQueryListener in Spark Connect (Scala/PySpark)
    
    ### What changes were proposed in this pull request?
    
    This PR proposes to add a functionality to perform backward compatibility 
check for StreamingQueryListener in Spark Connect (both Scala and PySpark), 
specifically implementing onQueryIdle or not.
    
    ### Why are the changes needed?
    
    We missed to add backward compatibility test when introducing onQueryIdle, 
and it led to an issue in PySpark 
(https://issues.apache.org/jira/browse/SPARK-45631). We added the compatibility 
test in PySpark but didn't add it in Spark Connect.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Modified UTs.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #44736 from HeartSaVioR/SPARK-46722.
    
    Authored-by: Jungtaek Lim <kabhwan.opensou...@gmail.com>
    Signed-off-by: Jungtaek Lim <kabhwan.opensou...@gmail.com>
---
 .../sql/streaming/ClientStreamingQuerySuite.scala  |  88 ++++++++++----
 .../connect/streaming/test_parity_listener.py      | 133 +++++++++++++--------
 2 files changed, 142 insertions(+), 79 deletions(-)

diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/ClientStreamingQuerySuite.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/ClientStreamingQuerySuite.scala
index 91c562c0f98b..fd989b5da35c 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/ClientStreamingQuerySuite.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/ClientStreamingQuerySuite.scala
@@ -20,7 +20,6 @@ package org.apache.spark.sql.streaming
 import java.io.{File, FileWriter}
 import java.util.concurrent.TimeUnit
 
-import scala.collection.mutable
 import scala.jdk.CollectionConverters._
 
 import org.scalatest.concurrent.Eventually.eventually
@@ -32,7 +31,7 @@ import org.apache.spark.api.java.function.VoidFunction2
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.{DataFrame, ForeachWriter, Row, SparkSession}
 import org.apache.spark.sql.functions.{col, udf, window}
-import org.apache.spark.sql.streaming.StreamingQueryListener.{QueryIdleEvent, 
QueryStartedEvent, QueryTerminatedEvent}
+import org.apache.spark.sql.streaming.StreamingQueryListener.{QueryIdleEvent, 
QueryProgressEvent, QueryStartedEvent, QueryTerminatedEvent}
 import org.apache.spark.sql.test.{QueryTest, SQLHelper}
 import org.apache.spark.util.SparkFileUtils
 
@@ -354,9 +353,15 @@ class ClientStreamingQuerySuite extends QueryTest with 
SQLHelper with Logging {
   }
 
   test("streaming query listener") {
+    testStreamingQueryListener(new EventCollectorV1, "_v1")
+    testStreamingQueryListener(new EventCollectorV2, "_v2")
+  }
+
+  private def testStreamingQueryListener(
+      listener: StreamingQueryListener,
+      tablePostfix: String): Unit = {
     assert(spark.streams.listListeners().length == 0)
 
-    val listener = new EventCollector
     spark.streams.addListener(listener)
 
     val q = spark.readStream
@@ -370,11 +375,21 @@ class ClientStreamingQuerySuite extends QueryTest with 
SQLHelper with Logging {
       q.processAllAvailable()
       eventually(timeout(30.seconds)) {
         assert(q.isActive)
-        checkAnswer(spark.table("my_listener_table").toDF(), Seq(Row(1, 2), 
Row(4, 5)))
+
+        
assert(!spark.table(s"listener_start_events$tablePostfix").toDF().isEmpty)
+        
assert(!spark.table(s"listener_progress_events$tablePostfix").toDF().isEmpty)
       }
     } finally {
       q.stop()
-      spark.sql("DROP TABLE IF EXISTS my_listener_table")
+
+      eventually(timeout(30.seconds)) {
+        assert(!q.isActive)
+        
assert(!spark.table(s"listener_terminated_events$tablePostfix").toDF().isEmpty)
+      }
+
+      spark.sql(s"DROP TABLE IF EXISTS listener_start_events$tablePostfix")
+      spark.sql(s"DROP TABLE IF EXISTS listener_progress_events$tablePostfix")
+      spark.sql(s"DROP TABLE IF EXISTS 
listener_terminated_events$tablePostfix")
     }
 
     // List listeners after adding a new listener, length should be 1.
@@ -382,7 +397,7 @@ class ClientStreamingQuerySuite extends QueryTest with 
SQLHelper with Logging {
     assert(listeners.length == 1)
 
     // Add listener1 as another instance of EventCollector and validate
-    val listener1 = new EventCollector
+    val listener1 = new EventCollectorV2
     spark.streams.addListener(listener1)
     assert(spark.streams.listListeners().length == 2)
     spark.streams.removeListener(listener1)
@@ -462,35 +477,56 @@ case class TestClass(value: Int) {
   override def toString: String = value.toString
 }
 
-class EventCollector extends StreamingQueryListener {
-  @volatile var startEvent: QueryStartedEvent = null
-  @volatile var terminationEvent: QueryTerminatedEvent = null
-  @volatile var idleEvent: QueryIdleEvent = null
+abstract class EventCollector extends StreamingQueryListener {
+  private lazy val spark = SparkSession.builder().getOrCreate()
 
-  private val _progressEvents = new mutable.Queue[StreamingQueryProgress]
+  protected def tablePostfix: String
 
-  def progressEvents: Seq[StreamingQueryProgress] = 
_progressEvents.synchronized {
-    _progressEvents.clone().toSeq
+  protected def handleOnQueryStarted(event: QueryStartedEvent): Unit = {
+    val df = spark.createDataFrame(Seq((event.json, 0)))
+    df.write.mode("append").saveAsTable(s"listener_start_events$tablePostfix")
   }
 
-  override def onQueryStarted(event: 
StreamingQueryListener.QueryStartedEvent): Unit = {
-    startEvent = event
-    val spark = SparkSession.builder().getOrCreate()
-    val df = spark.createDataFrame(Seq((1, 2), (4, 5)))
-    df.write.saveAsTable("my_listener_table")
+  protected def handleOnQueryProgress(event: QueryProgressEvent): Unit = {
+    val df = spark.createDataFrame(Seq((event.json, 0)))
+    
df.write.mode("append").saveAsTable(s"listener_progress_events$tablePostfix")
   }
 
-  override def onQueryProgress(event: 
StreamingQueryListener.QueryProgressEvent): Unit = {
-    _progressEvents += event.progress
+  protected def handleOnQueryTerminated(event: QueryTerminatedEvent): Unit = {
+    val df = spark.createDataFrame(Seq((event.json, 0)))
+    
df.write.mode("append").saveAsTable(s"listener_terminated_events$tablePostfix")
   }
+}
 
-  override def onQueryIdle(event: StreamingQueryListener.QueryIdleEvent): Unit 
= {
-    idleEvent = event
-  }
+/**
+ * V1: Initial interface of StreamingQueryListener containing methods 
`onQueryStarted`,
+ * `onQueryProgress`, `onQueryTerminated`. It is prior to Spark 3.5.
+ */
+class EventCollectorV1 extends EventCollector {
+  override protected def tablePostfix: String = "_v1"
 
-  override def onQueryTerminated(event: 
StreamingQueryListener.QueryTerminatedEvent): Unit = {
-    terminationEvent = event
-  }
+  override def onQueryStarted(event: QueryStartedEvent): Unit = 
handleOnQueryStarted(event)
+
+  override def onQueryProgress(event: QueryProgressEvent): Unit = 
handleOnQueryProgress(event)
+
+  override def onQueryTerminated(event: QueryTerminatedEvent): Unit =
+    handleOnQueryTerminated(event)
+}
+
+/**
+ * V2: The interface after the method `onQueryIdle` is added. It is Spark 3.5+.
+ */
+class EventCollectorV2 extends EventCollector {
+  override protected def tablePostfix: String = "_v2"
+
+  override def onQueryStarted(event: QueryStartedEvent): Unit = 
handleOnQueryStarted(event)
+
+  override def onQueryProgress(event: QueryProgressEvent): Unit = 
handleOnQueryProgress(event)
+
+  override def onQueryIdle(event: QueryIdleEvent): Unit = {}
+
+  override def onQueryTerminated(event: QueryTerminatedEvent): Unit =
+    handleOnQueryTerminated(event)
 }
 
 class ForeachBatchFn(val viewName: String)
diff --git a/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py 
b/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py
index 4fc040642bed..412f49a3960b 100644
--- a/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py
+++ b/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py
@@ -26,16 +26,36 @@ from pyspark.sql.functions import count, lit
 from pyspark.testing.connectutils import ReusedConnectTestCase
 
 
-class TestListener(StreamingQueryListener):
+# V1: Initial interface of StreamingQueryListener containing methods 
`onQueryStarted`,
+# `onQueryProgress`, `onQueryTerminated`. It is prior to Spark 3.5.
+class TestListenerV1(StreamingQueryListener):
     def onQueryStarted(self, event):
         e = pyspark.cloudpickle.dumps(event)
         df = self.spark.createDataFrame(data=[(e,)])
-        df.write.mode("append").saveAsTable("listener_start_events")
+        df.write.mode("append").saveAsTable("listener_start_events_v1")
 
     def onQueryProgress(self, event):
         e = pyspark.cloudpickle.dumps(event)
         df = self.spark.createDataFrame(data=[(e,)])
-        df.write.mode("append").saveAsTable("listener_progress_events")
+        df.write.mode("append").saveAsTable("listener_progress_events_v1")
+
+    def onQueryTerminated(self, event):
+        e = pyspark.cloudpickle.dumps(event)
+        df = self.spark.createDataFrame(data=[(e,)])
+        df.write.mode("append").saveAsTable("listener_terminated_events_v1")
+
+
+# V2: The interface after the method `onQueryIdle` is added. It is Spark 3.5+.
+class TestListenerV2(StreamingQueryListener):
+    def onQueryStarted(self, event):
+        e = pyspark.cloudpickle.dumps(event)
+        df = self.spark.createDataFrame(data=[(e,)])
+        df.write.mode("append").saveAsTable("listener_start_events_v2")
+
+    def onQueryProgress(self, event):
+        e = pyspark.cloudpickle.dumps(event)
+        df = self.spark.createDataFrame(data=[(e,)])
+        df.write.mode("append").saveAsTable("listener_progress_events_v2")
 
     def onQueryIdle(self, event):
         pass
@@ -43,60 +63,67 @@ class TestListener(StreamingQueryListener):
     def onQueryTerminated(self, event):
         e = pyspark.cloudpickle.dumps(event)
         df = self.spark.createDataFrame(data=[(e,)])
-        df.write.mode("append").saveAsTable("listener_terminated_events")
+        df.write.mode("append").saveAsTable("listener_terminated_events_v2")
 
 
 class StreamingListenerParityTests(StreamingListenerTestsMixin, 
ReusedConnectTestCase):
     def test_listener_events(self):
-        test_listener = TestListener()
-
-        try:
-            self.spark.streams.addListener(test_listener)
-
-            # This ensures the read socket on the server won't crash (i.e. 
because of timeout)
-            # when there hasn't been a new event for a long time
-            time.sleep(30)
-
-            df = self.spark.readStream.format("rate").option("rowsPerSecond", 
10).load()
-            df_observe = df.observe("my_event", count(lit(1)).alias("rc"))
-            df_stateful = df_observe.groupBy().count()  # make query stateful
-            q = (
-                df_stateful.writeStream.format("noop")
-                .queryName("test")
-                .outputMode("complete")
-                .start()
-            )
-
-            self.assertTrue(q.isActive)
-            # ensure at least one batch is ran
-            while q.lastProgress is None or q.lastProgress["batchId"] == 0:
-                time.sleep(5)
-            q.stop()
-            self.assertFalse(q.isActive)
-
-            time.sleep(60)  # Sleep to make sure listener_terminated_events is 
written successfully
-
-            start_event = pyspark.cloudpickle.loads(
-                self.spark.read.table("listener_start_events").collect()[0][0]
-            )
-
-            progress_event = pyspark.cloudpickle.loads(
-                
self.spark.read.table("listener_progress_events").collect()[0][0]
-            )
-
-            terminated_event = pyspark.cloudpickle.loads(
-                
self.spark.read.table("listener_terminated_events").collect()[0][0]
-            )
-
-            self.check_start_event(start_event)
-            self.check_progress_event(progress_event)
-            self.check_terminated_event(terminated_event)
-
-        finally:
-            self.spark.streams.removeListener(test_listener)
-
-            # Remove again to verify this won't throw any error
-            self.spark.streams.removeListener(test_listener)
+        def verify(test_listener, table_postfix):
+            try:
+                self.spark.streams.addListener(test_listener)
+
+                # This ensures the read socket on the server won't crash (i.e. 
because of timeout)
+                # when there hasn't been a new event for a long time
+                time.sleep(30)
+
+                df = 
self.spark.readStream.format("rate").option("rowsPerSecond", 10).load()
+                df_observe = df.observe("my_event", count(lit(1)).alias("rc"))
+                df_stateful = df_observe.groupBy().count()  # make query 
stateful
+                q = (
+                    df_stateful.writeStream.format("noop")
+                    .queryName("test")
+                    .outputMode("complete")
+                    .start()
+                )
+
+                self.assertTrue(q.isActive)
+                # ensure at least one batch is ran
+                while q.lastProgress is None or q.lastProgress["batchId"] == 0:
+                    time.sleep(5)
+                q.stop()
+                self.assertFalse(q.isActive)
+
+                # Sleep to make sure listener_terminated_events is written 
successfully
+                time.sleep(60)
+
+                start_table_name = "listener_start_events" + table_postfix
+                progress_tbl_name = "listener_progress_events" + table_postfix
+                terminated_tbl_name = "listener_terminated_events" + 
table_postfix
+
+                start_event = pyspark.cloudpickle.loads(
+                    self.spark.read.table(start_table_name).collect()[0][0]
+                )
+
+                progress_event = pyspark.cloudpickle.loads(
+                    self.spark.read.table(progress_tbl_name).collect()[0][0]
+                )
+
+                terminated_event = pyspark.cloudpickle.loads(
+                    self.spark.read.table(terminated_tbl_name).collect()[0][0]
+                )
+
+                self.check_start_event(start_event)
+                self.check_progress_event(progress_event)
+                self.check_terminated_event(terminated_event)
+
+            finally:
+                self.spark.streams.removeListener(test_listener)
+
+                # Remove again to verify this won't throw any error
+                self.spark.streams.removeListener(test_listener)
+
+        verify(TestListenerV1(), "_v1")
+        verify(TestListenerV2(), "_v2")
 
     def test_accessing_spark_session(self):
         spark = self.spark


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to