This is an automated email from the ASF dual-hosted git repository.

kabhwan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 406d0e243cf [SPARK-40925][SQL][SS] Fix stateful operator late record 
filtering
406d0e243cf is described below

commit 406d0e243cfec9b29df946e1a0e20ed5fe25e152
Author: Alex Balikov <91913242+alex-bali...@users.noreply.github.com>
AuthorDate: Mon Oct 31 10:13:50 2022 +0900

    [SPARK-40925][SQL][SS] Fix stateful operator late record filtering
    
    ### What changes were proposed in this pull request?
    
    This PR fixes the input late record filtering done by stateful operators to 
allow for chaining of stateful operators. Currently stateful operators are 
initialized with the current microbatch watermark and perform both input late 
record filtering and state eviction (e.g. producing aggregations) using the 
same watermark value. The state evicted (or aggregates produced) due to 
watermark advancing is behind the watermark and thus effectively late - if a 
following stateful operator consume [...]
    
    This PR provides two watermark values to the stateful operators - one from 
the previous microbatch to be used for late record filtering and the one from 
the current microbatch (as in the existing code) to be used for state eviction. 
This solves the above problem of the broken late record filtering.
    
    Note that this PR still does not solve the issue of time-interval stream 
join producing records delayed against the watermark. Therefore time-interval 
streaming join followed by stateful operators is still not supported. That will 
be fixed in a follow up PR (and a SPIP) effectively replacing the single global 
watermark with conceptually watermarks per operator.
    
    Also, the stateful operator chains unblocked by this PR (e.g. a chain of 
window aggregations) are still blocked by the unsupported operations checker. 
The new test for these scenarios - MultiStatefulOperatorsSuite has to 
explicitly disable the unsupported ops check. This again will be fixed in a 
follow-up PR.
    
    ### Why are the changes needed?
    
    The PR allows Spark Structured Streaming to support chaining of stateful 
operators e.g. chaining of time window aggregations which is a meaningful 
streaming scenario.
    
    ### Does this PR introduce _any_ user-facing change?
    
    With this PR, chains of stateful operators will be supported in Spark 
Structured Streaming.
    
    ### How was this patch tested?
    
    Added a new test suite - MultiStatefulOperatorsSuite
    
    Closes #38405 from alex-balikov/multiple_stateful-ops-base.
    
    Lead-authored-by: Alex Balikov 
<91913242+alex-bali...@users.noreply.github.com>
    Co-authored-by: Alex Balikov <alex.bali...@databricks.com>
    Signed-off-by: Jungtaek Lim <kabhwan.opensou...@gmail.com>
---
 .../sql/catalyst/expressions/SessionWindow.scala   |  24 +-
 .../sql/catalyst/expressions/TimeWindow.scala      |  20 +-
 .../org/apache/spark/sql/internal/SQLConf.scala    |  16 +
 .../spark/sql/execution/QueryExecution.scala       |   2 +-
 .../spark/sql/execution/SparkStrategies.scala      |   6 +-
 .../spark/sql/execution/aggregate/AggUtils.scala   |   9 +-
 .../FlatMapGroupsInPandasWithStateExec.scala       |  13 +-
 .../streaming/FlatMapGroupsWithStateExec.scala     |  27 +-
 .../execution/streaming/IncrementalExecution.scala |  44 ++-
 .../execution/streaming/MicroBatchExecution.scala  |   1 +
 .../sql/execution/streaming/OffsetSeqLog.scala     |   4 +
 .../streaming/StreamingSymmetricHashJoinExec.scala |  15 +-
 .../StreamingSymmetricHashJoinHelper.scala         |   6 +-
 .../streaming/continuous/ContinuousExecution.scala |   1 +
 .../streaming/sources/MicroBatchWrite.scala        |   2 +-
 .../execution/streaming/statefulOperators.scala    | 108 +++--
 .../streaming/FlatMapGroupsWithStateSuite.scala    |   2 +-
 .../streaming/MultiStatefulOperatorsSuite.scala    | 440 +++++++++++++++++++++
 .../spark/sql/streaming/StreamingJoinSuite.scala   |   2 +-
 19 files changed, 661 insertions(+), 81 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SessionWindow.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SessionWindow.scala
