This is an automated email from the ASF dual-hosted git repository.

kabhwan pushed a commit to branch branch-4.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-4.0 by this push:
     new 002a9bf3e617 [SPARK-51758][SS] Apply late record filtering based on 
watermark only if timeMode is passed as EventTime to the transformWithState 
operator
002a9bf3e617 is described below

commit 002a9bf3e61742933ee14cebbf3f2da4863ed34b
Author: Anish Shrigondekar <[email protected]>
AuthorDate: Tue Apr 15 06:33:11 2025 +0900

    [SPARK-51758][SS] Apply late record filtering based on watermark only if 
timeMode is passed as EventTime to the transformWithState operator
    
    ### What changes were proposed in this pull request?
    Apply late record filtering based on watermark only if timeMode is passed 
as EventTime to the transformWithState operator
    
    ### Why are the changes needed?
    Without this, we might filter records even if timeMode is passed as 
None/ProcessingTime which might be counter intuitive for users.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes
    
    ### How was this patch tested?
    Added unit tests
    
    ```
    ===== POSSIBLE THREAD LEAK IN SUITE 
o.a.s.sql.streaming.TransformWithStateValidationSuite, threads: 
ForkJoinPool.commonPool-worker-2 (daemon=true), files-client-8-1 (daemon=true), 
rpc-boss-3-1 (daemon=true), shuffle-boss-6-1 (daemon=true), 
ForkJoinPool.commonPool-worker-1 (daemon=true), Cleaner-0 (daemon=true), 
ForkJoinPool.commonPool-worker-3 (daemon=true) =====
    [info] Run completed in 15 seconds, 292 milliseconds.
    [info] Total number of tests run: 4
    [info] Suites: completed 1, aborted 0
    [info] Tests: succeeded 4, failed 0, canceled 0, ignored 0, pending 0
    [info] All tests passed.
    [success] Total time: 44 s, co
    ```
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #50550 from anishshri-db/task/SPARK-51758.
    
    Authored-by: Anish Shrigondekar <[email protected]>
    Signed-off-by: Jungtaek Lim <[email protected]>
    (cherry picked from commit bb1a63a8c361d8748b6eab1634271b9ebdc3e619)
    Signed-off-by: Jungtaek Lim <[email protected]>
---
 .../pandas/test_pandas_transform_with_state.py     | 61 ++++++++++++++++++-
 .../streaming/TransformWithStateInPandasExec.scala |  2 +-
 .../streaming/TransformWithStateExec.scala         |  2 +-
 .../sql/streaming/TransformWithStateSuite.scala    | 70 ++++++++++++++++++++++
 4 files changed, 131 insertions(+), 4 deletions(-)

diff --git 
a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py 
b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py
index dc104704e169..44ea90ab6659 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py
@@ -17,6 +17,7 @@
 
 import json
 import os
+import sys
 import time
 import tempfile
 from pyspark.sql.streaming import StatefulProcessor, StatefulProcessorHandle
@@ -532,7 +533,9 @@ class TransformWithStateInPandasTestsMixin:
             ProcTimeStatefulProcessor(), check_results
         )
 
-    def _test_transform_with_state_in_pandas_event_time(self, 
stateful_processor, check_results):
+    def _test_transform_with_state_in_pandas_event_time(
+        self, stateful_processor, check_results, time_mode="eventtime"
+    ):
         import pyspark.sql.functions as f
 
         input_path = tempfile.mkdtemp()
@@ -547,6 +550,7 @@ class TransformWithStateInPandasTestsMixin:
 
         def prepare_batch3(input_path):
             with open(input_path + "/text-test2.txt", "w") as fw:
+                fw.write("a, 2\n")
                 fw.write("a, 11\n")
                 fw.write("a, 13\n")
                 fw.write("a, 15\n")
@@ -577,7 +581,7 @@ class TransformWithStateInPandasTestsMixin:
                 statefulProcessor=stateful_processor,
                 outputStructType=output_schema,
                 outputMode="Update",
-                timeMode="eventtime",
+                timeMode=time_mode,
             )
             .writeStream.queryName(query_name)
             .foreachBatch(check_results)
@@ -625,6 +629,32 @@ class TransformWithStateInPandasTestsMixin:
             EventTimeStatefulProcessor(), check_results
         )
 
