This is an automated email from the ASF dual-hosted git repository.
kabhwan pushed a commit to branch branch-4.0
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-4.0 by this push:
new 002a9bf3e617 [SPARK-51758][SS] Apply late record filtering based on
watermark only if timeMode is passed as EventTime to the transformWithState
operator
002a9bf3e617 is described below
commit 002a9bf3e61742933ee14cebbf3f2da4863ed34b
Author: Anish Shrigondekar <[email protected]>
AuthorDate: Tue Apr 15 06:33:11 2025 +0900
[SPARK-51758][SS] Apply late record filtering based on watermark only if
timeMode is passed as EventTime to the transformWithState operator
### What changes were proposed in this pull request?
Apply late record filtering based on watermark only if timeMode is passed
as EventTime to the transformWithState operator
### Why are the changes needed?
Without this, we might filter records even if timeMode is passed as
None/ProcessingTime which might be counter intuitive for users.
### Does this PR introduce _any_ user-facing change?
Yes
### How was this patch tested?
Added unit tests
```
===== POSSIBLE THREAD LEAK IN SUITE
o.a.s.sql.streaming.TransformWithStateValidationSuite, threads:
ForkJoinPool.commonPool-worker-2 (daemon=true), files-client-8-1 (daemon=true),
rpc-boss-3-1 (daemon=true), shuffle-boss-6-1 (daemon=true),
ForkJoinPool.commonPool-worker-1 (daemon=true), Cleaner-0 (daemon=true),
ForkJoinPool.commonPool-worker-3 (daemon=true) =====
[info] Run completed in 15 seconds, 292 milliseconds.
[info] Total number of tests run: 4
[info] Suites: completed 1, aborted 0
[info] Tests: succeeded 4, failed 0, canceled 0, ignored 0, pending 0
[info] All tests passed.
[success] Total time: 44 s, co
```
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #50550 from anishshri-db/task/SPARK-51758.
Authored-by: Anish Shrigondekar <[email protected]>
Signed-off-by: Jungtaek Lim <[email protected]>
(cherry picked from commit bb1a63a8c361d8748b6eab1634271b9ebdc3e619)
Signed-off-by: Jungtaek Lim <[email protected]>
---
.../pandas/test_pandas_transform_with_state.py | 61 ++++++++++++++++++-
.../streaming/TransformWithStateInPandasExec.scala | 2 +-
.../streaming/TransformWithStateExec.scala | 2 +-
.../sql/streaming/TransformWithStateSuite.scala | 70 ++++++++++++++++++++++
4 files changed, 131 insertions(+), 4 deletions(-)
diff --git
a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py
b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py
index dc104704e169..44ea90ab6659 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py
@@ -17,6 +17,7 @@
import json
import os
+import sys
import time
import tempfile
from pyspark.sql.streaming import StatefulProcessor, StatefulProcessorHandle
@@ -532,7 +533,9 @@ class TransformWithStateInPandasTestsMixin:
ProcTimeStatefulProcessor(), check_results
)
- def _test_transform_with_state_in_pandas_event_time(self,
stateful_processor, check_results):
+ def _test_transform_with_state_in_pandas_event_time(
+ self, stateful_processor, check_results, time_mode="eventtime"
+ ):
import pyspark.sql.functions as f
input_path = tempfile.mkdtemp()
@@ -547,6 +550,7 @@ class TransformWithStateInPandasTestsMixin:
def prepare_batch3(input_path):
with open(input_path + "/text-test2.txt", "w") as fw:
+ fw.write("a, 2\n")
fw.write("a, 11\n")
fw.write("a, 13\n")
fw.write("a, 15\n")
@@ -577,7 +581,7 @@ class TransformWithStateInPandasTestsMixin:
statefulProcessor=stateful_processor,
outputStructType=output_schema,
outputMode="Update",
- timeMode="eventtime",
+ timeMode=time_mode,
)
.writeStream.queryName(query_name)
.foreachBatch(check_results)
@@ -625,6 +629,32 @@ class TransformWithStateInPandasTestsMixin:
EventTimeStatefulProcessor(), check_results
)
+ def test_transform_with_state_with_wmark_and_non_event_time(self):
+ def check_results(batch_df, batch_id):
+ if batch_id == 0:
+ # watermark for late event = 0 and min event = 20
+ assert set(batch_df.sort("id").collect()) == {
+ Row(id="a", timestamp="20"),
+ }
+ elif batch_id == 1:
+ # watermark for late event = 0 and min event = 4
+ assert set(batch_df.sort("id").collect()) == {
+ Row(id="a", timestamp="4"),
+ }
+ else:
+ # watermark for late event = 10 and min event = 2 with no
filtering
+ assert set(batch_df.sort("id").collect()) == {
+ Row(id="a", timestamp="2"),
+ }
+
+ self._test_transform_with_state_in_pandas_event_time(
+ MinEventTimeStatefulProcessor(), check_results, "None"
+ )
+
+ self._test_transform_with_state_in_pandas_event_time(
+ MinEventTimeStatefulProcessor(), check_results, "ProcessingTime"
+ )
+
def _test_transform_with_state_init_state_in_pandas(
self,
stateful_processor,
@@ -1611,6 +1641,33 @@ class EventTimeStatefulProcessor(StatefulProcessor):
pass
+# A stateful processor that output the min event time it has seen.
+class MinEventTimeStatefulProcessor(StatefulProcessor):
+ def init(self, handle: StatefulProcessorHandle) -> None:
+ state_schema = StructType([StructField("value", StringType(), True)])
+ self.handle = handle
+ self.min_state = handle.getValueState("min_state", state_schema)
+
+ def handleInputRows(self, key, rows, timerValues) ->
Iterator[pd.DataFrame]:
+ timestamp_list = []
+ for pdf in rows:
+ # int64 will represent timestamp in nanosecond, restore to second
+ timestamp_list.extend((pdf["eventTime"].astype("int64") //
10**9).tolist())
+
+ if self.min_state.exists():
+ cur_min = int(self.min_state.get()[0])
+ else:
+ cur_min = sys.maxsize
+ min_event_time = str(min(cur_min, min(timestamp_list)))
+
+ self.min_state.update((min_event_time,))
+
+ yield pd.DataFrame({"id": key, "timestamp": min_event_time})
+
+ def close(self) -> None:
+ pass
+
+
# A stateful processor that output the accumulation of count of input rows;
register
# processing timer and clear the counter if timer expires.
class ProcTimeStatefulProcessor(StatefulProcessor):
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPandasExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPandasExec.scala
index e77035e31ccb..909de5103c91 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPandasExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPandasExec.scala
@@ -380,7 +380,7 @@ case class TransformWithStateInPandasExec(
// If timeout is based on event time, then filter late data based on
watermark
val filteredIter = watermarkPredicateForDataForLateEvents match {
- case Some(predicate) =>
+ case Some(predicate) if timeMode == TimeMode.EventTime() =>
applyRemovingRowsOlderThanWatermark(dataIterator, predicate)
case _ =>
dataIterator
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala
index 6e0502e18659..3443ca21535c 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala
@@ -381,7 +381,7 @@ case class TransformWithStateExec(
// If timeout is based on event time, then filter late data based on
watermark
val filteredIter = watermarkPredicateForDataForLateEvents match {
- case Some(predicate) =>
+ case Some(predicate) if timeMode == TimeMode.EventTime() =>
applyRemovingRowsOlderThanWatermark(iter, predicate)
case _ =>
iter
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala
index e98e47cae427..6f1da588eb53 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala
@@ -611,6 +611,30 @@ class MaxEventTimeStatefulProcessor
}
}
+class MinEventTimeStatefulProcessor
+ extends StatefulProcessor[String, (String, Long), (String, Int)]
+ with Logging {
+ @transient var _minEventTimeState: ValueState[Long] = _
+
+ override def init(
+ outputMode: OutputMode,
+ timeMode: TimeMode): Unit = {
+ _minEventTimeState = getHandle.getValueState[Long]("minEventTimeState",
+ Encoders.scalaLong, TTLConfig.NONE)
+ }
+
+ override def handleInputRows(
+ key: String,
+ inputRows: Iterator[(String, Long)],
+ timerValues: TimerValues): Iterator[(String, Int)] = {
+ val valuesSeq = inputRows.toSeq
+ val minEventTimeSec = math.min(valuesSeq.map(_._2).min,
+ Option(_minEventTimeState.get()).getOrElse(Long.MaxValue))
+ _minEventTimeState.update(minEventTimeSec)
+ Iterator((key, minEventTimeSec.toInt))
+ }
+}
+
class RunningCountMostRecentStatefulProcessor
extends StatefulProcessor[String, (String, String), (String, String, String)]
with Logging {
@@ -2358,4 +2382,50 @@ class TransformWithStateValidationSuite extends
StateStoreMetricsTest {
assert(ex1.getMessage.contains("Failed to find time values"))
TransformWithStateVariableUtils.validateTimeMode(TimeMode.EventTime(),
Some(10L))
}
+
+ Seq(TimeMode.None(), TimeMode.ProcessingTime()).foreach { timeMode =>
+ test(s"transformWithState - using watermark but time mode as $timeMode
should not perform " +
+ s"late record filtering") {
+ withSQLConf(
+ SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
+ classOf[RocksDBStateStoreProvider].getName,
+ SQLConf.SHUFFLE_PARTITIONS.key ->
+ TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString
+ ) {
+ val inputData = MemoryStream[(String, Int)]
+ val result =
+ inputData.toDS()
+ .select($"_1".as("key"), timestamp_seconds($"_2").as("eventTime"))
+ .withWatermark("eventTime", "10 seconds")
+ .as[(String, Long)]
+ .groupByKey(_._1)
+ .transformWithState(
+ new MinEventTimeStatefulProcessor(),
+ timeMode,
+ OutputMode.Update())
+
+ testStream(result, OutputMode.Update())(
+ StartStream(Trigger.ProcessingTime("1 second"), triggerClock = new
StreamManualClock),
+
+ AddData(inputData, ("a", 11), ("a", 13), ("a", 15)),
+ AdvanceManualClock(1 * 1000),
+ // Min event time = 15. Watermark = 15 - 10 = 5.
+ CheckNewAnswer(("a", 11)), // Output = min event time of a
+
+ AddData(inputData, ("a", 4)), // Add data older than watermark for
"a"
+ AdvanceManualClock(1 * 1000),
+ CheckNewAnswer(("a", 4)), // Data should not get filtered and output
will be 4
+
+ AddData(inputData, ("a", 1)), // Add data older than watermark for
"a"
+ AdvanceManualClock(1 * 1000),
+ CheckNewAnswer(("a", 1)), // Data should not get filtered and output
will be 1
+
+ AddData(inputData, ("a", 85)), // Add data newer than watermark for
"a"
+ AdvanceManualClock(1 * 1000),
+ CheckNewAnswer(("a", 1)), // Min event time should still be 1
+ StopStream
+ )
+ }
+ }
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]