index 77e8dfde87b..02273b0c461 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SessionWindow.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SessionWindow.scala
@@ -68,11 +68,29 @@ case class SessionWindow(timeColumn: Expression, 
gapDuration: Expression) extend
   with Unevaluable
   with NonSQLExpression {
 
+  private def inputTypeOnTimeColumn: AbstractDataType = {
+    TypeCollection(
+      AnyTimestampType,
+      // Below two types cover both time window & session window, since they 
produce the same type
+      // of output as window column.
+      new StructType()
+        .add(StructField("start", TimestampType))
+        .add(StructField("end", TimestampType)),
+      new StructType()
+        .add(StructField("start", TimestampNTZType))
+        .add(StructField("end", TimestampNTZType))
+    )
+  }
+
+  // NOTE: if the window column is given as a time column, we resolve it to 
the point of time,
+  // which resolves to either TimestampType or TimestampNTZType. That means, 
timeColumn may not
+  // be "resolved", so it is safe to not rely on the data type of timeColumn 
directly.
+
   override def children: Seq[Expression] = Seq(timeColumn, gapDuration)
-  override def inputTypes: Seq[AbstractDataType] = Seq(AnyTimestampType, 
AnyDataType)
+  override def inputTypes: Seq[AbstractDataType] = Seq(inputTypeOnTimeColumn, 
AnyDataType)
   override def dataType: DataType = new StructType()
-    .add(StructField("start", timeColumn.dataType))
-    .add(StructField("end", timeColumn.dataType))
+    .add(StructField("start", children.head.dataType))
+    .add(StructField("end", children.head.dataType))
 
   // This expression is replaced in the analyzer.
   override lazy val resolved = false
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala
index 93c1074dfbe..bc9b7de7464 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala
@@ -96,8 +96,26 @@ case class TimeWindow(
     this(timeColumn, windowDuration, windowDuration)
   }
 
+  private def inputTypeOnTimeColumn: AbstractDataType = {
+    TypeCollection(
+      AnyTimestampType,
+      // Below two types cover both time window & session window, since they 
produce the same type
+      // of output as window column.
+      new StructType()
+        .add(StructField("start", TimestampType))
+        .add(StructField("end", TimestampType)),
+      new StructType()
+        .add(StructField("start", TimestampNTZType))
+        .add(StructField("end", TimestampNTZType))
+    )
+  }
+
+  // NOTE: if the window column is given as a time column, we resolve it to 
the point of time,
+  // which resolves to either TimestampType or TimestampNTZType. That means, 
timeColumn may not
+  // be "resolved", so it is safe to not rely on the data type of timeColumn 
directly.
+
   override def child: Expression = timeColumn
-  override def inputTypes: Seq[AbstractDataType] = Seq(AnyTimestampType)
+  override def inputTypes: Seq[AbstractDataType] = Seq(inputTypeOnTimeColumn)
   override def dataType: DataType = new StructType()
     .add(StructField("start", child.dataType))
     .add(StructField("end", child.dataType))
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 0a60c6b0265..3854f3190a7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -1941,6 +1941,22 @@ object SQLConf {
       .booleanConf
       .createWithDefault(true)
 
+  val STATEFUL_OPERATOR_ALLOW_MULTIPLE =
+    buildConf("spark.sql.streaming.statefulOperator.allowMultiple")
+      .internal()
+      .doc("When true, multiple stateful operators are allowed to be present 
in a streaming " +
+        "pipeline. The support for multiple stateful operators introduces a 
minor (semantically " +
+        "correct) change in respect to late record filtering - late records 
are detected and " +
+        "filtered in respect to the watermark from the previous microbatch 
instead of the " +
+        "current one. This is a behavior change for Spark streaming pipelines 
and we allow " +
+        "users to revert to the previous behavior of late record filtering 
(late records are " +
+        "detected and filtered by comparing with the current microbatch 
watermark) by setting " +
+        "the flag value to false. In this mode, only a single stateful 
operator will be allowed " +
+        "in a streaming pipeline.")
+      .version("3.4.0")
+      .booleanConf
+      .createWithDefault(true)
+
   val STATEFUL_OPERATOR_USE_STRICT_DISTRIBUTION =
     buildConf("spark.sql.streaming.statefulOperator.useStrictDistribution")
       .internal()
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
index 10b763b1b51..8bf5d3d317b 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
@@ -229,7 +229,7 @@ class QueryExecution(
       // output mode does not matter since there is no `Sink`.
       new IncrementalExecution(
         sparkSession, logical, OutputMode.Append(), "<unknown>",
-        UUID.randomUUID, UUID.randomUUID, 0, OffsetSeqMetadata(0, 0))
+        UUID.randomUUID, UUID.randomUUID, 0, None, OffsetSeqMetadata(0, 0))
     } else {
       this
     }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 03e722a86fb..b96e47846fc 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -678,7 +678,8 @@ abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
         val execPlan = FlatMapGroupsWithStateExec(
           func, keyDeser, valueDeser, sDeser, groupAttr, stateGroupAttr, 
dataAttr, sda, outputAttr,
           None, stateEnc, stateVersion, outputMode, timeout, batchTimestampMs 
= None,
-          eventTimeWatermark = None, planLater(initialState), hasInitialState, 
planLater(child)
+          eventTimeWatermarkForLateEvents = None, 
eventTimeWatermarkForEviction = None,
+          planLater(initialState), hasInitialState, planLater(child)
         )
         execPlan :: Nil
       case _ =>
@@ -697,7 +698,8 @@ abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
         val stateVersion = 
conf.getConf(SQLConf.FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION)
         val execPlan = python.FlatMapGroupsInPandasWithStateExec(
           func, groupAttr, outputAttr, stateType, None, stateVersion, 
outputMode, timeout,
-          batchTimestampMs = None, eventTimeWatermark = None, planLater(child)
+          batchTimestampMs = None, eventTimeWatermarkForLateEvents = None,
+          eventTimeWatermarkForEviction = None, planLater(child)
         )
         execPlan :: Nil
       case _ =>
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala
index 579a00c7996..557f0e897ee 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala
@@ -369,7 +369,8 @@ object AggUtils {
         groupingAttributes,
         stateInfo = None,
         outputMode = None,
-        eventTimeWatermark = None,
+        eventTimeWatermarkForLateEvents = None,
+        eventTimeWatermarkForEviction = None,
         stateFormatVersion = stateFormatVersion,
         partialMerged2)
 
@@ -472,7 +473,8 @@ object AggUtils {
 
     // shuffle & sort happens here: most of details are also handled in this 
physical plan
     val restored = 
SessionWindowStateStoreRestoreExec(groupingWithoutSessionAttributes,
-      sessionExpression.toAttribute, stateInfo = None, eventTimeWatermark = 
None,
+      sessionExpression.toAttribute, stateInfo = None,
+      eventTimeWatermarkForLateEvents = None, eventTimeWatermarkForEviction = 
None,
       stateFormatVersion, partialMerged1)
 
     val mergedSessions = {
@@ -501,7 +503,8 @@ object AggUtils {
       sessionExpression.toAttribute,
       stateInfo = None,
       outputMode = None,
-      eventTimeWatermark = None,
+      eventTimeWatermarkForLateEvents = None,
+      eventTimeWatermarkForEviction = None,
       stateFormatVersion, mergedSessions)
 
     val finalAndCompleteAggregate: SparkPlan = {
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala
index 3b096f07241..bc1a5ae17e4 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala
@@ -48,7 +48,8 @@ import org.apache.spark.util.CompletionIterator
  * @param outputMode the output mode of `functionExpr`
  * @param timeoutConf used to timeout groups that have not received data in a 
while
  * @param batchTimestampMs processing timestamp of the current batch.
- * @param eventTimeWatermark event time watermark for the current batch
+ * @param eventTimeWatermarkForLateEvents event time watermark for filtering 
late events
+ * @param eventTimeWatermarkForEviction event time watermark for state eviction
  * @param child logical plan of the underlying data
  */
 case class FlatMapGroupsInPandasWithStateExec(
@@ -61,9 +62,9 @@ case class FlatMapGroupsInPandasWithStateExec(
     outputMode: OutputMode,
     timeoutConf: GroupStateTimeout,
     batchTimestampMs: Option[Long],
-    eventTimeWatermark: Option[Long],
-    child: SparkPlan)
-  extends UnaryExecNode with PythonSQLMetrics with 
FlatMapGroupsWithStateExecBase {
+    eventTimeWatermarkForLateEvents: Option[Long],
+    eventTimeWatermarkForEviction: Option[Long],
+    child: SparkPlan) extends UnaryExecNode with 
FlatMapGroupsWithStateExecBase {
 
   // TODO(SPARK-40444): Add the support of initial state.
   override protected val initialStateDeserializer: Expression = null
@@ -132,7 +133,7 @@ case class FlatMapGroupsInPandasWithStateExec(
       if (isTimeoutEnabled) {
         val timeoutThreshold = timeoutConf match {
           case ProcessingTimeTimeout => batchTimestampMs.get
-          case EventTimeTimeout => eventTimeWatermark.get
+          case EventTimeTimeout => eventTimeWatermarkForEviction.get
           case _ =>
             throw new IllegalStateException(
               s"Cannot filter timed out keys for $timeoutConf")
@@ -176,7 +177,7 @@ case class FlatMapGroupsInPandasWithStateExec(
         val groupedState = GroupStateImpl.createForStreaming(
           Option(stateData.stateObj).map { r => assert(r.isInstanceOf[Row]); r 
},
           batchTimestampMs.getOrElse(NO_TIMESTAMP),
-          eventTimeWatermark.getOrElse(NO_TIMESTAMP),
+          eventTimeWatermarkForEviction.getOrElse(NO_TIMESTAMP),
           timeoutConf,
           hasTimedOut = hasTimedOut,
           watermarkPresent).asInstanceOf[GroupStateImpl[Row]]
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
index 790a652f211..138029e76c1 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
@@ -54,8 +54,8 @@ trait FlatMapGroupsWithStateExecBase
   protected val outputMode: OutputMode
   protected val timeoutConf: GroupStateTimeout
   protected val batchTimestampMs: Option[Long]
-  val eventTimeWatermark: Option[Long]
-
+  val eventTimeWatermarkForLateEvents: Option[Long]
+  val eventTimeWatermarkForEviction: Option[Long]
   protected val isTimeoutEnabled: Boolean = timeoutConf != NoTimeout
   protected val watermarkPresent: Boolean = child.output.exists {
     case a: Attribute if a.metadata.contains(EventTimeWatermark.delayKey) => 
true
@@ -96,7 +96,8 @@ trait FlatMapGroupsWithStateExecBase
         true  // Always run batches to process timeouts
       case EventTimeTimeout =>
         // Process another non-data batch only if the watermark has changed in 
this executed plan
-        eventTimeWatermark.isDefined && newMetadata.batchWatermarkMs > 
eventTimeWatermark.get
+        eventTimeWatermarkForEviction.isDefined &&
+          newMetadata.batchWatermarkMs > eventTimeWatermarkForEviction.get
       case _ =>
         false
     }
@@ -125,7 +126,7 @@ trait FlatMapGroupsWithStateExecBase
     var timeoutProcessingStartTimeNs = currentTimeNs
 
     // If timeout is based on event time, then filter late data based on 
watermark
-    val filteredIter = watermarkPredicateForData match {
+    val filteredIter = watermarkPredicateForDataForLateEvents match {
       case Some(predicate) if timeoutConf == EventTimeTimeout =>
         applyRemovingRowsOlderThanWatermark(iter, predicate)
       case _ =>
@@ -189,8 +190,12 @@ trait FlatMapGroupsWithStateExecBase
       case ProcessingTimeTimeout =>
         require(batchTimestampMs.nonEmpty)
       case EventTimeTimeout =>
-        require(eventTimeWatermark.nonEmpty) // watermark value has been 
populated
-        require(watermarkExpression.nonEmpty) // input schema has watermark 
attribute
+        // watermark value has been populated
+        require(eventTimeWatermarkForLateEvents.nonEmpty)
+        require(eventTimeWatermarkForEviction.nonEmpty)
+        // input schema has watermark attribute
+        require(watermarkExpressionForLateEvents.nonEmpty)
+        require(watermarkExpressionForEviction.nonEmpty)
       case _ =>
     }
 
@@ -310,7 +315,7 @@ trait FlatMapGroupsWithStateExecBase
       if (isTimeoutEnabled) {
         val timeoutThreshold = timeoutConf match {
           case ProcessingTimeTimeout => batchTimestampMs.get
-          case EventTimeTimeout => eventTimeWatermark.get
+          case EventTimeTimeout => eventTimeWatermarkForEviction.get
           case _ =>
             throw new IllegalStateException(
               s"Cannot filter timed out keys for $timeoutConf")
@@ -354,7 +359,8 @@ trait FlatMapGroupsWithStateExecBase
  * @param outputMode the output mode of `func`
  * @param timeoutConf used to timeout groups that have not received data in a 
while
  * @param batchTimestampMs processing timestamp of the current batch.
- * @param eventTimeWatermark event time watermark for the current batch
+ * @param eventTimeWatermarkForLateEvents event time watermark for filtering 
late events
+ * @param eventTimeWatermarkForEviction event time watermark for state eviction
  * @param initialState the user specified initial state
  * @param hasInitialState indicates whether the initial state is provided or 
not
  * @param child the physical plan for the underlying data
@@ -375,7 +381,8 @@ case class FlatMapGroupsWithStateExec(
     outputMode: OutputMode,
     timeoutConf: GroupStateTimeout,
     batchTimestampMs: Option[Long],
-    eventTimeWatermark: Option[Long],
+    eventTimeWatermarkForLateEvents: Option[Long],
+    eventTimeWatermarkForEviction: Option[Long],
     initialState: SparkPlan,
     hasInitialState: Boolean,
     child: SparkPlan)
@@ -410,7 +417,7 @@ case class FlatMapGroupsWithStateExec(
       val groupState = GroupStateImpl.createForStreaming(
         Option(stateData.stateObj),
         batchTimestampMs.getOrElse(NO_TIMESTAMP),
-        eventTimeWatermark.getOrElse(NO_TIMESTAMP),
+        eventTimeWatermarkForEviction.getOrElse(NO_TIMESTAMP),
         timeoutConf,
         hasTimedOut,
         watermarkPresent)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
index f386282a0b3..574709d05b0 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
@@ -48,6 +48,7 @@ class IncrementalExecution(
     val queryId: UUID,
     val runId: UUID,
     val currentBatchId: Long,
+    val prevOffsetSeqMetadata: Option[OffsetSeqMetadata],
     val offsetSeqMetadata: OffsetSeqMetadata)
   extends QueryExecution(sparkSession, logicalPlan) with Logging {
 
@@ -112,6 +113,17 @@ class IncrementalExecution(
       numStateStores)
   }
 
+  // Watermarks to use for late record filtering and state eviction in 
stateful operators.
+  // Using the previous watermark for late record filtering is a Spark 
behavior change so we allow
+  // this to be disabled.
+  val eventTimeWatermarkForEviction = offsetSeqMetadata.batchWatermarkMs
+  val eventTimeWatermarkForLateEvents =
+    if (sparkSession.conf.get(SQLConf.STATEFUL_OPERATOR_ALLOW_MULTIPLE)) {
+      prevOffsetSeqMetadata.getOrElse(offsetSeqMetadata).batchWatermarkMs
+    } else {
+      eventTimeWatermarkForEviction
+    }
+
   /** Locates save/restore pairs surrounding aggregation. */
   val state = new Rule[SparkPlan] {
 
@@ -158,7 +170,7 @@ class IncrementalExecution(
       case a: UpdatingSessionsExec if a.isStreaming =>
         a.copy(numShufflePartitions = Some(numStateStores))
 
-      case StateStoreSaveExec(keys, None, None, None, stateFormatVersion,
+      case StateStoreSaveExec(keys, None, None, None, None, stateFormatVersion,
              UnaryExecNode(agg,
                StateStoreRestoreExec(_, None, _, child))) =>
         val aggStateInfo = nextStatefulOperationStateInfo
@@ -166,7 +178,8 @@ class IncrementalExecution(
           keys,
           Some(aggStateInfo),
           Some(outputMode),
-          Some(offsetSeqMetadata.batchWatermarkMs),
+          eventTimeWatermarkForLateEvents = 
Some(eventTimeWatermarkForLateEvents),
+          eventTimeWatermarkForEviction = Some(eventTimeWatermarkForEviction),
           stateFormatVersion,
           agg.withNewChildren(
             StateStoreRestoreExec(
@@ -175,32 +188,36 @@ class IncrementalExecution(
               stateFormatVersion,
               child) :: Nil))
 
-      case SessionWindowStateStoreSaveExec(keys, session, None, None, None, 
stateFormatVersion,
+      case SessionWindowStateStoreSaveExec(keys, session, None, None, None, 
None,
+        stateFormatVersion,
         UnaryExecNode(agg,
-        SessionWindowStateStoreRestoreExec(_, _, None, None, _, child))) =>
+        SessionWindowStateStoreRestoreExec(_, _, None, None, None, _, child))) 
=>
           val aggStateInfo = nextStatefulOperationStateInfo
           SessionWindowStateStoreSaveExec(
             keys,
             session,
             Some(aggStateInfo),
             Some(outputMode),
-            Some(offsetSeqMetadata.batchWatermarkMs),
+            eventTimeWatermarkForLateEvents = 
Some(eventTimeWatermarkForLateEvents),
+            eventTimeWatermarkForEviction = 
Some(eventTimeWatermarkForEviction),
             stateFormatVersion,
             agg.withNewChildren(
               SessionWindowStateStoreRestoreExec(
                 keys,
                 session,
                 Some(aggStateInfo),
-                Some(offsetSeqMetadata.batchWatermarkMs),
+                eventTimeWatermarkForLateEvents = 
Some(eventTimeWatermarkForLateEvents),
+                eventTimeWatermarkForEviction = 
Some(eventTimeWatermarkForEviction),
                 stateFormatVersion,
                 child) :: Nil))
 
-      case StreamingDeduplicateExec(keys, child, None, None) =>
+      case StreamingDeduplicateExec(keys, child, None, None, None) =>
         StreamingDeduplicateExec(
           keys,
           child,
           Some(nextStatefulOperationStateInfo),
-          Some(offsetSeqMetadata.batchWatermarkMs))
+          eventTimeWatermarkForLateEvents = 
Some(eventTimeWatermarkForLateEvents),
+          eventTimeWatermarkForEviction = Some(eventTimeWatermarkForEviction))
 
       case m: FlatMapGroupsWithStateExec =>
         // We set this to true only for the first batch of the streaming query.
@@ -208,7 +225,8 @@ class IncrementalExecution(
         m.copy(
           stateInfo = Some(nextStatefulOperationStateInfo),
           batchTimestampMs = Some(offsetSeqMetadata.batchTimestampMs),
-          eventTimeWatermark = Some(offsetSeqMetadata.batchWatermarkMs),
+          eventTimeWatermarkForLateEvents = 
Some(eventTimeWatermarkForLateEvents),
+          eventTimeWatermarkForEviction = Some(eventTimeWatermarkForEviction),
           hasInitialState = hasInitialState
         )
 
@@ -216,17 +234,19 @@ class IncrementalExecution(
         m.copy(
           stateInfo = Some(nextStatefulOperationStateInfo),
           batchTimestampMs = Some(offsetSeqMetadata.batchTimestampMs),
-          eventTimeWatermark = Some(offsetSeqMetadata.batchWatermarkMs)
+          eventTimeWatermarkForLateEvents = 
Some(eventTimeWatermarkForLateEvents),
+          eventTimeWatermarkForEviction = Some(eventTimeWatermarkForEviction)
         )
 
       case j: StreamingSymmetricHashJoinExec =>
         j.copy(
           stateInfo = Some(nextStatefulOperationStateInfo),
-          eventTimeWatermark = Some(offsetSeqMetadata.batchWatermarkMs),
+          eventTimeWatermarkForLateEvents = 
Some(eventTimeWatermarkForLateEvents),
+          eventTimeWatermarkForEviction = 
Some(eventTimeWatermarkForLateEvents),
           stateWatermarkPredicates =
             StreamingSymmetricHashJoinHelper.getStateWatermarkPredicates(
               j.left.output, j.right.output, j.leftKeys, j.rightKeys, 
j.condition.full,
-              Some(offsetSeqMetadata.batchWatermarkMs)))
+              Some(eventTimeWatermarkForEviction)))
 
       case l: StreamingGlobalLimitExec =>
         l.copy(
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala
index 5f8fb93827b..7ed19b35114 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala
@@ -695,6 +695,7 @@ class MicroBatchExecution(
         id,
         runId,
         currentBatchId,
+        offsetLog.offsetSeqMetadataForBatchId(currentBatchId - 1),
         offsetSeqMetadata)
       lastExecution.executedPlan // Force the lazy generation of execution plan
     }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLog.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLog.scala
index 82e50263893..7f00717ea4d 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLog.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLog.scala
@@ -102,6 +102,10 @@ class OffsetSeqLog(sparkSession: SparkSession, path: 
String)
       }
     }
   }
+
+  def offsetSeqMetadataForBatchId(batchId: Long): Option[OffsetSeqMetadata] = {
+    if (batchId < 0) None else get(batchId).flatMap(_.metadata)
+  }
 }
 
 object OffsetSeqLog {
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala
index 4a8f3b18c09..dfde4156812 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala
@@ -119,7 +119,8 @@ import org.apache.spark.util.{CompletionIterator, 
SerializableConfiguration}
  * @param condition Conditions to filter rows, split by left, right, and 
joined. See
  *                  [[JoinConditionSplitPredicates]]
  * @param stateInfo Version information required to read join state (buffered 
rows)
- * @param eventTimeWatermark Watermark of input event, same for both sides
+ * @param eventTimeWatermarkForLateEvents Watermark for filtering late events, 
same for both sides
+ * @param eventTimeWatermarkForEviction Watermark for state eviction
  * @param stateWatermarkPredicates Predicates for removal of state, see
  *                                 [[JoinStateWatermarkPredicates]]
  * @param left      Left child plan
@@ -131,7 +132,8 @@ case class StreamingSymmetricHashJoinExec(
     joinType: JoinType,
     condition: JoinConditionSplitPredicates,
     stateInfo: Option[StatefulOperatorStateInfo],
-    eventTimeWatermark: Option[Long],
+    eventTimeWatermarkForLateEvents: Option[Long],
+    eventTimeWatermarkForEviction: Option[Long],
     stateWatermarkPredicates: JoinStateWatermarkPredicates,
     stateFormatVersion: Int,
     left: SparkPlan,
@@ -148,7 +150,8 @@ case class StreamingSymmetricHashJoinExec(
 
     this(
       leftKeys, rightKeys, joinType, JoinConditionSplitPredicates(condition, 
left, right),
-      stateInfo = None, eventTimeWatermark = None,
+      stateInfo = None,
+      eventTimeWatermarkForLateEvents = None, eventTimeWatermarkForEviction = 
None,
       stateWatermarkPredicates = JoinStateWatermarkPredicates(), 
stateFormatVersion, left, right)
   }
 
@@ -222,7 +225,8 @@ case class StreamingSymmetricHashJoinExec(
 
     // Latest watermark value is more than that used in this previous executed 
plan
     val watermarkHasChanged =
-      eventTimeWatermark.isDefined && newMetadata.batchWatermarkMs > 
eventTimeWatermark.get
+      eventTimeWatermarkForEviction.isDefined &&
+        newMetadata.batchWatermarkMs > eventTimeWatermarkForEviction.get
 
     watermarkUsedForStateCleanup && watermarkHasChanged
   }
@@ -555,7 +559,8 @@ case class StreamingSymmetricHashJoinExec(
 
       val watermarkAttribute = 
inputAttributes.find(_.metadata.contains(delayKey))
       val nonLateRows =
-        WatermarkSupport.watermarkExpression(watermarkAttribute, 
eventTimeWatermark) match {
+        WatermarkSupport.watermarkExpression(
+          watermarkAttribute, eventTimeWatermarkForLateEvents) match {
           case Some(watermarkExpr) =>
             val predicate = Predicate.create(watermarkExpr, inputAttributes)
             applyRemovingRowsOlderThanWatermark(inputIter, predicate)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala
index 2f62dbd7ec5..7bf6381e08f 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala
@@ -137,7 +137,7 @@ object StreamingSymmetricHashJoinHelper extends Logging {
       leftKeys: Seq[Expression],
       rightKeys: Seq[Expression],
       condition: Option[Expression],
-      eventTimeWatermark: Option[Long]): JoinStateWatermarkPredicates = {
+      eventTimeWatermarkForEviction: Option[Long]): 
JoinStateWatermarkPredicates = {
 
 
     // Join keys of both sides generate rows of the same fields, that is, same 
sequence of data
@@ -172,7 +172,7 @@ object StreamingSymmetricHashJoinHelper extends Logging {
           joinKeyOrdinalForWatermark.get,
           oneSideJoinKeys(joinKeyOrdinalForWatermark.get).dataType,
           oneSideJoinKeys(joinKeyOrdinalForWatermark.get).nullable)
-        val expr = watermarkExpression(Some(keyExprWithWatermark), 
eventTimeWatermark)
+        val expr = watermarkExpression(Some(keyExprWithWatermark), 
eventTimeWatermarkForEviction)
         expr.map(JoinStateKeyWatermarkPredicate.apply _)
 
       } else if (isWatermarkDefinedOnInput) { // case 2 in the 
StreamingSymmetricHashJoinExec docs
@@ -180,7 +180,7 @@ object StreamingSymmetricHashJoinHelper extends Logging {
           attributesToFindStateWatermarkFor = 
AttributeSet(oneSideInputAttributes),
           attributesWithEventWatermark = 
AttributeSet(otherSideInputAttributes),
           condition,
-          eventTimeWatermark)
+          eventTimeWatermarkForEviction)
         val inputAttributeWithWatermark = 
oneSideInputAttributes.find(_.metadata.contains(delayKey))
         val expr = watermarkExpression(inputAttributeWithWatermark, 
stateValueWatermark)
         expr.map(JoinStateValueWatermarkPredicate.apply _)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala
index 5b620eec25f..e8092e072bc 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala
@@ -218,6 +218,7 @@ class ContinuousExecution(
         id,
         runId,
         currentBatchId,
+        None,
         offsetSeqMetadata)
       lastExecution.executedPlan // Force the lazy generation of execution plan
     }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWrite.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWrite.scala
index 0a603a3b141..3f474ea533c 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWrite.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWrite.scala
@@ -28,7 +28,7 @@ import 
org.apache.spark.sql.connector.write.streaming.{StreamingDataWriterFactor
  */
 class MicroBatchWrite(epochId: Long, val writeSupport: StreamingWrite) extends 
BatchWrite {
   override def toString: String = {
-    s"MicroBathWrite[epoch: $epochId, writer: $writeSupport]"
+    s"MicroBatchWrite[epoch: $epochId, writer: $writeSupport]"
   }
 
   override def commit(messages: Array[WriterCommitMessage]): Unit = {
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
index b540f9f0093..457e5f80ae6 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
@@ -212,34 +212,73 @@ trait WatermarkSupport extends SparkPlan {
   /** The keys that may have a watermark attribute. */
   def keyExpressions: Seq[Attribute]
 
-  /** The watermark value. */
-  def eventTimeWatermark: Option[Long]
+  /**
+   * The watermark value for filtering late events/records. This should be the 
previous
+   * batch state eviction watermark.
+   */
+  def eventTimeWatermarkForLateEvents: Option[Long]
+  /**
+   * The watermark value for closing aggregates and evicting state.
+   * It is different from the late events filtering watermark (consider 
chained aggregators
+   * agg1 -> agg2: agg1 evicts state which will be effectively late against 
the eviction watermark
+   * but should not be late for agg2 input late record filtering watermark. 
Thus agg1 and agg2 use
+   * the current batch watermark for state eviction but the previous batch 
watermark for late
+   * record filtering.
+   */
+  def eventTimeWatermarkForEviction: Option[Long]
+
+  /** Generate an expression that matches data older than late event filtering 
watermark */
+  lazy val watermarkExpressionForLateEvents: Option[Expression] =
+    watermarkExpression(eventTimeWatermarkForLateEvents)
+  /** Generate an expression that matches data older than the state eviction 
watermark */
+  lazy val watermarkExpressionForEviction: Option[Expression] =
+    watermarkExpression(eventTimeWatermarkForEviction)
 
   /** Generate an expression that matches data older than the watermark */
-  lazy val watermarkExpression: Option[Expression] = {
+  private def watermarkExpression(watermark: Option[Long]): Option[Expression] 
= {
     WatermarkSupport.watermarkExpression(
-      child.output.find(_.metadata.contains(EventTimeWatermark.delayKey)),
-      eventTimeWatermark)
+      child.output.find(_.metadata.contains(EventTimeWatermark.delayKey)), 
watermark)
   }
 
-  /** Predicate based on keys that matches data older than the watermark */
-  lazy val watermarkPredicateForKeys: Option[BasePredicate] = 
watermarkExpression.flatMap { e =>
-    if 
(keyExpressions.exists(_.metadata.contains(EventTimeWatermark.delayKey))) {
-      Some(Predicate.create(e, keyExpressions))
-    } else {
-      None
+  /** Predicate based on keys that matches data older than the late event 
filtering watermark */
+  lazy val watermarkPredicateForKeysForLateEvents: Option[BasePredicate] =
+    watermarkPredicateForKeys(watermarkExpressionForLateEvents)
+
+  /** Generate an expression that matches data older than the state eviction 
watermark */
+  lazy val watermarkPredicateForKeysForEviction: Option[BasePredicate] =
+    watermarkPredicateForKeys(watermarkExpressionForEviction)
+
+  private def watermarkPredicateForKeys(
+      watermarkExpression: Option[Expression]): Option[BasePredicate] = {
+    watermarkExpression.flatMap { e =>
+      if 
(keyExpressions.exists(_.metadata.contains(EventTimeWatermark.delayKey))) {
+        Some(Predicate.create(e, keyExpressions))
+      } else {
+        None
+      }
     }
   }
 
-  /** Predicate based on the child output that matches data older than the 
watermark. */
-  lazy val watermarkPredicateForData: Option[BasePredicate] =
+  /**
+   * Predicate based on the child output that matches data older than the 
watermark for late events
+   * filtering.
+   */
+  lazy val watermarkPredicateForDataForLateEvents: Option[BasePredicate] =
+    watermarkPredicateForData(watermarkExpressionForLateEvents)
+
+  lazy val watermarkPredicateForDataForEviction: Option[BasePredicate] =
+    watermarkPredicateForData(watermarkExpressionForEviction)
+
+  private def watermarkPredicateForData(
+    watermarkExpression: Option[Expression]): Option[BasePredicate] = {
     watermarkExpression.map(Predicate.create(_, child.output))
+  }
 
   protected def removeKeysOlderThanWatermark(store: StateStore): Unit = {
-    if (watermarkPredicateForKeys.nonEmpty) {
+    if (watermarkPredicateForKeysForEviction.nonEmpty) {
       val numRemovedStateRows = longMetric("numRemovedStateRows")
       store.iterator().foreach { rowPair =>
-        if (watermarkPredicateForKeys.get.eval(rowPair.key)) {
+        if (watermarkPredicateForKeysForEviction.get.eval(rowPair.key)) {
           store.remove(rowPair.key)
           numRemovedStateRows += 1
         }
@@ -250,10 +289,10 @@ trait WatermarkSupport extends SparkPlan {
   protected def removeKeysOlderThanWatermark(
       storeManager: StreamingAggregationStateManager,
       store: StateStore): Unit = {
-    if (watermarkPredicateForKeys.nonEmpty) {
+    if (watermarkPredicateForKeysForEviction.nonEmpty) {
       val numRemovedStateRows = longMetric("numRemovedStateRows")
       storeManager.keys(store).foreach { keyRow =>
-        if (watermarkPredicateForKeys.get.eval(keyRow)) {
+        if (watermarkPredicateForKeysForEviction.get.eval(keyRow)) {
           storeManager.remove(store, keyRow)
           numRemovedStateRows += 1
         }
@@ -354,7 +393,8 @@ case class StateStoreSaveExec(
     keyExpressions: Seq[Attribute],
     stateInfo: Option[StatefulOperatorStateInfo] = None,
     outputMode: Option[OutputMode] = None,
-    eventTimeWatermark: Option[Long] = None,
+    eventTimeWatermarkForLateEvents: Option[Long] = None,
+    eventTimeWatermarkForEviction: Option[Long] = None,
     stateFormatVersion: Int,
     child: SparkPlan)
   extends UnaryExecNode with StateStoreWriter with WatermarkSupport {
@@ -407,7 +447,7 @@ case class StateStoreSaveExec(
           case Some(Append) =>
             allUpdatesTimeMs += timeTakenMs {
               val filteredIter = applyRemovingRowsOlderThanWatermark(iter,
-                watermarkPredicateForData.get)
+                watermarkPredicateForDataForLateEvents.get)
               while (filteredIter.hasNext) {
                 val row = filteredIter.next().asInstanceOf[UnsafeRow]
                 stateManager.put(store, row)
@@ -423,7 +463,7 @@ case class StateStoreSaveExec(
                 var removedValueRow: InternalRow = null
                 while(rangeIter.hasNext && removedValueRow == null) {
                   val rowPair = rangeIter.next()
-                  if (watermarkPredicateForKeys.get.eval(rowPair.key)) {
+                  if 
(watermarkPredicateForKeysForEviction.get.eval(rowPair.key)) {
                     stateManager.remove(store, rowPair.key)
                     numRemovedStateRows += 1
                     removedValueRow = rowPair.value
@@ -453,7 +493,7 @@ case class StateStoreSaveExec(
 
             new NextIterator[InternalRow] {
               // Filter late date using watermark if specified
-              private[this] val baseIterator = watermarkPredicateForData match 
{
+              private[this] val baseIterator = 
watermarkPredicateForDataForLateEvents match {
                 case Some(predicate) => 
applyRemovingRowsOlderThanWatermark(iter, predicate)
                 case None => iter
               }
@@ -507,8 +547,8 @@ case class StateStoreSaveExec(
 
   override def shouldRunAnotherBatch(newMetadata: OffsetSeqMetadata): Boolean 
= {
     (outputMode.contains(Append) || outputMode.contains(Update)) &&
-      eventTimeWatermark.isDefined &&
-      newMetadata.batchWatermarkMs > eventTimeWatermark.get
+      eventTimeWatermarkForEviction.isDefined &&
+      newMetadata.batchWatermarkMs > eventTimeWatermarkForEviction.get
   }
 
   override protected def withNewChildInternal(newChild: SparkPlan): 
StateStoreSaveExec =
@@ -525,7 +565,8 @@ case class SessionWindowStateStoreRestoreExec(
     keyWithoutSessionExpressions: Seq[Attribute],
     sessionExpression: Attribute,
     stateInfo: Option[StatefulOperatorStateInfo],
-    eventTimeWatermark: Option[Long],
+    eventTimeWatermarkForLateEvents: Option[Long],
+    eventTimeWatermarkForEviction: Option[Long],
     stateFormatVersion: Int,
     child: SparkPlan)
   extends UnaryExecNode with StateStoreReader with WatermarkSupport {
@@ -555,7 +596,7 @@ case class SessionWindowStateStoreRestoreExec(
       Some(session.streams.stateStoreCoordinator)) { case (store, iter) =>
 
       // We need to filter out outdated inputs
-      val filteredIterator = watermarkPredicateForData match {
+      val filteredIterator = watermarkPredicateForDataForLateEvents match {
         case Some(predicate) => iter.filter((row: InternalRow) => {
           val shouldKeep = !predicate.eval(row)
           if (!shouldKeep) longMetric("numRowsDroppedByWatermark") += 1
@@ -611,7 +652,8 @@ case class SessionWindowStateStoreSaveExec(
     sessionExpression: Attribute,
     stateInfo: Option[StatefulOperatorStateInfo] = None,
     outputMode: Option[OutputMode] = None,
-    eventTimeWatermark: Option[Long] = None,
+    eventTimeWatermarkForLateEvents: Option[Long] = None,
+    eventTimeWatermarkForEviction: Option[Long] = None,
     stateFormatVersion: Int,
     child: SparkPlan)
   extends UnaryExecNode with StateStoreWriter with WatermarkSupport {
@@ -667,7 +709,7 @@ case class SessionWindowStateStoreSaveExec(
           val removalStartTimeNs = System.nanoTime
           new NextIterator[InternalRow] {
             private val removedIter = stateManager.removeByValueCondition(
-              store, watermarkPredicateForData.get.eval)
+              store, watermarkPredicateForDataForEviction.get.eval)
 
             override protected def getNext(): InternalRow = {
               if (!removedIter.hasNext) {
@@ -704,8 +746,8 @@ case class SessionWindowStateStoreSaveExec(
 
   override def shouldRunAnotherBatch(newMetadata: OffsetSeqMetadata): Boolean 
= {
     (outputMode.contains(Append) || outputMode.contains(Update)) &&
-      eventTimeWatermark.isDefined &&
-      newMetadata.batchWatermarkMs > eventTimeWatermark.get
+      eventTimeWatermarkForEviction.isDefined &&
+      newMetadata.batchWatermarkMs > eventTimeWatermarkForEviction.get
   }
 
   private def putToStore(iter: Iterator[InternalRow], store: StateStore): Unit 
= {
@@ -775,7 +817,8 @@ case class StreamingDeduplicateExec(
     keyExpressions: Seq[Attribute],
     child: SparkPlan,
     stateInfo: Option[StatefulOperatorStateInfo] = None,
-    eventTimeWatermark: Option[Long] = None)
+    eventTimeWatermarkForLateEvents: Option[Long] = None,
+    eventTimeWatermarkForEviction: Option[Long] = None)
   extends UnaryExecNode with StateStoreWriter with WatermarkSupport {
 
   /** Distribute by grouping attributes */
@@ -807,7 +850,7 @@ case class StreamingDeduplicateExec(
       val commitTimeMs = longMetric("commitTimeMs")
       val numDroppedDuplicateRows = longMetric("numDroppedDuplicateRows")
 
-      val baseIterator = watermarkPredicateForData match {
+      val baseIterator = watermarkPredicateForDataForLateEvents match {
         case Some(predicate) => applyRemovingRowsOlderThanWatermark(iter, 
predicate)
         case None => iter
       }
@@ -851,7 +894,8 @@ case class StreamingDeduplicateExec(
   override def shortName: String = "dedupe"
 
   override def shouldRunAnotherBatch(newMetadata: OffsetSeqMetadata): Boolean 
= {
-    eventTimeWatermark.isDefined && newMetadata.batchWatermarkMs > 
eventTimeWatermark.get
+    eventTimeWatermarkForEviction.isDefined &&
+      newMetadata.batchWatermarkMs > eventTimeWatermarkForEviction.get
   }
 
   override protected def withNewChildInternal(newChild: SparkPlan): 
StreamingDeduplicateExec =
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
index 14f083bbd30..49f4214ac1a 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
@@ -1048,7 +1048,7 @@ class FlatMapGroupsWithStateSuite extends 
StateStoreMetricsTest {
             hasInitialState, sga, sda, se, i, c) =>
           FlatMapGroupsWithStateExec(
             f, k, v, se, g, sga, d, sda, o, None, s, stateFormatVersion, m, t,
-            Some(currentBatchTimestamp), Some(currentBatchWatermark),
+            Some(currentBatchTimestamp), Some(0), Some(currentBatchWatermark),
             RDDScanExec(g, emptyRdd, "rdd"),
             hasInitialState,
             RDDScanExec(g, emptyRdd, "rdd"))
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MultiStatefulOperatorsSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MultiStatefulOperatorsSuite.scala
new file mode 100644
index 00000000000..0a3ea40a677
--- /dev/null
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MultiStatefulOperatorsSuite.scala
@@ -0,0 +1,440 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.streaming
+
+import org.scalatest.BeforeAndAfter
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.execution.streaming.MemoryStream
+import org.apache.spark.sql.execution.streaming.state.StateStore
+import org.apache.spark.sql.functions._
+
+// Tests for the multiple stateful operators support.
+class MultiStatefulOperatorsSuite
+  extends StreamTest with StateStoreMetricsTest with BeforeAndAfter {
+
+  import testImplicits._
+
+  before {
+    SparkSession.setActiveSession(spark) // set this before force initializing 
'joinExec'
+    spark.streams.stateStoreCoordinator // initialize the lazy coordinator
+  }
+
+  after {
+    StateStore.stop()
+  }
+
+  test("window agg -> window agg, append mode") {
+    // TODO: SPARK-40940 - Fix the unsupported ops checker to allow chaining 
of stateful ops.
+    withSQLConf("spark.sql.streaming.unsupportedOperationCheck" -> "false") {
+      val inputData = MemoryStream[Int]
+
+      val stream = inputData.toDF()
+        .withColumn("eventTime", timestamp_seconds($"value"))
+        .withWatermark("eventTime", "0 seconds")
+        .groupBy(window($"eventTime", "5 seconds").as("window"))
+        .agg(count("*").as("count"))
+        .groupBy(window($"window", "10 seconds"))
+        .agg(count("*").as("count"), sum("count").as("sum"))
+        .select($"window".getField("start").cast("long").as[Long],
+          $"count".as[Long], $"sum".as[Long])
+
+      testStream(stream)(
+        AddData(inputData, 10 to 21: _*),
+        // op1 W (0, 0)
+        // agg: [10, 15) 5, [15, 20) 5, [20, 25) 2
+        // output: None
+        // state: [10, 15) 5, [15, 20) 5, [20, 25) 2
+        // op2 W (0, 0)
+        // agg: None
+        // output: None
+        // state: None
+
+        // no-data batch triggered
+
+        // op1 W (0, 21)
+        // agg: None
+        // output: [10, 15) 5, [15, 20) 5
+        // state: [20, 25) 2
+        // op2 W (0, 21)
+        // agg: [10, 20) (2, 10)
+        // output: [10, 20) (2, 10)
+        // state: None
+        CheckNewAnswer((10, 2, 10)),
+        assertNumStateRows(Seq(0, 1)),
+        assertNumRowsDroppedByWatermark(Seq(0, 0)),
+
+        AddData(inputData, 10 to 29: _*),
+        // op1 W (21, 21)
+        // agg: [10, 15) 5 - late, [15, 20) 5 - late, [20, 25) 5, [25, 30) 5
+        // output: None
+        // state: [20, 25) 7, [25, 30) 5
+        // op2 W (21, 21)
+        // agg: None
+        // output: None
+        // state: None
+
+        // no-data batch triggered
+
+        // op1 W (21, 29)
+        // agg: None
+        // output: [20, 25) 7
+        // state: [25, 30) 5
+        // op2 W (21, 29)
+        // agg: [20, 30) (1, 7)
+        // output: None
+        // state: [20, 30) (1, 7)
+        CheckNewAnswer(),
+        assertNumStateRows(Seq(1, 1)),
+        assertNumRowsDroppedByWatermark(Seq(0, 2)),
+
+        // Move the watermark.
+        AddData(inputData, 30, 31),
+        // op1 W (29, 29)
+        // agg: [30, 35) 2
+        // output: None
+        // state: [25, 30) 5 [30, 35) 2
+        // op2 W (29, 29)
+        // agg: None
+        // output: None
+        // state: [20, 30) (1, 7)
+
+        // no-data batch triggered
+
+        // op1 W (29, 31)
+        // agg: None
+        // output: [25, 30) 5
+        // state: [30, 35) 2
+        // op2 W (29, 31)
+        // agg: [20, 30) (2, 12)
+        // output: [20, 30) (2, 12)
+        // state: None
+        CheckNewAnswer((20, 2, 12)),
+        assertNumStateRows(Seq(0, 1)),
+        assertNumRowsDroppedByWatermark(Seq(0, 0))
+      )
+    }
+  }
+
+  test("agg -> agg -> agg, append mode") {
+    withSQLConf("spark.sql.streaming.unsupportedOperationCheck" -> "false") {
+      val inputData = MemoryStream[Int]
+
+      val stream = inputData.toDF()
+        .withColumn("eventTime", timestamp_seconds($"value"))
+        .withWatermark("eventTime", "0 seconds")
+        .groupBy(window($"eventTime", "5 seconds").as("window"))
+        .agg(count("*").as("count"))
+        .groupBy(window(window_time($"window"), "10 seconds"))
+        .agg(count("*").as("count"), sum("count").as("sum"))
+        .groupBy(window(window_time($"window"), "20 seconds"))
+        .agg(count("*").as("count"), sum("sum").as("sum"))
+        .select(
+          $"window".getField("start").cast("long").as[Long],
+          $"window".getField("end").cast("long").as[Long],
+          $"count".as[Long], $"sum".as[Long])
+
+      testStream(stream)(
+        AddData(inputData, 0 to 37: _*),
+        // op1 W (0, 0)
+        // agg: [0, 5) 5, [5, 10) 5, [10, 15) 5, [15, 20) 5, [20, 25) 5, [25, 
30) 5, [30, 35) 5,
+        //   [35, 40) 3
+        // output: None
+        // state: [0, 5) 5, [5, 10) 5, [10, 15) 5, [15, 20) 5, [20, 25) 5, 
[25, 30) 5, [30, 35) 5,
+        //   [35, 40) 3
+        // op2 W (0, 0)
+        // agg: None
+        // output: None
+        // state: None
+        // op3 W (0, 0)
+        // agg: None
+        // output: None
+        // state: None
+
+        // no-data batch triggered
+
+        // op1 W (0, 37)
+        // agg: None
+        // output: [0, 5) 5, [5, 10) 5, [10, 15) 5, [15, 20) 5, [20, 25) 5, 
[25, 30) 5, [30, 35) 5
+        // state: [35, 40) 3
+        // op2 W (0, 37)
+        // agg: [0, 10) (2, 10), [10, 20) (2, 10), [20, 30) (2, 10), [30, 40) 
(1, 5)
+        // output: [0, 10) (2, 10), [10, 20) (2, 10), [20, 30) (2, 10)
+        // state: [30, 40) (1, 5)
+        // op3 W (0, 37)
+        // agg: [0, 20) (2, 20), [20, 40) (1, 10)
+        // output: [0, 20) (2, 20)
+        // state: [20, 40) (1, 10)
+        CheckNewAnswer((0, 20, 2, 20)),
+        assertNumStateRows(Seq(1, 1, 1)),
+        assertNumRowsDroppedByWatermark(Seq(0, 0, 0)),
+
+        AddData(inputData, 30 to 60: _*),
+        // op1 W (37, 37)
+        // dropped rows: [30, 35), 1 row <= note that 35, 36, 37 are still in 
effect
+        // agg: [35, 40) 8, [40, 45) 5, [45, 50) 5, [50, 55) 5, [55, 60) 5, 
[60, 65) 1
+        // output: None
+        // state: [35, 40) 8, [40, 45) 5, [45, 50) 5, [50, 55) 5, [55, 60) 5, 
[60, 65) 1
+        // op2 W (37, 37)
+        // output: None
+        // state: [30, 40) (1, 5)
+        // op3 W (37, 37)
+        // output: None
+        // state: [20, 40) (1, 10)
+
+        // no-data batch
+        // op1 W (37, 60)
+        // output: [35, 40) 8, [40, 45) 5, [45, 50) 5, [50, 55) 5, [55, 60) 5
+        // state: [60, 65) 1
+        // op2 W (37, 60)
+        // agg: [30, 40) (2, 13), [40, 50) (2, 10), [50, 60), (2, 10)
+        // output: [30, 40) (2, 13), [40, 50) (2, 10), [50, 60), (2, 10)
+        // state: None
+        // op3 W (37, 60)
+        // agg: [20, 40) (2, 23), [40, 60) (2, 20)
+        // output: [20, 40) (2, 23), [40, 60) (2, 20)
+        // state: None
+
+        CheckNewAnswer((20, 40, 2, 23), (40, 60, 2, 20)),
+        assertNumStateRows(Seq(0, 0, 1)),
+        assertNumRowsDroppedByWatermark(Seq(0, 0, 1))
+      )
+    }
+  }
+
+  test("stream deduplication -> aggregation, append mode") {
+    withSQLConf("spark.sql.streaming.unsupportedOperationCheck" -> "false") {
+      val inputData = MemoryStream[Int]
+
+      val deduplication = inputData.toDF()
+        .withColumn("eventTime", timestamp_seconds($"value"))
+        .withWatermark("eventTime", "10 seconds")
+        .dropDuplicates("value", "eventTime")
+
+      val windowedAggregation = deduplication
+        .groupBy(window($"eventTime", "5 seconds").as("window"))
+        .agg(count("*").as("count"), sum("value").as("sum"))
+        .select($"window".getField("start").cast("long").as[Long],
+          $"count".as[Long])
+
+      testStream(windowedAggregation)(
+        AddData(inputData, 1 to 15: _*),
+        // op1 W (0, 0)
+        // input: 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15
+        // deduplicated: None
+        // output: 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15
+        // state: 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15
+        // op2 W (0, 0)
+        // agg: [0, 5) 4, [5, 10) 5 [10, 15) 5, [15, 20) 1
+        // output: None
+        // state: [0, 5) 4, [5, 10) 5 [10, 15) 5, [15, 20) 1
+
+        // no-data batch triggered
+
+        // op1 W (0, 5)
+        // agg: None
+        // output: None
+        // state: 6, 7, 8, 9, 10, 11, 12, 13, 14, 15
+        // op2 W (0, 5)
+        // agg: None
+        // output: [0, 5) 4
+        // state: [5, 10) 5 [10, 15) 5, [15, 20) 1
+        CheckNewAnswer((0, 4)),
+        assertNumStateRows(Seq(3, 10)),
+        assertNumRowsDroppedByWatermark(Seq(0, 0))
+      )
+    }
+  }
+
+  test("join -> window agg, append mode") {
+    withSQLConf("spark.sql.streaming.unsupportedOperationCheck" -> "false") {
+      val input1 = MemoryStream[Int]
+      val inputDF1 = input1.toDF
+        .withColumnRenamed("value", "value1")
+        .withColumn("eventTime1", timestamp_seconds($"value1"))
+        .withWatermark("eventTime1", "0 seconds")
+
+      val input2 = MemoryStream[Int]
+      val inputDF2 = input2.toDF
+        .withColumnRenamed("value", "value2")
+        .withColumn("eventTime2", timestamp_seconds($"value2"))
+        .withWatermark("eventTime2", "0 seconds")
+
+      val stream = inputDF1.join(inputDF2, expr("eventTime1 = eventTime2"), 
"inner")
+        .groupBy(window($"eventTime1", "5 seconds").as("window"))
+        .agg(count("*").as("count"))
+        .select($"window".getField("start").cast("long").as[Long], 
$"count".as[Long])
+
+      testStream(stream)(
+        MultiAddData(input1, 1 to 4: _*)(input2, 1 to 4: _*),
+
+        // op1 W (0, 0)
+        // join output: (1, 1), (2, 2), (3, 3), (4, 4)
+        // state: (1, 1), (2, 2), (3, 3), (4, 4)
+        // op2 W (0, 0)
+        // agg: [0, 5) 4
+        // output: None
+        // state: [0, 5) 4
+
+        // no-data batch triggered
+
+        // op1 W (0, 4)
+        // join output: None
+        // state: None
+        // op2 W (0, 4)
+        // agg: None
+        // output: None
+        // state: [0, 5) 4
+        CheckNewAnswer(),
+        assertNumStateRows(Seq(1, 0)),
+        assertNumRowsDroppedByWatermark(Seq(0, 0)),
+
+        // Move the watermark
+        MultiAddData(input1, 5)(input2, 5),
+
+        // op1 W (4, 4)
+        // join output: (5, 5)
+        // state: (5, 5)
+        // op2 W (4, 4)
+        // agg: [5, 10) 1
+        // output: None
+        // state: [0, 5) 4, [5, 10) 1
+
+        // no-data batch triggered
+
+        // op1 W (4, 5)
+        // join output: None
+        // state: None
+        // op2 W (4, 5)
+        // agg: None
+        // output: [0, 5) 4
+        // state: [5, 10) 1
+        CheckNewAnswer((0, 4)),
+        assertNumStateRows(Seq(1, 0)),
+        assertNumRowsDroppedByWatermark(Seq(0, 0))
+      )
+    }
+  }
+
+  test("aggregation -> stream deduplication, append mode") {
+    withSQLConf("spark.sql.streaming.unsupportedOperationCheck" -> "false") {
+      val inputData = MemoryStream[Int]
+
+      val aggStream = inputData.toDF()
+        .withColumn("eventTime", timestamp_seconds($"value"))
+        .withWatermark("eventTime", "0 seconds")
+        .groupBy(window($"eventTime", "5 seconds").as("window"))
+        .agg(count("*").as("count"))
+        .withColumn("windowEnd", expr("window.end"))
+
+      // dropDuplicates from aggStream without event time column for 
dropDuplicates - the
+      // state does not get trimmed due to watermark advancement.
+      val dedupNoEventTime = aggStream
+        .dropDuplicates("count", "windowEnd")
+        .select(
+          $"windowEnd".cast("long").as[Long],
+          $"count".as[Long])
+
+      testStream(dedupNoEventTime)(
+        AddData(inputData, 1, 5, 10, 15),
+
+        // op1 W (0, 0)
+        // agg: [0, 5) 1, [5, 10) 1, [10, 15) 1, [15, 20) 1
+        // output: None
+        // state: [0, 5) 1, [5, 10) 1, [10, 15) 1, [15, 20) 1
+        // op2 W (0, 0)
+        // output: None
+        // state: None
+
+        // no-data batch triggered
+
+        // op1 W (0, 15)
+        // agg: None
+        // output: [0, 5) 1, [5, 10) 1, [10, 15) 1
+        // state: [15, 20) 1
+        // op2 W (0, 15)
+        // output: (5, 1), (10, 1), (15, 1)
+        // state: (5, 1), (10, 1), (15, 1)
+
+        CheckNewAnswer((5, 1), (10, 1), (15, 1)),
+        assertNumStateRows(Seq(3, 1)),
+        assertNumRowsDroppedByWatermark(Seq(0, 0))
+      )
+
+      // Similar to the above but add event time. The dedup state will get 
trimmed.
+      val dedupWithEventTime = aggStream
+        .withColumn("windowTime", expr("window_time(window)"))
+        .withColumn("windowTimeMicros", expr("unix_micros(windowTime)"))
+        .dropDuplicates("count", "windowEnd", "windowTime")
+        .select(
+          $"windowEnd".cast("long").as[Long],
+          $"windowTimeMicros".cast("long").as[Long],
+          $"count".as[Long])
+
+      testStream(dedupWithEventTime)(
+        AddData(inputData, 1, 5, 10, 15),
+
+        // op1 W (0, 0)
+        // agg: [0, 5) 1, [5, 10) 1, [10, 15) 1, [15, 20) 1
+        // output: None
+        // state: [0, 5) 1, [5, 10) 1, [10, 15) 1, [15, 20) 1
+        // op2 W (0, 0)
+        // output: None
+        // state: None
+
+        // no-data batch triggered
+
+        // op1 W (0, 15)
+        // agg: None
+        // output: [0, 5) 1, [5, 10) 1, [10, 15) 1
+        // state: [15, 20) 1
+        // op2 W (0, 15)
+        // output: (5, 4999999, 1), (10, 9999999, 1), (15, 14999999, 1)
+        // state: None - trimmed by watermark
+
+        CheckNewAnswer((5, 4999999, 1), (10, 9999999, 1), (15, 14999999, 1)),
+        assertNumStateRows(Seq(0, 1)),
+        assertNumRowsDroppedByWatermark(Seq(0, 0))
+      )
+    }
+  }
+
+  private def assertNumStateRows(numTotalRows: Seq[Long]): AssertOnQuery = 
AssertOnQuery { q =>
+    q.processAllAvailable()
+    val progressWithData = q.recentProgress.lastOption.get
+    val stateOperators = progressWithData.stateOperators
+    assert(stateOperators.size === numTotalRows.size)
+    assert(stateOperators.map(_.numRowsTotal).toSeq === numTotalRows)
+    true
+  }
+
+  private def assertNumRowsDroppedByWatermark(
+      numRowsDroppedByWatermark: Seq[Long]): AssertOnQuery = AssertOnQuery { q 
=>
+    q.processAllAvailable()
+    val progressWithData = q.recentProgress.filterNot { p =>
+      // filter out batches which are falling into one of types:
+      // 1) doesn't execute the batch run
+      // 2) empty input batch
+      p.numInputRows == 0
+    }.lastOption.get
+    val stateOperators = progressWithData.stateOperators
+    assert(stateOperators.size === numRowsDroppedByWatermark.size)
+    assert(stateOperators.map(_.numRowsDroppedByWatermark).toSeq === 
numRowsDroppedByWatermark)
+    true
+  }
+}
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala
index 8ef8c21e13a..40868f896f5 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala
@@ -618,7 +618,7 @@ class StreamingInnerJoinSuite extends StreamingJoinSuite {
         val numPartitions = 
spark.sqlContext.conf.getConf(SQLConf.SHUFFLE_PARTITIONS)
 
         assert(query.lastExecution.executedPlan.collect {
-          case j @ StreamingSymmetricHashJoinExec(_, _, _, _, _, _, _, _,
+          case j @ StreamingSymmetricHashJoinExec(_, _, _, _, _, _, _, _, _,
             ShuffleExchangeExec(opA: HashPartitioning, _, _),
             ShuffleExchangeExec(opB: HashPartitioning, _, _))
               if partitionExpressionsColumns(opA.expressions) === Seq("a", "b")


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to