+    def test_transform_with_state_with_wmark_and_non_event_time(self):
+        def check_results(batch_df, batch_id):
+            if batch_id == 0:
+                # watermark for late event = 0 and min event = 20
+                assert set(batch_df.sort("id").collect()) == {
+                    Row(id="a", timestamp="20"),
+                }
+            elif batch_id == 1:
+                # watermark for late event = 0 and min event = 4
+                assert set(batch_df.sort("id").collect()) == {
+                    Row(id="a", timestamp="4"),
+                }
+            else:
+                # watermark for late event = 10 and min event = 2 with no 
filtering
+                assert set(batch_df.sort("id").collect()) == {
+                    Row(id="a", timestamp="2"),
+                }
+
+        self._test_transform_with_state_in_pandas_event_time(
+            MinEventTimeStatefulProcessor(), check_results, "None"
+        )
+
+        self._test_transform_with_state_in_pandas_event_time(
+            MinEventTimeStatefulProcessor(), check_results, "ProcessingTime"
+        )
+
     def _test_transform_with_state_init_state_in_pandas(
         self,
         stateful_processor,
@@ -1611,6 +1641,33 @@ class EventTimeStatefulProcessor(StatefulProcessor):
         pass
 
 
+# A stateful processor that output the min event time it has seen.
+class MinEventTimeStatefulProcessor(StatefulProcessor):
+    def init(self, handle: StatefulProcessorHandle) -> None:
+        state_schema = StructType([StructField("value", StringType(), True)])
+        self.handle = handle
+        self.min_state = handle.getValueState("min_state", state_schema)
+
+    def handleInputRows(self, key, rows, timerValues) -> 
Iterator[pd.DataFrame]:
+        timestamp_list = []
+        for pdf in rows:
+            # int64 will represent timestamp in nanosecond, restore to second
+            timestamp_list.extend((pdf["eventTime"].astype("int64") // 
10**9).tolist())
+
+        if self.min_state.exists():
+            cur_min = int(self.min_state.get()[0])
+        else:
+            cur_min = sys.maxsize
+        min_event_time = str(min(cur_min, min(timestamp_list)))
+
+        self.min_state.update((min_event_time,))
+
+        yield pd.DataFrame({"id": key, "timestamp": min_event_time})
+
+    def close(self) -> None:
+        pass
+
+
 # A stateful processor that output the accumulation of count of input rows; 
register
 # processing timer and clear the counter if timer expires.
 class ProcTimeStatefulProcessor(StatefulProcessor):
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPandasExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPandasExec.scala
index e77035e31ccb..909de5103c91 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPandasExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPandasExec.scala
@@ -380,7 +380,7 @@ case class TransformWithStateInPandasExec(
 
     // If timeout is based on event time, then filter late data based on 
watermark
     val filteredIter = watermarkPredicateForDataForLateEvents match {
-      case Some(predicate) =>
+      case Some(predicate) if timeMode == TimeMode.EventTime() =>
         applyRemovingRowsOlderThanWatermark(dataIterator, predicate)
       case _ =>
         dataIterator
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala
index 6e0502e18659..3443ca21535c 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala
@@ -381,7 +381,7 @@ case class TransformWithStateExec(
 
     // If timeout is based on event time, then filter late data based on 
watermark
     val filteredIter = watermarkPredicateForDataForLateEvents match {
-      case Some(predicate) =>
+      case Some(predicate) if timeMode == TimeMode.EventTime() =>
         applyRemovingRowsOlderThanWatermark(iter, predicate)
       case _ =>
         iter
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala
index e98e47cae427..6f1da588eb53 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala
@@ -611,6 +611,30 @@ class MaxEventTimeStatefulProcessor
   }
 }
 
+class MinEventTimeStatefulProcessor
+  extends StatefulProcessor[String, (String, Long), (String, Int)]
+  with Logging {
+  @transient var _minEventTimeState: ValueState[Long] = _
+
+  override def init(
+    outputMode: OutputMode,
+    timeMode: TimeMode): Unit = {
+    _minEventTimeState = getHandle.getValueState[Long]("minEventTimeState",
+      Encoders.scalaLong, TTLConfig.NONE)
+  }
+
+  override def handleInputRows(
+    key: String,
+    inputRows: Iterator[(String, Long)],
+    timerValues: TimerValues): Iterator[(String, Int)] = {
+    val valuesSeq = inputRows.toSeq
+    val minEventTimeSec = math.min(valuesSeq.map(_._2).min,
+      Option(_minEventTimeState.get()).getOrElse(Long.MaxValue))
+    _minEventTimeState.update(minEventTimeSec)
+    Iterator((key, minEventTimeSec.toInt))
+  }
+}
+
 class RunningCountMostRecentStatefulProcessor
   extends StatefulProcessor[String, (String, String), (String, String, String)]
   with Logging {
@@ -2358,4 +2382,50 @@ class TransformWithStateValidationSuite extends 
StateStoreMetricsTest {
     assert(ex1.getMessage.contains("Failed to find time values"))
     TransformWithStateVariableUtils.validateTimeMode(TimeMode.EventTime(), 
Some(10L))
   }
+
+  Seq(TimeMode.None(), TimeMode.ProcessingTime()).foreach { timeMode =>
+    test(s"transformWithState - using watermark but time mode as $timeMode 
should not perform " +
+      s"late record filtering") {
+      withSQLConf(
+        SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
+          classOf[RocksDBStateStoreProvider].getName,
+        SQLConf.SHUFFLE_PARTITIONS.key ->
+          TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString
+      ) {
+        val inputData = MemoryStream[(String, Int)]
+        val result =
+          inputData.toDS()
+            .select($"_1".as("key"), timestamp_seconds($"_2").as("eventTime"))
+            .withWatermark("eventTime", "10 seconds")
+            .as[(String, Long)]
+            .groupByKey(_._1)
+            .transformWithState(
+              new MinEventTimeStatefulProcessor(),
+              timeMode,
+              OutputMode.Update())
+
+        testStream(result, OutputMode.Update())(
+          StartStream(Trigger.ProcessingTime("1 second"), triggerClock = new 
StreamManualClock),
+
+          AddData(inputData, ("a", 11), ("a", 13), ("a", 15)),
+          AdvanceManualClock(1 * 1000),
+          // Min event time = 15. Watermark = 15 - 10 = 5.
+          CheckNewAnswer(("a", 11)), // Output = min event time of a
+
+          AddData(inputData, ("a", 4)), // Add data older than watermark for 
"a"
+          AdvanceManualClock(1 * 1000),
+          CheckNewAnswer(("a", 4)), // Data should not get filtered and output 
will be 4
+
+          AddData(inputData, ("a", 1)), // Add data older than watermark for 
"a"
+          AdvanceManualClock(1 * 1000),
+          CheckNewAnswer(("a", 1)), // Data should not get filtered and output 
will be 1
+
+          AddData(inputData, ("a", 85)), // Add data newer than watermark for 
"a"
+          AdvanceManualClock(1 * 1000),
+          CheckNewAnswer(("a", 1)), // Min event time should still be 1
+          StopStream
+        )
+      }
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to